1use std::collections::HashMap;
5use std::fmt::Display;
6
7use tracing::debug;
8
9use super::encoder::Name;
10use crate::api::{
11 Content, MessageType, ProtoMessage, ProtoName, ProtoPublish, ProtoPublishType,
12 ProtoSessionType, ProtoSubscribe, ProtoSubscribeType, ProtoUnsubscribe, ProtoUnsubscribeType,
13 SessionHeader, SlimHeader,
14 proto::dataplane::v1::{OriginalName, SessionMessageType},
15};
16
17use thiserror::Error;
18use tracing::error;
19
20#[derive(Error, Debug, PartialEq)]
21pub enum MessageError {
22 #[error("SLIM header not found")]
23 SlimHeaderNotFound,
24 #[error("source not found")]
25 SourceNotFound,
26 #[error("destination not found")]
27 DestinationNotFound,
28 #[error("session header not found")]
29 SessionHeaderNotFound,
30 #[error("message type not found")]
31 MessageTypeNotFound,
32 #[error("incoming connection not found")]
33 IncomingConnectionNotFound,
34}
35
36pub const SLIM_IDENTITY: &str = "SLIM_IDENTITY";
38
39impl From<&Name> for ProtoName {
41 fn from(name: &Name) -> Self {
42 Self {
43 component_0: name.components()[0],
44 component_1: name.components()[1],
45 component_2: name.components()[2],
46 component_3: name.components()[3],
47 }
48 }
49}
50
51impl Display for MessageType {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 match self {
55 MessageType::Publish(_) => write!(f, "publish"),
56 MessageType::Subscribe(_) => write!(f, "subscribe"),
57 MessageType::Unsubscribe(_) => write!(f, "unsubscribe"),
58 }
59 }
60}
61
62#[derive(Debug, Clone)]
64pub struct SlimHeaderFlags {
65 pub fanout: u32,
66 pub recv_from: Option<u64>,
67 pub forward_to: Option<u64>,
68 pub incoming_conn: Option<u64>,
69 pub error: Option<bool>,
70}
71
72impl Default for SlimHeaderFlags {
73 fn default() -> Self {
74 Self {
75 fanout: 1,
76 recv_from: None,
77 forward_to: None,
78 incoming_conn: None,
79 error: None,
80 }
81 }
82}
83
84impl Display for SlimHeaderFlags {
85 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
86 write!(
87 f,
88 "fanout: {}, recv_from: {:?}, forward_to: {:?}, incoming_conn: {:?}, error: {:?}",
89 self.fanout, self.recv_from, self.forward_to, self.incoming_conn, self.error
90 )
91 }
92}
93
94impl SlimHeaderFlags {
95 pub fn new(
96 fanout: u32,
97 recv_from: Option<u64>,
98 forward_to: Option<u64>,
99 incoming_conn: Option<u64>,
100 error: Option<bool>,
101 ) -> Self {
102 Self {
103 fanout,
104 recv_from,
105 forward_to,
106 incoming_conn,
107 error,
108 }
109 }
110
111 pub fn with_fanout(self, fanout: u32) -> Self {
112 Self { fanout, ..self }
113 }
114
115 pub fn with_recv_from(self, recv_from: u64) -> Self {
116 Self {
117 recv_from: Some(recv_from),
118 ..self
119 }
120 }
121
122 pub fn with_forward_to(self, forward_to: u64) -> Self {
123 Self {
124 forward_to: Some(forward_to),
125 ..self
126 }
127 }
128
129 pub fn with_incoming_conn(self, incoming_conn: u64) -> Self {
130 Self {
131 incoming_conn: Some(incoming_conn),
132 ..self
133 }
134 }
135
136 pub fn with_error(self, error: bool) -> Self {
137 Self {
138 error: Some(error),
139 ..self
140 }
141 }
142}
143
144impl SlimHeader {
148 pub fn new(source: &Name, destination: &Name, flags: Option<SlimHeaderFlags>) -> Self {
149 let flags = flags.unwrap_or_default();
150
151 Self {
152 source: Some(ProtoName::from(source)),
153 destination: Some(ProtoName::from(destination)),
154 fanout: flags.fanout,
155 recv_from: flags.recv_from,
156 forward_to: flags.forward_to,
157 incoming_conn: flags.incoming_conn,
158 error: flags.error,
159 }
160 }
161
162 pub fn clear(&mut self) {
163 self.recv_from = None;
164 self.forward_to = None;
165 }
166
167 pub fn get_recv_from(&self) -> Option<u64> {
168 self.recv_from
169 }
170
171 pub fn get_forward_to(&self) -> Option<u64> {
172 self.forward_to
173 }
174
175 pub fn get_incoming_conn(&self) -> Option<u64> {
176 self.incoming_conn
177 }
178
179 pub fn get_error(&self) -> Option<bool> {
180 self.error
181 }
182
183 pub fn get_source(&self) -> Name {
184 match &self.source {
185 Some(source) => Name::from(source),
186 None => panic!("source not found"),
187 }
188 }
189
190 pub fn get_dst(&self) -> Name {
191 match &self.destination {
192 Some(destination) => Name::from(destination),
193 None => panic!("destination not found"),
194 }
195 }
196
197 pub fn set_source(&mut self, source: &Name) {
198 self.source = Some(ProtoName::from(source));
199 }
200
201 pub fn set_destination(&mut self, dst: &Name) {
202 self.destination = Some(ProtoName::from(dst));
203 }
204
205 pub fn get_fanout(&self) -> u32 {
206 self.fanout
207 }
208
209 pub fn set_recv_from(&mut self, recv_from: Option<u64>) {
210 self.recv_from = recv_from;
211 }
212
213 pub fn set_forward_to(&mut self, forward_to: Option<u64>) {
214 self.forward_to = forward_to;
215 }
216
217 pub fn set_error(&mut self, error: Option<bool>) {
218 self.error = error;
219 }
220
221 pub fn set_incoming_conn(&mut self, incoming_conn: Option<u64>) {
222 self.incoming_conn = incoming_conn;
223 }
224
225 pub fn set_error_flag(&mut self, error: Option<bool>) {
226 self.error = error;
227 }
228
229 pub fn set_fanout(&mut self, fanout: u32) {
230 self.fanout = fanout;
231 }
232
233 pub fn get_in_out_connections(&self) -> (u64, Option<u64>) {
237 let incoming = self
239 .get_incoming_conn()
240 .expect("incoming connection not found");
241
242 if let Some(val) = self.get_recv_from() {
243 debug!(
244 "received recv_from command, update state on connection {}",
245 val
246 );
247 return (val, None);
248 }
249
250 if let Some(val) = self.get_forward_to() {
251 debug!(
252 "received forward_to command, update state and forward to connection {}",
253 val
254 );
255 return (incoming, Some(val));
256 }
257
258 (incoming, None)
260 }
261}
262
263impl SessionHeader {
267 pub fn new(
268 session_type: i32,
269 session_message_type: i32,
270 session_id: u32,
271 message_id: u32,
272 source: &Option<Name>,
273 destination: &Option<Name>,
274 ) -> Self {
275 let src = match source {
276 Some(name) => name.components_strings().map(|c| OriginalName {
277 component_0: c[0].clone(),
278 component_1: c[1].clone(),
279 component_2: c[2].clone(),
280 component_3: name.id(),
281 }),
282 None => None,
283 };
284
285 let dst = match destination {
286 Some(name) => name.components_strings().map(|c| OriginalName {
287 component_0: c[0].clone(),
288 component_1: c[1].clone(),
289 component_2: c[2].clone(),
290 component_3: name.id(),
291 }),
292 None => None,
293 };
294
295 Self {
296 session_type,
297 session_message_type,
298 session_id,
299 message_id,
300 source: src,
301 destination: dst,
302 }
303 }
304
305 pub fn get_session_id(&self) -> u32 {
306 self.session_id
307 }
308
309 pub fn get_message_id(&self) -> u32 {
310 self.message_id
311 }
312
313 pub fn set_session_id(&mut self, session_id: u32) {
314 self.session_id = session_id;
315 }
316
317 pub fn set_message_id(&mut self, message_id: u32) {
318 self.message_id = message_id;
319 }
320
321 pub fn clear(&mut self) {
322 self.session_id = 0;
323 self.message_id = 0;
324 }
325
326 pub fn get_source(&self) -> Option<Name> {
327 match &self.source {
328 Some(src) => {
329 let n = Name::from_strings([
330 src.component_0.clone(),
331 src.component_1.clone(),
332 src.component_2.clone(),
333 ])
334 .with_id(src.component_3);
335 Some(n)
336 }
337 None => None,
338 }
339 }
340
341 pub fn set_source(&mut self, source: &Name) {
342 let c_opt = source.components_strings();
343 if c_opt.is_none() {
344 return;
345 }
346
347 let c = c_opt.unwrap();
348
349 self.source = Some(OriginalName {
350 component_0: c[0].clone(),
351 component_1: c[1].clone(),
352 component_2: c[2].clone(),
353 component_3: source.id(),
354 });
355 }
356
357 pub fn get_destination(&self) -> Option<Name> {
358 match &self.destination {
359 Some(src) => {
360 let n = Name::from_strings([
361 src.component_0.clone(),
362 src.component_1.clone(),
363 src.component_2.clone(),
364 ])
365 .with_id(src.component_3);
366 Some(n)
367 }
368 None => None,
369 }
370 }
371
372 pub fn set_destination(&mut self, destination: &Name) {
373 let c_opt = destination.components_strings();
374 if c_opt.is_none() {
375 return;
376 }
377
378 let c = c_opt.unwrap();
379
380 self.destination = Some(OriginalName {
381 component_0: c[0].clone(),
382 component_1: c[1].clone(),
383 component_2: c[2].clone(),
384 component_3: destination.id(),
385 });
386 }
387}
388
389impl ProtoSubscribe {
392 pub fn new(source: &Name, dst: &Name, flags: Option<SlimHeaderFlags>) -> Self {
393 let header = Some(SlimHeader::new(source, dst, flags));
394
395 ProtoSubscribe {
396 header,
397 component_0: dst.components_strings().unwrap()[0].clone(),
398 component_1: dst.components_strings().unwrap()[1].clone(),
399 component_2: dst.components_strings().unwrap()[2].clone(),
400 }
401 }
402}
403
404impl From<ProtoMessage> for ProtoSubscribe {
406 fn from(message: ProtoMessage) -> Self {
407 match message.message_type {
408 Some(ProtoSubscribeType(s)) => s,
409 _ => panic!("message type is not subscribe"),
410 }
411 }
412}
413
414impl ProtoUnsubscribe {
417 pub fn new(source: &Name, dst: &Name, flags: Option<SlimHeaderFlags>) -> Self {
418 let header = Some(SlimHeader::new(source, dst, flags));
419
420 ProtoUnsubscribe {
421 header,
422 component_0: dst.components_strings().unwrap()[0].clone(),
423 component_1: dst.components_strings().unwrap()[1].clone(),
424 component_2: dst.components_strings().unwrap()[2].clone(),
425 }
426 }
427}
428
429impl From<ProtoMessage> for ProtoUnsubscribe {
431 fn from(message: ProtoMessage) -> Self {
432 match message.message_type {
433 Some(ProtoUnsubscribeType(u)) => u,
434 _ => panic!("message type is not unsubscribe"),
435 }
436 }
437}
438
439impl ProtoPublish {
442 pub fn with_header(
443 header: Option<SlimHeader>,
444 session: Option<SessionHeader>,
445 payload: Option<Content>,
446 ) -> Self {
447 ProtoPublish {
448 header,
449 session,
450 msg: payload,
451 }
452 }
453
454 pub fn new(
455 source: &Name,
456 dst: &Name,
457 flags: Option<SlimHeaderFlags>,
458 content_type: &str,
459 blob: Vec<u8>,
460 ) -> Self {
461 let slim_header = Some(SlimHeader::new(source, dst, flags));
462
463 let mut session_header = SessionHeader::default();
464 session_header.set_source(source);
465 session_header.set_destination(dst);
466
467 let msg = Some(Content {
468 content_type: content_type.to_string(),
469 blob,
470 });
471
472 Self::with_header(slim_header, Some(session_header), msg)
473 }
474
475 pub fn get_slim_header(&self) -> &SlimHeader {
476 self.header.as_ref().unwrap()
477 }
478
479 pub fn get_session_header(&self) -> &SessionHeader {
480 self.session.as_ref().unwrap()
481 }
482
483 pub fn get_slim_header_as_mut(&mut self) -> &mut SlimHeader {
484 self.header.as_mut().unwrap()
485 }
486
487 pub fn get_session_header_as_mut(&mut self) -> &mut SessionHeader {
488 self.session.as_mut().unwrap()
489 }
490
491 pub fn get_payload(&self) -> &Content {
492 self.msg.as_ref().unwrap()
493 }
494}
495
496impl From<ProtoMessage> for ProtoPublish {
498 fn from(message: ProtoMessage) -> Self {
499 match message.message_type {
500 Some(ProtoPublishType(p)) => p,
501 _ => panic!("message type is not publish"),
502 }
503 }
504}
505
506impl ProtoMessage {
509 fn new(metadata: HashMap<String, String>, message_type: MessageType) -> Self {
510 ProtoMessage {
511 metadata,
512 message_type: Some(message_type),
513 }
514 }
515
516 pub fn new_subscribe(source: &Name, dst: &Name, flags: Option<SlimHeaderFlags>) -> Self {
517 let subscribe = ProtoSubscribe::new(source, dst, flags);
518
519 Self::new(HashMap::new(), ProtoSubscribeType(subscribe))
520 }
521
522 pub fn new_unsubscribe(source: &Name, dst: &Name, flags: Option<SlimHeaderFlags>) -> Self {
523 let unsubscribe = ProtoUnsubscribe::new(source, dst, flags);
524
525 Self::new(HashMap::new(), ProtoUnsubscribeType(unsubscribe))
526 }
527
528 pub fn new_publish(
529 source: &Name,
530 dst: &Name,
531 flags: Option<SlimHeaderFlags>,
532 content_type: &str,
533 blob: Vec<u8>,
534 ) -> Self {
535 let publish = ProtoPublish::new(source, dst, flags, content_type, blob);
536
537 Self::new(HashMap::new(), ProtoPublishType(publish))
538 }
539
540 pub fn new_publish_with_headers(
541 slim_header: Option<SlimHeader>,
542 session_header: Option<SessionHeader>,
543 content_type: &str,
544 blob: Vec<u8>,
545 ) -> Self {
546 let publish = ProtoPublish::with_header(
547 slim_header,
548 session_header,
549 Some(Content {
550 content_type: content_type.to_string(),
551 blob,
552 }),
553 );
554
555 Self::new(HashMap::new(), ProtoPublishType(publish))
556 }
557
558 pub fn validate(&self) -> Result<(), MessageError> {
560 if self.message_type.is_none() {
562 return Err(MessageError::MessageTypeNotFound);
563 }
564
565 if self.try_get_slim_header().is_none() {
567 return Err(MessageError::SlimHeaderNotFound);
568 }
569
570 let slim_header = self.get_slim_header();
572
573 if slim_header.source.is_none() {
575 return Err(MessageError::SourceNotFound);
576 }
577 if slim_header.destination.is_none() {
578 return Err(MessageError::DestinationNotFound);
579 }
580
581 match &self.message_type {
582 Some(ProtoPublishType(p)) => {
583 if p.header.is_none() {
585 return Err(MessageError::SlimHeaderNotFound);
586 }
587
588 if p.session.is_none() {
590 return Err(MessageError::SessionHeaderNotFound);
591 }
592 }
593 Some(ProtoSubscribeType(s)) => {
594 if s.header.is_none() {
595 return Err(MessageError::SlimHeaderNotFound);
596 }
597 }
598 Some(ProtoUnsubscribeType(u)) => {
599 if u.header.is_none() {
600 return Err(MessageError::SlimHeaderNotFound);
601 }
602 }
603 None => return Err(MessageError::MessageTypeNotFound),
604 }
605
606 Ok(())
607 }
608
609 pub fn insert_metadata(&mut self, key: String, val: String) {
612 self.metadata.insert(key, val);
613 }
614
615 pub fn remove_metadata(&mut self, key: &str) -> Option<String> {
617 self.metadata.remove(key)
618 }
619
620 pub fn contains_metadata(&self, key: &str) -> bool {
621 self.metadata.contains_key(key)
622 }
623
624 pub fn get_metadata(&self, key: &str) -> Option<&String> {
625 self.metadata.get(key)
626 }
627
628 pub fn get_metadata_map(&self) -> HashMap<String, String> {
629 self.metadata.clone()
630 }
631
632 pub fn set_metadata_map(&mut self, map: HashMap<String, String>) {
633 for (k, v) in map.iter() {
634 self.insert_metadata(k.to_string(), v.to_string());
635 }
636 }
637
638 pub fn get_slim_header(&self) -> &SlimHeader {
639 match &self.message_type {
640 Some(ProtoPublishType(publish)) => publish.header.as_ref().unwrap(),
641 Some(ProtoSubscribeType(sub)) => sub.header.as_ref().unwrap(),
642 Some(ProtoUnsubscribeType(unsub)) => unsub.header.as_ref().unwrap(),
643 None => panic!("SLIM header not found"),
644 }
645 }
646
647 pub fn get_slim_header_mut(&mut self) -> &mut SlimHeader {
648 match &mut self.message_type {
649 Some(ProtoPublishType(publish)) => publish.header.as_mut().unwrap(),
650 Some(ProtoSubscribeType(sub)) => sub.header.as_mut().unwrap(),
651 Some(ProtoUnsubscribeType(unsub)) => unsub.header.as_mut().unwrap(),
652 None => panic!("SLIM header not found"),
653 }
654 }
655
656 pub fn try_get_slim_header(&self) -> Option<&SlimHeader> {
657 match &self.message_type {
658 Some(ProtoPublishType(publish)) => publish.header.as_ref(),
659 Some(ProtoSubscribeType(sub)) => sub.header.as_ref(),
660 Some(ProtoUnsubscribeType(unsub)) => unsub.header.as_ref(),
661 None => None,
662 }
663 }
664
665 pub fn get_session_header(&self) -> &SessionHeader {
666 match &self.message_type {
667 Some(ProtoPublishType(publish)) => publish.session.as_ref().unwrap(),
668 Some(ProtoSubscribeType(_)) => panic!("session header not found"),
669 Some(ProtoUnsubscribeType(_)) => panic!("session header not found"),
670 None => panic!("session header not found"),
671 }
672 }
673
674 pub fn get_session_header_mut(&mut self) -> &mut SessionHeader {
675 match &mut self.message_type {
676 Some(ProtoPublishType(publish)) => publish.session.as_mut().unwrap(),
677 Some(ProtoSubscribeType(_)) => panic!("session header not found"),
678 Some(ProtoUnsubscribeType(_)) => panic!("session header not found"),
679 None => panic!("session header not found"),
680 }
681 }
682
683 pub fn try_get_session_header(&self) -> Option<&SessionHeader> {
684 match &self.message_type {
685 Some(ProtoPublishType(publish)) => publish.session.as_ref(),
686 Some(ProtoSubscribeType(_)) => None,
687 Some(ProtoUnsubscribeType(_)) => None,
688 None => None,
689 }
690 }
691
692 pub fn try_get_session_header_mut(&mut self) -> Option<&mut SessionHeader> {
693 match &mut self.message_type {
694 Some(ProtoPublishType(publish)) => publish.session.as_mut(),
695 Some(ProtoSubscribeType(_)) => None,
696 Some(ProtoUnsubscribeType(_)) => None,
697 None => None,
698 }
699 }
700
701 pub fn get_id(&self) -> u32 {
702 self.get_session_header().get_message_id()
703 }
704
705 pub fn get_source(&self) -> Name {
706 if let Some(ProtoPublishType(pubslih)) = &self.message_type {
707 if let Some(s) = pubslih.get_session_header().get_source() {
709 return s;
710 }
711 }
712
713 self.get_slim_header().get_source()
715 }
716
717 pub fn get_fanout(&self) -> u32 {
718 self.get_slim_header().get_fanout()
719 }
720
721 pub fn get_recv_from(&self) -> Option<u64> {
722 self.get_slim_header().get_recv_from()
723 }
724
725 pub fn get_forward_to(&self) -> Option<u64> {
726 self.get_slim_header().get_forward_to()
727 }
728
729 pub fn get_error(&self) -> Option<bool> {
730 self.get_slim_header().get_error()
731 }
732
733 pub fn get_incoming_conn(&self) -> u64 {
734 self.get_slim_header().get_incoming_conn().unwrap()
735 }
736
737 pub fn try_get_incoming_conn(&self) -> Option<u64> {
738 self.get_slim_header().get_incoming_conn()
739 }
740
741 pub fn get_dst(&self) -> Name {
742 match &self.message_type {
743 Some(ProtoPublishType(pubslih)) => {
744 if let Some(d) = pubslih.get_session_header().get_destination() {
746 return d;
747 }
748 self.get_slim_header().get_dst()
750 }
751 Some(ProtoSubscribeType(subscribe)) => {
752 let dst = self.get_slim_header().get_dst();
753 Name::from_strings([
755 subscribe.component_0.clone(),
756 subscribe.component_1.clone(),
757 subscribe.component_2.clone(),
758 ])
759 .with_id(dst.id())
760 }
761 Some(ProtoUnsubscribeType(unsubscribe)) => {
762 let dst = self.get_slim_header().get_dst();
763 Name::from_strings([
765 unsubscribe.component_0.clone(),
766 unsubscribe.component_1.clone(),
767 unsubscribe.component_2.clone(),
768 ])
769 .with_id(dst.id())
770 }
771 None => {
772 self.get_slim_header().get_dst()
774 }
775 }
776 }
777
778 pub fn get_type(&self) -> &MessageType {
779 match &self.message_type {
780 Some(t) => t,
781 None => panic!("message type not found"),
782 }
783 }
784
785 pub fn get_payload(&self) -> Option<&Content> {
786 match &self.message_type {
787 Some(ProtoPublishType(p)) => p.msg.as_ref(),
788 Some(ProtoSubscribeType(_)) => panic!("payload not found"),
789 Some(ProtoUnsubscribeType(_)) => panic!("payload not found"),
790 None => panic!("payload not found"),
791 }
792 }
793
794 pub fn get_session_message_type(&self) -> SessionMessageType {
795 self.get_session_header()
796 .session_message_type
797 .try_into()
798 .unwrap_or_default()
799 }
800
801 pub fn clear_slim_header(&mut self) {
802 self.get_slim_header_mut().clear();
803 }
804
805 pub fn set_recv_from(&mut self, recv_from: Option<u64>) {
806 self.get_slim_header_mut().set_recv_from(recv_from);
807 }
808
809 pub fn set_forward_to(&mut self, forward_to: Option<u64>) {
810 self.get_slim_header_mut().set_forward_to(forward_to);
811 }
812
813 pub fn set_error(&mut self, error: Option<bool>) {
814 self.get_slim_header_mut().set_error(error);
815 }
816
817 pub fn set_fanout(&mut self, fanout: u32) {
818 self.get_slim_header_mut().set_fanout(fanout);
819 }
820
821 pub fn set_incoming_conn(&mut self, incoming_conn: Option<u64>) {
822 self.get_slim_header_mut().set_incoming_conn(incoming_conn);
823 }
824
825 pub fn set_error_flag(&mut self, error: Option<bool>) {
826 self.get_slim_header_mut().set_error_flag(error);
827 }
828
829 pub fn set_session_message_type(&mut self, message_type: SessionMessageType) {
830 self.get_session_header_mut()
831 .set_session_message_type(message_type);
832 }
833
834 pub fn set_session_type(&mut self, session_type: ProtoSessionType) {
835 self.get_session_header_mut().set_session_type(session_type);
836 }
837
838 pub fn get_session_type(&self) -> ProtoSessionType {
839 self.get_session_header().session_type()
840 }
841
842 pub fn set_message_id(&mut self, message_id: u32) {
843 self.get_session_header_mut().set_message_id(message_id);
844 }
845
846 pub fn is_publish(&self) -> bool {
847 matches!(self.get_type(), MessageType::Publish(_))
848 }
849
850 pub fn is_subscribe(&self) -> bool {
851 matches!(self.get_type(), MessageType::Subscribe(_))
852 }
853
854 pub fn is_unsubscribe(&self) -> bool {
855 matches!(self.get_type(), MessageType::Unsubscribe(_))
856 }
857}
858
859impl AsRef<ProtoPublish> for ProtoMessage {
860 fn as_ref(&self) -> &ProtoPublish {
861 match &self.message_type {
862 Some(ProtoPublishType(p)) => p,
863 _ => panic!("message type is not publish"),
864 }
865 }
866}
867
868#[cfg(test)]
869mod tests {
870 use crate::{api::proto::dataplane::v1::SessionMessageType, messages::encoder::Name};
871
872 use super::*;
873
874 fn test_subscription_template(
875 subscription: bool,
876 source: Name,
877 dst: Name,
878 flags: Option<SlimHeaderFlags>,
879 ) {
880 let sub = {
881 if subscription {
882 ProtoMessage::new_subscribe(&source, &dst, flags.clone())
883 } else {
884 ProtoMessage::new_unsubscribe(&source, &dst, flags.clone())
885 }
886 };
887
888 let flags = if flags.is_none() {
889 Some(SlimHeaderFlags::default())
890 } else {
891 flags
892 };
893
894 assert!(!sub.is_publish());
895 assert_eq!(sub.is_subscribe(), subscription);
896 assert_eq!(sub.is_unsubscribe(), !subscription);
897 assert_eq!(flags.as_ref().unwrap().recv_from, sub.get_recv_from());
898 assert_eq!(flags.as_ref().unwrap().forward_to, sub.get_forward_to());
899 assert_eq!(None, sub.try_get_incoming_conn());
900 assert_eq!(source, sub.get_source());
901 let got_name = sub.get_dst();
902 assert_eq!(dst, got_name);
903 }
904
905 fn test_publish_template(source: Name, dst: Name, flags: Option<SlimHeaderFlags>) {
906 let pub_msg = ProtoMessage::new_publish(
907 &source,
908 &dst,
909 flags.clone(),
910 "str",
911 "this is the content of the message".into(),
912 );
913
914 let flags = if flags.is_none() {
915 Some(SlimHeaderFlags::default())
916 } else {
917 flags
918 };
919
920 assert!(pub_msg.is_publish());
921 assert!(!pub_msg.is_subscribe());
922 assert!(!pub_msg.is_unsubscribe());
923 assert_eq!(flags.as_ref().unwrap().recv_from, pub_msg.get_recv_from());
924 assert_eq!(flags.as_ref().unwrap().forward_to, pub_msg.get_forward_to());
925 assert_eq!(None, pub_msg.try_get_incoming_conn());
926 assert_eq!(source, pub_msg.get_source());
927 let got_name = pub_msg.get_dst();
928 assert_eq!(dst, got_name);
929 assert_eq!(flags.as_ref().unwrap().fanout, pub_msg.get_fanout());
930 }
931
932 #[test]
933 fn test_subscription() {
934 let source = Name::from_strings(["org", "ns", "type"]).with_id(1);
935 let dst = Name::from_strings(["org", "ns", "type"]).with_id(2);
936
937 test_subscription_template(true, source.clone(), dst.clone(), None);
939
940 test_subscription_template(true, source.clone(), dst.clone(), None);
942
943 test_subscription_template(
945 true,
946 source.clone(),
947 dst.clone(),
948 Some(SlimHeaderFlags::default().with_recv_from(50)),
949 );
950
951 test_subscription_template(
953 true,
954 source.clone(),
955 dst.clone(),
956 Some(SlimHeaderFlags::default().with_forward_to(30)),
957 );
958 }
959
960 #[test]
961 fn test_unsubscription() {
962 let source = Name::from_strings(["org", "ns", "type"]).with_id(1);
963 let dst = Name::from_strings(["org", "ns", "type"]).with_id(2);
964
965 test_subscription_template(false, source.clone(), dst.clone(), None);
967
968 test_subscription_template(false, source.clone(), dst.clone(), None);
970
971 test_subscription_template(
973 false,
974 source.clone(),
975 dst.clone(),
976 Some(SlimHeaderFlags::default().with_recv_from(50)),
977 );
978
979 test_subscription_template(
981 false,
982 source.clone(),
983 dst.clone(),
984 Some(SlimHeaderFlags::default().with_forward_to(30)),
985 );
986 }
987
988 #[test]
989 fn test_publish() {
990 let source = Name::from_strings(["org", "ns", "type"]).with_id(1);
991 let mut dst = Name::from_strings(["org", "ns", "type"]);
992
993 test_publish_template(
995 source.clone(),
996 dst.clone(),
997 Some(SlimHeaderFlags::default()),
998 );
999
1000 dst.set_id(2);
1002 test_publish_template(
1003 source.clone(),
1004 dst.clone(),
1005 Some(SlimHeaderFlags::default()),
1006 );
1007 dst.reset_id();
1008
1009 test_publish_template(
1011 source.clone(),
1012 dst.clone(),
1013 Some(SlimHeaderFlags::default().with_recv_from(50)),
1014 );
1015
1016 test_publish_template(
1018 source.clone(),
1019 dst.clone(),
1020 Some(SlimHeaderFlags::default().with_forward_to(30)),
1021 );
1022
1023 test_publish_template(
1025 source.clone(),
1026 dst.clone(),
1027 Some(SlimHeaderFlags::default().with_fanout(2)),
1028 );
1029 }
1030
1031 #[test]
1032 fn test_conversions() {
1033 let name = Name::from_strings(["org", "ns", "type"]).with_id(1);
1035 let proto_name = ProtoName::from(&name);
1036
1037 assert_eq!(proto_name.component_0, name.components()[0]);
1038 assert_eq!(proto_name.component_1, name.components()[1]);
1039 assert_eq!(proto_name.component_2, name.components()[2]);
1040 assert_eq!(proto_name.component_3, name.components()[3]);
1041
1042 let name_from_proto = Name::from(&proto_name);
1044 assert_eq!(name_from_proto.components()[0], proto_name.component_0);
1045 assert_eq!(name_from_proto.components()[1], proto_name.component_1);
1046 assert_eq!(name_from_proto.components()[2], proto_name.component_2);
1047 assert_eq!(name_from_proto.components()[3], proto_name.component_3);
1048
1049 let dst = Name::from_strings(["org", "ns", "type"]).with_id(1);
1051 let proto_subscribe = ProtoMessage::new_subscribe(
1052 &name,
1053 &dst,
1054 Some(
1055 SlimHeaderFlags::default()
1056 .with_recv_from(2)
1057 .with_forward_to(3),
1058 ),
1059 );
1060 let proto_subscribe = ProtoSubscribe::from(proto_subscribe);
1061 assert_eq!(proto_subscribe.header.as_ref().unwrap().get_source(), name);
1062 assert_eq!(proto_subscribe.header.as_ref().unwrap().get_dst(), dst,);
1063
1064 let proto_unsubscribe = ProtoMessage::new_unsubscribe(
1066 &name,
1067 &dst,
1068 Some(
1069 SlimHeaderFlags::default()
1070 .with_recv_from(2)
1071 .with_forward_to(3),
1072 ),
1073 );
1074 let proto_unsubscribe = ProtoUnsubscribe::from(proto_unsubscribe);
1075 assert_eq!(
1076 proto_unsubscribe.header.as_ref().unwrap().get_source(),
1077 name
1078 );
1079 assert_eq!(proto_unsubscribe.header.as_ref().unwrap().get_dst(), dst);
1080
1081 let proto_publish = ProtoMessage::new_publish(
1083 &name,
1084 &dst,
1085 Some(
1086 SlimHeaderFlags::default()
1087 .with_recv_from(2)
1088 .with_forward_to(3),
1089 ),
1090 "str",
1091 "this is the content of the message".into(),
1092 );
1093 let proto_publish = ProtoPublish::from(proto_publish);
1094 assert_eq!(proto_publish.header.as_ref().unwrap().get_source(), name);
1095 assert_eq!(proto_publish.header.as_ref().unwrap().get_dst(), dst);
1096 }
1097
1098 #[test]
1099 fn test_panic() {
1100 let source = Name::from_strings(["org", "ns", "type"]).with_id(1);
1101 let dst = Name::from_strings(["org", "ns", "type"]).with_id(2);
1102
1103 let msg = ProtoMessage::new_subscribe(
1105 &source,
1106 &dst,
1107 Some(
1108 SlimHeaderFlags::default()
1109 .with_recv_from(2)
1110 .with_forward_to(3),
1111 ),
1112 );
1113
1114 let result = std::panic::catch_unwind(|| ProtoUnsubscribe::from(msg.clone()));
1117 assert!(result.is_err());
1118
1119 let result = std::panic::catch_unwind(|| ProtoPublish::from(msg.clone()));
1122 assert!(result.is_err());
1123
1124 let result = std::panic::catch_unwind(|| ProtoSubscribe::from(msg));
1126 assert!(result.is_ok());
1127 }
1128
1129 #[test]
1130 fn test_panic_header() {
1131 let header = SlimHeader {
1133 source: None,
1134 destination: None,
1135 fanout: 0,
1136 recv_from: None,
1137 forward_to: None,
1138 incoming_conn: None,
1139 error: None,
1140 };
1141
1142 let result = std::panic::catch_unwind(|| header.get_source());
1144 assert!(result.is_err());
1145
1146 let result = std::panic::catch_unwind(|| header.get_dst());
1147 assert!(result.is_err());
1148
1149 let result = std::panic::catch_unwind(|| header.get_recv_from());
1151 assert!(result.is_ok());
1152
1153 let result = std::panic::catch_unwind(|| header.get_forward_to());
1154 assert!(result.is_ok());
1155
1156 let result = std::panic::catch_unwind(|| header.get_incoming_conn());
1158 assert!(result.is_ok());
1159
1160 let result = std::panic::catch_unwind(|| header.get_error());
1162 assert!(result.is_ok());
1163 }
1164
1165 #[test]
1166 fn test_panic_session_header() {
1167 let header = SessionHeader::new(0, 0, 0, 0, &None, &None);
1169
1170 let result = std::panic::catch_unwind(|| header.get_session_id());
1172 assert!(result.is_ok());
1173
1174 let result = std::panic::catch_unwind(|| header.get_message_id());
1175 assert!(result.is_ok());
1176 }
1177
1178 #[test]
1179 fn test_panic_proto_message() {
1180 let message = ProtoMessage {
1182 metadata: HashMap::new(),
1183 message_type: None,
1184 };
1185
1186 let result = std::panic::catch_unwind(|| message.get_slim_header());
1188 assert!(result.is_err());
1189
1190 let result = std::panic::catch_unwind(|| message.get_type());
1192 assert!(result.is_err());
1193
1194 let result = std::panic::catch_unwind(|| message.get_source());
1196 assert!(result.is_err());
1197 let result = std::panic::catch_unwind(|| message.get_dst());
1198 assert!(result.is_err());
1199 let result = std::panic::catch_unwind(|| message.get_recv_from());
1200 assert!(result.is_err());
1201 let result = std::panic::catch_unwind(|| message.get_forward_to());
1202 assert!(result.is_err());
1203 let result = std::panic::catch_unwind(|| message.get_incoming_conn());
1204 assert!(result.is_err());
1205 let result = std::panic::catch_unwind(|| message.get_fanout());
1206 assert!(result.is_err());
1207 }
1208
1209 #[test]
1210 fn test_service_type_to_int() {
1211 let total_service_types = SessionMessageType::ChannelMlsAck as i32;
1213
1214 for i in 0..total_service_types {
1215 let service_type =
1217 SessionMessageType::try_from(i).expect("failed to convert int to service type");
1218 let service_type_int = i32::from(service_type);
1219 assert_eq!(service_type_int, i32::from(service_type),);
1220 }
1221
1222 let invalid_service_type = SessionMessageType::try_from(total_service_types + 1);
1224 assert!(invalid_service_type.is_err());
1225 }
1226}