slim_datapath/messages/
utils.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashMap;
5use std::fmt::Display;
6
7use tracing::debug;
8
9use super::encoder::Name;
10use crate::api::{
11    Content, MessageType, ProtoMessage, ProtoName, ProtoPublish, ProtoPublishType,
12    ProtoSessionType, ProtoSubscribe, ProtoSubscribeType, ProtoUnsubscribe, ProtoUnsubscribeType,
13    SessionHeader, SlimHeader, proto::pubsub::v1::SessionMessageType,
14};
15
16use thiserror::Error;
17use tracing::error;
18
19#[derive(Error, Debug, PartialEq)]
20pub enum MessageError {
21    #[error("SLIM header not found")]
22    SlimHeaderNotFound,
23    #[error("source not found")]
24    SourceNotFound,
25    #[error("destination not found")]
26    DestinationNotFound,
27    #[error("session header not found")]
28    SessionHeaderNotFound,
29    #[error("message type not found")]
30    MessageTypeNotFound,
31    #[error("incoming connection not found")]
32    IncomingConnectionNotFound,
33}
34
35// Metadata Keys
36pub const SLIM_IDENTITY: &str = "SLIM_IDENTITY";
37
38/// ProtoName from Name
39impl From<&Name> for ProtoName {
40    fn from(name: &Name) -> Self {
41        Self {
42            component_0: name.components()[0],
43            component_1: name.components()[1],
44            component_2: name.components()[2],
45            component_3: name.components()[3],
46        }
47    }
48}
49
50/// Print message type
51impl Display for MessageType {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        match self {
54            MessageType::Publish(_) => write!(f, "publish"),
55            MessageType::Subscribe(_) => write!(f, "subscribe"),
56            MessageType::Unsubscribe(_) => write!(f, "unsubscribe"),
57        }
58    }
59}
60
61/// Struct grouping the SLIMHeaeder flags for convenience
62#[derive(Debug, Clone)]
63pub struct SlimHeaderFlags {
64    pub fanout: u32,
65    pub recv_from: Option<u64>,
66    pub forward_to: Option<u64>,
67    pub incoming_conn: Option<u64>,
68    pub error: Option<bool>,
69}
70
71impl Default for SlimHeaderFlags {
72    fn default() -> Self {
73        Self {
74            fanout: 1,
75            recv_from: None,
76            forward_to: None,
77            incoming_conn: None,
78            error: None,
79        }
80    }
81}
82
83impl Display for SlimHeaderFlags {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        write!(
86            f,
87            "fanout: {}, recv_from: {:?}, forward_to: {:?}, incoming_conn: {:?}, error: {:?}",
88            self.fanout, self.recv_from, self.forward_to, self.incoming_conn, self.error
89        )
90    }
91}
92
93impl SlimHeaderFlags {
94    pub fn new(
95        fanout: u32,
96        recv_from: Option<u64>,
97        forward_to: Option<u64>,
98        incoming_conn: Option<u64>,
99        error: Option<bool>,
100    ) -> Self {
101        Self {
102            fanout,
103            recv_from,
104            forward_to,
105            incoming_conn,
106            error,
107        }
108    }
109
110    pub fn with_fanout(self, fanout: u32) -> Self {
111        Self { fanout, ..self }
112    }
113
114    pub fn with_recv_from(self, recv_from: u64) -> Self {
115        Self {
116            recv_from: Some(recv_from),
117            ..self
118        }
119    }
120
121    pub fn with_forward_to(self, forward_to: u64) -> Self {
122        Self {
123            forward_to: Some(forward_to),
124            ..self
125        }
126    }
127
128    pub fn with_incoming_conn(self, incoming_conn: u64) -> Self {
129        Self {
130            incoming_conn: Some(incoming_conn),
131            ..self
132        }
133    }
134
135    pub fn with_error(self, error: bool) -> Self {
136        Self {
137            error: Some(error),
138            ..self
139        }
140    }
141}
142
143/// SLIM Header
144/// This header is used to identify the source and destination of the message
145/// and to manage the connections used to send and receive the message
146impl SlimHeader {
147    pub fn new(source: &Name, destination: &Name, flags: Option<SlimHeaderFlags>) -> Self {
148        let flags = flags.unwrap_or_default();
149
150        Self {
151            source: Some(ProtoName::from(source)),
152            destination: Some(ProtoName::from(destination)),
153            fanout: flags.fanout,
154            recv_from: flags.recv_from,
155            forward_to: flags.forward_to,
156            incoming_conn: flags.incoming_conn,
157            error: flags.error,
158        }
159    }
160
161    pub fn clear(&mut self) {
162        self.recv_from = None;
163        self.forward_to = None;
164    }
165
166    pub fn get_recv_from(&self) -> Option<u64> {
167        self.recv_from
168    }
169
170    pub fn get_forward_to(&self) -> Option<u64> {
171        self.forward_to
172    }
173
174    pub fn get_incoming_conn(&self) -> Option<u64> {
175        self.incoming_conn
176    }
177
178    pub fn get_error(&self) -> Option<bool> {
179        self.error
180    }
181
182    pub fn get_source(&self) -> Name {
183        match &self.source {
184            Some(source) => Name::from(source),
185            None => panic!("source not found"),
186        }
187    }
188
189    pub fn get_dst(&self) -> Name {
190        match &self.destination {
191            Some(destination) => Name::from(destination),
192            None => panic!("destination not found"),
193        }
194    }
195
196    pub fn set_source(&mut self, source: &Name) {
197        self.source = Some(ProtoName::from(source));
198    }
199
200    pub fn set_destination(&mut self, dst: &Name) {
201        self.destination = Some(ProtoName::from(dst));
202    }
203
204    pub fn get_fanout(&self) -> u32 {
205        self.fanout
206    }
207
208    pub fn set_recv_from(&mut self, recv_from: Option<u64>) {
209        self.recv_from = recv_from;
210    }
211
212    pub fn set_forward_to(&mut self, forward_to: Option<u64>) {
213        self.forward_to = forward_to;
214    }
215
216    pub fn set_error(&mut self, error: Option<bool>) {
217        self.error = error;
218    }
219
220    pub fn set_incoming_conn(&mut self, incoming_conn: Option<u64>) {
221        self.incoming_conn = incoming_conn;
222    }
223
224    pub fn set_error_flag(&mut self, error: Option<bool>) {
225        self.error = error;
226    }
227
228    pub fn set_fanout(&mut self, fanout: u32) {
229        self.fanout = fanout;
230    }
231
232    // returns the connection to use to process correctly the message
233    // first connection is from where we received the packet
234    // the second is where to forward the packet if needed
235    pub fn get_in_out_connections(&self) -> (u64, Option<u64>) {
236        // when calling this function, incoming connection is set
237        let incoming = self
238            .get_incoming_conn()
239            .expect("incoming connection not found");
240
241        if let Some(val) = self.get_recv_from() {
242            debug!(
243                "received recv_from command, update state on connection {}",
244                val
245            );
246            return (val, None);
247        }
248
249        if let Some(val) = self.get_forward_to() {
250            debug!(
251                "received forward_to command, update state and forward to connection {}",
252                val
253            );
254            return (incoming, Some(val));
255        }
256
257        // by default, return the incoming connection and None
258        (incoming, None)
259    }
260}
261
262/// Session Header
263/// This header is used to identify the session and the message
264/// and to manage session state
265impl SessionHeader {
266    pub fn new(
267        session_type: i32,
268        session_message_type: i32,
269        session_id: u32,
270        message_id: u32,
271    ) -> Self {
272        Self {
273            session_type,
274            session_message_type,
275            session_id,
276            message_id,
277        }
278    }
279
280    pub fn get_session_id(&self) -> u32 {
281        self.session_id
282    }
283
284    pub fn get_message_id(&self) -> u32 {
285        self.message_id
286    }
287
288    pub fn set_session_id(&mut self, session_id: u32) {
289        self.session_id = session_id;
290    }
291
292    pub fn set_message_id(&mut self, message_id: u32) {
293        self.message_id = message_id;
294    }
295
296    pub fn clear(&mut self) {
297        self.session_id = 0;
298        self.message_id = 0;
299    }
300}
301
302/// ProtoSubscribe
303/// This message is used to subscribe to a topic
304impl ProtoSubscribe {
305    pub fn new(source: &Name, dst: &Name, flags: Option<SlimHeaderFlags>) -> Self {
306        let header = Some(SlimHeader::new(source, dst, flags));
307
308        ProtoSubscribe {
309            header,
310            component_0: dst.components_strings().unwrap()[0].clone(),
311            component_1: dst.components_strings().unwrap()[1].clone(),
312            component_2: dst.components_strings().unwrap()[2].clone(),
313        }
314    }
315}
316
317/// From ProtoMessage to ProtoSubscribe
318impl From<ProtoMessage> for ProtoSubscribe {
319    fn from(message: ProtoMessage) -> Self {
320        match message.message_type {
321            Some(ProtoSubscribeType(s)) => s,
322            _ => panic!("message type is not subscribe"),
323        }
324    }
325}
326
327/// ProtoUnsubscribe
328/// This message is used to unsubscribe from a topic
329impl ProtoUnsubscribe {
330    pub fn with_header(header: Option<SlimHeader>) -> Self {
331        ProtoUnsubscribe { header }
332    }
333
334    pub fn new(source: &Name, dst: &Name, flags: Option<SlimHeaderFlags>) -> Self {
335        let header = Some(SlimHeader::new(source, dst, flags));
336
337        Self::with_header(header)
338    }
339}
340
341/// From ProtoMessage to ProtoUnsubscribe
342impl From<ProtoMessage> for ProtoUnsubscribe {
343    fn from(message: ProtoMessage) -> Self {
344        match message.message_type {
345            Some(ProtoUnsubscribeType(u)) => u,
346            _ => panic!("message type is not unsubscribe"),
347        }
348    }
349}
350
351/// ProtoPublish
352/// This message is used to publish a message, either to a shared channel or to a specific application
353impl ProtoPublish {
354    pub fn with_header(
355        header: Option<SlimHeader>,
356        session: Option<SessionHeader>,
357        payload: Option<Content>,
358    ) -> Self {
359        ProtoPublish {
360            header,
361            session,
362            msg: payload,
363        }
364    }
365
366    pub fn new(
367        source: &Name,
368        dst: &Name,
369        flags: Option<SlimHeaderFlags>,
370        content_type: &str,
371        blob: Vec<u8>,
372    ) -> Self {
373        let slim_header = Some(SlimHeader::new(source, dst, flags));
374
375        let session_header = Some(SessionHeader::default());
376
377        let msg = Some(Content {
378            content_type: content_type.to_string(),
379            blob,
380        });
381
382        Self::with_header(slim_header, session_header, msg)
383    }
384
385    pub fn get_slim_header(&self) -> &SlimHeader {
386        self.header.as_ref().unwrap()
387    }
388
389    pub fn get_session_header(&self) -> &SessionHeader {
390        self.session.as_ref().unwrap()
391    }
392
393    pub fn get_slim_header_as_mut(&mut self) -> &mut SlimHeader {
394        self.header.as_mut().unwrap()
395    }
396
397    pub fn get_session_header_as_mut(&mut self) -> &mut SessionHeader {
398        self.session.as_mut().unwrap()
399    }
400
401    pub fn get_payload(&self) -> &Content {
402        self.msg.as_ref().unwrap()
403    }
404}
405
406/// From ProtoMessage to ProtoPublish
407impl From<ProtoMessage> for ProtoPublish {
408    fn from(message: ProtoMessage) -> Self {
409        match message.message_type {
410            Some(ProtoPublishType(p)) => p,
411            _ => panic!("message type is not publish"),
412        }
413    }
414}
415
416/// ProtoMessage
417/// This represents a generic message that can be sent over the network
418impl ProtoMessage {
419    fn new(metadata: HashMap<String, String>, message_type: MessageType) -> Self {
420        ProtoMessage {
421            metadata,
422            message_type: Some(message_type),
423        }
424    }
425
426    pub fn new_subscribe(source: &Name, dst: &Name, flags: Option<SlimHeaderFlags>) -> Self {
427        let subscribe = ProtoSubscribe::new(source, dst, flags);
428
429        Self::new(HashMap::new(), ProtoSubscribeType(subscribe))
430    }
431
432    pub fn new_unsubscribe(source: &Name, dst: &Name, flags: Option<SlimHeaderFlags>) -> Self {
433        let unsubscribe = ProtoUnsubscribe::new(source, dst, flags);
434
435        Self::new(HashMap::new(), ProtoUnsubscribeType(unsubscribe))
436    }
437
438    pub fn new_publish(
439        source: &Name,
440        dst: &Name,
441        flags: Option<SlimHeaderFlags>,
442        content_type: &str,
443        blob: Vec<u8>,
444    ) -> Self {
445        let publish = ProtoPublish::new(source, dst, flags, content_type, blob);
446
447        Self::new(HashMap::new(), ProtoPublishType(publish))
448    }
449
450    pub fn new_publish_with_headers(
451        slim_header: Option<SlimHeader>,
452        session_header: Option<SessionHeader>,
453        content_type: &str,
454        blob: Vec<u8>,
455    ) -> Self {
456        let publish = ProtoPublish::with_header(
457            slim_header,
458            session_header,
459            Some(Content {
460                content_type: content_type.to_string(),
461                blob,
462            }),
463        );
464
465        Self::new(HashMap::new(), ProtoPublishType(publish))
466    }
467
468    // validate message
469    pub fn validate(&self) -> Result<(), MessageError> {
470        // make sure the message type is set
471        if self.message_type.is_none() {
472            return Err(MessageError::MessageTypeNotFound);
473        }
474
475        // make sure SLIM header is set
476        if self.try_get_slim_header().is_none() {
477            return Err(MessageError::SlimHeaderNotFound);
478        }
479
480        // Get SLIM header
481        let slim_header = self.get_slim_header();
482
483        // make sure source and destination are set
484        if slim_header.source.is_none() {
485            return Err(MessageError::SourceNotFound);
486        }
487        if slim_header.destination.is_none() {
488            return Err(MessageError::DestinationNotFound);
489        }
490
491        match &self.message_type {
492            Some(ProtoPublishType(p)) => {
493                // SLIM Header
494                if p.header.is_none() {
495                    return Err(MessageError::SlimHeaderNotFound);
496                }
497
498                // Publish message should have the session header
499                if p.session.is_none() {
500                    return Err(MessageError::SessionHeaderNotFound);
501                }
502            }
503            Some(ProtoSubscribeType(s)) => {
504                if s.header.is_none() {
505                    return Err(MessageError::SlimHeaderNotFound);
506                }
507            }
508            Some(ProtoUnsubscribeType(u)) => {
509                if u.header.is_none() {
510                    return Err(MessageError::SlimHeaderNotFound);
511                }
512            }
513            None => return Err(MessageError::MessageTypeNotFound),
514        }
515
516        Ok(())
517    }
518
519    // add metadata key in the map assigning the value val
520    // if the key exists the value is replaced by val
521    pub fn insert_metadata(&mut self, key: String, val: String) {
522        self.metadata.insert(key, val);
523    }
524
525    // remove metadata key from the map
526    pub fn remove_metadata(&mut self, key: &str) {
527        self.metadata.remove(key);
528    }
529
530    pub fn contains_metadata(&self, key: &str) -> bool {
531        self.metadata.contains_key(key)
532    }
533
534    pub fn get_metadata(&self, key: &str) -> Option<&String> {
535        self.metadata.get(key)
536    }
537
538    pub fn get_metadata_map(&self) -> HashMap<String, String> {
539        self.metadata.clone()
540    }
541
542    pub fn set_metadata_map(&mut self, map: HashMap<String, String>) {
543        for (k, v) in map.iter() {
544            self.insert_metadata(k.to_string(), v.to_string());
545        }
546    }
547
548    pub fn get_slim_header(&self) -> &SlimHeader {
549        match &self.message_type {
550            Some(ProtoPublishType(publish)) => publish.header.as_ref().unwrap(),
551            Some(ProtoSubscribeType(sub)) => sub.header.as_ref().unwrap(),
552            Some(ProtoUnsubscribeType(unsub)) => unsub.header.as_ref().unwrap(),
553            None => panic!("SLIM header not found"),
554        }
555    }
556
557    pub fn get_slim_header_mut(&mut self) -> &mut SlimHeader {
558        match &mut self.message_type {
559            Some(ProtoPublishType(publish)) => publish.header.as_mut().unwrap(),
560            Some(ProtoSubscribeType(sub)) => sub.header.as_mut().unwrap(),
561            Some(ProtoUnsubscribeType(unsub)) => unsub.header.as_mut().unwrap(),
562            None => panic!("SLIM header not found"),
563        }
564    }
565
566    pub fn try_get_slim_header(&self) -> Option<&SlimHeader> {
567        match &self.message_type {
568            Some(ProtoPublishType(publish)) => publish.header.as_ref(),
569            Some(ProtoSubscribeType(sub)) => sub.header.as_ref(),
570            Some(ProtoUnsubscribeType(unsub)) => unsub.header.as_ref(),
571            None => None,
572        }
573    }
574
575    pub fn get_session_header(&self) -> &SessionHeader {
576        match &self.message_type {
577            Some(ProtoPublishType(publish)) => publish.session.as_ref().unwrap(),
578            Some(ProtoSubscribeType(_)) => panic!("session header not found"),
579            Some(ProtoUnsubscribeType(_)) => panic!("session header not found"),
580            None => panic!("session header not found"),
581        }
582    }
583
584    pub fn get_session_header_mut(&mut self) -> &mut SessionHeader {
585        match &mut self.message_type {
586            Some(ProtoPublishType(publish)) => publish.session.as_mut().unwrap(),
587            Some(ProtoSubscribeType(_)) => panic!("session header not found"),
588            Some(ProtoUnsubscribeType(_)) => panic!("session header not found"),
589            None => panic!("session header not found"),
590        }
591    }
592
593    pub fn try_get_session_header(&self) -> Option<&SessionHeader> {
594        match &self.message_type {
595            Some(ProtoPublishType(publish)) => publish.session.as_ref(),
596            Some(ProtoSubscribeType(_)) => None,
597            Some(ProtoUnsubscribeType(_)) => None,
598            None => None,
599        }
600    }
601
602    pub fn try_get_session_header_mut(&mut self) -> Option<&mut SessionHeader> {
603        match &mut self.message_type {
604            Some(ProtoPublishType(publish)) => publish.session.as_mut(),
605            Some(ProtoSubscribeType(_)) => None,
606            Some(ProtoUnsubscribeType(_)) => None,
607            None => None,
608        }
609    }
610
611    pub fn get_id(&self) -> u32 {
612        self.get_session_header().get_message_id()
613    }
614
615    pub fn get_source(&self) -> Name {
616        self.get_slim_header().get_source()
617    }
618
619    pub fn get_fanout(&self) -> u32 {
620        self.get_slim_header().get_fanout()
621    }
622
623    pub fn get_recv_from(&self) -> Option<u64> {
624        self.get_slim_header().get_recv_from()
625    }
626
627    pub fn get_forward_to(&self) -> Option<u64> {
628        self.get_slim_header().get_forward_to()
629    }
630
631    pub fn get_error(&self) -> Option<bool> {
632        self.get_slim_header().get_error()
633    }
634
635    pub fn get_incoming_conn(&self) -> u64 {
636        self.get_slim_header().get_incoming_conn().unwrap()
637    }
638
639    pub fn try_get_incoming_conn(&self) -> Option<u64> {
640        self.get_slim_header().get_incoming_conn()
641    }
642
643    pub fn get_dst(&self) -> Name {
644        let dst = self.get_slim_header().get_dst();
645
646        // complete name with the original strings if the message is a subscribe
647        if let Some(ProtoSubscribeType(subscribe)) = &self.message_type {
648            return Name::from_strings([
649                subscribe.component_0.clone(),
650                subscribe.component_1.clone(),
651                subscribe.component_2.clone(),
652            ])
653            .with_id(dst.id());
654        }
655
656        dst
657    }
658
659    pub fn get_type(&self) -> &MessageType {
660        match &self.message_type {
661            Some(t) => t,
662            None => panic!("message type not found"),
663        }
664    }
665
666    pub fn get_payload(&self) -> Option<&Content> {
667        match &self.message_type {
668            Some(ProtoPublishType(p)) => p.msg.as_ref(),
669            Some(ProtoSubscribeType(_)) => panic!("payload not found"),
670            Some(ProtoUnsubscribeType(_)) => panic!("payload not found"),
671            None => panic!("payload not found"),
672        }
673    }
674
675    pub fn get_session_message_type(&self) -> SessionMessageType {
676        self.get_session_header()
677            .session_message_type
678            .try_into()
679            .unwrap_or_default()
680    }
681
682    pub fn clear_slim_header(&mut self) {
683        self.get_slim_header_mut().clear();
684    }
685
686    pub fn set_recv_from(&mut self, recv_from: Option<u64>) {
687        self.get_slim_header_mut().set_recv_from(recv_from);
688    }
689
690    pub fn set_forward_to(&mut self, forward_to: Option<u64>) {
691        self.get_slim_header_mut().set_forward_to(forward_to);
692    }
693
694    pub fn set_error(&mut self, error: Option<bool>) {
695        self.get_slim_header_mut().set_error(error);
696    }
697
698    pub fn set_fanout(&mut self, fanout: u32) {
699        self.get_slim_header_mut().set_fanout(fanout);
700    }
701
702    pub fn set_incoming_conn(&mut self, incoming_conn: Option<u64>) {
703        self.get_slim_header_mut().set_incoming_conn(incoming_conn);
704    }
705
706    pub fn set_error_flag(&mut self, error: Option<bool>) {
707        self.get_slim_header_mut().set_error_flag(error);
708    }
709
710    pub fn set_session_message_type(&mut self, message_type: SessionMessageType) {
711        self.get_session_header_mut()
712            .set_session_message_type(message_type);
713    }
714
715    pub fn set_session_type(&mut self, session_type: ProtoSessionType) {
716        self.get_session_header_mut().set_session_type(session_type);
717    }
718
719    pub fn get_session_type(&self) -> ProtoSessionType {
720        self.get_session_header().session_type()
721    }
722
723    pub fn set_message_id(&mut self, message_id: u32) {
724        self.get_session_header_mut().set_message_id(message_id);
725    }
726
727    pub fn is_publish(&self) -> bool {
728        matches!(self.get_type(), MessageType::Publish(_))
729    }
730
731    pub fn is_subscribe(&self) -> bool {
732        matches!(self.get_type(), MessageType::Subscribe(_))
733    }
734
735    pub fn is_unsubscribe(&self) -> bool {
736        matches!(self.get_type(), MessageType::Unsubscribe(_))
737    }
738}
739
740impl AsRef<ProtoPublish> for ProtoMessage {
741    fn as_ref(&self) -> &ProtoPublish {
742        match &self.message_type {
743            Some(ProtoPublishType(p)) => p,
744            _ => panic!("message type is not publish"),
745        }
746    }
747}
748
749#[cfg(test)]
750mod tests {
751    use crate::{api::proto::pubsub::v1::SessionMessageType, messages::encoder::Name};
752
753    use super::*;
754
755    fn test_subscription_template(
756        subscription: bool,
757        source: Name,
758        dst: Name,
759        flags: Option<SlimHeaderFlags>,
760    ) {
761        let sub = {
762            if subscription {
763                ProtoMessage::new_subscribe(&source, &dst, flags.clone())
764            } else {
765                ProtoMessage::new_unsubscribe(&source, &dst, flags.clone())
766            }
767        };
768
769        let flags = if flags.is_none() {
770            Some(SlimHeaderFlags::default())
771        } else {
772            flags
773        };
774
775        assert!(!sub.is_publish());
776        assert_eq!(sub.is_subscribe(), subscription);
777        assert_eq!(sub.is_unsubscribe(), !subscription);
778        assert_eq!(flags.as_ref().unwrap().recv_from, sub.get_recv_from());
779        assert_eq!(flags.as_ref().unwrap().forward_to, sub.get_forward_to());
780        assert_eq!(None, sub.try_get_incoming_conn());
781        assert_eq!(source, sub.get_source());
782        let got_name = sub.get_dst();
783        assert_eq!(dst, got_name);
784    }
785
786    fn test_publish_template(source: Name, dst: Name, flags: Option<SlimHeaderFlags>) {
787        let pub_msg = ProtoMessage::new_publish(
788            &source,
789            &dst,
790            flags.clone(),
791            "str",
792            "this is the content of the message".into(),
793        );
794
795        let flags = if flags.is_none() {
796            Some(SlimHeaderFlags::default())
797        } else {
798            flags
799        };
800
801        assert!(pub_msg.is_publish());
802        assert!(!pub_msg.is_subscribe());
803        assert!(!pub_msg.is_unsubscribe());
804        assert_eq!(flags.as_ref().unwrap().recv_from, pub_msg.get_recv_from());
805        assert_eq!(flags.as_ref().unwrap().forward_to, pub_msg.get_forward_to());
806        assert_eq!(None, pub_msg.try_get_incoming_conn());
807        assert_eq!(source, pub_msg.get_source());
808        let got_name = pub_msg.get_dst();
809        assert_eq!(dst, got_name);
810        assert_eq!(flags.as_ref().unwrap().fanout, pub_msg.get_fanout());
811    }
812
813    #[test]
814    fn test_subscription() {
815        let source = Name::from_strings(["org", "ns", "type"]).with_id(1);
816        let dst = Name::from_strings(["org", "ns", "type"]).with_id(2);
817
818        // simple
819        test_subscription_template(true, source.clone(), dst.clone(), None);
820
821        // with name id
822        test_subscription_template(true, source.clone(), dst.clone(), None);
823
824        // with recv from
825        test_subscription_template(
826            true,
827            source.clone(),
828            dst.clone(),
829            Some(SlimHeaderFlags::default().with_recv_from(50)),
830        );
831
832        // with forward to
833        test_subscription_template(
834            true,
835            source.clone(),
836            dst.clone(),
837            Some(SlimHeaderFlags::default().with_forward_to(30)),
838        );
839    }
840
841    #[test]
842    fn test_unsubscription() {
843        let source = Name::from_strings(["org", "ns", "type"]).with_id(1);
844        let dst = Name::from_strings(["org", "ns", "type"]).with_id(2);
845
846        // simple
847        test_subscription_template(false, source.clone(), dst.clone(), None);
848
849        // with name id
850        test_subscription_template(false, source.clone(), dst.clone(), None);
851
852        // with recv from
853        test_subscription_template(
854            false,
855            source.clone(),
856            dst.clone(),
857            Some(SlimHeaderFlags::default().with_recv_from(50)),
858        );
859
860        // with forward to
861        test_subscription_template(
862            false,
863            source.clone(),
864            dst.clone(),
865            Some(SlimHeaderFlags::default().with_forward_to(30)),
866        );
867    }
868
869    #[test]
870    fn test_publish() {
871        let source = Name::from_strings(["org", "ns", "type"]).with_id(1);
872        let mut dst = Name::from_strings(["org", "ns", "type"]);
873
874        // simple
875        test_publish_template(
876            source.clone(),
877            dst.clone(),
878            Some(SlimHeaderFlags::default()),
879        );
880
881        // with name id
882        dst.set_id(2);
883        test_publish_template(
884            source.clone(),
885            dst.clone(),
886            Some(SlimHeaderFlags::default()),
887        );
888        dst.reset_id();
889
890        // with recv from
891        test_publish_template(
892            source.clone(),
893            dst.clone(),
894            Some(SlimHeaderFlags::default().with_recv_from(50)),
895        );
896
897        // with forward to
898        test_publish_template(
899            source.clone(),
900            dst.clone(),
901            Some(SlimHeaderFlags::default().with_forward_to(30)),
902        );
903
904        // with fanout
905        test_publish_template(
906            source.clone(),
907            dst.clone(),
908            Some(SlimHeaderFlags::default().with_fanout(2)),
909        );
910    }
911
912    #[test]
913    fn test_conversions() {
914        // Name to ProtoName
915        let name = Name::from_strings(["org", "ns", "type"]).with_id(1);
916        let proto_name = ProtoName::from(&name);
917
918        assert_eq!(proto_name.component_0, name.components()[0]);
919        assert_eq!(proto_name.component_1, name.components()[1]);
920        assert_eq!(proto_name.component_2, name.components()[2]);
921        assert_eq!(proto_name.component_3, name.components()[3]);
922
923        // ProtoName to Name
924        let name_from_proto = Name::from(&proto_name);
925        assert_eq!(name_from_proto.components()[0], proto_name.component_0);
926        assert_eq!(name_from_proto.components()[1], proto_name.component_1);
927        assert_eq!(name_from_proto.components()[2], proto_name.component_2);
928        assert_eq!(name_from_proto.components()[3], proto_name.component_3);
929
930        // ProtoMessage to ProtoSubscribe
931        let dst = Name::from_strings(["org", "ns", "type"]).with_id(1);
932        let proto_subscribe = ProtoMessage::new_subscribe(
933            &name,
934            &dst,
935            Some(
936                SlimHeaderFlags::default()
937                    .with_recv_from(2)
938                    .with_forward_to(3),
939            ),
940        );
941        let proto_subscribe = ProtoSubscribe::from(proto_subscribe);
942        assert_eq!(proto_subscribe.header.as_ref().unwrap().get_source(), name);
943        assert_eq!(proto_subscribe.header.as_ref().unwrap().get_dst(), dst,);
944
945        // ProtoMessage to ProtoUnsubscribe
946        let proto_unsubscribe = ProtoMessage::new_unsubscribe(
947            &name,
948            &dst,
949            Some(
950                SlimHeaderFlags::default()
951                    .with_recv_from(2)
952                    .with_forward_to(3),
953            ),
954        );
955        let proto_unsubscribe = ProtoUnsubscribe::from(proto_unsubscribe);
956        assert_eq!(
957            proto_unsubscribe.header.as_ref().unwrap().get_source(),
958            name
959        );
960        assert_eq!(proto_unsubscribe.header.as_ref().unwrap().get_dst(), dst);
961
962        // ProtoMessage to ProtoPublish
963        let proto_publish = ProtoMessage::new_publish(
964            &name,
965            &dst,
966            Some(
967                SlimHeaderFlags::default()
968                    .with_recv_from(2)
969                    .with_forward_to(3),
970            ),
971            "str",
972            "this is the content of the message".into(),
973        );
974        let proto_publish = ProtoPublish::from(proto_publish);
975        assert_eq!(proto_publish.header.as_ref().unwrap().get_source(), name);
976        assert_eq!(proto_publish.header.as_ref().unwrap().get_dst(), dst);
977    }
978
979    #[test]
980    fn test_panic() {
981        let source = Name::from_strings(["org", "ns", "type"]).with_id(1);
982        let dst = Name::from_strings(["org", "ns", "type"]).with_id(2);
983
984        // panic if SLIM header is not found
985        let msg = ProtoMessage::new_subscribe(
986            &source,
987            &dst,
988            Some(
989                SlimHeaderFlags::default()
990                    .with_recv_from(2)
991                    .with_forward_to(3),
992            ),
993        );
994
995        // let's try to convert it to a unsubscribe
996        // this should panic because the message type is not unsubscribe
997        let result = std::panic::catch_unwind(|| ProtoUnsubscribe::from(msg.clone()));
998        assert!(result.is_err());
999
1000        // try to convert to publish
1001        // this should panic because the message type is not publish
1002        let result = std::panic::catch_unwind(|| ProtoPublish::from(msg.clone()));
1003        assert!(result.is_err());
1004
1005        // finally make sure the conversion to subscribe works
1006        let result = std::panic::catch_unwind(|| ProtoSubscribe::from(msg));
1007        assert!(result.is_ok());
1008    }
1009
1010    #[test]
1011    fn test_panic_header() {
1012        // create a unusual SLIM header
1013        let header = SlimHeader {
1014            source: None,
1015            destination: None,
1016            fanout: 0,
1017            recv_from: None,
1018            forward_to: None,
1019            incoming_conn: None,
1020            error: None,
1021        };
1022
1023        // the operations to retrieve source and destination should fail with panic
1024        let result = std::panic::catch_unwind(|| header.get_source());
1025        assert!(result.is_err());
1026
1027        let result = std::panic::catch_unwind(|| header.get_dst());
1028        assert!(result.is_err());
1029
1030        // The operations to retrieve recv_from and forward_to should not fail with panic
1031        let result = std::panic::catch_unwind(|| header.get_recv_from());
1032        assert!(result.is_ok());
1033
1034        let result = std::panic::catch_unwind(|| header.get_forward_to());
1035        assert!(result.is_ok());
1036
1037        // The operations to retrieve incoming_conn should not fail with panic
1038        let result = std::panic::catch_unwind(|| header.get_incoming_conn());
1039        assert!(result.is_ok());
1040
1041        // The operations to retrieve error should not fail with panic
1042        let result = std::panic::catch_unwind(|| header.get_error());
1043        assert!(result.is_ok());
1044    }
1045
1046    #[test]
1047    fn test_panic_session_header() {
1048        // create a unusual session header
1049        let header = SessionHeader::new(0, 0, 0, 0);
1050
1051        // the operations to retrieve session_id and message_id should not fail with panic
1052        let result = std::panic::catch_unwind(|| header.get_session_id());
1053        assert!(result.is_ok());
1054
1055        let result = std::panic::catch_unwind(|| header.get_message_id());
1056        assert!(result.is_ok());
1057    }
1058
1059    #[test]
1060    fn test_panic_proto_message() {
1061        // create a unusual proto message
1062        let message = ProtoMessage {
1063            metadata: HashMap::new(),
1064            message_type: None,
1065        };
1066
1067        // the operation to retrieve the header should fail with panic
1068        let result = std::panic::catch_unwind(|| message.get_slim_header());
1069        assert!(result.is_err());
1070
1071        // the operation to retrieve the message type should fail with panic
1072        let result = std::panic::catch_unwind(|| message.get_type());
1073        assert!(result.is_err());
1074
1075        // all the other ops should fail with panic as well as the header is not set
1076        let result = std::panic::catch_unwind(|| message.get_source());
1077        assert!(result.is_err());
1078        let result = std::panic::catch_unwind(|| message.get_dst());
1079        assert!(result.is_err());
1080        let result = std::panic::catch_unwind(|| message.get_recv_from());
1081        assert!(result.is_err());
1082        let result = std::panic::catch_unwind(|| message.get_forward_to());
1083        assert!(result.is_err());
1084        let result = std::panic::catch_unwind(|| message.get_incoming_conn());
1085        assert!(result.is_err());
1086        let result = std::panic::catch_unwind(|| message.get_fanout());
1087        assert!(result.is_err());
1088    }
1089
1090    #[test]
1091    fn test_service_type_to_int() {
1092        // Get total number of service types
1093        let total_service_types = SessionMessageType::ChannelMlsAck as i32;
1094
1095        for i in 0..total_service_types {
1096            // int -> ServiceType
1097            let service_type =
1098                SessionMessageType::try_from(i).expect("failed to convert int to service type");
1099            let service_type_int = i32::from(service_type);
1100            assert_eq!(service_type_int, i32::from(service_type),);
1101        }
1102
1103        // Test invalid conversion
1104        let invalid_service_type = SessionMessageType::try_from(total_service_types + 1);
1105        assert!(invalid_service_type.is_err());
1106    }
1107}