1pub mod policy_gate;
6pub mod remote;
7
8pub use remote::RemoteCommunicationBus;
9
10use async_trait::async_trait;
11use parking_lot::RwLock;
12use std::collections::HashMap;
13use std::sync::Arc;
14use std::time::{Duration, SystemTime};
15use tokio::sync::{mpsc, oneshot, Notify};
16use tokio::time::{interval, timeout};
17
18use crate::crypto::Aes256GcmCrypto;
19use crate::types::*;
20use ed25519_dalek::{SigningKey, VerifyingKey};
21use rand::rngs::OsRng;
22use rand::RngCore;
23
24#[async_trait]
26pub trait CommunicationBus {
27 async fn send_message(&self, message: SecureMessage) -> Result<MessageId, CommunicationError>;
29
30 async fn receive_messages(
32 &self,
33 agent_id: AgentId,
34 ) -> Result<Vec<SecureMessage>, CommunicationError>;
35
36 async fn subscribe(&self, agent_id: AgentId, topic: String) -> Result<(), CommunicationError>;
38
39 async fn unsubscribe(&self, agent_id: AgentId, topic: String)
41 -> Result<(), CommunicationError>;
42
43 async fn publish(
45 &self,
46 topic: String,
47 message: SecureMessage,
48 ) -> Result<(), CommunicationError>;
49
50 async fn get_delivery_status(
52 &self,
53 message_id: MessageId,
54 ) -> Result<DeliveryStatus, CommunicationError>;
55
56 async fn register_agent(&self, agent_id: AgentId) -> Result<(), CommunicationError>;
58
59 async fn unregister_agent(&self, agent_id: AgentId) -> Result<(), CommunicationError>;
61
62 async fn request(
64 &self,
65 target_agent: AgentId,
66 request_payload: bytes::Bytes,
67 timeout_duration: Duration,
68 ) -> Result<bytes::Bytes, CommunicationError>;
69
70 async fn shutdown(&self) -> Result<(), CommunicationError>;
72
73 async fn check_health(&self) -> Result<ComponentHealth, CommunicationError>;
75
76 fn create_internal_message(
78 &self,
79 sender: AgentId,
80 recipient: AgentId,
81 payload_data: bytes::Bytes,
82 message_type: MessageType,
83 ttl: std::time::Duration,
84 ) -> SecureMessage;
85}
86
87#[derive(Debug, Clone)]
89pub struct CommunicationConfig {
90 pub max_message_size: usize,
91 pub message_ttl: Duration,
92 pub max_queue_size: usize,
93 pub delivery_timeout: Duration,
94 pub retry_attempts: u32,
95 pub enable_encryption: bool,
96 pub enable_compression: bool,
97 pub dead_letter_queue_size: usize,
98}
99
100impl Default for CommunicationConfig {
101 fn default() -> Self {
102 Self {
103 max_message_size: 1024 * 1024, message_ttl: Duration::from_secs(3600), max_queue_size: 10000,
106 delivery_timeout: Duration::from_secs(30),
107 retry_attempts: 3,
108 enable_encryption: true,
109 enable_compression: true,
110 dead_letter_queue_size: 1000,
111 }
112 }
113}
114
115pub struct DefaultCommunicationBus {
117 config: CommunicationConfig,
118 message_queues: Arc<RwLock<HashMap<AgentId, MessageQueue>>>,
119 subscriptions: Arc<RwLock<HashMap<String, Vec<AgentId>>>>,
120 message_tracker: Arc<RwLock<HashMap<MessageId, MessageTracker>>>,
121 dead_letter_queue: Arc<RwLock<DeadLetterQueue>>,
122 pending_requests: Arc<RwLock<HashMap<RequestId, oneshot::Sender<bytes::Bytes>>>>,
123 event_sender: mpsc::UnboundedSender<CommunicationEvent>,
124 shutdown_notify: Arc<Notify>,
125 is_running: Arc<RwLock<bool>>,
126 signing_key: SigningKey,
127 verifying_key: VerifyingKey,
128 system_agent_id: AgentId,
129 #[allow(dead_code)]
130 crypto: Aes256GcmCrypto,
131}
132
133impl DefaultCommunicationBus {
134 pub async fn new(config: CommunicationConfig) -> Result<Self, CommunicationError> {
136 let message_queues = Arc::new(RwLock::new(HashMap::new()));
137 let subscriptions = Arc::new(RwLock::new(HashMap::new()));
138 let message_tracker = Arc::new(RwLock::new(HashMap::new()));
139 let dead_letter_queue = Arc::new(RwLock::new(DeadLetterQueue::new(
140 config.dead_letter_queue_size,
141 )));
142 let pending_requests = Arc::new(RwLock::new(HashMap::new()));
143 let (event_sender, event_receiver) = mpsc::unbounded_channel();
144 let shutdown_notify = Arc::new(Notify::new());
145 let is_running = Arc::new(RwLock::new(true));
146
147 let mut secret_bytes = [0u8; 32];
149 OsRng.fill_bytes(&mut secret_bytes);
150 let signing_key = SigningKey::from_bytes(&secret_bytes);
151 let verifying_key = signing_key.verifying_key();
152
153 let system_agent_id = AgentId::new();
155
156 let crypto = Aes256GcmCrypto::new();
157
158 let bus = Self {
159 config,
160 message_queues,
161 subscriptions,
162 message_tracker,
163 dead_letter_queue,
164 pending_requests,
165 event_sender,
166 shutdown_notify,
167 is_running,
168 signing_key,
169 verifying_key,
170 system_agent_id,
171 crypto,
172 };
173
174 bus.start_event_loop(event_receiver).await;
176 bus.start_cleanup_loop().await;
177
178 Ok(bus)
179 }
180
181 async fn start_event_loop(
183 &self,
184 mut event_receiver: mpsc::UnboundedReceiver<CommunicationEvent>,
185 ) {
186 let message_queues = self.message_queues.clone();
187 let subscriptions = self.subscriptions.clone();
188 let message_tracker = self.message_tracker.clone();
189 let dead_letter_queue = self.dead_letter_queue.clone();
190 let pending_requests = self.pending_requests.clone();
191 let shutdown_notify = self.shutdown_notify.clone();
192 let config = self.config.clone();
193
194 tokio::spawn(async move {
195 loop {
196 tokio::select! {
197 event = event_receiver.recv() => {
198 if let Some(event) = event {
199 Self::process_communication_event(
200 event,
201 &message_queues,
202 &subscriptions,
203 &message_tracker,
204 &dead_letter_queue,
205 &pending_requests,
206 &config,
207 ).await;
208 } else {
209 break;
210 }
211 }
212 _ = shutdown_notify.notified() => {
213 break;
214 }
215 }
216 }
217 });
218 }
219
220 async fn start_cleanup_loop(&self) {
222 let message_queues = self.message_queues.clone();
223 let message_tracker = self.message_tracker.clone();
224 let dead_letter_queue = self.dead_letter_queue.clone();
225 let shutdown_notify = self.shutdown_notify.clone();
226 let is_running = self.is_running.clone();
227 let message_ttl = self.config.message_ttl;
228
229 tokio::spawn(async move {
230 let mut interval = interval(Duration::from_secs(60)); loop {
233 tokio::select! {
234 _ = interval.tick() => {
235 if !*is_running.read() {
236 break;
237 }
238
239 Self::cleanup_expired_messages(&message_queues, &message_tracker, &dead_letter_queue, message_ttl).await;
240 }
241 _ = shutdown_notify.notified() => {
242 break;
243 }
244 }
245 }
246 });
247 }
248
249 async fn process_communication_event(
251 event: CommunicationEvent,
252 message_queues: &Arc<RwLock<HashMap<AgentId, MessageQueue>>>,
253 subscriptions: &Arc<RwLock<HashMap<String, Vec<AgentId>>>>,
254 message_tracker: &Arc<RwLock<HashMap<MessageId, MessageTracker>>>,
255 dead_letter_queue: &Arc<RwLock<DeadLetterQueue>>,
256 pending_requests: &Arc<RwLock<HashMap<RequestId, oneshot::Sender<bytes::Bytes>>>>,
257 config: &CommunicationConfig,
258 ) {
259 match event {
260 CommunicationEvent::MessageSent { message } => {
261 let recipient = message.recipient;
262 let message_id = message.id;
263
264 if let MessageType::Response(request_id) = &message.message_type {
266 if let Some(sender) = pending_requests.write().remove(request_id) {
267 let _ = sender.send(message.payload.data.clone());
269 tracing::debug!(
270 "Response {} sent for request {:?}",
271 message_id,
272 request_id
273 );
274 return;
275 }
276 }
277
278 let mut tracker_map = message_tracker.write();
281 let mut queues = message_queues.write();
282
283 tracker_map.insert(message_id, MessageTracker::new(message.clone()));
284
285 if let Some(recipient_id) = recipient {
287 if let Some(queue) = queues.get_mut(&recipient_id) {
288 if queue.can_accept_message(config) {
289 queue.add_message(message);
290
291 if let Some(tracker) = tracker_map.get_mut(&message_id) {
293 tracker.status = DeliveryStatus::Delivered;
294 tracker.delivered_at = Some(SystemTime::now());
295 }
296
297 tracing::debug!(
298 "Message {} delivered to agent {}",
299 message_id,
300 recipient_id
301 );
302 } else {
303 dead_letter_queue
305 .write()
306 .add_message(message, DeadLetterReason::QueueFull);
307
308 if let Some(tracker) = tracker_map.get_mut(&message_id) {
309 tracker.status = DeliveryStatus::Failed;
310 tracker.failure_reason = Some("Queue full".to_string());
311 }
312
313 tracing::warn!(
314 "Message {} failed to deliver: queue full for agent {}",
315 message_id,
316 recipient_id
317 );
318 }
319 } else {
320 dead_letter_queue
322 .write()
323 .add_message(message, DeadLetterReason::AgentNotFound);
324
325 if let Some(tracker) = tracker_map.get_mut(&message_id) {
326 tracker.status = DeliveryStatus::Failed;
327 tracker.failure_reason = Some("Agent not registered".to_string());
328 }
329
330 tracing::warn!(
331 "Message {} failed to deliver: agent {:?} not registered",
332 message_id,
333 recipient
334 );
335 }
336 } else {
337 dead_letter_queue
339 .write()
340 .add_message(message, DeadLetterReason::AgentNotFound);
341
342 if let Some(tracker) = message_tracker.write().get_mut(&message_id) {
343 tracker.status = DeliveryStatus::Failed;
344 tracker.failure_reason = Some("Agent not registered".to_string());
345 }
346
347 tracing::warn!(
348 "Message {} failed to deliver: agent {:?} not registered",
349 message_id,
350 recipient
351 );
352 }
353 }
354 CommunicationEvent::TopicPublished { topic, message } => {
355 let subscribers = subscriptions
356 .read()
357 .get(&topic)
358 .cloned()
359 .unwrap_or_default();
360 let subscriber_count = subscribers.len();
361
362 for subscriber in &subscribers {
363 let mut subscriber_message = message.clone();
364 subscriber_message.recipient = Some(*subscriber);
365 subscriber_message.id = MessageId::new();
366
367 Box::pin(Self::process_communication_event(
369 CommunicationEvent::MessageSent {
370 message: subscriber_message,
371 },
372 message_queues,
373 subscriptions,
374 message_tracker,
375 dead_letter_queue,
376 pending_requests,
377 config,
378 ))
379 .await;
380 }
381
382 tracing::debug!(
383 "Published message to topic {} for {} subscribers",
384 topic,
385 subscriber_count
386 );
387 }
388 CommunicationEvent::AgentRegistered { agent_id } => {
389 message_queues.write().insert(agent_id, MessageQueue::new());
390 tracing::info!("Registered agent {} for communication", agent_id);
391 }
392 CommunicationEvent::AgentUnregistered { agent_id } => {
393 message_queues.write().remove(&agent_id);
394
395 let mut subs = subscriptions.write();
397 for subscribers in subs.values_mut() {
398 subscribers.retain(|&id| id != agent_id);
399 }
400
401 tracing::info!("Unregistered agent {} from communication", agent_id);
402 }
403 }
404 }
405
406 async fn cleanup_expired_messages(
408 message_queues: &Arc<RwLock<HashMap<AgentId, MessageQueue>>>,
409 message_tracker: &Arc<RwLock<HashMap<MessageId, MessageTracker>>>,
410 dead_letter_queue: &Arc<RwLock<DeadLetterQueue>>,
411 message_ttl: Duration,
412 ) {
413 let now = SystemTime::now();
414 let mut expired_messages = Vec::new();
415
416 {
418 let mut queues = message_queues.write();
419 let mut stale_queues = 0;
420 for queue in queues.values_mut() {
421 let expired = queue.remove_expired_messages(now, message_ttl);
422 expired_messages.extend(expired);
423
424 if queue.is_stale(message_ttl * 3) {
426 stale_queues += 1;
427 }
428 }
429
430 if stale_queues > 0 {
431 tracing::debug!("Found {} stale message queues", stale_queues);
432 }
433 }
434
435 {
437 let mut dlq = dead_letter_queue.write();
438 for message in expired_messages {
439 dlq.add_message(message.clone(), DeadLetterReason::Expired);
440
441 if let Some(tracker) = message_tracker.write().get_mut(&message.id) {
443 tracker.status = DeliveryStatus::Failed;
444 tracker.failure_reason = Some("Message expired".to_string());
445 }
446 }
447 }
448
449 {
451 let mut tracker = message_tracker.write();
452 let mut retry_candidates = Vec::new();
453
454 tracker.retain(|message_id, t| {
455 let age = t.get_age();
456 if age < message_ttl * 2 {
457 if t.should_retry(message_ttl) {
459 retry_candidates.push(*message_id);
460
461 let msg = t.get_message();
463 tracing::debug!(
464 "Message {} eligible for retry: size={} bytes, age={:?}s, sender={}",
465 message_id,
466 t.get_message_size(),
467 t.get_age().as_secs(),
468 msg.sender
469 );
470 }
471 true
472 } else {
473 false
474 }
475 });
476
477 if !retry_candidates.is_empty() {
479 tracing::debug!(
480 "Found {} messages eligible for retry",
481 retry_candidates.len()
482 );
483 }
484 }
485 }
486
487 fn send_event(&self, event: CommunicationEvent) -> Result<(), CommunicationError> {
489 self.event_sender
490 .send(event)
491 .map_err(|_| CommunicationError::EventProcessingFailed {
492 reason: "Failed to send communication event".into(),
493 })
494 }
495
496 fn generate_nonce() -> Vec<u8> {
498 use aes_gcm::{aead::AeadCore, Aes256Gcm};
499 let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
500 nonce.to_vec()
501 }
502
503 fn sign_message_data(&self, data: &[u8]) -> MessageSignature {
505 use ed25519_dalek::Signer;
506
507 let signature = self.signing_key.sign(data);
508 MessageSignature {
509 signature: signature.to_bytes().to_vec(),
510 algorithm: SignatureAlgorithm::Ed25519,
511 public_key: self.verifying_key.to_bytes().to_vec(),
512 }
513 }
514
515 fn verify_message_signature(&self, message: &SecureMessage) -> Result<(), CommunicationError> {
523 use ed25519_dalek::{Signature, Verifier};
524
525 if !matches!(message.signature.algorithm, SignatureAlgorithm::Ed25519) {
526 return Err(CommunicationError::SignatureInvalid {
527 message_id: message.id,
528 reason: format!(
529 "unsupported signature algorithm {:?}",
530 message.signature.algorithm
531 )
532 .into_boxed_str(),
533 });
534 }
535
536 let sig_bytes: &[u8; 64] =
537 message
538 .signature
539 .signature
540 .as_slice()
541 .try_into()
542 .map_err(|_| CommunicationError::SignatureInvalid {
543 message_id: message.id,
544 reason: format!(
545 "malformed signature length: expected 64, got {}",
546 message.signature.signature.len()
547 )
548 .into_boxed_str(),
549 })?;
550 let signature = Signature::from_bytes(sig_bytes);
551 let data = [
552 message.payload.data.as_ref(),
553 message.payload.nonce.as_slice(),
554 ]
555 .concat();
556 self.verifying_key.verify(&data, &signature).map_err(|e| {
557 CommunicationError::SignatureInvalid {
558 message_id: message.id,
559 reason: format!("verification failed: {}", e).into_boxed_str(),
560 }
561 })
562 }
563
564 fn create_secure_request_message(
566 &self,
567 target_agent: AgentId,
568 request_id: RequestId,
569 request_payload: bytes::Bytes,
570 timeout_duration: Duration,
571 ) -> Result<SecureMessage, CommunicationError> {
572 Ok(self.create_internal_message(
573 self.system_agent_id,
574 target_agent,
575 request_payload,
576 MessageType::Request(request_id),
577 timeout_duration,
578 ))
579 }
580}
581
582#[async_trait]
583impl CommunicationBus for DefaultCommunicationBus {
584 async fn send_message(&self, message: SecureMessage) -> Result<MessageId, CommunicationError> {
585 if !*self.is_running.read() {
586 return Err(CommunicationError::ShuttingDown);
587 }
588
589 if message.payload.data.len() > self.config.max_message_size {
591 return Err(CommunicationError::MessageTooLarge {
592 size: message.payload.data.len(),
593 max_size: self.config.max_message_size,
594 });
595 }
596
597 self.verify_message_signature(&message)?;
602
603 let message_id = message.id;
604
605 self.send_event(CommunicationEvent::MessageSent { message })?;
606
607 Ok(message_id)
608 }
609
610 async fn receive_messages(
611 &self,
612 agent_id: AgentId,
613 ) -> Result<Vec<SecureMessage>, CommunicationError> {
614 let mut queues = self.message_queues.write();
615 if let Some(queue) = queues.get_mut(&agent_id) {
616 Ok(queue.drain_messages())
617 } else {
618 Err(CommunicationError::AgentNotRegistered { agent_id })
619 }
620 }
621
622 async fn subscribe(&self, agent_id: AgentId, topic: String) -> Result<(), CommunicationError> {
623 let mut subscriptions = self.subscriptions.write();
624 subscriptions
625 .entry(topic.clone())
626 .or_default()
627 .push(agent_id);
628
629 tracing::info!("Agent {} subscribed to topic {}", agent_id, topic);
630 Ok(())
631 }
632
633 async fn unsubscribe(
634 &self,
635 agent_id: AgentId,
636 topic: String,
637 ) -> Result<(), CommunicationError> {
638 let mut subscriptions = self.subscriptions.write();
639 if let Some(subscribers) = subscriptions.get_mut(&topic) {
640 subscribers.retain(|&id| id != agent_id);
641 if subscribers.is_empty() {
642 subscriptions.remove(&topic);
643 }
644 }
645
646 tracing::info!("Agent {} unsubscribed from topic {}", agent_id, topic);
647 Ok(())
648 }
649
650 async fn publish(
651 &self,
652 topic: String,
653 message: SecureMessage,
654 ) -> Result<(), CommunicationError> {
655 if !*self.is_running.read() {
656 return Err(CommunicationError::ShuttingDown);
657 }
658
659 if message.payload.data.len() > self.config.max_message_size {
663 return Err(CommunicationError::MessageTooLarge {
664 size: message.payload.data.len(),
665 max_size: self.config.max_message_size,
666 });
667 }
668
669 self.verify_message_signature(&message)?;
672
673 self.send_event(CommunicationEvent::TopicPublished { topic, message })?;
674 Ok(())
675 }
676
677 async fn get_delivery_status(
678 &self,
679 message_id: MessageId,
680 ) -> Result<DeliveryStatus, CommunicationError> {
681 self.message_tracker
682 .read()
683 .get(&message_id)
684 .map(|tracker| tracker.status.clone())
685 .ok_or(CommunicationError::MessageNotFound { message_id })
686 }
687
688 async fn register_agent(&self, agent_id: AgentId) -> Result<(), CommunicationError> {
689 self.send_event(CommunicationEvent::AgentRegistered { agent_id })?;
690 Ok(())
691 }
692
693 async fn unregister_agent(&self, agent_id: AgentId) -> Result<(), CommunicationError> {
694 self.send_event(CommunicationEvent::AgentUnregistered { agent_id })?;
695 Ok(())
696 }
697
698 async fn request(
699 &self,
700 target_agent: AgentId,
701 request_payload: bytes::Bytes,
702 timeout_duration: Duration,
703 ) -> Result<bytes::Bytes, CommunicationError> {
704 if !*self.is_running.read() {
705 return Err(CommunicationError::ShuttingDown);
706 }
707
708 let request_id = RequestId::new();
710 let (response_sender, response_receiver) = oneshot::channel();
711
712 self.pending_requests
714 .write()
715 .insert(request_id, response_sender);
716
717 let request_message = self.create_secure_request_message(
719 target_agent,
720 request_id,
721 request_payload,
722 timeout_duration,
723 )?;
724
725 self.send_message(request_message).await?;
727
728 match timeout(timeout_duration, response_receiver).await {
730 Ok(Ok(response_payload)) => Ok(response_payload),
731 Ok(Err(_)) => {
732 self.pending_requests.write().remove(&request_id);
734 Err(CommunicationError::RequestCancelled { request_id })
735 }
736 Err(_) => {
737 self.pending_requests.write().remove(&request_id);
739 Err(CommunicationError::RequestTimeout {
740 request_id,
741 timeout: timeout_duration,
742 })
743 }
744 }
745 }
746
747 async fn shutdown(&self) -> Result<(), CommunicationError> {
748 tracing::info!("Shutting down communication bus");
749
750 *self.is_running.write() = false;
751 self.shutdown_notify.notify_waiters();
752
753 let agent_ids: Vec<AgentId> = self.message_queues.read().keys().copied().collect();
755
756 for agent_id in agent_ids {
757 if let Err(e) = self.unregister_agent(agent_id).await {
758 tracing::error!(
759 "Failed to unregister agent {} during shutdown: {}",
760 agent_id,
761 e
762 );
763 }
764 }
765
766 Ok(())
767 }
768
769 async fn check_health(&self) -> Result<ComponentHealth, CommunicationError> {
770 let is_running = *self.is_running.read();
771 if !is_running {
772 return Ok(ComponentHealth::unhealthy(
773 "Communication bus is shut down".to_string(),
774 ));
775 }
776
777 let queue_count = self.message_queues.read().len();
778 let topic_count = self.subscriptions.read().len();
779 let tracker_count = self.message_tracker.read().len();
780 let pending_requests = self.pending_requests.read().len();
781
782 let mut total_queued_messages = 0;
784 let mut full_queues = 0;
785
786 {
787 let queues = self.message_queues.read();
788 for queue in queues.values() {
789 total_queued_messages += queue.messages.len();
790 if queue.messages.len() >= self.config.max_queue_size * 9 / 10 {
791 full_queues += 1;
793 }
794 }
795 }
796
797 let dead_letter_count = self.dead_letter_queue.read().messages.len();
798
799 let status = if dead_letter_count > 100 {
800 ComponentHealth::degraded(format!(
801 "High dead letter queue: {} messages",
802 dead_letter_count
803 ))
804 } else if full_queues > 0 {
805 ComponentHealth::degraded(format!("{} message queues near capacity", full_queues))
806 } else if pending_requests > 50 {
807 ComponentHealth::degraded(format!("Many pending requests: {}", pending_requests))
808 } else {
809 ComponentHealth::healthy(Some(format!(
810 "{} agents registered, {} active topics",
811 queue_count, topic_count
812 )))
813 };
814
815 Ok(status
816 .with_metric("registered_agents".to_string(), queue_count.to_string())
817 .with_metric("active_topics".to_string(), topic_count.to_string())
818 .with_metric(
819 "queued_messages".to_string(),
820 total_queued_messages.to_string(),
821 )
822 .with_metric("pending_requests".to_string(), pending_requests.to_string())
823 .with_metric("dead_letters".to_string(), dead_letter_count.to_string())
824 .with_metric("message_trackers".to_string(), tracker_count.to_string()))
825 }
826
827 fn create_internal_message(
828 &self,
829 sender: AgentId,
830 recipient: AgentId,
831 payload_data: bytes::Bytes,
832 message_type: MessageType,
833 ttl: Duration,
834 ) -> SecureMessage {
835 let nonce = Self::generate_nonce();
836
837 let payload = EncryptedPayload {
838 data: payload_data,
839 nonce,
840 encryption_algorithm: EncryptionAlgorithm::Aes256Gcm,
841 };
842
843 let message_data_to_sign = [payload.data.as_ref(), &payload.nonce].concat();
845 let signature = self.sign_message_data(&message_data_to_sign);
846
847 SecureMessage {
848 id: MessageId::new(),
849 sender,
850 recipient: Some(recipient),
851 topic: None,
852 message_type,
853 payload,
854 signature,
855 ttl,
856 timestamp: SystemTime::now(),
857 }
858 }
859}
860
861#[derive(Debug, Clone)]
863struct MessageQueue {
864 messages: Vec<SecureMessage>,
865 created_at: SystemTime,
866}
867
868impl MessageQueue {
869 fn new() -> Self {
870 Self {
871 messages: Vec::new(),
872 created_at: SystemTime::now(),
873 }
874 }
875
876 fn can_accept_message(&self, config: &CommunicationConfig) -> bool {
877 self.messages.len() < config.max_queue_size
878 }
879
880 fn add_message(&mut self, message: SecureMessage) {
881 self.messages.push(message);
882 }
883
884 fn drain_messages(&mut self) -> Vec<SecureMessage> {
885 std::mem::take(&mut self.messages)
886 }
887
888 fn remove_expired_messages(&mut self, now: SystemTime, ttl: Duration) -> Vec<SecureMessage> {
889 let mut expired = Vec::new();
890
891 self.messages.retain(|message| {
892 let age = now.duration_since(message.timestamp).unwrap_or_default();
893 if age > ttl {
894 expired.push(message.clone());
895 false
896 } else {
897 true
898 }
899 });
900
901 expired
902 }
903
904 fn get_queue_age(&self) -> Duration {
905 SystemTime::now()
906 .duration_since(self.created_at)
907 .unwrap_or_default()
908 }
909
910 fn is_stale(&self, max_age: Duration) -> bool {
911 self.get_queue_age() > max_age
912 }
913}
914
915#[derive(Debug, Clone)]
917struct MessageTracker {
918 message: SecureMessage,
919 status: DeliveryStatus,
920 created_at: SystemTime,
921 delivered_at: Option<SystemTime>,
922 failure_reason: Option<String>,
923}
924
925impl MessageTracker {
926 fn new(message: SecureMessage) -> Self {
927 Self {
928 message,
929 status: DeliveryStatus::Pending,
930 created_at: SystemTime::now(),
931 delivered_at: None,
932 failure_reason: None,
933 }
934 }
935
936 fn get_message(&self) -> &SecureMessage {
938 &self.message
939 }
940
941 fn get_message_size(&self) -> usize {
943 self.message.payload.data.len()
944 }
945
946 fn get_age(&self) -> Duration {
948 SystemTime::now()
949 .duration_since(self.created_at)
950 .unwrap_or_default()
951 }
952
953 fn should_retry(&self, max_age: Duration) -> bool {
955 matches!(self.status, DeliveryStatus::Failed) && self.get_age() < max_age
956 }
957}
958
959#[derive(Debug, Clone, PartialEq, Eq)]
961pub enum DeliveryStatus {
962 Pending,
963 Delivered,
964 Failed,
965 Expired,
966}
967
968#[derive(Debug, Clone)]
970enum CommunicationEvent {
971 MessageSent {
972 message: SecureMessage,
973 },
974 TopicPublished {
975 topic: String,
976 message: SecureMessage,
977 },
978 AgentRegistered {
979 agent_id: AgentId,
980 },
981 AgentUnregistered {
982 agent_id: AgentId,
983 },
984}
985
986#[cfg(test)]
987mod tests {
988 use super::*;
989 use crate::types::{EncryptedPayload, MessageType};
990
991 fn create_test_message(sender: AgentId, recipient: AgentId) -> SecureMessage {
992 use crate::types::RequestId;
993 use aes_gcm::{aead::AeadCore, Aes256Gcm};
994 use ed25519_dalek::Signer;
995
996 let mut secret_bytes = [0u8; 32];
997 OsRng.fill_bytes(&mut secret_bytes);
998 let signing_key = SigningKey::from_bytes(&secret_bytes);
999 let verifying_key = signing_key.verifying_key();
1000
1001 let nonce = Aes256Gcm::generate_nonce(&mut OsRng).to_vec();
1002 let data: bytes::Bytes = b"test message".to_vec().into();
1003
1004 let message_data_to_sign = [data.as_ref(), &nonce].concat();
1005 let signature = signing_key.sign(&message_data_to_sign);
1006
1007 SecureMessage {
1008 id: MessageId::new(),
1009 sender,
1010 recipient: Some(recipient),
1011 message_type: MessageType::Request(RequestId::new()),
1012 topic: Some("test".to_string()),
1013 payload: EncryptedPayload {
1014 data,
1015 nonce,
1016 encryption_algorithm: EncryptionAlgorithm::Aes256Gcm,
1017 },
1018 signature: MessageSignature {
1019 signature: signature.to_bytes().to_vec(),
1020 algorithm: SignatureAlgorithm::Ed25519,
1021 public_key: verifying_key.to_bytes().to_vec(),
1022 },
1023 ttl: Duration::from_secs(3600),
1024 timestamp: SystemTime::now(),
1025 }
1026 }
1027
1028 #[tokio::test]
1029 async fn test_agent_registration() {
1030 let bus = DefaultCommunicationBus::new(CommunicationConfig::default())
1031 .await
1032 .unwrap();
1033 let agent_id = AgentId::new();
1034
1035 let result = bus.register_agent(agent_id).await;
1036 assert!(result.is_ok());
1037
1038 tokio::time::sleep(Duration::from_millis(50)).await;
1040
1041 let messages = bus.receive_messages(agent_id).await;
1043 assert!(messages.is_ok());
1044 }
1045
1046 #[tokio::test]
1047 async fn test_message_sending() {
1048 let bus = DefaultCommunicationBus::new(CommunicationConfig::default())
1049 .await
1050 .unwrap();
1051 let sender = AgentId::new();
1052 let recipient = AgentId::new();
1053
1054 bus.register_agent(sender).await.unwrap();
1056 bus.register_agent(recipient).await.unwrap();
1057
1058 tokio::time::sleep(Duration::from_millis(50)).await;
1059
1060 let message = bus.create_internal_message(
1062 sender,
1063 recipient,
1064 bytes::Bytes::from_static(b"test message"),
1065 MessageType::Request(crate::types::RequestId::new()),
1066 Duration::from_secs(60),
1067 );
1068 let message_id = bus.send_message(message).await.unwrap();
1069
1070 tokio::time::sleep(Duration::from_millis(50)).await;
1071
1072 let status = bus.get_delivery_status(message_id).await.unwrap();
1074 assert_eq!(status, DeliveryStatus::Delivered);
1075
1076 let messages = bus.receive_messages(recipient).await.unwrap();
1078 assert_eq!(messages.len(), 1);
1079 assert_eq!(messages[0].sender, sender);
1080 }
1081
1082 #[tokio::test]
1083 async fn test_topic_subscription() {
1084 let bus = DefaultCommunicationBus::new(CommunicationConfig::default())
1085 .await
1086 .unwrap();
1087 let publisher = AgentId::new();
1088 let subscriber1 = AgentId::new();
1089 let subscriber2 = AgentId::new();
1090
1091 bus.register_agent(publisher).await.unwrap();
1093 bus.register_agent(subscriber1).await.unwrap();
1094 bus.register_agent(subscriber2).await.unwrap();
1095
1096 let topic = "test_topic".to_string();
1098 bus.subscribe(subscriber1, topic.clone()).await.unwrap();
1099 bus.subscribe(subscriber2, topic.clone()).await.unwrap();
1100
1101 tokio::time::sleep(Duration::from_millis(50)).await;
1102
1103 let message = bus.create_internal_message(
1105 publisher,
1106 AgentId::new(),
1107 bytes::Bytes::from_static(b"test message"),
1108 MessageType::Publish(topic.clone()),
1109 Duration::from_secs(60),
1110 );
1111 bus.publish(topic, message).await.unwrap();
1112
1113 tokio::time::sleep(Duration::from_millis(50)).await;
1114
1115 let messages1 = bus.receive_messages(subscriber1).await.unwrap();
1117 let messages2 = bus.receive_messages(subscriber2).await.unwrap();
1118
1119 assert_eq!(messages1.len(), 1);
1120 assert_eq!(messages2.len(), 1);
1121 assert_eq!(messages1[0].sender, publisher);
1122 assert_eq!(messages2[0].sender, publisher);
1123 }
1124
1125 #[tokio::test]
1126 async fn test_message_size_limit() {
1127 let config = CommunicationConfig {
1128 max_message_size: 100, ..Default::default()
1130 };
1131
1132 let bus = DefaultCommunicationBus::new(config).await.unwrap();
1133 let sender = AgentId::new();
1134 let recipient = AgentId::new();
1135
1136 bus.register_agent(sender).await.unwrap();
1137 bus.register_agent(recipient).await.unwrap();
1138
1139 let mut message = bus.create_internal_message(
1142 sender,
1143 recipient,
1144 bytes::Bytes::from_static(b"placeholder"),
1145 MessageType::Request(crate::types::RequestId::new()),
1146 Duration::from_secs(60),
1147 );
1148 message.payload.data = vec![0u8; 200].into(); let result = bus.send_message(message).await;
1151 assert!(result.is_err());
1152
1153 if let Err(CommunicationError::MessageTooLarge { size, max_size }) = result {
1154 assert_eq!(size, 200);
1155 assert_eq!(max_size, 100);
1156 } else {
1157 panic!("Expected MessageTooLarge error");
1158 }
1159 }
1160
1161 #[tokio::test]
1162 async fn test_send_rejects_foreign_signature() {
1163 let bus = DefaultCommunicationBus::new(CommunicationConfig::default())
1167 .await
1168 .unwrap();
1169 let sender = AgentId::new();
1170 let recipient = AgentId::new();
1171 bus.register_agent(sender).await.unwrap();
1172 bus.register_agent(recipient).await.unwrap();
1173
1174 let foreign = create_test_message(sender, recipient);
1175 let err = bus.send_message(foreign.clone()).await.unwrap_err();
1176 assert!(
1177 matches!(err, CommunicationError::SignatureInvalid { .. }),
1178 "expected SignatureInvalid, got {:?}",
1179 err
1180 );
1181
1182 let err = bus.publish("t".to_string(), foreign).await.unwrap_err();
1183 assert!(matches!(err, CommunicationError::SignatureInvalid { .. }));
1184 }
1185
1186 #[tokio::test]
1187 async fn test_cross_bus_message_requires_rewrap() {
1188 let bus_a = DefaultCommunicationBus::new(CommunicationConfig::default())
1191 .await
1192 .unwrap();
1193 let bus_b = DefaultCommunicationBus::new(CommunicationConfig::default())
1194 .await
1195 .unwrap();
1196
1197 let sender = AgentId::new();
1198 let recipient = AgentId::new();
1199 bus_b.register_agent(recipient).await.unwrap();
1200
1201 let a_signed = bus_a.create_internal_message(
1203 sender,
1204 recipient,
1205 bytes::Bytes::from_static(b"hello"),
1206 MessageType::Direct(recipient),
1207 Duration::from_secs(60),
1208 );
1209
1210 let err = bus_b.send_message(a_signed.clone()).await.unwrap_err();
1212 assert!(
1213 matches!(err, CommunicationError::SignatureInvalid { .. }),
1214 "cross-bus message must be refused without re-wrap, got {:?}",
1215 err
1216 );
1217
1218 let b_signed = bus_b.create_internal_message(
1221 a_signed.sender,
1222 recipient,
1223 a_signed.payload.data.clone(),
1224 MessageType::Direct(recipient),
1225 Duration::from_secs(60),
1226 );
1227 bus_b
1228 .send_message(b_signed)
1229 .await
1230 .expect("re-wrapped message must be accepted");
1231 }
1232
1233 #[tokio::test]
1234 async fn test_send_rejects_none_signature_algorithm() {
1235 let bus = DefaultCommunicationBus::new(CommunicationConfig::default())
1236 .await
1237 .unwrap();
1238 let sender = AgentId::new();
1239 let recipient = AgentId::new();
1240 bus.register_agent(sender).await.unwrap();
1241 bus.register_agent(recipient).await.unwrap();
1242
1243 let mut msg = bus.create_internal_message(
1244 sender,
1245 recipient,
1246 bytes::Bytes::from_static(b"x"),
1247 MessageType::Direct(recipient),
1248 Duration::from_secs(60),
1249 );
1250 msg.signature.algorithm = SignatureAlgorithm::None;
1251 msg.signature.signature.clear();
1252 let err = bus.send_message(msg).await.unwrap_err();
1253 assert!(matches!(err, CommunicationError::SignatureInvalid { .. }));
1254 }
1255
1256 #[tokio::test]
1257 async fn test_agent_unregistration() {
1258 let bus = DefaultCommunicationBus::new(CommunicationConfig::default())
1259 .await
1260 .unwrap();
1261 let agent_id = AgentId::new();
1262
1263 bus.register_agent(agent_id).await.unwrap();
1265 tokio::time::sleep(Duration::from_millis(50)).await;
1266
1267 bus.unregister_agent(agent_id).await.unwrap();
1268 tokio::time::sleep(Duration::from_millis(50)).await;
1269
1270 let result = bus.receive_messages(agent_id).await;
1272 assert!(result.is_err());
1273 }
1274
1275 #[tokio::test]
1276 async fn test_request_response_timeout() {
1277 let bus = DefaultCommunicationBus::new(CommunicationConfig::default())
1278 .await
1279 .unwrap();
1280 let target_agent = AgentId::new();
1281
1282 bus.register_agent(target_agent).await.unwrap();
1284 tokio::time::sleep(Duration::from_millis(50)).await;
1285
1286 let request_payload = bytes::Bytes::from("test request");
1288 let timeout = Duration::from_millis(100);
1289
1290 let result = bus.request(target_agent, request_payload, timeout).await;
1291 assert!(result.is_err());
1292
1293 if let Err(CommunicationError::RequestTimeout {
1294 request_id: _,
1295 timeout: actual_timeout,
1296 }) = result
1297 {
1298 assert_eq!(actual_timeout, timeout);
1299 } else {
1300 panic!("Expected RequestTimeout error");
1301 }
1302 }
1303
1304 #[tokio::test]
1305 async fn test_request_response_success() {
1306 let bus = DefaultCommunicationBus::new(CommunicationConfig::default())
1307 .await
1308 .unwrap();
1309 let requester = AgentId::new();
1310 let responder = AgentId::new();
1311
1312 bus.register_agent(requester).await.unwrap();
1314 bus.register_agent(responder).await.unwrap();
1315 tokio::time::sleep(Duration::from_millis(50)).await;
1316
1317 let request_payload = bytes::Bytes::from("test request");
1318 let response_payload = bytes::Bytes::from("test response");
1319
1320 let bus_clone = Arc::new(bus);
1322 let request_bus = bus_clone.clone();
1323 let request_handle = tokio::spawn(async move {
1324 request_bus
1325 .request(responder, request_payload, Duration::from_secs(5))
1326 .await
1327 });
1328
1329 tokio::time::sleep(Duration::from_millis(100)).await;
1331
1332 let messages = bus_clone.receive_messages(responder).await.unwrap();
1334 assert_eq!(messages.len(), 1);
1335 assert!(matches!(messages[0].message_type, MessageType::Request(_)));
1336
1337 if let MessageType::Request(request_id) = &messages[0].message_type {
1339 let response_message = bus_clone.create_internal_message(
1340 responder,
1341 requester,
1342 response_payload.clone(),
1343 MessageType::Response(*request_id),
1344 Duration::from_secs(3600),
1345 );
1346
1347 bus_clone.send_message(response_message).await.unwrap();
1348 }
1349
1350 let result = request_handle.await.unwrap();
1352 assert!(result.is_ok());
1353 assert_eq!(result.unwrap(), response_payload);
1354 }
1355}