Skip to main content

slim_datapath/messages/
utils.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3
4use std::fmt::Display;
5use std::{collections::HashMap, time::Duration};
6
7use super::encoder::Name;
8use crate::api::proto::dataplane::v1::{GroupClosePayload, GroupNackPayload, PingPayload};
9use crate::api::{
10    Content, LinkNegotiationPayload, MessageType, ProtoLink, ProtoLinkMessageType, ProtoLinkType,
11    ProtoMessage, ProtoName, ProtoPublish, ProtoPublishType, ProtoSessionType, ProtoSubscribe,
12    ProtoSubscribeType, ProtoSubscriptionAck, ProtoSubscriptionAckType, ProtoUnsubscribe,
13    ProtoUnsubscribeType, SessionHeader, SlimHeader,
14    proto::dataplane::v1::{
15        ApplicationPayload, CommandPayload, DiscoveryReplyPayload, DiscoveryRequestPayload,
16        EncodedName, GroupAckPayload, GroupAddPayload, GroupProposalPayload, GroupRemovePayload,
17        GroupWelcomePayload, JoinReplyPayload, JoinRequestPayload, LeaveReplyPayload,
18        LeaveRequestPayload, MlsPayload, SessionMessageType, StringName, TimerSettings,
19        command_payload::CommandPayloadType, content::ContentType,
20    },
21};
22
23use thiserror::Error;
24
25/// IS_MODERATOR is used in the Join request metadata to indicate whether a participant is the moderator/initiator of a session.
26/// The value is set to `TRUE_VAL` when the participant is a moderator.
27pub const IS_MODERATOR: &str = "IS_MODERATOR";
28
29/// DELETE_GROUP indicates that the entire group should be deleted.
30/// When added to a leave request message, it signals to the moderator that the group needs to be closed.
31/// The value is set to `TRUE_VAL` to trigger group deletion.
32pub const DELETE_GROUP: &str = "DELETE_GROUP";
33
34/// PUBLISH_TO indicates that a message should bypass normal sequencing and be delivered directly to the specified endpoint.
35/// This is used in group sessions when the application API `publish_to` is used instead of `publish`.
36/// The value is set to `TRUE_VAL` for direct delivery without buffering.
37pub const PUBLISH_TO: &str = "PUBLISH_TO";
38
39/// DISCONNECTION_DETECTED indicates that a participant disconnection was detected (not a graceful leave).
40/// This is used in the leave request message and internally by the moderator when
41/// a disconnection is detected due to missing ping replies from the participant.
42/// The value is set to `TRUE_VAL` when disconnection is detected.
43pub const DISCONNECTION_DETECTED: &str = "DISCONNECTION_DETECTED";
44
45/// LEAVING_SESSION indicates that a participant is gracefully leaving the session.
46/// This is used in the leave request message sent by a participant closing the session to the moderator.
47/// The value is set to `TRUE_VAL` for graceful departure.
48pub const LEAVING_SESSION: &str = "LEAVING_SESSION";
49
50/// Standard string value representing a boolean "true" in message metadata.
51pub const TRUE_VAL: &str = "TRUE";
52
53/// Standard string value representing a boolean "false" in message metadata.
54pub const FALSE_VAL: &str = "FALSE";
55
56/// Maximum message ID for normal sequenced messages.
57/// Messages with IDs in the range [0, MAX_PUBLISH_ID] follow normal sequencing.
58/// Messages with IDs > MAX_PUBLISH_ID (used for `PUBLISH_TO` messages) bypass sequencing.
59/// Value: Half of u32::MAX to allow a separate ID space for out-of-band messages.
60pub const MAX_PUBLISH_ID: u32 = u32::MAX / 2;
61
62#[derive(Error, Debug, PartialEq)]
63pub enum MessageError {
64    #[error("SLIM header not found")]
65    SlimHeaderNotFound,
66    #[error("source not found")]
67    SourceNotFound,
68    #[error("destination not found")]
69    DestinationNotFound,
70    #[error("session header not found")]
71    SessionHeaderNotFound,
72    #[error("message type not found")]
73    MessageTypeNotFound,
74    #[error("incoming connection not found")]
75    IncomingConnectionNotFound,
76    #[error("content type is not set")]
77    ContentTypeNotSet,
78    #[error("content is not an application payload")]
79    NotApplicationPayload,
80    #[error("content is not a command payload")]
81    NotCommandPayload,
82    #[error("link type is not set")]
83    LinkTypeNotSet,
84    #[error("invalid command payload type: expected {expected}, got {got}")]
85    InvalidCommandPayloadType {
86        expected: Box<String>,
87        got: Box<String>,
88    },
89
90    // Builder errors
91    #[error("builder error: source is required")]
92    BuilderErrorSourceRequired,
93    #[error("builder error: destination is required")]
94    BuilderErrorDestinationRequired,
95}
96
97/// ProtoName from Name
98impl From<&Name> for ProtoName {
99    fn from(name: &Name) -> Self {
100        Self {
101            name: Some(EncodedName {
102                component_0: name.components()[0],
103                component_1: name.components()[1],
104                component_2: name.components()[2],
105                component_3: name.components()[3],
106            }),
107            str_name: Some(StringName {
108                str_component_0: name.components_strings()[0].clone(),
109                str_component_1: name.components_strings()[1].clone(),
110                str_component_2: name.components_strings()[2].clone(),
111            }),
112        }
113    }
114}
115
116/// Print message type
117impl Display for MessageType {
118    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119        match self {
120            MessageType::Publish(_) => write!(f, "publish"),
121            MessageType::Subscribe(_) => write!(f, "subscribe"),
122            MessageType::Unsubscribe(_) => write!(f, "unsubscribe"),
123            MessageType::Link(_) => write!(f, "link"),
124            MessageType::SubscriptionAck(_) => write!(f, "subscription_ack"),
125        }
126    }
127}
128
129/// Struct grouping the SLIMHeaeder flags for convenience
130#[derive(Debug, Clone)]
131pub struct SlimHeaderFlags {
132    pub fanout: u32,
133    pub recv_from: Option<u64>,
134    pub forward_to: Option<u64>,
135    pub incoming_conn: Option<u64>,
136    pub error: Option<bool>,
137}
138
139impl Default for SlimHeaderFlags {
140    fn default() -> Self {
141        Self {
142            fanout: 1,
143            recv_from: None,
144            forward_to: None,
145            incoming_conn: None,
146            error: None,
147        }
148    }
149}
150
151impl Display for SlimHeaderFlags {
152    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153        write!(
154            f,
155            "fanout: {}, recv_from: {:?}, forward_to: {:?}, incoming_conn: {:?}, error: {:?}",
156            self.fanout, self.recv_from, self.forward_to, self.incoming_conn, self.error
157        )
158    }
159}
160
161impl SlimHeaderFlags {
162    pub fn new(
163        fanout: u32,
164        recv_from: Option<u64>,
165        forward_to: Option<u64>,
166        incoming_conn: Option<u64>,
167        error: Option<bool>,
168    ) -> Self {
169        Self {
170            fanout,
171            recv_from,
172            forward_to,
173            incoming_conn,
174            error,
175        }
176    }
177
178    pub fn with_fanout(self, fanout: u32) -> Self {
179        Self { fanout, ..self }
180    }
181
182    pub fn with_recv_from(self, recv_from: u64) -> Self {
183        Self {
184            recv_from: Some(recv_from),
185            ..self
186        }
187    }
188
189    pub fn with_forward_to(self, forward_to: u64) -> Self {
190        Self {
191            forward_to: Some(forward_to),
192            ..self
193        }
194    }
195
196    pub fn with_incoming_conn(self, incoming_conn: u64) -> Self {
197        Self {
198            incoming_conn: Some(incoming_conn),
199            ..self
200        }
201    }
202
203    pub fn with_error(self, error: bool) -> Self {
204        Self {
205            error: Some(error),
206            ..self
207        }
208    }
209}
210
211/// SLIM Header
212/// This header is used to identify the source and destination of the message
213/// and to manage the connections used to send and receive the message
214impl SlimHeader {
215    pub fn new(
216        source: &Name,
217        destination: &Name,
218        identity: &str,
219        flags: Option<SlimHeaderFlags>,
220    ) -> Self {
221        let flags = flags.unwrap_or_default();
222
223        Self {
224            source: Some(ProtoName::from(source)),
225            destination: Some(ProtoName::from(destination)),
226            identity: identity.to_string(),
227            fanout: flags.fanout,
228            recv_from: flags.recv_from,
229            forward_to: flags.forward_to,
230            incoming_conn: flags.incoming_conn,
231            error: flags.error,
232        }
233    }
234
235    pub fn clear_flags(&mut self) {
236        self.recv_from = None;
237        self.forward_to = None;
238    }
239
240    pub fn get_fanout(&self) -> u32 {
241        self.fanout
242    }
243
244    pub fn get_recv_from(&self) -> Option<u64> {
245        self.recv_from
246    }
247
248    pub fn get_forward_to(&self) -> Option<u64> {
249        self.forward_to
250    }
251
252    pub fn get_incoming_conn(&self) -> Option<u64> {
253        self.incoming_conn
254    }
255
256    pub fn get_error(&self) -> Option<bool> {
257        self.error
258    }
259
260    pub fn get_source(&self) -> Name {
261        match &self.source {
262            Some(source) => Name::from(source),
263            None => panic!("source not found"),
264        }
265    }
266
267    pub fn get_dst(&self) -> Name {
268        match &self.destination {
269            Some(destination) => Name::from(destination),
270            None => panic!("destination not found"),
271        }
272    }
273
274    pub fn get_identity(&self) -> String {
275        self.identity.clone()
276    }
277
278    pub fn set_source(&mut self, source: &Name) {
279        self.source = Some(ProtoName::from(source));
280    }
281
282    pub fn set_destination(&mut self, dst: &Name) {
283        self.destination = Some(ProtoName::from(dst));
284    }
285
286    pub fn set_identity(&mut self, identity: String) {
287        self.identity = identity;
288    }
289
290    pub fn set_fanout(&mut self, fanout: u32) {
291        self.fanout = fanout;
292    }
293
294    pub fn set_recv_from(&mut self, recv_from: Option<u64>) {
295        self.recv_from = recv_from;
296    }
297
298    pub fn set_forward_to(&mut self, forward_to: Option<u64>) {
299        self.forward_to = forward_to;
300    }
301
302    pub fn set_error(&mut self, error: Option<bool>) {
303        self.error = error;
304    }
305
306    pub fn set_incoming_conn(&mut self, incoming_conn: Option<u64>) {
307        self.incoming_conn = incoming_conn;
308    }
309
310    pub fn set_error_flag(&mut self, error: Option<bool>) {
311        self.error = error;
312    }
313
314    // returns (incoming, recv_from, forward_to) for subscription processing
315    pub(crate) fn get_connections(&self) -> (u64, Option<u64>, Option<u64>) {
316        // when calling this function, incoming connection is set
317        let incoming = self
318            .get_incoming_conn()
319            .expect("incoming connection not found");
320
321        (incoming, self.get_recv_from(), self.get_forward_to())
322    }
323}
324
325/// Session Header
326/// This header is used to identify the session and the message
327/// and to manage session state
328impl SessionHeader {
329    pub fn new(
330        session_type: i32,
331        session_message_type: i32,
332        session_id: u32,
333        message_id: u32,
334    ) -> Self {
335        Self {
336            session_type,
337            session_message_type,
338            session_id,
339            message_id,
340        }
341    }
342
343    pub fn get_session_id(&self) -> u32 {
344        self.session_id
345    }
346
347    pub fn get_message_id(&self) -> u32 {
348        self.message_id
349    }
350
351    pub fn set_session_id(&mut self, session_id: u32) {
352        self.session_id = session_id;
353    }
354
355    pub fn set_message_id(&mut self, message_id: u32) {
356        self.message_id = message_id;
357    }
358
359    pub fn clear(&mut self) {
360        self.session_id = 0;
361        self.message_id = 0;
362    }
363}
364
365/// SessionMessageType
366/// Helper methods for session message types
367impl SessionMessageType {
368    /// Check if a message type is a command message (not application data)
369    pub fn is_command_message(&self) -> bool {
370        matches!(
371            self,
372            SessionMessageType::DiscoveryRequest
373                | SessionMessageType::DiscoveryReply
374                | SessionMessageType::JoinRequest
375                | SessionMessageType::JoinReply
376                | SessionMessageType::LeaveRequest
377                | SessionMessageType::LeaveReply
378                | SessionMessageType::GroupAdd
379                | SessionMessageType::GroupRemove
380                | SessionMessageType::GroupWelcome
381                | SessionMessageType::GroupClose
382                | SessionMessageType::GroupProposal
383                | SessionMessageType::GroupAck
384                | SessionMessageType::GroupNack
385                | SessionMessageType::Ping
386        )
387    }
388}
389
390/// ProtoSubscribe
391/// This message is used to subscribe to a topic
392impl ProtoSubscribe {
393    fn new(
394        source: &Name,
395        dst: &Name,
396        identity: Option<&str>,
397        flags: Option<SlimHeaderFlags>,
398    ) -> Self {
399        let id = identity.unwrap_or("");
400        let header = Some(SlimHeader::new(source, dst, id, flags));
401
402        ProtoSubscribe {
403            header,
404            subscription_id: 0,
405        }
406    }
407}
408
409/// From ProtoMessage to ProtoSubscribe
410impl From<ProtoMessage> for ProtoSubscribe {
411    fn from(message: ProtoMessage) -> Self {
412        match message.message_type {
413            Some(ProtoSubscribeType(s)) => s,
414            _ => panic!("message type is not subscribe"),
415        }
416    }
417}
418
419/// ProtoUnsubscribe
420/// This message is used to unsubscribe from a topic
421impl ProtoUnsubscribe {
422    fn new(
423        source: &Name,
424        dst: &Name,
425        identity: Option<&str>,
426        flags: Option<SlimHeaderFlags>,
427    ) -> Self {
428        let id = identity.unwrap_or("");
429        let header = Some(SlimHeader::new(source, dst, id, flags));
430
431        ProtoUnsubscribe {
432            header,
433            subscription_id: 0,
434        }
435    }
436}
437
438/// From ProtoMessage to ProtoUnsubscribe
439impl From<ProtoMessage> for ProtoUnsubscribe {
440    fn from(message: ProtoMessage) -> Self {
441        match message.message_type {
442            Some(ProtoUnsubscribeType(u)) => u,
443            _ => panic!("message type is not unsubscribe"),
444        }
445    }
446}
447
448/// ProtoPublish
449/// This message is used to publish a message, either to a shared channel or to a specific application
450impl ProtoPublish {
451    fn with_header(
452        header: Option<SlimHeader>,
453        session: Option<SessionHeader>,
454        payload: Option<Content>,
455    ) -> Self {
456        ProtoPublish {
457            header,
458            session,
459            msg: payload,
460        }
461    }
462
463    pub fn get_slim_header(&self) -> &SlimHeader {
464        self.header.as_ref().unwrap()
465    }
466
467    pub fn get_session_header(&self) -> &SessionHeader {
468        self.session.as_ref().unwrap()
469    }
470
471    pub fn get_slim_header_as_mut(&mut self) -> &mut SlimHeader {
472        self.header.as_mut().unwrap()
473    }
474
475    pub fn get_session_header_as_mut(&mut self) -> &mut SessionHeader {
476        self.session.as_mut().unwrap()
477    }
478
479    pub fn get_payload(&self) -> &Content {
480        self.msg.as_ref().unwrap()
481    }
482
483    pub fn set_payload(&mut self, payload: Content) {
484        self.msg = Some(payload);
485    }
486
487    pub fn is_command(&self) -> bool {
488        match &self.get_payload().content_type.as_ref().unwrap() {
489            ContentType::AppPayload(_) => false,
490            ContentType::CommandPayload(_) => true,
491        }
492    }
493
494    pub fn get_application_payload(&self) -> &ApplicationPayload {
495        match self.get_payload().content_type.as_ref().unwrap() {
496            ContentType::AppPayload(application_payload) => application_payload,
497            ContentType::CommandPayload(_) => panic!("the payload is not an application payload"),
498        }
499    }
500
501    pub fn get_command_payload(&self) -> &CommandPayload {
502        match &self.get_payload().content_type.as_ref().unwrap() {
503            ContentType::AppPayload(_) => panic!("the payaoad is not a command payload"),
504            ContentType::CommandPayload(command_payload) => command_payload,
505        }
506    }
507}
508
509/// From ProtoMessage to ProtoPublish
510impl From<ProtoMessage> for ProtoPublish {
511    fn from(message: ProtoMessage) -> Self {
512        match message.message_type {
513            Some(ProtoPublishType(p)) => p,
514            _ => panic!("message type is not publish"),
515        }
516    }
517}
518
519/// ProtoMessage
520/// This represents a generic message that can be sent over the network
521// Macro to generate payload extraction methods for ProtoMessage
522macro_rules! impl_payload_extractors {
523    ($($method_name:ident => $getter_method:ident($payload_type:ty)),* $(,)?) => {
524        $(
525            /// Extracts a specific command payload from the message.
526            pub fn $method_name(&self) -> Result<&$payload_type, MessageError> {
527                self.extract_command_payload()?.$getter_method()
528            }
529        )*
530    };
531}
532
533impl ProtoMessage {
534    fn new(metadata: HashMap<String, String>, message_type: MessageType) -> Self {
535        ProtoMessage {
536            metadata,
537            message_type: Some(message_type),
538        }
539    }
540
541    fn validate_link(link: &ProtoLink) -> Result<(), MessageError> {
542        if link.link_type.is_none() {
543            return Err(MessageError::LinkTypeNotSet);
544        }
545        Ok(())
546    }
547
548    fn validate_routed_header(slim_header: &SlimHeader) -> Result<(), MessageError> {
549        if slim_header.source.is_none() {
550            return Err(MessageError::SourceNotFound);
551        }
552        if slim_header.destination.is_none() {
553            return Err(MessageError::DestinationNotFound);
554        }
555        Ok(())
556    }
557
558    fn validate_publish(p: &ProtoPublish) -> Result<(), MessageError> {
559        let hdr = p.header.as_ref().ok_or(MessageError::SlimHeaderNotFound)?;
560        Self::validate_routed_header(hdr)?;
561        if p.session.is_none() {
562            return Err(MessageError::SessionHeaderNotFound);
563        }
564        Ok(())
565    }
566
567    fn validate_subscribe(s: &ProtoSubscribe) -> Result<(), MessageError> {
568        let hdr = s.header.as_ref().ok_or(MessageError::SlimHeaderNotFound)?;
569        Self::validate_routed_header(hdr)
570    }
571
572    fn validate_unsubscribe(u: &ProtoUnsubscribe) -> Result<(), MessageError> {
573        let hdr = u.header.as_ref().ok_or(MessageError::SlimHeaderNotFound)?;
574        Self::validate_routed_header(hdr)
575    }
576
577    // validate message
578    pub fn validate(&self) -> Result<(), MessageError> {
579        match &self.message_type {
580            None => Err(MessageError::MessageTypeNotFound),
581            Some(ProtoLinkMessageType(link)) => Self::validate_link(link),
582            Some(ProtoPublishType(p)) => Self::validate_publish(p),
583            Some(ProtoSubscribeType(s)) => Self::validate_subscribe(s),
584            Some(ProtoUnsubscribeType(u)) => Self::validate_unsubscribe(u),
585            Some(ProtoSubscriptionAckType(_)) => Ok(()),
586        }
587    }
588
589    // add metadata key in the map assigning the value val
590    // if the key exists the value is replaced by val
591    pub fn insert_metadata(&mut self, key: String, val: String) {
592        self.metadata.insert(key, val);
593    }
594
595    // remove metadata key from the map
596    pub fn remove_metadata(&mut self, key: &str) -> Option<String> {
597        self.metadata.remove(key)
598    }
599
600    pub fn contains_metadata(&self, key: &str) -> bool {
601        self.metadata.contains_key(key)
602    }
603
604    pub fn get_metadata(&self, key: &str) -> Option<&String> {
605        self.metadata.get(key)
606    }
607
608    pub fn get_metadata_map(&self) -> HashMap<String, String> {
609        self.metadata.clone()
610    }
611
612    pub fn set_metadata_map(&mut self, map: HashMap<String, String>) {
613        for (k, v) in map.iter() {
614            self.insert_metadata(k.to_string(), v.to_string());
615        }
616    }
617
618    pub fn get_slim_header(&self) -> &SlimHeader {
619        match &self.message_type {
620            Some(ProtoPublishType(publish)) => publish.header.as_ref().unwrap(),
621            Some(ProtoSubscribeType(sub)) => sub.header.as_ref().unwrap(),
622            Some(ProtoUnsubscribeType(unsub)) => unsub.header.as_ref().unwrap(),
623            Some(ProtoLinkMessageType(_)) | Some(ProtoSubscriptionAckType(_)) | None => {
624                panic!("SLIM header not found")
625            }
626        }
627    }
628
629    pub fn get_slim_header_mut(&mut self) -> &mut SlimHeader {
630        match &mut self.message_type {
631            Some(ProtoPublishType(publish)) => publish.header.as_mut().unwrap(),
632            Some(ProtoSubscribeType(sub)) => sub.header.as_mut().unwrap(),
633            Some(ProtoUnsubscribeType(unsub)) => unsub.header.as_mut().unwrap(),
634            Some(ProtoLinkMessageType(_)) | Some(ProtoSubscriptionAckType(_)) | None => {
635                panic!("SLIM header not found")
636            }
637        }
638    }
639
640    pub fn try_get_slim_header(&self) -> Option<&SlimHeader> {
641        match &self.message_type {
642            Some(ProtoPublishType(publish)) => publish.header.as_ref(),
643            Some(ProtoSubscribeType(sub)) => sub.header.as_ref(),
644            Some(ProtoUnsubscribeType(unsub)) => unsub.header.as_ref(),
645            Some(ProtoLinkMessageType(_)) | Some(ProtoSubscriptionAckType(_)) | None => None,
646        }
647    }
648
649    pub fn get_session_header(&self) -> &SessionHeader {
650        match &self.message_type {
651            Some(ProtoPublishType(publish)) => publish.session.as_ref().unwrap(),
652            Some(ProtoSubscribeType(_))
653            | Some(ProtoUnsubscribeType(_))
654            | Some(ProtoLinkMessageType(_))
655            | Some(ProtoSubscriptionAckType(_))
656            | None => panic!("session header not found"),
657        }
658    }
659
660    pub fn get_session_header_mut(&mut self) -> &mut SessionHeader {
661        match &mut self.message_type {
662            Some(ProtoPublishType(publish)) => publish.session.as_mut().unwrap(),
663            Some(ProtoSubscribeType(_))
664            | Some(ProtoUnsubscribeType(_))
665            | Some(ProtoLinkMessageType(_))
666            | Some(ProtoSubscriptionAckType(_))
667            | None => panic!("session header not found"),
668        }
669    }
670
671    pub fn try_get_session_header(&self) -> Option<&SessionHeader> {
672        match &self.message_type {
673            Some(ProtoPublishType(publish)) => publish.session.as_ref(),
674            Some(ProtoSubscribeType(_))
675            | Some(ProtoUnsubscribeType(_))
676            | Some(ProtoLinkMessageType(_))
677            | Some(ProtoSubscriptionAckType(_))
678            | None => None,
679        }
680    }
681
682    pub fn try_get_session_header_mut(&mut self) -> Option<&mut SessionHeader> {
683        match &mut self.message_type {
684            Some(ProtoPublishType(publish)) => publish.session.as_mut(),
685            Some(ProtoSubscribeType(_))
686            | Some(ProtoUnsubscribeType(_))
687            | Some(ProtoLinkMessageType(_))
688            | Some(ProtoSubscriptionAckType(_))
689            | None => None,
690        }
691    }
692
693    pub fn get_id(&self) -> u32 {
694        self.get_session_header().get_message_id()
695    }
696
697    pub fn get_source(&self) -> Name {
698        self.get_slim_header().get_source()
699    }
700
701    pub fn get_dst(&self) -> Name {
702        self.get_slim_header().get_dst()
703    }
704
705    pub fn get_identity(&self) -> String {
706        self.get_slim_header().get_identity()
707    }
708
709    pub fn get_fanout(&self) -> u32 {
710        self.get_slim_header().get_fanout()
711    }
712
713    pub fn get_recv_from(&self) -> Option<u64> {
714        self.get_slim_header().get_recv_from()
715    }
716
717    pub fn get_forward_to(&self) -> Option<u64> {
718        self.get_slim_header().get_forward_to()
719    }
720
721    pub fn get_error(&self) -> Option<bool> {
722        self.get_slim_header().get_error()
723    }
724
725    pub fn get_incoming_conn(&self) -> u64 {
726        self.get_slim_header().get_incoming_conn().unwrap()
727    }
728
729    pub fn try_get_incoming_conn(&self) -> Option<u64> {
730        self.get_slim_header().get_incoming_conn()
731    }
732
733    pub fn get_type(&self) -> &MessageType {
734        match &self.message_type {
735            Some(t) => t,
736            None => panic!("message type not found"),
737        }
738    }
739
740    pub fn get_payload(&self) -> Option<&Content> {
741        match &self.message_type {
742            Some(ProtoPublishType(p)) => p.msg.as_ref(),
743            Some(ProtoSubscribeType(_)) => panic!("payload not found"),
744            Some(ProtoUnsubscribeType(_)) => panic!("payload not found"),
745            Some(ProtoLinkMessageType(_)) => panic!("payload not found"),
746            Some(ProtoSubscriptionAckType(_)) => panic!("payload not found"),
747            None => panic!("payload not found"),
748        }
749    }
750
751    pub fn set_payload(&mut self, payload: Content) {
752        match &mut self.message_type {
753            Some(ProtoPublishType(p)) => p.set_payload(payload),
754            Some(ProtoSubscribeType(_)) => panic!("no payload allowed"),
755            Some(ProtoUnsubscribeType(_)) => panic!("no payload allowed"),
756            Some(ProtoLinkMessageType(_)) => panic!("no payload allowed"),
757            Some(ProtoSubscriptionAckType(_)) => panic!("no payload allowed"),
758            None => panic!("no payload allowed"),
759        }
760    }
761
762    pub fn get_session_message_type(&self) -> SessionMessageType {
763        self.get_session_header()
764            .session_message_type
765            .try_into()
766            .unwrap_or_default()
767    }
768
769    pub fn clear_slim_header(&mut self) {
770        if self.is_link() || self.is_subscription_ack() {
771            return;
772        }
773        self.get_slim_header_mut().clear_flags();
774    }
775
776    pub fn set_recv_from(&mut self, recv_from: Option<u64>) {
777        self.get_slim_header_mut().set_recv_from(recv_from);
778    }
779
780    pub fn set_forward_to(&mut self, forward_to: Option<u64>) {
781        self.get_slim_header_mut().set_forward_to(forward_to);
782    }
783
784    pub fn set_error(&mut self, error: Option<bool>) {
785        self.get_slim_header_mut().set_error(error);
786    }
787
788    pub fn set_fanout(&mut self, fanout: u32) {
789        self.get_slim_header_mut().set_fanout(fanout);
790    }
791
792    pub fn set_incoming_conn(&mut self, incoming_conn: Option<u64>) {
793        self.get_slim_header_mut().set_incoming_conn(incoming_conn);
794    }
795
796    pub fn set_error_flag(&mut self, error: Option<bool>) {
797        self.get_slim_header_mut().set_error_flag(error);
798    }
799
800    pub fn set_session_message_type(&mut self, message_type: SessionMessageType) {
801        self.get_session_header_mut()
802            .set_session_message_type(message_type);
803    }
804
805    pub fn set_session_type(&mut self, session_type: ProtoSessionType) {
806        self.get_session_header_mut().set_session_type(session_type);
807    }
808
809    pub fn get_session_type(&self) -> ProtoSessionType {
810        self.get_session_header().session_type()
811    }
812
813    pub fn set_message_id(&mut self, message_id: u32) {
814        self.get_session_header_mut().set_message_id(message_id);
815    }
816
817    pub fn is_publish(&self) -> bool {
818        matches!(self.get_type(), MessageType::Publish(_))
819    }
820
821    pub fn is_subscribe(&self) -> bool {
822        matches!(self.get_type(), MessageType::Subscribe(_))
823    }
824
825    pub fn is_unsubscribe(&self) -> bool {
826        matches!(self.get_type(), MessageType::Unsubscribe(_))
827    }
828
829    pub fn is_link(&self) -> bool {
830        matches!(self.get_type(), MessageType::Link(_))
831    }
832
833    pub fn is_subscription_ack(&self) -> bool {
834        matches!(self.get_type(), MessageType::SubscriptionAck(_))
835    }
836
837    pub fn is_traceable(&self) -> bool {
838        !self.is_link() && !self.is_subscription_ack()
839    }
840
841    pub fn get_subscription_ack(&self) -> &ProtoSubscriptionAck {
842        match &self.message_type {
843            Some(ProtoSubscriptionAckType(ack)) => ack,
844            _ => panic!("message type is not subscription_ack"),
845        }
846    }
847
848    /// Returns the subscription_id from a Subscribe/Unsubscribe message, or None if absent/zero.
849    pub fn get_subscription_id(&self) -> Option<u64> {
850        match &self.message_type {
851            Some(ProtoSubscribeType(s)) if s.subscription_id != 0 => Some(s.subscription_id),
852            Some(ProtoUnsubscribeType(u)) if u.subscription_id != 0 => Some(u.subscription_id),
853            _ => None,
854        }
855    }
856
857    /// Takes and clears the subscription_id from a Subscribe/Unsubscribe message.
858    /// Returns None if the field is absent or zero.
859    pub fn take_subscription_id(&mut self) -> Option<u64> {
860        match &mut self.message_type {
861            Some(ProtoSubscribeType(s)) if s.subscription_id != 0 => {
862                Some(std::mem::take(&mut s.subscription_id))
863            }
864            Some(ProtoUnsubscribeType(u)) if u.subscription_id != 0 => {
865                Some(std::mem::take(&mut u.subscription_id))
866            }
867            _ => None,
868        }
869    }
870
871    /// Sets the subscription_id on a Subscribe/Unsubscribe message (no-op for other types).
872    pub fn set_subscription_id(&mut self, subscription_id: u64) {
873        match &mut self.message_type {
874            Some(ProtoSubscribeType(s)) => s.subscription_id = subscription_id,
875            Some(ProtoUnsubscribeType(u)) => u.subscription_id = subscription_id,
876            _ => {}
877        }
878    }
879
880    /// Extracts the command payload from the message.
881    ///
882    /// # Errors
883    /// Returns `MessageError` if the payload is missing or cannot be converted.
884    pub fn extract_command_payload(&self) -> Result<&CommandPayload, MessageError> {
885        self.get_payload()
886            .ok_or(MessageError::ContentTypeNotSet)?
887            .as_command_payload()
888    }
889
890    // Generate all payload extraction methods
891    impl_payload_extractors! {
892        extract_discovery_request => as_discovery_request_payload(DiscoveryRequestPayload),
893        extract_discovery_reply => as_discovery_reply_payload(DiscoveryReplyPayload),
894        extract_join_request => as_join_request_payload(JoinRequestPayload),
895        extract_join_reply => as_join_reply_payload(JoinReplyPayload),
896        extract_leave_request => as_leave_request_payload(LeaveRequestPayload),
897        extract_leave_reply => as_leave_reply_payload(LeaveReplyPayload),
898        extract_group_add => as_group_add_payload(GroupAddPayload),
899        extract_group_remove => as_group_remove_payload(GroupRemovePayload),
900        extract_group_welcome => as_welcome_payload(GroupWelcomePayload),
901        extract_group_close => as_group_close_payload(GroupClosePayload),
902        extract_group_proposal => as_group_proposal_payload(GroupProposalPayload),
903        extract_group_ack => as_group_ack_payload(GroupAckPayload),
904        extract_group_nack => as_group_nack_payload(GroupNackPayload),
905        extract_ping => as_ping_payload(PingPayload),
906    }
907}
908
909impl Content {
910    pub fn as_application_payload(&self) -> Result<&ApplicationPayload, MessageError> {
911        match &self.content_type {
912            Some(ContentType::AppPayload(app_payload)) => Ok(app_payload),
913            Some(ContentType::CommandPayload(_)) => Err(MessageError::NotApplicationPayload),
914            None => Err(MessageError::ContentTypeNotSet),
915        }
916    }
917
918    pub fn as_command_payload(&self) -> Result<&CommandPayload, MessageError> {
919        match &self.content_type {
920            Some(ContentType::AppPayload(_)) => Err(MessageError::NotCommandPayload),
921            Some(ContentType::CommandPayload(comm_payload)) => Ok(comm_payload),
922            None => Err(MessageError::ContentTypeNotSet),
923        }
924    }
925}
926
927impl ApplicationPayload {
928    pub fn new(payload_type: &str, blob: Vec<u8>) -> Self {
929        Self {
930            payload_type: payload_type.to_string(),
931            blob,
932        }
933    }
934
935    pub fn as_content(&self) -> Content {
936        Content {
937            content_type: Some(ContentType::AppPayload(self.clone())),
938        }
939    }
940}
941
942// Macro to generate getter methods for all CommandPayloadType variants
943macro_rules! impl_command_payload_getters {
944    ($(
945        $method_name:ident => $variant:ident($payload_type:ty)
946    ),* $(,)?) => {
947        $(
948            pub fn $method_name(&self) -> Result<&$payload_type, MessageError> {
949                match &self.command_payload_type {
950                    Some(CommandPayloadType::$variant(payload)) => Ok(payload),
951                    Some(other) => Err(MessageError::InvalidCommandPayloadType {
952                        expected: Box::new(stringify!($variant).to_string()),
953                        got: Box::new(format!("{:?}", other)),
954                    }),
955                    None => Err(MessageError::InvalidCommandPayloadType {
956                        expected: Box::new(stringify!($variant).to_string()),
957                        got: Box::new("None".to_string()),
958                    }),
959                }
960            }
961        )*
962    };
963}
964
965impl CommandPayload {
966    pub fn as_content(self) -> Content {
967        Content {
968            content_type: Some(ContentType::CommandPayload(self)),
969        }
970    }
971
972    // Getter methods for all CommandPayloadType variants
973    impl_command_payload_getters! {
974        as_discovery_request_payload => DiscoveryRequest(DiscoveryRequestPayload),
975        as_discovery_reply_payload => DiscoveryReply(DiscoveryReplyPayload),
976        as_join_request_payload => JoinRequest(JoinRequestPayload),
977        as_join_reply_payload => JoinReply(JoinReplyPayload),
978        as_leave_request_payload => LeaveRequest(LeaveRequestPayload),
979        as_leave_reply_payload => LeaveReply(LeaveReplyPayload),
980        as_group_add_payload => GroupAdd(GroupAddPayload),
981        as_group_remove_payload => GroupRemove(GroupRemovePayload),
982        as_welcome_payload => GroupWelcome(GroupWelcomePayload),
983        as_group_close_payload => GroupClose(GroupClosePayload),
984        as_group_proposal_payload => GroupProposal(GroupProposalPayload),
985        as_group_ack_payload => GroupAck(GroupAckPayload),
986        as_group_nack_payload => GroupNack(GroupNackPayload),
987        as_ping_payload => Ping(PingPayload),
988    }
989}
990
991impl AsRef<ProtoPublish> for ProtoMessage {
992    fn as_ref(&self) -> &ProtoPublish {
993        match &self.message_type {
994            Some(ProtoPublishType(p)) => p,
995            _ => panic!("message type is not publish"),
996        }
997    }
998}
999
1000/// Builder for creating CommandPayload instances with a fluent API
1001///
1002/// Provides methods for creating all types of command payloads.
1003///
1004/// # Examples
1005///
1006/// ## Discovery Request
1007/// ```
1008/// use slim_datapath::api::CommandPayload;
1009/// use slim_datapath::messages::Name;
1010///
1011/// let dest = Name::from_strings(["org", "namespace", "service"]);
1012/// let payload = CommandPayload::builder().discovery_request(Some(dest));
1013/// ```
1014///
1015/// ## Join Request with Timer Settings
1016/// ```
1017/// use slim_datapath::api::CommandPayload;
1018/// use slim_datapath::messages::Name;
1019/// use std::time::Duration;
1020///
1021/// let channel = Name::from_strings(["org", "namespace", "channel"]);
1022/// let payload = CommandPayload::builder().join_request(
1023///     true,  // enable_mls
1024///     Some(5),  // max_retries
1025///     Some(Duration::from_secs(10)),  // timeout
1026///     Some(channel),
1027/// );
1028/// ```
1029///
1030/// ## Group Operations
1031/// ```
1032/// use slim_datapath::api::CommandPayload;
1033/// use slim_datapath::messages::Name;
1034///
1035/// let participant = Name::from_strings(["org", "ns", "user1"]);
1036/// let participants = vec![
1037///     Name::from_strings(["org", "ns", "user2"]),
1038///     Name::from_strings(["org", "ns", "user3"]),
1039/// ];
1040///
1041/// // Add participant
1042/// let add_payload = CommandPayload::builder().group_add(
1043///     participant.clone(),
1044///     participants.clone(),
1045///     None,  // mls payload
1046/// );
1047/// ```
1048pub struct CommandPayloadBuilder;
1049
1050impl CommandPayloadBuilder {
1051    /// Creates a new CommandPayloadBuilder
1052    pub fn new() -> Self {
1053        Self
1054    }
1055
1056    /// Creates a discovery request payload
1057    pub fn discovery_request(self, destination: Option<Name>) -> CommandPayload {
1058        let proto_destination = destination.as_ref().map(ProtoName::from);
1059        let payload = DiscoveryRequestPayload {
1060            destination: proto_destination,
1061        };
1062        CommandPayload {
1063            command_payload_type: Some(CommandPayloadType::DiscoveryRequest(payload)),
1064        }
1065    }
1066
1067    /// Creates a discovery reply payload
1068    pub fn discovery_reply(self) -> CommandPayload {
1069        let payload = DiscoveryReplyPayload {};
1070        CommandPayload {
1071            command_payload_type: Some(CommandPayloadType::DiscoveryReply(payload)),
1072        }
1073    }
1074
1075    /// Creates a join request payload
1076    pub fn join_request(
1077        self,
1078        enable_mls: bool,
1079        max_retries: Option<u32>,
1080        timer_duration: Option<Duration>,
1081        channel: Option<Name>,
1082    ) -> CommandPayload {
1083        let proto_channel = channel.as_ref().map(ProtoName::from);
1084
1085        let timer_settings = if let Some(t) = timer_duration
1086            && let Some(m) = max_retries
1087        {
1088            Some(TimerSettings {
1089                timeout: t.as_millis() as u32,
1090                max_retries: m,
1091            })
1092        } else {
1093            None
1094        };
1095
1096        let payload = JoinRequestPayload {
1097            enable_mls,
1098            timer_settings,
1099            channel: proto_channel,
1100        };
1101        CommandPayload {
1102            command_payload_type: Some(CommandPayloadType::JoinRequest(payload)),
1103        }
1104    }
1105
1106    /// Creates a join reply payload
1107    pub fn join_reply(self, key_package: Option<Vec<u8>>) -> CommandPayload {
1108        let payload = JoinReplyPayload { key_package };
1109        CommandPayload {
1110            command_payload_type: Some(CommandPayloadType::JoinReply(payload)),
1111        }
1112    }
1113
1114    /// Creates a leave request payload
1115    pub fn leave_request(self, destination: Option<Name>) -> CommandPayload {
1116        let proto_destination = destination.as_ref().map(ProtoName::from);
1117        let payload = LeaveRequestPayload {
1118            destination: proto_destination,
1119        };
1120        CommandPayload {
1121            command_payload_type: Some(CommandPayloadType::LeaveRequest(payload)),
1122        }
1123    }
1124
1125    /// Creates a leave reply payload
1126    pub fn leave_reply(self) -> CommandPayload {
1127        let payload = LeaveReplyPayload {};
1128        CommandPayload {
1129            command_payload_type: Some(CommandPayloadType::LeaveReply(payload)),
1130        }
1131    }
1132
1133    /// Creates a group add payload
1134    pub fn group_add(
1135        self,
1136        new_participant: Name,
1137        participants: Vec<Name>,
1138        mls: Option<MlsPayload>,
1139    ) -> CommandPayload {
1140        let proto_new_participant = Some(ProtoName::from(&new_participant));
1141        let proto_participants = participants.iter().map(ProtoName::from).collect();
1142
1143        let payload = GroupAddPayload {
1144            new_participant: proto_new_participant,
1145            participants: proto_participants,
1146            mls,
1147        };
1148        CommandPayload {
1149            command_payload_type: Some(CommandPayloadType::GroupAdd(payload)),
1150        }
1151    }
1152
1153    /// Creates a group remove payload
1154    pub fn group_remove(
1155        self,
1156        removed_participant: Name,
1157        participants: Vec<Name>,
1158        mls: Option<MlsPayload>,
1159    ) -> CommandPayload {
1160        let proto_removed_participant = Some(ProtoName::from(&removed_participant));
1161        let proto_participants = participants.iter().map(ProtoName::from).collect();
1162
1163        let payload = GroupRemovePayload {
1164            removed_participant: proto_removed_participant,
1165            participants: proto_participants,
1166            mls,
1167        };
1168        CommandPayload {
1169            command_payload_type: Some(CommandPayloadType::GroupRemove(payload)),
1170        }
1171    }
1172
1173    /// Creates a group welcome payload
1174    pub fn group_welcome(self, participants: Vec<Name>, mls: Option<MlsPayload>) -> CommandPayload {
1175        let proto_participants = participants.iter().map(ProtoName::from).collect();
1176
1177        let payload = GroupWelcomePayload {
1178            participants: proto_participants,
1179            mls,
1180        };
1181        CommandPayload {
1182            command_payload_type: Some(CommandPayloadType::GroupWelcome(payload)),
1183        }
1184    }
1185
1186    /// Creates a group close payload
1187    pub fn group_close(self, participants: Vec<Name>) -> CommandPayload {
1188        let proto_participants = participants.iter().map(ProtoName::from).collect();
1189
1190        let payload = GroupClosePayload {
1191            participants: proto_participants,
1192        };
1193        CommandPayload {
1194            command_payload_type: Some(CommandPayloadType::GroupClose(payload)),
1195        }
1196    }
1197
1198    /// Creates a group proposal payload
1199    pub fn group_proposal(self, source: Option<Name>, mls_proposal: Vec<u8>) -> CommandPayload {
1200        let proto_source = source.as_ref().map(ProtoName::from);
1201        let payload = GroupProposalPayload {
1202            source: proto_source,
1203            mls_proposal,
1204        };
1205        CommandPayload {
1206            command_payload_type: Some(CommandPayloadType::GroupProposal(payload)),
1207        }
1208    }
1209
1210    /// Creates a group ack payload
1211    pub fn group_ack(self) -> CommandPayload {
1212        let payload = GroupAckPayload {};
1213        CommandPayload {
1214            command_payload_type: Some(CommandPayloadType::GroupAck(payload)),
1215        }
1216    }
1217
1218    /// Creates a group nack payload
1219    pub fn group_nack(self) -> CommandPayload {
1220        let payload = GroupNackPayload {};
1221        CommandPayload {
1222            command_payload_type: Some(CommandPayloadType::GroupNack(payload)),
1223        }
1224    }
1225
1226    /// Creates a ping payload
1227    pub fn ping(self) -> CommandPayload {
1228        let payload = PingPayload {};
1229        CommandPayload {
1230            command_payload_type: Some(CommandPayloadType::Ping(payload)),
1231        }
1232    }
1233}
1234
1235impl Default for CommandPayloadBuilder {
1236    fn default() -> Self {
1237        Self::new()
1238    }
1239}
1240
1241impl CommandPayload {
1242    /// Creates a new builder for CommandPayload
1243    pub fn builder() -> CommandPayloadBuilder {
1244        CommandPayloadBuilder::new()
1245    }
1246}
1247
1248/// Builder for creating ProtoMessage instances with a fluent API
1249///
1250/// # Examples
1251///
1252/// ## Basic Publish Message
1253/// ```
1254/// use slim_datapath::api::{ProtoMessage, ProtoSessionType};
1255/// use slim_datapath::messages::Name;
1256///
1257/// let source = Name::from_strings(["org", "ns", "app"]).with_id(1);
1258/// let dest = Name::from_strings(["org", "ns", "service"]).with_id(2);
1259///
1260/// let msg = ProtoMessage::builder()
1261///     .source(source)
1262///     .destination(dest)
1263///     .session_type(ProtoSessionType::PointToPoint)
1264///     .session_id(123)
1265///     .application_payload("text", b"Hello".to_vec())
1266///     .build_publish()
1267///     .unwrap();
1268/// ```
1269///
1270/// ## Session Control Message
1271/// ```
1272/// use slim_datapath::api::{CommandPayload, ProtoMessage, ProtoSessionType, ProtoSessionMessageType};
1273/// use slim_datapath::messages::Name;
1274///
1275/// let source = Name::from_strings(["org", "ns", "app"]);
1276/// let dest = Name::from_strings(["org", "ns", "service"]);
1277///
1278/// let cmd = CommandPayload::builder().discovery_request(Some(dest.clone()));
1279///
1280/// let msg = ProtoMessage::builder()
1281///     .source(source)
1282///     .destination(dest)
1283///     .session_type(ProtoSessionType::PointToPoint)
1284///     .session_message_type(ProtoSessionMessageType::DiscoveryRequest)
1285///     .session_id(42)
1286///     .command_payload(cmd)
1287///     .build_publish()
1288///     .unwrap();
1289/// ```
1290///
1291/// ## Multicast with Broadcast
1292/// ```
1293/// use slim_datapath::api::{ProtoMessage, ProtoSessionType};
1294/// use slim_datapath::messages::Name;
1295///
1296/// let source = Name::from_strings(["org", "ns", "app"]);
1297/// let dest = Name::from_strings(["org", "ns", "channel"]);
1298///
1299/// let msg = ProtoMessage::builder()
1300///     .source(source)
1301///     .destination(dest)
1302///     .session_type(ProtoSessionType::Multicast)
1303///     .fanout(256)
1304///     .application_payload("event", b"broadcast event".to_vec())
1305///     .metadata("priority", "high")
1306///     .build_publish()
1307///     .unwrap();
1308/// ```
1309///
1310/// ## Subscribe/Unsubscribe Messages
1311/// ```
1312/// use slim_datapath::api::ProtoMessage;
1313/// use slim_datapath::messages::Name;
1314///
1315/// let source = Name::from_strings(["org", "ns", "app"]);
1316/// let dest = Name::from_strings(["org", "ns", "topic"]);
1317///
1318/// // Subscribe
1319/// let sub_msg = ProtoMessage::builder()
1320///     .source(source.clone())
1321///     .destination(dest.clone())
1322///     .recv_from(100)
1323///     .build_subscribe()
1324///     .unwrap();
1325///
1326/// // Unsubscribe
1327/// let unsub_msg = ProtoMessage::builder()
1328///     .source(source)
1329///     .destination(dest)
1330///     .build_unsubscribe()
1331///     .unwrap();
1332/// ```
1333pub struct ProtoMessageBuilder {
1334    source: Option<Name>,
1335    destination: Option<Name>,
1336    identity: Option<String>,
1337    flags: Option<SlimHeaderFlags>,
1338    session_type: Option<ProtoSessionType>,
1339    session_message_type: Option<SessionMessageType>,
1340    session_id: Option<u32>,
1341    message_id: Option<u32>,
1342    payload: Option<Content>,
1343    metadata: HashMap<String, String>,
1344    subscription_id: Option<u64>,
1345}
1346
1347impl ProtoMessageBuilder {
1348    /// Creates a new ProtoMessageBuilder
1349    pub fn new() -> Self {
1350        Self {
1351            source: None,
1352            destination: None,
1353            identity: None,
1354            flags: None,
1355            session_type: None,
1356            session_message_type: None,
1357            session_id: None,
1358            message_id: None,
1359            payload: None,
1360            metadata: HashMap::new(),
1361            subscription_id: None,
1362        }
1363    }
1364
1365    /// Sets the source name
1366    pub fn source(mut self, source: Name) -> Self {
1367        self.source = Some(source);
1368        self
1369    }
1370
1371    /// Sets the destination name
1372    pub fn destination(mut self, destination: Name) -> Self {
1373        self.destination = Some(destination);
1374        self
1375    }
1376
1377    /// Sets the identity string
1378    pub fn identity(mut self, identity: impl Into<String>) -> Self {
1379        self.identity = Some(identity.into());
1380        self
1381    }
1382
1383    /// Sets the SLIM header flags
1384    pub fn flags(mut self, flags: SlimHeaderFlags) -> Self {
1385        self.flags = Some(flags);
1386        self
1387    }
1388
1389    /// Sets the fanout value
1390    pub fn fanout(mut self, fanout: u32) -> Self {
1391        let flags = self.flags.take().unwrap_or_default();
1392        self.flags = Some(flags.with_fanout(fanout));
1393        self
1394    }
1395
1396    /// Sets the recv_from connection
1397    pub fn recv_from(mut self, recv_from: u64) -> Self {
1398        let flags = self.flags.take().unwrap_or_default();
1399        self.flags = Some(flags.with_recv_from(recv_from));
1400        self
1401    }
1402
1403    /// Sets the forward_to connection
1404    pub fn forward_to(mut self, forward_to: u64) -> Self {
1405        let flags = self.flags.take().unwrap_or_default();
1406        self.flags = Some(flags.with_forward_to(forward_to));
1407        self
1408    }
1409
1410    /// Sets the incoming connection
1411    pub fn incoming_conn(mut self, incoming_conn: u64) -> Self {
1412        let flags = self.flags.take().unwrap_or_default();
1413        self.flags = Some(flags.with_incoming_conn(incoming_conn));
1414        self
1415    }
1416
1417    /// Sets the error flag
1418    pub fn error(mut self, error: bool) -> Self {
1419        let flags = self.flags.take().unwrap_or_default();
1420        self.flags = Some(flags.with_error(error));
1421        self
1422    }
1423
1424    /// Sets the session type
1425    pub fn session_type(mut self, session_type: ProtoSessionType) -> Self {
1426        self.session_type = Some(session_type);
1427        self
1428    }
1429
1430    /// Sets the session message type
1431    pub fn session_message_type(mut self, session_message_type: SessionMessageType) -> Self {
1432        self.session_message_type = Some(session_message_type);
1433        self
1434    }
1435
1436    /// Sets the session ID
1437    pub fn session_id(mut self, session_id: u32) -> Self {
1438        self.session_id = Some(session_id);
1439        self
1440    }
1441
1442    /// Sets the message ID
1443    pub fn message_id(mut self, message_id: u32) -> Self {
1444        self.message_id = Some(message_id);
1445        self
1446    }
1447
1448    /// Sets the message payload
1449    pub fn payload(mut self, payload: Content) -> Self {
1450        self.payload = Some(payload);
1451        self
1452    }
1453
1454    /// Sets an application payload
1455    pub fn application_payload(mut self, payload_type: &str, blob: Vec<u8>) -> Self {
1456        let app_payload = ApplicationPayload::new(payload_type, blob);
1457        self.payload = Some(app_payload.as_content());
1458        self
1459    }
1460
1461    /// Sets a command payload
1462    pub fn command_payload(mut self, payload: CommandPayload) -> Self {
1463        self.payload = Some(payload.as_content());
1464        self
1465    }
1466
1467    /// Sets a pre-built SlimHeader (for low-level use cases)
1468    ///
1469    /// This is a convenience method for cases where you already have a constructed SlimHeader.
1470    /// For most cases, prefer using the individual builder methods like `source()`, `destination()`, etc.
1471    pub fn with_slim_header(mut self, header: SlimHeader) -> Self {
1472        // Extract fields from the header
1473        if let Some(src) = &header.source {
1474            self.source = Some(Name::from(src));
1475        }
1476        if let Some(dst) = &header.destination {
1477            self.destination = Some(Name::from(dst));
1478        }
1479        if !header.identity.is_empty() {
1480            self.identity = Some(header.identity.clone());
1481        }
1482
1483        // Extract flags
1484        let flags = SlimHeaderFlags {
1485            fanout: header.fanout,
1486            recv_from: header.recv_from,
1487            forward_to: header.forward_to,
1488            incoming_conn: header.incoming_conn,
1489            error: header.error,
1490        };
1491        self.flags = Some(flags);
1492        self
1493    }
1494
1495    /// Sets a pre-built SessionHeader (for low-level use cases)
1496    ///
1497    /// This is a convenience method for cases where you already have a constructed SessionHeader.
1498    /// For most cases, prefer using the individual builder methods like `session_type()`, `session_message_type()`, etc.
1499    pub fn with_session_header(mut self, header: SessionHeader) -> Self {
1500        self.session_type = Some(
1501            ProtoSessionType::try_from(header.session_type)
1502                .unwrap_or(ProtoSessionType::PointToPoint),
1503        );
1504        self.session_message_type = Some(
1505            SessionMessageType::try_from(header.session_message_type)
1506                .unwrap_or(SessionMessageType::Msg),
1507        );
1508        self.session_id = Some(header.session_id);
1509        self.message_id = Some(header.message_id);
1510        self
1511    }
1512
1513    /// Adds metadata to the message
1514    pub fn metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
1515        self.metadata.insert(key.into(), value.into());
1516        self
1517    }
1518
1519    /// Adds multiple metadata entries
1520    pub fn metadata_map(mut self, map: HashMap<String, String>) -> Self {
1521        self.metadata.extend(map);
1522        self
1523    }
1524
1525    /// Sets the subscription_id for subscribe/unsubscribe messages.
1526    pub fn subscription_id(mut self, id: u64) -> Self {
1527        self.subscription_id = Some(id);
1528        self
1529    }
1530
1531    /// Builds a publish message
1532    pub fn build_publish(self) -> Result<ProtoMessage, MessageError> {
1533        let source = self
1534            .source
1535            .ok_or(MessageError::BuilderErrorSourceRequired)?;
1536        let destination = self
1537            .destination
1538            .ok_or(MessageError::BuilderErrorDestinationRequired)?;
1539
1540        let slim_header = Some(SlimHeader::new(
1541            &source,
1542            &destination,
1543            self.identity.as_deref().unwrap_or(""),
1544            self.flags,
1545        ));
1546
1547        let session_header = if self.session_type.is_some() || self.session_message_type.is_some() {
1548            Some(SessionHeader::new(
1549                self.session_type
1550                    .unwrap_or(ProtoSessionType::PointToPoint)
1551                    .into(),
1552                self.session_message_type
1553                    .unwrap_or(SessionMessageType::Msg)
1554                    .into(),
1555                self.session_id.unwrap_or(0),
1556                self.message_id.unwrap_or_else(rand::random),
1557            ))
1558        } else {
1559            Some(SessionHeader::default())
1560        };
1561
1562        let publish = ProtoPublish::with_header(slim_header, session_header, self.payload);
1563        let message = ProtoMessage::new(self.metadata, ProtoPublishType(publish));
1564        Ok(message)
1565    }
1566
1567    /// Builds a subscribe message
1568    pub fn build_subscribe(self) -> Result<ProtoMessage, MessageError> {
1569        let source = self
1570            .source
1571            .ok_or(MessageError::BuilderErrorSourceRequired)?;
1572        let destination = self
1573            .destination
1574            .ok_or(MessageError::BuilderErrorDestinationRequired)?;
1575
1576        let mut subscribe =
1577            ProtoSubscribe::new(&source, &destination, self.identity.as_deref(), self.flags);
1578        subscribe.subscription_id = self.subscription_id.unwrap_or_default();
1579
1580        Ok(ProtoMessage::new(
1581            self.metadata,
1582            ProtoSubscribeType(subscribe),
1583        ))
1584    }
1585
1586    /// Builds an unsubscribe message
1587    pub fn build_unsubscribe(self) -> Result<ProtoMessage, MessageError> {
1588        let source = self
1589            .source
1590            .ok_or(MessageError::BuilderErrorSourceRequired)?;
1591        let destination = self
1592            .destination
1593            .ok_or(MessageError::BuilderErrorDestinationRequired)?;
1594
1595        let mut unsubscribe =
1596            ProtoUnsubscribe::new(&source, &destination, self.identity.as_deref(), self.flags);
1597        unsubscribe.subscription_id = self.subscription_id.unwrap_or_default();
1598
1599        Ok(ProtoMessage::new(
1600            self.metadata,
1601            ProtoUnsubscribeType(unsubscribe),
1602        ))
1603    }
1604
1605    /// Builds a subscription ack message.
1606    /// SubscriptionAck messages are delivered directly to the requesting connection
1607    /// and are never routed through the subscription table.
1608    pub fn build_subscription_ack(
1609        self,
1610        subscription_id: u64,
1611        success: bool,
1612        error: impl Into<String>,
1613    ) -> ProtoMessage {
1614        let ack = ProtoSubscriptionAck {
1615            subscription_id,
1616            success,
1617            error: error.into(),
1618        };
1619        ProtoMessage::new(self.metadata, ProtoSubscriptionAckType(ack))
1620    }
1621
1622    /// Builds a link negotiation message.
1623    /// Link messages are link-local and never routed; they carry no SLIM header.
1624    pub fn build_link_negotiation(
1625        self,
1626        link_id: impl Into<String>,
1627        slim_version: impl Into<String>,
1628        is_reply: bool,
1629    ) -> ProtoMessage {
1630        let link = ProtoLink {
1631            link_type: Some(ProtoLinkType::LinkNegotiation(LinkNegotiationPayload {
1632                link_id: link_id.into(),
1633                slim_version: slim_version.into(),
1634                is_reply,
1635            })),
1636        };
1637        ProtoMessage::new(self.metadata, ProtoLinkMessageType(link))
1638    }
1639}
1640
1641impl Default for ProtoMessageBuilder {
1642    fn default() -> Self {
1643        Self::new()
1644    }
1645}
1646
1647impl ProtoMessage {
1648    /// Creates a new builder for ProtoMessage
1649    pub fn builder() -> ProtoMessageBuilder {
1650        ProtoMessageBuilder::new()
1651    }
1652}
1653
1654#[cfg(test)]
1655mod tests {
1656    use crate::{api::proto::dataplane::v1::SessionMessageType, messages::encoder::Name};
1657
1658    use super::*;
1659
1660    fn test_subscription_template(
1661        subscription: bool,
1662        source: Name,
1663        dst: Name,
1664        identity: Option<&str>,
1665        flags: Option<SlimHeaderFlags>,
1666    ) {
1667        let sub = {
1668            let mut builder = ProtoMessage::builder()
1669                .source(source.clone())
1670                .destination(dst.clone());
1671
1672            if let Some(id) = identity {
1673                builder = builder.identity(id);
1674            }
1675
1676            if let Some(f) = flags.clone() {
1677                builder = builder.flags(f);
1678            }
1679
1680            if subscription {
1681                builder.build_subscribe().unwrap()
1682            } else {
1683                builder.build_unsubscribe().unwrap()
1684            }
1685        };
1686
1687        let flags = if flags.is_none() {
1688            Some(SlimHeaderFlags::default())
1689        } else {
1690            flags
1691        };
1692
1693        assert!(!sub.is_publish());
1694        assert_eq!(sub.is_subscribe(), subscription);
1695        assert_eq!(sub.is_unsubscribe(), !subscription);
1696        assert_eq!(flags.as_ref().unwrap().recv_from, sub.get_recv_from());
1697        assert_eq!(flags.as_ref().unwrap().forward_to, sub.get_forward_to());
1698        assert_eq!(None, sub.try_get_incoming_conn());
1699        assert_eq!(source, sub.get_source());
1700        let got_name = sub.get_dst();
1701        assert_eq!(dst, got_name);
1702    }
1703
1704    fn test_publish_template(
1705        source: Name,
1706        dst: Name,
1707        identity: Option<&str>,
1708        flags: Option<SlimHeaderFlags>,
1709    ) {
1710        let mut builder = ProtoMessage::builder()
1711            .source(source.clone())
1712            .destination(dst.clone())
1713            .application_payload("str", "this is the content of the message".into());
1714
1715        if let Some(id) = identity {
1716            builder = builder.identity(id);
1717        }
1718
1719        if let Some(f) = flags.clone() {
1720            builder = builder.flags(f);
1721        }
1722
1723        let pub_msg = builder.build_publish().unwrap();
1724
1725        let flags = if flags.is_none() {
1726            Some(SlimHeaderFlags::default())
1727        } else {
1728            flags
1729        };
1730
1731        assert!(pub_msg.is_publish());
1732        assert!(!pub_msg.is_subscribe());
1733        assert!(!pub_msg.is_unsubscribe());
1734        assert_eq!(flags.as_ref().unwrap().recv_from, pub_msg.get_recv_from());
1735        assert_eq!(flags.as_ref().unwrap().forward_to, pub_msg.get_forward_to());
1736        assert_eq!(None, pub_msg.try_get_incoming_conn());
1737        assert_eq!(source, pub_msg.get_source());
1738        let got_name = pub_msg.get_dst();
1739        assert_eq!(dst, got_name);
1740        assert_eq!(flags.as_ref().unwrap().fanout, pub_msg.get_fanout());
1741    }
1742
1743    #[test]
1744    fn test_subscription() {
1745        let source = Name::from_strings(["org", "ns", "type"]).with_id(1);
1746        let dst = Name::from_strings(["org", "ns", "type"]).with_id(2);
1747
1748        // simple
1749        test_subscription_template(true, source.clone(), dst.clone(), None, None);
1750
1751        // with name id
1752        test_subscription_template(true, source.clone(), dst.clone(), None, None);
1753
1754        // with recv from
1755        test_subscription_template(
1756            true,
1757            source.clone(),
1758            dst.clone(),
1759            None,
1760            Some(SlimHeaderFlags::default().with_recv_from(50)),
1761        );
1762
1763        // with forward to
1764        test_subscription_template(
1765            true,
1766            source.clone(),
1767            dst.clone(),
1768            None,
1769            Some(SlimHeaderFlags::default().with_forward_to(30)),
1770        );
1771    }
1772
1773    #[test]
1774    fn test_unsubscription() {
1775        let source = Name::from_strings(["org", "ns", "type"]).with_id(1);
1776        let dst = Name::from_strings(["org", "ns", "type"]).with_id(2);
1777
1778        // simple
1779        test_subscription_template(false, source.clone(), dst.clone(), None, None);
1780
1781        // with name id
1782        test_subscription_template(false, source.clone(), dst.clone(), None, None);
1783
1784        // with recv from
1785        test_subscription_template(
1786            false,
1787            source.clone(),
1788            dst.clone(),
1789            None,
1790            Some(SlimHeaderFlags::default().with_recv_from(50)),
1791        );
1792
1793        // with forward to
1794        test_subscription_template(
1795            false,
1796            source.clone(),
1797            dst.clone(),
1798            None,
1799            Some(SlimHeaderFlags::default().with_forward_to(30)),
1800        );
1801    }
1802
1803    #[test]
1804    fn test_publish() {
1805        let source = Name::from_strings(["org", "ns", "type"]).with_id(1);
1806        let mut dst = Name::from_strings(["org", "ns", "type"]);
1807
1808        // simple
1809        test_publish_template(
1810            source.clone(),
1811            dst.clone(),
1812            None,
1813            Some(SlimHeaderFlags::default()),
1814        );
1815
1816        // with name id
1817        dst.set_id(2);
1818        test_publish_template(
1819            source.clone(),
1820            dst.clone(),
1821            None,
1822            Some(SlimHeaderFlags::default()),
1823        );
1824        dst.reset_id();
1825
1826        // with recv from
1827        test_publish_template(
1828            source.clone(),
1829            dst.clone(),
1830            None,
1831            Some(SlimHeaderFlags::default().with_recv_from(50)),
1832        );
1833
1834        // with forward to
1835        test_publish_template(
1836            source.clone(),
1837            dst.clone(),
1838            None,
1839            Some(SlimHeaderFlags::default().with_forward_to(30)),
1840        );
1841
1842        // with fanout
1843        test_publish_template(
1844            source.clone(),
1845            dst.clone(),
1846            None,
1847            Some(SlimHeaderFlags::default().with_fanout(2)),
1848        );
1849    }
1850
1851    #[test]
1852    fn test_conversions() {
1853        // Name to ProtoName
1854        let name = Name::from_strings(["org", "ns", "type"]).with_id(1);
1855        let proto_name = ProtoName::from(&name);
1856
1857        assert_eq!(
1858            proto_name.name.as_ref().unwrap().component_0,
1859            name.components()[0]
1860        );
1861        assert_eq!(
1862            proto_name.name.as_ref().unwrap().component_1,
1863            name.components()[1]
1864        );
1865        assert_eq!(
1866            proto_name.name.as_ref().unwrap().component_2,
1867            name.components()[2]
1868        );
1869        assert_eq!(
1870            proto_name.name.as_ref().unwrap().component_3,
1871            name.components()[3]
1872        );
1873
1874        // ProtoName to Name
1875        let name_from_proto = Name::from(&proto_name);
1876        assert_eq!(
1877            name_from_proto.components()[0],
1878            proto_name.name.as_ref().unwrap().component_0
1879        );
1880        assert_eq!(
1881            name_from_proto.components()[1],
1882            proto_name.name.as_ref().unwrap().component_1
1883        );
1884        assert_eq!(
1885            name_from_proto.components()[2],
1886            proto_name.name.as_ref().unwrap().component_2
1887        );
1888        assert_eq!(
1889            name_from_proto.components()[3],
1890            proto_name.name.as_ref().unwrap().component_3
1891        );
1892
1893        // ProtoMessage to ProtoSubscribe
1894        let dst = Name::from_strings(["org", "ns", "type"]).with_id(1);
1895        let proto_subscribe = ProtoMessage::builder()
1896            .source(name.clone())
1897            .destination(dst.clone())
1898            .flags(
1899                SlimHeaderFlags::default()
1900                    .with_recv_from(2)
1901                    .with_forward_to(3),
1902            )
1903            .build_subscribe()
1904            .unwrap();
1905        let proto_subscribe = ProtoSubscribe::from(proto_subscribe);
1906        assert_eq!(proto_subscribe.header.as_ref().unwrap().get_source(), name);
1907        assert_eq!(proto_subscribe.header.as_ref().unwrap().get_dst(), dst,);
1908
1909        // ProtoMessage to ProtoUnsubscribe
1910        let proto_unsubscribe = ProtoMessage::builder()
1911            .source(name.clone())
1912            .destination(dst.clone())
1913            .flags(
1914                SlimHeaderFlags::default()
1915                    .with_recv_from(2)
1916                    .with_forward_to(3),
1917            )
1918            .build_unsubscribe()
1919            .unwrap();
1920        let proto_unsubscribe = ProtoUnsubscribe::from(proto_unsubscribe);
1921        assert_eq!(
1922            proto_unsubscribe.header.as_ref().unwrap().get_source(),
1923            name
1924        );
1925        assert_eq!(proto_unsubscribe.header.as_ref().unwrap().get_dst(), dst);
1926
1927        // ProtoMessage to ProtoPublish
1928        let proto_publish = ProtoMessage::builder()
1929            .source(name.clone())
1930            .destination(dst.clone())
1931            .flags(
1932                SlimHeaderFlags::default()
1933                    .with_recv_from(2)
1934                    .with_forward_to(3),
1935            )
1936            .application_payload("str", "this is the content of the message".into())
1937            .build_publish()
1938            .unwrap();
1939        let proto_publish = ProtoPublish::from(proto_publish);
1940        assert_eq!(proto_publish.header.as_ref().unwrap().get_source(), name);
1941        assert_eq!(proto_publish.header.as_ref().unwrap().get_dst(), dst);
1942    }
1943
1944    #[test]
1945    fn test_panic() {
1946        let source = Name::from_strings(["org", "ns", "type"]).with_id(1);
1947        let dst = Name::from_strings(["org", "ns", "type"]).with_id(2);
1948
1949        // panic if SLIM header is not found
1950        let msg = ProtoMessage::builder()
1951            .source(source.clone())
1952            .destination(dst.clone())
1953            .flags(
1954                SlimHeaderFlags::default()
1955                    .with_recv_from(2)
1956                    .with_forward_to(3),
1957            )
1958            .build_subscribe()
1959            .unwrap();
1960
1961        // let's try to convert it to a unsubscribe
1962        // this should panic because the message type is not unsubscribe
1963        let result = std::panic::catch_unwind(|| ProtoUnsubscribe::from(msg.clone()));
1964        assert!(result.is_err());
1965
1966        // try to convert to publish
1967        // this should panic because the message type is not publish
1968        let result = std::panic::catch_unwind(|| ProtoPublish::from(msg.clone()));
1969        assert!(result.is_err());
1970
1971        // finally make sure the conversion to subscribe works
1972        let result = std::panic::catch_unwind(|| ProtoSubscribe::from(msg));
1973        assert!(result.is_ok());
1974    }
1975
1976    #[test]
1977    fn test_panic_header() {
1978        // create a unusual SLIM header
1979        let header = SlimHeader {
1980            source: None,
1981            destination: None,
1982            identity: String::new(),
1983            fanout: 0,
1984            recv_from: None,
1985            forward_to: None,
1986            incoming_conn: None,
1987            error: None,
1988        };
1989
1990        // the operations to retrieve source and destination should fail with panic
1991        let result = std::panic::catch_unwind(|| header.get_source());
1992        assert!(result.is_err());
1993
1994        let result = std::panic::catch_unwind(|| header.get_dst());
1995        assert!(result.is_err());
1996
1997        // The operations to retrieve recv_from and forward_to should not fail with panic
1998        let result = std::panic::catch_unwind(|| header.get_recv_from());
1999        assert!(result.is_ok());
2000
2001        let result = std::panic::catch_unwind(|| header.get_forward_to());
2002        assert!(result.is_ok());
2003
2004        // The operations to retrieve incoming_conn should not fail with panic
2005        let result = std::panic::catch_unwind(|| header.get_incoming_conn());
2006        assert!(result.is_ok());
2007
2008        // The operations to retrieve error should not fail with panic
2009        let result = std::panic::catch_unwind(|| header.get_error());
2010        assert!(result.is_ok());
2011    }
2012
2013    #[test]
2014    fn test_panic_session_header() {
2015        // create a unusual session header
2016        let header = SessionHeader::new(0, 0, 0, 0);
2017
2018        // the operations to retrieve session_id and message_id should not fail with panic
2019        let result = std::panic::catch_unwind(|| header.get_session_id());
2020        assert!(result.is_ok());
2021
2022        let result = std::panic::catch_unwind(|| header.get_message_id());
2023        assert!(result.is_ok());
2024    }
2025
2026    #[test]
2027    fn test_panic_proto_message() {
2028        // create a unusual proto message
2029        let message = ProtoMessage {
2030            metadata: HashMap::new(),
2031            message_type: None,
2032        };
2033
2034        // the operation to retrieve the header should fail with panic
2035        let result = std::panic::catch_unwind(|| message.get_slim_header());
2036        assert!(result.is_err());
2037
2038        // the operation to retrieve the message type should fail with panic
2039        let result = std::panic::catch_unwind(|| message.get_type());
2040        assert!(result.is_err());
2041
2042        // all the other ops should fail with panic as well as the header is not set
2043        let result = std::panic::catch_unwind(|| message.get_source());
2044        assert!(result.is_err());
2045        let result = std::panic::catch_unwind(|| message.get_dst());
2046        assert!(result.is_err());
2047        let result = std::panic::catch_unwind(|| message.get_recv_from());
2048        assert!(result.is_err());
2049        let result = std::panic::catch_unwind(|| message.get_forward_to());
2050        assert!(result.is_err());
2051        let result = std::panic::catch_unwind(|| message.get_incoming_conn());
2052        assert!(result.is_err());
2053        let result = std::panic::catch_unwind(|| message.get_fanout());
2054        assert!(result.is_err());
2055    }
2056
2057    #[test]
2058    fn test_service_type_to_int() {
2059        // Get total number of service types
2060        let total_service_types = SessionMessageType::Ping as i32;
2061
2062        for i in 0..total_service_types {
2063            // int -> ServiceType
2064            let service_type =
2065                SessionMessageType::try_from(i).expect("failed to convert int to service type");
2066            let service_type_int = i32::from(service_type);
2067            assert_eq!(service_type_int, i32::from(service_type),);
2068        }
2069
2070        // Test invalid conversion
2071        let invalid_service_type = SessionMessageType::try_from(total_service_types + 1);
2072        assert!(invalid_service_type.is_err());
2073    }
2074
2075    #[test]
2076    fn test_proto_message_builder() {
2077        let source = Name::from_strings(["org", "ns", "type"]).with_id(1);
2078        let dest = Name::from_strings(["org", "ns", "app"]).with_id(2);
2079
2080        // Test basic publish message
2081        let msg = ProtoMessage::builder()
2082            .source(source.clone())
2083            .destination(dest.clone())
2084            .application_payload("test", b"hello world".to_vec())
2085            .build_publish()
2086            .unwrap();
2087
2088        assert!(msg.is_publish());
2089        assert_eq!(msg.get_source(), source);
2090        assert_eq!(msg.get_dst(), dest);
2091
2092        // Test with session headers
2093        let msg = ProtoMessage::builder()
2094            .source(source.clone())
2095            .destination(dest.clone())
2096            .session_type(ProtoSessionType::Multicast)
2097            .session_message_type(SessionMessageType::Msg)
2098            .session_id(42)
2099            .message_id(100)
2100            .fanout(256)
2101            .application_payload("test", b"broadcast".to_vec())
2102            .build_publish()
2103            .unwrap();
2104
2105        assert_eq!(msg.get_session_type(), ProtoSessionType::Multicast);
2106        assert_eq!(msg.get_id(), 100);
2107        assert_eq!(msg.get_fanout(), 256);
2108
2109        // Test with metadata
2110        let msg = ProtoMessage::builder()
2111            .source(source.clone())
2112            .destination(dest.clone())
2113            .metadata("key1", "value1")
2114            .metadata("key2", "value2")
2115            .application_payload("test", vec![1, 2, 3])
2116            .build_publish()
2117            .unwrap();
2118
2119        assert_eq!(msg.get_metadata("key1"), Some(&"value1".to_string()));
2120        assert_eq!(msg.get_metadata("key2"), Some(&"value2".to_string()));
2121
2122        // Test subscribe message
2123        let msg = ProtoMessage::builder()
2124            .source(source.clone())
2125            .destination(dest.clone())
2126            .recv_from(10)
2127            .build_subscribe()
2128            .unwrap();
2129
2130        assert!(msg.is_subscribe());
2131        assert_eq!(msg.get_recv_from(), Some(10));
2132
2133        // Test unsubscribe message
2134        let msg = ProtoMessage::builder()
2135            .source(source.clone())
2136            .destination(dest.clone())
2137            .forward_to(20)
2138            .build_unsubscribe()
2139            .unwrap();
2140
2141        assert!(msg.is_unsubscribe());
2142        assert_eq!(msg.get_forward_to(), Some(20));
2143    }
2144
2145    #[test]
2146    fn test_command_payload_builder() {
2147        let dest = Name::from_strings(["org", "ns", "app"]);
2148
2149        // Test discovery request
2150        let payload = CommandPayload::builder().discovery_request(Some(dest.clone()));
2151        let extracted = payload.as_discovery_request_payload().unwrap();
2152        assert!(extracted.destination.is_some());
2153
2154        // Test discovery reply
2155        let payload = CommandPayload::builder().discovery_reply();
2156        assert!(payload.as_discovery_reply_payload().is_ok());
2157
2158        // Test join request
2159        let payload = CommandPayload::builder().join_request(
2160            true,
2161            Some(5),
2162            Some(Duration::from_secs(10)),
2163            Some(dest.clone()),
2164        );
2165        let extracted = payload.as_join_request_payload().unwrap();
2166        assert!(extracted.enable_mls);
2167        assert!(extracted.timer_settings.is_some());
2168
2169        // Test join reply
2170        let payload = CommandPayload::builder().join_reply(Some(vec![1, 2, 3]));
2171        let extracted = payload.as_join_reply_payload().unwrap();
2172        assert_eq!(extracted.key_package, Some(vec![1, 2, 3]));
2173
2174        // Test leave request
2175        let payload = CommandPayload::builder().leave_request(Some(dest.clone()));
2176        assert!(payload.as_leave_request_payload().is_ok());
2177
2178        // Test leave reply
2179        let payload = CommandPayload::builder().leave_reply();
2180        assert!(payload.as_leave_reply_payload().is_ok());
2181
2182        // Test group add
2183        let participants = vec![dest.clone()];
2184        let payload = CommandPayload::builder().group_add(dest.clone(), participants.clone(), None);
2185        let extracted = payload.as_group_add_payload().unwrap();
2186        assert!(extracted.new_participant.is_some());
2187
2188        // Test group remove
2189        let payload =
2190            CommandPayload::builder().group_remove(dest.clone(), participants.clone(), None);
2191        let extracted = payload.as_group_remove_payload().unwrap();
2192        assert!(extracted.removed_participant.is_some());
2193
2194        // Test group welcome
2195        let payload = CommandPayload::builder().group_welcome(participants.clone(), None);
2196        let extracted = payload.as_welcome_payload().unwrap();
2197        assert!(!extracted.participants.is_empty());
2198
2199        // Test group proposal
2200        let payload = CommandPayload::builder().group_proposal(Some(dest.clone()), vec![4, 5, 6]);
2201        let extracted = payload.as_group_proposal_payload().unwrap();
2202        assert_eq!(extracted.mls_proposal, vec![4, 5, 6]);
2203
2204        // Test group ack
2205        let payload = CommandPayload::builder().group_ack();
2206        assert!(payload.as_group_ack_payload().is_ok());
2207
2208        // Test group nack
2209        let payload = CommandPayload::builder().group_nack();
2210        assert!(payload.as_group_nack_payload().is_ok());
2211
2212        // Test ping
2213        let payload = CommandPayload::builder().ping();
2214        assert!(payload.as_ping_payload().is_ok());
2215    }
2216
2217    #[test]
2218    fn test_builder_with_command_payload() {
2219        let source = Name::from_strings(["org", "ns", "type"]).with_id(1);
2220        let dest = Name::from_strings(["org", "ns", "app"]).with_id(2);
2221
2222        let cmd_payload = CommandPayload::builder().discovery_request(Some(dest.clone()));
2223
2224        let msg = ProtoMessage::builder()
2225            .source(source.clone())
2226            .destination(dest.clone())
2227            .session_type(ProtoSessionType::PointToPoint)
2228            .session_message_type(SessionMessageType::DiscoveryRequest)
2229            .session_id(1)
2230            .command_payload(cmd_payload)
2231            .build_publish()
2232            .unwrap();
2233
2234        assert!(msg.is_publish());
2235        assert_eq!(
2236            msg.get_session_message_type(),
2237            SessionMessageType::DiscoveryRequest
2238        );
2239
2240        // Verify we can extract the payload
2241        let extracted = msg.extract_discovery_request().unwrap();
2242        assert!(extracted.destination.is_some());
2243    }
2244
2245    #[test]
2246    fn test_validate_link_without_link_type() {
2247        let link = ProtoLink { link_type: None };
2248        let msg = ProtoMessage::new(HashMap::new(), ProtoLinkMessageType(link));
2249        assert!(matches!(msg.validate(), Err(MessageError::LinkTypeNotSet)));
2250    }
2251
2252    #[test]
2253    fn test_validate_link_with_link_type() {
2254        let link = ProtoLink {
2255            link_type: Some(ProtoLinkType::LinkNegotiation(LinkNegotiationPayload {
2256                link_id: "abc".into(),
2257                slim_version: "1.0.0".into(),
2258                is_reply: false,
2259            })),
2260        };
2261        let msg = ProtoMessage::new(HashMap::new(), ProtoLinkMessageType(link));
2262        assert!(msg.validate().is_ok());
2263    }
2264
2265    #[test]
2266    fn test_build_link_negotiation_request() {
2267        let msg = ProtoMessage::builder().build_link_negotiation("my-id", "1.2.3", false);
2268        assert!(msg.is_link());
2269        assert!(!msg.is_publish());
2270        assert!(!msg.is_subscribe());
2271        assert!(msg.validate().is_ok());
2272    }
2273
2274    #[test]
2275    fn test_build_link_negotiation_reply() {
2276        let msg = ProtoMessage::builder().build_link_negotiation("my-id", "1.2.3", true);
2277        assert!(msg.is_link());
2278        assert!(msg.validate().is_ok());
2279    }
2280}