Skip to main content

slim_datapath/messages/
utils.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3
4use std::fmt::Display;
5use std::{collections::HashMap, time::Duration};
6
7use crate::api::proto::dataplane::v1::{
8    GroupClosePayload, GroupNackPayload, LinkConnectionType, Participant, ParticipantSettings,
9    PingPayload,
10};
11use crate::api::{
12    Content, LinkNegotiationPayload, MessageType, ProtoLink, ProtoLinkMessageType, ProtoLinkType,
13    ProtoMessage, ProtoMlsSettings as MlsSettings, ProtoName, ProtoPublish, ProtoPublishType,
14    ProtoSessionType, ProtoSubscribe, ProtoSubscribeType, ProtoSubscriptionAck,
15    ProtoSubscriptionAckType, ProtoUnsubscribe, ProtoUnsubscribeType, SessionHeader, SlimHeader,
16    proto::dataplane::v1::{
17        ApplicationPayload, CommandPayload, DiscoveryReplyPayload, DiscoveryRequestPayload,
18        EncodedName, GroupAckPayload, GroupAddPayload, GroupProposalPayload, GroupRemovePayload,
19        GroupWelcomePayload, JoinReplyPayload, JoinRequestPayload, LeaveReplyPayload,
20        LeaveRequestPayload, MlsPayload, SessionMessageType, TimerSettings,
21        command_payload::CommandPayloadType, content::ContentType,
22    },
23};
24
25use slim_version::version;
26use thiserror::Error;
27
28use crate::tables::ConnType;
29
30impl From<ConnType> for LinkConnectionType {
31    fn from(ct: ConnType) -> Self {
32        match ct {
33            ConnType::Peer => LinkConnectionType::Peer,
34            ConnType::Edge => LinkConnectionType::Edge,
35            _ => LinkConnectionType::Remote,
36        }
37    }
38}
39
40/// DELETE_GROUP indicates that the entire group is being closed.
41/// The moderator sets this metadata on the leave message sent to all participants
42/// when a channel deletion is requested.
43pub const DELETE_GROUP: &str = "DELETE_GROUP";
44
45/// PUBLISH_TO indicates that a message should bypass normal sequencing and be delivered directly to the specified endpoint.
46/// This is used in group sessions when the application API `publish_to` is used instead of `publish`.
47/// The value is set to `TRUE_VAL` for direct delivery without buffering.
48pub const PUBLISH_TO: &str = "PUBLISH_TO";
49
50/// DISCONNECTION_DETECTED indicates that a participant disconnection was detected (not a graceful leave).
51/// This is used in the leave request message and internally by the moderator when
52/// a disconnection is detected due to missing ping replies from the participant.
53/// The value is set to `TRUE_VAL` when disconnection is detected.
54pub const DISCONNECTION_DETECTED: &str = "DISCONNECTION_DETECTED";
55
56/// LEAVING_SESSION indicates that a participant is gracefully leaving the session.
57/// This is used in the leave request message sent by a participant closing the session to the moderator.
58/// The value is set to `TRUE_VAL` for graceful departure.
59pub const LEAVING_SESSION: &str = "LEAVING_SESSION";
60
61/// Standard string value representing a boolean "true" in message metadata.
62pub const TRUE_VAL: &str = "TRUE";
63
64/// Standard string value representing a boolean "false" in message metadata.
65pub const FALSE_VAL: &str = "FALSE";
66
67/// Maximum message ID for normal sequenced messages.
68/// Messages with IDs in the range [0, MAX_PUBLISH_ID] follow normal sequencing.
69/// Messages with IDs > MAX_PUBLISH_ID (used for `PUBLISH_TO` messages) bypass sequencing.
70/// Value: Half of u32::MAX to allow a separate ID space for out-of-band messages.
71pub const MAX_PUBLISH_ID: u32 = u32::MAX / 2;
72
73/// Default TTL value for messages that do not have an explicit TTL set.
74pub const DEFAULT_TTL: u32 = 8;
75
76#[derive(Error, Debug, PartialEq)]
77pub enum MessageError {
78    #[error("SLIM header not found")]
79    SlimHeaderNotFound,
80    #[error("source not found")]
81    SourceNotFound,
82    #[error("source encoded name not found")]
83    SourceEncodedNameNotFound,
84    #[error("destination not found")]
85    DestinationNotFound,
86    #[error("destination encoded name not found")]
87    DestinationEncodedNameNotFound,
88    #[error("session header not found")]
89    SessionHeaderNotFound,
90    #[error("message type not found")]
91    MessageTypeNotFound,
92    #[error("incoming connection not found")]
93    IncomingConnectionNotFound,
94    #[error("content type is not set")]
95    ContentTypeNotSet,
96    #[error("content is not an application payload")]
97    NotApplicationPayload,
98    #[error("content is not a command payload")]
99    NotCommandPayload,
100    #[error("link type is not set")]
101    LinkTypeNotSet,
102    #[error("invalid command payload type: expected {expected}, got {got}")]
103    InvalidCommandPayloadType {
104        expected: Box<String>,
105        got: Box<String>,
106    },
107
108    // Builder errors
109    #[error("builder error: source is required")]
110    BuilderErrorSourceRequired,
111    #[error("builder error: destination is required")]
112    BuilderErrorDestinationRequired,
113    #[error("participant name not found")]
114    ParticipantNameNotFound,
115    #[error("participant settings not found")]
116    ParticipantSettingsNotFound,
117}
118
119impl ParticipantSettings {
120    /// Creates bidirectional participant settings (both sends and receives data).
121    /// This is the most common configuration for participants in a session.
122    pub fn bidirectional() -> Self {
123        Self {
124            sends_data: true,
125            receives_data: true,
126        }
127    }
128
129    /// Creates send-only participant settings.
130    pub fn send_only() -> Self {
131        Self {
132            sends_data: true,
133            receives_data: false,
134        }
135    }
136
137    /// Creates receive-only participant settings.
138    pub fn receive_only() -> Self {
139        Self {
140            sends_data: false,
141            receives_data: true,
142        }
143    }
144
145    /// Returns whether this participant produces data messages.
146    pub fn is_sender(&self) -> bool {
147        self.sends_data
148    }
149
150    /// Returns whether this participant consumes data messages.
151    pub fn is_receiver(&self) -> bool {
152        self.receives_data
153    }
154}
155
156impl Participant {
157    pub fn new(name: ProtoName, settings: ParticipantSettings) -> Self {
158        Self {
159            name: Some(name),
160            settings: Some(settings),
161        }
162    }
163
164    pub fn get_name(&self) -> Result<ProtoName, MessageError> {
165        match &self.name {
166            Some(name) => Ok(name.clone()),
167            None => Err(MessageError::ParticipantNameNotFound),
168        }
169    }
170
171    pub fn get_settings(&self) -> Result<&ParticipantSettings, MessageError> {
172        match &self.settings {
173            Some(settings) => Ok(settings),
174            None => Err(MessageError::ParticipantSettingsNotFound),
175        }
176    }
177}
178
179/// Print message type
180impl Display for MessageType {
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182        match self {
183            MessageType::Publish(_) => write!(f, "publish"),
184            MessageType::Subscribe(_) => write!(f, "subscribe"),
185            MessageType::Unsubscribe(_) => write!(f, "unsubscribe"),
186            MessageType::Link(_) => write!(f, "link"),
187            MessageType::SubscriptionAck(_) => write!(f, "subscription_ack"),
188        }
189    }
190}
191
192/// Struct grouping the SLIMHeaeder flags for convenience
193#[derive(Debug, Clone)]
194pub struct SlimHeaderFlags {
195    pub fanout: u32,
196    pub recv_from: Option<u64>,
197    pub forward_to: Option<u64>,
198    pub incoming_conn: Option<u64>,
199    pub error: Option<bool>,
200    pub ttl: u32,
201}
202
203impl Default for SlimHeaderFlags {
204    fn default() -> Self {
205        Self {
206            fanout: 1,
207            recv_from: None,
208            forward_to: None,
209            incoming_conn: None,
210            error: None,
211            ttl: DEFAULT_TTL,
212        }
213    }
214}
215
216impl Display for SlimHeaderFlags {
217    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218        write!(
219            f,
220            "fanout: {}, recv_from: {:?}, forward_to: {:?}, incoming_conn: {:?}, error: {:?}, ttl: {:?}",
221            self.fanout, self.recv_from, self.forward_to, self.incoming_conn, self.error, self.ttl
222        )
223    }
224}
225
226impl SlimHeaderFlags {
227    pub fn new(
228        fanout: u32,
229        recv_from: Option<u64>,
230        forward_to: Option<u64>,
231        incoming_conn: Option<u64>,
232        error: Option<bool>,
233    ) -> Self {
234        Self {
235            fanout,
236            recv_from,
237            forward_to,
238            incoming_conn,
239            error,
240            ttl: DEFAULT_TTL,
241        }
242    }
243
244    pub fn with_fanout(self, fanout: u32) -> Self {
245        Self { fanout, ..self }
246    }
247
248    pub fn with_recv_from(self, recv_from: u64) -> Self {
249        Self {
250            recv_from: Some(recv_from),
251            ..self
252        }
253    }
254
255    pub fn with_forward_to(self, forward_to: u64) -> Self {
256        Self {
257            forward_to: Some(forward_to),
258            ..self
259        }
260    }
261
262    pub fn with_incoming_conn(self, incoming_conn: u64) -> Self {
263        Self {
264            incoming_conn: Some(incoming_conn),
265            ..self
266        }
267    }
268
269    pub fn with_error(self, error: bool) -> Self {
270        Self {
271            error: Some(error),
272            ..self
273        }
274    }
275
276    pub fn with_ttl(self, ttl: u32) -> Self {
277        Self { ttl, ..self }
278    }
279}
280
281/// SLIM Header
282/// This header is used to identify the source and destination of the message
283/// and to manage the connections used to send and receive the message
284impl SlimHeader {
285    pub fn new(
286        source: ProtoName,
287        destination: ProtoName,
288        identity: &str,
289        flags: Option<SlimHeaderFlags>,
290    ) -> Self {
291        let flags = flags.unwrap_or_default();
292        Self {
293            source: Some(source),
294            destination: Some(destination),
295            identity: identity.to_string(),
296            fanout: flags.fanout,
297            version: version().to_string(),
298            recv_from: flags.recv_from,
299            forward_to: flags.forward_to,
300            incoming_conn: flags.incoming_conn,
301            error: flags.error,
302            header_mac: None,
303            ttl: flags.ttl,
304        }
305    }
306
307    pub fn clear_flags(&mut self) {
308        self.recv_from = None;
309        self.forward_to = None;
310    }
311
312    pub fn get_fanout(&self) -> u32 {
313        self.fanout
314    }
315
316    pub fn get_recv_from(&self) -> Option<u64> {
317        self.recv_from
318    }
319
320    pub fn get_forward_to(&self) -> Option<u64> {
321        self.forward_to
322    }
323
324    pub fn get_incoming_conn(&self) -> Option<u64> {
325        self.incoming_conn
326    }
327
328    pub fn get_error(&self) -> Option<bool> {
329        self.error
330    }
331
332    pub fn get_source(&self) -> ProtoName {
333        self.source.clone().expect("source not found")
334    }
335
336    pub fn get_encoded_source(&self) -> EncodedName {
337        self.source.as_ref().unwrap().name.unwrap()
338    }
339
340    pub fn get_dst(&self) -> ProtoName {
341        self.destination.clone().expect("destination not found")
342    }
343
344    pub fn get_encoded_dst(&self) -> EncodedName {
345        self.destination.as_ref().unwrap().name.unwrap()
346    }
347
348    pub fn get_identity(&self) -> String {
349        self.identity.clone()
350    }
351
352    pub fn get_version(&self) -> String {
353        self.version.clone()
354    }
355
356    pub fn set_source(&mut self, source: ProtoName) {
357        self.source = Some(source);
358    }
359
360    pub fn set_destination(&mut self, dst: ProtoName) {
361        self.destination = Some(dst);
362    }
363
364    pub fn set_identity(&mut self, identity: String) {
365        self.identity = identity;
366    }
367
368    pub fn set_fanout(&mut self, fanout: u32) {
369        self.fanout = fanout;
370    }
371
372    pub fn set_recv_from(&mut self, recv_from: Option<u64>) {
373        self.recv_from = recv_from;
374    }
375
376    pub fn set_forward_to(&mut self, forward_to: Option<u64>) {
377        self.forward_to = forward_to;
378    }
379
380    pub fn set_error(&mut self, error: Option<bool>) {
381        self.error = error;
382    }
383
384    pub fn set_incoming_conn(&mut self, incoming_conn: Option<u64>) {
385        self.incoming_conn = incoming_conn;
386    }
387
388    pub fn set_error_flag(&mut self, error: Option<bool>) {
389        self.error = error;
390    }
391
392    pub fn get_ttl(&self) -> u32 {
393        self.ttl
394    }
395
396    pub fn set_ttl(&mut self, ttl: u32) {
397        self.ttl = ttl;
398    }
399
400    /// Decrements TTL by 1. Returns the new value.
401    pub fn decrement_ttl(&mut self) -> u32 {
402        self.ttl = self.ttl.saturating_sub(1);
403        self.ttl
404    }
405
406    #[cfg(not(target_arch = "wasm32"))]
407    // returns (incoming, recv_from, forward_to) for subscription processing
408    pub(crate) fn get_connections(&self) -> (u64, Option<u64>, Option<u64>) {
409        // when calling this function, incoming connection is set
410        let incoming = self
411            .get_incoming_conn()
412            .expect("incoming connection not found");
413
414        (incoming, self.get_recv_from(), self.get_forward_to())
415    }
416}
417
418/// Session Header
419/// This header is used to identify the session and the message
420/// and to manage session state
421impl SessionHeader {
422    pub fn new(
423        session_type: i32,
424        session_message_type: i32,
425        session_id: u32,
426        message_id: u32,
427    ) -> Self {
428        Self {
429            session_type,
430            session_message_type,
431            session_id,
432            message_id,
433        }
434    }
435
436    pub fn get_session_id(&self) -> u32 {
437        self.session_id
438    }
439
440    pub fn get_message_id(&self) -> u32 {
441        self.message_id
442    }
443
444    pub fn set_session_id(&mut self, session_id: u32) {
445        self.session_id = session_id;
446    }
447
448    pub fn set_message_id(&mut self, message_id: u32) {
449        self.message_id = message_id;
450    }
451
452    pub fn clear(&mut self) {
453        self.session_id = 0;
454        self.message_id = 0;
455    }
456}
457
458/// SessionMessageType
459/// Helper methods for session message types
460impl SessionMessageType {
461    /// Check if a message type is a command message (not application data)
462    pub fn is_command_message(&self) -> bool {
463        matches!(
464            self,
465            SessionMessageType::DiscoveryRequest
466                | SessionMessageType::DiscoveryReply
467                | SessionMessageType::JoinRequest
468                | SessionMessageType::JoinReply
469                | SessionMessageType::LeaveRequest
470                | SessionMessageType::LeaveReply
471                | SessionMessageType::GroupAdd
472                | SessionMessageType::GroupRemove
473                | SessionMessageType::GroupWelcome
474                | SessionMessageType::GroupClose
475                | SessionMessageType::GroupProposal
476                | SessionMessageType::GroupAck
477                | SessionMessageType::GroupNack
478                | SessionMessageType::Ping
479        )
480    }
481}
482
483/// ProtoSubscribe
484/// This message is used to subscribe to a topic
485impl ProtoSubscribe {
486    fn new(
487        source: ProtoName,
488        dst: ProtoName,
489        identity: Option<&str>,
490        flags: Option<SlimHeaderFlags>,
491    ) -> Self {
492        let id = identity.unwrap_or("");
493        let header = Some(SlimHeader::new(source, dst, id, flags));
494
495        ProtoSubscribe {
496            header,
497            subscription_id: 0,
498        }
499    }
500}
501
502/// From ProtoMessage to ProtoSubscribe
503impl From<ProtoMessage> for ProtoSubscribe {
504    fn from(message: ProtoMessage) -> Self {
505        match message.message_type {
506            Some(ProtoSubscribeType(s)) => s,
507            _ => panic!("message type is not subscribe"),
508        }
509    }
510}
511
512/// ProtoUnsubscribe
513/// This message is used to unsubscribe from a topic
514impl ProtoUnsubscribe {
515    fn new(
516        source: ProtoName,
517        dst: ProtoName,
518        identity: Option<&str>,
519        flags: Option<SlimHeaderFlags>,
520    ) -> Self {
521        let id = identity.unwrap_or("");
522        let header = Some(SlimHeader::new(source, dst, id, flags));
523
524        ProtoUnsubscribe {
525            header,
526            subscription_id: 0,
527        }
528    }
529}
530
531/// From ProtoMessage to ProtoUnsubscribe
532impl From<ProtoMessage> for ProtoUnsubscribe {
533    fn from(message: ProtoMessage) -> Self {
534        match message.message_type {
535            Some(ProtoUnsubscribeType(u)) => u,
536            _ => panic!("message type is not unsubscribe"),
537        }
538    }
539}
540
541/// ProtoPublish
542/// This message is used to publish a message, either to a shared channel or to a specific application
543impl ProtoPublish {
544    fn with_header(
545        header: Option<SlimHeader>,
546        session: Option<SessionHeader>,
547        payload: Option<Content>,
548    ) -> Self {
549        ProtoPublish {
550            header,
551            session,
552            msg: payload,
553        }
554    }
555
556    pub fn get_slim_header(&self) -> &SlimHeader {
557        self.header.as_ref().unwrap()
558    }
559
560    pub fn get_session_header(&self) -> &SessionHeader {
561        self.session.as_ref().unwrap()
562    }
563
564    pub fn get_slim_header_as_mut(&mut self) -> &mut SlimHeader {
565        self.header.as_mut().unwrap()
566    }
567
568    pub fn get_session_header_as_mut(&mut self) -> &mut SessionHeader {
569        self.session.as_mut().unwrap()
570    }
571
572    pub fn get_payload(&self) -> &Content {
573        self.msg.as_ref().unwrap()
574    }
575
576    pub fn set_payload(&mut self, payload: Content) {
577        self.msg = Some(payload);
578    }
579
580    pub fn is_command(&self) -> bool {
581        match &self.get_payload().content_type.as_ref().unwrap() {
582            ContentType::AppPayload(_) => false,
583            ContentType::CommandPayload(_) => true,
584        }
585    }
586
587    pub fn get_application_payload(&self) -> &ApplicationPayload {
588        match self.get_payload().content_type.as_ref().unwrap() {
589            ContentType::AppPayload(application_payload) => application_payload,
590            ContentType::CommandPayload(_) => panic!("the payload is not an application payload"),
591        }
592    }
593
594    pub fn get_command_payload(&self) -> &CommandPayload {
595        match &self.get_payload().content_type.as_ref().unwrap() {
596            ContentType::AppPayload(_) => panic!("the payaoad is not a command payload"),
597            ContentType::CommandPayload(command_payload) => command_payload,
598        }
599    }
600}
601
602/// From ProtoMessage to ProtoPublish
603impl From<ProtoMessage> for ProtoPublish {
604    fn from(message: ProtoMessage) -> Self {
605        match message.message_type {
606            Some(ProtoPublishType(p)) => p,
607            _ => panic!("message type is not publish"),
608        }
609    }
610}
611
612/// ProtoMessage
613/// This represents a generic message that can be sent over the network
614// Macro to generate payload extraction methods for ProtoMessage
615macro_rules! impl_payload_extractors {
616    ($($method_name:ident => $getter_method:ident($payload_type:ty)),* $(,)?) => {
617        $(
618            /// Extracts a specific command payload from the message.
619            pub fn $method_name(&self) -> Result<&$payload_type, MessageError> {
620                self.extract_command_payload()?.$getter_method()
621            }
622        )*
623    };
624}
625
626impl ProtoMessage {
627    fn new(metadata: HashMap<String, String>, message_type: MessageType) -> Self {
628        ProtoMessage {
629            metadata,
630            message_type: Some(message_type),
631        }
632    }
633
634    fn validate_link(link: &ProtoLink) -> Result<(), MessageError> {
635        if link.link_type.is_none() {
636            return Err(MessageError::LinkTypeNotSet);
637        }
638        Ok(())
639    }
640
641    fn validate_routed_header(slim_header: &SlimHeader) -> Result<(), MessageError> {
642        match &slim_header.source {
643            None => return Err(MessageError::SourceNotFound),
644            Some(src) if src.name.is_none() => return Err(MessageError::SourceEncodedNameNotFound),
645            _ => {}
646        }
647        match &slim_header.destination {
648            None => return Err(MessageError::DestinationNotFound),
649            Some(dst) if dst.name.is_none() => {
650                return Err(MessageError::DestinationEncodedNameNotFound);
651            }
652            _ => {}
653        }
654        Ok(())
655    }
656
657    fn validate_publish(p: &ProtoPublish) -> Result<(), MessageError> {
658        let hdr = p.header.as_ref().ok_or(MessageError::SlimHeaderNotFound)?;
659        Self::validate_routed_header(hdr)?;
660        if p.session.is_none() {
661            return Err(MessageError::SessionHeaderNotFound);
662        }
663        Ok(())
664    }
665
666    fn validate_subscribe(s: &ProtoSubscribe) -> Result<(), MessageError> {
667        let hdr = s.header.as_ref().ok_or(MessageError::SlimHeaderNotFound)?;
668        Self::validate_routed_header(hdr)
669    }
670
671    fn validate_unsubscribe(u: &ProtoUnsubscribe) -> Result<(), MessageError> {
672        let hdr = u.header.as_ref().ok_or(MessageError::SlimHeaderNotFound)?;
673        Self::validate_routed_header(hdr)
674    }
675
676    // validate message
677    pub fn validate(&self) -> Result<(), MessageError> {
678        match &self.message_type {
679            None => Err(MessageError::MessageTypeNotFound),
680            Some(ProtoLinkMessageType(link)) => Self::validate_link(link),
681            Some(ProtoPublishType(p)) => Self::validate_publish(p),
682            Some(ProtoSubscribeType(s)) => Self::validate_subscribe(s),
683            Some(ProtoUnsubscribeType(u)) => Self::validate_unsubscribe(u),
684            Some(ProtoSubscriptionAckType(_)) => Ok(()),
685        }
686    }
687
688    // add metadata key in the map assigning the value val
689    // if the key exists the value is replaced by val
690    pub fn insert_metadata(&mut self, key: String, val: String) {
691        self.metadata.insert(key, val);
692    }
693
694    // remove metadata key from the map
695    pub fn remove_metadata(&mut self, key: &str) -> Option<String> {
696        self.metadata.remove(key)
697    }
698
699    pub fn contains_metadata(&self, key: &str) -> bool {
700        self.metadata.contains_key(key)
701    }
702
703    pub fn get_metadata(&self, key: &str) -> Option<&String> {
704        self.metadata.get(key)
705    }
706
707    pub fn get_metadata_map(&self) -> HashMap<String, String> {
708        self.metadata.clone()
709    }
710
711    pub fn set_metadata_map(&mut self, map: HashMap<String, String>) {
712        for (k, v) in map.iter() {
713            self.insert_metadata(k.to_string(), v.to_string());
714        }
715    }
716
717    pub fn get_slim_header(&self) -> &SlimHeader {
718        match &self.message_type {
719            Some(ProtoPublishType(publish)) => publish.header.as_ref().unwrap(),
720            Some(ProtoSubscribeType(sub)) => sub.header.as_ref().unwrap(),
721            Some(ProtoUnsubscribeType(unsub)) => unsub.header.as_ref().unwrap(),
722            Some(ProtoLinkMessageType(_)) | Some(ProtoSubscriptionAckType(_)) | None => {
723                panic!("SLIM header not found")
724            }
725        }
726    }
727
728    pub fn get_slim_header_mut(&mut self) -> &mut SlimHeader {
729        match &mut self.message_type {
730            Some(ProtoPublishType(publish)) => publish.header.as_mut().unwrap(),
731            Some(ProtoSubscribeType(sub)) => sub.header.as_mut().unwrap(),
732            Some(ProtoUnsubscribeType(unsub)) => unsub.header.as_mut().unwrap(),
733            Some(ProtoLinkMessageType(_)) | Some(ProtoSubscriptionAckType(_)) | None => {
734                panic!("SLIM header not found")
735            }
736        }
737    }
738
739    pub fn try_get_slim_header(&self) -> Option<&SlimHeader> {
740        match &self.message_type {
741            Some(ProtoPublishType(publish)) => publish.header.as_ref(),
742            Some(ProtoSubscribeType(sub)) => sub.header.as_ref(),
743            Some(ProtoUnsubscribeType(unsub)) => unsub.header.as_ref(),
744            Some(ProtoLinkMessageType(_)) | Some(ProtoSubscriptionAckType(_)) | None => None,
745        }
746    }
747
748    pub fn get_session_header(&self) -> &SessionHeader {
749        match &self.message_type {
750            Some(ProtoPublishType(publish)) => publish.session.as_ref().unwrap(),
751            Some(ProtoSubscribeType(_))
752            | Some(ProtoUnsubscribeType(_))
753            | Some(ProtoLinkMessageType(_))
754            | Some(ProtoSubscriptionAckType(_))
755            | None => panic!("session header not found"),
756        }
757    }
758
759    pub fn get_session_header_mut(&mut self) -> &mut SessionHeader {
760        match &mut self.message_type {
761            Some(ProtoPublishType(publish)) => publish.session.as_mut().unwrap(),
762            Some(ProtoSubscribeType(_))
763            | Some(ProtoUnsubscribeType(_))
764            | Some(ProtoLinkMessageType(_))
765            | Some(ProtoSubscriptionAckType(_))
766            | None => panic!("session header not found"),
767        }
768    }
769
770    pub fn try_get_session_header(&self) -> Option<&SessionHeader> {
771        match &self.message_type {
772            Some(ProtoPublishType(publish)) => publish.session.as_ref(),
773            Some(ProtoSubscribeType(_))
774            | Some(ProtoUnsubscribeType(_))
775            | Some(ProtoLinkMessageType(_))
776            | Some(ProtoSubscriptionAckType(_))
777            | None => None,
778        }
779    }
780
781    pub fn try_get_session_header_mut(&mut self) -> Option<&mut SessionHeader> {
782        match &mut self.message_type {
783            Some(ProtoPublishType(publish)) => publish.session.as_mut(),
784            Some(ProtoSubscribeType(_))
785            | Some(ProtoUnsubscribeType(_))
786            | Some(ProtoLinkMessageType(_))
787            | Some(ProtoSubscriptionAckType(_))
788            | None => None,
789        }
790    }
791
792    pub fn get_id(&self) -> u32 {
793        self.get_session_header().get_message_id()
794    }
795
796    pub fn get_source(&self) -> ProtoName {
797        self.get_slim_header().get_source()
798    }
799
800    pub fn get_encoded_source(&self) -> EncodedName {
801        self.get_slim_header().get_encoded_source()
802    }
803
804    pub fn get_dst(&self) -> ProtoName {
805        self.get_slim_header().get_dst()
806    }
807
808    pub fn get_encoded_dst(&self) -> EncodedName {
809        self.get_slim_header().get_encoded_dst()
810    }
811
812    pub fn get_identity(&self) -> String {
813        self.get_slim_header().get_identity()
814    }
815
816    pub fn get_fanout(&self) -> u32 {
817        self.get_slim_header().get_fanout()
818    }
819
820    pub fn get_recv_from(&self) -> Option<u64> {
821        self.get_slim_header().get_recv_from()
822    }
823
824    pub fn get_forward_to(&self) -> Option<u64> {
825        self.get_slim_header().get_forward_to()
826    }
827
828    pub fn get_error(&self) -> Option<bool> {
829        self.get_slim_header().get_error()
830    }
831
832    pub fn get_incoming_conn(&self) -> u64 {
833        self.get_slim_header().get_incoming_conn().unwrap()
834    }
835
836    pub fn try_get_incoming_conn(&self) -> Option<u64> {
837        self.get_slim_header().get_incoming_conn()
838    }
839
840    pub fn get_type(&self) -> &MessageType {
841        match &self.message_type {
842            Some(t) => t,
843            None => panic!("message type not found"),
844        }
845    }
846
847    pub fn get_payload(&self) -> Option<&Content> {
848        match &self.message_type {
849            Some(ProtoPublishType(p)) => p.msg.as_ref(),
850            Some(ProtoSubscribeType(_)) => panic!("payload not found"),
851            Some(ProtoUnsubscribeType(_)) => panic!("payload not found"),
852            Some(ProtoLinkMessageType(_)) => panic!("payload not found"),
853            Some(ProtoSubscriptionAckType(_)) => panic!("payload not found"),
854            None => panic!("payload not found"),
855        }
856    }
857
858    pub fn set_payload(&mut self, payload: Content) {
859        match &mut self.message_type {
860            Some(ProtoPublishType(p)) => p.set_payload(payload),
861            Some(ProtoSubscribeType(_)) => panic!("no payload allowed"),
862            Some(ProtoUnsubscribeType(_)) => panic!("no payload allowed"),
863            Some(ProtoLinkMessageType(_)) => panic!("no payload allowed"),
864            Some(ProtoSubscriptionAckType(_)) => panic!("no payload allowed"),
865            None => panic!("no payload allowed"),
866        }
867    }
868
869    pub fn get_session_message_type(&self) -> SessionMessageType {
870        self.get_session_header()
871            .session_message_type
872            .try_into()
873            .unwrap_or_default()
874    }
875
876    pub fn clear_slim_header(&mut self) {
877        if self.is_link() || self.is_subscription_ack() {
878            return;
879        }
880        self.get_slim_header_mut().clear_flags();
881    }
882
883    pub fn set_recv_from(&mut self, recv_from: Option<u64>) {
884        self.get_slim_header_mut().set_recv_from(recv_from);
885    }
886
887    pub fn set_forward_to(&mut self, forward_to: Option<u64>) {
888        self.get_slim_header_mut().set_forward_to(forward_to);
889    }
890
891    pub fn set_error(&mut self, error: Option<bool>) {
892        self.get_slim_header_mut().set_error(error);
893    }
894
895    pub fn set_fanout(&mut self, fanout: u32) {
896        self.get_slim_header_mut().set_fanout(fanout);
897    }
898
899    pub fn set_incoming_conn(&mut self, incoming_conn: Option<u64>) {
900        self.get_slim_header_mut().set_incoming_conn(incoming_conn);
901    }
902
903    pub fn set_error_flag(&mut self, error: Option<bool>) {
904        self.get_slim_header_mut().set_error_flag(error);
905    }
906
907    pub fn get_ttl(&self) -> u32 {
908        self.get_slim_header().get_ttl()
909    }
910
911    pub fn set_ttl(&mut self, ttl: u32) {
912        self.get_slim_header_mut().set_ttl(ttl);
913    }
914
915    /// Decrements TTL by 1. Returns the new value.
916    pub fn decrement_ttl(&mut self) -> u32 {
917        self.get_slim_header_mut().decrement_ttl()
918    }
919
920    pub fn set_session_message_type(&mut self, message_type: SessionMessageType) {
921        self.get_session_header_mut()
922            .set_session_message_type(message_type);
923    }
924
925    pub fn set_session_type(&mut self, session_type: ProtoSessionType) {
926        self.get_session_header_mut().set_session_type(session_type);
927    }
928
929    pub fn get_session_type(&self) -> ProtoSessionType {
930        self.get_session_header().session_type()
931    }
932
933    pub fn set_message_id(&mut self, message_id: u32) {
934        self.get_session_header_mut().set_message_id(message_id);
935    }
936
937    pub fn is_publish(&self) -> bool {
938        matches!(self.get_type(), MessageType::Publish(_))
939    }
940
941    pub fn is_subscribe(&self) -> bool {
942        matches!(self.get_type(), MessageType::Subscribe(_))
943    }
944
945    pub fn is_unsubscribe(&self) -> bool {
946        matches!(self.get_type(), MessageType::Unsubscribe(_))
947    }
948
949    pub fn is_link(&self) -> bool {
950        matches!(self.get_type(), MessageType::Link(_))
951    }
952
953    /// Extract the [`LinkNegotiationPayload`] from a Link message, if present.
954    pub fn get_link_negotiation_payload(&self) -> Option<LinkNegotiationPayload> {
955        match &self.message_type {
956            Some(ProtoLinkMessageType(link)) => match &link.link_type {
957                Some(ProtoLinkType::LinkNegotiation(payload)) => Some(payload.clone()),
958                _ => None,
959            },
960            _ => None,
961        }
962    }
963
964    pub fn is_subscription_ack(&self) -> bool {
965        matches!(self.get_type(), MessageType::SubscriptionAck(_))
966    }
967
968    pub fn is_traceable(&self) -> bool {
969        !self.is_link() && !self.is_subscription_ack()
970    }
971
972    pub fn get_subscription_ack(&self) -> &ProtoSubscriptionAck {
973        match &self.message_type {
974            Some(ProtoSubscriptionAckType(ack)) => ack,
975            _ => panic!("message type is not subscription_ack"),
976        }
977    }
978
979    /// Returns the subscription_id from a Subscribe/Unsubscribe message, or None if absent/zero.
980    pub fn get_subscription_id(&self) -> Option<u64> {
981        match &self.message_type {
982            Some(ProtoSubscribeType(s)) if s.subscription_id != 0 => Some(s.subscription_id),
983            Some(ProtoUnsubscribeType(u)) if u.subscription_id != 0 => Some(u.subscription_id),
984            _ => None,
985        }
986    }
987
988    /// Takes and clears the subscription_id from a Subscribe/Unsubscribe message.
989    /// Returns None if the field is absent or zero.
990    pub fn take_subscription_id(&mut self) -> Option<u64> {
991        match &mut self.message_type {
992            Some(ProtoSubscribeType(s)) if s.subscription_id != 0 => {
993                Some(std::mem::take(&mut s.subscription_id))
994            }
995            Some(ProtoUnsubscribeType(u)) if u.subscription_id != 0 => {
996                Some(std::mem::take(&mut u.subscription_id))
997            }
998            _ => None,
999        }
1000    }
1001
1002    /// Sets the subscription_id on a Subscribe/Unsubscribe message (no-op for other types).
1003    pub fn set_subscription_id(&mut self, subscription_id: u64) {
1004        match &mut self.message_type {
1005            Some(ProtoSubscribeType(s)) => s.subscription_id = subscription_id,
1006            Some(ProtoUnsubscribeType(u)) => u.subscription_id = subscription_id,
1007            _ => {}
1008        }
1009    }
1010
1011    /// Extracts the command payload from the message.
1012    ///
1013    /// # Errors
1014    /// Returns `MessageError` if the payload is missing or cannot be converted.
1015    pub fn extract_command_payload(&self) -> Result<&CommandPayload, MessageError> {
1016        self.get_payload()
1017            .ok_or(MessageError::ContentTypeNotSet)?
1018            .as_command_payload()
1019    }
1020
1021    // Generate all payload extraction methods
1022    impl_payload_extractors! {
1023        extract_discovery_request => as_discovery_request_payload(DiscoveryRequestPayload),
1024        extract_discovery_reply => as_discovery_reply_payload(DiscoveryReplyPayload),
1025        extract_join_request => as_join_request_payload(JoinRequestPayload),
1026        extract_join_reply => as_join_reply_payload(JoinReplyPayload),
1027        extract_leave_request => as_leave_request_payload(LeaveRequestPayload),
1028        extract_leave_reply => as_leave_reply_payload(LeaveReplyPayload),
1029        extract_group_add => as_group_add_payload(GroupAddPayload),
1030        extract_group_remove => as_group_remove_payload(GroupRemovePayload),
1031        extract_group_welcome => as_welcome_payload(GroupWelcomePayload),
1032        extract_group_close => as_group_close_payload(GroupClosePayload),
1033        extract_group_proposal => as_group_proposal_payload(GroupProposalPayload),
1034        extract_group_ack => as_group_ack_payload(GroupAckPayload),
1035        extract_group_nack => as_group_nack_payload(GroupNackPayload),
1036        extract_ping => as_ping_payload(PingPayload),
1037    }
1038}
1039
1040impl Content {
1041    pub fn as_application_payload(&self) -> Result<&ApplicationPayload, MessageError> {
1042        match &self.content_type {
1043            Some(ContentType::AppPayload(app_payload)) => Ok(app_payload),
1044            Some(ContentType::CommandPayload(_)) => Err(MessageError::NotApplicationPayload),
1045            None => Err(MessageError::ContentTypeNotSet),
1046        }
1047    }
1048
1049    pub fn as_command_payload(&self) -> Result<&CommandPayload, MessageError> {
1050        match &self.content_type {
1051            Some(ContentType::AppPayload(_)) => Err(MessageError::NotCommandPayload),
1052            Some(ContentType::CommandPayload(comm_payload)) => Ok(comm_payload),
1053            None => Err(MessageError::ContentTypeNotSet),
1054        }
1055    }
1056}
1057
1058impl ApplicationPayload {
1059    pub fn new(payload_type: &str, blob: Vec<u8>) -> Self {
1060        Self {
1061            payload_type: payload_type.to_string(),
1062            blob,
1063        }
1064    }
1065
1066    pub fn as_content(&self) -> Content {
1067        Content {
1068            content_type: Some(ContentType::AppPayload(self.clone())),
1069        }
1070    }
1071}
1072
1073// Macro to generate getter methods for all CommandPayloadType variants
1074macro_rules! impl_command_payload_getters {
1075    ($(
1076        $method_name:ident => $variant:ident($payload_type:ty)
1077    ),* $(,)?) => {
1078        $(
1079            pub fn $method_name(&self) -> Result<&$payload_type, MessageError> {
1080                match &self.command_payload_type {
1081                    Some(CommandPayloadType::$variant(payload)) => Ok(payload),
1082                    Some(other) => Err(MessageError::InvalidCommandPayloadType {
1083                        expected: Box::new(stringify!($variant).to_string()),
1084                        got: Box::new(format!("{:?}", other)),
1085                    }),
1086                    None => Err(MessageError::InvalidCommandPayloadType {
1087                        expected: Box::new(stringify!($variant).to_string()),
1088                        got: Box::new("None".to_string()),
1089                    }),
1090                }
1091            }
1092        )*
1093    };
1094}
1095
1096impl CommandPayload {
1097    pub fn as_content(self) -> Content {
1098        Content {
1099            content_type: Some(ContentType::CommandPayload(self)),
1100        }
1101    }
1102
1103    // Getter methods for all CommandPayloadType variants
1104    impl_command_payload_getters! {
1105        as_discovery_request_payload => DiscoveryRequest(DiscoveryRequestPayload),
1106        as_discovery_reply_payload => DiscoveryReply(DiscoveryReplyPayload),
1107        as_join_request_payload => JoinRequest(JoinRequestPayload),
1108        as_join_reply_payload => JoinReply(JoinReplyPayload),
1109        as_leave_request_payload => LeaveRequest(LeaveRequestPayload),
1110        as_leave_reply_payload => LeaveReply(LeaveReplyPayload),
1111        as_group_add_payload => GroupAdd(GroupAddPayload),
1112        as_group_remove_payload => GroupRemove(GroupRemovePayload),
1113        as_welcome_payload => GroupWelcome(GroupWelcomePayload),
1114        as_group_close_payload => GroupClose(GroupClosePayload),
1115        as_group_proposal_payload => GroupProposal(GroupProposalPayload),
1116        as_group_ack_payload => GroupAck(GroupAckPayload),
1117        as_group_nack_payload => GroupNack(GroupNackPayload),
1118        as_ping_payload => Ping(PingPayload),
1119    }
1120}
1121
1122impl AsRef<ProtoPublish> for ProtoMessage {
1123    fn as_ref(&self) -> &ProtoPublish {
1124        match &self.message_type {
1125            Some(ProtoPublishType(p)) => p,
1126            _ => panic!("message type is not publish"),
1127        }
1128    }
1129}
1130
1131/// Builder for creating CommandPayload instances with a fluent API
1132///
1133/// Provides methods for creating all types of command payloads.
1134///
1135/// # Examples
1136///
1137/// ## Discovery Request
1138/// ```
1139/// use slim_datapath::api::CommandPayload;
1140///
1141/// let payload = CommandPayload::builder().discovery_request();
1142/// ```
1143///
1144/// ## Join Request with Timer Settings
1145/// ```
1146/// use slim_datapath::api::CommandPayload;
1147/// use slim_datapath::api::ProtoName;
1148/// use std::time::Duration;
1149///
1150/// let channel = ProtoName::from_strings(["org", "namespace", "channel"]);
1151/// let payload = CommandPayload::builder().join_request(
1152///     Some(5),  // max_retries
1153///     Some(Duration::from_secs(10)),  // timeout
1154///     Some(channel),
1155///     None, // mls_settings
1156/// );
1157/// ```
1158///
1159/// ## Group Operations
1160/// ```
1161/// use slim_datapath::api::{CommandPayload, Participant};
1162/// use slim_datapath::api::ProtoName;
1163///
1164/// let participant = Participant { name: Some(ProtoName::from_strings(["org", "ns", "user1"])), settings: None };
1165/// let participants = vec![
1166///     Participant { name: Some(ProtoName::from_strings(["org", "ns", "user2"])), settings: None },
1167///     Participant { name: Some(ProtoName::from_strings(["org", "ns", "user3"])), settings: None },
1168/// ];
1169///
1170/// // Add participant
1171/// let add_payload = CommandPayload::builder().group_add(
1172///     participant.clone(),
1173///     participants.clone(),
1174///     None,  // mls payload
1175/// );
1176/// ```
1177pub struct CommandPayloadBuilder;
1178
1179impl CommandPayloadBuilder {
1180    /// Creates a new CommandPayloadBuilder
1181    pub fn new() -> Self {
1182        Self
1183    }
1184
1185    /// Creates a discovery request payload
1186    pub fn discovery_request(self) -> CommandPayload {
1187        let payload = DiscoveryRequestPayload {};
1188        CommandPayload {
1189            command_payload_type: Some(CommandPayloadType::DiscoveryRequest(payload)),
1190        }
1191    }
1192
1193    /// Creates a discovery reply payload
1194    pub fn discovery_reply(self) -> CommandPayload {
1195        let payload = DiscoveryReplyPayload {};
1196        CommandPayload {
1197            command_payload_type: Some(CommandPayloadType::DiscoveryReply(payload)),
1198        }
1199    }
1200
1201    /// Creates a join request payload
1202    #[allow(deprecated)]
1203    pub fn join_request(
1204        self,
1205        max_retries: Option<u32>,
1206        timer_duration: Option<Duration>,
1207        channel: Option<ProtoName>,
1208        mls_settings: Option<MlsSettings>,
1209    ) -> CommandPayload {
1210        let proto_channel = channel;
1211
1212        let timer_settings = if let Some(t) = timer_duration
1213            && let Some(m) = max_retries
1214        {
1215            Some(TimerSettings {
1216                timeout: t.as_millis() as u32,
1217                max_retries: m,
1218            })
1219        } else {
1220            None
1221        };
1222
1223        let payload = JoinRequestPayload {
1224            timer_settings,
1225            channel: proto_channel,
1226            mls_settings,
1227        };
1228        CommandPayload {
1229            command_payload_type: Some(CommandPayloadType::JoinRequest(payload)),
1230        }
1231    }
1232
1233    /// Creates a join reply payload
1234    pub fn join_reply(
1235        self,
1236        key_package: Option<Vec<u8>>,
1237        participant: Participant,
1238    ) -> CommandPayload {
1239        let payload = JoinReplyPayload {
1240            key_package,
1241            participant: Some(participant),
1242        };
1243        CommandPayload {
1244            command_payload_type: Some(CommandPayloadType::JoinReply(payload)),
1245        }
1246    }
1247
1248    /// Creates a leave request payload
1249    pub fn leave_request(self) -> CommandPayload {
1250        let payload = LeaveRequestPayload {};
1251        CommandPayload {
1252            command_payload_type: Some(CommandPayloadType::LeaveRequest(payload)),
1253        }
1254    }
1255
1256    /// Creates a leave reply payload
1257    pub fn leave_reply(self) -> CommandPayload {
1258        let payload = LeaveReplyPayload {};
1259        CommandPayload {
1260            command_payload_type: Some(CommandPayloadType::LeaveReply(payload)),
1261        }
1262    }
1263
1264    /// Creates a group add payload
1265    pub fn group_add(
1266        self,
1267        new_participant: Participant,
1268        participants: Vec<Participant>,
1269        mls: Option<MlsPayload>,
1270    ) -> CommandPayload {
1271        let payload = GroupAddPayload {
1272            new_participant: Some(new_participant),
1273            participants,
1274            mls,
1275        };
1276        CommandPayload {
1277            command_payload_type: Some(CommandPayloadType::GroupAdd(payload)),
1278        }
1279    }
1280
1281    /// Creates a group remove payload
1282    pub fn group_remove(
1283        self,
1284        removed_participant: ProtoName,
1285        participants: Vec<ProtoName>,
1286        mls: Option<MlsPayload>,
1287    ) -> CommandPayload {
1288        let payload = GroupRemovePayload {
1289            removed_participant: Some(removed_participant),
1290            participants,
1291            mls,
1292        };
1293        CommandPayload {
1294            command_payload_type: Some(CommandPayloadType::GroupRemove(payload)),
1295        }
1296    }
1297
1298    /// Creates a group welcome payload
1299    pub fn group_welcome(
1300        self,
1301        participants: Vec<Participant>,
1302        mls: Option<MlsPayload>,
1303    ) -> CommandPayload {
1304        let payload = GroupWelcomePayload { participants, mls };
1305        CommandPayload {
1306            command_payload_type: Some(CommandPayloadType::GroupWelcome(payload)),
1307        }
1308    }
1309
1310    /// Creates a group close payload
1311    pub fn group_close(self, participants: Vec<ProtoName>) -> CommandPayload {
1312        let payload = GroupClosePayload { participants };
1313        CommandPayload {
1314            command_payload_type: Some(CommandPayloadType::GroupClose(payload)),
1315        }
1316    }
1317
1318    /// Creates a group proposal payload
1319    pub fn group_proposal(
1320        self,
1321        source: Option<ProtoName>,
1322        mls_proposal: Vec<u8>,
1323    ) -> CommandPayload {
1324        let payload = GroupProposalPayload {
1325            source,
1326            mls_proposal,
1327        };
1328        CommandPayload {
1329            command_payload_type: Some(CommandPayloadType::GroupProposal(payload)),
1330        }
1331    }
1332
1333    /// Creates a group ack payload
1334    pub fn group_ack(self) -> CommandPayload {
1335        let payload = GroupAckPayload {};
1336        CommandPayload {
1337            command_payload_type: Some(CommandPayloadType::GroupAck(payload)),
1338        }
1339    }
1340
1341    /// Creates a group nack payload
1342    pub fn group_nack(self) -> CommandPayload {
1343        let payload = GroupNackPayload {};
1344        CommandPayload {
1345            command_payload_type: Some(CommandPayloadType::GroupNack(payload)),
1346        }
1347    }
1348
1349    /// Creates a ping payload
1350    pub fn ping(self) -> CommandPayload {
1351        let payload = PingPayload {};
1352        CommandPayload {
1353            command_payload_type: Some(CommandPayloadType::Ping(payload)),
1354        }
1355    }
1356}
1357
1358impl Default for CommandPayloadBuilder {
1359    fn default() -> Self {
1360        Self::new()
1361    }
1362}
1363
1364impl CommandPayload {
1365    /// Creates a new builder for CommandPayload
1366    pub fn builder() -> CommandPayloadBuilder {
1367        CommandPayloadBuilder::new()
1368    }
1369}
1370
1371/// Builder for creating ProtoMessage instances with a fluent API
1372///
1373/// # Examples
1374///
1375/// ## Basic Publish Message
1376/// ```
1377/// use slim_datapath::api::{ProtoMessage, ProtoSessionType};
1378/// use slim_datapath::api::ProtoName;
1379///
1380/// let source = ProtoName::from_strings(["org", "ns", "app"]).with_id(1);
1381/// let dest = ProtoName::from_strings(["org", "ns", "service"]).with_id(2);
1382///
1383/// let msg = ProtoMessage::builder()
1384///     .source(source)
1385///     .destination(dest)
1386///     .session_type(ProtoSessionType::PointToPoint)
1387///     .session_id(123)
1388///     .application_payload("text", b"Hello".to_vec())
1389///     .build_publish()
1390///     .unwrap();
1391/// ```
1392///
1393/// ## Session Control Message
1394/// ```
1395/// use slim_datapath::api::{CommandPayload, ProtoMessage, ProtoSessionType, ProtoSessionMessageType};
1396/// use slim_datapath::api::ProtoName;
1397///
1398/// let source = ProtoName::from_strings(["org", "ns", "app"]);
1399/// let dest = ProtoName::from_strings(["org", "ns", "service"]);
1400///
1401/// let cmd = CommandPayload::builder().discovery_request();
1402///
1403/// let msg = ProtoMessage::builder()
1404///     .source(source)
1405///     .destination(dest)
1406///     .session_type(ProtoSessionType::PointToPoint)
1407///     .session_message_type(ProtoSessionMessageType::DiscoveryRequest)
1408///     .session_id(42)
1409///     .command_payload(cmd)
1410///     .build_publish()
1411///     .unwrap();
1412/// ```
1413///
1414/// ## Multicast with Broadcast
1415/// ```
1416/// use slim_datapath::api::{ProtoMessage, ProtoSessionType};
1417/// use slim_datapath::api::ProtoName;
1418///
1419/// let source = ProtoName::from_strings(["org", "ns", "app"]);
1420/// let dest = ProtoName::from_strings(["org", "ns", "channel"]);
1421///
1422/// let msg = ProtoMessage::builder()
1423///     .source(source)
1424///     .destination(dest)
1425///     .session_type(ProtoSessionType::Multicast)
1426///     .fanout(256)
1427///     .application_payload("event", b"broadcast event".to_vec())
1428///     .metadata("priority", "high")
1429///     .build_publish()
1430///     .unwrap();
1431/// ```
1432///
1433/// ## Subscribe/Unsubscribe Messages
1434/// ```
1435/// use slim_datapath::api::ProtoMessage;
1436/// use slim_datapath::api::ProtoName;
1437///
1438/// let source = ProtoName::from_strings(["org", "ns", "app"]);
1439/// let dest = ProtoName::from_strings(["org", "ns", "topic"]);
1440///
1441/// // Subscribe
1442/// let sub_msg = ProtoMessage::builder()
1443///     .source(source.clone())
1444///     .destination(dest.clone())
1445///     .recv_from(100)
1446///     .build_subscribe()
1447///     .unwrap();
1448///
1449/// // Unsubscribe
1450/// let unsub_msg = ProtoMessage::builder()
1451///     .source(source)
1452///     .destination(dest)
1453///     .build_unsubscribe()
1454///     .unwrap();
1455/// ```
1456pub struct ProtoMessageBuilder {
1457    source: Option<ProtoName>,
1458    destination: Option<ProtoName>,
1459    identity: Option<String>,
1460    flags: Option<SlimHeaderFlags>,
1461    session_type: Option<ProtoSessionType>,
1462    session_message_type: Option<SessionMessageType>,
1463    session_id: Option<u32>,
1464    message_id: Option<u32>,
1465    payload: Option<Content>,
1466    metadata: HashMap<String, String>,
1467    subscription_id: Option<u64>,
1468}
1469
1470impl ProtoMessageBuilder {
1471    /// Creates a new ProtoMessageBuilder
1472    pub fn new() -> Self {
1473        Self {
1474            source: None,
1475            destination: None,
1476            identity: None,
1477            flags: None,
1478            session_type: None,
1479            session_message_type: None,
1480            session_id: None,
1481            message_id: None,
1482            payload: None,
1483            metadata: HashMap::new(),
1484            subscription_id: None,
1485        }
1486    }
1487
1488    /// Sets the source name
1489    pub fn source(mut self, source: ProtoName) -> Self {
1490        self.source = Some(source);
1491        self
1492    }
1493
1494    /// Sets the destination name
1495    pub fn destination(mut self, destination: ProtoName) -> Self {
1496        self.destination = Some(destination);
1497        self
1498    }
1499
1500    /// Sets the identity string
1501    pub fn identity(mut self, identity: impl Into<String>) -> Self {
1502        self.identity = Some(identity.into());
1503        self
1504    }
1505
1506    /// Sets the SLIM header flags
1507    pub fn flags(mut self, flags: SlimHeaderFlags) -> Self {
1508        self.flags = Some(flags);
1509        self
1510    }
1511
1512    /// Sets the fanout value
1513    pub fn fanout(mut self, fanout: u32) -> Self {
1514        self.flags.get_or_insert_default().fanout = fanout;
1515        self
1516    }
1517
1518    /// Sets the recv_from connection
1519    pub fn recv_from(mut self, recv_from: u64) -> Self {
1520        self.flags.get_or_insert_default().recv_from = Some(recv_from);
1521        self
1522    }
1523
1524    /// Sets the forward_to connection
1525    pub fn forward_to(mut self, forward_to: u64) -> Self {
1526        self.flags.get_or_insert_default().forward_to = Some(forward_to);
1527        self
1528    }
1529
1530    /// Sets the incoming connection
1531    pub fn incoming_conn(mut self, incoming_conn: u64) -> Self {
1532        self.flags.get_or_insert_default().incoming_conn = Some(incoming_conn);
1533        self
1534    }
1535
1536    /// Sets the error flag
1537    pub fn error(mut self, error: bool) -> Self {
1538        self.flags.get_or_insert_default().error = Some(error);
1539        self
1540    }
1541
1542    /// Sets the TTL (time-to-live) value
1543    pub fn ttl(mut self, ttl: u32) -> Self {
1544        self.flags.get_or_insert_default().ttl = ttl;
1545        self
1546    }
1547
1548    /// Sets the session type
1549    pub fn session_type(mut self, session_type: ProtoSessionType) -> Self {
1550        self.session_type = Some(session_type);
1551        self
1552    }
1553
1554    /// Sets the session message type
1555    pub fn session_message_type(mut self, session_message_type: SessionMessageType) -> Self {
1556        self.session_message_type = Some(session_message_type);
1557        self
1558    }
1559
1560    /// Sets the session ID
1561    pub fn session_id(mut self, session_id: u32) -> Self {
1562        self.session_id = Some(session_id);
1563        self
1564    }
1565
1566    /// Sets the message ID
1567    pub fn message_id(mut self, message_id: u32) -> Self {
1568        self.message_id = Some(message_id);
1569        self
1570    }
1571
1572    /// Sets the message payload
1573    pub fn payload(mut self, payload: Content) -> Self {
1574        self.payload = Some(payload);
1575        self
1576    }
1577
1578    /// Sets an application payload
1579    pub fn application_payload(mut self, payload_type: &str, blob: Vec<u8>) -> Self {
1580        let app_payload = ApplicationPayload::new(payload_type, blob);
1581        self.payload = Some(app_payload.as_content());
1582        self
1583    }
1584
1585    /// Sets a command payload
1586    pub fn command_payload(mut self, payload: CommandPayload) -> Self {
1587        self.payload = Some(payload.as_content());
1588        self
1589    }
1590
1591    /// Sets a pre-built SlimHeader (for low-level use cases)
1592    ///
1593    /// This is a convenience method for cases where you already have a constructed SlimHeader.
1594    /// For most cases, prefer using the individual builder methods like `source()`, `destination()`, etc.
1595    pub fn with_slim_header(mut self, header: SlimHeader) -> Self {
1596        // Extract fields from the header
1597        if let Some(src) = header.source.clone() {
1598            self.source = Some(src);
1599        }
1600        if let Some(dst) = header.destination.clone() {
1601            self.destination = Some(dst);
1602        }
1603        if !header.identity.is_empty() {
1604            self.identity = Some(header.identity.clone());
1605        }
1606
1607        // Extract flags
1608        let flags = SlimHeaderFlags {
1609            fanout: header.fanout,
1610            recv_from: header.recv_from,
1611            forward_to: header.forward_to,
1612            incoming_conn: header.incoming_conn,
1613            error: header.error,
1614            ttl: header.ttl,
1615        };
1616        self.flags = Some(flags);
1617        self
1618    }
1619
1620    /// Sets a pre-built SessionHeader (for low-level use cases)
1621    ///
1622    /// This is a convenience method for cases where you already have a constructed SessionHeader.
1623    /// For most cases, prefer using the individual builder methods like `session_type()`, `session_message_type()`, etc.
1624    pub fn with_session_header(mut self, header: SessionHeader) -> Self {
1625        self.session_type = Some(
1626            ProtoSessionType::try_from(header.session_type)
1627                .unwrap_or(ProtoSessionType::PointToPoint),
1628        );
1629        self.session_message_type = Some(
1630            SessionMessageType::try_from(header.session_message_type)
1631                .unwrap_or(SessionMessageType::Msg),
1632        );
1633        self.session_id = Some(header.session_id);
1634        self.message_id = Some(header.message_id);
1635        self
1636    }
1637
1638    /// Adds metadata to the message
1639    pub fn metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
1640        self.metadata.insert(key.into(), value.into());
1641        self
1642    }
1643
1644    /// Adds multiple metadata entries
1645    pub fn metadata_map(mut self, map: HashMap<String, String>) -> Self {
1646        self.metadata.extend(map);
1647        self
1648    }
1649
1650    /// Sets the subscription_id for subscribe/unsubscribe messages.
1651    pub fn subscription_id(mut self, id: u64) -> Self {
1652        self.subscription_id = Some(id);
1653        self
1654    }
1655
1656    /// Builds a publish message
1657    pub fn build_publish(self) -> Result<ProtoMessage, MessageError> {
1658        let source = self
1659            .source
1660            .ok_or(MessageError::BuilderErrorSourceRequired)?;
1661        let destination = self
1662            .destination
1663            .ok_or(MessageError::BuilderErrorDestinationRequired)?;
1664
1665        let slim_header = Some(SlimHeader::new(
1666            source,
1667            destination,
1668            self.identity.as_deref().unwrap_or(""),
1669            self.flags,
1670        ));
1671
1672        let session_header = if self.session_type.is_some() || self.session_message_type.is_some() {
1673            Some(SessionHeader::new(
1674                self.session_type
1675                    .unwrap_or(ProtoSessionType::PointToPoint)
1676                    .into(),
1677                self.session_message_type
1678                    .unwrap_or(SessionMessageType::Msg)
1679                    .into(),
1680                self.session_id.unwrap_or(0),
1681                self.message_id.unwrap_or_else(rand::random),
1682            ))
1683        } else {
1684            Some(SessionHeader::default())
1685        };
1686
1687        let publish = ProtoPublish::with_header(slim_header, session_header, self.payload);
1688        let message = ProtoMessage::new(self.metadata, ProtoPublishType(publish));
1689        Ok(message)
1690    }
1691
1692    /// Builds a subscribe message
1693    pub fn build_subscribe(self) -> Result<ProtoMessage, MessageError> {
1694        let source = self
1695            .source
1696            .ok_or(MessageError::BuilderErrorSourceRequired)?;
1697        let destination = self
1698            .destination
1699            .ok_or(MessageError::BuilderErrorDestinationRequired)?;
1700
1701        let mut subscribe =
1702            ProtoSubscribe::new(source, destination, self.identity.as_deref(), self.flags);
1703        subscribe.subscription_id = self.subscription_id.unwrap_or_default();
1704
1705        Ok(ProtoMessage::new(
1706            self.metadata,
1707            ProtoSubscribeType(subscribe),
1708        ))
1709    }
1710
1711    /// Builds an unsubscribe message
1712    pub fn build_unsubscribe(self) -> Result<ProtoMessage, MessageError> {
1713        let source = self
1714            .source
1715            .ok_or(MessageError::BuilderErrorSourceRequired)?;
1716        let destination = self
1717            .destination
1718            .ok_or(MessageError::BuilderErrorDestinationRequired)?;
1719
1720        let mut unsubscribe =
1721            ProtoUnsubscribe::new(source, destination, self.identity.as_deref(), self.flags);
1722        unsubscribe.subscription_id = self.subscription_id.unwrap_or_default();
1723
1724        Ok(ProtoMessage::new(
1725            self.metadata,
1726            ProtoUnsubscribeType(unsubscribe),
1727        ))
1728    }
1729
1730    /// Builds a subscription ack message.
1731    /// SubscriptionAck messages are delivered directly to the requesting connection
1732    /// and are never routed through the subscription table.
1733    pub fn build_subscription_ack(
1734        self,
1735        subscription_id: u64,
1736        success: bool,
1737        error: impl Into<String>,
1738    ) -> ProtoMessage {
1739        let ack = ProtoSubscriptionAck {
1740            subscription_id,
1741            success,
1742            error: error.into(),
1743        };
1744        ProtoMessage::new(self.metadata, ProtoSubscriptionAckType(ack))
1745    }
1746
1747    /// Builds a link negotiation message.
1748    /// Link messages are link-local and never routed; they carry no SLIM header.
1749    #[allow(clippy::too_many_arguments)]
1750    pub fn build_link_negotiation(
1751        self,
1752        link_id: impl Into<String>,
1753        slim_version: impl Into<String>,
1754        is_reply: bool,
1755        link_ecdh_public_key: Option<Vec<u8>>,
1756        connection_type: LinkConnectionType,
1757        node_id: impl Into<String>,
1758        deployment_name: impl Into<String>,
1759    ) -> ProtoMessage {
1760        let link_ecdh_public_key = link_ecdh_public_key.unwrap_or_default();
1761        let link = ProtoLink {
1762            link_type: Some(ProtoLinkType::LinkNegotiation(LinkNegotiationPayload {
1763                link_id: link_id.into(),
1764                slim_version: slim_version.into(),
1765                is_reply,
1766                link_ecdh_public_key,
1767                connection_type: connection_type.into(),
1768                node_id: node_id.into(),
1769                deployment_name: deployment_name.into(),
1770            })),
1771        };
1772        ProtoMessage::new(self.metadata, ProtoLinkMessageType(link))
1773    }
1774}
1775
1776impl Default for ProtoMessageBuilder {
1777    fn default() -> Self {
1778        Self::new()
1779    }
1780}
1781
1782impl ProtoMessage {
1783    /// Creates a new builder for ProtoMessage
1784    pub fn builder() -> ProtoMessageBuilder {
1785        ProtoMessageBuilder::new()
1786    }
1787}
1788
1789#[cfg(test)]
1790mod tests {
1791    use crate::api::proto::dataplane::v1::SessionMessageType;
1792
1793    use super::*;
1794
1795    fn test_subscription_template(
1796        subscription: bool,
1797        source: ProtoName,
1798        dst: ProtoName,
1799        identity: Option<&str>,
1800        flags: Option<SlimHeaderFlags>,
1801    ) {
1802        let sub = {
1803            let mut builder = ProtoMessage::builder()
1804                .source(source.clone())
1805                .destination(dst.clone());
1806
1807            if let Some(id) = identity {
1808                builder = builder.identity(id);
1809            }
1810
1811            if let Some(f) = flags.clone() {
1812                builder = builder.flags(f);
1813            }
1814
1815            if subscription {
1816                builder.build_subscribe().unwrap()
1817            } else {
1818                builder.build_unsubscribe().unwrap()
1819            }
1820        };
1821
1822        let flags = if flags.is_none() {
1823            Some(SlimHeaderFlags::default())
1824        } else {
1825            flags
1826        };
1827
1828        assert!(!sub.is_publish());
1829        assert_eq!(sub.is_subscribe(), subscription);
1830        assert_eq!(sub.is_unsubscribe(), !subscription);
1831        assert_eq!(flags.as_ref().unwrap().recv_from, sub.get_recv_from());
1832        assert_eq!(flags.as_ref().unwrap().forward_to, sub.get_forward_to());
1833        assert_eq!(None, sub.try_get_incoming_conn());
1834        assert_eq!(source, sub.get_source());
1835        let got_name = sub.get_dst();
1836        assert_eq!(dst, got_name);
1837    }
1838
1839    fn test_publish_template(
1840        source: ProtoName,
1841        dst: ProtoName,
1842        identity: Option<&str>,
1843        flags: Option<SlimHeaderFlags>,
1844    ) {
1845        let mut builder = ProtoMessage::builder()
1846            .source(source.clone())
1847            .destination(dst.clone())
1848            .application_payload("str", "this is the content of the message".into());
1849
1850        if let Some(id) = identity {
1851            builder = builder.identity(id);
1852        }
1853
1854        if let Some(f) = flags.clone() {
1855            builder = builder.flags(f);
1856        }
1857
1858        let pub_msg = builder.build_publish().unwrap();
1859
1860        let flags = if flags.is_none() {
1861            Some(SlimHeaderFlags::default())
1862        } else {
1863            flags
1864        };
1865
1866        assert!(pub_msg.is_publish());
1867        assert!(!pub_msg.is_subscribe());
1868        assert!(!pub_msg.is_unsubscribe());
1869        assert_eq!(flags.as_ref().unwrap().recv_from, pub_msg.get_recv_from());
1870        assert_eq!(flags.as_ref().unwrap().forward_to, pub_msg.get_forward_to());
1871        assert_eq!(None, pub_msg.try_get_incoming_conn());
1872        assert_eq!(source, pub_msg.get_source());
1873        let got_name = pub_msg.get_dst();
1874        assert_eq!(dst, got_name);
1875        assert_eq!(flags.as_ref().unwrap().fanout, pub_msg.get_fanout());
1876    }
1877
1878    #[test]
1879    fn test_subscription() {
1880        let source = ProtoName::from_strings(["org", "ns", "type"]).with_id(1);
1881        let dst = ProtoName::from_strings(["org", "ns", "type"]).with_id(2);
1882
1883        // simple
1884        test_subscription_template(true, source.clone(), dst.clone(), None, None);
1885
1886        // with name id
1887        test_subscription_template(true, source.clone(), dst.clone(), None, None);
1888
1889        // with recv from
1890        test_subscription_template(
1891            true,
1892            source.clone(),
1893            dst.clone(),
1894            None,
1895            Some(SlimHeaderFlags::default().with_recv_from(50)),
1896        );
1897
1898        // with forward to
1899        test_subscription_template(
1900            true,
1901            source.clone(),
1902            dst.clone(),
1903            None,
1904            Some(SlimHeaderFlags::default().with_forward_to(30)),
1905        );
1906    }
1907
1908    #[test]
1909    fn test_unsubscription() {
1910        let source = ProtoName::from_strings(["org", "ns", "type"]).with_id(1);
1911        let dst = ProtoName::from_strings(["org", "ns", "type"]).with_id(2);
1912
1913        // simple
1914        test_subscription_template(false, source.clone(), dst.clone(), None, None);
1915
1916        // with name id
1917        test_subscription_template(false, source.clone(), dst.clone(), None, None);
1918
1919        // with recv from
1920        test_subscription_template(
1921            false,
1922            source.clone(),
1923            dst.clone(),
1924            None,
1925            Some(SlimHeaderFlags::default().with_recv_from(50)),
1926        );
1927
1928        // with forward to
1929        test_subscription_template(
1930            false,
1931            source.clone(),
1932            dst.clone(),
1933            None,
1934            Some(SlimHeaderFlags::default().with_forward_to(30)),
1935        );
1936    }
1937
1938    #[test]
1939    fn test_publish() {
1940        let source = ProtoName::from_strings(["org", "ns", "type"]).with_id(1);
1941        let mut dst = ProtoName::from_strings(["org", "ns", "type"]);
1942
1943        // simple
1944        test_publish_template(
1945            source.clone(),
1946            dst.clone(),
1947            None,
1948            Some(SlimHeaderFlags::default()),
1949        );
1950
1951        // with name id
1952        dst.set_id(2);
1953        test_publish_template(
1954            source.clone(),
1955            dst.clone(),
1956            None,
1957            Some(SlimHeaderFlags::default()),
1958        );
1959        dst.reset_id();
1960
1961        // with recv from
1962        test_publish_template(
1963            source.clone(),
1964            dst.clone(),
1965            None,
1966            Some(SlimHeaderFlags::default().with_recv_from(50)),
1967        );
1968
1969        // with forward to
1970        test_publish_template(
1971            source.clone(),
1972            dst.clone(),
1973            None,
1974            Some(SlimHeaderFlags::default().with_forward_to(30)),
1975        );
1976
1977        // with fanout
1978        test_publish_template(
1979            source.clone(),
1980            dst.clone(),
1981            None,
1982            Some(SlimHeaderFlags::default().with_fanout(2)),
1983        );
1984    }
1985
1986    #[test]
1987    fn test_conversions() {
1988        // ProtoMessage to ProtoSubscribe
1989        let name = ProtoName::from_strings(["org", "ns", "type"]).with_id(1);
1990        let dst = ProtoName::from_strings(["org", "ns", "type"]).with_id(1);
1991        let proto_subscribe = ProtoMessage::builder()
1992            .source(name.clone())
1993            .destination(dst.clone())
1994            .flags(
1995                SlimHeaderFlags::default()
1996                    .with_recv_from(2)
1997                    .with_forward_to(3),
1998            )
1999            .build_subscribe()
2000            .unwrap();
2001        let proto_subscribe = ProtoSubscribe::from(proto_subscribe);
2002        assert_eq!(proto_subscribe.header.as_ref().unwrap().get_source(), name);
2003        assert_eq!(proto_subscribe.header.as_ref().unwrap().get_dst(), dst,);
2004
2005        // ProtoMessage to ProtoUnsubscribe
2006        let proto_unsubscribe = ProtoMessage::builder()
2007            .source(name.clone())
2008            .destination(dst.clone())
2009            .flags(
2010                SlimHeaderFlags::default()
2011                    .with_recv_from(2)
2012                    .with_forward_to(3),
2013            )
2014            .build_unsubscribe()
2015            .unwrap();
2016        let proto_unsubscribe = ProtoUnsubscribe::from(proto_unsubscribe);
2017        assert_eq!(
2018            proto_unsubscribe.header.as_ref().unwrap().get_source(),
2019            name
2020        );
2021        assert_eq!(proto_unsubscribe.header.as_ref().unwrap().get_dst(), dst);
2022
2023        // ProtoMessage to ProtoPublish
2024        let proto_publish = ProtoMessage::builder()
2025            .source(name.clone())
2026            .destination(dst.clone())
2027            .flags(
2028                SlimHeaderFlags::default()
2029                    .with_recv_from(2)
2030                    .with_forward_to(3),
2031            )
2032            .application_payload("str", "this is the content of the message".into())
2033            .build_publish()
2034            .unwrap();
2035        let proto_publish = ProtoPublish::from(proto_publish);
2036        assert_eq!(proto_publish.header.as_ref().unwrap().get_source(), name);
2037        assert_eq!(proto_publish.header.as_ref().unwrap().get_dst(), dst);
2038    }
2039
2040    #[test]
2041    fn test_panic() {
2042        let source = ProtoName::from_strings(["org", "ns", "type"]).with_id(1);
2043        let dst = ProtoName::from_strings(["org", "ns", "type"]).with_id(2);
2044
2045        // panic if SLIM header is not found
2046        let msg = ProtoMessage::builder()
2047            .source(source.clone())
2048            .destination(dst.clone())
2049            .flags(
2050                SlimHeaderFlags::default()
2051                    .with_recv_from(2)
2052                    .with_forward_to(3),
2053            )
2054            .build_subscribe()
2055            .unwrap();
2056
2057        // let's try to convert it to a unsubscribe
2058        // this should panic because the message type is not unsubscribe
2059        let result = std::panic::catch_unwind(|| ProtoUnsubscribe::from(msg.clone()));
2060        assert!(result.is_err());
2061
2062        // try to convert to publish
2063        // this should panic because the message type is not publish
2064        let result = std::panic::catch_unwind(|| ProtoPublish::from(msg.clone()));
2065        assert!(result.is_err());
2066
2067        // finally make sure the conversion to subscribe works
2068        let result = std::panic::catch_unwind(|| ProtoSubscribe::from(msg));
2069        assert!(result.is_ok());
2070    }
2071
2072    #[test]
2073    fn test_panic_header() {
2074        // create a unusual SLIM header
2075        let header = SlimHeader {
2076            source: None,
2077            destination: None,
2078            identity: String::new(),
2079            fanout: 0,
2080            version: version().to_string(),
2081            recv_from: None,
2082            forward_to: None,
2083            incoming_conn: None,
2084            error: None,
2085            header_mac: None,
2086            ttl: DEFAULT_TTL,
2087        };
2088
2089        // the operations to retrieve source and destination should fail with panic
2090        let result = std::panic::catch_unwind(|| header.get_source());
2091        assert!(result.is_err());
2092
2093        let result = std::panic::catch_unwind(|| header.get_dst());
2094        assert!(result.is_err());
2095
2096        // The operations to retrieve recv_from and forward_to should not fail with panic
2097        let result = std::panic::catch_unwind(|| header.get_recv_from());
2098        assert!(result.is_ok());
2099
2100        let result = std::panic::catch_unwind(|| header.get_forward_to());
2101        assert!(result.is_ok());
2102
2103        // The operations to retrieve incoming_conn should not fail with panic
2104        let result = std::panic::catch_unwind(|| header.get_incoming_conn());
2105        assert!(result.is_ok());
2106
2107        // The operations to retrieve error should not fail with panic
2108        let result = std::panic::catch_unwind(|| header.get_error());
2109        assert!(result.is_ok());
2110    }
2111
2112    #[test]
2113    fn test_panic_session_header() {
2114        // create a unusual session header
2115        let header = SessionHeader::new(0, 0, 0, 0);
2116
2117        // the operations to retrieve session_id and message_id should not fail with panic
2118        let result = std::panic::catch_unwind(|| header.get_session_id());
2119        assert!(result.is_ok());
2120
2121        let result = std::panic::catch_unwind(|| header.get_message_id());
2122        assert!(result.is_ok());
2123    }
2124
2125    #[test]
2126    fn test_panic_proto_message() {
2127        // create a unusual proto message
2128        let message = ProtoMessage {
2129            metadata: HashMap::new(),
2130            message_type: None,
2131        };
2132
2133        // the operation to retrieve the header should fail with panic
2134        let result = std::panic::catch_unwind(|| message.get_slim_header());
2135        assert!(result.is_err());
2136
2137        // the operation to retrieve the message type should fail with panic
2138        let result = std::panic::catch_unwind(|| message.get_type());
2139        assert!(result.is_err());
2140
2141        // all the other ops should fail with panic as well as the header is not set
2142        let result = std::panic::catch_unwind(|| message.get_source());
2143        assert!(result.is_err());
2144        let result = std::panic::catch_unwind(|| message.get_dst());
2145        assert!(result.is_err());
2146        let result = std::panic::catch_unwind(|| message.get_recv_from());
2147        assert!(result.is_err());
2148        let result = std::panic::catch_unwind(|| message.get_forward_to());
2149        assert!(result.is_err());
2150        let result = std::panic::catch_unwind(|| message.get_incoming_conn());
2151        assert!(result.is_err());
2152        let result = std::panic::catch_unwind(|| message.get_fanout());
2153        assert!(result.is_err());
2154    }
2155
2156    #[test]
2157    fn test_service_type_to_int() {
2158        // Get total number of service types
2159        let total_service_types = SessionMessageType::Ping as i32;
2160
2161        for i in 0..total_service_types {
2162            // int -> ServiceType
2163            let service_type =
2164                SessionMessageType::try_from(i).expect("failed to convert int to service type");
2165            let service_type_int = i32::from(service_type);
2166            assert_eq!(service_type_int, i32::from(service_type),);
2167        }
2168
2169        // Test invalid conversion
2170        let invalid_service_type = SessionMessageType::try_from(total_service_types + 1);
2171        assert!(invalid_service_type.is_err());
2172    }
2173
2174    #[test]
2175    fn test_proto_message_builder() {
2176        let source = ProtoName::from_strings(["org", "ns", "type"]).with_id(1);
2177        let dest = ProtoName::from_strings(["org", "ns", "app"]).with_id(2);
2178
2179        // Test basic publish message
2180        let msg = ProtoMessage::builder()
2181            .source(source.clone())
2182            .destination(dest.clone())
2183            .application_payload("test", b"hello world".to_vec())
2184            .build_publish()
2185            .unwrap();
2186
2187        assert!(msg.is_publish());
2188        assert_eq!(msg.get_source(), source);
2189        assert_eq!(msg.get_dst(), dest);
2190
2191        // Test with session headers
2192        let msg = ProtoMessage::builder()
2193            .source(source.clone())
2194            .destination(dest.clone())
2195            .session_type(ProtoSessionType::Multicast)
2196            .session_message_type(SessionMessageType::Msg)
2197            .session_id(42)
2198            .message_id(100)
2199            .fanout(256)
2200            .application_payload("test", b"broadcast".to_vec())
2201            .build_publish()
2202            .unwrap();
2203
2204        assert_eq!(msg.get_session_type(), ProtoSessionType::Multicast);
2205        assert_eq!(msg.get_id(), 100);
2206        assert_eq!(msg.get_fanout(), 256);
2207
2208        // Test with metadata
2209        let msg = ProtoMessage::builder()
2210            .source(source.clone())
2211            .destination(dest.clone())
2212            .metadata("key1", "value1")
2213            .metadata("key2", "value2")
2214            .application_payload("test", vec![1, 2, 3])
2215            .build_publish()
2216            .unwrap();
2217
2218        assert_eq!(msg.get_metadata("key1"), Some(&"value1".to_string()));
2219        assert_eq!(msg.get_metadata("key2"), Some(&"value2".to_string()));
2220
2221        // Test subscribe message
2222        let msg = ProtoMessage::builder()
2223            .source(source.clone())
2224            .destination(dest.clone())
2225            .recv_from(10)
2226            .build_subscribe()
2227            .unwrap();
2228
2229        assert!(msg.is_subscribe());
2230        assert_eq!(msg.get_recv_from(), Some(10));
2231
2232        // Test unsubscribe message
2233        let msg = ProtoMessage::builder()
2234            .source(source.clone())
2235            .destination(dest.clone())
2236            .forward_to(20)
2237            .build_unsubscribe()
2238            .unwrap();
2239
2240        assert!(msg.is_unsubscribe());
2241        assert_eq!(msg.get_forward_to(), Some(20));
2242    }
2243
2244    #[test]
2245    fn test_command_payload_builder() {
2246        let dest = ProtoName::from_strings(["org", "ns", "app"]);
2247
2248        // Test discovery request
2249        let payload = CommandPayload::builder().discovery_request();
2250        assert!(payload.as_discovery_request_payload().is_ok());
2251
2252        // Test discovery reply
2253        let payload = CommandPayload::builder().discovery_reply();
2254        assert!(payload.as_discovery_reply_payload().is_ok());
2255
2256        // Test join request
2257        let payload = CommandPayload::builder().join_request(
2258            Some(5),
2259            Some(Duration::from_secs(10)),
2260            Some(dest.clone()),
2261            Some(MlsSettings::default()),
2262        );
2263        let extracted = payload.as_join_request_payload().unwrap();
2264        assert!(extracted.mls_settings.is_some());
2265        assert!(extracted.timer_settings.is_some());
2266
2267        // Test join reply
2268        let participant = Participant::new(dest.clone(), ParticipantSettings::bidirectional());
2269        let payload =
2270            CommandPayload::builder().join_reply(Some(vec![1, 2, 3]), participant.clone());
2271        let extracted = payload.as_join_reply_payload().unwrap();
2272        assert_eq!(extracted.key_package, Some(vec![1, 2, 3]));
2273        assert_eq!(extracted.participant, Some(participant));
2274
2275        // Test leave request
2276        let payload = CommandPayload::builder().leave_request();
2277        assert!(payload.as_leave_request_payload().is_ok());
2278
2279        // Test leave reply
2280        let payload = CommandPayload::builder().leave_reply();
2281        assert!(payload.as_leave_reply_payload().is_ok());
2282
2283        // Test group add
2284        let participant = Participant::new(dest.clone(), ParticipantSettings::bidirectional());
2285        let participants = vec![participant.clone()];
2286        let payload =
2287            CommandPayload::builder().group_add(participant.clone(), participants.clone(), None);
2288        let extracted = payload.as_group_add_payload().unwrap();
2289        assert_eq!(extracted.new_participant, Some(participant));
2290        assert_eq!(extracted.participants, participants);
2291
2292        // Test group remove
2293        let payload =
2294            CommandPayload::builder().group_remove(dest.clone(), vec![dest.clone()], None);
2295        let extracted = payload.as_group_remove_payload().unwrap();
2296        assert!(extracted.removed_participant.is_some());
2297
2298        // Test group welcome
2299        let payload = CommandPayload::builder().group_welcome(participants.clone(), None);
2300        let extracted = payload.as_welcome_payload().unwrap();
2301        assert!(!extracted.participants.is_empty());
2302
2303        // Test group proposal
2304        let payload = CommandPayload::builder().group_proposal(Some(dest.clone()), vec![4, 5, 6]);
2305        let extracted = payload.as_group_proposal_payload().unwrap();
2306        assert_eq!(extracted.mls_proposal, vec![4, 5, 6]);
2307
2308        // Test group ack
2309        let payload = CommandPayload::builder().group_ack();
2310        assert!(payload.as_group_ack_payload().is_ok());
2311
2312        // Test group nack
2313        let payload = CommandPayload::builder().group_nack();
2314        assert!(payload.as_group_nack_payload().is_ok());
2315
2316        // Test ping
2317        let payload = CommandPayload::builder().ping();
2318        assert!(payload.as_ping_payload().is_ok());
2319    }
2320
2321    #[test]
2322    fn test_builder_with_command_payload() {
2323        let source = ProtoName::from_strings(["org", "ns", "type"]).with_id(1);
2324        let dest = ProtoName::from_strings(["org", "ns", "app"]).with_id(2);
2325
2326        let cmd_payload = CommandPayload::builder().discovery_request();
2327
2328        let msg = ProtoMessage::builder()
2329            .source(source.clone())
2330            .destination(dest.clone())
2331            .session_type(ProtoSessionType::PointToPoint)
2332            .session_message_type(SessionMessageType::DiscoveryRequest)
2333            .session_id(1)
2334            .command_payload(cmd_payload)
2335            .build_publish()
2336            .unwrap();
2337
2338        assert!(msg.is_publish());
2339        assert_eq!(
2340            msg.get_session_message_type(),
2341            SessionMessageType::DiscoveryRequest
2342        );
2343    }
2344
2345    #[test]
2346    fn test_validate_link_without_link_type() {
2347        let link = ProtoLink { link_type: None };
2348        let msg = ProtoMessage::new(HashMap::new(), ProtoLinkMessageType(link));
2349        assert!(matches!(msg.validate(), Err(MessageError::LinkTypeNotSet)));
2350    }
2351
2352    #[test]
2353    fn test_validate_link_with_link_type() {
2354        let link = ProtoLink {
2355            link_type: Some(ProtoLinkType::LinkNegotiation(LinkNegotiationPayload {
2356                link_id: "abc".into(),
2357                slim_version: "1.0.0".into(),
2358                is_reply: false,
2359                link_ecdh_public_key: vec![],
2360                connection_type: 0,
2361                node_id: String::new(),
2362                deployment_name: String::new(),
2363            })),
2364        };
2365        let msg = ProtoMessage::new(HashMap::new(), ProtoLinkMessageType(link));
2366        assert!(msg.validate().is_ok());
2367    }
2368
2369    #[test]
2370    fn test_build_link_negotiation_request() {
2371        let msg = ProtoMessage::builder().build_link_negotiation(
2372            "my-id",
2373            "1.2.3",
2374            false,
2375            None,
2376            LinkConnectionType::Remote,
2377            "",
2378            "",
2379        );
2380        assert!(msg.is_link());
2381        assert!(!msg.is_publish());
2382        assert!(!msg.is_subscribe());
2383        assert!(msg.validate().is_ok());
2384    }
2385
2386    #[test]
2387    fn test_build_link_negotiation_reply() {
2388        let msg = ProtoMessage::builder().build_link_negotiation(
2389            "my-id",
2390            "1.2.3",
2391            true,
2392            None,
2393            LinkConnectionType::Remote,
2394            "",
2395            "",
2396        );
2397        assert!(msg.is_link());
2398        assert!(msg.validate().is_ok());
2399    }
2400
2401    #[test]
2402    fn test_validate_subscribe_missing_source_encoded_name() {
2403        let valid = ProtoName::from_strings(["org", "ns", "agent"]);
2404        let hdr = SlimHeader {
2405            source: Some(ProtoName {
2406                name: None,
2407                str_name: None,
2408            }),
2409            destination: Some(valid),
2410            ..Default::default()
2411        };
2412        let msg = ProtoMessage::new(
2413            HashMap::new(),
2414            ProtoSubscribeType(ProtoSubscribe {
2415                header: Some(hdr),
2416                ..Default::default()
2417            }),
2418        );
2419        assert!(matches!(
2420            msg.validate(),
2421            Err(MessageError::SourceEncodedNameNotFound)
2422        ));
2423    }
2424
2425    #[test]
2426    fn test_validate_subscribe_missing_destination_encoded_name() {
2427        let valid = ProtoName::from_strings(["org", "ns", "agent"]);
2428        let hdr = SlimHeader {
2429            source: Some(valid),
2430            destination: Some(ProtoName {
2431                name: None,
2432                str_name: None,
2433            }),
2434            ..Default::default()
2435        };
2436        let msg = ProtoMessage::new(
2437            HashMap::new(),
2438            ProtoSubscribeType(ProtoSubscribe {
2439                header: Some(hdr),
2440                ..Default::default()
2441            }),
2442        );
2443        assert!(matches!(
2444            msg.validate(),
2445            Err(MessageError::DestinationEncodedNameNotFound)
2446        ));
2447    }
2448
2449    #[test]
2450    fn test_participant_settings_convenience_methods() {
2451        let bidirectional = ParticipantSettings::bidirectional();
2452        assert!(bidirectional.sends_data);
2453        assert!(bidirectional.receives_data);
2454        assert!(bidirectional.is_sender());
2455        assert!(bidirectional.is_receiver());
2456
2457        let send_only = ParticipantSettings::send_only();
2458        assert!(send_only.sends_data);
2459        assert!(!send_only.receives_data);
2460        assert!(send_only.is_sender());
2461        assert!(!send_only.is_receiver());
2462
2463        let receive_only = ParticipantSettings::receive_only();
2464        assert!(!receive_only.sends_data);
2465        assert!(receive_only.receives_data);
2466        assert!(!receive_only.is_sender());
2467        assert!(receive_only.is_receiver());
2468    }
2469}