slim_datapath/messages/
utils.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashMap;
5use std::fmt::Display;
6
7use tracing::debug;
8
9use super::encoder::Name;
10use crate::api::{
11    Content, MessageType, ProtoMessage, ProtoName, ProtoPublish, ProtoPublishType,
12    ProtoSessionType, ProtoSubscribe, ProtoSubscribeType, ProtoUnsubscribe, ProtoUnsubscribeType,
13    SessionHeader, SlimHeader,
14    proto::dataplane::v1::{OriginalName, SessionMessageType},
15};
16
17use thiserror::Error;
18use tracing::error;
19
20#[derive(Error, Debug, PartialEq)]
21pub enum MessageError {
22    #[error("SLIM header not found")]
23    SlimHeaderNotFound,
24    #[error("source not found")]
25    SourceNotFound,
26    #[error("destination not found")]
27    DestinationNotFound,
28    #[error("session header not found")]
29    SessionHeaderNotFound,
30    #[error("message type not found")]
31    MessageTypeNotFound,
32    #[error("incoming connection not found")]
33    IncomingConnectionNotFound,
34}
35
36// Metadata Keys
37pub const SLIM_IDENTITY: &str = "SLIM_IDENTITY";
38
39/// ProtoName from Name
40impl From<&Name> for ProtoName {
41    fn from(name: &Name) -> Self {
42        Self {
43            component_0: name.components()[0],
44            component_1: name.components()[1],
45            component_2: name.components()[2],
46            component_3: name.components()[3],
47        }
48    }
49}
50
51/// Print message type
52impl Display for MessageType {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        match self {
55            MessageType::Publish(_) => write!(f, "publish"),
56            MessageType::Subscribe(_) => write!(f, "subscribe"),
57            MessageType::Unsubscribe(_) => write!(f, "unsubscribe"),
58        }
59    }
60}
61
62/// Struct grouping the SLIMHeaeder flags for convenience
63#[derive(Debug, Clone)]
64pub struct SlimHeaderFlags {
65    pub fanout: u32,
66    pub recv_from: Option<u64>,
67    pub forward_to: Option<u64>,
68    pub incoming_conn: Option<u64>,
69    pub error: Option<bool>,
70}
71
72impl Default for SlimHeaderFlags {
73    fn default() -> Self {
74        Self {
75            fanout: 1,
76            recv_from: None,
77            forward_to: None,
78            incoming_conn: None,
79            error: None,
80        }
81    }
82}
83
84impl Display for SlimHeaderFlags {
85    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86        write!(
87            f,
88            "fanout: {}, recv_from: {:?}, forward_to: {:?}, incoming_conn: {:?}, error: {:?}",
89            self.fanout, self.recv_from, self.forward_to, self.incoming_conn, self.error
90        )
91    }
92}
93
94impl SlimHeaderFlags {
95    pub fn new(
96        fanout: u32,
97        recv_from: Option<u64>,
98        forward_to: Option<u64>,
99        incoming_conn: Option<u64>,
100        error: Option<bool>,
101    ) -> Self {
102        Self {
103            fanout,
104            recv_from,
105            forward_to,
106            incoming_conn,
107            error,
108        }
109    }
110
111    pub fn with_fanout(self, fanout: u32) -> Self {
112        Self { fanout, ..self }
113    }
114
115    pub fn with_recv_from(self, recv_from: u64) -> Self {
116        Self {
117            recv_from: Some(recv_from),
118            ..self
119        }
120    }
121
122    pub fn with_forward_to(self, forward_to: u64) -> Self {
123        Self {
124            forward_to: Some(forward_to),
125            ..self
126        }
127    }
128
129    pub fn with_incoming_conn(self, incoming_conn: u64) -> Self {
130        Self {
131            incoming_conn: Some(incoming_conn),
132            ..self
133        }
134    }
135
136    pub fn with_error(self, error: bool) -> Self {
137        Self {
138            error: Some(error),
139            ..self
140        }
141    }
142}
143
144/// SLIM Header
145/// This header is used to identify the source and destination of the message
146/// and to manage the connections used to send and receive the message
147impl SlimHeader {
148    pub fn new(source: &Name, destination: &Name, flags: Option<SlimHeaderFlags>) -> Self {
149        let flags = flags.unwrap_or_default();
150
151        Self {
152            source: Some(ProtoName::from(source)),
153            destination: Some(ProtoName::from(destination)),
154            fanout: flags.fanout,
155            recv_from: flags.recv_from,
156            forward_to: flags.forward_to,
157            incoming_conn: flags.incoming_conn,
158            error: flags.error,
159        }
160    }
161
162    pub fn clear(&mut self) {
163        self.recv_from = None;
164        self.forward_to = None;
165    }
166
167    pub fn get_recv_from(&self) -> Option<u64> {
168        self.recv_from
169    }
170
171    pub fn get_forward_to(&self) -> Option<u64> {
172        self.forward_to
173    }
174
175    pub fn get_incoming_conn(&self) -> Option<u64> {
176        self.incoming_conn
177    }
178
179    pub fn get_error(&self) -> Option<bool> {
180        self.error
181    }
182
183    pub fn get_source(&self) -> Name {
184        match &self.source {
185            Some(source) => Name::from(source),
186            None => panic!("source not found"),
187        }
188    }
189
190    pub fn get_dst(&self) -> Name {
191        match &self.destination {
192            Some(destination) => Name::from(destination),
193            None => panic!("destination not found"),
194        }
195    }
196
197    pub fn set_source(&mut self, source: &Name) {
198        self.source = Some(ProtoName::from(source));
199    }
200
201    pub fn set_destination(&mut self, dst: &Name) {
202        self.destination = Some(ProtoName::from(dst));
203    }
204
205    pub fn get_fanout(&self) -> u32 {
206        self.fanout
207    }
208
209    pub fn set_recv_from(&mut self, recv_from: Option<u64>) {
210        self.recv_from = recv_from;
211    }
212
213    pub fn set_forward_to(&mut self, forward_to: Option<u64>) {
214        self.forward_to = forward_to;
215    }
216
217    pub fn set_error(&mut self, error: Option<bool>) {
218        self.error = error;
219    }
220
221    pub fn set_incoming_conn(&mut self, incoming_conn: Option<u64>) {
222        self.incoming_conn = incoming_conn;
223    }
224
225    pub fn set_error_flag(&mut self, error: Option<bool>) {
226        self.error = error;
227    }
228
229    pub fn set_fanout(&mut self, fanout: u32) {
230        self.fanout = fanout;
231    }
232
233    // returns the connection to use to process correctly the message
234    // first connection is from where we received the packet
235    // the second is where to forward the packet if needed
236    pub fn get_in_out_connections(&self) -> (u64, Option<u64>) {
237        // when calling this function, incoming connection is set
238        let incoming = self
239            .get_incoming_conn()
240            .expect("incoming connection not found");
241
242        if let Some(val) = self.get_recv_from() {
243            debug!(
244                "received recv_from command, update state on connection {}",
245                val
246            );
247            return (val, None);
248        }
249
250        if let Some(val) = self.get_forward_to() {
251            debug!(
252                "received forward_to command, update state and forward to connection {}",
253                val
254            );
255            return (incoming, Some(val));
256        }
257
258        // by default, return the incoming connection and None
259        (incoming, None)
260    }
261}
262
263/// Session Header
264/// This header is used to identify the session and the message
265/// and to manage session state
266impl SessionHeader {
267    pub fn new(
268        session_type: i32,
269        session_message_type: i32,
270        session_id: u32,
271        message_id: u32,
272        source: &Option<Name>,
273        destination: &Option<Name>,
274    ) -> Self {
275        let src = match source {
276            Some(name) => name.components_strings().map(|c| OriginalName {
277                component_0: c[0].clone(),
278                component_1: c[1].clone(),
279                component_2: c[2].clone(),
280                component_3: name.id(),
281            }),
282            None => None,
283        };
284
285        let dst = match destination {
286            Some(name) => name.components_strings().map(|c| OriginalName {
287                component_0: c[0].clone(),
288                component_1: c[1].clone(),
289                component_2: c[2].clone(),
290                component_3: name.id(),
291            }),
292            None => None,
293        };
294
295        Self {
296            session_type,
297            session_message_type,
298            session_id,
299            message_id,
300            source: src,
301            destination: dst,
302        }
303    }
304
305    pub fn get_session_id(&self) -> u32 {
306        self.session_id
307    }
308
309    pub fn get_message_id(&self) -> u32 {
310        self.message_id
311    }
312
313    pub fn set_session_id(&mut self, session_id: u32) {
314        self.session_id = session_id;
315    }
316
317    pub fn set_message_id(&mut self, message_id: u32) {
318        self.message_id = message_id;
319    }
320
321    pub fn clear(&mut self) {
322        self.session_id = 0;
323        self.message_id = 0;
324    }
325
326    pub fn get_source(&self) -> Option<Name> {
327        match &self.source {
328            Some(src) => {
329                let n = Name::from_strings([
330                    src.component_0.clone(),
331                    src.component_1.clone(),
332                    src.component_2.clone(),
333                ])
334                .with_id(src.component_3);
335                Some(n)
336            }
337            None => None,
338        }
339    }
340
341    pub fn set_source(&mut self, source: &Name) {
342        let c_opt = source.components_strings();
343        if c_opt.is_none() {
344            return;
345        }
346
347        let c = c_opt.unwrap();
348
349        self.source = Some(OriginalName {
350            component_0: c[0].clone(),
351            component_1: c[1].clone(),
352            component_2: c[2].clone(),
353            component_3: source.id(),
354        });
355    }
356
357    pub fn get_destination(&self) -> Option<Name> {
358        match &self.destination {
359            Some(src) => {
360                let n = Name::from_strings([
361                    src.component_0.clone(),
362                    src.component_1.clone(),
363                    src.component_2.clone(),
364                ])
365                .with_id(src.component_3);
366                Some(n)
367            }
368            None => None,
369        }
370    }
371
372    pub fn set_destination(&mut self, destination: &Name) {
373        let c_opt = destination.components_strings();
374        if c_opt.is_none() {
375            return;
376        }
377
378        let c = c_opt.unwrap();
379
380        self.destination = Some(OriginalName {
381            component_0: c[0].clone(),
382            component_1: c[1].clone(),
383            component_2: c[2].clone(),
384            component_3: destination.id(),
385        });
386    }
387}
388
389/// ProtoSubscribe
390/// This message is used to subscribe to a topic
391impl ProtoSubscribe {
392    pub fn new(source: &Name, dst: &Name, flags: Option<SlimHeaderFlags>) -> Self {
393        let header = Some(SlimHeader::new(source, dst, flags));
394
395        ProtoSubscribe {
396            header,
397            component_0: dst.components_strings().unwrap()[0].clone(),
398            component_1: dst.components_strings().unwrap()[1].clone(),
399            component_2: dst.components_strings().unwrap()[2].clone(),
400        }
401    }
402}
403
404/// From ProtoMessage to ProtoSubscribe
405impl From<ProtoMessage> for ProtoSubscribe {
406    fn from(message: ProtoMessage) -> Self {
407        match message.message_type {
408            Some(ProtoSubscribeType(s)) => s,
409            _ => panic!("message type is not subscribe"),
410        }
411    }
412}
413
414/// ProtoUnsubscribe
415/// This message is used to unsubscribe from a topic
416impl ProtoUnsubscribe {
417    pub fn new(source: &Name, dst: &Name, flags: Option<SlimHeaderFlags>) -> Self {
418        let header = Some(SlimHeader::new(source, dst, flags));
419
420        ProtoUnsubscribe {
421            header,
422            component_0: dst.components_strings().unwrap()[0].clone(),
423            component_1: dst.components_strings().unwrap()[1].clone(),
424            component_2: dst.components_strings().unwrap()[2].clone(),
425        }
426    }
427}
428
429/// From ProtoMessage to ProtoUnsubscribe
430impl From<ProtoMessage> for ProtoUnsubscribe {
431    fn from(message: ProtoMessage) -> Self {
432        match message.message_type {
433            Some(ProtoUnsubscribeType(u)) => u,
434            _ => panic!("message type is not unsubscribe"),
435        }
436    }
437}
438
439/// ProtoPublish
440/// This message is used to publish a message, either to a shared channel or to a specific application
441impl ProtoPublish {
442    pub fn with_header(
443        header: Option<SlimHeader>,
444        session: Option<SessionHeader>,
445        payload: Option<Content>,
446    ) -> Self {
447        ProtoPublish {
448            header,
449            session,
450            msg: payload,
451        }
452    }
453
454    pub fn new(
455        source: &Name,
456        dst: &Name,
457        flags: Option<SlimHeaderFlags>,
458        content_type: &str,
459        blob: Vec<u8>,
460    ) -> Self {
461        let slim_header = Some(SlimHeader::new(source, dst, flags));
462
463        let mut session_header = SessionHeader::default();
464        session_header.set_source(source);
465        session_header.set_destination(dst);
466
467        let msg = Some(Content {
468            content_type: content_type.to_string(),
469            blob,
470        });
471
472        Self::with_header(slim_header, Some(session_header), msg)
473    }
474
475    pub fn get_slim_header(&self) -> &SlimHeader {
476        self.header.as_ref().unwrap()
477    }
478
479    pub fn get_session_header(&self) -> &SessionHeader {
480        self.session.as_ref().unwrap()
481    }
482
483    pub fn get_slim_header_as_mut(&mut self) -> &mut SlimHeader {
484        self.header.as_mut().unwrap()
485    }
486
487    pub fn get_session_header_as_mut(&mut self) -> &mut SessionHeader {
488        self.session.as_mut().unwrap()
489    }
490
491    pub fn get_payload(&self) -> &Content {
492        self.msg.as_ref().unwrap()
493    }
494}
495
496/// From ProtoMessage to ProtoPublish
497impl From<ProtoMessage> for ProtoPublish {
498    fn from(message: ProtoMessage) -> Self {
499        match message.message_type {
500            Some(ProtoPublishType(p)) => p,
501            _ => panic!("message type is not publish"),
502        }
503    }
504}
505
506/// ProtoMessage
507/// This represents a generic message that can be sent over the network
508impl ProtoMessage {
509    fn new(metadata: HashMap<String, String>, message_type: MessageType) -> Self {
510        ProtoMessage {
511            metadata,
512            message_type: Some(message_type),
513        }
514    }
515
516    pub fn new_subscribe(source: &Name, dst: &Name, flags: Option<SlimHeaderFlags>) -> Self {
517        let subscribe = ProtoSubscribe::new(source, dst, flags);
518
519        Self::new(HashMap::new(), ProtoSubscribeType(subscribe))
520    }
521
522    pub fn new_unsubscribe(source: &Name, dst: &Name, flags: Option<SlimHeaderFlags>) -> Self {
523        let unsubscribe = ProtoUnsubscribe::new(source, dst, flags);
524
525        Self::new(HashMap::new(), ProtoUnsubscribeType(unsubscribe))
526    }
527
528    pub fn new_publish(
529        source: &Name,
530        dst: &Name,
531        flags: Option<SlimHeaderFlags>,
532        content_type: &str,
533        blob: Vec<u8>,
534    ) -> Self {
535        let publish = ProtoPublish::new(source, dst, flags, content_type, blob);
536
537        Self::new(HashMap::new(), ProtoPublishType(publish))
538    }
539
540    pub fn new_publish_with_headers(
541        slim_header: Option<SlimHeader>,
542        session_header: Option<SessionHeader>,
543        content_type: &str,
544        blob: Vec<u8>,
545    ) -> Self {
546        let publish = ProtoPublish::with_header(
547            slim_header,
548            session_header,
549            Some(Content {
550                content_type: content_type.to_string(),
551                blob,
552            }),
553        );
554
555        Self::new(HashMap::new(), ProtoPublishType(publish))
556    }
557
558    // validate message
559    pub fn validate(&self) -> Result<(), MessageError> {
560        // make sure the message type is set
561        if self.message_type.is_none() {
562            return Err(MessageError::MessageTypeNotFound);
563        }
564
565        // make sure SLIM header is set
566        if self.try_get_slim_header().is_none() {
567            return Err(MessageError::SlimHeaderNotFound);
568        }
569
570        // Get SLIM header
571        let slim_header = self.get_slim_header();
572
573        // make sure source and destination are set
574        if slim_header.source.is_none() {
575            return Err(MessageError::SourceNotFound);
576        }
577        if slim_header.destination.is_none() {
578            return Err(MessageError::DestinationNotFound);
579        }
580
581        match &self.message_type {
582            Some(ProtoPublishType(p)) => {
583                // SLIM Header
584                if p.header.is_none() {
585                    return Err(MessageError::SlimHeaderNotFound);
586                }
587
588                // Publish message should have the session header
589                if p.session.is_none() {
590                    return Err(MessageError::SessionHeaderNotFound);
591                }
592            }
593            Some(ProtoSubscribeType(s)) => {
594                if s.header.is_none() {
595                    return Err(MessageError::SlimHeaderNotFound);
596                }
597            }
598            Some(ProtoUnsubscribeType(u)) => {
599                if u.header.is_none() {
600                    return Err(MessageError::SlimHeaderNotFound);
601                }
602            }
603            None => return Err(MessageError::MessageTypeNotFound),
604        }
605
606        Ok(())
607    }
608
609    // add metadata key in the map assigning the value val
610    // if the key exists the value is replaced by val
611    pub fn insert_metadata(&mut self, key: String, val: String) {
612        self.metadata.insert(key, val);
613    }
614
615    // remove metadata key from the map
616    pub fn remove_metadata(&mut self, key: &str) -> Option<String> {
617        self.metadata.remove(key)
618    }
619
620    pub fn contains_metadata(&self, key: &str) -> bool {
621        self.metadata.contains_key(key)
622    }
623
624    pub fn get_metadata(&self, key: &str) -> Option<&String> {
625        self.metadata.get(key)
626    }
627
628    pub fn get_metadata_map(&self) -> HashMap<String, String> {
629        self.metadata.clone()
630    }
631
632    pub fn set_metadata_map(&mut self, map: HashMap<String, String>) {
633        for (k, v) in map.iter() {
634            self.insert_metadata(k.to_string(), v.to_string());
635        }
636    }
637
638    pub fn get_slim_header(&self) -> &SlimHeader {
639        match &self.message_type {
640            Some(ProtoPublishType(publish)) => publish.header.as_ref().unwrap(),
641            Some(ProtoSubscribeType(sub)) => sub.header.as_ref().unwrap(),
642            Some(ProtoUnsubscribeType(unsub)) => unsub.header.as_ref().unwrap(),
643            None => panic!("SLIM header not found"),
644        }
645    }
646
647    pub fn get_slim_header_mut(&mut self) -> &mut SlimHeader {
648        match &mut self.message_type {
649            Some(ProtoPublishType(publish)) => publish.header.as_mut().unwrap(),
650            Some(ProtoSubscribeType(sub)) => sub.header.as_mut().unwrap(),
651            Some(ProtoUnsubscribeType(unsub)) => unsub.header.as_mut().unwrap(),
652            None => panic!("SLIM header not found"),
653        }
654    }
655
656    pub fn try_get_slim_header(&self) -> Option<&SlimHeader> {
657        match &self.message_type {
658            Some(ProtoPublishType(publish)) => publish.header.as_ref(),
659            Some(ProtoSubscribeType(sub)) => sub.header.as_ref(),
660            Some(ProtoUnsubscribeType(unsub)) => unsub.header.as_ref(),
661            None => None,
662        }
663    }
664
665    pub fn get_session_header(&self) -> &SessionHeader {
666        match &self.message_type {
667            Some(ProtoPublishType(publish)) => publish.session.as_ref().unwrap(),
668            Some(ProtoSubscribeType(_)) => panic!("session header not found"),
669            Some(ProtoUnsubscribeType(_)) => panic!("session header not found"),
670            None => panic!("session header not found"),
671        }
672    }
673
674    pub fn get_session_header_mut(&mut self) -> &mut SessionHeader {
675        match &mut self.message_type {
676            Some(ProtoPublishType(publish)) => publish.session.as_mut().unwrap(),
677            Some(ProtoSubscribeType(_)) => panic!("session header not found"),
678            Some(ProtoUnsubscribeType(_)) => panic!("session header not found"),
679            None => panic!("session header not found"),
680        }
681    }
682
683    pub fn try_get_session_header(&self) -> Option<&SessionHeader> {
684        match &self.message_type {
685            Some(ProtoPublishType(publish)) => publish.session.as_ref(),
686            Some(ProtoSubscribeType(_)) => None,
687            Some(ProtoUnsubscribeType(_)) => None,
688            None => None,
689        }
690    }
691
692    pub fn try_get_session_header_mut(&mut self) -> Option<&mut SessionHeader> {
693        match &mut self.message_type {
694            Some(ProtoPublishType(publish)) => publish.session.as_mut(),
695            Some(ProtoSubscribeType(_)) => None,
696            Some(ProtoUnsubscribeType(_)) => None,
697            None => None,
698        }
699    }
700
701    pub fn get_id(&self) -> u32 {
702        self.get_session_header().get_message_id()
703    }
704
705    pub fn get_source(&self) -> Name {
706        if let Some(ProtoPublishType(pubslih)) = &self.message_type {
707            // try to the src dst from the session header
708            if let Some(s) = pubslih.get_session_header().get_source() {
709                return s;
710            }
711        }
712
713        // get from slim header
714        self.get_slim_header().get_source()
715    }
716
717    pub fn get_fanout(&self) -> u32 {
718        self.get_slim_header().get_fanout()
719    }
720
721    pub fn get_recv_from(&self) -> Option<u64> {
722        self.get_slim_header().get_recv_from()
723    }
724
725    pub fn get_forward_to(&self) -> Option<u64> {
726        self.get_slim_header().get_forward_to()
727    }
728
729    pub fn get_error(&self) -> Option<bool> {
730        self.get_slim_header().get_error()
731    }
732
733    pub fn get_incoming_conn(&self) -> u64 {
734        self.get_slim_header().get_incoming_conn().unwrap()
735    }
736
737    pub fn try_get_incoming_conn(&self) -> Option<u64> {
738        self.get_slim_header().get_incoming_conn()
739    }
740
741    pub fn get_dst(&self) -> Name {
742        match &self.message_type {
743            Some(ProtoPublishType(pubslih)) => {
744                // try to the get dst from the session header
745                if let Some(d) = pubslih.get_session_header().get_destination() {
746                    return d;
747                }
748                // get from the slim header
749                self.get_slim_header().get_dst()
750            }
751            Some(ProtoSubscribeType(subscribe)) => {
752                let dst = self.get_slim_header().get_dst();
753                // complete name with the original strings
754                Name::from_strings([
755                    subscribe.component_0.clone(),
756                    subscribe.component_1.clone(),
757                    subscribe.component_2.clone(),
758                ])
759                .with_id(dst.id())
760            }
761            Some(ProtoUnsubscribeType(unsubscribe)) => {
762                let dst = self.get_slim_header().get_dst();
763                // complete name with the original strings
764                Name::from_strings([
765                    unsubscribe.component_0.clone(),
766                    unsubscribe.component_1.clone(),
767                    unsubscribe.component_2.clone(),
768                ])
769                .with_id(dst.id())
770            }
771            None => {
772                // this should never happen
773                self.get_slim_header().get_dst()
774            }
775        }
776    }
777
778    pub fn get_type(&self) -> &MessageType {
779        match &self.message_type {
780            Some(t) => t,
781            None => panic!("message type not found"),
782        }
783    }
784
785    pub fn get_payload(&self) -> Option<&Content> {
786        match &self.message_type {
787            Some(ProtoPublishType(p)) => p.msg.as_ref(),
788            Some(ProtoSubscribeType(_)) => panic!("payload not found"),
789            Some(ProtoUnsubscribeType(_)) => panic!("payload not found"),
790            None => panic!("payload not found"),
791        }
792    }
793
794    pub fn get_session_message_type(&self) -> SessionMessageType {
795        self.get_session_header()
796            .session_message_type
797            .try_into()
798            .unwrap_or_default()
799    }
800
801    pub fn clear_slim_header(&mut self) {
802        self.get_slim_header_mut().clear();
803    }
804
805    pub fn set_recv_from(&mut self, recv_from: Option<u64>) {
806        self.get_slim_header_mut().set_recv_from(recv_from);
807    }
808
809    pub fn set_forward_to(&mut self, forward_to: Option<u64>) {
810        self.get_slim_header_mut().set_forward_to(forward_to);
811    }
812
813    pub fn set_error(&mut self, error: Option<bool>) {
814        self.get_slim_header_mut().set_error(error);
815    }
816
817    pub fn set_fanout(&mut self, fanout: u32) {
818        self.get_slim_header_mut().set_fanout(fanout);
819    }
820
821    pub fn set_incoming_conn(&mut self, incoming_conn: Option<u64>) {
822        self.get_slim_header_mut().set_incoming_conn(incoming_conn);
823    }
824
825    pub fn set_error_flag(&mut self, error: Option<bool>) {
826        self.get_slim_header_mut().set_error_flag(error);
827    }
828
829    pub fn set_session_message_type(&mut self, message_type: SessionMessageType) {
830        self.get_session_header_mut()
831            .set_session_message_type(message_type);
832    }
833
834    pub fn set_session_type(&mut self, session_type: ProtoSessionType) {
835        self.get_session_header_mut().set_session_type(session_type);
836    }
837
838    pub fn get_session_type(&self) -> ProtoSessionType {
839        self.get_session_header().session_type()
840    }
841
842    pub fn set_message_id(&mut self, message_id: u32) {
843        self.get_session_header_mut().set_message_id(message_id);
844    }
845
846    pub fn is_publish(&self) -> bool {
847        matches!(self.get_type(), MessageType::Publish(_))
848    }
849
850    pub fn is_subscribe(&self) -> bool {
851        matches!(self.get_type(), MessageType::Subscribe(_))
852    }
853
854    pub fn is_unsubscribe(&self) -> bool {
855        matches!(self.get_type(), MessageType::Unsubscribe(_))
856    }
857}
858
859impl AsRef<ProtoPublish> for ProtoMessage {
860    fn as_ref(&self) -> &ProtoPublish {
861        match &self.message_type {
862            Some(ProtoPublishType(p)) => p,
863            _ => panic!("message type is not publish"),
864        }
865    }
866}
867
868#[cfg(test)]
869mod tests {
870    use crate::{api::proto::dataplane::v1::SessionMessageType, messages::encoder::Name};
871
872    use super::*;
873
874    fn test_subscription_template(
875        subscription: bool,
876        source: Name,
877        dst: Name,
878        flags: Option<SlimHeaderFlags>,
879    ) {
880        let sub = {
881            if subscription {
882                ProtoMessage::new_subscribe(&source, &dst, flags.clone())
883            } else {
884                ProtoMessage::new_unsubscribe(&source, &dst, flags.clone())
885            }
886        };
887
888        let flags = if flags.is_none() {
889            Some(SlimHeaderFlags::default())
890        } else {
891            flags
892        };
893
894        assert!(!sub.is_publish());
895        assert_eq!(sub.is_subscribe(), subscription);
896        assert_eq!(sub.is_unsubscribe(), !subscription);
897        assert_eq!(flags.as_ref().unwrap().recv_from, sub.get_recv_from());
898        assert_eq!(flags.as_ref().unwrap().forward_to, sub.get_forward_to());
899        assert_eq!(None, sub.try_get_incoming_conn());
900        assert_eq!(source, sub.get_source());
901        let got_name = sub.get_dst();
902        assert_eq!(dst, got_name);
903    }
904
905    fn test_publish_template(source: Name, dst: Name, flags: Option<SlimHeaderFlags>) {
906        let pub_msg = ProtoMessage::new_publish(
907            &source,
908            &dst,
909            flags.clone(),
910            "str",
911            "this is the content of the message".into(),
912        );
913
914        let flags = if flags.is_none() {
915            Some(SlimHeaderFlags::default())
916        } else {
917            flags
918        };
919
920        assert!(pub_msg.is_publish());
921        assert!(!pub_msg.is_subscribe());
922        assert!(!pub_msg.is_unsubscribe());
923        assert_eq!(flags.as_ref().unwrap().recv_from, pub_msg.get_recv_from());
924        assert_eq!(flags.as_ref().unwrap().forward_to, pub_msg.get_forward_to());
925        assert_eq!(None, pub_msg.try_get_incoming_conn());
926        assert_eq!(source, pub_msg.get_source());
927        let got_name = pub_msg.get_dst();
928        assert_eq!(dst, got_name);
929        assert_eq!(flags.as_ref().unwrap().fanout, pub_msg.get_fanout());
930    }
931
932    #[test]
933    fn test_subscription() {
934        let source = Name::from_strings(["org", "ns", "type"]).with_id(1);
935        let dst = Name::from_strings(["org", "ns", "type"]).with_id(2);
936
937        // simple
938        test_subscription_template(true, source.clone(), dst.clone(), None);
939
940        // with name id
941        test_subscription_template(true, source.clone(), dst.clone(), None);
942
943        // with recv from
944        test_subscription_template(
945            true,
946            source.clone(),
947            dst.clone(),
948            Some(SlimHeaderFlags::default().with_recv_from(50)),
949        );
950
951        // with forward to
952        test_subscription_template(
953            true,
954            source.clone(),
955            dst.clone(),
956            Some(SlimHeaderFlags::default().with_forward_to(30)),
957        );
958    }
959
960    #[test]
961    fn test_unsubscription() {
962        let source = Name::from_strings(["org", "ns", "type"]).with_id(1);
963        let dst = Name::from_strings(["org", "ns", "type"]).with_id(2);
964
965        // simple
966        test_subscription_template(false, source.clone(), dst.clone(), None);
967
968        // with name id
969        test_subscription_template(false, source.clone(), dst.clone(), None);
970
971        // with recv from
972        test_subscription_template(
973            false,
974            source.clone(),
975            dst.clone(),
976            Some(SlimHeaderFlags::default().with_recv_from(50)),
977        );
978
979        // with forward to
980        test_subscription_template(
981            false,
982            source.clone(),
983            dst.clone(),
984            Some(SlimHeaderFlags::default().with_forward_to(30)),
985        );
986    }
987
988    #[test]
989    fn test_publish() {
990        let source = Name::from_strings(["org", "ns", "type"]).with_id(1);
991        let mut dst = Name::from_strings(["org", "ns", "type"]);
992
993        // simple
994        test_publish_template(
995            source.clone(),
996            dst.clone(),
997            Some(SlimHeaderFlags::default()),
998        );
999
1000        // with name id
1001        dst.set_id(2);
1002        test_publish_template(
1003            source.clone(),
1004            dst.clone(),
1005            Some(SlimHeaderFlags::default()),
1006        );
1007        dst.reset_id();
1008
1009        // with recv from
1010        test_publish_template(
1011            source.clone(),
1012            dst.clone(),
1013            Some(SlimHeaderFlags::default().with_recv_from(50)),
1014        );
1015
1016        // with forward to
1017        test_publish_template(
1018            source.clone(),
1019            dst.clone(),
1020            Some(SlimHeaderFlags::default().with_forward_to(30)),
1021        );
1022
1023        // with fanout
1024        test_publish_template(
1025            source.clone(),
1026            dst.clone(),
1027            Some(SlimHeaderFlags::default().with_fanout(2)),
1028        );
1029    }
1030
1031    #[test]
1032    fn test_conversions() {
1033        // Name to ProtoName
1034        let name = Name::from_strings(["org", "ns", "type"]).with_id(1);
1035        let proto_name = ProtoName::from(&name);
1036
1037        assert_eq!(proto_name.component_0, name.components()[0]);
1038        assert_eq!(proto_name.component_1, name.components()[1]);
1039        assert_eq!(proto_name.component_2, name.components()[2]);
1040        assert_eq!(proto_name.component_3, name.components()[3]);
1041
1042        // ProtoName to Name
1043        let name_from_proto = Name::from(&proto_name);
1044        assert_eq!(name_from_proto.components()[0], proto_name.component_0);
1045        assert_eq!(name_from_proto.components()[1], proto_name.component_1);
1046        assert_eq!(name_from_proto.components()[2], proto_name.component_2);
1047        assert_eq!(name_from_proto.components()[3], proto_name.component_3);
1048
1049        // ProtoMessage to ProtoSubscribe
1050        let dst = Name::from_strings(["org", "ns", "type"]).with_id(1);
1051        let proto_subscribe = ProtoMessage::new_subscribe(
1052            &name,
1053            &dst,
1054            Some(
1055                SlimHeaderFlags::default()
1056                    .with_recv_from(2)
1057                    .with_forward_to(3),
1058            ),
1059        );
1060        let proto_subscribe = ProtoSubscribe::from(proto_subscribe);
1061        assert_eq!(proto_subscribe.header.as_ref().unwrap().get_source(), name);
1062        assert_eq!(proto_subscribe.header.as_ref().unwrap().get_dst(), dst,);
1063
1064        // ProtoMessage to ProtoUnsubscribe
1065        let proto_unsubscribe = ProtoMessage::new_unsubscribe(
1066            &name,
1067            &dst,
1068            Some(
1069                SlimHeaderFlags::default()
1070                    .with_recv_from(2)
1071                    .with_forward_to(3),
1072            ),
1073        );
1074        let proto_unsubscribe = ProtoUnsubscribe::from(proto_unsubscribe);
1075        assert_eq!(
1076            proto_unsubscribe.header.as_ref().unwrap().get_source(),
1077            name
1078        );
1079        assert_eq!(proto_unsubscribe.header.as_ref().unwrap().get_dst(), dst);
1080
1081        // ProtoMessage to ProtoPublish
1082        let proto_publish = ProtoMessage::new_publish(
1083            &name,
1084            &dst,
1085            Some(
1086                SlimHeaderFlags::default()
1087                    .with_recv_from(2)
1088                    .with_forward_to(3),
1089            ),
1090            "str",
1091            "this is the content of the message".into(),
1092        );
1093        let proto_publish = ProtoPublish::from(proto_publish);
1094        assert_eq!(proto_publish.header.as_ref().unwrap().get_source(), name);
1095        assert_eq!(proto_publish.header.as_ref().unwrap().get_dst(), dst);
1096    }
1097
1098    #[test]
1099    fn test_panic() {
1100        let source = Name::from_strings(["org", "ns", "type"]).with_id(1);
1101        let dst = Name::from_strings(["org", "ns", "type"]).with_id(2);
1102
1103        // panic if SLIM header is not found
1104        let msg = ProtoMessage::new_subscribe(
1105            &source,
1106            &dst,
1107            Some(
1108                SlimHeaderFlags::default()
1109                    .with_recv_from(2)
1110                    .with_forward_to(3),
1111            ),
1112        );
1113
1114        // let's try to convert it to a unsubscribe
1115        // this should panic because the message type is not unsubscribe
1116        let result = std::panic::catch_unwind(|| ProtoUnsubscribe::from(msg.clone()));
1117        assert!(result.is_err());
1118
1119        // try to convert to publish
1120        // this should panic because the message type is not publish
1121        let result = std::panic::catch_unwind(|| ProtoPublish::from(msg.clone()));
1122        assert!(result.is_err());
1123
1124        // finally make sure the conversion to subscribe works
1125        let result = std::panic::catch_unwind(|| ProtoSubscribe::from(msg));
1126        assert!(result.is_ok());
1127    }
1128
1129    #[test]
1130    fn test_panic_header() {
1131        // create a unusual SLIM header
1132        let header = SlimHeader {
1133            source: None,
1134            destination: None,
1135            fanout: 0,
1136            recv_from: None,
1137            forward_to: None,
1138            incoming_conn: None,
1139            error: None,
1140        };
1141
1142        // the operations to retrieve source and destination should fail with panic
1143        let result = std::panic::catch_unwind(|| header.get_source());
1144        assert!(result.is_err());
1145
1146        let result = std::panic::catch_unwind(|| header.get_dst());
1147        assert!(result.is_err());
1148
1149        // The operations to retrieve recv_from and forward_to should not fail with panic
1150        let result = std::panic::catch_unwind(|| header.get_recv_from());
1151        assert!(result.is_ok());
1152
1153        let result = std::panic::catch_unwind(|| header.get_forward_to());
1154        assert!(result.is_ok());
1155
1156        // The operations to retrieve incoming_conn should not fail with panic
1157        let result = std::panic::catch_unwind(|| header.get_incoming_conn());
1158        assert!(result.is_ok());
1159
1160        // The operations to retrieve error should not fail with panic
1161        let result = std::panic::catch_unwind(|| header.get_error());
1162        assert!(result.is_ok());
1163    }
1164
1165    #[test]
1166    fn test_panic_session_header() {
1167        // create a unusual session header
1168        let header = SessionHeader::new(0, 0, 0, 0, &None, &None);
1169
1170        // the operations to retrieve session_id and message_id should not fail with panic
1171        let result = std::panic::catch_unwind(|| header.get_session_id());
1172        assert!(result.is_ok());
1173
1174        let result = std::panic::catch_unwind(|| header.get_message_id());
1175        assert!(result.is_ok());
1176    }
1177
1178    #[test]
1179    fn test_panic_proto_message() {
1180        // create a unusual proto message
1181        let message = ProtoMessage {
1182            metadata: HashMap::new(),
1183            message_type: None,
1184        };
1185
1186        // the operation to retrieve the header should fail with panic
1187        let result = std::panic::catch_unwind(|| message.get_slim_header());
1188        assert!(result.is_err());
1189
1190        // the operation to retrieve the message type should fail with panic
1191        let result = std::panic::catch_unwind(|| message.get_type());
1192        assert!(result.is_err());
1193
1194        // all the other ops should fail with panic as well as the header is not set
1195        let result = std::panic::catch_unwind(|| message.get_source());
1196        assert!(result.is_err());
1197        let result = std::panic::catch_unwind(|| message.get_dst());
1198        assert!(result.is_err());
1199        let result = std::panic::catch_unwind(|| message.get_recv_from());
1200        assert!(result.is_err());
1201        let result = std::panic::catch_unwind(|| message.get_forward_to());
1202        assert!(result.is_err());
1203        let result = std::panic::catch_unwind(|| message.get_incoming_conn());
1204        assert!(result.is_err());
1205        let result = std::panic::catch_unwind(|| message.get_fanout());
1206        assert!(result.is_err());
1207    }
1208
1209    #[test]
1210    fn test_service_type_to_int() {
1211        // Get total number of service types
1212        let total_service_types = SessionMessageType::ChannelMlsAck as i32;
1213
1214        for i in 0..total_service_types {
1215            // int -> ServiceType
1216            let service_type =
1217                SessionMessageType::try_from(i).expect("failed to convert int to service type");
1218            let service_type_int = i32::from(service_type);
1219            assert_eq!(service_type_int, i32::from(service_type),);
1220        }
1221
1222        // Test invalid conversion
1223        let invalid_service_type = SessionMessageType::try_from(total_service_types + 1);
1224        assert!(invalid_service_type.is_err());
1225    }
1226}