1use std::fmt::Display;
5use std::{collections::HashMap, time::Duration};
6
7use crate::api::proto::dataplane::v1::{
8 GroupClosePayload, GroupNackPayload, Participant, ParticipantSettings, PingPayload,
9};
10use crate::api::{
11 Content, LinkNegotiationPayload, MessageType, ProtoLink, ProtoLinkMessageType, ProtoLinkType,
12 ProtoMessage, ProtoName, ProtoPublish, ProtoPublishType, ProtoSessionType, ProtoSubscribe,
13 ProtoSubscribeType, ProtoSubscriptionAck, ProtoSubscriptionAckType, ProtoUnsubscribe,
14 ProtoUnsubscribeType, SessionHeader, SlimHeader,
15 proto::dataplane::v1::{
16 ApplicationPayload, CommandPayload, DiscoveryReplyPayload, DiscoveryRequestPayload,
17 EncodedName, GroupAckPayload, GroupAddPayload, GroupProposalPayload, GroupRemovePayload,
18 GroupWelcomePayload, JoinReplyPayload, JoinRequestPayload, LeaveReplyPayload,
19 LeaveRequestPayload, MlsPayload, SessionMessageType, TimerSettings,
20 command_payload::CommandPayloadType, content::ContentType,
21 },
22};
23
24use slim_version::version;
25use thiserror::Error;
26
27pub const DELETE_GROUP: &str = "DELETE_GROUP";
31
32pub const PUBLISH_TO: &str = "PUBLISH_TO";
36
37pub const DISCONNECTION_DETECTED: &str = "DISCONNECTION_DETECTED";
42
43pub const LEAVING_SESSION: &str = "LEAVING_SESSION";
47
48pub const TRUE_VAL: &str = "TRUE";
50
51pub const FALSE_VAL: &str = "FALSE";
53
54pub const MAX_PUBLISH_ID: u32 = u32::MAX / 2;
59
60#[derive(Error, Debug, PartialEq)]
61pub enum MessageError {
62 #[error("SLIM header not found")]
63 SlimHeaderNotFound,
64 #[error("source not found")]
65 SourceNotFound,
66 #[error("source encoded name not found")]
67 SourceEncodedNameNotFound,
68 #[error("destination not found")]
69 DestinationNotFound,
70 #[error("destination encoded name not found")]
71 DestinationEncodedNameNotFound,
72 #[error("session header not found")]
73 SessionHeaderNotFound,
74 #[error("message type not found")]
75 MessageTypeNotFound,
76 #[error("incoming connection not found")]
77 IncomingConnectionNotFound,
78 #[error("content type is not set")]
79 ContentTypeNotSet,
80 #[error("content is not an application payload")]
81 NotApplicationPayload,
82 #[error("content is not a command payload")]
83 NotCommandPayload,
84 #[error("link type is not set")]
85 LinkTypeNotSet,
86 #[error("invalid command payload type: expected {expected}, got {got}")]
87 InvalidCommandPayloadType {
88 expected: Box<String>,
89 got: Box<String>,
90 },
91
92 #[error("builder error: source is required")]
94 BuilderErrorSourceRequired,
95 #[error("builder error: destination is required")]
96 BuilderErrorDestinationRequired,
97 #[error("participant name not found")]
98 ParticipantNameNotFound,
99 #[error("participant settings not found")]
100 ParticipantSettingsNotFound,
101}
102
103impl ParticipantSettings {
104 pub fn bidirectional() -> Self {
107 Self {
108 sends_data: true,
109 receives_data: true,
110 }
111 }
112
113 pub fn send_only() -> Self {
115 Self {
116 sends_data: true,
117 receives_data: false,
118 }
119 }
120
121 pub fn receive_only() -> Self {
123 Self {
124 sends_data: false,
125 receives_data: true,
126 }
127 }
128
129 pub fn is_sender(&self) -> bool {
131 self.sends_data
132 }
133
134 pub fn is_receiver(&self) -> bool {
136 self.receives_data
137 }
138}
139
140impl Participant {
141 pub fn new(name: ProtoName, settings: ParticipantSettings) -> Self {
142 Self {
143 name: Some(name),
144 settings: Some(settings),
145 }
146 }
147
148 pub fn get_name(&self) -> Result<ProtoName, MessageError> {
149 match &self.name {
150 Some(name) => Ok(name.clone()),
151 None => Err(MessageError::ParticipantNameNotFound),
152 }
153 }
154
155 pub fn get_settings(&self) -> Result<&ParticipantSettings, MessageError> {
156 match &self.settings {
157 Some(settings) => Ok(settings),
158 None => Err(MessageError::ParticipantSettingsNotFound),
159 }
160 }
161}
162
163impl Display for MessageType {
165 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
166 match self {
167 MessageType::Publish(_) => write!(f, "publish"),
168 MessageType::Subscribe(_) => write!(f, "subscribe"),
169 MessageType::Unsubscribe(_) => write!(f, "unsubscribe"),
170 MessageType::Link(_) => write!(f, "link"),
171 MessageType::SubscriptionAck(_) => write!(f, "subscription_ack"),
172 }
173 }
174}
175
176#[derive(Debug, Clone)]
178pub struct SlimHeaderFlags {
179 pub fanout: u32,
180 pub recv_from: Option<u64>,
181 pub forward_to: Option<u64>,
182 pub incoming_conn: Option<u64>,
183 pub error: Option<bool>,
184}
185
186impl Default for SlimHeaderFlags {
187 fn default() -> Self {
188 Self {
189 fanout: 1,
190 recv_from: None,
191 forward_to: None,
192 incoming_conn: None,
193 error: None,
194 }
195 }
196}
197
198impl Display for SlimHeaderFlags {
199 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
200 write!(
201 f,
202 "fanout: {}, recv_from: {:?}, forward_to: {:?}, incoming_conn: {:?}, error: {:?}",
203 self.fanout, self.recv_from, self.forward_to, self.incoming_conn, self.error
204 )
205 }
206}
207
208impl SlimHeaderFlags {
209 pub fn new(
210 fanout: u32,
211 recv_from: Option<u64>,
212 forward_to: Option<u64>,
213 incoming_conn: Option<u64>,
214 error: Option<bool>,
215 ) -> Self {
216 Self {
217 fanout,
218 recv_from,
219 forward_to,
220 incoming_conn,
221 error,
222 }
223 }
224
225 pub fn with_fanout(self, fanout: u32) -> Self {
226 Self { fanout, ..self }
227 }
228
229 pub fn with_recv_from(self, recv_from: u64) -> Self {
230 Self {
231 recv_from: Some(recv_from),
232 ..self
233 }
234 }
235
236 pub fn with_forward_to(self, forward_to: u64) -> Self {
237 Self {
238 forward_to: Some(forward_to),
239 ..self
240 }
241 }
242
243 pub fn with_incoming_conn(self, incoming_conn: u64) -> Self {
244 Self {
245 incoming_conn: Some(incoming_conn),
246 ..self
247 }
248 }
249
250 pub fn with_error(self, error: bool) -> Self {
251 Self {
252 error: Some(error),
253 ..self
254 }
255 }
256}
257
258impl SlimHeader {
262 pub fn new(
263 source: ProtoName,
264 destination: ProtoName,
265 identity: &str,
266 flags: Option<SlimHeaderFlags>,
267 ) -> Self {
268 let flags = flags.unwrap_or_default();
269 Self {
270 source: Some(source),
271 destination: Some(destination),
272 identity: identity.to_string(),
273 fanout: flags.fanout,
274 version: version().to_string(),
275 recv_from: flags.recv_from,
276 forward_to: flags.forward_to,
277 incoming_conn: flags.incoming_conn,
278 error: flags.error,
279 header_mac: None,
280 }
281 }
282
283 pub fn clear_flags(&mut self) {
284 self.recv_from = None;
285 self.forward_to = None;
286 }
287
288 pub fn get_fanout(&self) -> u32 {
289 self.fanout
290 }
291
292 pub fn get_recv_from(&self) -> Option<u64> {
293 self.recv_from
294 }
295
296 pub fn get_forward_to(&self) -> Option<u64> {
297 self.forward_to
298 }
299
300 pub fn get_incoming_conn(&self) -> Option<u64> {
301 self.incoming_conn
302 }
303
304 pub fn get_error(&self) -> Option<bool> {
305 self.error
306 }
307
308 pub fn get_source(&self) -> ProtoName {
309 self.source.clone().expect("source not found")
310 }
311
312 pub fn get_encoded_source(&self) -> EncodedName {
313 self.source.as_ref().unwrap().name.unwrap()
314 }
315
316 pub fn get_dst(&self) -> ProtoName {
317 self.destination.clone().expect("destination not found")
318 }
319
320 pub fn get_encoded_dst(&self) -> EncodedName {
321 self.destination.as_ref().unwrap().name.unwrap()
322 }
323
324 pub fn get_identity(&self) -> String {
325 self.identity.clone()
326 }
327
328 pub fn get_version(&self) -> String {
329 self.version.clone()
330 }
331
332 pub fn set_source(&mut self, source: ProtoName) {
333 self.source = Some(source);
334 }
335
336 pub fn set_destination(&mut self, dst: ProtoName) {
337 self.destination = Some(dst);
338 }
339
340 pub fn set_identity(&mut self, identity: String) {
341 self.identity = identity;
342 }
343
344 pub fn set_fanout(&mut self, fanout: u32) {
345 self.fanout = fanout;
346 }
347
348 pub fn set_recv_from(&mut self, recv_from: Option<u64>) {
349 self.recv_from = recv_from;
350 }
351
352 pub fn set_forward_to(&mut self, forward_to: Option<u64>) {
353 self.forward_to = forward_to;
354 }
355
356 pub fn set_error(&mut self, error: Option<bool>) {
357 self.error = error;
358 }
359
360 pub fn set_incoming_conn(&mut self, incoming_conn: Option<u64>) {
361 self.incoming_conn = incoming_conn;
362 }
363
364 pub fn set_error_flag(&mut self, error: Option<bool>) {
365 self.error = error;
366 }
367
368 pub(crate) fn get_connections(&self) -> (u64, Option<u64>, Option<u64>) {
370 let incoming = self
372 .get_incoming_conn()
373 .expect("incoming connection not found");
374
375 (incoming, self.get_recv_from(), self.get_forward_to())
376 }
377}
378
379impl SessionHeader {
383 pub fn new(
384 session_type: i32,
385 session_message_type: i32,
386 session_id: u32,
387 message_id: u32,
388 ) -> Self {
389 Self {
390 session_type,
391 session_message_type,
392 session_id,
393 message_id,
394 }
395 }
396
397 pub fn get_session_id(&self) -> u32 {
398 self.session_id
399 }
400
401 pub fn get_message_id(&self) -> u32 {
402 self.message_id
403 }
404
405 pub fn set_session_id(&mut self, session_id: u32) {
406 self.session_id = session_id;
407 }
408
409 pub fn set_message_id(&mut self, message_id: u32) {
410 self.message_id = message_id;
411 }
412
413 pub fn clear(&mut self) {
414 self.session_id = 0;
415 self.message_id = 0;
416 }
417}
418
419impl SessionMessageType {
422 pub fn is_command_message(&self) -> bool {
424 matches!(
425 self,
426 SessionMessageType::DiscoveryRequest
427 | SessionMessageType::DiscoveryReply
428 | SessionMessageType::JoinRequest
429 | SessionMessageType::JoinReply
430 | SessionMessageType::LeaveRequest
431 | SessionMessageType::LeaveReply
432 | SessionMessageType::GroupAdd
433 | SessionMessageType::GroupRemove
434 | SessionMessageType::GroupWelcome
435 | SessionMessageType::GroupClose
436 | SessionMessageType::GroupProposal
437 | SessionMessageType::GroupAck
438 | SessionMessageType::GroupNack
439 | SessionMessageType::Ping
440 )
441 }
442}
443
444impl ProtoSubscribe {
447 fn new(
448 source: ProtoName,
449 dst: ProtoName,
450 identity: Option<&str>,
451 flags: Option<SlimHeaderFlags>,
452 ) -> Self {
453 let id = identity.unwrap_or("");
454 let header = Some(SlimHeader::new(source, dst, id, flags));
455
456 ProtoSubscribe {
457 header,
458 subscription_id: 0,
459 }
460 }
461}
462
463impl From<ProtoMessage> for ProtoSubscribe {
465 fn from(message: ProtoMessage) -> Self {
466 match message.message_type {
467 Some(ProtoSubscribeType(s)) => s,
468 _ => panic!("message type is not subscribe"),
469 }
470 }
471}
472
473impl ProtoUnsubscribe {
476 fn new(
477 source: ProtoName,
478 dst: ProtoName,
479 identity: Option<&str>,
480 flags: Option<SlimHeaderFlags>,
481 ) -> Self {
482 let id = identity.unwrap_or("");
483 let header = Some(SlimHeader::new(source, dst, id, flags));
484
485 ProtoUnsubscribe {
486 header,
487 subscription_id: 0,
488 }
489 }
490}
491
492impl From<ProtoMessage> for ProtoUnsubscribe {
494 fn from(message: ProtoMessage) -> Self {
495 match message.message_type {
496 Some(ProtoUnsubscribeType(u)) => u,
497 _ => panic!("message type is not unsubscribe"),
498 }
499 }
500}
501
502impl ProtoPublish {
505 fn with_header(
506 header: Option<SlimHeader>,
507 session: Option<SessionHeader>,
508 payload: Option<Content>,
509 ) -> Self {
510 ProtoPublish {
511 header,
512 session,
513 msg: payload,
514 }
515 }
516
517 pub fn get_slim_header(&self) -> &SlimHeader {
518 self.header.as_ref().unwrap()
519 }
520
521 pub fn get_session_header(&self) -> &SessionHeader {
522 self.session.as_ref().unwrap()
523 }
524
525 pub fn get_slim_header_as_mut(&mut self) -> &mut SlimHeader {
526 self.header.as_mut().unwrap()
527 }
528
529 pub fn get_session_header_as_mut(&mut self) -> &mut SessionHeader {
530 self.session.as_mut().unwrap()
531 }
532
533 pub fn get_payload(&self) -> &Content {
534 self.msg.as_ref().unwrap()
535 }
536
537 pub fn set_payload(&mut self, payload: Content) {
538 self.msg = Some(payload);
539 }
540
541 pub fn is_command(&self) -> bool {
542 match &self.get_payload().content_type.as_ref().unwrap() {
543 ContentType::AppPayload(_) => false,
544 ContentType::CommandPayload(_) => true,
545 }
546 }
547
548 pub fn get_application_payload(&self) -> &ApplicationPayload {
549 match self.get_payload().content_type.as_ref().unwrap() {
550 ContentType::AppPayload(application_payload) => application_payload,
551 ContentType::CommandPayload(_) => panic!("the payload is not an application payload"),
552 }
553 }
554
555 pub fn get_command_payload(&self) -> &CommandPayload {
556 match &self.get_payload().content_type.as_ref().unwrap() {
557 ContentType::AppPayload(_) => panic!("the payaoad is not a command payload"),
558 ContentType::CommandPayload(command_payload) => command_payload,
559 }
560 }
561}
562
563impl From<ProtoMessage> for ProtoPublish {
565 fn from(message: ProtoMessage) -> Self {
566 match message.message_type {
567 Some(ProtoPublishType(p)) => p,
568 _ => panic!("message type is not publish"),
569 }
570 }
571}
572
573macro_rules! impl_payload_extractors {
577 ($($method_name:ident => $getter_method:ident($payload_type:ty)),* $(,)?) => {
578 $(
579 pub fn $method_name(&self) -> Result<&$payload_type, MessageError> {
581 self.extract_command_payload()?.$getter_method()
582 }
583 )*
584 };
585}
586
587impl ProtoMessage {
588 fn new(metadata: HashMap<String, String>, message_type: MessageType) -> Self {
589 ProtoMessage {
590 metadata,
591 message_type: Some(message_type),
592 }
593 }
594
595 fn validate_link(link: &ProtoLink) -> Result<(), MessageError> {
596 if link.link_type.is_none() {
597 return Err(MessageError::LinkTypeNotSet);
598 }
599 Ok(())
600 }
601
602 fn validate_routed_header(slim_header: &SlimHeader) -> Result<(), MessageError> {
603 match &slim_header.source {
604 None => return Err(MessageError::SourceNotFound),
605 Some(src) if src.name.is_none() => return Err(MessageError::SourceEncodedNameNotFound),
606 _ => {}
607 }
608 match &slim_header.destination {
609 None => return Err(MessageError::DestinationNotFound),
610 Some(dst) if dst.name.is_none() => {
611 return Err(MessageError::DestinationEncodedNameNotFound);
612 }
613 _ => {}
614 }
615 Ok(())
616 }
617
618 fn validate_publish(p: &ProtoPublish) -> Result<(), MessageError> {
619 let hdr = p.header.as_ref().ok_or(MessageError::SlimHeaderNotFound)?;
620 Self::validate_routed_header(hdr)?;
621 if p.session.is_none() {
622 return Err(MessageError::SessionHeaderNotFound);
623 }
624 Ok(())
625 }
626
627 fn validate_subscribe(s: &ProtoSubscribe) -> Result<(), MessageError> {
628 let hdr = s.header.as_ref().ok_or(MessageError::SlimHeaderNotFound)?;
629 Self::validate_routed_header(hdr)
630 }
631
632 fn validate_unsubscribe(u: &ProtoUnsubscribe) -> Result<(), MessageError> {
633 let hdr = u.header.as_ref().ok_or(MessageError::SlimHeaderNotFound)?;
634 Self::validate_routed_header(hdr)
635 }
636
637 pub fn validate(&self) -> Result<(), MessageError> {
639 match &self.message_type {
640 None => Err(MessageError::MessageTypeNotFound),
641 Some(ProtoLinkMessageType(link)) => Self::validate_link(link),
642 Some(ProtoPublishType(p)) => Self::validate_publish(p),
643 Some(ProtoSubscribeType(s)) => Self::validate_subscribe(s),
644 Some(ProtoUnsubscribeType(u)) => Self::validate_unsubscribe(u),
645 Some(ProtoSubscriptionAckType(_)) => Ok(()),
646 }
647 }
648
649 pub fn insert_metadata(&mut self, key: String, val: String) {
652 self.metadata.insert(key, val);
653 }
654
655 pub fn remove_metadata(&mut self, key: &str) -> Option<String> {
657 self.metadata.remove(key)
658 }
659
660 pub fn contains_metadata(&self, key: &str) -> bool {
661 self.metadata.contains_key(key)
662 }
663
664 pub fn get_metadata(&self, key: &str) -> Option<&String> {
665 self.metadata.get(key)
666 }
667
668 pub fn get_metadata_map(&self) -> HashMap<String, String> {
669 self.metadata.clone()
670 }
671
672 pub fn set_metadata_map(&mut self, map: HashMap<String, String>) {
673 for (k, v) in map.iter() {
674 self.insert_metadata(k.to_string(), v.to_string());
675 }
676 }
677
678 pub fn get_slim_header(&self) -> &SlimHeader {
679 match &self.message_type {
680 Some(ProtoPublishType(publish)) => publish.header.as_ref().unwrap(),
681 Some(ProtoSubscribeType(sub)) => sub.header.as_ref().unwrap(),
682 Some(ProtoUnsubscribeType(unsub)) => unsub.header.as_ref().unwrap(),
683 Some(ProtoLinkMessageType(_)) | Some(ProtoSubscriptionAckType(_)) | None => {
684 panic!("SLIM header not found")
685 }
686 }
687 }
688
689 pub fn get_slim_header_mut(&mut self) -> &mut SlimHeader {
690 match &mut self.message_type {
691 Some(ProtoPublishType(publish)) => publish.header.as_mut().unwrap(),
692 Some(ProtoSubscribeType(sub)) => sub.header.as_mut().unwrap(),
693 Some(ProtoUnsubscribeType(unsub)) => unsub.header.as_mut().unwrap(),
694 Some(ProtoLinkMessageType(_)) | Some(ProtoSubscriptionAckType(_)) | None => {
695 panic!("SLIM header not found")
696 }
697 }
698 }
699
700 pub fn try_get_slim_header(&self) -> Option<&SlimHeader> {
701 match &self.message_type {
702 Some(ProtoPublishType(publish)) => publish.header.as_ref(),
703 Some(ProtoSubscribeType(sub)) => sub.header.as_ref(),
704 Some(ProtoUnsubscribeType(unsub)) => unsub.header.as_ref(),
705 Some(ProtoLinkMessageType(_)) | Some(ProtoSubscriptionAckType(_)) | None => None,
706 }
707 }
708
709 pub fn get_session_header(&self) -> &SessionHeader {
710 match &self.message_type {
711 Some(ProtoPublishType(publish)) => publish.session.as_ref().unwrap(),
712 Some(ProtoSubscribeType(_))
713 | Some(ProtoUnsubscribeType(_))
714 | Some(ProtoLinkMessageType(_))
715 | Some(ProtoSubscriptionAckType(_))
716 | None => panic!("session header not found"),
717 }
718 }
719
720 pub fn get_session_header_mut(&mut self) -> &mut SessionHeader {
721 match &mut self.message_type {
722 Some(ProtoPublishType(publish)) => publish.session.as_mut().unwrap(),
723 Some(ProtoSubscribeType(_))
724 | Some(ProtoUnsubscribeType(_))
725 | Some(ProtoLinkMessageType(_))
726 | Some(ProtoSubscriptionAckType(_))
727 | None => panic!("session header not found"),
728 }
729 }
730
731 pub fn try_get_session_header(&self) -> Option<&SessionHeader> {
732 match &self.message_type {
733 Some(ProtoPublishType(publish)) => publish.session.as_ref(),
734 Some(ProtoSubscribeType(_))
735 | Some(ProtoUnsubscribeType(_))
736 | Some(ProtoLinkMessageType(_))
737 | Some(ProtoSubscriptionAckType(_))
738 | None => None,
739 }
740 }
741
742 pub fn try_get_session_header_mut(&mut self) -> Option<&mut SessionHeader> {
743 match &mut self.message_type {
744 Some(ProtoPublishType(publish)) => publish.session.as_mut(),
745 Some(ProtoSubscribeType(_))
746 | Some(ProtoUnsubscribeType(_))
747 | Some(ProtoLinkMessageType(_))
748 | Some(ProtoSubscriptionAckType(_))
749 | None => None,
750 }
751 }
752
753 pub fn get_id(&self) -> u32 {
754 self.get_session_header().get_message_id()
755 }
756
757 pub fn get_source(&self) -> ProtoName {
758 self.get_slim_header().get_source()
759 }
760
761 pub fn get_encoded_source(&self) -> EncodedName {
762 self.get_slim_header().get_encoded_source()
763 }
764
765 pub fn get_dst(&self) -> ProtoName {
766 self.get_slim_header().get_dst()
767 }
768
769 pub fn get_encoded_dst(&self) -> EncodedName {
770 self.get_slim_header().get_encoded_dst()
771 }
772
773 pub fn get_identity(&self) -> String {
774 self.get_slim_header().get_identity()
775 }
776
777 pub fn get_fanout(&self) -> u32 {
778 self.get_slim_header().get_fanout()
779 }
780
781 pub fn get_recv_from(&self) -> Option<u64> {
782 self.get_slim_header().get_recv_from()
783 }
784
785 pub fn get_forward_to(&self) -> Option<u64> {
786 self.get_slim_header().get_forward_to()
787 }
788
789 pub fn get_error(&self) -> Option<bool> {
790 self.get_slim_header().get_error()
791 }
792
793 pub fn get_incoming_conn(&self) -> u64 {
794 self.get_slim_header().get_incoming_conn().unwrap()
795 }
796
797 pub fn try_get_incoming_conn(&self) -> Option<u64> {
798 self.get_slim_header().get_incoming_conn()
799 }
800
801 pub fn get_type(&self) -> &MessageType {
802 match &self.message_type {
803 Some(t) => t,
804 None => panic!("message type not found"),
805 }
806 }
807
808 pub fn get_payload(&self) -> Option<&Content> {
809 match &self.message_type {
810 Some(ProtoPublishType(p)) => p.msg.as_ref(),
811 Some(ProtoSubscribeType(_)) => panic!("payload not found"),
812 Some(ProtoUnsubscribeType(_)) => panic!("payload not found"),
813 Some(ProtoLinkMessageType(_)) => panic!("payload not found"),
814 Some(ProtoSubscriptionAckType(_)) => panic!("payload not found"),
815 None => panic!("payload not found"),
816 }
817 }
818
819 pub fn set_payload(&mut self, payload: Content) {
820 match &mut self.message_type {
821 Some(ProtoPublishType(p)) => p.set_payload(payload),
822 Some(ProtoSubscribeType(_)) => panic!("no payload allowed"),
823 Some(ProtoUnsubscribeType(_)) => panic!("no payload allowed"),
824 Some(ProtoLinkMessageType(_)) => panic!("no payload allowed"),
825 Some(ProtoSubscriptionAckType(_)) => panic!("no payload allowed"),
826 None => panic!("no payload allowed"),
827 }
828 }
829
830 pub fn get_session_message_type(&self) -> SessionMessageType {
831 self.get_session_header()
832 .session_message_type
833 .try_into()
834 .unwrap_or_default()
835 }
836
837 pub fn clear_slim_header(&mut self) {
838 if self.is_link() || self.is_subscription_ack() {
839 return;
840 }
841 self.get_slim_header_mut().clear_flags();
842 }
843
844 pub fn set_recv_from(&mut self, recv_from: Option<u64>) {
845 self.get_slim_header_mut().set_recv_from(recv_from);
846 }
847
848 pub fn set_forward_to(&mut self, forward_to: Option<u64>) {
849 self.get_slim_header_mut().set_forward_to(forward_to);
850 }
851
852 pub fn set_error(&mut self, error: Option<bool>) {
853 self.get_slim_header_mut().set_error(error);
854 }
855
856 pub fn set_fanout(&mut self, fanout: u32) {
857 self.get_slim_header_mut().set_fanout(fanout);
858 }
859
860 pub fn set_incoming_conn(&mut self, incoming_conn: Option<u64>) {
861 self.get_slim_header_mut().set_incoming_conn(incoming_conn);
862 }
863
864 pub fn set_error_flag(&mut self, error: Option<bool>) {
865 self.get_slim_header_mut().set_error_flag(error);
866 }
867
868 pub fn set_session_message_type(&mut self, message_type: SessionMessageType) {
869 self.get_session_header_mut()
870 .set_session_message_type(message_type);
871 }
872
873 pub fn set_session_type(&mut self, session_type: ProtoSessionType) {
874 self.get_session_header_mut().set_session_type(session_type);
875 }
876
877 pub fn get_session_type(&self) -> ProtoSessionType {
878 self.get_session_header().session_type()
879 }
880
881 pub fn set_message_id(&mut self, message_id: u32) {
882 self.get_session_header_mut().set_message_id(message_id);
883 }
884
885 pub fn is_publish(&self) -> bool {
886 matches!(self.get_type(), MessageType::Publish(_))
887 }
888
889 pub fn is_subscribe(&self) -> bool {
890 matches!(self.get_type(), MessageType::Subscribe(_))
891 }
892
893 pub fn is_unsubscribe(&self) -> bool {
894 matches!(self.get_type(), MessageType::Unsubscribe(_))
895 }
896
897 pub fn is_link(&self) -> bool {
898 matches!(self.get_type(), MessageType::Link(_))
899 }
900
901 pub fn is_subscription_ack(&self) -> bool {
902 matches!(self.get_type(), MessageType::SubscriptionAck(_))
903 }
904
905 pub fn is_traceable(&self) -> bool {
906 !self.is_link() && !self.is_subscription_ack()
907 }
908
909 pub fn get_subscription_ack(&self) -> &ProtoSubscriptionAck {
910 match &self.message_type {
911 Some(ProtoSubscriptionAckType(ack)) => ack,
912 _ => panic!("message type is not subscription_ack"),
913 }
914 }
915
916 pub fn get_subscription_id(&self) -> Option<u64> {
918 match &self.message_type {
919 Some(ProtoSubscribeType(s)) if s.subscription_id != 0 => Some(s.subscription_id),
920 Some(ProtoUnsubscribeType(u)) if u.subscription_id != 0 => Some(u.subscription_id),
921 _ => None,
922 }
923 }
924
925 pub fn take_subscription_id(&mut self) -> Option<u64> {
928 match &mut self.message_type {
929 Some(ProtoSubscribeType(s)) if s.subscription_id != 0 => {
930 Some(std::mem::take(&mut s.subscription_id))
931 }
932 Some(ProtoUnsubscribeType(u)) if u.subscription_id != 0 => {
933 Some(std::mem::take(&mut u.subscription_id))
934 }
935 _ => None,
936 }
937 }
938
939 pub fn set_subscription_id(&mut self, subscription_id: u64) {
941 match &mut self.message_type {
942 Some(ProtoSubscribeType(s)) => s.subscription_id = subscription_id,
943 Some(ProtoUnsubscribeType(u)) => u.subscription_id = subscription_id,
944 _ => {}
945 }
946 }
947
948 pub fn extract_command_payload(&self) -> Result<&CommandPayload, MessageError> {
953 self.get_payload()
954 .ok_or(MessageError::ContentTypeNotSet)?
955 .as_command_payload()
956 }
957
958 impl_payload_extractors! {
960 extract_discovery_request => as_discovery_request_payload(DiscoveryRequestPayload),
961 extract_discovery_reply => as_discovery_reply_payload(DiscoveryReplyPayload),
962 extract_join_request => as_join_request_payload(JoinRequestPayload),
963 extract_join_reply => as_join_reply_payload(JoinReplyPayload),
964 extract_leave_request => as_leave_request_payload(LeaveRequestPayload),
965 extract_leave_reply => as_leave_reply_payload(LeaveReplyPayload),
966 extract_group_add => as_group_add_payload(GroupAddPayload),
967 extract_group_remove => as_group_remove_payload(GroupRemovePayload),
968 extract_group_welcome => as_welcome_payload(GroupWelcomePayload),
969 extract_group_close => as_group_close_payload(GroupClosePayload),
970 extract_group_proposal => as_group_proposal_payload(GroupProposalPayload),
971 extract_group_ack => as_group_ack_payload(GroupAckPayload),
972 extract_group_nack => as_group_nack_payload(GroupNackPayload),
973 extract_ping => as_ping_payload(PingPayload),
974 }
975}
976
977impl Content {
978 pub fn as_application_payload(&self) -> Result<&ApplicationPayload, MessageError> {
979 match &self.content_type {
980 Some(ContentType::AppPayload(app_payload)) => Ok(app_payload),
981 Some(ContentType::CommandPayload(_)) => Err(MessageError::NotApplicationPayload),
982 None => Err(MessageError::ContentTypeNotSet),
983 }
984 }
985
986 pub fn as_command_payload(&self) -> Result<&CommandPayload, MessageError> {
987 match &self.content_type {
988 Some(ContentType::AppPayload(_)) => Err(MessageError::NotCommandPayload),
989 Some(ContentType::CommandPayload(comm_payload)) => Ok(comm_payload),
990 None => Err(MessageError::ContentTypeNotSet),
991 }
992 }
993}
994
995impl ApplicationPayload {
996 pub fn new(payload_type: &str, blob: Vec<u8>) -> Self {
997 Self {
998 payload_type: payload_type.to_string(),
999 blob,
1000 }
1001 }
1002
1003 pub fn as_content(&self) -> Content {
1004 Content {
1005 content_type: Some(ContentType::AppPayload(self.clone())),
1006 }
1007 }
1008}
1009
1010macro_rules! impl_command_payload_getters {
1012 ($(
1013 $method_name:ident => $variant:ident($payload_type:ty)
1014 ),* $(,)?) => {
1015 $(
1016 pub fn $method_name(&self) -> Result<&$payload_type, MessageError> {
1017 match &self.command_payload_type {
1018 Some(CommandPayloadType::$variant(payload)) => Ok(payload),
1019 Some(other) => Err(MessageError::InvalidCommandPayloadType {
1020 expected: Box::new(stringify!($variant).to_string()),
1021 got: Box::new(format!("{:?}", other)),
1022 }),
1023 None => Err(MessageError::InvalidCommandPayloadType {
1024 expected: Box::new(stringify!($variant).to_string()),
1025 got: Box::new("None".to_string()),
1026 }),
1027 }
1028 }
1029 )*
1030 };
1031}
1032
1033impl CommandPayload {
1034 pub fn as_content(self) -> Content {
1035 Content {
1036 content_type: Some(ContentType::CommandPayload(self)),
1037 }
1038 }
1039
1040 impl_command_payload_getters! {
1042 as_discovery_request_payload => DiscoveryRequest(DiscoveryRequestPayload),
1043 as_discovery_reply_payload => DiscoveryReply(DiscoveryReplyPayload),
1044 as_join_request_payload => JoinRequest(JoinRequestPayload),
1045 as_join_reply_payload => JoinReply(JoinReplyPayload),
1046 as_leave_request_payload => LeaveRequest(LeaveRequestPayload),
1047 as_leave_reply_payload => LeaveReply(LeaveReplyPayload),
1048 as_group_add_payload => GroupAdd(GroupAddPayload),
1049 as_group_remove_payload => GroupRemove(GroupRemovePayload),
1050 as_welcome_payload => GroupWelcome(GroupWelcomePayload),
1051 as_group_close_payload => GroupClose(GroupClosePayload),
1052 as_group_proposal_payload => GroupProposal(GroupProposalPayload),
1053 as_group_ack_payload => GroupAck(GroupAckPayload),
1054 as_group_nack_payload => GroupNack(GroupNackPayload),
1055 as_ping_payload => Ping(PingPayload),
1056 }
1057}
1058
1059impl AsRef<ProtoPublish> for ProtoMessage {
1060 fn as_ref(&self) -> &ProtoPublish {
1061 match &self.message_type {
1062 Some(ProtoPublishType(p)) => p,
1063 _ => panic!("message type is not publish"),
1064 }
1065 }
1066}
1067
1068pub struct CommandPayloadBuilder;
1115
1116impl CommandPayloadBuilder {
1117 pub fn new() -> Self {
1119 Self
1120 }
1121
1122 pub fn discovery_request(self) -> CommandPayload {
1124 let payload = DiscoveryRequestPayload {};
1125 CommandPayload {
1126 command_payload_type: Some(CommandPayloadType::DiscoveryRequest(payload)),
1127 }
1128 }
1129
1130 pub fn discovery_reply(self) -> CommandPayload {
1132 let payload = DiscoveryReplyPayload {};
1133 CommandPayload {
1134 command_payload_type: Some(CommandPayloadType::DiscoveryReply(payload)),
1135 }
1136 }
1137
1138 pub fn join_request(
1140 self,
1141 enable_mls: bool,
1142 max_retries: Option<u32>,
1143 timer_duration: Option<Duration>,
1144 channel: Option<ProtoName>,
1145 ) -> CommandPayload {
1146 let proto_channel = channel;
1147
1148 let timer_settings = if let Some(t) = timer_duration
1149 && let Some(m) = max_retries
1150 {
1151 Some(TimerSettings {
1152 timeout: t.as_millis() as u32,
1153 max_retries: m,
1154 })
1155 } else {
1156 None
1157 };
1158
1159 let payload = JoinRequestPayload {
1160 enable_mls,
1161 timer_settings,
1162 channel: proto_channel,
1163 };
1164 CommandPayload {
1165 command_payload_type: Some(CommandPayloadType::JoinRequest(payload)),
1166 }
1167 }
1168
1169 pub fn join_reply(
1171 self,
1172 key_package: Option<Vec<u8>>,
1173 participant: Participant,
1174 ) -> CommandPayload {
1175 let payload = JoinReplyPayload {
1176 key_package,
1177 participant: Some(participant),
1178 };
1179 CommandPayload {
1180 command_payload_type: Some(CommandPayloadType::JoinReply(payload)),
1181 }
1182 }
1183
1184 pub fn leave_request(self) -> CommandPayload {
1186 let payload = LeaveRequestPayload {};
1187 CommandPayload {
1188 command_payload_type: Some(CommandPayloadType::LeaveRequest(payload)),
1189 }
1190 }
1191
1192 pub fn leave_reply(self) -> CommandPayload {
1194 let payload = LeaveReplyPayload {};
1195 CommandPayload {
1196 command_payload_type: Some(CommandPayloadType::LeaveReply(payload)),
1197 }
1198 }
1199
1200 pub fn group_add(
1202 self,
1203 new_participant: Participant,
1204 participants: Vec<Participant>,
1205 mls: Option<MlsPayload>,
1206 ) -> CommandPayload {
1207 let payload = GroupAddPayload {
1208 new_participant: Some(new_participant),
1209 participants,
1210 mls,
1211 };
1212 CommandPayload {
1213 command_payload_type: Some(CommandPayloadType::GroupAdd(payload)),
1214 }
1215 }
1216
1217 pub fn group_remove(
1219 self,
1220 removed_participant: ProtoName,
1221 participants: Vec<ProtoName>,
1222 mls: Option<MlsPayload>,
1223 ) -> CommandPayload {
1224 let payload = GroupRemovePayload {
1225 removed_participant: Some(removed_participant),
1226 participants,
1227 mls,
1228 };
1229 CommandPayload {
1230 command_payload_type: Some(CommandPayloadType::GroupRemove(payload)),
1231 }
1232 }
1233
1234 pub fn group_welcome(
1236 self,
1237 participants: Vec<Participant>,
1238 mls: Option<MlsPayload>,
1239 ) -> CommandPayload {
1240 let payload = GroupWelcomePayload { participants, mls };
1241 CommandPayload {
1242 command_payload_type: Some(CommandPayloadType::GroupWelcome(payload)),
1243 }
1244 }
1245
1246 pub fn group_close(self, participants: Vec<ProtoName>) -> CommandPayload {
1248 let payload = GroupClosePayload { participants };
1249 CommandPayload {
1250 command_payload_type: Some(CommandPayloadType::GroupClose(payload)),
1251 }
1252 }
1253
1254 pub fn group_proposal(
1256 self,
1257 source: Option<ProtoName>,
1258 mls_proposal: Vec<u8>,
1259 ) -> CommandPayload {
1260 let payload = GroupProposalPayload {
1261 source,
1262 mls_proposal,
1263 };
1264 CommandPayload {
1265 command_payload_type: Some(CommandPayloadType::GroupProposal(payload)),
1266 }
1267 }
1268
1269 pub fn group_ack(self) -> CommandPayload {
1271 let payload = GroupAckPayload {};
1272 CommandPayload {
1273 command_payload_type: Some(CommandPayloadType::GroupAck(payload)),
1274 }
1275 }
1276
1277 pub fn group_nack(self) -> CommandPayload {
1279 let payload = GroupNackPayload {};
1280 CommandPayload {
1281 command_payload_type: Some(CommandPayloadType::GroupNack(payload)),
1282 }
1283 }
1284
1285 pub fn ping(self) -> CommandPayload {
1287 let payload = PingPayload {};
1288 CommandPayload {
1289 command_payload_type: Some(CommandPayloadType::Ping(payload)),
1290 }
1291 }
1292}
1293
1294impl Default for CommandPayloadBuilder {
1295 fn default() -> Self {
1296 Self::new()
1297 }
1298}
1299
1300impl CommandPayload {
1301 pub fn builder() -> CommandPayloadBuilder {
1303 CommandPayloadBuilder::new()
1304 }
1305}
1306
1307pub struct ProtoMessageBuilder {
1393 source: Option<ProtoName>,
1394 destination: Option<ProtoName>,
1395 identity: Option<String>,
1396 flags: Option<SlimHeaderFlags>,
1397 session_type: Option<ProtoSessionType>,
1398 session_message_type: Option<SessionMessageType>,
1399 session_id: Option<u32>,
1400 message_id: Option<u32>,
1401 payload: Option<Content>,
1402 metadata: HashMap<String, String>,
1403 subscription_id: Option<u64>,
1404}
1405
1406impl ProtoMessageBuilder {
1407 pub fn new() -> Self {
1409 Self {
1410 source: None,
1411 destination: None,
1412 identity: None,
1413 flags: None,
1414 session_type: None,
1415 session_message_type: None,
1416 session_id: None,
1417 message_id: None,
1418 payload: None,
1419 metadata: HashMap::new(),
1420 subscription_id: None,
1421 }
1422 }
1423
1424 pub fn source(mut self, source: ProtoName) -> Self {
1426 self.source = Some(source);
1427 self
1428 }
1429
1430 pub fn destination(mut self, destination: ProtoName) -> Self {
1432 self.destination = Some(destination);
1433 self
1434 }
1435
1436 pub fn identity(mut self, identity: impl Into<String>) -> Self {
1438 self.identity = Some(identity.into());
1439 self
1440 }
1441
1442 pub fn flags(mut self, flags: SlimHeaderFlags) -> Self {
1444 self.flags = Some(flags);
1445 self
1446 }
1447
1448 pub fn fanout(mut self, fanout: u32) -> Self {
1450 let flags = self.flags.take().unwrap_or_default();
1451 self.flags = Some(flags.with_fanout(fanout));
1452 self
1453 }
1454
1455 pub fn recv_from(mut self, recv_from: u64) -> Self {
1457 let flags = self.flags.take().unwrap_or_default();
1458 self.flags = Some(flags.with_recv_from(recv_from));
1459 self
1460 }
1461
1462 pub fn forward_to(mut self, forward_to: u64) -> Self {
1464 let flags = self.flags.take().unwrap_or_default();
1465 self.flags = Some(flags.with_forward_to(forward_to));
1466 self
1467 }
1468
1469 pub fn incoming_conn(mut self, incoming_conn: u64) -> Self {
1471 let flags = self.flags.take().unwrap_or_default();
1472 self.flags = Some(flags.with_incoming_conn(incoming_conn));
1473 self
1474 }
1475
1476 pub fn error(mut self, error: bool) -> Self {
1478 let flags = self.flags.take().unwrap_or_default();
1479 self.flags = Some(flags.with_error(error));
1480 self
1481 }
1482
1483 pub fn session_type(mut self, session_type: ProtoSessionType) -> Self {
1485 self.session_type = Some(session_type);
1486 self
1487 }
1488
1489 pub fn session_message_type(mut self, session_message_type: SessionMessageType) -> Self {
1491 self.session_message_type = Some(session_message_type);
1492 self
1493 }
1494
1495 pub fn session_id(mut self, session_id: u32) -> Self {
1497 self.session_id = Some(session_id);
1498 self
1499 }
1500
1501 pub fn message_id(mut self, message_id: u32) -> Self {
1503 self.message_id = Some(message_id);
1504 self
1505 }
1506
1507 pub fn payload(mut self, payload: Content) -> Self {
1509 self.payload = Some(payload);
1510 self
1511 }
1512
1513 pub fn application_payload(mut self, payload_type: &str, blob: Vec<u8>) -> Self {
1515 let app_payload = ApplicationPayload::new(payload_type, blob);
1516 self.payload = Some(app_payload.as_content());
1517 self
1518 }
1519
1520 pub fn command_payload(mut self, payload: CommandPayload) -> Self {
1522 self.payload = Some(payload.as_content());
1523 self
1524 }
1525
1526 pub fn with_slim_header(mut self, header: SlimHeader) -> Self {
1531 if let Some(src) = header.source.clone() {
1533 self.source = Some(src);
1534 }
1535 if let Some(dst) = header.destination.clone() {
1536 self.destination = Some(dst);
1537 }
1538 if !header.identity.is_empty() {
1539 self.identity = Some(header.identity.clone());
1540 }
1541
1542 let flags = SlimHeaderFlags {
1544 fanout: header.fanout,
1545 recv_from: header.recv_from,
1546 forward_to: header.forward_to,
1547 incoming_conn: header.incoming_conn,
1548 error: header.error,
1549 };
1550 self.flags = Some(flags);
1551 self
1552 }
1553
1554 pub fn with_session_header(mut self, header: SessionHeader) -> Self {
1559 self.session_type = Some(
1560 ProtoSessionType::try_from(header.session_type)
1561 .unwrap_or(ProtoSessionType::PointToPoint),
1562 );
1563 self.session_message_type = Some(
1564 SessionMessageType::try_from(header.session_message_type)
1565 .unwrap_or(SessionMessageType::Msg),
1566 );
1567 self.session_id = Some(header.session_id);
1568 self.message_id = Some(header.message_id);
1569 self
1570 }
1571
1572 pub fn metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
1574 self.metadata.insert(key.into(), value.into());
1575 self
1576 }
1577
1578 pub fn metadata_map(mut self, map: HashMap<String, String>) -> Self {
1580 self.metadata.extend(map);
1581 self
1582 }
1583
1584 pub fn subscription_id(mut self, id: u64) -> Self {
1586 self.subscription_id = Some(id);
1587 self
1588 }
1589
1590 pub fn build_publish(self) -> Result<ProtoMessage, MessageError> {
1592 let source = self
1593 .source
1594 .ok_or(MessageError::BuilderErrorSourceRequired)?;
1595 let destination = self
1596 .destination
1597 .ok_or(MessageError::BuilderErrorDestinationRequired)?;
1598
1599 let slim_header = Some(SlimHeader::new(
1600 source,
1601 destination,
1602 self.identity.as_deref().unwrap_or(""),
1603 self.flags,
1604 ));
1605
1606 let session_header = if self.session_type.is_some() || self.session_message_type.is_some() {
1607 Some(SessionHeader::new(
1608 self.session_type
1609 .unwrap_or(ProtoSessionType::PointToPoint)
1610 .into(),
1611 self.session_message_type
1612 .unwrap_or(SessionMessageType::Msg)
1613 .into(),
1614 self.session_id.unwrap_or(0),
1615 self.message_id.unwrap_or_else(rand::random),
1616 ))
1617 } else {
1618 Some(SessionHeader::default())
1619 };
1620
1621 let publish = ProtoPublish::with_header(slim_header, session_header, self.payload);
1622 let message = ProtoMessage::new(self.metadata, ProtoPublishType(publish));
1623 Ok(message)
1624 }
1625
1626 pub fn build_subscribe(self) -> Result<ProtoMessage, MessageError> {
1628 let source = self
1629 .source
1630 .ok_or(MessageError::BuilderErrorSourceRequired)?;
1631 let destination = self
1632 .destination
1633 .ok_or(MessageError::BuilderErrorDestinationRequired)?;
1634
1635 let mut subscribe =
1636 ProtoSubscribe::new(source, destination, self.identity.as_deref(), self.flags);
1637 subscribe.subscription_id = self.subscription_id.unwrap_or_default();
1638
1639 Ok(ProtoMessage::new(
1640 self.metadata,
1641 ProtoSubscribeType(subscribe),
1642 ))
1643 }
1644
1645 pub fn build_unsubscribe(self) -> Result<ProtoMessage, MessageError> {
1647 let source = self
1648 .source
1649 .ok_or(MessageError::BuilderErrorSourceRequired)?;
1650 let destination = self
1651 .destination
1652 .ok_or(MessageError::BuilderErrorDestinationRequired)?;
1653
1654 let mut unsubscribe =
1655 ProtoUnsubscribe::new(source, destination, self.identity.as_deref(), self.flags);
1656 unsubscribe.subscription_id = self.subscription_id.unwrap_or_default();
1657
1658 Ok(ProtoMessage::new(
1659 self.metadata,
1660 ProtoUnsubscribeType(unsubscribe),
1661 ))
1662 }
1663
1664 pub fn build_subscription_ack(
1668 self,
1669 subscription_id: u64,
1670 success: bool,
1671 error: impl Into<String>,
1672 ) -> ProtoMessage {
1673 let ack = ProtoSubscriptionAck {
1674 subscription_id,
1675 success,
1676 error: error.into(),
1677 };
1678 ProtoMessage::new(self.metadata, ProtoSubscriptionAckType(ack))
1679 }
1680
1681 pub fn build_link_negotiation(
1684 self,
1685 link_id: impl Into<String>,
1686 slim_version: impl Into<String>,
1687 is_reply: bool,
1688 link_ecdh_public_key: Option<Vec<u8>>,
1689 ) -> ProtoMessage {
1690 let link_ecdh_public_key = link_ecdh_public_key.unwrap_or_default();
1691 let link = ProtoLink {
1692 link_type: Some(ProtoLinkType::LinkNegotiation(LinkNegotiationPayload {
1693 link_id: link_id.into(),
1694 slim_version: slim_version.into(),
1695 is_reply,
1696 link_ecdh_public_key,
1697 })),
1698 };
1699 ProtoMessage::new(self.metadata, ProtoLinkMessageType(link))
1700 }
1701}
1702
1703impl Default for ProtoMessageBuilder {
1704 fn default() -> Self {
1705 Self::new()
1706 }
1707}
1708
1709impl ProtoMessage {
1710 pub fn builder() -> ProtoMessageBuilder {
1712 ProtoMessageBuilder::new()
1713 }
1714}
1715
1716#[cfg(test)]
1717mod tests {
1718 use crate::api::proto::dataplane::v1::SessionMessageType;
1719
1720 use super::*;
1721
1722 fn test_subscription_template(
1723 subscription: bool,
1724 source: ProtoName,
1725 dst: ProtoName,
1726 identity: Option<&str>,
1727 flags: Option<SlimHeaderFlags>,
1728 ) {
1729 let sub = {
1730 let mut builder = ProtoMessage::builder()
1731 .source(source.clone())
1732 .destination(dst.clone());
1733
1734 if let Some(id) = identity {
1735 builder = builder.identity(id);
1736 }
1737
1738 if let Some(f) = flags.clone() {
1739 builder = builder.flags(f);
1740 }
1741
1742 if subscription {
1743 builder.build_subscribe().unwrap()
1744 } else {
1745 builder.build_unsubscribe().unwrap()
1746 }
1747 };
1748
1749 let flags = if flags.is_none() {
1750 Some(SlimHeaderFlags::default())
1751 } else {
1752 flags
1753 };
1754
1755 assert!(!sub.is_publish());
1756 assert_eq!(sub.is_subscribe(), subscription);
1757 assert_eq!(sub.is_unsubscribe(), !subscription);
1758 assert_eq!(flags.as_ref().unwrap().recv_from, sub.get_recv_from());
1759 assert_eq!(flags.as_ref().unwrap().forward_to, sub.get_forward_to());
1760 assert_eq!(None, sub.try_get_incoming_conn());
1761 assert_eq!(source, sub.get_source());
1762 let got_name = sub.get_dst();
1763 assert_eq!(dst, got_name);
1764 }
1765
1766 fn test_publish_template(
1767 source: ProtoName,
1768 dst: ProtoName,
1769 identity: Option<&str>,
1770 flags: Option<SlimHeaderFlags>,
1771 ) {
1772 let mut builder = ProtoMessage::builder()
1773 .source(source.clone())
1774 .destination(dst.clone())
1775 .application_payload("str", "this is the content of the message".into());
1776
1777 if let Some(id) = identity {
1778 builder = builder.identity(id);
1779 }
1780
1781 if let Some(f) = flags.clone() {
1782 builder = builder.flags(f);
1783 }
1784
1785 let pub_msg = builder.build_publish().unwrap();
1786
1787 let flags = if flags.is_none() {
1788 Some(SlimHeaderFlags::default())
1789 } else {
1790 flags
1791 };
1792
1793 assert!(pub_msg.is_publish());
1794 assert!(!pub_msg.is_subscribe());
1795 assert!(!pub_msg.is_unsubscribe());
1796 assert_eq!(flags.as_ref().unwrap().recv_from, pub_msg.get_recv_from());
1797 assert_eq!(flags.as_ref().unwrap().forward_to, pub_msg.get_forward_to());
1798 assert_eq!(None, pub_msg.try_get_incoming_conn());
1799 assert_eq!(source, pub_msg.get_source());
1800 let got_name = pub_msg.get_dst();
1801 assert_eq!(dst, got_name);
1802 assert_eq!(flags.as_ref().unwrap().fanout, pub_msg.get_fanout());
1803 }
1804
1805 #[test]
1806 fn test_subscription() {
1807 let source = ProtoName::from_strings(["org", "ns", "type"]).with_id(1);
1808 let dst = ProtoName::from_strings(["org", "ns", "type"]).with_id(2);
1809
1810 test_subscription_template(true, source.clone(), dst.clone(), None, None);
1812
1813 test_subscription_template(true, source.clone(), dst.clone(), None, None);
1815
1816 test_subscription_template(
1818 true,
1819 source.clone(),
1820 dst.clone(),
1821 None,
1822 Some(SlimHeaderFlags::default().with_recv_from(50)),
1823 );
1824
1825 test_subscription_template(
1827 true,
1828 source.clone(),
1829 dst.clone(),
1830 None,
1831 Some(SlimHeaderFlags::default().with_forward_to(30)),
1832 );
1833 }
1834
1835 #[test]
1836 fn test_unsubscription() {
1837 let source = ProtoName::from_strings(["org", "ns", "type"]).with_id(1);
1838 let dst = ProtoName::from_strings(["org", "ns", "type"]).with_id(2);
1839
1840 test_subscription_template(false, source.clone(), dst.clone(), None, None);
1842
1843 test_subscription_template(false, source.clone(), dst.clone(), None, None);
1845
1846 test_subscription_template(
1848 false,
1849 source.clone(),
1850 dst.clone(),
1851 None,
1852 Some(SlimHeaderFlags::default().with_recv_from(50)),
1853 );
1854
1855 test_subscription_template(
1857 false,
1858 source.clone(),
1859 dst.clone(),
1860 None,
1861 Some(SlimHeaderFlags::default().with_forward_to(30)),
1862 );
1863 }
1864
1865 #[test]
1866 fn test_publish() {
1867 let source = ProtoName::from_strings(["org", "ns", "type"]).with_id(1);
1868 let mut dst = ProtoName::from_strings(["org", "ns", "type"]);
1869
1870 test_publish_template(
1872 source.clone(),
1873 dst.clone(),
1874 None,
1875 Some(SlimHeaderFlags::default()),
1876 );
1877
1878 dst.set_id(2);
1880 test_publish_template(
1881 source.clone(),
1882 dst.clone(),
1883 None,
1884 Some(SlimHeaderFlags::default()),
1885 );
1886 dst.reset_id();
1887
1888 test_publish_template(
1890 source.clone(),
1891 dst.clone(),
1892 None,
1893 Some(SlimHeaderFlags::default().with_recv_from(50)),
1894 );
1895
1896 test_publish_template(
1898 source.clone(),
1899 dst.clone(),
1900 None,
1901 Some(SlimHeaderFlags::default().with_forward_to(30)),
1902 );
1903
1904 test_publish_template(
1906 source.clone(),
1907 dst.clone(),
1908 None,
1909 Some(SlimHeaderFlags::default().with_fanout(2)),
1910 );
1911 }
1912
1913 #[test]
1914 fn test_conversions() {
1915 let name = ProtoName::from_strings(["org", "ns", "type"]).with_id(1);
1917 let dst = ProtoName::from_strings(["org", "ns", "type"]).with_id(1);
1918 let proto_subscribe = ProtoMessage::builder()
1919 .source(name.clone())
1920 .destination(dst.clone())
1921 .flags(
1922 SlimHeaderFlags::default()
1923 .with_recv_from(2)
1924 .with_forward_to(3),
1925 )
1926 .build_subscribe()
1927 .unwrap();
1928 let proto_subscribe = ProtoSubscribe::from(proto_subscribe);
1929 assert_eq!(proto_subscribe.header.as_ref().unwrap().get_source(), name);
1930 assert_eq!(proto_subscribe.header.as_ref().unwrap().get_dst(), dst,);
1931
1932 let proto_unsubscribe = ProtoMessage::builder()
1934 .source(name.clone())
1935 .destination(dst.clone())
1936 .flags(
1937 SlimHeaderFlags::default()
1938 .with_recv_from(2)
1939 .with_forward_to(3),
1940 )
1941 .build_unsubscribe()
1942 .unwrap();
1943 let proto_unsubscribe = ProtoUnsubscribe::from(proto_unsubscribe);
1944 assert_eq!(
1945 proto_unsubscribe.header.as_ref().unwrap().get_source(),
1946 name
1947 );
1948 assert_eq!(proto_unsubscribe.header.as_ref().unwrap().get_dst(), dst);
1949
1950 let proto_publish = ProtoMessage::builder()
1952 .source(name.clone())
1953 .destination(dst.clone())
1954 .flags(
1955 SlimHeaderFlags::default()
1956 .with_recv_from(2)
1957 .with_forward_to(3),
1958 )
1959 .application_payload("str", "this is the content of the message".into())
1960 .build_publish()
1961 .unwrap();
1962 let proto_publish = ProtoPublish::from(proto_publish);
1963 assert_eq!(proto_publish.header.as_ref().unwrap().get_source(), name);
1964 assert_eq!(proto_publish.header.as_ref().unwrap().get_dst(), dst);
1965 }
1966
1967 #[test]
1968 fn test_panic() {
1969 let source = ProtoName::from_strings(["org", "ns", "type"]).with_id(1);
1970 let dst = ProtoName::from_strings(["org", "ns", "type"]).with_id(2);
1971
1972 let msg = ProtoMessage::builder()
1974 .source(source.clone())
1975 .destination(dst.clone())
1976 .flags(
1977 SlimHeaderFlags::default()
1978 .with_recv_from(2)
1979 .with_forward_to(3),
1980 )
1981 .build_subscribe()
1982 .unwrap();
1983
1984 let result = std::panic::catch_unwind(|| ProtoUnsubscribe::from(msg.clone()));
1987 assert!(result.is_err());
1988
1989 let result = std::panic::catch_unwind(|| ProtoPublish::from(msg.clone()));
1992 assert!(result.is_err());
1993
1994 let result = std::panic::catch_unwind(|| ProtoSubscribe::from(msg));
1996 assert!(result.is_ok());
1997 }
1998
1999 #[test]
2000 fn test_panic_header() {
2001 let header = SlimHeader {
2003 source: None,
2004 destination: None,
2005 identity: String::new(),
2006 fanout: 0,
2007 version: version().to_string(),
2008 recv_from: None,
2009 forward_to: None,
2010 incoming_conn: None,
2011 error: None,
2012 header_mac: None,
2013 };
2014
2015 let result = std::panic::catch_unwind(|| header.get_source());
2017 assert!(result.is_err());
2018
2019 let result = std::panic::catch_unwind(|| header.get_dst());
2020 assert!(result.is_err());
2021
2022 let result = std::panic::catch_unwind(|| header.get_recv_from());
2024 assert!(result.is_ok());
2025
2026 let result = std::panic::catch_unwind(|| header.get_forward_to());
2027 assert!(result.is_ok());
2028
2029 let result = std::panic::catch_unwind(|| header.get_incoming_conn());
2031 assert!(result.is_ok());
2032
2033 let result = std::panic::catch_unwind(|| header.get_error());
2035 assert!(result.is_ok());
2036 }
2037
2038 #[test]
2039 fn test_panic_session_header() {
2040 let header = SessionHeader::new(0, 0, 0, 0);
2042
2043 let result = std::panic::catch_unwind(|| header.get_session_id());
2045 assert!(result.is_ok());
2046
2047 let result = std::panic::catch_unwind(|| header.get_message_id());
2048 assert!(result.is_ok());
2049 }
2050
2051 #[test]
2052 fn test_panic_proto_message() {
2053 let message = ProtoMessage {
2055 metadata: HashMap::new(),
2056 message_type: None,
2057 };
2058
2059 let result = std::panic::catch_unwind(|| message.get_slim_header());
2061 assert!(result.is_err());
2062
2063 let result = std::panic::catch_unwind(|| message.get_type());
2065 assert!(result.is_err());
2066
2067 let result = std::panic::catch_unwind(|| message.get_source());
2069 assert!(result.is_err());
2070 let result = std::panic::catch_unwind(|| message.get_dst());
2071 assert!(result.is_err());
2072 let result = std::panic::catch_unwind(|| message.get_recv_from());
2073 assert!(result.is_err());
2074 let result = std::panic::catch_unwind(|| message.get_forward_to());
2075 assert!(result.is_err());
2076 let result = std::panic::catch_unwind(|| message.get_incoming_conn());
2077 assert!(result.is_err());
2078 let result = std::panic::catch_unwind(|| message.get_fanout());
2079 assert!(result.is_err());
2080 }
2081
2082 #[test]
2083 fn test_service_type_to_int() {
2084 let total_service_types = SessionMessageType::Ping as i32;
2086
2087 for i in 0..total_service_types {
2088 let service_type =
2090 SessionMessageType::try_from(i).expect("failed to convert int to service type");
2091 let service_type_int = i32::from(service_type);
2092 assert_eq!(service_type_int, i32::from(service_type),);
2093 }
2094
2095 let invalid_service_type = SessionMessageType::try_from(total_service_types + 1);
2097 assert!(invalid_service_type.is_err());
2098 }
2099
2100 #[test]
2101 fn test_proto_message_builder() {
2102 let source = ProtoName::from_strings(["org", "ns", "type"]).with_id(1);
2103 let dest = ProtoName::from_strings(["org", "ns", "app"]).with_id(2);
2104
2105 let msg = ProtoMessage::builder()
2107 .source(source.clone())
2108 .destination(dest.clone())
2109 .application_payload("test", b"hello world".to_vec())
2110 .build_publish()
2111 .unwrap();
2112
2113 assert!(msg.is_publish());
2114 assert_eq!(msg.get_source(), source);
2115 assert_eq!(msg.get_dst(), dest);
2116
2117 let msg = ProtoMessage::builder()
2119 .source(source.clone())
2120 .destination(dest.clone())
2121 .session_type(ProtoSessionType::Multicast)
2122 .session_message_type(SessionMessageType::Msg)
2123 .session_id(42)
2124 .message_id(100)
2125 .fanout(256)
2126 .application_payload("test", b"broadcast".to_vec())
2127 .build_publish()
2128 .unwrap();
2129
2130 assert_eq!(msg.get_session_type(), ProtoSessionType::Multicast);
2131 assert_eq!(msg.get_id(), 100);
2132 assert_eq!(msg.get_fanout(), 256);
2133
2134 let msg = ProtoMessage::builder()
2136 .source(source.clone())
2137 .destination(dest.clone())
2138 .metadata("key1", "value1")
2139 .metadata("key2", "value2")
2140 .application_payload("test", vec![1, 2, 3])
2141 .build_publish()
2142 .unwrap();
2143
2144 assert_eq!(msg.get_metadata("key1"), Some(&"value1".to_string()));
2145 assert_eq!(msg.get_metadata("key2"), Some(&"value2".to_string()));
2146
2147 let msg = ProtoMessage::builder()
2149 .source(source.clone())
2150 .destination(dest.clone())
2151 .recv_from(10)
2152 .build_subscribe()
2153 .unwrap();
2154
2155 assert!(msg.is_subscribe());
2156 assert_eq!(msg.get_recv_from(), Some(10));
2157
2158 let msg = ProtoMessage::builder()
2160 .source(source.clone())
2161 .destination(dest.clone())
2162 .forward_to(20)
2163 .build_unsubscribe()
2164 .unwrap();
2165
2166 assert!(msg.is_unsubscribe());
2167 assert_eq!(msg.get_forward_to(), Some(20));
2168 }
2169
2170 #[test]
2171 fn test_command_payload_builder() {
2172 let dest = ProtoName::from_strings(["org", "ns", "app"]);
2173
2174 let payload = CommandPayload::builder().discovery_request();
2176 assert!(payload.as_discovery_request_payload().is_ok());
2177
2178 let payload = CommandPayload::builder().discovery_reply();
2180 assert!(payload.as_discovery_reply_payload().is_ok());
2181
2182 let payload = CommandPayload::builder().join_request(
2184 true,
2185 Some(5),
2186 Some(Duration::from_secs(10)),
2187 Some(dest.clone()),
2188 );
2189 let extracted = payload.as_join_request_payload().unwrap();
2190 assert!(extracted.enable_mls);
2191 assert!(extracted.timer_settings.is_some());
2192
2193 let participant = Participant::new(dest.clone(), ParticipantSettings::bidirectional());
2195 let payload =
2196 CommandPayload::builder().join_reply(Some(vec![1, 2, 3]), participant.clone());
2197 let extracted = payload.as_join_reply_payload().unwrap();
2198 assert_eq!(extracted.key_package, Some(vec![1, 2, 3]));
2199 assert_eq!(extracted.participant, Some(participant));
2200
2201 let payload = CommandPayload::builder().leave_request();
2203 assert!(payload.as_leave_request_payload().is_ok());
2204
2205 let payload = CommandPayload::builder().leave_reply();
2207 assert!(payload.as_leave_reply_payload().is_ok());
2208
2209 let participant = Participant::new(dest.clone(), ParticipantSettings::bidirectional());
2211 let participants = vec![participant.clone()];
2212 let payload =
2213 CommandPayload::builder().group_add(participant.clone(), participants.clone(), None);
2214 let extracted = payload.as_group_add_payload().unwrap();
2215 assert_eq!(extracted.new_participant, Some(participant));
2216 assert_eq!(extracted.participants, participants);
2217
2218 let payload =
2220 CommandPayload::builder().group_remove(dest.clone(), vec![dest.clone()], None);
2221 let extracted = payload.as_group_remove_payload().unwrap();
2222 assert!(extracted.removed_participant.is_some());
2223
2224 let payload = CommandPayload::builder().group_welcome(participants.clone(), None);
2226 let extracted = payload.as_welcome_payload().unwrap();
2227 assert!(!extracted.participants.is_empty());
2228
2229 let payload = CommandPayload::builder().group_proposal(Some(dest.clone()), vec![4, 5, 6]);
2231 let extracted = payload.as_group_proposal_payload().unwrap();
2232 assert_eq!(extracted.mls_proposal, vec![4, 5, 6]);
2233
2234 let payload = CommandPayload::builder().group_ack();
2236 assert!(payload.as_group_ack_payload().is_ok());
2237
2238 let payload = CommandPayload::builder().group_nack();
2240 assert!(payload.as_group_nack_payload().is_ok());
2241
2242 let payload = CommandPayload::builder().ping();
2244 assert!(payload.as_ping_payload().is_ok());
2245 }
2246
2247 #[test]
2248 fn test_builder_with_command_payload() {
2249 let source = ProtoName::from_strings(["org", "ns", "type"]).with_id(1);
2250 let dest = ProtoName::from_strings(["org", "ns", "app"]).with_id(2);
2251
2252 let cmd_payload = CommandPayload::builder().discovery_request();
2253
2254 let msg = ProtoMessage::builder()
2255 .source(source.clone())
2256 .destination(dest.clone())
2257 .session_type(ProtoSessionType::PointToPoint)
2258 .session_message_type(SessionMessageType::DiscoveryRequest)
2259 .session_id(1)
2260 .command_payload(cmd_payload)
2261 .build_publish()
2262 .unwrap();
2263
2264 assert!(msg.is_publish());
2265 assert_eq!(
2266 msg.get_session_message_type(),
2267 SessionMessageType::DiscoveryRequest
2268 );
2269 }
2270
2271 #[test]
2272 fn test_validate_link_without_link_type() {
2273 let link = ProtoLink { link_type: None };
2274 let msg = ProtoMessage::new(HashMap::new(), ProtoLinkMessageType(link));
2275 assert!(matches!(msg.validate(), Err(MessageError::LinkTypeNotSet)));
2276 }
2277
2278 #[test]
2279 fn test_validate_link_with_link_type() {
2280 let link = ProtoLink {
2281 link_type: Some(ProtoLinkType::LinkNegotiation(LinkNegotiationPayload {
2282 link_id: "abc".into(),
2283 slim_version: "1.0.0".into(),
2284 is_reply: false,
2285 link_ecdh_public_key: vec![],
2286 })),
2287 };
2288 let msg = ProtoMessage::new(HashMap::new(), ProtoLinkMessageType(link));
2289 assert!(msg.validate().is_ok());
2290 }
2291
2292 #[test]
2293 fn test_build_link_negotiation_request() {
2294 let msg = ProtoMessage::builder().build_link_negotiation("my-id", "1.2.3", false, None);
2295 assert!(msg.is_link());
2296 assert!(!msg.is_publish());
2297 assert!(!msg.is_subscribe());
2298 assert!(msg.validate().is_ok());
2299 }
2300
2301 #[test]
2302 fn test_build_link_negotiation_reply() {
2303 let msg = ProtoMessage::builder().build_link_negotiation("my-id", "1.2.3", true, None);
2304 assert!(msg.is_link());
2305 assert!(msg.validate().is_ok());
2306 }
2307
2308 #[test]
2309 fn test_validate_subscribe_missing_source_encoded_name() {
2310 let valid = ProtoName::from_strings(["org", "ns", "agent"]);
2311 let hdr = SlimHeader {
2312 source: Some(ProtoName {
2313 name: None,
2314 str_name: None,
2315 }),
2316 destination: Some(valid),
2317 ..Default::default()
2318 };
2319 let msg = ProtoMessage::new(
2320 HashMap::new(),
2321 ProtoSubscribeType(ProtoSubscribe {
2322 header: Some(hdr),
2323 ..Default::default()
2324 }),
2325 );
2326 assert!(matches!(
2327 msg.validate(),
2328 Err(MessageError::SourceEncodedNameNotFound)
2329 ));
2330 }
2331
2332 #[test]
2333 fn test_validate_subscribe_missing_destination_encoded_name() {
2334 let valid = ProtoName::from_strings(["org", "ns", "agent"]);
2335 let hdr = SlimHeader {
2336 source: Some(valid),
2337 destination: Some(ProtoName {
2338 name: None,
2339 str_name: None,
2340 }),
2341 ..Default::default()
2342 };
2343 let msg = ProtoMessage::new(
2344 HashMap::new(),
2345 ProtoSubscribeType(ProtoSubscribe {
2346 header: Some(hdr),
2347 ..Default::default()
2348 }),
2349 );
2350 assert!(matches!(
2351 msg.validate(),
2352 Err(MessageError::DestinationEncodedNameNotFound)
2353 ));
2354 }
2355
2356 #[test]
2357 fn test_participant_settings_convenience_methods() {
2358 let bidirectional = ParticipantSettings::bidirectional();
2359 assert!(bidirectional.sends_data);
2360 assert!(bidirectional.receives_data);
2361 assert!(bidirectional.is_sender());
2362 assert!(bidirectional.is_receiver());
2363
2364 let send_only = ParticipantSettings::send_only();
2365 assert!(send_only.sends_data);
2366 assert!(!send_only.receives_data);
2367 assert!(send_only.is_sender());
2368 assert!(!send_only.is_receiver());
2369
2370 let receive_only = ParticipantSettings::receive_only();
2371 assert!(!receive_only.sends_data);
2372 assert!(receive_only.receives_data);
2373 assert!(!receive_only.is_sender());
2374 assert!(receive_only.is_receiver());
2375 }
2376}