agp_datapath/messages/
utils.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashMap;
5use std::fmt::Display;
6
7use tracing::debug;
8
9use super::encoder::{Agent, AgentType, DEFAULT_AGENT_ID};
10use crate::pubsub::{
11    AgpHeader, Content, MessageType, ProtoAgent, ProtoMessage, ProtoPublish, ProtoPublishType,
12    ProtoSubscribe, ProtoSubscribeType, ProtoUnsubscribe, ProtoUnsubscribeType, SessionHeader,
13    proto::pubsub::v1::SessionHeaderType,
14};
15
16use thiserror::Error;
17use tracing::error;
18
19#[derive(Error, Debug, PartialEq)]
20pub enum MessageError {
21    #[error("AGP header not found")]
22    AgpHeaderNotFound,
23    #[error("source not found")]
24    SourceNotFound,
25    #[error("destination not found")]
26    DestinationNotFound,
27    #[error("session header not found")]
28    SessionHeaderNotFound,
29    #[error("message type not found")]
30    MessageTypeNotFound,
31    #[error("incoming connection not found")]
32    IncomingConnectionNotFound,
33}
34
35/// ProtoAgent from Agent
36impl From<&Agent> for ProtoAgent {
37    fn from(agent: &Agent) -> Self {
38        let mut id = None;
39        if agent.agent_id() != DEFAULT_AGENT_ID {
40            id = Some(agent.agent_id())
41        }
42
43        Self {
44            organization: agent.agent_type().organization(),
45            namespace: agent.agent_type().namespace(),
46            agent_type: agent.agent_type().agent_type(),
47            agent_id: id,
48        }
49    }
50}
51
52/// ProtoAgent from AgentType
53impl From<(&AgentType, Option<u64>)> for ProtoAgent {
54    fn from((agent_type, agent_id): (&AgentType, Option<u64>)) -> Self {
55        Self {
56            organization: agent_type.organization(),
57            namespace: agent_type.namespace(),
58            agent_type: agent_type.agent_type(),
59            agent_id,
60        }
61    }
62}
63
64/// Print message type
65impl Display for MessageType {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        match self {
68            MessageType::Publish(_) => write!(f, "publish"),
69            MessageType::Subscribe(_) => write!(f, "subscribe"),
70            MessageType::Unsubscribe(_) => write!(f, "unsubscribe"),
71        }
72    }
73}
74
75/// Struct grouping the AGPHeaeder flags for convenience
76#[derive(Debug, Clone)]
77pub struct AgpHeaderFlags {
78    pub fanout: u32,
79    pub recv_from: Option<u64>,
80    pub forward_to: Option<u64>,
81    pub incoming_conn: Option<u64>,
82    pub error: Option<bool>,
83}
84
85impl Default for AgpHeaderFlags {
86    fn default() -> Self {
87        Self {
88            fanout: 1,
89            recv_from: None,
90            forward_to: None,
91            incoming_conn: None,
92            error: None,
93        }
94    }
95}
96
97impl Display for AgpHeaderFlags {
98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99        write!(
100            f,
101            "fanout: {}, recv_from: {:?}, forward_to: {:?}, incoming_conn: {:?}, error: {:?}",
102            self.fanout, self.recv_from, self.forward_to, self.incoming_conn, self.error
103        )
104    }
105}
106
107impl AgpHeaderFlags {
108    pub fn new(
109        fanout: u32,
110        recv_from: Option<u64>,
111        forward_to: Option<u64>,
112        incoming_conn: Option<u64>,
113        error: Option<bool>,
114    ) -> Self {
115        Self {
116            fanout,
117            recv_from,
118            forward_to,
119            incoming_conn,
120            error,
121        }
122    }
123
124    pub fn with_fanout(self, fanout: u32) -> Self {
125        Self { fanout, ..self }
126    }
127
128    pub fn with_recv_from(self, recv_from: u64) -> Self {
129        Self {
130            recv_from: Some(recv_from),
131            ..self
132        }
133    }
134
135    pub fn with_forward_to(self, forward_to: u64) -> Self {
136        Self {
137            forward_to: Some(forward_to),
138            ..self
139        }
140    }
141
142    pub fn with_incoming_conn(self, incoming_conn: u64) -> Self {
143        Self {
144            incoming_conn: Some(incoming_conn),
145            ..self
146        }
147    }
148
149    pub fn with_error(self, error: bool) -> Self {
150        Self {
151            error: Some(error),
152            ..self
153        }
154    }
155}
156
157/// AGP Header
158/// This header is used to identify the source and destination of the message
159/// and to manage the connections used to send and receive the message
160impl AgpHeader {
161    pub fn new(
162        source: &Agent,
163        name_type: &AgentType,
164        name_id: Option<u64>,
165        flags: Option<AgpHeaderFlags>,
166    ) -> Self {
167        let flags = flags.unwrap_or_default();
168
169        Self {
170            source: Some(ProtoAgent::from(source)),
171            destination: Some(ProtoAgent::from((name_type, name_id))),
172            fanout: flags.fanout,
173            recv_from: flags.recv_from,
174            forward_to: flags.forward_to,
175            incoming_conn: flags.incoming_conn,
176            error: flags.error,
177        }
178    }
179
180    pub fn clear(&mut self) {
181        self.recv_from = None;
182        self.forward_to = None;
183    }
184
185    pub fn get_recv_from(&self) -> Option<u64> {
186        self.recv_from
187    }
188
189    pub fn get_forward_to(&self) -> Option<u64> {
190        self.forward_to
191    }
192
193    pub fn get_incoming_conn(&self) -> Option<u64> {
194        self.incoming_conn
195    }
196
197    pub fn get_error(&self) -> Option<bool> {
198        self.error
199    }
200
201    pub fn get_source(&self) -> Agent {
202        match &self.source {
203            Some(source) => Agent::from(source),
204            None => panic!("source not found"),
205        }
206    }
207
208    pub fn get_dst(&self) -> (AgentType, Option<u64>) {
209        match &self.destination {
210            Some(destination) => (AgentType::from(destination), destination.agent_id),
211            None => panic!("destination not found"),
212        }
213    }
214
215    pub fn get_fanout(&self) -> u32 {
216        self.fanout
217    }
218
219    pub fn set_recv_from(&mut self, recv_from: Option<u64>) {
220        self.recv_from = recv_from;
221    }
222
223    pub fn set_forward_to(&mut self, forward_to: Option<u64>) {
224        self.forward_to = forward_to;
225    }
226
227    pub fn set_error(&mut self, error: Option<bool>) {
228        self.error = error;
229    }
230
231    pub fn set_incoming_conn(&mut self, incoming_conn: Option<u64>) {
232        self.incoming_conn = incoming_conn;
233    }
234
235    pub fn set_error_flag(&mut self, error: Option<bool>) {
236        self.error = error;
237    }
238
239    pub fn set_fanout(&mut self, fanout: u32) {
240        self.fanout = fanout;
241    }
242
243    // returns the connection to use to process correctly the message
244    // first connection is from where we received the packet
245    // the second is where to forward the packet if needed
246    pub fn get_in_out_connections(&self) -> (u64, Option<u64>) {
247        // when calling this function, incoming connection is set
248        let incoming = self
249            .get_incoming_conn()
250            .expect("incoming connection not found");
251
252        if let Some(val) = self.get_recv_from() {
253            debug!(
254                "received recv_from command, update state on connection {}",
255                val
256            );
257            return (val, None);
258        }
259
260        if let Some(val) = self.get_forward_to() {
261            debug!(
262                "received forward_to command, update state and forward to connection {}",
263                val
264            );
265            return (incoming, Some(val));
266        }
267
268        // by default, return the incoming connection and None
269        (incoming, None)
270    }
271}
272
273/// Session Header
274/// This header is used to identify the session and the message
275/// and to manage session state
276impl SessionHeader {
277    pub fn new(header_type: i32, session_id: u32, message_id: u32) -> Self {
278        Self {
279            header_type,
280            session_id,
281            message_id,
282        }
283    }
284
285    pub fn get_session_id(&self) -> u32 {
286        self.session_id
287    }
288
289    pub fn get_message_id(&self) -> u32 {
290        self.message_id
291    }
292
293    pub fn set_session_id(&mut self, session_id: u32) {
294        self.session_id = session_id;
295    }
296
297    pub fn set_message_id(&mut self, message_id: u32) {
298        self.message_id = message_id;
299    }
300
301    pub fn clear(&mut self) {
302        self.session_id = 0;
303        self.message_id = 0;
304    }
305}
306
307/// ProtoSubscribe
308/// This message is used to subscribe to a topic
309impl ProtoSubscribe {
310    pub fn with_header(header: Option<AgpHeader>) -> Self {
311        ProtoSubscribe { header }
312    }
313
314    pub fn new(
315        source: &Agent,
316        agent_type: &AgentType,
317        agent_id: Option<u64>,
318        flags: Option<AgpHeaderFlags>,
319    ) -> Self {
320        let header = Some(AgpHeader::new(source, agent_type, agent_id, flags));
321
322        Self::with_header(header)
323    }
324}
325
326/// From ProtoMessage to ProtoSubscribe
327impl From<ProtoMessage> for ProtoSubscribe {
328    fn from(message: ProtoMessage) -> Self {
329        match message.message_type {
330            Some(ProtoSubscribeType(s)) => s,
331            _ => panic!("message type is not subscribe"),
332        }
333    }
334}
335
336/// ProtoUnsubscribe
337/// This message is used to unsubscribe from a topic
338impl ProtoUnsubscribe {
339    pub fn with_header(header: Option<AgpHeader>) -> Self {
340        ProtoUnsubscribe { header }
341    }
342
343    pub fn new(
344        source: &Agent,
345        agent_type: &AgentType,
346        agent_id: Option<u64>,
347        flags: Option<AgpHeaderFlags>,
348    ) -> Self {
349        let header = Some(AgpHeader::new(source, agent_type, agent_id, flags));
350
351        Self::with_header(header)
352    }
353}
354
355/// From ProtoMessage to ProtoUnsubscribe
356impl From<ProtoMessage> for ProtoUnsubscribe {
357    fn from(message: ProtoMessage) -> Self {
358        match message.message_type {
359            Some(ProtoUnsubscribeType(u)) => u,
360            _ => panic!("message type is not unsubscribe"),
361        }
362    }
363}
364
365/// ProtoPublish
366/// This message is used to publish a message to a topic/agent
367impl ProtoPublish {
368    pub fn with_header(
369        header: Option<AgpHeader>,
370        session: Option<SessionHeader>,
371        payload: Option<Content>,
372    ) -> Self {
373        ProtoPublish {
374            header,
375            session,
376            msg: payload,
377        }
378    }
379
380    pub fn new(
381        source: &Agent,
382        agent_type: &AgentType,
383        agent_id: Option<u64>,
384        flags: Option<AgpHeaderFlags>,
385        content_type: &str,
386        blob: Vec<u8>,
387    ) -> Self {
388        let agp_header = Some(AgpHeader::new(source, agent_type, agent_id, flags));
389
390        let session_header = Some(SessionHeader::default());
391
392        let msg = Some(Content {
393            content_type: content_type.to_string(),
394            blob,
395        });
396
397        Self::with_header(agp_header, session_header, msg)
398    }
399
400    pub fn get_agp_header(&self) -> &AgpHeader {
401        self.header.as_ref().unwrap()
402    }
403
404    pub fn get_session_header(&self) -> &SessionHeader {
405        self.session.as_ref().unwrap()
406    }
407
408    pub fn get_agp_header_as_mut(&mut self) -> &mut AgpHeader {
409        self.header.as_mut().unwrap()
410    }
411
412    pub fn get_session_header_as_mut(&mut self) -> &mut SessionHeader {
413        self.session.as_mut().unwrap()
414    }
415
416    pub fn get_payload(&self) -> &Content {
417        self.msg.as_ref().unwrap()
418    }
419}
420
421/// From ProtoMessage to ProtoPublish
422impl From<ProtoMessage> for ProtoPublish {
423    fn from(message: ProtoMessage) -> Self {
424        match message.message_type {
425            Some(ProtoPublishType(p)) => p,
426            _ => panic!("message type is not publish"),
427        }
428    }
429}
430
431/// ProtoMessage
432/// This represents a generic message that can be sent over the network
433impl ProtoMessage {
434    fn new(metadata: HashMap<String, String>, message_type: MessageType) -> Self {
435        ProtoMessage {
436            metadata,
437            message_type: Some(message_type),
438        }
439    }
440
441    pub fn new_subscribe(
442        source: &Agent,
443        agent_type: &AgentType,
444        agent_id: Option<u64>,
445        flags: Option<AgpHeaderFlags>,
446    ) -> Self {
447        let subscribe = ProtoSubscribe::new(source, agent_type, agent_id, flags);
448
449        Self::new(HashMap::new(), ProtoSubscribeType(subscribe))
450    }
451
452    pub fn new_unsubscribe(
453        source: &Agent,
454        agent_type: &AgentType,
455        agent_id: Option<u64>,
456        flags: Option<AgpHeaderFlags>,
457    ) -> Self {
458        let unsubscribe = ProtoUnsubscribe::new(source, agent_type, agent_id, flags);
459
460        Self::new(HashMap::new(), ProtoUnsubscribeType(unsubscribe))
461    }
462
463    pub fn new_publish(
464        source: &Agent,
465        agent_type: &AgentType,
466        agent_id: Option<u64>,
467        flags: Option<AgpHeaderFlags>,
468        content_type: &str,
469        blob: Vec<u8>,
470    ) -> Self {
471        let publish = ProtoPublish::new(source, agent_type, agent_id, flags, content_type, blob);
472
473        Self::new(HashMap::new(), ProtoPublishType(publish))
474    }
475
476    pub fn new_publish_with_headers(
477        agp_header: Option<AgpHeader>,
478        session_header: Option<SessionHeader>,
479        content_type: &str,
480        blob: Vec<u8>,
481    ) -> Self {
482        let publish = ProtoPublish::with_header(
483            agp_header,
484            session_header,
485            Some(Content {
486                content_type: content_type.to_string(),
487                blob,
488            }),
489        );
490
491        Self::new(HashMap::new(), ProtoPublishType(publish))
492    }
493
494    // validate message
495    pub fn validate(&self) -> Result<(), MessageError> {
496        // make sure the message type is set
497        if self.message_type.is_none() {
498            return Err(MessageError::MessageTypeNotFound);
499        }
500
501        // make sure AGP header is set
502        if self.try_get_agp_header().is_none() {
503            return Err(MessageError::AgpHeaderNotFound);
504        }
505
506        // Get AGP header
507        let agp_header = self.get_agp_header();
508
509        // make sure source and destination are set
510        if agp_header.source.is_none() {
511            return Err(MessageError::SourceNotFound);
512        }
513        if agp_header.destination.is_none() {
514            return Err(MessageError::DestinationNotFound);
515        }
516
517        match &self.message_type {
518            Some(ProtoPublishType(p)) => {
519                // AGP Header
520                if p.header.is_none() {
521                    return Err(MessageError::AgpHeaderNotFound);
522                }
523
524                // Publish message should have the session header
525                if p.session.is_none() {
526                    return Err(MessageError::SessionHeaderNotFound);
527                }
528            }
529            Some(ProtoSubscribeType(s)) => {
530                if s.header.is_none() {
531                    return Err(MessageError::AgpHeaderNotFound);
532                }
533            }
534            Some(ProtoUnsubscribeType(u)) => {
535                if u.header.is_none() {
536                    return Err(MessageError::AgpHeaderNotFound);
537                }
538            }
539            None => return Err(MessageError::MessageTypeNotFound),
540        }
541
542        Ok(())
543    }
544
545    pub fn get_agp_header(&self) -> &AgpHeader {
546        match &self.message_type {
547            Some(ProtoPublishType(publish)) => publish.header.as_ref().unwrap(),
548            Some(ProtoSubscribeType(sub)) => sub.header.as_ref().unwrap(),
549            Some(ProtoUnsubscribeType(unsub)) => unsub.header.as_ref().unwrap(),
550            None => panic!("AGP header not found"),
551        }
552    }
553
554    pub fn get_agp_header_mut(&mut self) -> &mut AgpHeader {
555        match &mut self.message_type {
556            Some(ProtoPublishType(publish)) => publish.header.as_mut().unwrap(),
557            Some(ProtoSubscribeType(sub)) => sub.header.as_mut().unwrap(),
558            Some(ProtoUnsubscribeType(unsub)) => unsub.header.as_mut().unwrap(),
559            None => panic!("AGP header not found"),
560        }
561    }
562
563    pub fn try_get_agp_header(&self) -> Option<&AgpHeader> {
564        match &self.message_type {
565            Some(ProtoPublishType(publish)) => publish.header.as_ref(),
566            Some(ProtoSubscribeType(sub)) => sub.header.as_ref(),
567            Some(ProtoUnsubscribeType(unsub)) => unsub.header.as_ref(),
568            None => None,
569        }
570    }
571
572    pub fn get_session_header(&self) -> &SessionHeader {
573        match &self.message_type {
574            Some(ProtoPublishType(publish)) => publish.session.as_ref().unwrap(),
575            Some(ProtoSubscribeType(_)) => panic!("session header not found"),
576            Some(ProtoUnsubscribeType(_)) => panic!("session header not found"),
577            None => panic!("session header not found"),
578        }
579    }
580
581    pub fn get_session_header_mut(&mut self) -> &mut SessionHeader {
582        match &mut self.message_type {
583            Some(ProtoPublishType(publish)) => publish.session.as_mut().unwrap(),
584            Some(ProtoSubscribeType(_)) => panic!("session header not found"),
585            Some(ProtoUnsubscribeType(_)) => panic!("session header not found"),
586            None => panic!("session header not found"),
587        }
588    }
589
590    pub fn try_get_session_header(&self) -> Option<&SessionHeader> {
591        match &self.message_type {
592            Some(ProtoPublishType(publish)) => publish.session.as_ref(),
593            Some(ProtoSubscribeType(_)) => None,
594            Some(ProtoUnsubscribeType(_)) => None,
595            None => None,
596        }
597    }
598
599    pub fn try_get_session_header_mut(&mut self) -> Option<&mut SessionHeader> {
600        match &mut self.message_type {
601            Some(ProtoPublishType(publish)) => publish.session.as_mut(),
602            Some(ProtoSubscribeType(_)) => None,
603            Some(ProtoUnsubscribeType(_)) => None,
604            None => None,
605        }
606    }
607
608    pub fn get_id(&self) -> u32 {
609        self.get_session_header().get_message_id()
610    }
611
612    pub fn get_source(&self) -> Agent {
613        self.get_agp_header().get_source()
614    }
615
616    pub fn get_fanout(&self) -> u32 {
617        self.get_agp_header().get_fanout()
618    }
619
620    pub fn get_recv_from(&self) -> Option<u64> {
621        self.get_agp_header().get_recv_from()
622    }
623
624    pub fn get_forward_to(&self) -> Option<u64> {
625        self.get_agp_header().get_forward_to()
626    }
627
628    pub fn get_error(&self) -> Option<bool> {
629        self.get_agp_header().get_error()
630    }
631
632    pub fn get_incoming_conn(&self) -> u64 {
633        self.get_agp_header().get_incoming_conn().unwrap()
634    }
635
636    pub fn try_get_incoming_conn(&self) -> Option<u64> {
637        self.get_agp_header().get_incoming_conn()
638    }
639
640    pub fn get_source_agent(&self) -> Agent {
641        self.get_agp_header().get_source()
642    }
643
644    pub fn get_name(&self) -> (AgentType, Option<u64>) {
645        self.get_agp_header().get_dst()
646    }
647
648    pub fn get_name_as_agent(&self) -> Agent {
649        let (a_type, a_id) = self.get_agp_header().get_dst();
650        let id = match a_id {
651            None => DEFAULT_AGENT_ID,
652            Some(id) => id,
653        };
654        Agent::new(a_type, id)
655    }
656
657    pub fn get_type(&self) -> &MessageType {
658        match &self.message_type {
659            Some(t) => t,
660            None => panic!("message type not found"),
661        }
662    }
663
664    pub fn get_payload(&self) -> Option<&Content> {
665        match &self.message_type {
666            Some(ProtoPublishType(p)) => p.msg.as_ref(),
667            Some(ProtoSubscribeType(_)) => panic!("payload not found"),
668            Some(ProtoUnsubscribeType(_)) => panic!("payload not found"),
669            None => panic!("payload not found"),
670        }
671    }
672
673    pub fn get_header_type(&self) -> SessionHeaderType {
674        self.get_session_header()
675            .header_type
676            .try_into()
677            .unwrap_or_default()
678    }
679
680    pub fn clear_agp_header(&mut self) {
681        self.get_agp_header_mut().clear();
682    }
683
684    pub fn set_recv_from(&mut self, recv_from: Option<u64>) {
685        self.get_agp_header_mut().set_recv_from(recv_from);
686    }
687
688    pub fn set_forward_to(&mut self, forward_to: Option<u64>) {
689        self.get_agp_header_mut().set_forward_to(forward_to);
690    }
691
692    pub fn set_error(&mut self, error: Option<bool>) {
693        self.get_agp_header_mut().set_error(error);
694    }
695
696    pub fn set_fanout(&mut self, fanout: u32) {
697        self.get_agp_header_mut().set_fanout(fanout);
698    }
699
700    pub fn set_incoming_conn(&mut self, incoming_conn: Option<u64>) {
701        self.get_agp_header_mut().set_incoming_conn(incoming_conn);
702    }
703
704    pub fn set_error_flag(&mut self, error: Option<bool>) {
705        self.get_agp_header_mut().set_error_flag(error);
706    }
707
708    pub fn set_header_type(&mut self, header_type: SessionHeaderType) {
709        self.get_session_header_mut().set_header_type(header_type);
710    }
711
712    pub fn set_message_id(&mut self, message_id: u32) {
713        self.get_session_header_mut().set_message_id(message_id);
714    }
715
716    pub fn is_publish(&self) -> bool {
717        matches!(self.get_type(), MessageType::Publish(_))
718    }
719
720    pub fn is_subscribe(&self) -> bool {
721        matches!(self.get_type(), MessageType::Subscribe(_))
722    }
723
724    pub fn is_unsubscribe(&self) -> bool {
725        matches!(self.get_type(), MessageType::Unsubscribe(_))
726    }
727}
728
729impl AsRef<ProtoPublish> for ProtoMessage {
730    fn as_ref(&self) -> &ProtoPublish {
731        match &self.message_type {
732            Some(ProtoPublishType(p)) => p,
733            _ => panic!("message type is not publish"),
734        }
735    }
736}
737
738#[cfg(test)]
739mod tests {
740    use crate::{
741        messages::encoder::{Agent, AgentType},
742        pubsub::proto::pubsub::v1::SessionHeaderType,
743    };
744
745    use super::*;
746
747    fn test_subscription_template(
748        subscription: bool,
749        source: Agent,
750        name: AgentType,
751        name_id: Option<u64>,
752        flags: Option<AgpHeaderFlags>,
753    ) {
754        let sub = {
755            if subscription {
756                ProtoMessage::new_subscribe(&source, &name, name_id, flags.clone())
757            } else {
758                ProtoMessage::new_unsubscribe(&source, &name, name_id, flags.clone())
759            }
760        };
761
762        let flags = if flags.is_none() {
763            Some(AgpHeaderFlags::default())
764        } else {
765            flags
766        };
767
768        assert!(!sub.is_publish());
769        assert_eq!(sub.is_subscribe(), subscription);
770        assert_eq!(sub.is_unsubscribe(), !subscription);
771        assert_eq!(flags.as_ref().unwrap().recv_from, sub.get_recv_from());
772        assert_eq!(flags.as_ref().unwrap().forward_to, sub.get_forward_to());
773        assert_eq!(None, sub.try_get_incoming_conn());
774        assert_eq!(source, sub.get_source());
775        let (got_name, got_name_id) = sub.get_name();
776        assert_eq!(name, got_name);
777        assert_eq!(name_id, got_name_id);
778    }
779
780    fn test_publish_template(
781        source: Agent,
782        name: AgentType,
783        name_id: Option<u64>,
784        flags: Option<AgpHeaderFlags>,
785    ) {
786        let pub_msg = ProtoMessage::new_publish(
787            &source,
788            &name,
789            name_id,
790            flags.clone(),
791            "str",
792            "this is the content of the message".into(),
793        );
794
795        let flags = if flags.is_none() {
796            Some(AgpHeaderFlags::default())
797        } else {
798            flags
799        };
800
801        assert!(pub_msg.is_publish());
802        assert!(!pub_msg.is_subscribe());
803        assert!(!pub_msg.is_unsubscribe());
804        assert_eq!(flags.as_ref().unwrap().recv_from, pub_msg.get_recv_from());
805        assert_eq!(flags.as_ref().unwrap().forward_to, pub_msg.get_forward_to());
806        assert_eq!(None, pub_msg.try_get_incoming_conn());
807        assert_eq!(source, pub_msg.get_source());
808        let (got_name, got_name_id) = pub_msg.get_name();
809        assert_eq!(name, got_name);
810        assert_eq!(name_id, got_name_id);
811        assert_eq!(flags.as_ref().unwrap().fanout, pub_msg.get_fanout());
812    }
813
814    #[test]
815    fn test_subscription() {
816        let source = Agent::from_strings("org", "ns", "type", 1);
817        let name = AgentType::from_strings("org", "ns", "type");
818
819        // simple
820        test_subscription_template(true, source.clone(), name.clone(), None, None);
821
822        // with name id
823        test_subscription_template(true, source.clone(), name.clone(), Some(2), None);
824
825        // with recv from
826        test_subscription_template(
827            true,
828            source.clone(),
829            name.clone(),
830            None,
831            Some(AgpHeaderFlags::default().with_recv_from(50)),
832        );
833
834        // with forward to
835        test_subscription_template(
836            true,
837            source.clone(),
838            name.clone(),
839            None,
840            Some(AgpHeaderFlags::default().with_forward_to(30)),
841        );
842    }
843
844    #[test]
845    fn test_unsubscription() {
846        let source = Agent::from_strings("org", "ns", "type", 1);
847        let name = AgentType::from_strings("org", "ns", "type");
848
849        // simple
850        test_subscription_template(false, source.clone(), name.clone(), None, None);
851
852        // with name id
853        test_subscription_template(false, source.clone(), name.clone(), Some(2), None);
854
855        // with recv from
856        test_subscription_template(
857            false,
858            source.clone(),
859            name.clone(),
860            None,
861            Some(AgpHeaderFlags::default().with_recv_from(50)),
862        );
863
864        // with forward to
865        test_subscription_template(
866            false,
867            source.clone(),
868            name.clone(),
869            None,
870            Some(AgpHeaderFlags::default().with_forward_to(30)),
871        );
872    }
873
874    #[test]
875    fn test_publish() {
876        let source = Agent::from_strings("org", "ns", "type", 1);
877        let name = AgentType::from_strings("org", "ns", "type");
878
879        // simple
880        test_publish_template(
881            source.clone(),
882            name.clone(),
883            None,
884            Some(AgpHeaderFlags::default()),
885        );
886
887        // with name id
888        test_publish_template(
889            source.clone(),
890            name.clone(),
891            Some(2),
892            Some(AgpHeaderFlags::default()),
893        );
894
895        // with recv from
896        test_publish_template(
897            source.clone(),
898            name.clone(),
899            None,
900            Some(AgpHeaderFlags::default().with_recv_from(50)),
901        );
902
903        // with forward to
904        test_publish_template(
905            source.clone(),
906            name.clone(),
907            None,
908            Some(AgpHeaderFlags::default().with_forward_to(30)),
909        );
910
911        // with fanout
912        test_publish_template(
913            source.clone(),
914            name.clone(),
915            None,
916            Some(AgpHeaderFlags::default().with_fanout(2)),
917        );
918    }
919
920    #[test]
921    fn test_conversions() {
922        // ProtoAgent to Agent
923        let agent = Agent::from_strings("org", "ns", "type", 1);
924        let proto_agent = ProtoAgent::from(&agent);
925
926        assert_eq!(proto_agent.organization, agent.agent_type().organization());
927        assert_eq!(proto_agent.namespace, agent.agent_type().namespace());
928        assert_eq!(proto_agent.agent_type, agent.agent_type().agent_type());
929        assert_eq!(proto_agent.agent_id.unwrap(), agent.agent_id());
930
931        // AgentType to ProtoAgent
932        let agent_type = AgentType::from_strings("org", "ns", "type");
933        let proto_agent = ProtoAgent::from((&agent_type, Some(1)));
934
935        assert_eq!(proto_agent.organization, agent_type.organization());
936        assert_eq!(proto_agent.namespace, agent_type.namespace());
937        assert_eq!(proto_agent.agent_type, agent_type.agent_type());
938        assert_eq!(proto_agent.agent_id.unwrap(), 1);
939
940        // ProtoMessage to ProtoSubscribe
941        let proto_subscribe = ProtoMessage::new_subscribe(
942            &agent,
943            &agent_type,
944            Some(1),
945            Some(
946                AgpHeaderFlags::default()
947                    .with_recv_from(2)
948                    .with_forward_to(3),
949            ),
950        );
951        let proto_subscribe = ProtoSubscribe::from(proto_subscribe);
952        assert_eq!(proto_subscribe.header.as_ref().unwrap().get_source(), agent);
953        assert_eq!(
954            proto_subscribe.header.as_ref().unwrap().get_dst(),
955            (agent_type.clone(), Some(1))
956        );
957
958        // ProtoMessage to ProtoUnsubscribe
959        let proto_unsubscribe = ProtoMessage::new_unsubscribe(
960            &agent,
961            &agent_type,
962            Some(1),
963            Some(
964                AgpHeaderFlags::default()
965                    .with_recv_from(2)
966                    .with_forward_to(3),
967            ),
968        );
969        let proto_unsubscribe = ProtoUnsubscribe::from(proto_unsubscribe);
970        assert_eq!(
971            proto_unsubscribe.header.as_ref().unwrap().get_source(),
972            agent
973        );
974        assert_eq!(
975            proto_unsubscribe.header.as_ref().unwrap().get_dst(),
976            (agent_type.clone(), Some(1))
977        );
978
979        // ProtoMessage to ProtoPublish
980        let proto_publish = ProtoMessage::new_publish(
981            &agent,
982            &agent_type,
983            Some(1),
984            Some(
985                AgpHeaderFlags::default()
986                    .with_recv_from(2)
987                    .with_forward_to(3),
988            ),
989            "str",
990            "this is the content of the message".into(),
991        );
992        let proto_publish = ProtoPublish::from(proto_publish);
993        assert_eq!(proto_publish.header.as_ref().unwrap().get_source(), agent);
994        assert_eq!(
995            proto_publish.header.as_ref().unwrap().get_dst(),
996            (agent_type.clone(), Some(1))
997        );
998    }
999
1000    #[test]
1001    fn test_panic() {
1002        let source = Agent::from_strings("org", "ns", "type", 1);
1003        let name = AgentType::from_strings("org", "ns", "type");
1004
1005        // panic if AGP header is not found
1006        let msg = ProtoMessage::new_subscribe(
1007            &source,
1008            &name,
1009            None,
1010            Some(
1011                AgpHeaderFlags::default()
1012                    .with_recv_from(2)
1013                    .with_forward_to(3),
1014            ),
1015        );
1016
1017        // let's try to convert it to a unsubscribe
1018        // this should panic because the message type is not unsubscribe
1019        let result = std::panic::catch_unwind(|| ProtoUnsubscribe::from(msg.clone()));
1020        assert!(result.is_err());
1021
1022        // try to convert to publish
1023        // this should panic because the message type is not publish
1024        let result = std::panic::catch_unwind(|| ProtoPublish::from(msg.clone()));
1025        assert!(result.is_err());
1026
1027        // finally make sure the conversion to subscribe works
1028        let result = std::panic::catch_unwind(|| ProtoSubscribe::from(msg));
1029        assert!(result.is_ok());
1030    }
1031
1032    #[test]
1033    fn test_panic_header() {
1034        // create a unusual AGP header
1035        let header = AgpHeader {
1036            source: None,
1037            destination: None,
1038            fanout: 0,
1039            recv_from: None,
1040            forward_to: None,
1041            incoming_conn: None,
1042            error: None,
1043        };
1044
1045        // the operations to retrieve source and destination should fail with panic
1046        let result = std::panic::catch_unwind(|| header.get_source());
1047        assert!(result.is_err());
1048
1049        let result = std::panic::catch_unwind(|| header.get_dst());
1050        assert!(result.is_err());
1051
1052        // The operations to retrieve recv_from and forward_to should not fail with panic
1053        let result = std::panic::catch_unwind(|| header.get_recv_from());
1054        assert!(result.is_ok());
1055
1056        let result = std::panic::catch_unwind(|| header.get_forward_to());
1057        assert!(result.is_ok());
1058
1059        // The operations to retrieve incoming_conn should not fail with panic
1060        let result = std::panic::catch_unwind(|| header.get_incoming_conn());
1061        assert!(result.is_ok());
1062
1063        // The operations to retrieve error should not fail with panic
1064        let result = std::panic::catch_unwind(|| header.get_error());
1065        assert!(result.is_ok());
1066    }
1067
1068    #[test]
1069    fn test_panic_session_header() {
1070        // create a unusual session header
1071        let header = SessionHeader {
1072            header_type: 0,
1073            session_id: 0,
1074            message_id: 0,
1075        };
1076
1077        // the operations to retrieve session_id and message_id should not fail with panic
1078        let result = std::panic::catch_unwind(|| header.get_session_id());
1079        assert!(result.is_ok());
1080
1081        let result = std::panic::catch_unwind(|| header.get_message_id());
1082        assert!(result.is_ok());
1083    }
1084
1085    #[test]
1086    fn test_panic_proto_message() {
1087        // create a unusual proto message
1088        let message = ProtoMessage {
1089            metadata: HashMap::new(),
1090            message_type: None,
1091        };
1092
1093        // the operation to retrieve the header should fail with panic
1094        let result = std::panic::catch_unwind(|| message.get_agp_header());
1095        assert!(result.is_err());
1096
1097        // the operation to retrieve the message type should fail with panic
1098        let result = std::panic::catch_unwind(|| message.get_type());
1099        assert!(result.is_err());
1100
1101        // all the other ops should fail with panic as well as the header is not set
1102        let result = std::panic::catch_unwind(|| message.get_source());
1103        assert!(result.is_err());
1104        let result = std::panic::catch_unwind(|| message.get_name());
1105        assert!(result.is_err());
1106        let result = std::panic::catch_unwind(|| message.get_recv_from());
1107        assert!(result.is_err());
1108        let result = std::panic::catch_unwind(|| message.get_forward_to());
1109        assert!(result.is_err());
1110        let result = std::panic::catch_unwind(|| message.get_incoming_conn());
1111        assert!(result.is_err());
1112        let result = std::panic::catch_unwind(|| message.get_fanout());
1113        assert!(result.is_err());
1114    }
1115
1116    #[test]
1117    fn test_service_type_to_int() {
1118        // Get total number of service types
1119        let total_service_types = SessionHeaderType::BeaconPubSub as i32;
1120
1121        for i in 0..total_service_types {
1122            // int -> ServiceType
1123            let service_type =
1124                SessionHeaderType::try_from(i).expect("failed to convert int to service type");
1125            let service_type_int = i32::from(service_type);
1126            assert_eq!(service_type_int, i32::from(service_type),);
1127        }
1128
1129        // Test invalid conversion
1130        let invalid_service_type = SessionHeaderType::try_from(total_service_types + 1);
1131        assert!(invalid_service_type.is_err());
1132    }
1133}