1use std::fmt::Display;
5use std::{collections::HashMap, time::Duration};
6
7use super::encoder::Name;
8use crate::api::proto::dataplane::v1::{GroupClosePayload, GroupNackPayload, PingPayload};
9use crate::api::{
10 Content, LinkNegotiationPayload, MessageType, ProtoLink, ProtoLinkMessageType, ProtoLinkType,
11 ProtoMessage, ProtoName, ProtoPublish, ProtoPublishType, ProtoSessionType, ProtoSubscribe,
12 ProtoSubscribeType, ProtoSubscriptionAck, ProtoSubscriptionAckType, ProtoUnsubscribe,
13 ProtoUnsubscribeType, SessionHeader, SlimHeader,
14 proto::dataplane::v1::{
15 ApplicationPayload, CommandPayload, DiscoveryReplyPayload, DiscoveryRequestPayload,
16 EncodedName, GroupAckPayload, GroupAddPayload, GroupProposalPayload, GroupRemovePayload,
17 GroupWelcomePayload, JoinReplyPayload, JoinRequestPayload, LeaveReplyPayload,
18 LeaveRequestPayload, MlsPayload, SessionMessageType, StringName, TimerSettings,
19 command_payload::CommandPayloadType, content::ContentType,
20 },
21};
22
23use thiserror::Error;
24
25pub const IS_MODERATOR: &str = "IS_MODERATOR";
28
29pub const DELETE_GROUP: &str = "DELETE_GROUP";
33
34pub const PUBLISH_TO: &str = "PUBLISH_TO";
38
39pub const DISCONNECTION_DETECTED: &str = "DISCONNECTION_DETECTED";
44
45pub const LEAVING_SESSION: &str = "LEAVING_SESSION";
49
50pub const TRUE_VAL: &str = "TRUE";
52
53pub const FALSE_VAL: &str = "FALSE";
55
56pub const MAX_PUBLISH_ID: u32 = u32::MAX / 2;
61
62#[derive(Error, Debug, PartialEq)]
63pub enum MessageError {
64 #[error("SLIM header not found")]
65 SlimHeaderNotFound,
66 #[error("source not found")]
67 SourceNotFound,
68 #[error("destination not found")]
69 DestinationNotFound,
70 #[error("session header not found")]
71 SessionHeaderNotFound,
72 #[error("message type not found")]
73 MessageTypeNotFound,
74 #[error("incoming connection not found")]
75 IncomingConnectionNotFound,
76 #[error("content type is not set")]
77 ContentTypeNotSet,
78 #[error("content is not an application payload")]
79 NotApplicationPayload,
80 #[error("content is not a command payload")]
81 NotCommandPayload,
82 #[error("link type is not set")]
83 LinkTypeNotSet,
84 #[error("invalid command payload type: expected {expected}, got {got}")]
85 InvalidCommandPayloadType {
86 expected: Box<String>,
87 got: Box<String>,
88 },
89
90 #[error("builder error: source is required")]
92 BuilderErrorSourceRequired,
93 #[error("builder error: destination is required")]
94 BuilderErrorDestinationRequired,
95}
96
97impl From<&Name> for ProtoName {
99 fn from(name: &Name) -> Self {
100 Self {
101 name: Some(EncodedName {
102 component_0: name.components()[0],
103 component_1: name.components()[1],
104 component_2: name.components()[2],
105 component_3: name.components()[3],
106 }),
107 str_name: Some(StringName {
108 str_component_0: name.components_strings()[0].clone(),
109 str_component_1: name.components_strings()[1].clone(),
110 str_component_2: name.components_strings()[2].clone(),
111 }),
112 }
113 }
114}
115
116impl Display for MessageType {
118 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
119 match self {
120 MessageType::Publish(_) => write!(f, "publish"),
121 MessageType::Subscribe(_) => write!(f, "subscribe"),
122 MessageType::Unsubscribe(_) => write!(f, "unsubscribe"),
123 MessageType::Link(_) => write!(f, "link"),
124 MessageType::SubscriptionAck(_) => write!(f, "subscription_ack"),
125 }
126 }
127}
128
129#[derive(Debug, Clone)]
131pub struct SlimHeaderFlags {
132 pub fanout: u32,
133 pub recv_from: Option<u64>,
134 pub forward_to: Option<u64>,
135 pub incoming_conn: Option<u64>,
136 pub error: Option<bool>,
137}
138
139impl Default for SlimHeaderFlags {
140 fn default() -> Self {
141 Self {
142 fanout: 1,
143 recv_from: None,
144 forward_to: None,
145 incoming_conn: None,
146 error: None,
147 }
148 }
149}
150
151impl Display for SlimHeaderFlags {
152 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
153 write!(
154 f,
155 "fanout: {}, recv_from: {:?}, forward_to: {:?}, incoming_conn: {:?}, error: {:?}",
156 self.fanout, self.recv_from, self.forward_to, self.incoming_conn, self.error
157 )
158 }
159}
160
161impl SlimHeaderFlags {
162 pub fn new(
163 fanout: u32,
164 recv_from: Option<u64>,
165 forward_to: Option<u64>,
166 incoming_conn: Option<u64>,
167 error: Option<bool>,
168 ) -> Self {
169 Self {
170 fanout,
171 recv_from,
172 forward_to,
173 incoming_conn,
174 error,
175 }
176 }
177
178 pub fn with_fanout(self, fanout: u32) -> Self {
179 Self { fanout, ..self }
180 }
181
182 pub fn with_recv_from(self, recv_from: u64) -> Self {
183 Self {
184 recv_from: Some(recv_from),
185 ..self
186 }
187 }
188
189 pub fn with_forward_to(self, forward_to: u64) -> Self {
190 Self {
191 forward_to: Some(forward_to),
192 ..self
193 }
194 }
195
196 pub fn with_incoming_conn(self, incoming_conn: u64) -> Self {
197 Self {
198 incoming_conn: Some(incoming_conn),
199 ..self
200 }
201 }
202
203 pub fn with_error(self, error: bool) -> Self {
204 Self {
205 error: Some(error),
206 ..self
207 }
208 }
209}
210
211impl SlimHeader {
215 pub fn new(
216 source: &Name,
217 destination: &Name,
218 identity: &str,
219 flags: Option<SlimHeaderFlags>,
220 ) -> Self {
221 let flags = flags.unwrap_or_default();
222
223 Self {
224 source: Some(ProtoName::from(source)),
225 destination: Some(ProtoName::from(destination)),
226 identity: identity.to_string(),
227 fanout: flags.fanout,
228 recv_from: flags.recv_from,
229 forward_to: flags.forward_to,
230 incoming_conn: flags.incoming_conn,
231 error: flags.error,
232 }
233 }
234
235 pub fn clear_flags(&mut self) {
236 self.recv_from = None;
237 self.forward_to = None;
238 }
239
240 pub fn get_fanout(&self) -> u32 {
241 self.fanout
242 }
243
244 pub fn get_recv_from(&self) -> Option<u64> {
245 self.recv_from
246 }
247
248 pub fn get_forward_to(&self) -> Option<u64> {
249 self.forward_to
250 }
251
252 pub fn get_incoming_conn(&self) -> Option<u64> {
253 self.incoming_conn
254 }
255
256 pub fn get_error(&self) -> Option<bool> {
257 self.error
258 }
259
260 pub fn get_source(&self) -> Name {
261 match &self.source {
262 Some(source) => Name::from(source),
263 None => panic!("source not found"),
264 }
265 }
266
267 pub fn get_dst(&self) -> Name {
268 match &self.destination {
269 Some(destination) => Name::from(destination),
270 None => panic!("destination not found"),
271 }
272 }
273
274 pub fn get_identity(&self) -> String {
275 self.identity.clone()
276 }
277
278 pub fn set_source(&mut self, source: &Name) {
279 self.source = Some(ProtoName::from(source));
280 }
281
282 pub fn set_destination(&mut self, dst: &Name) {
283 self.destination = Some(ProtoName::from(dst));
284 }
285
286 pub fn set_identity(&mut self, identity: String) {
287 self.identity = identity;
288 }
289
290 pub fn set_fanout(&mut self, fanout: u32) {
291 self.fanout = fanout;
292 }
293
294 pub fn set_recv_from(&mut self, recv_from: Option<u64>) {
295 self.recv_from = recv_from;
296 }
297
298 pub fn set_forward_to(&mut self, forward_to: Option<u64>) {
299 self.forward_to = forward_to;
300 }
301
302 pub fn set_error(&mut self, error: Option<bool>) {
303 self.error = error;
304 }
305
306 pub fn set_incoming_conn(&mut self, incoming_conn: Option<u64>) {
307 self.incoming_conn = incoming_conn;
308 }
309
310 pub fn set_error_flag(&mut self, error: Option<bool>) {
311 self.error = error;
312 }
313
314 pub(crate) fn get_connections(&self) -> (u64, Option<u64>, Option<u64>) {
316 let incoming = self
318 .get_incoming_conn()
319 .expect("incoming connection not found");
320
321 (incoming, self.get_recv_from(), self.get_forward_to())
322 }
323}
324
325impl SessionHeader {
329 pub fn new(
330 session_type: i32,
331 session_message_type: i32,
332 session_id: u32,
333 message_id: u32,
334 ) -> Self {
335 Self {
336 session_type,
337 session_message_type,
338 session_id,
339 message_id,
340 }
341 }
342
343 pub fn get_session_id(&self) -> u32 {
344 self.session_id
345 }
346
347 pub fn get_message_id(&self) -> u32 {
348 self.message_id
349 }
350
351 pub fn set_session_id(&mut self, session_id: u32) {
352 self.session_id = session_id;
353 }
354
355 pub fn set_message_id(&mut self, message_id: u32) {
356 self.message_id = message_id;
357 }
358
359 pub fn clear(&mut self) {
360 self.session_id = 0;
361 self.message_id = 0;
362 }
363}
364
365impl SessionMessageType {
368 pub fn is_command_message(&self) -> bool {
370 matches!(
371 self,
372 SessionMessageType::DiscoveryRequest
373 | SessionMessageType::DiscoveryReply
374 | SessionMessageType::JoinRequest
375 | SessionMessageType::JoinReply
376 | SessionMessageType::LeaveRequest
377 | SessionMessageType::LeaveReply
378 | SessionMessageType::GroupAdd
379 | SessionMessageType::GroupRemove
380 | SessionMessageType::GroupWelcome
381 | SessionMessageType::GroupClose
382 | SessionMessageType::GroupProposal
383 | SessionMessageType::GroupAck
384 | SessionMessageType::GroupNack
385 | SessionMessageType::Ping
386 )
387 }
388}
389
390impl ProtoSubscribe {
393 fn new(
394 source: &Name,
395 dst: &Name,
396 identity: Option<&str>,
397 flags: Option<SlimHeaderFlags>,
398 ) -> Self {
399 let id = identity.unwrap_or("");
400 let header = Some(SlimHeader::new(source, dst, id, flags));
401
402 ProtoSubscribe {
403 header,
404 subscription_id: 0,
405 }
406 }
407}
408
409impl From<ProtoMessage> for ProtoSubscribe {
411 fn from(message: ProtoMessage) -> Self {
412 match message.message_type {
413 Some(ProtoSubscribeType(s)) => s,
414 _ => panic!("message type is not subscribe"),
415 }
416 }
417}
418
419impl ProtoUnsubscribe {
422 fn new(
423 source: &Name,
424 dst: &Name,
425 identity: Option<&str>,
426 flags: Option<SlimHeaderFlags>,
427 ) -> Self {
428 let id = identity.unwrap_or("");
429 let header = Some(SlimHeader::new(source, dst, id, flags));
430
431 ProtoUnsubscribe {
432 header,
433 subscription_id: 0,
434 }
435 }
436}
437
438impl From<ProtoMessage> for ProtoUnsubscribe {
440 fn from(message: ProtoMessage) -> Self {
441 match message.message_type {
442 Some(ProtoUnsubscribeType(u)) => u,
443 _ => panic!("message type is not unsubscribe"),
444 }
445 }
446}
447
448impl ProtoPublish {
451 fn with_header(
452 header: Option<SlimHeader>,
453 session: Option<SessionHeader>,
454 payload: Option<Content>,
455 ) -> Self {
456 ProtoPublish {
457 header,
458 session,
459 msg: payload,
460 }
461 }
462
463 pub fn get_slim_header(&self) -> &SlimHeader {
464 self.header.as_ref().unwrap()
465 }
466
467 pub fn get_session_header(&self) -> &SessionHeader {
468 self.session.as_ref().unwrap()
469 }
470
471 pub fn get_slim_header_as_mut(&mut self) -> &mut SlimHeader {
472 self.header.as_mut().unwrap()
473 }
474
475 pub fn get_session_header_as_mut(&mut self) -> &mut SessionHeader {
476 self.session.as_mut().unwrap()
477 }
478
479 pub fn get_payload(&self) -> &Content {
480 self.msg.as_ref().unwrap()
481 }
482
483 pub fn set_payload(&mut self, payload: Content) {
484 self.msg = Some(payload);
485 }
486
487 pub fn is_command(&self) -> bool {
488 match &self.get_payload().content_type.as_ref().unwrap() {
489 ContentType::AppPayload(_) => false,
490 ContentType::CommandPayload(_) => true,
491 }
492 }
493
494 pub fn get_application_payload(&self) -> &ApplicationPayload {
495 match self.get_payload().content_type.as_ref().unwrap() {
496 ContentType::AppPayload(application_payload) => application_payload,
497 ContentType::CommandPayload(_) => panic!("the payload is not an application payload"),
498 }
499 }
500
501 pub fn get_command_payload(&self) -> &CommandPayload {
502 match &self.get_payload().content_type.as_ref().unwrap() {
503 ContentType::AppPayload(_) => panic!("the payaoad is not a command payload"),
504 ContentType::CommandPayload(command_payload) => command_payload,
505 }
506 }
507}
508
509impl From<ProtoMessage> for ProtoPublish {
511 fn from(message: ProtoMessage) -> Self {
512 match message.message_type {
513 Some(ProtoPublishType(p)) => p,
514 _ => panic!("message type is not publish"),
515 }
516 }
517}
518
519macro_rules! impl_payload_extractors {
523 ($($method_name:ident => $getter_method:ident($payload_type:ty)),* $(,)?) => {
524 $(
525 pub fn $method_name(&self) -> Result<&$payload_type, MessageError> {
527 self.extract_command_payload()?.$getter_method()
528 }
529 )*
530 };
531}
532
533impl ProtoMessage {
534 fn new(metadata: HashMap<String, String>, message_type: MessageType) -> Self {
535 ProtoMessage {
536 metadata,
537 message_type: Some(message_type),
538 }
539 }
540
541 fn validate_link(link: &ProtoLink) -> Result<(), MessageError> {
542 if link.link_type.is_none() {
543 return Err(MessageError::LinkTypeNotSet);
544 }
545 Ok(())
546 }
547
548 fn validate_routed_header(slim_header: &SlimHeader) -> Result<(), MessageError> {
549 if slim_header.source.is_none() {
550 return Err(MessageError::SourceNotFound);
551 }
552 if slim_header.destination.is_none() {
553 return Err(MessageError::DestinationNotFound);
554 }
555 Ok(())
556 }
557
558 fn validate_publish(p: &ProtoPublish) -> Result<(), MessageError> {
559 let hdr = p.header.as_ref().ok_or(MessageError::SlimHeaderNotFound)?;
560 Self::validate_routed_header(hdr)?;
561 if p.session.is_none() {
562 return Err(MessageError::SessionHeaderNotFound);
563 }
564 Ok(())
565 }
566
567 fn validate_subscribe(s: &ProtoSubscribe) -> Result<(), MessageError> {
568 let hdr = s.header.as_ref().ok_or(MessageError::SlimHeaderNotFound)?;
569 Self::validate_routed_header(hdr)
570 }
571
572 fn validate_unsubscribe(u: &ProtoUnsubscribe) -> Result<(), MessageError> {
573 let hdr = u.header.as_ref().ok_or(MessageError::SlimHeaderNotFound)?;
574 Self::validate_routed_header(hdr)
575 }
576
577 pub fn validate(&self) -> Result<(), MessageError> {
579 match &self.message_type {
580 None => Err(MessageError::MessageTypeNotFound),
581 Some(ProtoLinkMessageType(link)) => Self::validate_link(link),
582 Some(ProtoPublishType(p)) => Self::validate_publish(p),
583 Some(ProtoSubscribeType(s)) => Self::validate_subscribe(s),
584 Some(ProtoUnsubscribeType(u)) => Self::validate_unsubscribe(u),
585 Some(ProtoSubscriptionAckType(_)) => Ok(()),
586 }
587 }
588
589 pub fn insert_metadata(&mut self, key: String, val: String) {
592 self.metadata.insert(key, val);
593 }
594
595 pub fn remove_metadata(&mut self, key: &str) -> Option<String> {
597 self.metadata.remove(key)
598 }
599
600 pub fn contains_metadata(&self, key: &str) -> bool {
601 self.metadata.contains_key(key)
602 }
603
604 pub fn get_metadata(&self, key: &str) -> Option<&String> {
605 self.metadata.get(key)
606 }
607
608 pub fn get_metadata_map(&self) -> HashMap<String, String> {
609 self.metadata.clone()
610 }
611
612 pub fn set_metadata_map(&mut self, map: HashMap<String, String>) {
613 for (k, v) in map.iter() {
614 self.insert_metadata(k.to_string(), v.to_string());
615 }
616 }
617
618 pub fn get_slim_header(&self) -> &SlimHeader {
619 match &self.message_type {
620 Some(ProtoPublishType(publish)) => publish.header.as_ref().unwrap(),
621 Some(ProtoSubscribeType(sub)) => sub.header.as_ref().unwrap(),
622 Some(ProtoUnsubscribeType(unsub)) => unsub.header.as_ref().unwrap(),
623 Some(ProtoLinkMessageType(_)) | Some(ProtoSubscriptionAckType(_)) | None => {
624 panic!("SLIM header not found")
625 }
626 }
627 }
628
629 pub fn get_slim_header_mut(&mut self) -> &mut SlimHeader {
630 match &mut self.message_type {
631 Some(ProtoPublishType(publish)) => publish.header.as_mut().unwrap(),
632 Some(ProtoSubscribeType(sub)) => sub.header.as_mut().unwrap(),
633 Some(ProtoUnsubscribeType(unsub)) => unsub.header.as_mut().unwrap(),
634 Some(ProtoLinkMessageType(_)) | Some(ProtoSubscriptionAckType(_)) | None => {
635 panic!("SLIM header not found")
636 }
637 }
638 }
639
640 pub fn try_get_slim_header(&self) -> Option<&SlimHeader> {
641 match &self.message_type {
642 Some(ProtoPublishType(publish)) => publish.header.as_ref(),
643 Some(ProtoSubscribeType(sub)) => sub.header.as_ref(),
644 Some(ProtoUnsubscribeType(unsub)) => unsub.header.as_ref(),
645 Some(ProtoLinkMessageType(_)) | Some(ProtoSubscriptionAckType(_)) | None => None,
646 }
647 }
648
649 pub fn get_session_header(&self) -> &SessionHeader {
650 match &self.message_type {
651 Some(ProtoPublishType(publish)) => publish.session.as_ref().unwrap(),
652 Some(ProtoSubscribeType(_))
653 | Some(ProtoUnsubscribeType(_))
654 | Some(ProtoLinkMessageType(_))
655 | Some(ProtoSubscriptionAckType(_))
656 | None => panic!("session header not found"),
657 }
658 }
659
660 pub fn get_session_header_mut(&mut self) -> &mut SessionHeader {
661 match &mut self.message_type {
662 Some(ProtoPublishType(publish)) => publish.session.as_mut().unwrap(),
663 Some(ProtoSubscribeType(_))
664 | Some(ProtoUnsubscribeType(_))
665 | Some(ProtoLinkMessageType(_))
666 | Some(ProtoSubscriptionAckType(_))
667 | None => panic!("session header not found"),
668 }
669 }
670
671 pub fn try_get_session_header(&self) -> Option<&SessionHeader> {
672 match &self.message_type {
673 Some(ProtoPublishType(publish)) => publish.session.as_ref(),
674 Some(ProtoSubscribeType(_))
675 | Some(ProtoUnsubscribeType(_))
676 | Some(ProtoLinkMessageType(_))
677 | Some(ProtoSubscriptionAckType(_))
678 | None => None,
679 }
680 }
681
682 pub fn try_get_session_header_mut(&mut self) -> Option<&mut SessionHeader> {
683 match &mut self.message_type {
684 Some(ProtoPublishType(publish)) => publish.session.as_mut(),
685 Some(ProtoSubscribeType(_))
686 | Some(ProtoUnsubscribeType(_))
687 | Some(ProtoLinkMessageType(_))
688 | Some(ProtoSubscriptionAckType(_))
689 | None => None,
690 }
691 }
692
693 pub fn get_id(&self) -> u32 {
694 self.get_session_header().get_message_id()
695 }
696
697 pub fn get_source(&self) -> Name {
698 self.get_slim_header().get_source()
699 }
700
701 pub fn get_dst(&self) -> Name {
702 self.get_slim_header().get_dst()
703 }
704
705 pub fn get_identity(&self) -> String {
706 self.get_slim_header().get_identity()
707 }
708
709 pub fn get_fanout(&self) -> u32 {
710 self.get_slim_header().get_fanout()
711 }
712
713 pub fn get_recv_from(&self) -> Option<u64> {
714 self.get_slim_header().get_recv_from()
715 }
716
717 pub fn get_forward_to(&self) -> Option<u64> {
718 self.get_slim_header().get_forward_to()
719 }
720
721 pub fn get_error(&self) -> Option<bool> {
722 self.get_slim_header().get_error()
723 }
724
725 pub fn get_incoming_conn(&self) -> u64 {
726 self.get_slim_header().get_incoming_conn().unwrap()
727 }
728
729 pub fn try_get_incoming_conn(&self) -> Option<u64> {
730 self.get_slim_header().get_incoming_conn()
731 }
732
733 pub fn get_type(&self) -> &MessageType {
734 match &self.message_type {
735 Some(t) => t,
736 None => panic!("message type not found"),
737 }
738 }
739
740 pub fn get_payload(&self) -> Option<&Content> {
741 match &self.message_type {
742 Some(ProtoPublishType(p)) => p.msg.as_ref(),
743 Some(ProtoSubscribeType(_)) => panic!("payload not found"),
744 Some(ProtoUnsubscribeType(_)) => panic!("payload not found"),
745 Some(ProtoLinkMessageType(_)) => panic!("payload not found"),
746 Some(ProtoSubscriptionAckType(_)) => panic!("payload not found"),
747 None => panic!("payload not found"),
748 }
749 }
750
751 pub fn set_payload(&mut self, payload: Content) {
752 match &mut self.message_type {
753 Some(ProtoPublishType(p)) => p.set_payload(payload),
754 Some(ProtoSubscribeType(_)) => panic!("no payload allowed"),
755 Some(ProtoUnsubscribeType(_)) => panic!("no payload allowed"),
756 Some(ProtoLinkMessageType(_)) => panic!("no payload allowed"),
757 Some(ProtoSubscriptionAckType(_)) => panic!("no payload allowed"),
758 None => panic!("no payload allowed"),
759 }
760 }
761
762 pub fn get_session_message_type(&self) -> SessionMessageType {
763 self.get_session_header()
764 .session_message_type
765 .try_into()
766 .unwrap_or_default()
767 }
768
769 pub fn clear_slim_header(&mut self) {
770 if self.is_link() || self.is_subscription_ack() {
771 return;
772 }
773 self.get_slim_header_mut().clear_flags();
774 }
775
776 pub fn set_recv_from(&mut self, recv_from: Option<u64>) {
777 self.get_slim_header_mut().set_recv_from(recv_from);
778 }
779
780 pub fn set_forward_to(&mut self, forward_to: Option<u64>) {
781 self.get_slim_header_mut().set_forward_to(forward_to);
782 }
783
784 pub fn set_error(&mut self, error: Option<bool>) {
785 self.get_slim_header_mut().set_error(error);
786 }
787
788 pub fn set_fanout(&mut self, fanout: u32) {
789 self.get_slim_header_mut().set_fanout(fanout);
790 }
791
792 pub fn set_incoming_conn(&mut self, incoming_conn: Option<u64>) {
793 self.get_slim_header_mut().set_incoming_conn(incoming_conn);
794 }
795
796 pub fn set_error_flag(&mut self, error: Option<bool>) {
797 self.get_slim_header_mut().set_error_flag(error);
798 }
799
800 pub fn set_session_message_type(&mut self, message_type: SessionMessageType) {
801 self.get_session_header_mut()
802 .set_session_message_type(message_type);
803 }
804
805 pub fn set_session_type(&mut self, session_type: ProtoSessionType) {
806 self.get_session_header_mut().set_session_type(session_type);
807 }
808
809 pub fn get_session_type(&self) -> ProtoSessionType {
810 self.get_session_header().session_type()
811 }
812
813 pub fn set_message_id(&mut self, message_id: u32) {
814 self.get_session_header_mut().set_message_id(message_id);
815 }
816
817 pub fn is_publish(&self) -> bool {
818 matches!(self.get_type(), MessageType::Publish(_))
819 }
820
821 pub fn is_subscribe(&self) -> bool {
822 matches!(self.get_type(), MessageType::Subscribe(_))
823 }
824
825 pub fn is_unsubscribe(&self) -> bool {
826 matches!(self.get_type(), MessageType::Unsubscribe(_))
827 }
828
829 pub fn is_link(&self) -> bool {
830 matches!(self.get_type(), MessageType::Link(_))
831 }
832
833 pub fn is_subscription_ack(&self) -> bool {
834 matches!(self.get_type(), MessageType::SubscriptionAck(_))
835 }
836
837 pub fn is_traceable(&self) -> bool {
838 !self.is_link() && !self.is_subscription_ack()
839 }
840
841 pub fn get_subscription_ack(&self) -> &ProtoSubscriptionAck {
842 match &self.message_type {
843 Some(ProtoSubscriptionAckType(ack)) => ack,
844 _ => panic!("message type is not subscription_ack"),
845 }
846 }
847
848 pub fn get_subscription_id(&self) -> Option<u64> {
850 match &self.message_type {
851 Some(ProtoSubscribeType(s)) if s.subscription_id != 0 => Some(s.subscription_id),
852 Some(ProtoUnsubscribeType(u)) if u.subscription_id != 0 => Some(u.subscription_id),
853 _ => None,
854 }
855 }
856
857 pub fn take_subscription_id(&mut self) -> Option<u64> {
860 match &mut self.message_type {
861 Some(ProtoSubscribeType(s)) if s.subscription_id != 0 => {
862 Some(std::mem::take(&mut s.subscription_id))
863 }
864 Some(ProtoUnsubscribeType(u)) if u.subscription_id != 0 => {
865 Some(std::mem::take(&mut u.subscription_id))
866 }
867 _ => None,
868 }
869 }
870
871 pub fn set_subscription_id(&mut self, subscription_id: u64) {
873 match &mut self.message_type {
874 Some(ProtoSubscribeType(s)) => s.subscription_id = subscription_id,
875 Some(ProtoUnsubscribeType(u)) => u.subscription_id = subscription_id,
876 _ => {}
877 }
878 }
879
880 pub fn extract_command_payload(&self) -> Result<&CommandPayload, MessageError> {
885 self.get_payload()
886 .ok_or(MessageError::ContentTypeNotSet)?
887 .as_command_payload()
888 }
889
890 impl_payload_extractors! {
892 extract_discovery_request => as_discovery_request_payload(DiscoveryRequestPayload),
893 extract_discovery_reply => as_discovery_reply_payload(DiscoveryReplyPayload),
894 extract_join_request => as_join_request_payload(JoinRequestPayload),
895 extract_join_reply => as_join_reply_payload(JoinReplyPayload),
896 extract_leave_request => as_leave_request_payload(LeaveRequestPayload),
897 extract_leave_reply => as_leave_reply_payload(LeaveReplyPayload),
898 extract_group_add => as_group_add_payload(GroupAddPayload),
899 extract_group_remove => as_group_remove_payload(GroupRemovePayload),
900 extract_group_welcome => as_welcome_payload(GroupWelcomePayload),
901 extract_group_close => as_group_close_payload(GroupClosePayload),
902 extract_group_proposal => as_group_proposal_payload(GroupProposalPayload),
903 extract_group_ack => as_group_ack_payload(GroupAckPayload),
904 extract_group_nack => as_group_nack_payload(GroupNackPayload),
905 extract_ping => as_ping_payload(PingPayload),
906 }
907}
908
909impl Content {
910 pub fn as_application_payload(&self) -> Result<&ApplicationPayload, MessageError> {
911 match &self.content_type {
912 Some(ContentType::AppPayload(app_payload)) => Ok(app_payload),
913 Some(ContentType::CommandPayload(_)) => Err(MessageError::NotApplicationPayload),
914 None => Err(MessageError::ContentTypeNotSet),
915 }
916 }
917
918 pub fn as_command_payload(&self) -> Result<&CommandPayload, MessageError> {
919 match &self.content_type {
920 Some(ContentType::AppPayload(_)) => Err(MessageError::NotCommandPayload),
921 Some(ContentType::CommandPayload(comm_payload)) => Ok(comm_payload),
922 None => Err(MessageError::ContentTypeNotSet),
923 }
924 }
925}
926
927impl ApplicationPayload {
928 pub fn new(payload_type: &str, blob: Vec<u8>) -> Self {
929 Self {
930 payload_type: payload_type.to_string(),
931 blob,
932 }
933 }
934
935 pub fn as_content(&self) -> Content {
936 Content {
937 content_type: Some(ContentType::AppPayload(self.clone())),
938 }
939 }
940}
941
942macro_rules! impl_command_payload_getters {
944 ($(
945 $method_name:ident => $variant:ident($payload_type:ty)
946 ),* $(,)?) => {
947 $(
948 pub fn $method_name(&self) -> Result<&$payload_type, MessageError> {
949 match &self.command_payload_type {
950 Some(CommandPayloadType::$variant(payload)) => Ok(payload),
951 Some(other) => Err(MessageError::InvalidCommandPayloadType {
952 expected: Box::new(stringify!($variant).to_string()),
953 got: Box::new(format!("{:?}", other)),
954 }),
955 None => Err(MessageError::InvalidCommandPayloadType {
956 expected: Box::new(stringify!($variant).to_string()),
957 got: Box::new("None".to_string()),
958 }),
959 }
960 }
961 )*
962 };
963}
964
965impl CommandPayload {
966 pub fn as_content(self) -> Content {
967 Content {
968 content_type: Some(ContentType::CommandPayload(self)),
969 }
970 }
971
972 impl_command_payload_getters! {
974 as_discovery_request_payload => DiscoveryRequest(DiscoveryRequestPayload),
975 as_discovery_reply_payload => DiscoveryReply(DiscoveryReplyPayload),
976 as_join_request_payload => JoinRequest(JoinRequestPayload),
977 as_join_reply_payload => JoinReply(JoinReplyPayload),
978 as_leave_request_payload => LeaveRequest(LeaveRequestPayload),
979 as_leave_reply_payload => LeaveReply(LeaveReplyPayload),
980 as_group_add_payload => GroupAdd(GroupAddPayload),
981 as_group_remove_payload => GroupRemove(GroupRemovePayload),
982 as_welcome_payload => GroupWelcome(GroupWelcomePayload),
983 as_group_close_payload => GroupClose(GroupClosePayload),
984 as_group_proposal_payload => GroupProposal(GroupProposalPayload),
985 as_group_ack_payload => GroupAck(GroupAckPayload),
986 as_group_nack_payload => GroupNack(GroupNackPayload),
987 as_ping_payload => Ping(PingPayload),
988 }
989}
990
991impl AsRef<ProtoPublish> for ProtoMessage {
992 fn as_ref(&self) -> &ProtoPublish {
993 match &self.message_type {
994 Some(ProtoPublishType(p)) => p,
995 _ => panic!("message type is not publish"),
996 }
997 }
998}
999
1000pub struct CommandPayloadBuilder;
1049
1050impl CommandPayloadBuilder {
1051 pub fn new() -> Self {
1053 Self
1054 }
1055
1056 pub fn discovery_request(self, destination: Option<Name>) -> CommandPayload {
1058 let proto_destination = destination.as_ref().map(ProtoName::from);
1059 let payload = DiscoveryRequestPayload {
1060 destination: proto_destination,
1061 };
1062 CommandPayload {
1063 command_payload_type: Some(CommandPayloadType::DiscoveryRequest(payload)),
1064 }
1065 }
1066
1067 pub fn discovery_reply(self) -> CommandPayload {
1069 let payload = DiscoveryReplyPayload {};
1070 CommandPayload {
1071 command_payload_type: Some(CommandPayloadType::DiscoveryReply(payload)),
1072 }
1073 }
1074
1075 pub fn join_request(
1077 self,
1078 enable_mls: bool,
1079 max_retries: Option<u32>,
1080 timer_duration: Option<Duration>,
1081 channel: Option<Name>,
1082 ) -> CommandPayload {
1083 let proto_channel = channel.as_ref().map(ProtoName::from);
1084
1085 let timer_settings = if let Some(t) = timer_duration
1086 && let Some(m) = max_retries
1087 {
1088 Some(TimerSettings {
1089 timeout: t.as_millis() as u32,
1090 max_retries: m,
1091 })
1092 } else {
1093 None
1094 };
1095
1096 let payload = JoinRequestPayload {
1097 enable_mls,
1098 timer_settings,
1099 channel: proto_channel,
1100 };
1101 CommandPayload {
1102 command_payload_type: Some(CommandPayloadType::JoinRequest(payload)),
1103 }
1104 }
1105
1106 pub fn join_reply(self, key_package: Option<Vec<u8>>) -> CommandPayload {
1108 let payload = JoinReplyPayload { key_package };
1109 CommandPayload {
1110 command_payload_type: Some(CommandPayloadType::JoinReply(payload)),
1111 }
1112 }
1113
1114 pub fn leave_request(self, destination: Option<Name>) -> CommandPayload {
1116 let proto_destination = destination.as_ref().map(ProtoName::from);
1117 let payload = LeaveRequestPayload {
1118 destination: proto_destination,
1119 };
1120 CommandPayload {
1121 command_payload_type: Some(CommandPayloadType::LeaveRequest(payload)),
1122 }
1123 }
1124
1125 pub fn leave_reply(self) -> CommandPayload {
1127 let payload = LeaveReplyPayload {};
1128 CommandPayload {
1129 command_payload_type: Some(CommandPayloadType::LeaveReply(payload)),
1130 }
1131 }
1132
1133 pub fn group_add(
1135 self,
1136 new_participant: Name,
1137 participants: Vec<Name>,
1138 mls: Option<MlsPayload>,
1139 ) -> CommandPayload {
1140 let proto_new_participant = Some(ProtoName::from(&new_participant));
1141 let proto_participants = participants.iter().map(ProtoName::from).collect();
1142
1143 let payload = GroupAddPayload {
1144 new_participant: proto_new_participant,
1145 participants: proto_participants,
1146 mls,
1147 };
1148 CommandPayload {
1149 command_payload_type: Some(CommandPayloadType::GroupAdd(payload)),
1150 }
1151 }
1152
1153 pub fn group_remove(
1155 self,
1156 removed_participant: Name,
1157 participants: Vec<Name>,
1158 mls: Option<MlsPayload>,
1159 ) -> CommandPayload {
1160 let proto_removed_participant = Some(ProtoName::from(&removed_participant));
1161 let proto_participants = participants.iter().map(ProtoName::from).collect();
1162
1163 let payload = GroupRemovePayload {
1164 removed_participant: proto_removed_participant,
1165 participants: proto_participants,
1166 mls,
1167 };
1168 CommandPayload {
1169 command_payload_type: Some(CommandPayloadType::GroupRemove(payload)),
1170 }
1171 }
1172
1173 pub fn group_welcome(self, participants: Vec<Name>, mls: Option<MlsPayload>) -> CommandPayload {
1175 let proto_participants = participants.iter().map(ProtoName::from).collect();
1176
1177 let payload = GroupWelcomePayload {
1178 participants: proto_participants,
1179 mls,
1180 };
1181 CommandPayload {
1182 command_payload_type: Some(CommandPayloadType::GroupWelcome(payload)),
1183 }
1184 }
1185
1186 pub fn group_close(self, participants: Vec<Name>) -> CommandPayload {
1188 let proto_participants = participants.iter().map(ProtoName::from).collect();
1189
1190 let payload = GroupClosePayload {
1191 participants: proto_participants,
1192 };
1193 CommandPayload {
1194 command_payload_type: Some(CommandPayloadType::GroupClose(payload)),
1195 }
1196 }
1197
1198 pub fn group_proposal(self, source: Option<Name>, mls_proposal: Vec<u8>) -> CommandPayload {
1200 let proto_source = source.as_ref().map(ProtoName::from);
1201 let payload = GroupProposalPayload {
1202 source: proto_source,
1203 mls_proposal,
1204 };
1205 CommandPayload {
1206 command_payload_type: Some(CommandPayloadType::GroupProposal(payload)),
1207 }
1208 }
1209
1210 pub fn group_ack(self) -> CommandPayload {
1212 let payload = GroupAckPayload {};
1213 CommandPayload {
1214 command_payload_type: Some(CommandPayloadType::GroupAck(payload)),
1215 }
1216 }
1217
1218 pub fn group_nack(self) -> CommandPayload {
1220 let payload = GroupNackPayload {};
1221 CommandPayload {
1222 command_payload_type: Some(CommandPayloadType::GroupNack(payload)),
1223 }
1224 }
1225
1226 pub fn ping(self) -> CommandPayload {
1228 let payload = PingPayload {};
1229 CommandPayload {
1230 command_payload_type: Some(CommandPayloadType::Ping(payload)),
1231 }
1232 }
1233}
1234
1235impl Default for CommandPayloadBuilder {
1236 fn default() -> Self {
1237 Self::new()
1238 }
1239}
1240
1241impl CommandPayload {
1242 pub fn builder() -> CommandPayloadBuilder {
1244 CommandPayloadBuilder::new()
1245 }
1246}
1247
1248pub struct ProtoMessageBuilder {
1334 source: Option<Name>,
1335 destination: Option<Name>,
1336 identity: Option<String>,
1337 flags: Option<SlimHeaderFlags>,
1338 session_type: Option<ProtoSessionType>,
1339 session_message_type: Option<SessionMessageType>,
1340 session_id: Option<u32>,
1341 message_id: Option<u32>,
1342 payload: Option<Content>,
1343 metadata: HashMap<String, String>,
1344 subscription_id: Option<u64>,
1345}
1346
1347impl ProtoMessageBuilder {
1348 pub fn new() -> Self {
1350 Self {
1351 source: None,
1352 destination: None,
1353 identity: None,
1354 flags: None,
1355 session_type: None,
1356 session_message_type: None,
1357 session_id: None,
1358 message_id: None,
1359 payload: None,
1360 metadata: HashMap::new(),
1361 subscription_id: None,
1362 }
1363 }
1364
1365 pub fn source(mut self, source: Name) -> Self {
1367 self.source = Some(source);
1368 self
1369 }
1370
1371 pub fn destination(mut self, destination: Name) -> Self {
1373 self.destination = Some(destination);
1374 self
1375 }
1376
1377 pub fn identity(mut self, identity: impl Into<String>) -> Self {
1379 self.identity = Some(identity.into());
1380 self
1381 }
1382
1383 pub fn flags(mut self, flags: SlimHeaderFlags) -> Self {
1385 self.flags = Some(flags);
1386 self
1387 }
1388
1389 pub fn fanout(mut self, fanout: u32) -> Self {
1391 let flags = self.flags.take().unwrap_or_default();
1392 self.flags = Some(flags.with_fanout(fanout));
1393 self
1394 }
1395
1396 pub fn recv_from(mut self, recv_from: u64) -> Self {
1398 let flags = self.flags.take().unwrap_or_default();
1399 self.flags = Some(flags.with_recv_from(recv_from));
1400 self
1401 }
1402
1403 pub fn forward_to(mut self, forward_to: u64) -> Self {
1405 let flags = self.flags.take().unwrap_or_default();
1406 self.flags = Some(flags.with_forward_to(forward_to));
1407 self
1408 }
1409
1410 pub fn incoming_conn(mut self, incoming_conn: u64) -> Self {
1412 let flags = self.flags.take().unwrap_or_default();
1413 self.flags = Some(flags.with_incoming_conn(incoming_conn));
1414 self
1415 }
1416
1417 pub fn error(mut self, error: bool) -> Self {
1419 let flags = self.flags.take().unwrap_or_default();
1420 self.flags = Some(flags.with_error(error));
1421 self
1422 }
1423
1424 pub fn session_type(mut self, session_type: ProtoSessionType) -> Self {
1426 self.session_type = Some(session_type);
1427 self
1428 }
1429
1430 pub fn session_message_type(mut self, session_message_type: SessionMessageType) -> Self {
1432 self.session_message_type = Some(session_message_type);
1433 self
1434 }
1435
1436 pub fn session_id(mut self, session_id: u32) -> Self {
1438 self.session_id = Some(session_id);
1439 self
1440 }
1441
1442 pub fn message_id(mut self, message_id: u32) -> Self {
1444 self.message_id = Some(message_id);
1445 self
1446 }
1447
1448 pub fn payload(mut self, payload: Content) -> Self {
1450 self.payload = Some(payload);
1451 self
1452 }
1453
1454 pub fn application_payload(mut self, payload_type: &str, blob: Vec<u8>) -> Self {
1456 let app_payload = ApplicationPayload::new(payload_type, blob);
1457 self.payload = Some(app_payload.as_content());
1458 self
1459 }
1460
1461 pub fn command_payload(mut self, payload: CommandPayload) -> Self {
1463 self.payload = Some(payload.as_content());
1464 self
1465 }
1466
1467 pub fn with_slim_header(mut self, header: SlimHeader) -> Self {
1472 if let Some(src) = &header.source {
1474 self.source = Some(Name::from(src));
1475 }
1476 if let Some(dst) = &header.destination {
1477 self.destination = Some(Name::from(dst));
1478 }
1479 if !header.identity.is_empty() {
1480 self.identity = Some(header.identity.clone());
1481 }
1482
1483 let flags = SlimHeaderFlags {
1485 fanout: header.fanout,
1486 recv_from: header.recv_from,
1487 forward_to: header.forward_to,
1488 incoming_conn: header.incoming_conn,
1489 error: header.error,
1490 };
1491 self.flags = Some(flags);
1492 self
1493 }
1494
1495 pub fn with_session_header(mut self, header: SessionHeader) -> Self {
1500 self.session_type = Some(
1501 ProtoSessionType::try_from(header.session_type)
1502 .unwrap_or(ProtoSessionType::PointToPoint),
1503 );
1504 self.session_message_type = Some(
1505 SessionMessageType::try_from(header.session_message_type)
1506 .unwrap_or(SessionMessageType::Msg),
1507 );
1508 self.session_id = Some(header.session_id);
1509 self.message_id = Some(header.message_id);
1510 self
1511 }
1512
1513 pub fn metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
1515 self.metadata.insert(key.into(), value.into());
1516 self
1517 }
1518
1519 pub fn metadata_map(mut self, map: HashMap<String, String>) -> Self {
1521 self.metadata.extend(map);
1522 self
1523 }
1524
1525 pub fn subscription_id(mut self, id: u64) -> Self {
1527 self.subscription_id = Some(id);
1528 self
1529 }
1530
1531 pub fn build_publish(self) -> Result<ProtoMessage, MessageError> {
1533 let source = self
1534 .source
1535 .ok_or(MessageError::BuilderErrorSourceRequired)?;
1536 let destination = self
1537 .destination
1538 .ok_or(MessageError::BuilderErrorDestinationRequired)?;
1539
1540 let slim_header = Some(SlimHeader::new(
1541 &source,
1542 &destination,
1543 self.identity.as_deref().unwrap_or(""),
1544 self.flags,
1545 ));
1546
1547 let session_header = if self.session_type.is_some() || self.session_message_type.is_some() {
1548 Some(SessionHeader::new(
1549 self.session_type
1550 .unwrap_or(ProtoSessionType::PointToPoint)
1551 .into(),
1552 self.session_message_type
1553 .unwrap_or(SessionMessageType::Msg)
1554 .into(),
1555 self.session_id.unwrap_or(0),
1556 self.message_id.unwrap_or_else(rand::random),
1557 ))
1558 } else {
1559 Some(SessionHeader::default())
1560 };
1561
1562 let publish = ProtoPublish::with_header(slim_header, session_header, self.payload);
1563 let message = ProtoMessage::new(self.metadata, ProtoPublishType(publish));
1564 Ok(message)
1565 }
1566
1567 pub fn build_subscribe(self) -> Result<ProtoMessage, MessageError> {
1569 let source = self
1570 .source
1571 .ok_or(MessageError::BuilderErrorSourceRequired)?;
1572 let destination = self
1573 .destination
1574 .ok_or(MessageError::BuilderErrorDestinationRequired)?;
1575
1576 let mut subscribe =
1577 ProtoSubscribe::new(&source, &destination, self.identity.as_deref(), self.flags);
1578 subscribe.subscription_id = self.subscription_id.unwrap_or_default();
1579
1580 Ok(ProtoMessage::new(
1581 self.metadata,
1582 ProtoSubscribeType(subscribe),
1583 ))
1584 }
1585
1586 pub fn build_unsubscribe(self) -> Result<ProtoMessage, MessageError> {
1588 let source = self
1589 .source
1590 .ok_or(MessageError::BuilderErrorSourceRequired)?;
1591 let destination = self
1592 .destination
1593 .ok_or(MessageError::BuilderErrorDestinationRequired)?;
1594
1595 let mut unsubscribe =
1596 ProtoUnsubscribe::new(&source, &destination, self.identity.as_deref(), self.flags);
1597 unsubscribe.subscription_id = self.subscription_id.unwrap_or_default();
1598
1599 Ok(ProtoMessage::new(
1600 self.metadata,
1601 ProtoUnsubscribeType(unsubscribe),
1602 ))
1603 }
1604
1605 pub fn build_subscription_ack(
1609 self,
1610 subscription_id: u64,
1611 success: bool,
1612 error: impl Into<String>,
1613 ) -> ProtoMessage {
1614 let ack = ProtoSubscriptionAck {
1615 subscription_id,
1616 success,
1617 error: error.into(),
1618 };
1619 ProtoMessage::new(self.metadata, ProtoSubscriptionAckType(ack))
1620 }
1621
1622 pub fn build_link_negotiation(
1625 self,
1626 link_id: impl Into<String>,
1627 slim_version: impl Into<String>,
1628 is_reply: bool,
1629 ) -> ProtoMessage {
1630 let link = ProtoLink {
1631 link_type: Some(ProtoLinkType::LinkNegotiation(LinkNegotiationPayload {
1632 link_id: link_id.into(),
1633 slim_version: slim_version.into(),
1634 is_reply,
1635 })),
1636 };
1637 ProtoMessage::new(self.metadata, ProtoLinkMessageType(link))
1638 }
1639}
1640
1641impl Default for ProtoMessageBuilder {
1642 fn default() -> Self {
1643 Self::new()
1644 }
1645}
1646
1647impl ProtoMessage {
1648 pub fn builder() -> ProtoMessageBuilder {
1650 ProtoMessageBuilder::new()
1651 }
1652}
1653
1654#[cfg(test)]
1655mod tests {
1656 use crate::{api::proto::dataplane::v1::SessionMessageType, messages::encoder::Name};
1657
1658 use super::*;
1659
1660 fn test_subscription_template(
1661 subscription: bool,
1662 source: Name,
1663 dst: Name,
1664 identity: Option<&str>,
1665 flags: Option<SlimHeaderFlags>,
1666 ) {
1667 let sub = {
1668 let mut builder = ProtoMessage::builder()
1669 .source(source.clone())
1670 .destination(dst.clone());
1671
1672 if let Some(id) = identity {
1673 builder = builder.identity(id);
1674 }
1675
1676 if let Some(f) = flags.clone() {
1677 builder = builder.flags(f);
1678 }
1679
1680 if subscription {
1681 builder.build_subscribe().unwrap()
1682 } else {
1683 builder.build_unsubscribe().unwrap()
1684 }
1685 };
1686
1687 let flags = if flags.is_none() {
1688 Some(SlimHeaderFlags::default())
1689 } else {
1690 flags
1691 };
1692
1693 assert!(!sub.is_publish());
1694 assert_eq!(sub.is_subscribe(), subscription);
1695 assert_eq!(sub.is_unsubscribe(), !subscription);
1696 assert_eq!(flags.as_ref().unwrap().recv_from, sub.get_recv_from());
1697 assert_eq!(flags.as_ref().unwrap().forward_to, sub.get_forward_to());
1698 assert_eq!(None, sub.try_get_incoming_conn());
1699 assert_eq!(source, sub.get_source());
1700 let got_name = sub.get_dst();
1701 assert_eq!(dst, got_name);
1702 }
1703
1704 fn test_publish_template(
1705 source: Name,
1706 dst: Name,
1707 identity: Option<&str>,
1708 flags: Option<SlimHeaderFlags>,
1709 ) {
1710 let mut builder = ProtoMessage::builder()
1711 .source(source.clone())
1712 .destination(dst.clone())
1713 .application_payload("str", "this is the content of the message".into());
1714
1715 if let Some(id) = identity {
1716 builder = builder.identity(id);
1717 }
1718
1719 if let Some(f) = flags.clone() {
1720 builder = builder.flags(f);
1721 }
1722
1723 let pub_msg = builder.build_publish().unwrap();
1724
1725 let flags = if flags.is_none() {
1726 Some(SlimHeaderFlags::default())
1727 } else {
1728 flags
1729 };
1730
1731 assert!(pub_msg.is_publish());
1732 assert!(!pub_msg.is_subscribe());
1733 assert!(!pub_msg.is_unsubscribe());
1734 assert_eq!(flags.as_ref().unwrap().recv_from, pub_msg.get_recv_from());
1735 assert_eq!(flags.as_ref().unwrap().forward_to, pub_msg.get_forward_to());
1736 assert_eq!(None, pub_msg.try_get_incoming_conn());
1737 assert_eq!(source, pub_msg.get_source());
1738 let got_name = pub_msg.get_dst();
1739 assert_eq!(dst, got_name);
1740 assert_eq!(flags.as_ref().unwrap().fanout, pub_msg.get_fanout());
1741 }
1742
1743 #[test]
1744 fn test_subscription() {
1745 let source = Name::from_strings(["org", "ns", "type"]).with_id(1);
1746 let dst = Name::from_strings(["org", "ns", "type"]).with_id(2);
1747
1748 test_subscription_template(true, source.clone(), dst.clone(), None, None);
1750
1751 test_subscription_template(true, source.clone(), dst.clone(), None, None);
1753
1754 test_subscription_template(
1756 true,
1757 source.clone(),
1758 dst.clone(),
1759 None,
1760 Some(SlimHeaderFlags::default().with_recv_from(50)),
1761 );
1762
1763 test_subscription_template(
1765 true,
1766 source.clone(),
1767 dst.clone(),
1768 None,
1769 Some(SlimHeaderFlags::default().with_forward_to(30)),
1770 );
1771 }
1772
1773 #[test]
1774 fn test_unsubscription() {
1775 let source = Name::from_strings(["org", "ns", "type"]).with_id(1);
1776 let dst = Name::from_strings(["org", "ns", "type"]).with_id(2);
1777
1778 test_subscription_template(false, source.clone(), dst.clone(), None, None);
1780
1781 test_subscription_template(false, source.clone(), dst.clone(), None, None);
1783
1784 test_subscription_template(
1786 false,
1787 source.clone(),
1788 dst.clone(),
1789 None,
1790 Some(SlimHeaderFlags::default().with_recv_from(50)),
1791 );
1792
1793 test_subscription_template(
1795 false,
1796 source.clone(),
1797 dst.clone(),
1798 None,
1799 Some(SlimHeaderFlags::default().with_forward_to(30)),
1800 );
1801 }
1802
1803 #[test]
1804 fn test_publish() {
1805 let source = Name::from_strings(["org", "ns", "type"]).with_id(1);
1806 let mut dst = Name::from_strings(["org", "ns", "type"]);
1807
1808 test_publish_template(
1810 source.clone(),
1811 dst.clone(),
1812 None,
1813 Some(SlimHeaderFlags::default()),
1814 );
1815
1816 dst.set_id(2);
1818 test_publish_template(
1819 source.clone(),
1820 dst.clone(),
1821 None,
1822 Some(SlimHeaderFlags::default()),
1823 );
1824 dst.reset_id();
1825
1826 test_publish_template(
1828 source.clone(),
1829 dst.clone(),
1830 None,
1831 Some(SlimHeaderFlags::default().with_recv_from(50)),
1832 );
1833
1834 test_publish_template(
1836 source.clone(),
1837 dst.clone(),
1838 None,
1839 Some(SlimHeaderFlags::default().with_forward_to(30)),
1840 );
1841
1842 test_publish_template(
1844 source.clone(),
1845 dst.clone(),
1846 None,
1847 Some(SlimHeaderFlags::default().with_fanout(2)),
1848 );
1849 }
1850
1851 #[test]
1852 fn test_conversions() {
1853 let name = Name::from_strings(["org", "ns", "type"]).with_id(1);
1855 let proto_name = ProtoName::from(&name);
1856
1857 assert_eq!(
1858 proto_name.name.as_ref().unwrap().component_0,
1859 name.components()[0]
1860 );
1861 assert_eq!(
1862 proto_name.name.as_ref().unwrap().component_1,
1863 name.components()[1]
1864 );
1865 assert_eq!(
1866 proto_name.name.as_ref().unwrap().component_2,
1867 name.components()[2]
1868 );
1869 assert_eq!(
1870 proto_name.name.as_ref().unwrap().component_3,
1871 name.components()[3]
1872 );
1873
1874 let name_from_proto = Name::from(&proto_name);
1876 assert_eq!(
1877 name_from_proto.components()[0],
1878 proto_name.name.as_ref().unwrap().component_0
1879 );
1880 assert_eq!(
1881 name_from_proto.components()[1],
1882 proto_name.name.as_ref().unwrap().component_1
1883 );
1884 assert_eq!(
1885 name_from_proto.components()[2],
1886 proto_name.name.as_ref().unwrap().component_2
1887 );
1888 assert_eq!(
1889 name_from_proto.components()[3],
1890 proto_name.name.as_ref().unwrap().component_3
1891 );
1892
1893 let dst = Name::from_strings(["org", "ns", "type"]).with_id(1);
1895 let proto_subscribe = ProtoMessage::builder()
1896 .source(name.clone())
1897 .destination(dst.clone())
1898 .flags(
1899 SlimHeaderFlags::default()
1900 .with_recv_from(2)
1901 .with_forward_to(3),
1902 )
1903 .build_subscribe()
1904 .unwrap();
1905 let proto_subscribe = ProtoSubscribe::from(proto_subscribe);
1906 assert_eq!(proto_subscribe.header.as_ref().unwrap().get_source(), name);
1907 assert_eq!(proto_subscribe.header.as_ref().unwrap().get_dst(), dst,);
1908
1909 let proto_unsubscribe = ProtoMessage::builder()
1911 .source(name.clone())
1912 .destination(dst.clone())
1913 .flags(
1914 SlimHeaderFlags::default()
1915 .with_recv_from(2)
1916 .with_forward_to(3),
1917 )
1918 .build_unsubscribe()
1919 .unwrap();
1920 let proto_unsubscribe = ProtoUnsubscribe::from(proto_unsubscribe);
1921 assert_eq!(
1922 proto_unsubscribe.header.as_ref().unwrap().get_source(),
1923 name
1924 );
1925 assert_eq!(proto_unsubscribe.header.as_ref().unwrap().get_dst(), dst);
1926
1927 let proto_publish = ProtoMessage::builder()
1929 .source(name.clone())
1930 .destination(dst.clone())
1931 .flags(
1932 SlimHeaderFlags::default()
1933 .with_recv_from(2)
1934 .with_forward_to(3),
1935 )
1936 .application_payload("str", "this is the content of the message".into())
1937 .build_publish()
1938 .unwrap();
1939 let proto_publish = ProtoPublish::from(proto_publish);
1940 assert_eq!(proto_publish.header.as_ref().unwrap().get_source(), name);
1941 assert_eq!(proto_publish.header.as_ref().unwrap().get_dst(), dst);
1942 }
1943
1944 #[test]
1945 fn test_panic() {
1946 let source = Name::from_strings(["org", "ns", "type"]).with_id(1);
1947 let dst = Name::from_strings(["org", "ns", "type"]).with_id(2);
1948
1949 let msg = ProtoMessage::builder()
1951 .source(source.clone())
1952 .destination(dst.clone())
1953 .flags(
1954 SlimHeaderFlags::default()
1955 .with_recv_from(2)
1956 .with_forward_to(3),
1957 )
1958 .build_subscribe()
1959 .unwrap();
1960
1961 let result = std::panic::catch_unwind(|| ProtoUnsubscribe::from(msg.clone()));
1964 assert!(result.is_err());
1965
1966 let result = std::panic::catch_unwind(|| ProtoPublish::from(msg.clone()));
1969 assert!(result.is_err());
1970
1971 let result = std::panic::catch_unwind(|| ProtoSubscribe::from(msg));
1973 assert!(result.is_ok());
1974 }
1975
1976 #[test]
1977 fn test_panic_header() {
1978 let header = SlimHeader {
1980 source: None,
1981 destination: None,
1982 identity: String::new(),
1983 fanout: 0,
1984 recv_from: None,
1985 forward_to: None,
1986 incoming_conn: None,
1987 error: None,
1988 };
1989
1990 let result = std::panic::catch_unwind(|| header.get_source());
1992 assert!(result.is_err());
1993
1994 let result = std::panic::catch_unwind(|| header.get_dst());
1995 assert!(result.is_err());
1996
1997 let result = std::panic::catch_unwind(|| header.get_recv_from());
1999 assert!(result.is_ok());
2000
2001 let result = std::panic::catch_unwind(|| header.get_forward_to());
2002 assert!(result.is_ok());
2003
2004 let result = std::panic::catch_unwind(|| header.get_incoming_conn());
2006 assert!(result.is_ok());
2007
2008 let result = std::panic::catch_unwind(|| header.get_error());
2010 assert!(result.is_ok());
2011 }
2012
2013 #[test]
2014 fn test_panic_session_header() {
2015 let header = SessionHeader::new(0, 0, 0, 0);
2017
2018 let result = std::panic::catch_unwind(|| header.get_session_id());
2020 assert!(result.is_ok());
2021
2022 let result = std::panic::catch_unwind(|| header.get_message_id());
2023 assert!(result.is_ok());
2024 }
2025
2026 #[test]
2027 fn test_panic_proto_message() {
2028 let message = ProtoMessage {
2030 metadata: HashMap::new(),
2031 message_type: None,
2032 };
2033
2034 let result = std::panic::catch_unwind(|| message.get_slim_header());
2036 assert!(result.is_err());
2037
2038 let result = std::panic::catch_unwind(|| message.get_type());
2040 assert!(result.is_err());
2041
2042 let result = std::panic::catch_unwind(|| message.get_source());
2044 assert!(result.is_err());
2045 let result = std::panic::catch_unwind(|| message.get_dst());
2046 assert!(result.is_err());
2047 let result = std::panic::catch_unwind(|| message.get_recv_from());
2048 assert!(result.is_err());
2049 let result = std::panic::catch_unwind(|| message.get_forward_to());
2050 assert!(result.is_err());
2051 let result = std::panic::catch_unwind(|| message.get_incoming_conn());
2052 assert!(result.is_err());
2053 let result = std::panic::catch_unwind(|| message.get_fanout());
2054 assert!(result.is_err());
2055 }
2056
2057 #[test]
2058 fn test_service_type_to_int() {
2059 let total_service_types = SessionMessageType::Ping as i32;
2061
2062 for i in 0..total_service_types {
2063 let service_type =
2065 SessionMessageType::try_from(i).expect("failed to convert int to service type");
2066 let service_type_int = i32::from(service_type);
2067 assert_eq!(service_type_int, i32::from(service_type),);
2068 }
2069
2070 let invalid_service_type = SessionMessageType::try_from(total_service_types + 1);
2072 assert!(invalid_service_type.is_err());
2073 }
2074
2075 #[test]
2076 fn test_proto_message_builder() {
2077 let source = Name::from_strings(["org", "ns", "type"]).with_id(1);
2078 let dest = Name::from_strings(["org", "ns", "app"]).with_id(2);
2079
2080 let msg = ProtoMessage::builder()
2082 .source(source.clone())
2083 .destination(dest.clone())
2084 .application_payload("test", b"hello world".to_vec())
2085 .build_publish()
2086 .unwrap();
2087
2088 assert!(msg.is_publish());
2089 assert_eq!(msg.get_source(), source);
2090 assert_eq!(msg.get_dst(), dest);
2091
2092 let msg = ProtoMessage::builder()
2094 .source(source.clone())
2095 .destination(dest.clone())
2096 .session_type(ProtoSessionType::Multicast)
2097 .session_message_type(SessionMessageType::Msg)
2098 .session_id(42)
2099 .message_id(100)
2100 .fanout(256)
2101 .application_payload("test", b"broadcast".to_vec())
2102 .build_publish()
2103 .unwrap();
2104
2105 assert_eq!(msg.get_session_type(), ProtoSessionType::Multicast);
2106 assert_eq!(msg.get_id(), 100);
2107 assert_eq!(msg.get_fanout(), 256);
2108
2109 let msg = ProtoMessage::builder()
2111 .source(source.clone())
2112 .destination(dest.clone())
2113 .metadata("key1", "value1")
2114 .metadata("key2", "value2")
2115 .application_payload("test", vec![1, 2, 3])
2116 .build_publish()
2117 .unwrap();
2118
2119 assert_eq!(msg.get_metadata("key1"), Some(&"value1".to_string()));
2120 assert_eq!(msg.get_metadata("key2"), Some(&"value2".to_string()));
2121
2122 let msg = ProtoMessage::builder()
2124 .source(source.clone())
2125 .destination(dest.clone())
2126 .recv_from(10)
2127 .build_subscribe()
2128 .unwrap();
2129
2130 assert!(msg.is_subscribe());
2131 assert_eq!(msg.get_recv_from(), Some(10));
2132
2133 let msg = ProtoMessage::builder()
2135 .source(source.clone())
2136 .destination(dest.clone())
2137 .forward_to(20)
2138 .build_unsubscribe()
2139 .unwrap();
2140
2141 assert!(msg.is_unsubscribe());
2142 assert_eq!(msg.get_forward_to(), Some(20));
2143 }
2144
2145 #[test]
2146 fn test_command_payload_builder() {
2147 let dest = Name::from_strings(["org", "ns", "app"]);
2148
2149 let payload = CommandPayload::builder().discovery_request(Some(dest.clone()));
2151 let extracted = payload.as_discovery_request_payload().unwrap();
2152 assert!(extracted.destination.is_some());
2153
2154 let payload = CommandPayload::builder().discovery_reply();
2156 assert!(payload.as_discovery_reply_payload().is_ok());
2157
2158 let payload = CommandPayload::builder().join_request(
2160 true,
2161 Some(5),
2162 Some(Duration::from_secs(10)),
2163 Some(dest.clone()),
2164 );
2165 let extracted = payload.as_join_request_payload().unwrap();
2166 assert!(extracted.enable_mls);
2167 assert!(extracted.timer_settings.is_some());
2168
2169 let payload = CommandPayload::builder().join_reply(Some(vec![1, 2, 3]));
2171 let extracted = payload.as_join_reply_payload().unwrap();
2172 assert_eq!(extracted.key_package, Some(vec![1, 2, 3]));
2173
2174 let payload = CommandPayload::builder().leave_request(Some(dest.clone()));
2176 assert!(payload.as_leave_request_payload().is_ok());
2177
2178 let payload = CommandPayload::builder().leave_reply();
2180 assert!(payload.as_leave_reply_payload().is_ok());
2181
2182 let participants = vec![dest.clone()];
2184 let payload = CommandPayload::builder().group_add(dest.clone(), participants.clone(), None);
2185 let extracted = payload.as_group_add_payload().unwrap();
2186 assert!(extracted.new_participant.is_some());
2187
2188 let payload =
2190 CommandPayload::builder().group_remove(dest.clone(), participants.clone(), None);
2191 let extracted = payload.as_group_remove_payload().unwrap();
2192 assert!(extracted.removed_participant.is_some());
2193
2194 let payload = CommandPayload::builder().group_welcome(participants.clone(), None);
2196 let extracted = payload.as_welcome_payload().unwrap();
2197 assert!(!extracted.participants.is_empty());
2198
2199 let payload = CommandPayload::builder().group_proposal(Some(dest.clone()), vec![4, 5, 6]);
2201 let extracted = payload.as_group_proposal_payload().unwrap();
2202 assert_eq!(extracted.mls_proposal, vec![4, 5, 6]);
2203
2204 let payload = CommandPayload::builder().group_ack();
2206 assert!(payload.as_group_ack_payload().is_ok());
2207
2208 let payload = CommandPayload::builder().group_nack();
2210 assert!(payload.as_group_nack_payload().is_ok());
2211
2212 let payload = CommandPayload::builder().ping();
2214 assert!(payload.as_ping_payload().is_ok());
2215 }
2216
2217 #[test]
2218 fn test_builder_with_command_payload() {
2219 let source = Name::from_strings(["org", "ns", "type"]).with_id(1);
2220 let dest = Name::from_strings(["org", "ns", "app"]).with_id(2);
2221
2222 let cmd_payload = CommandPayload::builder().discovery_request(Some(dest.clone()));
2223
2224 let msg = ProtoMessage::builder()
2225 .source(source.clone())
2226 .destination(dest.clone())
2227 .session_type(ProtoSessionType::PointToPoint)
2228 .session_message_type(SessionMessageType::DiscoveryRequest)
2229 .session_id(1)
2230 .command_payload(cmd_payload)
2231 .build_publish()
2232 .unwrap();
2233
2234 assert!(msg.is_publish());
2235 assert_eq!(
2236 msg.get_session_message_type(),
2237 SessionMessageType::DiscoveryRequest
2238 );
2239
2240 let extracted = msg.extract_discovery_request().unwrap();
2242 assert!(extracted.destination.is_some());
2243 }
2244
2245 #[test]
2246 fn test_validate_link_without_link_type() {
2247 let link = ProtoLink { link_type: None };
2248 let msg = ProtoMessage::new(HashMap::new(), ProtoLinkMessageType(link));
2249 assert!(matches!(msg.validate(), Err(MessageError::LinkTypeNotSet)));
2250 }
2251
2252 #[test]
2253 fn test_validate_link_with_link_type() {
2254 let link = ProtoLink {
2255 link_type: Some(ProtoLinkType::LinkNegotiation(LinkNegotiationPayload {
2256 link_id: "abc".into(),
2257 slim_version: "1.0.0".into(),
2258 is_reply: false,
2259 })),
2260 };
2261 let msg = ProtoMessage::new(HashMap::new(), ProtoLinkMessageType(link));
2262 assert!(msg.validate().is_ok());
2263 }
2264
2265 #[test]
2266 fn test_build_link_negotiation_request() {
2267 let msg = ProtoMessage::builder().build_link_negotiation("my-id", "1.2.3", false);
2268 assert!(msg.is_link());
2269 assert!(!msg.is_publish());
2270 assert!(!msg.is_subscribe());
2271 assert!(msg.validate().is_ok());
2272 }
2273
2274 #[test]
2275 fn test_build_link_negotiation_reply() {
2276 let msg = ProtoMessage::builder().build_link_negotiation("my-id", "1.2.3", true);
2277 assert!(msg.is_link());
2278 assert!(msg.validate().is_ok());
2279 }
2280}