slim_datapath/messages/
utils.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashMap;
5use std::fmt::Display;
6
7use tracing::debug;
8
9use super::encoder::Name;
10use crate::api::{
11    Content, MessageType, ProtoMessage, ProtoName, ProtoPublish, ProtoPublishType,
12    ProtoSessionType, ProtoSubscribe, ProtoSubscribeType, ProtoUnsubscribe, ProtoUnsubscribeType,
13    SessionHeader, SlimHeader,
14    proto::dataplane::v1::{OriginalName, SessionMessageType},
15};
16
17use thiserror::Error;
18use tracing::error;
19
20#[derive(Error, Debug, PartialEq)]
21pub enum MessageError {
22    #[error("SLIM header not found")]
23    SlimHeaderNotFound,
24    #[error("source not found")]
25    SourceNotFound,
26    #[error("destination not found")]
27    DestinationNotFound,
28    #[error("session header not found")]
29    SessionHeaderNotFound,
30    #[error("message type not found")]
31    MessageTypeNotFound,
32    #[error("incoming connection not found")]
33    IncomingConnectionNotFound,
34}
35
36// Metadata Keys
37pub const SLIM_IDENTITY: &str = "SLIM_IDENTITY";
38
39/// ProtoName from Name
40impl From<&Name> for ProtoName {
41    fn from(name: &Name) -> Self {
42        Self {
43            component_0: name.components()[0],
44            component_1: name.components()[1],
45            component_2: name.components()[2],
46            component_3: name.components()[3],
47        }
48    }
49}
50
51/// Print message type
52impl Display for MessageType {
53    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54        match self {
55            MessageType::Publish(_) => write!(f, "publish"),
56            MessageType::Subscribe(_) => write!(f, "subscribe"),
57            MessageType::Unsubscribe(_) => write!(f, "unsubscribe"),
58        }
59    }
60}
61
62/// Struct grouping the SLIMHeaeder flags for convenience
63#[derive(Debug, Clone)]
64pub struct SlimHeaderFlags {
65    pub fanout: u32,
66    pub recv_from: Option<u64>,
67    pub forward_to: Option<u64>,
68    pub incoming_conn: Option<u64>,
69    pub error: Option<bool>,
70}
71
72impl Default for SlimHeaderFlags {
73    fn default() -> Self {
74        Self {
75            fanout: 1,
76            recv_from: None,
77            forward_to: None,
78            incoming_conn: None,
79            error: None,
80        }
81    }
82}
83
84impl Display for SlimHeaderFlags {
85    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86        write!(
87            f,
88            "fanout: {}, recv_from: {:?}, forward_to: {:?}, incoming_conn: {:?}, error: {:?}",
89            self.fanout, self.recv_from, self.forward_to, self.incoming_conn, self.error
90        )
91    }
92}
93
94impl SlimHeaderFlags {
95    pub fn new(
96        fanout: u32,
97        recv_from: Option<u64>,
98        forward_to: Option<u64>,
99        incoming_conn: Option<u64>,
100        error: Option<bool>,
101    ) -> Self {
102        Self {
103            fanout,
104            recv_from,
105            forward_to,
106            incoming_conn,
107            error,
108        }
109    }
110
111    pub fn with_fanout(self, fanout: u32) -> Self {
112        Self { fanout, ..self }
113    }
114
115    pub fn with_recv_from(self, recv_from: u64) -> Self {
116        Self {
117            recv_from: Some(recv_from),
118            ..self
119        }
120    }
121
122    pub fn with_forward_to(self, forward_to: u64) -> Self {
123        Self {
124            forward_to: Some(forward_to),
125            ..self
126        }
127    }
128
129    pub fn with_incoming_conn(self, incoming_conn: u64) -> Self {
130        Self {
131            incoming_conn: Some(incoming_conn),
132            ..self
133        }
134    }
135
136    pub fn with_error(self, error: bool) -> Self {
137        Self {
138            error: Some(error),
139            ..self
140        }
141    }
142}
143
144/// SLIM Header
145/// This header is used to identify the source and destination of the message
146/// and to manage the connections used to send and receive the message
147impl SlimHeader {
148    pub fn new(source: &Name, destination: &Name, flags: Option<SlimHeaderFlags>) -> Self {
149        let flags = flags.unwrap_or_default();
150
151        Self {
152            source: Some(ProtoName::from(source)),
153            destination: Some(ProtoName::from(destination)),
154            fanout: flags.fanout,
155            recv_from: flags.recv_from,
156            forward_to: flags.forward_to,
157            incoming_conn: flags.incoming_conn,
158            error: flags.error,
159        }
160    }
161
162    pub fn clear(&mut self) {
163        self.recv_from = None;
164        self.forward_to = None;
165    }
166
167    pub fn get_recv_from(&self) -> Option<u64> {
168        self.recv_from
169    }
170
171    pub fn get_forward_to(&self) -> Option<u64> {
172        self.forward_to
173    }
174
175    pub fn get_incoming_conn(&self) -> Option<u64> {
176        self.incoming_conn
177    }
178
179    pub fn get_error(&self) -> Option<bool> {
180        self.error
181    }
182
183    pub fn get_source(&self) -> Name {
184        match &self.source {
185            Some(source) => Name::from(source),
186            None => panic!("source not found"),
187        }
188    }
189
190    pub fn get_dst(&self) -> Name {
191        match &self.destination {
192            Some(destination) => Name::from(destination),
193            None => panic!("destination not found"),
194        }
195    }
196
197    pub fn set_source(&mut self, source: &Name) {
198        self.source = Some(ProtoName::from(source));
199    }
200
201    pub fn set_destination(&mut self, dst: &Name) {
202        self.destination = Some(ProtoName::from(dst));
203    }
204
205    pub fn get_fanout(&self) -> u32 {
206        self.fanout
207    }
208
209    pub fn set_recv_from(&mut self, recv_from: Option<u64>) {
210        self.recv_from = recv_from;
211    }
212
213    pub fn set_forward_to(&mut self, forward_to: Option<u64>) {
214        self.forward_to = forward_to;
215    }
216
217    pub fn set_error(&mut self, error: Option<bool>) {
218        self.error = error;
219    }
220
221    pub fn set_incoming_conn(&mut self, incoming_conn: Option<u64>) {
222        self.incoming_conn = incoming_conn;
223    }
224
225    pub fn set_error_flag(&mut self, error: Option<bool>) {
226        self.error = error;
227    }
228
229    pub fn set_fanout(&mut self, fanout: u32) {
230        self.fanout = fanout;
231    }
232
233    // returns the connection to use to process correctly the message
234    // first connection is from where we received the packet
235    // the second is where to forward the packet if needed
236    pub fn get_in_out_connections(&self) -> (u64, Option<u64>) {
237        // when calling this function, incoming connection is set
238        let incoming = self
239            .get_incoming_conn()
240            .expect("incoming connection not found");
241
242        if let Some(val) = self.get_recv_from() {
243            debug!(
244                "received recv_from command, update state on connection {}",
245                val
246            );
247            return (val, None);
248        }
249
250        if let Some(val) = self.get_forward_to() {
251            debug!(
252                "received forward_to command, update state and forward to connection {}",
253                val
254            );
255            return (incoming, Some(val));
256        }
257
258        // by default, return the incoming connection and None
259        (incoming, None)
260    }
261}
262
263/// Session Header
264/// This header is used to identify the session and the message
265/// and to manage session state
266impl SessionHeader {
267    pub fn new(
268        session_type: i32,
269        session_message_type: i32,
270        session_id: u32,
271        message_id: u32,
272        source: &Option<Name>,
273        destination: &Option<Name>,
274    ) -> Self {
275        let src = match source {
276            Some(name) => name.components_strings().map(|c| OriginalName {
277                component_0: c[0].clone(),
278                component_1: c[1].clone(),
279                component_2: c[2].clone(),
280                component_3: name.id(),
281            }),
282            None => None,
283        };
284
285        let dst = match destination {
286            Some(name) => name.components_strings().map(|c| OriginalName {
287                component_0: c[0].clone(),
288                component_1: c[1].clone(),
289                component_2: c[2].clone(),
290                component_3: name.id(),
291            }),
292            None => None,
293        };
294
295        Self {
296            session_type,
297            session_message_type,
298            session_id,
299            message_id,
300            source: src,
301            destination: dst,
302        }
303    }
304
305    pub fn get_session_id(&self) -> u32 {
306        self.session_id
307    }
308
309    pub fn get_message_id(&self) -> u32 {
310        self.message_id
311    }
312
313    pub fn set_session_id(&mut self, session_id: u32) {
314        self.session_id = session_id;
315    }
316
317    pub fn set_message_id(&mut self, message_id: u32) {
318        self.message_id = message_id;
319    }
320
321    pub fn clear(&mut self) {
322        self.session_id = 0;
323        self.message_id = 0;
324    }
325
326    pub fn get_source(&self) -> Option<Name> {
327        match &self.source {
328            Some(src) => {
329                let n = Name::from_strings([
330                    src.component_0.clone(),
331                    src.component_1.clone(),
332                    src.component_2.clone(),
333                ])
334                .with_id(src.component_3);
335                Some(n)
336            }
337            None => None,
338        }
339    }
340
341    pub fn set_source(&mut self, source: &Name) {
342        let c_opt = source.components_strings();
343        if c_opt.is_none() {
344            return;
345        }
346
347        let c = c_opt.unwrap();
348
349        self.source = Some(OriginalName {
350            component_0: c[0].clone(),
351            component_1: c[1].clone(),
352            component_2: c[2].clone(),
353            component_3: source.id(),
354        });
355    }
356
357    pub fn get_destination(&self) -> Option<Name> {
358        match &self.destination {
359            Some(src) => {
360                let n = Name::from_strings([
361                    src.component_0.clone(),
362                    src.component_1.clone(),
363                    src.component_2.clone(),
364                ])
365                .with_id(src.component_3);
366                Some(n)
367            }
368            None => None,
369        }
370    }
371
372    pub fn set_destination(&mut self, destination: &Name) {
373        let c_opt = destination.components_strings();
374        if c_opt.is_none() {
375            return;
376        }
377
378        let c = c_opt.unwrap();
379
380        self.destination = Some(OriginalName {
381            component_0: c[0].clone(),
382            component_1: c[1].clone(),
383            component_2: c[2].clone(),
384            component_3: destination.id(),
385        });
386    }
387}
388
389/// ProtoSubscribe
390/// This message is used to subscribe to a topic
391impl ProtoSubscribe {
392    pub fn new(source: &Name, dst: &Name, flags: Option<SlimHeaderFlags>) -> Self {
393        let header = Some(SlimHeader::new(source, dst, flags));
394
395        ProtoSubscribe {
396            header,
397            component_0: dst.components_strings().unwrap()[0].clone(),
398            component_1: dst.components_strings().unwrap()[1].clone(),
399            component_2: dst.components_strings().unwrap()[2].clone(),
400        }
401    }
402}
403
404/// From ProtoMessage to ProtoSubscribe
405impl From<ProtoMessage> for ProtoSubscribe {
406    fn from(message: ProtoMessage) -> Self {
407        match message.message_type {
408            Some(ProtoSubscribeType(s)) => s,
409            _ => panic!("message type is not subscribe"),
410        }
411    }
412}
413
414/// ProtoUnsubscribe
415/// This message is used to unsubscribe from a topic
416impl ProtoUnsubscribe {
417    pub fn new(source: &Name, dst: &Name, flags: Option<SlimHeaderFlags>) -> Self {
418        let header = Some(SlimHeader::new(source, dst, flags));
419
420        ProtoUnsubscribe {
421            header,
422            component_0: dst.components_strings().unwrap()[0].clone(),
423            component_1: dst.components_strings().unwrap()[1].clone(),
424            component_2: dst.components_strings().unwrap()[2].clone(),
425        }
426    }
427}
428
429/// From ProtoMessage to ProtoUnsubscribe
430impl From<ProtoMessage> for ProtoUnsubscribe {
431    fn from(message: ProtoMessage) -> Self {
432        match message.message_type {
433            Some(ProtoUnsubscribeType(u)) => u,
434            _ => panic!("message type is not unsubscribe"),
435        }
436    }
437}
438
439/// ProtoPublish
440/// This message is used to publish a message, either to a shared channel or to a specific application
441impl ProtoPublish {
442    pub fn with_header(
443        header: Option<SlimHeader>,
444        session: Option<SessionHeader>,
445        payload: Option<Content>,
446    ) -> Self {
447        ProtoPublish {
448            header,
449            session,
450            msg: payload,
451        }
452    }
453
454    pub fn new(
455        source: &Name,
456        dst: &Name,
457        flags: Option<SlimHeaderFlags>,
458        content_type: &str,
459        blob: Vec<u8>,
460    ) -> Self {
461        let slim_header = Some(SlimHeader::new(source, dst, flags));
462
463        let mut session_header = SessionHeader::default();
464        session_header.set_source(source);
465        session_header.set_destination(dst);
466
467        let msg = Some(Content {
468            content_type: content_type.to_string(),
469            blob,
470        });
471
472        Self::with_header(slim_header, Some(session_header), msg)
473    }
474
475    pub fn get_slim_header(&self) -> &SlimHeader {
476        self.header.as_ref().unwrap()
477    }
478
479    pub fn get_session_header(&self) -> &SessionHeader {
480        self.session.as_ref().unwrap()
481    }
482
483    pub fn get_slim_header_as_mut(&mut self) -> &mut SlimHeader {
484        self.header.as_mut().unwrap()
485    }
486
487    pub fn get_session_header_as_mut(&mut self) -> &mut SessionHeader {
488        self.session.as_mut().unwrap()
489    }
490
491    pub fn get_payload(&self) -> &Content {
492        self.msg.as_ref().unwrap()
493    }
494}
495
496/// From ProtoMessage to ProtoPublish
497impl From<ProtoMessage> for ProtoPublish {
498    fn from(message: ProtoMessage) -> Self {
499        match message.message_type {
500            Some(ProtoPublishType(p)) => p,
501            _ => panic!("message type is not publish"),
502        }
503    }
504}
505
506/// ProtoMessage
507/// This represents a generic message that can be sent over the network
508impl ProtoMessage {
509    fn new(metadata: HashMap<String, String>, message_type: MessageType) -> Self {
510        ProtoMessage {
511            metadata,
512            message_type: Some(message_type),
513        }
514    }
515
516    pub fn new_subscribe(source: &Name, dst: &Name, flags: Option<SlimHeaderFlags>) -> Self {
517        let subscribe = ProtoSubscribe::new(source, dst, flags);
518
519        Self::new(HashMap::new(), ProtoSubscribeType(subscribe))
520    }
521
522    pub fn new_unsubscribe(source: &Name, dst: &Name, flags: Option<SlimHeaderFlags>) -> Self {
523        let unsubscribe = ProtoUnsubscribe::new(source, dst, flags);
524
525        Self::new(HashMap::new(), ProtoUnsubscribeType(unsubscribe))
526    }
527
528    pub fn new_publish(
529        source: &Name,
530        dst: &Name,
531        flags: Option<SlimHeaderFlags>,
532        content_type: &str,
533        blob: Vec<u8>,
534    ) -> Self {
535        let publish = ProtoPublish::new(source, dst, flags, content_type, blob);
536
537        Self::new(HashMap::new(), ProtoPublishType(publish))
538    }
539
540    pub fn new_publish_with_headers(
541        slim_header: Option<SlimHeader>,
542        session_header: Option<SessionHeader>,
543        content_type: &str,
544        blob: Vec<u8>,
545    ) -> Self {
546        let publish = ProtoPublish::with_header(
547            slim_header,
548            session_header,
549            Some(Content {
550                content_type: content_type.to_string(),
551                blob,
552            }),
553        );
554
555        Self::new(HashMap::new(), ProtoPublishType(publish))
556    }
557
558    // validate message
559    pub fn validate(&self) -> Result<(), MessageError> {
560        // make sure the message type is set
561        if self.message_type.is_none() {
562            return Err(MessageError::MessageTypeNotFound);
563        }
564
565        // make sure SLIM header is set
566        if self.try_get_slim_header().is_none() {
567            return Err(MessageError::SlimHeaderNotFound);
568        }
569
570        // Get SLIM header
571        let slim_header = self.get_slim_header();
572
573        // make sure source and destination are set
574        if slim_header.source.is_none() {
575            return Err(MessageError::SourceNotFound);
576        }
577        if slim_header.destination.is_none() {
578            return Err(MessageError::DestinationNotFound);
579        }
580
581        match &self.message_type {
582            Some(ProtoPublishType(p)) => {
583                // SLIM Header
584                if p.header.is_none() {
585                    return Err(MessageError::SlimHeaderNotFound);
586                }
587
588                // Publish message should have the session header
589                if p.session.is_none() {
590                    return Err(MessageError::SessionHeaderNotFound);
591                }
592            }
593            Some(ProtoSubscribeType(s)) => {
594                if s.header.is_none() {
595                    return Err(MessageError::SlimHeaderNotFound);
596                }
597            }
598            Some(ProtoUnsubscribeType(u)) => {
599                if u.header.is_none() {
600                    return Err(MessageError::SlimHeaderNotFound);
601                }
602            }
603            None => return Err(MessageError::MessageTypeNotFound),
604        }
605
606        Ok(())
607    }
608
609    // add metadata key in the map assigning the value val
610    // if the key exists the value is replaced by val
611    pub fn insert_metadata(&mut self, key: String, val: String) {
612        self.metadata.insert(key, val);
613    }
614
615    // remove metadata key from the map
616    pub fn remove_metadata(&mut self, key: &str) -> Option<String> {
617        self.metadata.remove(key)
618    }
619
620    pub fn contains_metadata(&self, key: &str) -> bool {
621        self.metadata.contains_key(key)
622    }
623
624    pub fn get_metadata(&self, key: &str) -> Option<&String> {
625        self.metadata.get(key)
626    }
627
628    pub fn get_metadata_map(&self) -> HashMap<String, String> {
629        self.metadata.clone()
630    }
631
632    pub fn set_metadata_map(&mut self, map: HashMap<String, String>) {
633        for (k, v) in map.iter() {
634            self.insert_metadata(k.to_string(), v.to_string());
635        }
636    }
637
638    pub fn get_slim_header(&self) -> &SlimHeader {
639        match &self.message_type {
640            Some(ProtoPublishType(publish)) => publish.header.as_ref().unwrap(),
641            Some(ProtoSubscribeType(sub)) => sub.header.as_ref().unwrap(),
642            Some(ProtoUnsubscribeType(unsub)) => unsub.header.as_ref().unwrap(),
643            None => panic!("SLIM header not found"),
644        }
645    }
646
647    pub fn get_slim_header_mut(&mut self) -> &mut SlimHeader {
648        match &mut self.message_type {
649            Some(ProtoPublishType(publish)) => publish.header.as_mut().unwrap(),
650            Some(ProtoSubscribeType(sub)) => sub.header.as_mut().unwrap(),
651            Some(ProtoUnsubscribeType(unsub)) => unsub.header.as_mut().unwrap(),
652            None => panic!("SLIM header not found"),
653        }
654    }
655
656    pub fn try_get_slim_header(&self) -> Option<&SlimHeader> {
657        match &self.message_type {
658            Some(ProtoPublishType(publish)) => publish.header.as_ref(),
659            Some(ProtoSubscribeType(sub)) => sub.header.as_ref(),
660            Some(ProtoUnsubscribeType(unsub)) => unsub.header.as_ref(),
661            None => None,
662        }
663    }
664
665    pub fn get_session_header(&self) -> &SessionHeader {
666        match &self.message_type {
667            Some(ProtoPublishType(publish)) => publish.session.as_ref().unwrap(),
668            Some(ProtoSubscribeType(_)) => panic!("session header not found"),
669            Some(ProtoUnsubscribeType(_)) => panic!("session header not found"),
670            None => panic!("session header not found"),
671        }
672    }
673
674    pub fn get_session_header_mut(&mut self) -> &mut SessionHeader {
675        match &mut self.message_type {
676            Some(ProtoPublishType(publish)) => publish.session.as_mut().unwrap(),
677            Some(ProtoSubscribeType(_)) => panic!("session header not found"),
678            Some(ProtoUnsubscribeType(_)) => panic!("session header not found"),
679            None => panic!("session header not found"),
680        }
681    }
682
683    pub fn try_get_session_header(&self) -> Option<&SessionHeader> {
684        match &self.message_type {
685            Some(ProtoPublishType(publish)) => publish.session.as_ref(),
686            Some(ProtoSubscribeType(_)) => None,
687            Some(ProtoUnsubscribeType(_)) => None,
688            None => None,
689        }
690    }
691
692    pub fn try_get_session_header_mut(&mut self) -> Option<&mut SessionHeader> {
693        match &mut self.message_type {
694            Some(ProtoPublishType(publish)) => publish.session.as_mut(),
695            Some(ProtoSubscribeType(_)) => None,
696            Some(ProtoUnsubscribeType(_)) => None,
697            None => None,
698        }
699    }
700
701    pub fn get_id(&self) -> u32 {
702        self.get_session_header().get_message_id()
703    }
704
705    pub fn get_source(&self) -> Name {
706        if let Some(ProtoPublishType(pubslih)) = &self.message_type {
707            // try to the src dst from the session header
708            if let Some(s) = pubslih.get_session_header().get_source() {
709                return s;
710            }
711        }
712
713        // get from slim header
714        self.get_slim_header().get_source()
715    }
716
717    pub fn get_fanout(&self) -> u32 {
718        self.get_slim_header().get_fanout()
719    }
720
721    pub fn get_recv_from(&self) -> Option<u64> {
722        self.get_slim_header().get_recv_from()
723    }
724
725    pub fn get_forward_to(&self) -> Option<u64> {
726        self.get_slim_header().get_forward_to()
727    }
728
729    pub fn get_error(&self) -> Option<bool> {
730        self.get_slim_header().get_error()
731    }
732
733    pub fn get_incoming_conn(&self) -> u64 {
734        self.get_slim_header().get_incoming_conn().unwrap()
735    }
736
737    pub fn try_get_incoming_conn(&self) -> Option<u64> {
738        self.get_slim_header().get_incoming_conn()
739    }
740
741    pub fn get_dst(&self) -> Name {
742        match &self.message_type {
743            Some(ProtoPublishType(publish)) => {
744                // try to the get dst from the session header
745                if let Some(d) = publish.get_session_header().get_destination() {
746                    return d;
747                }
748                // get from the slim header
749                self.get_slim_header().get_dst()
750            }
751            Some(ProtoSubscribeType(subscribe)) => {
752                let dst = self.get_slim_header().get_dst();
753                // complete name with the original strings
754                Name::from_strings([
755                    subscribe.component_0.clone(),
756                    subscribe.component_1.clone(),
757                    subscribe.component_2.clone(),
758                ])
759                .with_id(dst.id())
760            }
761            Some(ProtoUnsubscribeType(unsubscribe)) => {
762                let dst = self.get_slim_header().get_dst();
763                // complete name with the original strings
764                Name::from_strings([
765                    unsubscribe.component_0.clone(),
766                    unsubscribe.component_1.clone(),
767                    unsubscribe.component_2.clone(),
768                ])
769                .with_id(dst.id())
770            }
771            None => {
772                unreachable!("destination not found");
773            }
774        }
775    }
776
777    pub fn get_type(&self) -> &MessageType {
778        match &self.message_type {
779            Some(t) => t,
780            None => panic!("message type not found"),
781        }
782    }
783
784    pub fn get_payload(&self) -> Option<&Content> {
785        match &self.message_type {
786            Some(ProtoPublishType(p)) => p.msg.as_ref(),
787            Some(ProtoSubscribeType(_)) => panic!("payload not found"),
788            Some(ProtoUnsubscribeType(_)) => panic!("payload not found"),
789            None => panic!("payload not found"),
790        }
791    }
792
793    pub fn get_session_message_type(&self) -> SessionMessageType {
794        self.get_session_header()
795            .session_message_type
796            .try_into()
797            .unwrap_or_default()
798    }
799
800    pub fn clear_slim_header(&mut self) {
801        self.get_slim_header_mut().clear();
802    }
803
804    pub fn set_recv_from(&mut self, recv_from: Option<u64>) {
805        self.get_slim_header_mut().set_recv_from(recv_from);
806    }
807
808    pub fn set_forward_to(&mut self, forward_to: Option<u64>) {
809        self.get_slim_header_mut().set_forward_to(forward_to);
810    }
811
812    pub fn set_error(&mut self, error: Option<bool>) {
813        self.get_slim_header_mut().set_error(error);
814    }
815
816    pub fn set_fanout(&mut self, fanout: u32) {
817        self.get_slim_header_mut().set_fanout(fanout);
818    }
819
820    pub fn set_incoming_conn(&mut self, incoming_conn: Option<u64>) {
821        self.get_slim_header_mut().set_incoming_conn(incoming_conn);
822    }
823
824    pub fn set_error_flag(&mut self, error: Option<bool>) {
825        self.get_slim_header_mut().set_error_flag(error);
826    }
827
828    pub fn set_session_message_type(&mut self, message_type: SessionMessageType) {
829        self.get_session_header_mut()
830            .set_session_message_type(message_type);
831    }
832
833    pub fn set_session_type(&mut self, session_type: ProtoSessionType) {
834        self.get_session_header_mut().set_session_type(session_type);
835    }
836
837    pub fn get_session_type(&self) -> ProtoSessionType {
838        self.get_session_header().session_type()
839    }
840
841    pub fn set_message_id(&mut self, message_id: u32) {
842        self.get_session_header_mut().set_message_id(message_id);
843    }
844
845    pub fn is_publish(&self) -> bool {
846        matches!(self.get_type(), MessageType::Publish(_))
847    }
848
849    pub fn is_subscribe(&self) -> bool {
850        matches!(self.get_type(), MessageType::Subscribe(_))
851    }
852
853    pub fn is_unsubscribe(&self) -> bool {
854        matches!(self.get_type(), MessageType::Unsubscribe(_))
855    }
856}
857
858impl AsRef<ProtoPublish> for ProtoMessage {
859    fn as_ref(&self) -> &ProtoPublish {
860        match &self.message_type {
861            Some(ProtoPublishType(p)) => p,
862            _ => panic!("message type is not publish"),
863        }
864    }
865}
866
867#[cfg(test)]
868mod tests {
869    use crate::{api::proto::dataplane::v1::SessionMessageType, messages::encoder::Name};
870
871    use super::*;
872
873    fn test_subscription_template(
874        subscription: bool,
875        source: Name,
876        dst: Name,
877        flags: Option<SlimHeaderFlags>,
878    ) {
879        let sub = {
880            if subscription {
881                ProtoMessage::new_subscribe(&source, &dst, flags.clone())
882            } else {
883                ProtoMessage::new_unsubscribe(&source, &dst, flags.clone())
884            }
885        };
886
887        let flags = if flags.is_none() {
888            Some(SlimHeaderFlags::default())
889        } else {
890            flags
891        };
892
893        assert!(!sub.is_publish());
894        assert_eq!(sub.is_subscribe(), subscription);
895        assert_eq!(sub.is_unsubscribe(), !subscription);
896        assert_eq!(flags.as_ref().unwrap().recv_from, sub.get_recv_from());
897        assert_eq!(flags.as_ref().unwrap().forward_to, sub.get_forward_to());
898        assert_eq!(None, sub.try_get_incoming_conn());
899        assert_eq!(source, sub.get_source());
900        let got_name = sub.get_dst();
901        assert_eq!(dst, got_name);
902    }
903
904    fn test_publish_template(source: Name, dst: Name, flags: Option<SlimHeaderFlags>) {
905        let pub_msg = ProtoMessage::new_publish(
906            &source,
907            &dst,
908            flags.clone(),
909            "str",
910            "this is the content of the message".into(),
911        );
912
913        let flags = if flags.is_none() {
914            Some(SlimHeaderFlags::default())
915        } else {
916            flags
917        };
918
919        assert!(pub_msg.is_publish());
920        assert!(!pub_msg.is_subscribe());
921        assert!(!pub_msg.is_unsubscribe());
922        assert_eq!(flags.as_ref().unwrap().recv_from, pub_msg.get_recv_from());
923        assert_eq!(flags.as_ref().unwrap().forward_to, pub_msg.get_forward_to());
924        assert_eq!(None, pub_msg.try_get_incoming_conn());
925        assert_eq!(source, pub_msg.get_source());
926        let got_name = pub_msg.get_dst();
927        assert_eq!(dst, got_name);
928        assert_eq!(flags.as_ref().unwrap().fanout, pub_msg.get_fanout());
929    }
930
931    #[test]
932    fn test_subscription() {
933        let source = Name::from_strings(["org", "ns", "type"]).with_id(1);
934        let dst = Name::from_strings(["org", "ns", "type"]).with_id(2);
935
936        // simple
937        test_subscription_template(true, source.clone(), dst.clone(), None);
938
939        // with name id
940        test_subscription_template(true, source.clone(), dst.clone(), None);
941
942        // with recv from
943        test_subscription_template(
944            true,
945            source.clone(),
946            dst.clone(),
947            Some(SlimHeaderFlags::default().with_recv_from(50)),
948        );
949
950        // with forward to
951        test_subscription_template(
952            true,
953            source.clone(),
954            dst.clone(),
955            Some(SlimHeaderFlags::default().with_forward_to(30)),
956        );
957    }
958
959    #[test]
960    fn test_unsubscription() {
961        let source = Name::from_strings(["org", "ns", "type"]).with_id(1);
962        let dst = Name::from_strings(["org", "ns", "type"]).with_id(2);
963
964        // simple
965        test_subscription_template(false, source.clone(), dst.clone(), None);
966
967        // with name id
968        test_subscription_template(false, source.clone(), dst.clone(), None);
969
970        // with recv from
971        test_subscription_template(
972            false,
973            source.clone(),
974            dst.clone(),
975            Some(SlimHeaderFlags::default().with_recv_from(50)),
976        );
977
978        // with forward to
979        test_subscription_template(
980            false,
981            source.clone(),
982            dst.clone(),
983            Some(SlimHeaderFlags::default().with_forward_to(30)),
984        );
985    }
986
987    #[test]
988    fn test_publish() {
989        let source = Name::from_strings(["org", "ns", "type"]).with_id(1);
990        let mut dst = Name::from_strings(["org", "ns", "type"]);
991
992        // simple
993        test_publish_template(
994            source.clone(),
995            dst.clone(),
996            Some(SlimHeaderFlags::default()),
997        );
998
999        // with name id
1000        dst.set_id(2);
1001        test_publish_template(
1002            source.clone(),
1003            dst.clone(),
1004            Some(SlimHeaderFlags::default()),
1005        );
1006        dst.reset_id();
1007
1008        // with recv from
1009        test_publish_template(
1010            source.clone(),
1011            dst.clone(),
1012            Some(SlimHeaderFlags::default().with_recv_from(50)),
1013        );
1014
1015        // with forward to
1016        test_publish_template(
1017            source.clone(),
1018            dst.clone(),
1019            Some(SlimHeaderFlags::default().with_forward_to(30)),
1020        );
1021
1022        // with fanout
1023        test_publish_template(
1024            source.clone(),
1025            dst.clone(),
1026            Some(SlimHeaderFlags::default().with_fanout(2)),
1027        );
1028    }
1029
1030    #[test]
1031    fn test_conversions() {
1032        // Name to ProtoName
1033        let name = Name::from_strings(["org", "ns", "type"]).with_id(1);
1034        let proto_name = ProtoName::from(&name);
1035
1036        assert_eq!(proto_name.component_0, name.components()[0]);
1037        assert_eq!(proto_name.component_1, name.components()[1]);
1038        assert_eq!(proto_name.component_2, name.components()[2]);
1039        assert_eq!(proto_name.component_3, name.components()[3]);
1040
1041        // ProtoName to Name
1042        let name_from_proto = Name::from(&proto_name);
1043        assert_eq!(name_from_proto.components()[0], proto_name.component_0);
1044        assert_eq!(name_from_proto.components()[1], proto_name.component_1);
1045        assert_eq!(name_from_proto.components()[2], proto_name.component_2);
1046        assert_eq!(name_from_proto.components()[3], proto_name.component_3);
1047
1048        // ProtoMessage to ProtoSubscribe
1049        let dst = Name::from_strings(["org", "ns", "type"]).with_id(1);
1050        let proto_subscribe = ProtoMessage::new_subscribe(
1051            &name,
1052            &dst,
1053            Some(
1054                SlimHeaderFlags::default()
1055                    .with_recv_from(2)
1056                    .with_forward_to(3),
1057            ),
1058        );
1059        let proto_subscribe = ProtoSubscribe::from(proto_subscribe);
1060        assert_eq!(proto_subscribe.header.as_ref().unwrap().get_source(), name);
1061        assert_eq!(proto_subscribe.header.as_ref().unwrap().get_dst(), dst,);
1062
1063        // ProtoMessage to ProtoUnsubscribe
1064        let proto_unsubscribe = ProtoMessage::new_unsubscribe(
1065            &name,
1066            &dst,
1067            Some(
1068                SlimHeaderFlags::default()
1069                    .with_recv_from(2)
1070                    .with_forward_to(3),
1071            ),
1072        );
1073        let proto_unsubscribe = ProtoUnsubscribe::from(proto_unsubscribe);
1074        assert_eq!(
1075            proto_unsubscribe.header.as_ref().unwrap().get_source(),
1076            name
1077        );
1078        assert_eq!(proto_unsubscribe.header.as_ref().unwrap().get_dst(), dst);
1079
1080        // ProtoMessage to ProtoPublish
1081        let proto_publish = ProtoMessage::new_publish(
1082            &name,
1083            &dst,
1084            Some(
1085                SlimHeaderFlags::default()
1086                    .with_recv_from(2)
1087                    .with_forward_to(3),
1088            ),
1089            "str",
1090            "this is the content of the message".into(),
1091        );
1092        let proto_publish = ProtoPublish::from(proto_publish);
1093        assert_eq!(proto_publish.header.as_ref().unwrap().get_source(), name);
1094        assert_eq!(proto_publish.header.as_ref().unwrap().get_dst(), dst);
1095    }
1096
1097    #[test]
1098    fn test_panic() {
1099        let source = Name::from_strings(["org", "ns", "type"]).with_id(1);
1100        let dst = Name::from_strings(["org", "ns", "type"]).with_id(2);
1101
1102        // panic if SLIM header is not found
1103        let msg = ProtoMessage::new_subscribe(
1104            &source,
1105            &dst,
1106            Some(
1107                SlimHeaderFlags::default()
1108                    .with_recv_from(2)
1109                    .with_forward_to(3),
1110            ),
1111        );
1112
1113        // let's try to convert it to a unsubscribe
1114        // this should panic because the message type is not unsubscribe
1115        let result = std::panic::catch_unwind(|| ProtoUnsubscribe::from(msg.clone()));
1116        assert!(result.is_err());
1117
1118        // try to convert to publish
1119        // this should panic because the message type is not publish
1120        let result = std::panic::catch_unwind(|| ProtoPublish::from(msg.clone()));
1121        assert!(result.is_err());
1122
1123        // finally make sure the conversion to subscribe works
1124        let result = std::panic::catch_unwind(|| ProtoSubscribe::from(msg));
1125        assert!(result.is_ok());
1126    }
1127
1128    #[test]
1129    fn test_panic_header() {
1130        // create a unusual SLIM header
1131        let header = SlimHeader {
1132            source: None,
1133            destination: None,
1134            fanout: 0,
1135            recv_from: None,
1136            forward_to: None,
1137            incoming_conn: None,
1138            error: None,
1139        };
1140
1141        // the operations to retrieve source and destination should fail with panic
1142        let result = std::panic::catch_unwind(|| header.get_source());
1143        assert!(result.is_err());
1144
1145        let result = std::panic::catch_unwind(|| header.get_dst());
1146        assert!(result.is_err());
1147
1148        // The operations to retrieve recv_from and forward_to should not fail with panic
1149        let result = std::panic::catch_unwind(|| header.get_recv_from());
1150        assert!(result.is_ok());
1151
1152        let result = std::panic::catch_unwind(|| header.get_forward_to());
1153        assert!(result.is_ok());
1154
1155        // The operations to retrieve incoming_conn should not fail with panic
1156        let result = std::panic::catch_unwind(|| header.get_incoming_conn());
1157        assert!(result.is_ok());
1158
1159        // The operations to retrieve error should not fail with panic
1160        let result = std::panic::catch_unwind(|| header.get_error());
1161        assert!(result.is_ok());
1162    }
1163
1164    #[test]
1165    fn test_panic_session_header() {
1166        // create a unusual session header
1167        let header = SessionHeader::new(0, 0, 0, 0, &None, &None);
1168
1169        // the operations to retrieve session_id and message_id should not fail with panic
1170        let result = std::panic::catch_unwind(|| header.get_session_id());
1171        assert!(result.is_ok());
1172
1173        let result = std::panic::catch_unwind(|| header.get_message_id());
1174        assert!(result.is_ok());
1175    }
1176
1177    #[test]
1178    fn test_panic_proto_message() {
1179        // create a unusual proto message
1180        let message = ProtoMessage {
1181            metadata: HashMap::new(),
1182            message_type: None,
1183        };
1184
1185        // the operation to retrieve the header should fail with panic
1186        let result = std::panic::catch_unwind(|| message.get_slim_header());
1187        assert!(result.is_err());
1188
1189        // the operation to retrieve the message type should fail with panic
1190        let result = std::panic::catch_unwind(|| message.get_type());
1191        assert!(result.is_err());
1192
1193        // all the other ops should fail with panic as well as the header is not set
1194        let result = std::panic::catch_unwind(|| message.get_source());
1195        assert!(result.is_err());
1196        let result = std::panic::catch_unwind(|| message.get_dst());
1197        assert!(result.is_err());
1198        let result = std::panic::catch_unwind(|| message.get_recv_from());
1199        assert!(result.is_err());
1200        let result = std::panic::catch_unwind(|| message.get_forward_to());
1201        assert!(result.is_err());
1202        let result = std::panic::catch_unwind(|| message.get_incoming_conn());
1203        assert!(result.is_err());
1204        let result = std::panic::catch_unwind(|| message.get_fanout());
1205        assert!(result.is_err());
1206    }
1207
1208    #[test]
1209    fn test_service_type_to_int() {
1210        // Get total number of service types
1211        let total_service_types = SessionMessageType::ChannelMlsAck as i32;
1212
1213        for i in 0..total_service_types {
1214            // int -> ServiceType
1215            let service_type =
1216                SessionMessageType::try_from(i).expect("failed to convert int to service type");
1217            let service_type_int = i32::from(service_type);
1218            assert_eq!(service_type_int, i32::from(service_type),);
1219        }
1220
1221        // Test invalid conversion
1222        let invalid_service_type = SessionMessageType::try_from(total_service_types + 1);
1223        assert!(invalid_service_type.is_err());
1224    }
1225}