1use std::fmt::Display;
5use std::{collections::HashMap, time::Duration};
6
7use crate::api::proto::dataplane::v1::{
8 GroupClosePayload, GroupNackPayload, LinkConnectionType, Participant, ParticipantSettings,
9 PingPayload,
10};
11use crate::api::{
12 Content, LinkNegotiationPayload, MessageType, ProtoLink, ProtoLinkMessageType, ProtoLinkType,
13 ProtoMessage, ProtoMlsSettings as MlsSettings, ProtoName, ProtoPublish, ProtoPublishType,
14 ProtoSessionType, ProtoSubscribe, ProtoSubscribeType, ProtoSubscriptionAck,
15 ProtoSubscriptionAckType, ProtoUnsubscribe, ProtoUnsubscribeType, SessionHeader, SlimHeader,
16 proto::dataplane::v1::{
17 ApplicationPayload, CommandPayload, DiscoveryReplyPayload, DiscoveryRequestPayload,
18 EncodedName, GroupAckPayload, GroupAddPayload, GroupProposalPayload, GroupRemovePayload,
19 GroupWelcomePayload, JoinReplyPayload, JoinRequestPayload, LeaveReplyPayload,
20 LeaveRequestPayload, MlsPayload, SessionMessageType, TimerSettings,
21 command_payload::CommandPayloadType, content::ContentType,
22 },
23};
24
25use slim_version::version;
26use thiserror::Error;
27
28use crate::tables::ConnType;
29
30impl From<ConnType> for LinkConnectionType {
31 fn from(ct: ConnType) -> Self {
32 match ct {
33 ConnType::Peer => LinkConnectionType::Peer,
34 ConnType::Edge => LinkConnectionType::Edge,
35 _ => LinkConnectionType::Remote,
36 }
37 }
38}
39
40pub const DELETE_GROUP: &str = "DELETE_GROUP";
44
45pub const PUBLISH_TO: &str = "PUBLISH_TO";
49
50pub const DISCONNECTION_DETECTED: &str = "DISCONNECTION_DETECTED";
55
56pub const LEAVING_SESSION: &str = "LEAVING_SESSION";
60
61pub const TRUE_VAL: &str = "TRUE";
63
64pub const FALSE_VAL: &str = "FALSE";
66
67pub const MAX_PUBLISH_ID: u32 = u32::MAX / 2;
72
73pub const DEFAULT_TTL: u32 = 8;
75
76#[derive(Error, Debug, PartialEq)]
77pub enum MessageError {
78 #[error("SLIM header not found")]
79 SlimHeaderNotFound,
80 #[error("source not found")]
81 SourceNotFound,
82 #[error("source encoded name not found")]
83 SourceEncodedNameNotFound,
84 #[error("destination not found")]
85 DestinationNotFound,
86 #[error("destination encoded name not found")]
87 DestinationEncodedNameNotFound,
88 #[error("session header not found")]
89 SessionHeaderNotFound,
90 #[error("message type not found")]
91 MessageTypeNotFound,
92 #[error("incoming connection not found")]
93 IncomingConnectionNotFound,
94 #[error("content type is not set")]
95 ContentTypeNotSet,
96 #[error("content is not an application payload")]
97 NotApplicationPayload,
98 #[error("content is not a command payload")]
99 NotCommandPayload,
100 #[error("link type is not set")]
101 LinkTypeNotSet,
102 #[error("invalid command payload type: expected {expected}, got {got}")]
103 InvalidCommandPayloadType {
104 expected: Box<String>,
105 got: Box<String>,
106 },
107
108 #[error("builder error: source is required")]
110 BuilderErrorSourceRequired,
111 #[error("builder error: destination is required")]
112 BuilderErrorDestinationRequired,
113 #[error("participant name not found")]
114 ParticipantNameNotFound,
115 #[error("participant settings not found")]
116 ParticipantSettingsNotFound,
117}
118
119impl ParticipantSettings {
120 pub fn bidirectional() -> Self {
123 Self {
124 sends_data: true,
125 receives_data: true,
126 }
127 }
128
129 pub fn send_only() -> Self {
131 Self {
132 sends_data: true,
133 receives_data: false,
134 }
135 }
136
137 pub fn receive_only() -> Self {
139 Self {
140 sends_data: false,
141 receives_data: true,
142 }
143 }
144
145 pub fn is_sender(&self) -> bool {
147 self.sends_data
148 }
149
150 pub fn is_receiver(&self) -> bool {
152 self.receives_data
153 }
154}
155
156impl Participant {
157 pub fn new(name: ProtoName, settings: ParticipantSettings) -> Self {
158 Self {
159 name: Some(name),
160 settings: Some(settings),
161 }
162 }
163
164 pub fn get_name(&self) -> Result<ProtoName, MessageError> {
165 match &self.name {
166 Some(name) => Ok(name.clone()),
167 None => Err(MessageError::ParticipantNameNotFound),
168 }
169 }
170
171 pub fn get_settings(&self) -> Result<&ParticipantSettings, MessageError> {
172 match &self.settings {
173 Some(settings) => Ok(settings),
174 None => Err(MessageError::ParticipantSettingsNotFound),
175 }
176 }
177}
178
179impl Display for MessageType {
181 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182 match self {
183 MessageType::Publish(_) => write!(f, "publish"),
184 MessageType::Subscribe(_) => write!(f, "subscribe"),
185 MessageType::Unsubscribe(_) => write!(f, "unsubscribe"),
186 MessageType::Link(_) => write!(f, "link"),
187 MessageType::SubscriptionAck(_) => write!(f, "subscription_ack"),
188 }
189 }
190}
191
192#[derive(Debug, Clone)]
194pub struct SlimHeaderFlags {
195 pub fanout: u32,
196 pub recv_from: Option<u64>,
197 pub forward_to: Option<u64>,
198 pub incoming_conn: Option<u64>,
199 pub error: Option<bool>,
200 pub ttl: u32,
201}
202
203impl Default for SlimHeaderFlags {
204 fn default() -> Self {
205 Self {
206 fanout: 1,
207 recv_from: None,
208 forward_to: None,
209 incoming_conn: None,
210 error: None,
211 ttl: DEFAULT_TTL,
212 }
213 }
214}
215
216impl Display for SlimHeaderFlags {
217 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
218 write!(
219 f,
220 "fanout: {}, recv_from: {:?}, forward_to: {:?}, incoming_conn: {:?}, error: {:?}, ttl: {:?}",
221 self.fanout, self.recv_from, self.forward_to, self.incoming_conn, self.error, self.ttl
222 )
223 }
224}
225
226impl SlimHeaderFlags {
227 pub fn new(
228 fanout: u32,
229 recv_from: Option<u64>,
230 forward_to: Option<u64>,
231 incoming_conn: Option<u64>,
232 error: Option<bool>,
233 ) -> Self {
234 Self {
235 fanout,
236 recv_from,
237 forward_to,
238 incoming_conn,
239 error,
240 ttl: DEFAULT_TTL,
241 }
242 }
243
244 pub fn with_fanout(self, fanout: u32) -> Self {
245 Self { fanout, ..self }
246 }
247
248 pub fn with_recv_from(self, recv_from: u64) -> Self {
249 Self {
250 recv_from: Some(recv_from),
251 ..self
252 }
253 }
254
255 pub fn with_forward_to(self, forward_to: u64) -> Self {
256 Self {
257 forward_to: Some(forward_to),
258 ..self
259 }
260 }
261
262 pub fn with_incoming_conn(self, incoming_conn: u64) -> Self {
263 Self {
264 incoming_conn: Some(incoming_conn),
265 ..self
266 }
267 }
268
269 pub fn with_error(self, error: bool) -> Self {
270 Self {
271 error: Some(error),
272 ..self
273 }
274 }
275
276 pub fn with_ttl(self, ttl: u32) -> Self {
277 Self { ttl, ..self }
278 }
279}
280
281impl SlimHeader {
285 pub fn new(
286 source: ProtoName,
287 destination: ProtoName,
288 identity: &str,
289 flags: Option<SlimHeaderFlags>,
290 ) -> Self {
291 let flags = flags.unwrap_or_default();
292 Self {
293 source: Some(source),
294 destination: Some(destination),
295 identity: identity.to_string(),
296 fanout: flags.fanout,
297 version: version().to_string(),
298 recv_from: flags.recv_from,
299 forward_to: flags.forward_to,
300 incoming_conn: flags.incoming_conn,
301 error: flags.error,
302 header_mac: None,
303 ttl: flags.ttl,
304 }
305 }
306
307 pub fn clear_flags(&mut self) {
308 self.recv_from = None;
309 self.forward_to = None;
310 }
311
312 pub fn get_fanout(&self) -> u32 {
313 self.fanout
314 }
315
316 pub fn get_recv_from(&self) -> Option<u64> {
317 self.recv_from
318 }
319
320 pub fn get_forward_to(&self) -> Option<u64> {
321 self.forward_to
322 }
323
324 pub fn get_incoming_conn(&self) -> Option<u64> {
325 self.incoming_conn
326 }
327
328 pub fn get_error(&self) -> Option<bool> {
329 self.error
330 }
331
332 pub fn get_source(&self) -> ProtoName {
333 self.source.clone().expect("source not found")
334 }
335
336 pub fn get_encoded_source(&self) -> EncodedName {
337 self.source.as_ref().unwrap().name.unwrap()
338 }
339
340 pub fn get_dst(&self) -> ProtoName {
341 self.destination.clone().expect("destination not found")
342 }
343
344 pub fn get_encoded_dst(&self) -> EncodedName {
345 self.destination.as_ref().unwrap().name.unwrap()
346 }
347
348 pub fn get_identity(&self) -> String {
349 self.identity.clone()
350 }
351
352 pub fn get_version(&self) -> String {
353 self.version.clone()
354 }
355
356 pub fn set_source(&mut self, source: ProtoName) {
357 self.source = Some(source);
358 }
359
360 pub fn set_destination(&mut self, dst: ProtoName) {
361 self.destination = Some(dst);
362 }
363
364 pub fn set_identity(&mut self, identity: String) {
365 self.identity = identity;
366 }
367
368 pub fn set_fanout(&mut self, fanout: u32) {
369 self.fanout = fanout;
370 }
371
372 pub fn set_recv_from(&mut self, recv_from: Option<u64>) {
373 self.recv_from = recv_from;
374 }
375
376 pub fn set_forward_to(&mut self, forward_to: Option<u64>) {
377 self.forward_to = forward_to;
378 }
379
380 pub fn set_error(&mut self, error: Option<bool>) {
381 self.error = error;
382 }
383
384 pub fn set_incoming_conn(&mut self, incoming_conn: Option<u64>) {
385 self.incoming_conn = incoming_conn;
386 }
387
388 pub fn set_error_flag(&mut self, error: Option<bool>) {
389 self.error = error;
390 }
391
392 pub fn get_ttl(&self) -> u32 {
393 self.ttl
394 }
395
396 pub fn set_ttl(&mut self, ttl: u32) {
397 self.ttl = ttl;
398 }
399
400 pub fn decrement_ttl(&mut self) -> u32 {
402 self.ttl = self.ttl.saturating_sub(1);
403 self.ttl
404 }
405
406 #[cfg(not(target_arch = "wasm32"))]
407 pub(crate) fn get_connections(&self) -> (u64, Option<u64>, Option<u64>) {
409 let incoming = self
411 .get_incoming_conn()
412 .expect("incoming connection not found");
413
414 (incoming, self.get_recv_from(), self.get_forward_to())
415 }
416}
417
418impl SessionHeader {
422 pub fn new(
423 session_type: i32,
424 session_message_type: i32,
425 session_id: u32,
426 message_id: u32,
427 ) -> Self {
428 Self {
429 session_type,
430 session_message_type,
431 session_id,
432 message_id,
433 }
434 }
435
436 pub fn get_session_id(&self) -> u32 {
437 self.session_id
438 }
439
440 pub fn get_message_id(&self) -> u32 {
441 self.message_id
442 }
443
444 pub fn set_session_id(&mut self, session_id: u32) {
445 self.session_id = session_id;
446 }
447
448 pub fn set_message_id(&mut self, message_id: u32) {
449 self.message_id = message_id;
450 }
451
452 pub fn clear(&mut self) {
453 self.session_id = 0;
454 self.message_id = 0;
455 }
456}
457
458impl SessionMessageType {
461 pub fn is_command_message(&self) -> bool {
463 matches!(
464 self,
465 SessionMessageType::DiscoveryRequest
466 | SessionMessageType::DiscoveryReply
467 | SessionMessageType::JoinRequest
468 | SessionMessageType::JoinReply
469 | SessionMessageType::LeaveRequest
470 | SessionMessageType::LeaveReply
471 | SessionMessageType::GroupAdd
472 | SessionMessageType::GroupRemove
473 | SessionMessageType::GroupWelcome
474 | SessionMessageType::GroupClose
475 | SessionMessageType::GroupProposal
476 | SessionMessageType::GroupAck
477 | SessionMessageType::GroupNack
478 | SessionMessageType::Ping
479 )
480 }
481}
482
483impl ProtoSubscribe {
486 fn new(
487 source: ProtoName,
488 dst: ProtoName,
489 identity: Option<&str>,
490 flags: Option<SlimHeaderFlags>,
491 ) -> Self {
492 let id = identity.unwrap_or("");
493 let header = Some(SlimHeader::new(source, dst, id, flags));
494
495 ProtoSubscribe {
496 header,
497 subscription_id: 0,
498 }
499 }
500}
501
502impl From<ProtoMessage> for ProtoSubscribe {
504 fn from(message: ProtoMessage) -> Self {
505 match message.message_type {
506 Some(ProtoSubscribeType(s)) => s,
507 _ => panic!("message type is not subscribe"),
508 }
509 }
510}
511
512impl ProtoUnsubscribe {
515 fn new(
516 source: ProtoName,
517 dst: ProtoName,
518 identity: Option<&str>,
519 flags: Option<SlimHeaderFlags>,
520 ) -> Self {
521 let id = identity.unwrap_or("");
522 let header = Some(SlimHeader::new(source, dst, id, flags));
523
524 ProtoUnsubscribe {
525 header,
526 subscription_id: 0,
527 }
528 }
529}
530
531impl From<ProtoMessage> for ProtoUnsubscribe {
533 fn from(message: ProtoMessage) -> Self {
534 match message.message_type {
535 Some(ProtoUnsubscribeType(u)) => u,
536 _ => panic!("message type is not unsubscribe"),
537 }
538 }
539}
540
541impl ProtoPublish {
544 fn with_header(
545 header: Option<SlimHeader>,
546 session: Option<SessionHeader>,
547 payload: Option<Content>,
548 ) -> Self {
549 ProtoPublish {
550 header,
551 session,
552 msg: payload,
553 }
554 }
555
556 pub fn get_slim_header(&self) -> &SlimHeader {
557 self.header.as_ref().unwrap()
558 }
559
560 pub fn get_session_header(&self) -> &SessionHeader {
561 self.session.as_ref().unwrap()
562 }
563
564 pub fn get_slim_header_as_mut(&mut self) -> &mut SlimHeader {
565 self.header.as_mut().unwrap()
566 }
567
568 pub fn get_session_header_as_mut(&mut self) -> &mut SessionHeader {
569 self.session.as_mut().unwrap()
570 }
571
572 pub fn get_payload(&self) -> &Content {
573 self.msg.as_ref().unwrap()
574 }
575
576 pub fn set_payload(&mut self, payload: Content) {
577 self.msg = Some(payload);
578 }
579
580 pub fn is_command(&self) -> bool {
581 match &self.get_payload().content_type.as_ref().unwrap() {
582 ContentType::AppPayload(_) => false,
583 ContentType::CommandPayload(_) => true,
584 }
585 }
586
587 pub fn get_application_payload(&self) -> &ApplicationPayload {
588 match self.get_payload().content_type.as_ref().unwrap() {
589 ContentType::AppPayload(application_payload) => application_payload,
590 ContentType::CommandPayload(_) => panic!("the payload is not an application payload"),
591 }
592 }
593
594 pub fn get_command_payload(&self) -> &CommandPayload {
595 match &self.get_payload().content_type.as_ref().unwrap() {
596 ContentType::AppPayload(_) => panic!("the payaoad is not a command payload"),
597 ContentType::CommandPayload(command_payload) => command_payload,
598 }
599 }
600}
601
602impl From<ProtoMessage> for ProtoPublish {
604 fn from(message: ProtoMessage) -> Self {
605 match message.message_type {
606 Some(ProtoPublishType(p)) => p,
607 _ => panic!("message type is not publish"),
608 }
609 }
610}
611
612macro_rules! impl_payload_extractors {
616 ($($method_name:ident => $getter_method:ident($payload_type:ty)),* $(,)?) => {
617 $(
618 pub fn $method_name(&self) -> Result<&$payload_type, MessageError> {
620 self.extract_command_payload()?.$getter_method()
621 }
622 )*
623 };
624}
625
626impl ProtoMessage {
627 fn new(metadata: HashMap<String, String>, message_type: MessageType) -> Self {
628 ProtoMessage {
629 metadata,
630 message_type: Some(message_type),
631 }
632 }
633
634 fn validate_link(link: &ProtoLink) -> Result<(), MessageError> {
635 if link.link_type.is_none() {
636 return Err(MessageError::LinkTypeNotSet);
637 }
638 Ok(())
639 }
640
641 fn validate_routed_header(slim_header: &SlimHeader) -> Result<(), MessageError> {
642 match &slim_header.source {
643 None => return Err(MessageError::SourceNotFound),
644 Some(src) if src.name.is_none() => return Err(MessageError::SourceEncodedNameNotFound),
645 _ => {}
646 }
647 match &slim_header.destination {
648 None => return Err(MessageError::DestinationNotFound),
649 Some(dst) if dst.name.is_none() => {
650 return Err(MessageError::DestinationEncodedNameNotFound);
651 }
652 _ => {}
653 }
654 Ok(())
655 }
656
657 fn validate_publish(p: &ProtoPublish) -> Result<(), MessageError> {
658 let hdr = p.header.as_ref().ok_or(MessageError::SlimHeaderNotFound)?;
659 Self::validate_routed_header(hdr)?;
660 if p.session.is_none() {
661 return Err(MessageError::SessionHeaderNotFound);
662 }
663 Ok(())
664 }
665
666 fn validate_subscribe(s: &ProtoSubscribe) -> Result<(), MessageError> {
667 let hdr = s.header.as_ref().ok_or(MessageError::SlimHeaderNotFound)?;
668 Self::validate_routed_header(hdr)
669 }
670
671 fn validate_unsubscribe(u: &ProtoUnsubscribe) -> Result<(), MessageError> {
672 let hdr = u.header.as_ref().ok_or(MessageError::SlimHeaderNotFound)?;
673 Self::validate_routed_header(hdr)
674 }
675
676 pub fn validate(&self) -> Result<(), MessageError> {
678 match &self.message_type {
679 None => Err(MessageError::MessageTypeNotFound),
680 Some(ProtoLinkMessageType(link)) => Self::validate_link(link),
681 Some(ProtoPublishType(p)) => Self::validate_publish(p),
682 Some(ProtoSubscribeType(s)) => Self::validate_subscribe(s),
683 Some(ProtoUnsubscribeType(u)) => Self::validate_unsubscribe(u),
684 Some(ProtoSubscriptionAckType(_)) => Ok(()),
685 }
686 }
687
688 pub fn insert_metadata(&mut self, key: String, val: String) {
691 self.metadata.insert(key, val);
692 }
693
694 pub fn remove_metadata(&mut self, key: &str) -> Option<String> {
696 self.metadata.remove(key)
697 }
698
699 pub fn contains_metadata(&self, key: &str) -> bool {
700 self.metadata.contains_key(key)
701 }
702
703 pub fn get_metadata(&self, key: &str) -> Option<&String> {
704 self.metadata.get(key)
705 }
706
707 pub fn get_metadata_map(&self) -> HashMap<String, String> {
708 self.metadata.clone()
709 }
710
711 pub fn set_metadata_map(&mut self, map: HashMap<String, String>) {
712 for (k, v) in map.iter() {
713 self.insert_metadata(k.to_string(), v.to_string());
714 }
715 }
716
717 pub fn get_slim_header(&self) -> &SlimHeader {
718 match &self.message_type {
719 Some(ProtoPublishType(publish)) => publish.header.as_ref().unwrap(),
720 Some(ProtoSubscribeType(sub)) => sub.header.as_ref().unwrap(),
721 Some(ProtoUnsubscribeType(unsub)) => unsub.header.as_ref().unwrap(),
722 Some(ProtoLinkMessageType(_)) | Some(ProtoSubscriptionAckType(_)) | None => {
723 panic!("SLIM header not found")
724 }
725 }
726 }
727
728 pub fn get_slim_header_mut(&mut self) -> &mut SlimHeader {
729 match &mut self.message_type {
730 Some(ProtoPublishType(publish)) => publish.header.as_mut().unwrap(),
731 Some(ProtoSubscribeType(sub)) => sub.header.as_mut().unwrap(),
732 Some(ProtoUnsubscribeType(unsub)) => unsub.header.as_mut().unwrap(),
733 Some(ProtoLinkMessageType(_)) | Some(ProtoSubscriptionAckType(_)) | None => {
734 panic!("SLIM header not found")
735 }
736 }
737 }
738
739 pub fn try_get_slim_header(&self) -> Option<&SlimHeader> {
740 match &self.message_type {
741 Some(ProtoPublishType(publish)) => publish.header.as_ref(),
742 Some(ProtoSubscribeType(sub)) => sub.header.as_ref(),
743 Some(ProtoUnsubscribeType(unsub)) => unsub.header.as_ref(),
744 Some(ProtoLinkMessageType(_)) | Some(ProtoSubscriptionAckType(_)) | None => None,
745 }
746 }
747
748 pub fn get_session_header(&self) -> &SessionHeader {
749 match &self.message_type {
750 Some(ProtoPublishType(publish)) => publish.session.as_ref().unwrap(),
751 Some(ProtoSubscribeType(_))
752 | Some(ProtoUnsubscribeType(_))
753 | Some(ProtoLinkMessageType(_))
754 | Some(ProtoSubscriptionAckType(_))
755 | None => panic!("session header not found"),
756 }
757 }
758
759 pub fn get_session_header_mut(&mut self) -> &mut SessionHeader {
760 match &mut self.message_type {
761 Some(ProtoPublishType(publish)) => publish.session.as_mut().unwrap(),
762 Some(ProtoSubscribeType(_))
763 | Some(ProtoUnsubscribeType(_))
764 | Some(ProtoLinkMessageType(_))
765 | Some(ProtoSubscriptionAckType(_))
766 | None => panic!("session header not found"),
767 }
768 }
769
770 pub fn try_get_session_header(&self) -> Option<&SessionHeader> {
771 match &self.message_type {
772 Some(ProtoPublishType(publish)) => publish.session.as_ref(),
773 Some(ProtoSubscribeType(_))
774 | Some(ProtoUnsubscribeType(_))
775 | Some(ProtoLinkMessageType(_))
776 | Some(ProtoSubscriptionAckType(_))
777 | None => None,
778 }
779 }
780
781 pub fn try_get_session_header_mut(&mut self) -> Option<&mut SessionHeader> {
782 match &mut self.message_type {
783 Some(ProtoPublishType(publish)) => publish.session.as_mut(),
784 Some(ProtoSubscribeType(_))
785 | Some(ProtoUnsubscribeType(_))
786 | Some(ProtoLinkMessageType(_))
787 | Some(ProtoSubscriptionAckType(_))
788 | None => None,
789 }
790 }
791
792 pub fn get_id(&self) -> u32 {
793 self.get_session_header().get_message_id()
794 }
795
796 pub fn get_source(&self) -> ProtoName {
797 self.get_slim_header().get_source()
798 }
799
800 pub fn get_encoded_source(&self) -> EncodedName {
801 self.get_slim_header().get_encoded_source()
802 }
803
804 pub fn get_dst(&self) -> ProtoName {
805 self.get_slim_header().get_dst()
806 }
807
808 pub fn get_encoded_dst(&self) -> EncodedName {
809 self.get_slim_header().get_encoded_dst()
810 }
811
812 pub fn get_identity(&self) -> String {
813 self.get_slim_header().get_identity()
814 }
815
816 pub fn get_fanout(&self) -> u32 {
817 self.get_slim_header().get_fanout()
818 }
819
820 pub fn get_recv_from(&self) -> Option<u64> {
821 self.get_slim_header().get_recv_from()
822 }
823
824 pub fn get_forward_to(&self) -> Option<u64> {
825 self.get_slim_header().get_forward_to()
826 }
827
828 pub fn get_error(&self) -> Option<bool> {
829 self.get_slim_header().get_error()
830 }
831
832 pub fn get_incoming_conn(&self) -> u64 {
833 self.get_slim_header().get_incoming_conn().unwrap()
834 }
835
836 pub fn try_get_incoming_conn(&self) -> Option<u64> {
837 self.get_slim_header().get_incoming_conn()
838 }
839
840 pub fn get_type(&self) -> &MessageType {
841 match &self.message_type {
842 Some(t) => t,
843 None => panic!("message type not found"),
844 }
845 }
846
847 pub fn get_payload(&self) -> Option<&Content> {
848 match &self.message_type {
849 Some(ProtoPublishType(p)) => p.msg.as_ref(),
850 Some(ProtoSubscribeType(_)) => panic!("payload not found"),
851 Some(ProtoUnsubscribeType(_)) => panic!("payload not found"),
852 Some(ProtoLinkMessageType(_)) => panic!("payload not found"),
853 Some(ProtoSubscriptionAckType(_)) => panic!("payload not found"),
854 None => panic!("payload not found"),
855 }
856 }
857
858 pub fn set_payload(&mut self, payload: Content) {
859 match &mut self.message_type {
860 Some(ProtoPublishType(p)) => p.set_payload(payload),
861 Some(ProtoSubscribeType(_)) => panic!("no payload allowed"),
862 Some(ProtoUnsubscribeType(_)) => panic!("no payload allowed"),
863 Some(ProtoLinkMessageType(_)) => panic!("no payload allowed"),
864 Some(ProtoSubscriptionAckType(_)) => panic!("no payload allowed"),
865 None => panic!("no payload allowed"),
866 }
867 }
868
869 pub fn get_session_message_type(&self) -> SessionMessageType {
870 self.get_session_header()
871 .session_message_type
872 .try_into()
873 .unwrap_or_default()
874 }
875
876 pub fn clear_slim_header(&mut self) {
877 if self.is_link() || self.is_subscription_ack() {
878 return;
879 }
880 self.get_slim_header_mut().clear_flags();
881 }
882
883 pub fn set_recv_from(&mut self, recv_from: Option<u64>) {
884 self.get_slim_header_mut().set_recv_from(recv_from);
885 }
886
887 pub fn set_forward_to(&mut self, forward_to: Option<u64>) {
888 self.get_slim_header_mut().set_forward_to(forward_to);
889 }
890
891 pub fn set_error(&mut self, error: Option<bool>) {
892 self.get_slim_header_mut().set_error(error);
893 }
894
895 pub fn set_fanout(&mut self, fanout: u32) {
896 self.get_slim_header_mut().set_fanout(fanout);
897 }
898
899 pub fn set_incoming_conn(&mut self, incoming_conn: Option<u64>) {
900 self.get_slim_header_mut().set_incoming_conn(incoming_conn);
901 }
902
903 pub fn set_error_flag(&mut self, error: Option<bool>) {
904 self.get_slim_header_mut().set_error_flag(error);
905 }
906
907 pub fn get_ttl(&self) -> u32 {
908 self.get_slim_header().get_ttl()
909 }
910
911 pub fn set_ttl(&mut self, ttl: u32) {
912 self.get_slim_header_mut().set_ttl(ttl);
913 }
914
915 pub fn decrement_ttl(&mut self) -> u32 {
917 self.get_slim_header_mut().decrement_ttl()
918 }
919
920 pub fn set_session_message_type(&mut self, message_type: SessionMessageType) {
921 self.get_session_header_mut()
922 .set_session_message_type(message_type);
923 }
924
925 pub fn set_session_type(&mut self, session_type: ProtoSessionType) {
926 self.get_session_header_mut().set_session_type(session_type);
927 }
928
929 pub fn get_session_type(&self) -> ProtoSessionType {
930 self.get_session_header().session_type()
931 }
932
933 pub fn set_message_id(&mut self, message_id: u32) {
934 self.get_session_header_mut().set_message_id(message_id);
935 }
936
937 pub fn is_publish(&self) -> bool {
938 matches!(self.get_type(), MessageType::Publish(_))
939 }
940
941 pub fn is_subscribe(&self) -> bool {
942 matches!(self.get_type(), MessageType::Subscribe(_))
943 }
944
945 pub fn is_unsubscribe(&self) -> bool {
946 matches!(self.get_type(), MessageType::Unsubscribe(_))
947 }
948
949 pub fn is_link(&self) -> bool {
950 matches!(self.get_type(), MessageType::Link(_))
951 }
952
953 pub fn get_link_negotiation_payload(&self) -> Option<LinkNegotiationPayload> {
955 match &self.message_type {
956 Some(ProtoLinkMessageType(link)) => match &link.link_type {
957 Some(ProtoLinkType::LinkNegotiation(payload)) => Some(payload.clone()),
958 _ => None,
959 },
960 _ => None,
961 }
962 }
963
964 pub fn is_subscription_ack(&self) -> bool {
965 matches!(self.get_type(), MessageType::SubscriptionAck(_))
966 }
967
968 pub fn is_traceable(&self) -> bool {
969 !self.is_link() && !self.is_subscription_ack()
970 }
971
972 pub fn get_subscription_ack(&self) -> &ProtoSubscriptionAck {
973 match &self.message_type {
974 Some(ProtoSubscriptionAckType(ack)) => ack,
975 _ => panic!("message type is not subscription_ack"),
976 }
977 }
978
979 pub fn get_subscription_id(&self) -> Option<u64> {
981 match &self.message_type {
982 Some(ProtoSubscribeType(s)) if s.subscription_id != 0 => Some(s.subscription_id),
983 Some(ProtoUnsubscribeType(u)) if u.subscription_id != 0 => Some(u.subscription_id),
984 _ => None,
985 }
986 }
987
988 pub fn take_subscription_id(&mut self) -> Option<u64> {
991 match &mut self.message_type {
992 Some(ProtoSubscribeType(s)) if s.subscription_id != 0 => {
993 Some(std::mem::take(&mut s.subscription_id))
994 }
995 Some(ProtoUnsubscribeType(u)) if u.subscription_id != 0 => {
996 Some(std::mem::take(&mut u.subscription_id))
997 }
998 _ => None,
999 }
1000 }
1001
1002 pub fn set_subscription_id(&mut self, subscription_id: u64) {
1004 match &mut self.message_type {
1005 Some(ProtoSubscribeType(s)) => s.subscription_id = subscription_id,
1006 Some(ProtoUnsubscribeType(u)) => u.subscription_id = subscription_id,
1007 _ => {}
1008 }
1009 }
1010
1011 pub fn extract_command_payload(&self) -> Result<&CommandPayload, MessageError> {
1016 self.get_payload()
1017 .ok_or(MessageError::ContentTypeNotSet)?
1018 .as_command_payload()
1019 }
1020
1021 impl_payload_extractors! {
1023 extract_discovery_request => as_discovery_request_payload(DiscoveryRequestPayload),
1024 extract_discovery_reply => as_discovery_reply_payload(DiscoveryReplyPayload),
1025 extract_join_request => as_join_request_payload(JoinRequestPayload),
1026 extract_join_reply => as_join_reply_payload(JoinReplyPayload),
1027 extract_leave_request => as_leave_request_payload(LeaveRequestPayload),
1028 extract_leave_reply => as_leave_reply_payload(LeaveReplyPayload),
1029 extract_group_add => as_group_add_payload(GroupAddPayload),
1030 extract_group_remove => as_group_remove_payload(GroupRemovePayload),
1031 extract_group_welcome => as_welcome_payload(GroupWelcomePayload),
1032 extract_group_close => as_group_close_payload(GroupClosePayload),
1033 extract_group_proposal => as_group_proposal_payload(GroupProposalPayload),
1034 extract_group_ack => as_group_ack_payload(GroupAckPayload),
1035 extract_group_nack => as_group_nack_payload(GroupNackPayload),
1036 extract_ping => as_ping_payload(PingPayload),
1037 }
1038}
1039
1040impl Content {
1041 pub fn as_application_payload(&self) -> Result<&ApplicationPayload, MessageError> {
1042 match &self.content_type {
1043 Some(ContentType::AppPayload(app_payload)) => Ok(app_payload),
1044 Some(ContentType::CommandPayload(_)) => Err(MessageError::NotApplicationPayload),
1045 None => Err(MessageError::ContentTypeNotSet),
1046 }
1047 }
1048
1049 pub fn as_command_payload(&self) -> Result<&CommandPayload, MessageError> {
1050 match &self.content_type {
1051 Some(ContentType::AppPayload(_)) => Err(MessageError::NotCommandPayload),
1052 Some(ContentType::CommandPayload(comm_payload)) => Ok(comm_payload),
1053 None => Err(MessageError::ContentTypeNotSet),
1054 }
1055 }
1056}
1057
1058impl ApplicationPayload {
1059 pub fn new(payload_type: &str, blob: Vec<u8>) -> Self {
1060 Self {
1061 payload_type: payload_type.to_string(),
1062 blob,
1063 }
1064 }
1065
1066 pub fn as_content(&self) -> Content {
1067 Content {
1068 content_type: Some(ContentType::AppPayload(self.clone())),
1069 }
1070 }
1071}
1072
1073macro_rules! impl_command_payload_getters {
1075 ($(
1076 $method_name:ident => $variant:ident($payload_type:ty)
1077 ),* $(,)?) => {
1078 $(
1079 pub fn $method_name(&self) -> Result<&$payload_type, MessageError> {
1080 match &self.command_payload_type {
1081 Some(CommandPayloadType::$variant(payload)) => Ok(payload),
1082 Some(other) => Err(MessageError::InvalidCommandPayloadType {
1083 expected: Box::new(stringify!($variant).to_string()),
1084 got: Box::new(format!("{:?}", other)),
1085 }),
1086 None => Err(MessageError::InvalidCommandPayloadType {
1087 expected: Box::new(stringify!($variant).to_string()),
1088 got: Box::new("None".to_string()),
1089 }),
1090 }
1091 }
1092 )*
1093 };
1094}
1095
1096impl CommandPayload {
1097 pub fn as_content(self) -> Content {
1098 Content {
1099 content_type: Some(ContentType::CommandPayload(self)),
1100 }
1101 }
1102
1103 impl_command_payload_getters! {
1105 as_discovery_request_payload => DiscoveryRequest(DiscoveryRequestPayload),
1106 as_discovery_reply_payload => DiscoveryReply(DiscoveryReplyPayload),
1107 as_join_request_payload => JoinRequest(JoinRequestPayload),
1108 as_join_reply_payload => JoinReply(JoinReplyPayload),
1109 as_leave_request_payload => LeaveRequest(LeaveRequestPayload),
1110 as_leave_reply_payload => LeaveReply(LeaveReplyPayload),
1111 as_group_add_payload => GroupAdd(GroupAddPayload),
1112 as_group_remove_payload => GroupRemove(GroupRemovePayload),
1113 as_welcome_payload => GroupWelcome(GroupWelcomePayload),
1114 as_group_close_payload => GroupClose(GroupClosePayload),
1115 as_group_proposal_payload => GroupProposal(GroupProposalPayload),
1116 as_group_ack_payload => GroupAck(GroupAckPayload),
1117 as_group_nack_payload => GroupNack(GroupNackPayload),
1118 as_ping_payload => Ping(PingPayload),
1119 }
1120}
1121
1122impl AsRef<ProtoPublish> for ProtoMessage {
1123 fn as_ref(&self) -> &ProtoPublish {
1124 match &self.message_type {
1125 Some(ProtoPublishType(p)) => p,
1126 _ => panic!("message type is not publish"),
1127 }
1128 }
1129}
1130
1131pub struct CommandPayloadBuilder;
1178
1179impl CommandPayloadBuilder {
1180 pub fn new() -> Self {
1182 Self
1183 }
1184
1185 pub fn discovery_request(self) -> CommandPayload {
1187 let payload = DiscoveryRequestPayload {};
1188 CommandPayload {
1189 command_payload_type: Some(CommandPayloadType::DiscoveryRequest(payload)),
1190 }
1191 }
1192
1193 pub fn discovery_reply(self) -> CommandPayload {
1195 let payload = DiscoveryReplyPayload {};
1196 CommandPayload {
1197 command_payload_type: Some(CommandPayloadType::DiscoveryReply(payload)),
1198 }
1199 }
1200
1201 #[allow(deprecated)]
1203 pub fn join_request(
1204 self,
1205 max_retries: Option<u32>,
1206 timer_duration: Option<Duration>,
1207 channel: Option<ProtoName>,
1208 mls_settings: Option<MlsSettings>,
1209 ) -> CommandPayload {
1210 let proto_channel = channel;
1211
1212 let timer_settings = if let Some(t) = timer_duration
1213 && let Some(m) = max_retries
1214 {
1215 Some(TimerSettings {
1216 timeout: t.as_millis() as u32,
1217 max_retries: m,
1218 })
1219 } else {
1220 None
1221 };
1222
1223 let payload = JoinRequestPayload {
1224 timer_settings,
1225 channel: proto_channel,
1226 mls_settings,
1227 };
1228 CommandPayload {
1229 command_payload_type: Some(CommandPayloadType::JoinRequest(payload)),
1230 }
1231 }
1232
1233 pub fn join_reply(
1235 self,
1236 key_package: Option<Vec<u8>>,
1237 participant: Participant,
1238 ) -> CommandPayload {
1239 let payload = JoinReplyPayload {
1240 key_package,
1241 participant: Some(participant),
1242 };
1243 CommandPayload {
1244 command_payload_type: Some(CommandPayloadType::JoinReply(payload)),
1245 }
1246 }
1247
1248 pub fn leave_request(self) -> CommandPayload {
1250 let payload = LeaveRequestPayload {};
1251 CommandPayload {
1252 command_payload_type: Some(CommandPayloadType::LeaveRequest(payload)),
1253 }
1254 }
1255
1256 pub fn leave_reply(self) -> CommandPayload {
1258 let payload = LeaveReplyPayload {};
1259 CommandPayload {
1260 command_payload_type: Some(CommandPayloadType::LeaveReply(payload)),
1261 }
1262 }
1263
1264 pub fn group_add(
1266 self,
1267 new_participant: Participant,
1268 participants: Vec<Participant>,
1269 mls: Option<MlsPayload>,
1270 ) -> CommandPayload {
1271 let payload = GroupAddPayload {
1272 new_participant: Some(new_participant),
1273 participants,
1274 mls,
1275 };
1276 CommandPayload {
1277 command_payload_type: Some(CommandPayloadType::GroupAdd(payload)),
1278 }
1279 }
1280
1281 pub fn group_remove(
1283 self,
1284 removed_participant: ProtoName,
1285 participants: Vec<ProtoName>,
1286 mls: Option<MlsPayload>,
1287 ) -> CommandPayload {
1288 let payload = GroupRemovePayload {
1289 removed_participant: Some(removed_participant),
1290 participants,
1291 mls,
1292 };
1293 CommandPayload {
1294 command_payload_type: Some(CommandPayloadType::GroupRemove(payload)),
1295 }
1296 }
1297
1298 pub fn group_welcome(
1300 self,
1301 participants: Vec<Participant>,
1302 mls: Option<MlsPayload>,
1303 ) -> CommandPayload {
1304 let payload = GroupWelcomePayload { participants, mls };
1305 CommandPayload {
1306 command_payload_type: Some(CommandPayloadType::GroupWelcome(payload)),
1307 }
1308 }
1309
1310 pub fn group_close(self, participants: Vec<ProtoName>) -> CommandPayload {
1312 let payload = GroupClosePayload { participants };
1313 CommandPayload {
1314 command_payload_type: Some(CommandPayloadType::GroupClose(payload)),
1315 }
1316 }
1317
1318 pub fn group_proposal(
1320 self,
1321 source: Option<ProtoName>,
1322 mls_proposal: Vec<u8>,
1323 ) -> CommandPayload {
1324 let payload = GroupProposalPayload {
1325 source,
1326 mls_proposal,
1327 };
1328 CommandPayload {
1329 command_payload_type: Some(CommandPayloadType::GroupProposal(payload)),
1330 }
1331 }
1332
1333 pub fn group_ack(self) -> CommandPayload {
1335 let payload = GroupAckPayload {};
1336 CommandPayload {
1337 command_payload_type: Some(CommandPayloadType::GroupAck(payload)),
1338 }
1339 }
1340
1341 pub fn group_nack(self) -> CommandPayload {
1343 let payload = GroupNackPayload {};
1344 CommandPayload {
1345 command_payload_type: Some(CommandPayloadType::GroupNack(payload)),
1346 }
1347 }
1348
1349 pub fn ping(self) -> CommandPayload {
1351 let payload = PingPayload {};
1352 CommandPayload {
1353 command_payload_type: Some(CommandPayloadType::Ping(payload)),
1354 }
1355 }
1356}
1357
1358impl Default for CommandPayloadBuilder {
1359 fn default() -> Self {
1360 Self::new()
1361 }
1362}
1363
1364impl CommandPayload {
1365 pub fn builder() -> CommandPayloadBuilder {
1367 CommandPayloadBuilder::new()
1368 }
1369}
1370
1371pub struct ProtoMessageBuilder {
1457 source: Option<ProtoName>,
1458 destination: Option<ProtoName>,
1459 identity: Option<String>,
1460 flags: Option<SlimHeaderFlags>,
1461 session_type: Option<ProtoSessionType>,
1462 session_message_type: Option<SessionMessageType>,
1463 session_id: Option<u32>,
1464 message_id: Option<u32>,
1465 payload: Option<Content>,
1466 metadata: HashMap<String, String>,
1467 subscription_id: Option<u64>,
1468}
1469
1470impl ProtoMessageBuilder {
1471 pub fn new() -> Self {
1473 Self {
1474 source: None,
1475 destination: None,
1476 identity: None,
1477 flags: None,
1478 session_type: None,
1479 session_message_type: None,
1480 session_id: None,
1481 message_id: None,
1482 payload: None,
1483 metadata: HashMap::new(),
1484 subscription_id: None,
1485 }
1486 }
1487
1488 pub fn source(mut self, source: ProtoName) -> Self {
1490 self.source = Some(source);
1491 self
1492 }
1493
1494 pub fn destination(mut self, destination: ProtoName) -> Self {
1496 self.destination = Some(destination);
1497 self
1498 }
1499
1500 pub fn identity(mut self, identity: impl Into<String>) -> Self {
1502 self.identity = Some(identity.into());
1503 self
1504 }
1505
1506 pub fn flags(mut self, flags: SlimHeaderFlags) -> Self {
1508 self.flags = Some(flags);
1509 self
1510 }
1511
1512 pub fn fanout(mut self, fanout: u32) -> Self {
1514 self.flags.get_or_insert_default().fanout = fanout;
1515 self
1516 }
1517
1518 pub fn recv_from(mut self, recv_from: u64) -> Self {
1520 self.flags.get_or_insert_default().recv_from = Some(recv_from);
1521 self
1522 }
1523
1524 pub fn forward_to(mut self, forward_to: u64) -> Self {
1526 self.flags.get_or_insert_default().forward_to = Some(forward_to);
1527 self
1528 }
1529
1530 pub fn incoming_conn(mut self, incoming_conn: u64) -> Self {
1532 self.flags.get_or_insert_default().incoming_conn = Some(incoming_conn);
1533 self
1534 }
1535
1536 pub fn error(mut self, error: bool) -> Self {
1538 self.flags.get_or_insert_default().error = Some(error);
1539 self
1540 }
1541
1542 pub fn ttl(mut self, ttl: u32) -> Self {
1544 self.flags.get_or_insert_default().ttl = ttl;
1545 self
1546 }
1547
1548 pub fn session_type(mut self, session_type: ProtoSessionType) -> Self {
1550 self.session_type = Some(session_type);
1551 self
1552 }
1553
1554 pub fn session_message_type(mut self, session_message_type: SessionMessageType) -> Self {
1556 self.session_message_type = Some(session_message_type);
1557 self
1558 }
1559
1560 pub fn session_id(mut self, session_id: u32) -> Self {
1562 self.session_id = Some(session_id);
1563 self
1564 }
1565
1566 pub fn message_id(mut self, message_id: u32) -> Self {
1568 self.message_id = Some(message_id);
1569 self
1570 }
1571
1572 pub fn payload(mut self, payload: Content) -> Self {
1574 self.payload = Some(payload);
1575 self
1576 }
1577
1578 pub fn application_payload(mut self, payload_type: &str, blob: Vec<u8>) -> Self {
1580 let app_payload = ApplicationPayload::new(payload_type, blob);
1581 self.payload = Some(app_payload.as_content());
1582 self
1583 }
1584
1585 pub fn command_payload(mut self, payload: CommandPayload) -> Self {
1587 self.payload = Some(payload.as_content());
1588 self
1589 }
1590
1591 pub fn with_slim_header(mut self, header: SlimHeader) -> Self {
1596 if let Some(src) = header.source.clone() {
1598 self.source = Some(src);
1599 }
1600 if let Some(dst) = header.destination.clone() {
1601 self.destination = Some(dst);
1602 }
1603 if !header.identity.is_empty() {
1604 self.identity = Some(header.identity.clone());
1605 }
1606
1607 let flags = SlimHeaderFlags {
1609 fanout: header.fanout,
1610 recv_from: header.recv_from,
1611 forward_to: header.forward_to,
1612 incoming_conn: header.incoming_conn,
1613 error: header.error,
1614 ttl: header.ttl,
1615 };
1616 self.flags = Some(flags);
1617 self
1618 }
1619
1620 pub fn with_session_header(mut self, header: SessionHeader) -> Self {
1625 self.session_type = Some(
1626 ProtoSessionType::try_from(header.session_type)
1627 .unwrap_or(ProtoSessionType::PointToPoint),
1628 );
1629 self.session_message_type = Some(
1630 SessionMessageType::try_from(header.session_message_type)
1631 .unwrap_or(SessionMessageType::Msg),
1632 );
1633 self.session_id = Some(header.session_id);
1634 self.message_id = Some(header.message_id);
1635 self
1636 }
1637
1638 pub fn metadata(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
1640 self.metadata.insert(key.into(), value.into());
1641 self
1642 }
1643
1644 pub fn metadata_map(mut self, map: HashMap<String, String>) -> Self {
1646 self.metadata.extend(map);
1647 self
1648 }
1649
1650 pub fn subscription_id(mut self, id: u64) -> Self {
1652 self.subscription_id = Some(id);
1653 self
1654 }
1655
1656 pub fn build_publish(self) -> Result<ProtoMessage, MessageError> {
1658 let source = self
1659 .source
1660 .ok_or(MessageError::BuilderErrorSourceRequired)?;
1661 let destination = self
1662 .destination
1663 .ok_or(MessageError::BuilderErrorDestinationRequired)?;
1664
1665 let slim_header = Some(SlimHeader::new(
1666 source,
1667 destination,
1668 self.identity.as_deref().unwrap_or(""),
1669 self.flags,
1670 ));
1671
1672 let session_header = if self.session_type.is_some() || self.session_message_type.is_some() {
1673 Some(SessionHeader::new(
1674 self.session_type
1675 .unwrap_or(ProtoSessionType::PointToPoint)
1676 .into(),
1677 self.session_message_type
1678 .unwrap_or(SessionMessageType::Msg)
1679 .into(),
1680 self.session_id.unwrap_or(0),
1681 self.message_id.unwrap_or_else(rand::random),
1682 ))
1683 } else {
1684 Some(SessionHeader::default())
1685 };
1686
1687 let publish = ProtoPublish::with_header(slim_header, session_header, self.payload);
1688 let message = ProtoMessage::new(self.metadata, ProtoPublishType(publish));
1689 Ok(message)
1690 }
1691
1692 pub fn build_subscribe(self) -> Result<ProtoMessage, MessageError> {
1694 let source = self
1695 .source
1696 .ok_or(MessageError::BuilderErrorSourceRequired)?;
1697 let destination = self
1698 .destination
1699 .ok_or(MessageError::BuilderErrorDestinationRequired)?;
1700
1701 let mut subscribe =
1702 ProtoSubscribe::new(source, destination, self.identity.as_deref(), self.flags);
1703 subscribe.subscription_id = self.subscription_id.unwrap_or_default();
1704
1705 Ok(ProtoMessage::new(
1706 self.metadata,
1707 ProtoSubscribeType(subscribe),
1708 ))
1709 }
1710
1711 pub fn build_unsubscribe(self) -> Result<ProtoMessage, MessageError> {
1713 let source = self
1714 .source
1715 .ok_or(MessageError::BuilderErrorSourceRequired)?;
1716 let destination = self
1717 .destination
1718 .ok_or(MessageError::BuilderErrorDestinationRequired)?;
1719
1720 let mut unsubscribe =
1721 ProtoUnsubscribe::new(source, destination, self.identity.as_deref(), self.flags);
1722 unsubscribe.subscription_id = self.subscription_id.unwrap_or_default();
1723
1724 Ok(ProtoMessage::new(
1725 self.metadata,
1726 ProtoUnsubscribeType(unsubscribe),
1727 ))
1728 }
1729
1730 pub fn build_subscription_ack(
1734 self,
1735 subscription_id: u64,
1736 success: bool,
1737 error: impl Into<String>,
1738 ) -> ProtoMessage {
1739 let ack = ProtoSubscriptionAck {
1740 subscription_id,
1741 success,
1742 error: error.into(),
1743 };
1744 ProtoMessage::new(self.metadata, ProtoSubscriptionAckType(ack))
1745 }
1746
1747 #[allow(clippy::too_many_arguments)]
1750 pub fn build_link_negotiation(
1751 self,
1752 link_id: impl Into<String>,
1753 slim_version: impl Into<String>,
1754 is_reply: bool,
1755 link_ecdh_public_key: Option<Vec<u8>>,
1756 connection_type: LinkConnectionType,
1757 node_id: impl Into<String>,
1758 deployment_name: impl Into<String>,
1759 ) -> ProtoMessage {
1760 let link_ecdh_public_key = link_ecdh_public_key.unwrap_or_default();
1761 let link = ProtoLink {
1762 link_type: Some(ProtoLinkType::LinkNegotiation(LinkNegotiationPayload {
1763 link_id: link_id.into(),
1764 slim_version: slim_version.into(),
1765 is_reply,
1766 link_ecdh_public_key,
1767 connection_type: connection_type.into(),
1768 node_id: node_id.into(),
1769 deployment_name: deployment_name.into(),
1770 })),
1771 };
1772 ProtoMessage::new(self.metadata, ProtoLinkMessageType(link))
1773 }
1774}
1775
1776impl Default for ProtoMessageBuilder {
1777 fn default() -> Self {
1778 Self::new()
1779 }
1780}
1781
1782impl ProtoMessage {
1783 pub fn builder() -> ProtoMessageBuilder {
1785 ProtoMessageBuilder::new()
1786 }
1787}
1788
1789#[cfg(test)]
1790mod tests {
1791 use crate::api::proto::dataplane::v1::SessionMessageType;
1792
1793 use super::*;
1794
1795 fn test_subscription_template(
1796 subscription: bool,
1797 source: ProtoName,
1798 dst: ProtoName,
1799 identity: Option<&str>,
1800 flags: Option<SlimHeaderFlags>,
1801 ) {
1802 let sub = {
1803 let mut builder = ProtoMessage::builder()
1804 .source(source.clone())
1805 .destination(dst.clone());
1806
1807 if let Some(id) = identity {
1808 builder = builder.identity(id);
1809 }
1810
1811 if let Some(f) = flags.clone() {
1812 builder = builder.flags(f);
1813 }
1814
1815 if subscription {
1816 builder.build_subscribe().unwrap()
1817 } else {
1818 builder.build_unsubscribe().unwrap()
1819 }
1820 };
1821
1822 let flags = if flags.is_none() {
1823 Some(SlimHeaderFlags::default())
1824 } else {
1825 flags
1826 };
1827
1828 assert!(!sub.is_publish());
1829 assert_eq!(sub.is_subscribe(), subscription);
1830 assert_eq!(sub.is_unsubscribe(), !subscription);
1831 assert_eq!(flags.as_ref().unwrap().recv_from, sub.get_recv_from());
1832 assert_eq!(flags.as_ref().unwrap().forward_to, sub.get_forward_to());
1833 assert_eq!(None, sub.try_get_incoming_conn());
1834 assert_eq!(source, sub.get_source());
1835 let got_name = sub.get_dst();
1836 assert_eq!(dst, got_name);
1837 }
1838
1839 fn test_publish_template(
1840 source: ProtoName,
1841 dst: ProtoName,
1842 identity: Option<&str>,
1843 flags: Option<SlimHeaderFlags>,
1844 ) {
1845 let mut builder = ProtoMessage::builder()
1846 .source(source.clone())
1847 .destination(dst.clone())
1848 .application_payload("str", "this is the content of the message".into());
1849
1850 if let Some(id) = identity {
1851 builder = builder.identity(id);
1852 }
1853
1854 if let Some(f) = flags.clone() {
1855 builder = builder.flags(f);
1856 }
1857
1858 let pub_msg = builder.build_publish().unwrap();
1859
1860 let flags = if flags.is_none() {
1861 Some(SlimHeaderFlags::default())
1862 } else {
1863 flags
1864 };
1865
1866 assert!(pub_msg.is_publish());
1867 assert!(!pub_msg.is_subscribe());
1868 assert!(!pub_msg.is_unsubscribe());
1869 assert_eq!(flags.as_ref().unwrap().recv_from, pub_msg.get_recv_from());
1870 assert_eq!(flags.as_ref().unwrap().forward_to, pub_msg.get_forward_to());
1871 assert_eq!(None, pub_msg.try_get_incoming_conn());
1872 assert_eq!(source, pub_msg.get_source());
1873 let got_name = pub_msg.get_dst();
1874 assert_eq!(dst, got_name);
1875 assert_eq!(flags.as_ref().unwrap().fanout, pub_msg.get_fanout());
1876 }
1877
1878 #[test]
1879 fn test_subscription() {
1880 let source = ProtoName::from_strings(["org", "ns", "type"]).with_id(1);
1881 let dst = ProtoName::from_strings(["org", "ns", "type"]).with_id(2);
1882
1883 test_subscription_template(true, source.clone(), dst.clone(), None, None);
1885
1886 test_subscription_template(true, source.clone(), dst.clone(), None, None);
1888
1889 test_subscription_template(
1891 true,
1892 source.clone(),
1893 dst.clone(),
1894 None,
1895 Some(SlimHeaderFlags::default().with_recv_from(50)),
1896 );
1897
1898 test_subscription_template(
1900 true,
1901 source.clone(),
1902 dst.clone(),
1903 None,
1904 Some(SlimHeaderFlags::default().with_forward_to(30)),
1905 );
1906 }
1907
1908 #[test]
1909 fn test_unsubscription() {
1910 let source = ProtoName::from_strings(["org", "ns", "type"]).with_id(1);
1911 let dst = ProtoName::from_strings(["org", "ns", "type"]).with_id(2);
1912
1913 test_subscription_template(false, source.clone(), dst.clone(), None, None);
1915
1916 test_subscription_template(false, source.clone(), dst.clone(), None, None);
1918
1919 test_subscription_template(
1921 false,
1922 source.clone(),
1923 dst.clone(),
1924 None,
1925 Some(SlimHeaderFlags::default().with_recv_from(50)),
1926 );
1927
1928 test_subscription_template(
1930 false,
1931 source.clone(),
1932 dst.clone(),
1933 None,
1934 Some(SlimHeaderFlags::default().with_forward_to(30)),
1935 );
1936 }
1937
1938 #[test]
1939 fn test_publish() {
1940 let source = ProtoName::from_strings(["org", "ns", "type"]).with_id(1);
1941 let mut dst = ProtoName::from_strings(["org", "ns", "type"]);
1942
1943 test_publish_template(
1945 source.clone(),
1946 dst.clone(),
1947 None,
1948 Some(SlimHeaderFlags::default()),
1949 );
1950
1951 dst.set_id(2);
1953 test_publish_template(
1954 source.clone(),
1955 dst.clone(),
1956 None,
1957 Some(SlimHeaderFlags::default()),
1958 );
1959 dst.reset_id();
1960
1961 test_publish_template(
1963 source.clone(),
1964 dst.clone(),
1965 None,
1966 Some(SlimHeaderFlags::default().with_recv_from(50)),
1967 );
1968
1969 test_publish_template(
1971 source.clone(),
1972 dst.clone(),
1973 None,
1974 Some(SlimHeaderFlags::default().with_forward_to(30)),
1975 );
1976
1977 test_publish_template(
1979 source.clone(),
1980 dst.clone(),
1981 None,
1982 Some(SlimHeaderFlags::default().with_fanout(2)),
1983 );
1984 }
1985
1986 #[test]
1987 fn test_conversions() {
1988 let name = ProtoName::from_strings(["org", "ns", "type"]).with_id(1);
1990 let dst = ProtoName::from_strings(["org", "ns", "type"]).with_id(1);
1991 let proto_subscribe = ProtoMessage::builder()
1992 .source(name.clone())
1993 .destination(dst.clone())
1994 .flags(
1995 SlimHeaderFlags::default()
1996 .with_recv_from(2)
1997 .with_forward_to(3),
1998 )
1999 .build_subscribe()
2000 .unwrap();
2001 let proto_subscribe = ProtoSubscribe::from(proto_subscribe);
2002 assert_eq!(proto_subscribe.header.as_ref().unwrap().get_source(), name);
2003 assert_eq!(proto_subscribe.header.as_ref().unwrap().get_dst(), dst,);
2004
2005 let proto_unsubscribe = ProtoMessage::builder()
2007 .source(name.clone())
2008 .destination(dst.clone())
2009 .flags(
2010 SlimHeaderFlags::default()
2011 .with_recv_from(2)
2012 .with_forward_to(3),
2013 )
2014 .build_unsubscribe()
2015 .unwrap();
2016 let proto_unsubscribe = ProtoUnsubscribe::from(proto_unsubscribe);
2017 assert_eq!(
2018 proto_unsubscribe.header.as_ref().unwrap().get_source(),
2019 name
2020 );
2021 assert_eq!(proto_unsubscribe.header.as_ref().unwrap().get_dst(), dst);
2022
2023 let proto_publish = ProtoMessage::builder()
2025 .source(name.clone())
2026 .destination(dst.clone())
2027 .flags(
2028 SlimHeaderFlags::default()
2029 .with_recv_from(2)
2030 .with_forward_to(3),
2031 )
2032 .application_payload("str", "this is the content of the message".into())
2033 .build_publish()
2034 .unwrap();
2035 let proto_publish = ProtoPublish::from(proto_publish);
2036 assert_eq!(proto_publish.header.as_ref().unwrap().get_source(), name);
2037 assert_eq!(proto_publish.header.as_ref().unwrap().get_dst(), dst);
2038 }
2039
2040 #[test]
2041 fn test_panic() {
2042 let source = ProtoName::from_strings(["org", "ns", "type"]).with_id(1);
2043 let dst = ProtoName::from_strings(["org", "ns", "type"]).with_id(2);
2044
2045 let msg = ProtoMessage::builder()
2047 .source(source.clone())
2048 .destination(dst.clone())
2049 .flags(
2050 SlimHeaderFlags::default()
2051 .with_recv_from(2)
2052 .with_forward_to(3),
2053 )
2054 .build_subscribe()
2055 .unwrap();
2056
2057 let result = std::panic::catch_unwind(|| ProtoUnsubscribe::from(msg.clone()));
2060 assert!(result.is_err());
2061
2062 let result = std::panic::catch_unwind(|| ProtoPublish::from(msg.clone()));
2065 assert!(result.is_err());
2066
2067 let result = std::panic::catch_unwind(|| ProtoSubscribe::from(msg));
2069 assert!(result.is_ok());
2070 }
2071
2072 #[test]
2073 fn test_panic_header() {
2074 let header = SlimHeader {
2076 source: None,
2077 destination: None,
2078 identity: String::new(),
2079 fanout: 0,
2080 version: version().to_string(),
2081 recv_from: None,
2082 forward_to: None,
2083 incoming_conn: None,
2084 error: None,
2085 header_mac: None,
2086 ttl: DEFAULT_TTL,
2087 };
2088
2089 let result = std::panic::catch_unwind(|| header.get_source());
2091 assert!(result.is_err());
2092
2093 let result = std::panic::catch_unwind(|| header.get_dst());
2094 assert!(result.is_err());
2095
2096 let result = std::panic::catch_unwind(|| header.get_recv_from());
2098 assert!(result.is_ok());
2099
2100 let result = std::panic::catch_unwind(|| header.get_forward_to());
2101 assert!(result.is_ok());
2102
2103 let result = std::panic::catch_unwind(|| header.get_incoming_conn());
2105 assert!(result.is_ok());
2106
2107 let result = std::panic::catch_unwind(|| header.get_error());
2109 assert!(result.is_ok());
2110 }
2111
2112 #[test]
2113 fn test_panic_session_header() {
2114 let header = SessionHeader::new(0, 0, 0, 0);
2116
2117 let result = std::panic::catch_unwind(|| header.get_session_id());
2119 assert!(result.is_ok());
2120
2121 let result = std::panic::catch_unwind(|| header.get_message_id());
2122 assert!(result.is_ok());
2123 }
2124
2125 #[test]
2126 fn test_panic_proto_message() {
2127 let message = ProtoMessage {
2129 metadata: HashMap::new(),
2130 message_type: None,
2131 };
2132
2133 let result = std::panic::catch_unwind(|| message.get_slim_header());
2135 assert!(result.is_err());
2136
2137 let result = std::panic::catch_unwind(|| message.get_type());
2139 assert!(result.is_err());
2140
2141 let result = std::panic::catch_unwind(|| message.get_source());
2143 assert!(result.is_err());
2144 let result = std::panic::catch_unwind(|| message.get_dst());
2145 assert!(result.is_err());
2146 let result = std::panic::catch_unwind(|| message.get_recv_from());
2147 assert!(result.is_err());
2148 let result = std::panic::catch_unwind(|| message.get_forward_to());
2149 assert!(result.is_err());
2150 let result = std::panic::catch_unwind(|| message.get_incoming_conn());
2151 assert!(result.is_err());
2152 let result = std::panic::catch_unwind(|| message.get_fanout());
2153 assert!(result.is_err());
2154 }
2155
2156 #[test]
2157 fn test_service_type_to_int() {
2158 let total_service_types = SessionMessageType::Ping as i32;
2160
2161 for i in 0..total_service_types {
2162 let service_type =
2164 SessionMessageType::try_from(i).expect("failed to convert int to service type");
2165 let service_type_int = i32::from(service_type);
2166 assert_eq!(service_type_int, i32::from(service_type),);
2167 }
2168
2169 let invalid_service_type = SessionMessageType::try_from(total_service_types + 1);
2171 assert!(invalid_service_type.is_err());
2172 }
2173
2174 #[test]
2175 fn test_proto_message_builder() {
2176 let source = ProtoName::from_strings(["org", "ns", "type"]).with_id(1);
2177 let dest = ProtoName::from_strings(["org", "ns", "app"]).with_id(2);
2178
2179 let msg = ProtoMessage::builder()
2181 .source(source.clone())
2182 .destination(dest.clone())
2183 .application_payload("test", b"hello world".to_vec())
2184 .build_publish()
2185 .unwrap();
2186
2187 assert!(msg.is_publish());
2188 assert_eq!(msg.get_source(), source);
2189 assert_eq!(msg.get_dst(), dest);
2190
2191 let msg = ProtoMessage::builder()
2193 .source(source.clone())
2194 .destination(dest.clone())
2195 .session_type(ProtoSessionType::Multicast)
2196 .session_message_type(SessionMessageType::Msg)
2197 .session_id(42)
2198 .message_id(100)
2199 .fanout(256)
2200 .application_payload("test", b"broadcast".to_vec())
2201 .build_publish()
2202 .unwrap();
2203
2204 assert_eq!(msg.get_session_type(), ProtoSessionType::Multicast);
2205 assert_eq!(msg.get_id(), 100);
2206 assert_eq!(msg.get_fanout(), 256);
2207
2208 let msg = ProtoMessage::builder()
2210 .source(source.clone())
2211 .destination(dest.clone())
2212 .metadata("key1", "value1")
2213 .metadata("key2", "value2")
2214 .application_payload("test", vec![1, 2, 3])
2215 .build_publish()
2216 .unwrap();
2217
2218 assert_eq!(msg.get_metadata("key1"), Some(&"value1".to_string()));
2219 assert_eq!(msg.get_metadata("key2"), Some(&"value2".to_string()));
2220
2221 let msg = ProtoMessage::builder()
2223 .source(source.clone())
2224 .destination(dest.clone())
2225 .recv_from(10)
2226 .build_subscribe()
2227 .unwrap();
2228
2229 assert!(msg.is_subscribe());
2230 assert_eq!(msg.get_recv_from(), Some(10));
2231
2232 let msg = ProtoMessage::builder()
2234 .source(source.clone())
2235 .destination(dest.clone())
2236 .forward_to(20)
2237 .build_unsubscribe()
2238 .unwrap();
2239
2240 assert!(msg.is_unsubscribe());
2241 assert_eq!(msg.get_forward_to(), Some(20));
2242 }
2243
2244 #[test]
2245 fn test_command_payload_builder() {
2246 let dest = ProtoName::from_strings(["org", "ns", "app"]);
2247
2248 let payload = CommandPayload::builder().discovery_request();
2250 assert!(payload.as_discovery_request_payload().is_ok());
2251
2252 let payload = CommandPayload::builder().discovery_reply();
2254 assert!(payload.as_discovery_reply_payload().is_ok());
2255
2256 let payload = CommandPayload::builder().join_request(
2258 Some(5),
2259 Some(Duration::from_secs(10)),
2260 Some(dest.clone()),
2261 Some(MlsSettings::default()),
2262 );
2263 let extracted = payload.as_join_request_payload().unwrap();
2264 assert!(extracted.mls_settings.is_some());
2265 assert!(extracted.timer_settings.is_some());
2266
2267 let participant = Participant::new(dest.clone(), ParticipantSettings::bidirectional());
2269 let payload =
2270 CommandPayload::builder().join_reply(Some(vec![1, 2, 3]), participant.clone());
2271 let extracted = payload.as_join_reply_payload().unwrap();
2272 assert_eq!(extracted.key_package, Some(vec![1, 2, 3]));
2273 assert_eq!(extracted.participant, Some(participant));
2274
2275 let payload = CommandPayload::builder().leave_request();
2277 assert!(payload.as_leave_request_payload().is_ok());
2278
2279 let payload = CommandPayload::builder().leave_reply();
2281 assert!(payload.as_leave_reply_payload().is_ok());
2282
2283 let participant = Participant::new(dest.clone(), ParticipantSettings::bidirectional());
2285 let participants = vec![participant.clone()];
2286 let payload =
2287 CommandPayload::builder().group_add(participant.clone(), participants.clone(), None);
2288 let extracted = payload.as_group_add_payload().unwrap();
2289 assert_eq!(extracted.new_participant, Some(participant));
2290 assert_eq!(extracted.participants, participants);
2291
2292 let payload =
2294 CommandPayload::builder().group_remove(dest.clone(), vec![dest.clone()], None);
2295 let extracted = payload.as_group_remove_payload().unwrap();
2296 assert!(extracted.removed_participant.is_some());
2297
2298 let payload = CommandPayload::builder().group_welcome(participants.clone(), None);
2300 let extracted = payload.as_welcome_payload().unwrap();
2301 assert!(!extracted.participants.is_empty());
2302
2303 let payload = CommandPayload::builder().group_proposal(Some(dest.clone()), vec![4, 5, 6]);
2305 let extracted = payload.as_group_proposal_payload().unwrap();
2306 assert_eq!(extracted.mls_proposal, vec![4, 5, 6]);
2307
2308 let payload = CommandPayload::builder().group_ack();
2310 assert!(payload.as_group_ack_payload().is_ok());
2311
2312 let payload = CommandPayload::builder().group_nack();
2314 assert!(payload.as_group_nack_payload().is_ok());
2315
2316 let payload = CommandPayload::builder().ping();
2318 assert!(payload.as_ping_payload().is_ok());
2319 }
2320
2321 #[test]
2322 fn test_builder_with_command_payload() {
2323 let source = ProtoName::from_strings(["org", "ns", "type"]).with_id(1);
2324 let dest = ProtoName::from_strings(["org", "ns", "app"]).with_id(2);
2325
2326 let cmd_payload = CommandPayload::builder().discovery_request();
2327
2328 let msg = ProtoMessage::builder()
2329 .source(source.clone())
2330 .destination(dest.clone())
2331 .session_type(ProtoSessionType::PointToPoint)
2332 .session_message_type(SessionMessageType::DiscoveryRequest)
2333 .session_id(1)
2334 .command_payload(cmd_payload)
2335 .build_publish()
2336 .unwrap();
2337
2338 assert!(msg.is_publish());
2339 assert_eq!(
2340 msg.get_session_message_type(),
2341 SessionMessageType::DiscoveryRequest
2342 );
2343 }
2344
2345 #[test]
2346 fn test_validate_link_without_link_type() {
2347 let link = ProtoLink { link_type: None };
2348 let msg = ProtoMessage::new(HashMap::new(), ProtoLinkMessageType(link));
2349 assert!(matches!(msg.validate(), Err(MessageError::LinkTypeNotSet)));
2350 }
2351
2352 #[test]
2353 fn test_validate_link_with_link_type() {
2354 let link = ProtoLink {
2355 link_type: Some(ProtoLinkType::LinkNegotiation(LinkNegotiationPayload {
2356 link_id: "abc".into(),
2357 slim_version: "1.0.0".into(),
2358 is_reply: false,
2359 link_ecdh_public_key: vec![],
2360 connection_type: 0,
2361 node_id: String::new(),
2362 deployment_name: String::new(),
2363 })),
2364 };
2365 let msg = ProtoMessage::new(HashMap::new(), ProtoLinkMessageType(link));
2366 assert!(msg.validate().is_ok());
2367 }
2368
2369 #[test]
2370 fn test_build_link_negotiation_request() {
2371 let msg = ProtoMessage::builder().build_link_negotiation(
2372 "my-id",
2373 "1.2.3",
2374 false,
2375 None,
2376 LinkConnectionType::Remote,
2377 "",
2378 "",
2379 );
2380 assert!(msg.is_link());
2381 assert!(!msg.is_publish());
2382 assert!(!msg.is_subscribe());
2383 assert!(msg.validate().is_ok());
2384 }
2385
2386 #[test]
2387 fn test_build_link_negotiation_reply() {
2388 let msg = ProtoMessage::builder().build_link_negotiation(
2389 "my-id",
2390 "1.2.3",
2391 true,
2392 None,
2393 LinkConnectionType::Remote,
2394 "",
2395 "",
2396 );
2397 assert!(msg.is_link());
2398 assert!(msg.validate().is_ok());
2399 }
2400
2401 #[test]
2402 fn test_validate_subscribe_missing_source_encoded_name() {
2403 let valid = ProtoName::from_strings(["org", "ns", "agent"]);
2404 let hdr = SlimHeader {
2405 source: Some(ProtoName {
2406 name: None,
2407 str_name: None,
2408 }),
2409 destination: Some(valid),
2410 ..Default::default()
2411 };
2412 let msg = ProtoMessage::new(
2413 HashMap::new(),
2414 ProtoSubscribeType(ProtoSubscribe {
2415 header: Some(hdr),
2416 ..Default::default()
2417 }),
2418 );
2419 assert!(matches!(
2420 msg.validate(),
2421 Err(MessageError::SourceEncodedNameNotFound)
2422 ));
2423 }
2424
2425 #[test]
2426 fn test_validate_subscribe_missing_destination_encoded_name() {
2427 let valid = ProtoName::from_strings(["org", "ns", "agent"]);
2428 let hdr = SlimHeader {
2429 source: Some(valid),
2430 destination: Some(ProtoName {
2431 name: None,
2432 str_name: None,
2433 }),
2434 ..Default::default()
2435 };
2436 let msg = ProtoMessage::new(
2437 HashMap::new(),
2438 ProtoSubscribeType(ProtoSubscribe {
2439 header: Some(hdr),
2440 ..Default::default()
2441 }),
2442 );
2443 assert!(matches!(
2444 msg.validate(),
2445 Err(MessageError::DestinationEncodedNameNotFound)
2446 ));
2447 }
2448
2449 #[test]
2450 fn test_participant_settings_convenience_methods() {
2451 let bidirectional = ParticipantSettings::bidirectional();
2452 assert!(bidirectional.sends_data);
2453 assert!(bidirectional.receives_data);
2454 assert!(bidirectional.is_sender());
2455 assert!(bidirectional.is_receiver());
2456
2457 let send_only = ParticipantSettings::send_only();
2458 assert!(send_only.sends_data);
2459 assert!(!send_only.receives_data);
2460 assert!(send_only.is_sender());
2461 assert!(!send_only.is_receiver());
2462
2463 let receive_only = ParticipantSettings::receive_only();
2464 assert!(!receive_only.sends_data);
2465 assert!(receive_only.receives_data);
2466 assert!(!receive_only.is_sender());
2467 assert!(receive_only.is_receiver());
2468 }
2469}