1use crate::models::{Capability, CellRole};
33use crate::Result;
34use serde::{Deserialize, Serialize};
35use std::collections::{HashMap, VecDeque};
36use std::sync::{Arc, Mutex};
37use std::time::{Duration, Instant};
38use tracing::{debug, instrument, warn};
39
40pub type SequenceNumber = u64;
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize, Default)]
45pub enum MessagePriority {
46 Low = 0,
48 #[default]
50 Normal = 1,
51 High = 2,
53 Critical = 3,
55}
56
57#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
62pub enum RoutingContext {
63 IntraCell,
65 CellToZone,
67 ZoneToCell,
69 IntraZone,
71}
72
73impl MessagePriority {
74 pub fn escalate(self, context: RoutingContext) -> Self {
107 match context {
108 RoutingContext::CellToZone => match self {
110 MessagePriority::Low => MessagePriority::Normal,
111 MessagePriority::Normal => MessagePriority::High,
112 MessagePriority::High | MessagePriority::Critical => MessagePriority::Critical,
113 },
114 RoutingContext::IntraCell | RoutingContext::ZoneToCell | RoutingContext::IntraZone => {
116 self
117 }
118 }
119 }
120
121 pub fn can_preempt(self, other: MessagePriority) -> bool {
125 self > other
126 }
127
128 pub fn bandwidth_multiplier(self) -> f32 {
133 match self {
134 MessagePriority::Low => 0.5,
135 MessagePriority::Normal => 1.0,
136 MessagePriority::High => 1.5,
137 MessagePriority::Critical => 2.0,
138 }
139 }
140}
141
142#[derive(Debug, Clone, Serialize, Deserialize)]
144pub enum CellMessageType {
145 Join {
147 platform_id: String,
148 capabilities: Vec<Capability>,
149 },
150 Leave { platform_id: String, reason: String },
152 CapabilityAnnounce {
154 platform_id: String,
155 capabilities: Vec<Capability>,
156 },
157 LeaderAnnounce {
159 leader_id: String,
160 election_round: u32,
161 },
162 Heartbeat { platform_id: String },
164 RoleAssignment {
166 platform_id: String,
167 role: CellRole,
168 score: f64,
169 is_primary: bool,
170 },
171 StatusUpdate {
173 platform_id: String,
174 status: serde_json::Value,
175 },
176 Ack { message_seq: SequenceNumber },
178 Nack {
180 message_seq: SequenceNumber,
181 reason: String,
182 },
183}
184
185#[derive(Debug, Clone, Serialize, Deserialize)]
187pub struct CellMessage {
188 pub message_id: String,
190 pub seq: SequenceNumber,
192 pub sender: String,
194 pub squad_id: String,
196 pub priority: MessagePriority,
198 pub routing_context: RoutingContext,
200 pub payload: CellMessageType,
202 pub timestamp: u64,
204 pub ttl: u64,
206}
207
208impl CellMessage {
209 pub fn new(
211 sender: String,
212 squad_id: String,
213 seq: SequenceNumber,
214 payload: CellMessageType,
215 ) -> Self {
216 let timestamp = std::time::SystemTime::now()
217 .duration_since(std::time::UNIX_EPOCH)
218 .unwrap()
219 .as_secs();
220
221 Self {
222 message_id: format!("{}-{}", sender, seq),
223 seq,
224 sender,
225 squad_id,
226 priority: MessagePriority::Normal,
227 routing_context: RoutingContext::IntraCell,
228 payload,
229 timestamp,
230 ttl: 30, }
232 }
233
234 pub fn with_priority(mut self, priority: MessagePriority) -> Self {
236 self.priority = priority;
237 self
238 }
239
240 pub fn with_routing_context(mut self, context: RoutingContext) -> Self {
242 self.routing_context = context;
243 self
244 }
245
246 pub fn with_ttl(mut self, ttl: u64) -> Self {
248 self.ttl = ttl;
249 self
250 }
251
252 pub fn escalate_priority(&mut self) {
257 self.priority = self.priority.escalate(self.routing_context);
258 }
259
260 pub fn effective_priority(&self) -> MessagePriority {
264 self.priority.escalate(self.routing_context)
265 }
266
267 pub fn is_expired(&self) -> bool {
269 let current_time = std::time::SystemTime::now()
270 .duration_since(std::time::UNIX_EPOCH)
271 .unwrap()
272 .as_secs();
273 current_time.saturating_sub(self.timestamp) > self.ttl
274 }
275
276 pub fn role_assignment(
278 sender: String,
279 squad_id: String,
280 seq: SequenceNumber,
281 platform_id: String,
282 role: CellRole,
283 score: f64,
284 is_primary: bool,
285 ) -> Self {
286 Self::new(
287 sender,
288 squad_id,
289 seq,
290 CellMessageType::RoleAssignment {
291 platform_id,
292 role,
293 score,
294 is_primary,
295 },
296 )
297 .with_priority(MessagePriority::High)
298 }
299}
300
301#[derive(Debug, Clone, Copy, PartialEq, Eq)]
303pub enum DeliveryStatus {
304 Pending,
306 Delivered,
308 Acknowledged,
310 Failed,
312}
313
314#[derive(Debug, Clone)]
316struct TrackedMessage {
317 message: CellMessage,
318 status: DeliveryStatus,
319 retry_count: u32,
320 last_send: Instant,
321}
322
323pub type MessageHandler = Arc<dyn Fn(&CellMessage) -> Result<()> + Send + Sync>;
325
326pub struct CellMessageBus {
334 squad_id: String,
336 platform_id: String,
338 next_seq: Arc<Mutex<SequenceNumber>>,
340 outbound_queue: Arc<Mutex<VecDeque<CellMessage>>>,
342 tracked_messages: Arc<Mutex<HashMap<SequenceNumber, TrackedMessage>>>,
344 received_seqs: Arc<Mutex<HashMap<String, SequenceNumber>>>,
346 subscribers: Arc<Mutex<Vec<MessageHandler>>>,
348 retry_timeout: Duration,
350 max_retries: u32,
352}
353
354impl CellMessageBus {
355 pub fn new(squad_id: String, platform_id: String) -> Self {
357 Self {
358 squad_id,
359 platform_id,
360 next_seq: Arc::new(Mutex::new(1)),
361 outbound_queue: Arc::new(Mutex::new(VecDeque::new())),
362 tracked_messages: Arc::new(Mutex::new(HashMap::new())),
363 received_seqs: Arc::new(Mutex::new(HashMap::new())),
364 subscribers: Arc::new(Mutex::new(Vec::new())),
365 retry_timeout: Duration::from_secs(2),
366 max_retries: 3,
367 }
368 }
369
370 pub fn subscribe(&self, handler: MessageHandler) -> Result<()> {
372 let mut subscribers = self.subscribers.lock().unwrap();
373 subscribers.push(handler);
374 Ok(())
375 }
376
377 #[instrument(skip(self, payload))]
379 pub fn publish(&self, payload: CellMessageType) -> Result<SequenceNumber> {
380 let seq = {
381 let mut next_seq = self.next_seq.lock().unwrap();
382 let seq = *next_seq;
383 *next_seq += 1;
384 seq
385 };
386
387 let message = CellMessage::new(
388 self.platform_id.clone(),
389 self.squad_id.clone(),
390 seq,
391 payload,
392 );
393
394 debug!(
395 "Publishing message seq={} from {} to squad {}",
396 seq, self.platform_id, self.squad_id
397 );
398
399 let mut queue = self.outbound_queue.lock().unwrap();
401 queue.push_back(message.clone());
402 let mut vec: Vec<_> = queue.drain(..).collect();
404 vec.sort_by_key(|m| std::cmp::Reverse(m.priority));
405 queue.extend(vec);
406
407 let tracked = TrackedMessage {
409 message: message.clone(),
410 status: DeliveryStatus::Pending,
411 retry_count: 0,
412 last_send: Instant::now(),
413 };
414 self.tracked_messages.lock().unwrap().insert(seq, tracked);
415
416 Ok(seq)
417 }
418
419 #[instrument(skip(self, message))]
421 pub fn deliver(&self, message: &CellMessage) -> Result<()> {
422 if message.is_expired() {
424 debug!("Dropping expired message seq={}", message.seq);
425 return Ok(());
426 }
427
428 {
430 let mut received = self.received_seqs.lock().unwrap();
431 if let Some(&last_seq) = received.get(&message.sender) {
432 if message.seq <= last_seq {
433 debug!(
434 "Dropping duplicate message seq={} from {}",
435 message.seq, message.sender
436 );
437 return Ok(());
438 }
439 }
440 received.insert(message.sender.clone(), message.seq);
441 }
442
443 debug!(
444 "Delivering message seq={} from {} to subscribers",
445 message.seq, message.sender
446 );
447
448 let subscribers = self.subscribers.lock().unwrap();
450 for handler in subscribers.iter() {
451 if let Err(e) = handler(message) {
452 warn!("Subscriber error: {}", e);
453 }
454 }
455
456 Ok(())
457 }
458
459 pub fn acknowledge(&self, message_seq: SequenceNumber) -> Result<()> {
461 let mut tracked = self.tracked_messages.lock().unwrap();
462 if let Some(msg) = tracked.get_mut(&message_seq) {
463 msg.status = DeliveryStatus::Acknowledged;
464 debug!("Acknowledged message seq={}", message_seq);
465 }
466 Ok(())
467 }
468
469 #[instrument(skip(self))]
471 pub fn process_retransmissions(&self) -> Result<Vec<CellMessage>> {
472 let mut tracked = self.tracked_messages.lock().unwrap();
473 let mut to_retry = Vec::new();
474
475 for (seq, msg) in tracked.iter_mut() {
476 if msg.status == DeliveryStatus::Acknowledged {
477 continue;
478 }
479
480 if msg.last_send.elapsed() >= self.retry_timeout {
481 if msg.retry_count >= self.max_retries {
482 warn!(
483 "Message seq={} failed after {} retries",
484 seq, msg.retry_count
485 );
486 msg.status = DeliveryStatus::Failed;
487 } else {
488 debug!(
489 "Retransmitting message seq={} (attempt {})",
490 seq,
491 msg.retry_count + 1
492 );
493 msg.retry_count += 1;
494 msg.last_send = Instant::now();
495 msg.status = DeliveryStatus::Delivered;
496 to_retry.push(msg.message.clone());
497 }
498 }
499 }
500
501 tracked.retain(|_, msg| {
503 msg.status != DeliveryStatus::Acknowledged && msg.status != DeliveryStatus::Failed
504 });
505
506 Ok(to_retry)
507 }
508
509 pub fn get_pending_messages(&self) -> Result<Vec<CellMessage>> {
511 let mut queue = self.outbound_queue.lock().unwrap();
512 let messages: Vec<_> = queue.drain(..).collect();
513 Ok(messages)
514 }
515
516 pub fn stats(&self) -> MessageBusStats {
518 let tracked = self.tracked_messages.lock().unwrap();
519 let outbound = self.outbound_queue.lock().unwrap();
520 let received = self.received_seqs.lock().unwrap();
521 let subscribers = self.subscribers.lock().unwrap();
522
523 MessageBusStats {
524 pending_outbound: outbound.len(),
525 tracked_messages: tracked.len(),
526 unique_senders: received.len(),
527 subscriber_count: subscribers.len(),
528 next_seq: *self.next_seq.lock().unwrap(),
529 }
530 }
531}
532
533#[derive(Debug, Clone)]
535pub struct MessageBusStats {
536 pub pending_outbound: usize,
537 pub tracked_messages: usize,
538 pub unique_senders: usize,
539 pub subscriber_count: usize,
540 pub next_seq: SequenceNumber,
541}
542
543#[cfg(test)]
544mod tests {
545 use super::*;
546
547 #[test]
548 fn test_message_creation() {
549 let payload = CellMessageType::Heartbeat {
550 platform_id: "node_1".to_string(),
551 };
552
553 let message = CellMessage::new("node_1".to_string(), "squad_alpha".to_string(), 1, payload);
554
555 assert_eq!(message.seq, 1);
556 assert_eq!(message.sender, "node_1");
557 assert_eq!(message.squad_id, "squad_alpha");
558 assert_eq!(message.priority, MessagePriority::Normal);
559 assert!(!message.is_expired());
560 }
561
562 #[test]
563 fn test_message_expiration() {
564 let payload = CellMessageType::Heartbeat {
565 platform_id: "node_1".to_string(),
566 };
567
568 let mut message =
569 CellMessage::new("node_1".to_string(), "squad_alpha".to_string(), 1, payload)
570 .with_ttl(0);
571
572 message.timestamp = 0;
574
575 assert!(message.is_expired());
576 }
577
578 #[test]
579 fn test_message_priority() {
580 let payload = CellMessageType::Heartbeat {
581 platform_id: "node_1".to_string(),
582 };
583
584 let message = CellMessage::new("node_1".to_string(), "squad_alpha".to_string(), 1, payload)
585 .with_priority(MessagePriority::Critical);
586
587 assert_eq!(message.priority, MessagePriority::Critical);
588 }
589
590 #[test]
591 fn test_message_bus_creation() {
592 let bus = CellMessageBus::new("squad_alpha".to_string(), "node_1".to_string());
593
594 assert_eq!(bus.squad_id, "squad_alpha");
595 assert_eq!(bus.platform_id, "node_1");
596
597 let stats = bus.stats();
598 assert_eq!(stats.pending_outbound, 0);
599 assert_eq!(stats.next_seq, 1);
600 }
601
602 #[test]
603 fn test_publish_message() {
604 let bus = CellMessageBus::new("squad_alpha".to_string(), "node_1".to_string());
605
606 let payload = CellMessageType::Heartbeat {
607 platform_id: "node_1".to_string(),
608 };
609
610 let seq = bus.publish(payload).unwrap();
611 assert_eq!(seq, 1);
612
613 let stats = bus.stats();
614 assert_eq!(stats.next_seq, 2);
615 assert_eq!(stats.pending_outbound, 1);
616 }
617
618 #[test]
619 fn test_priority_ordering() {
620 let bus = CellMessageBus::new("squad_alpha".to_string(), "node_1".to_string());
621
622 let _ = bus.publish(CellMessageType::Heartbeat {
624 platform_id: "node_1".to_string(),
625 });
626
627 let _ = bus.publish(CellMessageType::LeaderAnnounce {
628 leader_id: "node_2".to_string(),
629 election_round: 1,
630 });
631
632 {
634 let mut queue = bus.outbound_queue.lock().unwrap();
635 if let Some(msg) = queue.get_mut(1) {
636 msg.priority = MessagePriority::Critical;
637 }
638 }
639
640 let messages = bus.get_pending_messages().unwrap();
641 assert_eq!(messages.len(), 2);
642 }
644
645 #[test]
646 fn test_duplicate_detection() {
647 let bus = CellMessageBus::new("squad_alpha".to_string(), "node_1".to_string());
648
649 let message = CellMessage::new(
650 "node_2".to_string(),
651 "squad_alpha".to_string(),
652 1,
653 CellMessageType::Heartbeat {
654 platform_id: "node_2".to_string(),
655 },
656 );
657
658 bus.deliver(&message).unwrap();
660
661 bus.deliver(&message).unwrap();
663
664 let stats = bus.stats();
665 assert_eq!(stats.unique_senders, 1);
666 }
667
668 #[test]
669 fn test_subscriber_notification() {
670 let bus = CellMessageBus::new("squad_alpha".to_string(), "node_1".to_string());
671
672 let received = Arc::new(Mutex::new(false));
673 let received_clone = received.clone();
674
675 bus.subscribe(Arc::new(move |_msg| {
676 *received_clone.lock().unwrap() = true;
677 Ok(())
678 }))
679 .unwrap();
680
681 let message = CellMessage::new(
682 "node_2".to_string(),
683 "squad_alpha".to_string(),
684 1,
685 CellMessageType::Heartbeat {
686 platform_id: "node_2".to_string(),
687 },
688 );
689
690 bus.deliver(&message).unwrap();
691
692 assert!(*received.lock().unwrap());
693 }
694
695 #[test]
696 fn test_acknowledgment() {
697 let bus = CellMessageBus::new("squad_alpha".to_string(), "node_1".to_string());
698
699 let seq = bus
700 .publish(CellMessageType::Heartbeat {
701 platform_id: "node_1".to_string(),
702 })
703 .unwrap();
704
705 bus.acknowledge(seq).unwrap();
706
707 let tracked = bus.tracked_messages.lock().unwrap();
708 assert_eq!(
709 tracked.get(&seq).unwrap().status,
710 DeliveryStatus::Acknowledged
711 );
712 }
713
714 #[test]
715 fn test_retransmission() {
716 let mut bus = CellMessageBus::new("squad_alpha".to_string(), "node_1".to_string());
717 bus.retry_timeout = Duration::from_millis(10); let seq = bus
720 .publish(CellMessageType::Heartbeat {
721 platform_id: "node_1".to_string(),
722 })
723 .unwrap();
724
725 let _ = bus.get_pending_messages().unwrap();
727
728 std::thread::sleep(Duration::from_millis(15));
730
731 let retries = bus.process_retransmissions().unwrap();
733
734 assert_eq!(retries.len(), 1);
735 assert_eq!(retries[0].seq, seq);
736 }
737
738 #[test]
739 fn test_role_assignment_message() {
740 let msg = CellMessage::role_assignment(
741 "node_1".to_string(),
742 "squad_1".to_string(),
743 1,
744 "node_2".to_string(),
745 CellRole::Sensor,
746 0.85,
747 true,
748 );
749
750 assert_eq!(msg.sender, "node_1");
751 assert_eq!(msg.squad_id, "squad_1");
752 assert_eq!(msg.seq, 1);
753 assert_eq!(msg.priority, MessagePriority::High);
754
755 match msg.payload {
756 CellMessageType::RoleAssignment {
757 platform_id,
758 role,
759 score,
760 is_primary,
761 } => {
762 assert_eq!(platform_id, "node_2");
763 assert_eq!(role, CellRole::Sensor);
764 assert_eq!(score, 0.85);
765 assert!(is_primary);
766 }
767 _ => panic!("Expected RoleAssignment message"),
768 }
769 }
770
771 #[test]
774 fn test_priority_escalation_upward() {
775 assert_eq!(
777 MessagePriority::Low.escalate(RoutingContext::CellToZone),
778 MessagePriority::Normal
779 );
780
781 assert_eq!(
783 MessagePriority::Normal.escalate(RoutingContext::CellToZone),
784 MessagePriority::High
785 );
786
787 assert_eq!(
789 MessagePriority::High.escalate(RoutingContext::CellToZone),
790 MessagePriority::Critical
791 );
792
793 assert_eq!(
795 MessagePriority::Critical.escalate(RoutingContext::CellToZone),
796 MessagePriority::Critical
797 );
798 }
799
800 #[test]
801 fn test_priority_escalation_lateral() {
802 assert_eq!(
804 MessagePriority::Low.escalate(RoutingContext::IntraCell),
805 MessagePriority::Low
806 );
807 assert_eq!(
808 MessagePriority::Normal.escalate(RoutingContext::IntraCell),
809 MessagePriority::Normal
810 );
811 assert_eq!(
812 MessagePriority::High.escalate(RoutingContext::IntraCell),
813 MessagePriority::High
814 );
815
816 assert_eq!(
818 MessagePriority::Normal.escalate(RoutingContext::IntraZone),
819 MessagePriority::Normal
820 );
821 }
822
823 #[test]
824 fn test_priority_escalation_downward() {
825 assert_eq!(
827 MessagePriority::Low.escalate(RoutingContext::ZoneToCell),
828 MessagePriority::Low
829 );
830 assert_eq!(
831 MessagePriority::Normal.escalate(RoutingContext::ZoneToCell),
832 MessagePriority::Normal
833 );
834 assert_eq!(
835 MessagePriority::Critical.escalate(RoutingContext::ZoneToCell),
836 MessagePriority::Critical
837 );
838 }
839
840 #[test]
841 fn test_message_routing_context() {
842 let payload = CellMessageType::Heartbeat {
843 platform_id: "node_1".to_string(),
844 };
845
846 let msg = CellMessage::new("node_1".to_string(), "squad_alpha".to_string(), 1, payload);
848 assert_eq!(msg.routing_context, RoutingContext::IntraCell);
849
850 let msg2 = CellMessage::new(
852 "node_1".to_string(),
853 "squad_alpha".to_string(),
854 2,
855 CellMessageType::Heartbeat {
856 platform_id: "node_1".to_string(),
857 },
858 )
859 .with_routing_context(RoutingContext::CellToZone);
860
861 assert_eq!(msg2.routing_context, RoutingContext::CellToZone);
862 }
863
864 #[test]
865 fn test_message_escalate_priority() {
866 let payload = CellMessageType::Heartbeat {
867 platform_id: "node_1".to_string(),
868 };
869
870 let mut msg = CellMessage::new("node_1".to_string(), "squad_alpha".to_string(), 1, payload)
871 .with_priority(MessagePriority::Normal)
872 .with_routing_context(RoutingContext::CellToZone);
873
874 assert_eq!(msg.priority, MessagePriority::Normal);
876
877 msg.escalate_priority();
879
880 assert_eq!(msg.priority, MessagePriority::High);
882 }
883
884 #[test]
885 fn test_message_effective_priority() {
886 let payload = CellMessageType::Heartbeat {
887 platform_id: "node_1".to_string(),
888 };
889
890 let msg = CellMessage::new("node_1".to_string(), "squad_alpha".to_string(), 1, payload)
891 .with_priority(MessagePriority::Low)
892 .with_routing_context(RoutingContext::CellToZone);
893
894 assert_eq!(msg.priority, MessagePriority::Low);
896
897 assert_eq!(msg.effective_priority(), MessagePriority::Normal);
899 }
900
901 #[test]
902 fn test_priority_preemption() {
903 assert!(MessagePriority::Critical.can_preempt(MessagePriority::High));
904 assert!(MessagePriority::High.can_preempt(MessagePriority::Normal));
905 assert!(MessagePriority::Normal.can_preempt(MessagePriority::Low));
906
907 assert!(!MessagePriority::Low.can_preempt(MessagePriority::Normal));
908 assert!(!MessagePriority::Normal.can_preempt(MessagePriority::High));
909 assert!(!MessagePriority::Normal.can_preempt(MessagePriority::Normal)); }
911
912 #[test]
913 fn test_priority_bandwidth_multiplier() {
914 assert_eq!(MessagePriority::Low.bandwidth_multiplier(), 0.5);
915 assert_eq!(MessagePriority::Normal.bandwidth_multiplier(), 1.0);
916 assert_eq!(MessagePriority::High.bandwidth_multiplier(), 1.5);
917 assert_eq!(MessagePriority::Critical.bandwidth_multiplier(), 2.0);
918 }
919
920 #[test]
921 fn test_priority_level_ordering() {
922 assert!(MessagePriority::Low < MessagePriority::Normal);
924 assert!(MessagePriority::Normal < MessagePriority::High);
925 assert!(MessagePriority::High < MessagePriority::Critical);
926
927 assert!(MessagePriority::Low < MessagePriority::Critical);
929 }
930
931 #[test]
932 fn test_hierarchical_message_workflow() {
933 let payload = CellMessageType::StatusUpdate {
935 platform_id: "node_1".to_string(),
936 status: serde_json::json!({"health": "ok"}),
937 };
938
939 let mut msg = CellMessage::new("node_1".to_string(), "squad_alpha".to_string(), 1, payload)
941 .with_priority(MessagePriority::Normal)
942 .with_routing_context(RoutingContext::IntraCell);
943
944 assert_eq!(msg.priority, MessagePriority::Normal);
945 assert_eq!(msg.routing_context, RoutingContext::IntraCell);
946
947 msg.routing_context = RoutingContext::CellToZone;
949 msg.escalate_priority();
950
951 assert_eq!(msg.priority, MessagePriority::High);
953 assert_eq!(msg.routing_context, RoutingContext::CellToZone);
954 }
955}