1use async_trait::async_trait;
6use parking_lot::RwLock;
7use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, SystemTime};
10use tokio::sync::{mpsc, oneshot, Notify};
11use tokio::time::{interval, timeout};
12
13use crate::crypto::Aes256GcmCrypto;
14use crate::types::*;
15use ed25519_dalek::{SigningKey, VerifyingKey};
16use rand::rngs::OsRng;
17use rand::RngCore;
18
19#[async_trait]
21pub trait CommunicationBus {
22 async fn send_message(&self, message: SecureMessage) -> Result<MessageId, CommunicationError>;
24
25 async fn receive_messages(
27 &self,
28 agent_id: AgentId,
29 ) -> Result<Vec<SecureMessage>, CommunicationError>;
30
31 async fn subscribe(&self, agent_id: AgentId, topic: String) -> Result<(), CommunicationError>;
33
34 async fn unsubscribe(&self, agent_id: AgentId, topic: String)
36 -> Result<(), CommunicationError>;
37
38 async fn publish(
40 &self,
41 topic: String,
42 message: SecureMessage,
43 ) -> Result<(), CommunicationError>;
44
45 async fn get_delivery_status(
47 &self,
48 message_id: MessageId,
49 ) -> Result<DeliveryStatus, CommunicationError>;
50
51 async fn register_agent(&self, agent_id: AgentId) -> Result<(), CommunicationError>;
53
54 async fn unregister_agent(&self, agent_id: AgentId) -> Result<(), CommunicationError>;
56
57 async fn request(
59 &self,
60 target_agent: AgentId,
61 request_payload: bytes::Bytes,
62 timeout_duration: Duration,
63 ) -> Result<bytes::Bytes, CommunicationError>;
64
65 async fn shutdown(&self) -> Result<(), CommunicationError>;
67
68 async fn check_health(&self) -> Result<ComponentHealth, CommunicationError>;
70}
71
72#[derive(Debug, Clone)]
74pub struct CommunicationConfig {
75 pub max_message_size: usize,
76 pub message_ttl: Duration,
77 pub max_queue_size: usize,
78 pub delivery_timeout: Duration,
79 pub retry_attempts: u32,
80 pub enable_encryption: bool,
81 pub enable_compression: bool,
82 pub dead_letter_queue_size: usize,
83}
84
85impl Default for CommunicationConfig {
86 fn default() -> Self {
87 Self {
88 max_message_size: 1024 * 1024, message_ttl: Duration::from_secs(3600), max_queue_size: 10000,
91 delivery_timeout: Duration::from_secs(30),
92 retry_attempts: 3,
93 enable_encryption: true,
94 enable_compression: true,
95 dead_letter_queue_size: 1000,
96 }
97 }
98}
99
100pub struct DefaultCommunicationBus {
102 config: CommunicationConfig,
103 message_queues: Arc<RwLock<HashMap<AgentId, MessageQueue>>>,
104 subscriptions: Arc<RwLock<HashMap<String, Vec<AgentId>>>>,
105 message_tracker: Arc<RwLock<HashMap<MessageId, MessageTracker>>>,
106 dead_letter_queue: Arc<RwLock<DeadLetterQueue>>,
107 pending_requests: Arc<RwLock<HashMap<RequestId, oneshot::Sender<bytes::Bytes>>>>,
108 event_sender: mpsc::UnboundedSender<CommunicationEvent>,
109 shutdown_notify: Arc<Notify>,
110 is_running: Arc<RwLock<bool>>,
111 signing_key: SigningKey,
112 verifying_key: VerifyingKey,
113 system_agent_id: AgentId,
114 #[allow(dead_code)]
115 crypto: Aes256GcmCrypto,
116}
117
118impl DefaultCommunicationBus {
119 pub async fn new(config: CommunicationConfig) -> Result<Self, CommunicationError> {
121 let message_queues = Arc::new(RwLock::new(HashMap::new()));
122 let subscriptions = Arc::new(RwLock::new(HashMap::new()));
123 let message_tracker = Arc::new(RwLock::new(HashMap::new()));
124 let dead_letter_queue = Arc::new(RwLock::new(DeadLetterQueue::new(
125 config.dead_letter_queue_size,
126 )));
127 let pending_requests = Arc::new(RwLock::new(HashMap::new()));
128 let (event_sender, event_receiver) = mpsc::unbounded_channel();
129 let shutdown_notify = Arc::new(Notify::new());
130 let is_running = Arc::new(RwLock::new(true));
131
132 let mut secret_bytes = [0u8; 32];
134 OsRng.fill_bytes(&mut secret_bytes);
135 let signing_key = SigningKey::from_bytes(&secret_bytes);
136 let verifying_key = signing_key.verifying_key();
137
138 let system_agent_id = AgentId::new();
140
141 let crypto = Aes256GcmCrypto::new();
142
143 let bus = Self {
144 config,
145 message_queues,
146 subscriptions,
147 message_tracker,
148 dead_letter_queue,
149 pending_requests,
150 event_sender,
151 shutdown_notify,
152 is_running,
153 signing_key,
154 verifying_key,
155 system_agent_id,
156 crypto,
157 };
158
159 bus.start_event_loop(event_receiver).await;
161 bus.start_cleanup_loop().await;
162
163 Ok(bus)
164 }
165
166 async fn start_event_loop(
168 &self,
169 mut event_receiver: mpsc::UnboundedReceiver<CommunicationEvent>,
170 ) {
171 let message_queues = self.message_queues.clone();
172 let subscriptions = self.subscriptions.clone();
173 let message_tracker = self.message_tracker.clone();
174 let dead_letter_queue = self.dead_letter_queue.clone();
175 let pending_requests = self.pending_requests.clone();
176 let shutdown_notify = self.shutdown_notify.clone();
177 let config = self.config.clone();
178
179 tokio::spawn(async move {
180 loop {
181 tokio::select! {
182 event = event_receiver.recv() => {
183 if let Some(event) = event {
184 Self::process_communication_event(
185 event,
186 &message_queues,
187 &subscriptions,
188 &message_tracker,
189 &dead_letter_queue,
190 &pending_requests,
191 &config,
192 ).await;
193 } else {
194 break;
195 }
196 }
197 _ = shutdown_notify.notified() => {
198 break;
199 }
200 }
201 }
202 });
203 }
204
205 async fn start_cleanup_loop(&self) {
207 let message_queues = self.message_queues.clone();
208 let message_tracker = self.message_tracker.clone();
209 let dead_letter_queue = self.dead_letter_queue.clone();
210 let shutdown_notify = self.shutdown_notify.clone();
211 let is_running = self.is_running.clone();
212 let message_ttl = self.config.message_ttl;
213
214 tokio::spawn(async move {
215 let mut interval = interval(Duration::from_secs(60)); loop {
218 tokio::select! {
219 _ = interval.tick() => {
220 if !*is_running.read() {
221 break;
222 }
223
224 Self::cleanup_expired_messages(&message_queues, &message_tracker, &dead_letter_queue, message_ttl).await;
225 }
226 _ = shutdown_notify.notified() => {
227 break;
228 }
229 }
230 }
231 });
232 }
233
234 async fn process_communication_event(
236 event: CommunicationEvent,
237 message_queues: &Arc<RwLock<HashMap<AgentId, MessageQueue>>>,
238 subscriptions: &Arc<RwLock<HashMap<String, Vec<AgentId>>>>,
239 message_tracker: &Arc<RwLock<HashMap<MessageId, MessageTracker>>>,
240 dead_letter_queue: &Arc<RwLock<DeadLetterQueue>>,
241 pending_requests: &Arc<RwLock<HashMap<RequestId, oneshot::Sender<bytes::Bytes>>>>,
242 config: &CommunicationConfig,
243 ) {
244 match event {
245 CommunicationEvent::MessageSent { message } => {
246 let recipient = message.recipient;
247 let message_id = message.id;
248
249 if let MessageType::Response(request_id) = &message.message_type {
251 if let Some(sender) = pending_requests.write().remove(request_id) {
252 let _ = sender.send(message.payload.data.clone());
254 tracing::debug!(
255 "Response {} sent for request {:?}",
256 message_id,
257 request_id
258 );
259 return;
260 }
261 }
262
263 message_tracker
265 .write()
266 .insert(message_id, MessageTracker::new(message.clone()));
267
268 let mut queues = message_queues.write();
270 if let Some(recipient_id) = recipient {
271 if let Some(queue) = queues.get_mut(&recipient_id) {
272 if queue.can_accept_message(config) {
273 queue.add_message(message);
274
275 if let Some(tracker) = message_tracker.write().get_mut(&message_id) {
277 tracker.status = DeliveryStatus::Delivered;
278 tracker.delivered_at = Some(SystemTime::now());
279 }
280
281 tracing::debug!(
282 "Message {} delivered to agent {}",
283 message_id,
284 recipient_id
285 );
286 } else {
287 dead_letter_queue
289 .write()
290 .add_message(message, DeadLetterReason::QueueFull);
291
292 if let Some(tracker) = message_tracker.write().get_mut(&message_id) {
293 tracker.status = DeliveryStatus::Failed;
294 tracker.failure_reason = Some("Queue full".to_string());
295 }
296
297 tracing::warn!(
298 "Message {} failed to deliver: queue full for agent {}",
299 message_id,
300 recipient_id
301 );
302 }
303 } else {
304 dead_letter_queue
306 .write()
307 .add_message(message, DeadLetterReason::AgentNotFound);
308
309 if let Some(tracker) = message_tracker.write().get_mut(&message_id) {
310 tracker.status = DeliveryStatus::Failed;
311 tracker.failure_reason = Some("Agent not registered".to_string());
312 }
313
314 tracing::warn!(
315 "Message {} failed to deliver: agent {:?} not registered",
316 message_id,
317 recipient
318 );
319 }
320 } else {
321 dead_letter_queue
323 .write()
324 .add_message(message, DeadLetterReason::AgentNotFound);
325
326 if let Some(tracker) = message_tracker.write().get_mut(&message_id) {
327 tracker.status = DeliveryStatus::Failed;
328 tracker.failure_reason = Some("Agent not registered".to_string());
329 }
330
331 tracing::warn!(
332 "Message {} failed to deliver: agent {:?} not registered",
333 message_id,
334 recipient
335 );
336 }
337 }
338 CommunicationEvent::TopicPublished { topic, message } => {
339 let subscribers = subscriptions
340 .read()
341 .get(&topic)
342 .cloned()
343 .unwrap_or_default();
344 let subscriber_count = subscribers.len();
345
346 for subscriber in &subscribers {
347 let mut subscriber_message = message.clone();
348 subscriber_message.recipient = Some(*subscriber);
349 subscriber_message.id = MessageId::new();
350
351 Box::pin(Self::process_communication_event(
353 CommunicationEvent::MessageSent {
354 message: subscriber_message,
355 },
356 message_queues,
357 subscriptions,
358 message_tracker,
359 dead_letter_queue,
360 pending_requests,
361 config,
362 ))
363 .await;
364 }
365
366 tracing::debug!(
367 "Published message to topic {} for {} subscribers",
368 topic,
369 subscriber_count
370 );
371 }
372 CommunicationEvent::AgentRegistered { agent_id } => {
373 message_queues.write().insert(agent_id, MessageQueue::new());
374 tracing::info!("Registered agent {} for communication", agent_id);
375 }
376 CommunicationEvent::AgentUnregistered { agent_id } => {
377 message_queues.write().remove(&agent_id);
378
379 let mut subs = subscriptions.write();
381 for subscribers in subs.values_mut() {
382 subscribers.retain(|&id| id != agent_id);
383 }
384
385 tracing::info!("Unregistered agent {} from communication", agent_id);
386 }
387 }
388 }
389
390 async fn cleanup_expired_messages(
392 message_queues: &Arc<RwLock<HashMap<AgentId, MessageQueue>>>,
393 message_tracker: &Arc<RwLock<HashMap<MessageId, MessageTracker>>>,
394 dead_letter_queue: &Arc<RwLock<DeadLetterQueue>>,
395 message_ttl: Duration,
396 ) {
397 let now = SystemTime::now();
398 let mut expired_messages = Vec::new();
399
400 {
402 let mut queues = message_queues.write();
403 let mut stale_queues = 0;
404 for queue in queues.values_mut() {
405 let expired = queue.remove_expired_messages(now, message_ttl);
406 expired_messages.extend(expired);
407
408 if queue.is_stale(message_ttl * 3) {
410 stale_queues += 1;
411 }
412 }
413
414 if stale_queues > 0 {
415 tracing::debug!("Found {} stale message queues", stale_queues);
416 }
417 }
418
419 {
421 let mut dlq = dead_letter_queue.write();
422 for message in expired_messages {
423 dlq.add_message(message.clone(), DeadLetterReason::Expired);
424
425 if let Some(tracker) = message_tracker.write().get_mut(&message.id) {
427 tracker.status = DeliveryStatus::Failed;
428 tracker.failure_reason = Some("Message expired".to_string());
429 }
430 }
431 }
432
433 {
435 let mut tracker = message_tracker.write();
436 let mut retry_candidates = Vec::new();
437
438 tracker.retain(|message_id, t| {
439 let age = t.get_age();
440 if age < message_ttl * 2 {
441 if t.should_retry(message_ttl) {
443 retry_candidates.push(*message_id);
444
445 let msg = t.get_message();
447 tracing::debug!(
448 "Message {} eligible for retry: size={} bytes, age={:?}s, sender={}",
449 message_id,
450 t.get_message_size(),
451 t.get_age().as_secs(),
452 msg.sender
453 );
454 }
455 true
456 } else {
457 false
458 }
459 });
460
461 if !retry_candidates.is_empty() {
463 tracing::debug!(
464 "Found {} messages eligible for retry",
465 retry_candidates.len()
466 );
467 }
468 }
469 }
470
471 fn send_event(&self, event: CommunicationEvent) -> Result<(), CommunicationError> {
473 self.event_sender
474 .send(event)
475 .map_err(|_| CommunicationError::EventProcessingFailed {
476 reason: "Failed to send communication event".to_string(),
477 })
478 }
479
480 fn generate_nonce() -> Vec<u8> {
482 use aes_gcm::{aead::AeadCore, Aes256Gcm};
483 let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
484 nonce.to_vec()
485 }
486
487 fn sign_message_data(&self, data: &[u8]) -> MessageSignature {
489 use ed25519_dalek::Signer;
490
491 let signature = self.signing_key.sign(data);
492 MessageSignature {
493 signature: signature.to_bytes().to_vec(),
494 algorithm: SignatureAlgorithm::Ed25519,
495 public_key: self.verifying_key.to_bytes().to_vec(),
496 }
497 }
498
499 fn create_secure_request_message(
501 &self,
502 target_agent: AgentId,
503 request_id: RequestId,
504 request_payload: bytes::Bytes,
505 timeout_duration: Duration,
506 ) -> Result<SecureMessage, CommunicationError> {
507 let nonce = Self::generate_nonce();
509
510 let payload = EncryptedPayload {
512 data: request_payload,
513 nonce,
514 encryption_algorithm: EncryptionAlgorithm::Aes256Gcm,
515 };
516
517 let message_data_to_sign = [payload.data.as_ref(), &payload.nonce].concat();
519
520 let signature = self.sign_message_data(&message_data_to_sign);
522
523 Ok(SecureMessage {
524 id: MessageId::new(),
525 sender: self.system_agent_id,
526 recipient: Some(target_agent),
527 topic: None,
528 message_type: MessageType::Request(request_id),
529 payload,
530 signature,
531 ttl: timeout_duration,
532 timestamp: SystemTime::now(),
533 })
534 }
535}
536
537#[async_trait]
538impl CommunicationBus for DefaultCommunicationBus {
539 async fn send_message(&self, message: SecureMessage) -> Result<MessageId, CommunicationError> {
540 if !*self.is_running.read() {
541 return Err(CommunicationError::ShuttingDown);
542 }
543
544 if message.payload.data.len() > self.config.max_message_size {
546 return Err(CommunicationError::MessageTooLarge {
547 size: message.payload.data.len(),
548 max_size: self.config.max_message_size,
549 });
550 }
551
552 let message_id = message.id;
553
554 self.send_event(CommunicationEvent::MessageSent { message })?;
555
556 Ok(message_id)
557 }
558
559 async fn receive_messages(
560 &self,
561 agent_id: AgentId,
562 ) -> Result<Vec<SecureMessage>, CommunicationError> {
563 let mut queues = self.message_queues.write();
564 if let Some(queue) = queues.get_mut(&agent_id) {
565 Ok(queue.drain_messages())
566 } else {
567 Err(CommunicationError::AgentNotRegistered { agent_id })
568 }
569 }
570
571 async fn subscribe(&self, agent_id: AgentId, topic: String) -> Result<(), CommunicationError> {
572 let mut subscriptions = self.subscriptions.write();
573 subscriptions
574 .entry(topic.clone())
575 .or_default()
576 .push(agent_id);
577
578 tracing::info!("Agent {} subscribed to topic {}", agent_id, topic);
579 Ok(())
580 }
581
582 async fn unsubscribe(
583 &self,
584 agent_id: AgentId,
585 topic: String,
586 ) -> Result<(), CommunicationError> {
587 let mut subscriptions = self.subscriptions.write();
588 if let Some(subscribers) = subscriptions.get_mut(&topic) {
589 subscribers.retain(|&id| id != agent_id);
590 if subscribers.is_empty() {
591 subscriptions.remove(&topic);
592 }
593 }
594
595 tracing::info!("Agent {} unsubscribed from topic {}", agent_id, topic);
596 Ok(())
597 }
598
599 async fn publish(
600 &self,
601 topic: String,
602 message: SecureMessage,
603 ) -> Result<(), CommunicationError> {
604 if !*self.is_running.read() {
605 return Err(CommunicationError::ShuttingDown);
606 }
607
608 self.send_event(CommunicationEvent::TopicPublished { topic, message })?;
609 Ok(())
610 }
611
612 async fn get_delivery_status(
613 &self,
614 message_id: MessageId,
615 ) -> Result<DeliveryStatus, CommunicationError> {
616 self.message_tracker
617 .read()
618 .get(&message_id)
619 .map(|tracker| tracker.status.clone())
620 .ok_or(CommunicationError::MessageNotFound { message_id })
621 }
622
623 async fn register_agent(&self, agent_id: AgentId) -> Result<(), CommunicationError> {
624 self.send_event(CommunicationEvent::AgentRegistered { agent_id })?;
625 Ok(())
626 }
627
628 async fn unregister_agent(&self, agent_id: AgentId) -> Result<(), CommunicationError> {
629 self.send_event(CommunicationEvent::AgentUnregistered { agent_id })?;
630 Ok(())
631 }
632
633 async fn request(
634 &self,
635 target_agent: AgentId,
636 request_payload: bytes::Bytes,
637 timeout_duration: Duration,
638 ) -> Result<bytes::Bytes, CommunicationError> {
639 if !*self.is_running.read() {
640 return Err(CommunicationError::ShuttingDown);
641 }
642
643 let request_id = RequestId::new();
645 let (response_sender, response_receiver) = oneshot::channel();
646
647 self.pending_requests
649 .write()
650 .insert(request_id, response_sender);
651
652 let request_message = self.create_secure_request_message(
654 target_agent,
655 request_id,
656 request_payload,
657 timeout_duration,
658 )?;
659
660 self.send_message(request_message).await?;
662
663 match timeout(timeout_duration, response_receiver).await {
665 Ok(Ok(response_payload)) => Ok(response_payload),
666 Ok(Err(_)) => {
667 self.pending_requests.write().remove(&request_id);
669 Err(CommunicationError::RequestCancelled { request_id })
670 }
671 Err(_) => {
672 self.pending_requests.write().remove(&request_id);
674 Err(CommunicationError::RequestTimeout {
675 request_id,
676 timeout: timeout_duration,
677 })
678 }
679 }
680 }
681
682 async fn shutdown(&self) -> Result<(), CommunicationError> {
683 tracing::info!("Shutting down communication bus");
684
685 *self.is_running.write() = false;
686 self.shutdown_notify.notify_waiters();
687
688 let agent_ids: Vec<AgentId> = self.message_queues.read().keys().copied().collect();
690
691 for agent_id in agent_ids {
692 if let Err(e) = self.unregister_agent(agent_id).await {
693 tracing::error!(
694 "Failed to unregister agent {} during shutdown: {}",
695 agent_id,
696 e
697 );
698 }
699 }
700
701 Ok(())
702 }
703
704 async fn check_health(&self) -> Result<ComponentHealth, CommunicationError> {
705 let is_running = *self.is_running.read();
706 if !is_running {
707 return Ok(ComponentHealth::unhealthy(
708 "Communication bus is shut down".to_string(),
709 ));
710 }
711
712 let queue_count = self.message_queues.read().len();
713 let topic_count = self.subscriptions.read().len();
714 let tracker_count = self.message_tracker.read().len();
715 let pending_requests = self.pending_requests.read().len();
716
717 let mut total_queued_messages = 0;
719 let mut full_queues = 0;
720
721 {
722 let queues = self.message_queues.read();
723 for queue in queues.values() {
724 total_queued_messages += queue.messages.len();
725 if queue.messages.len() >= self.config.max_queue_size * 9 / 10 {
726 full_queues += 1;
728 }
729 }
730 }
731
732 let dead_letter_count = self.dead_letter_queue.read().messages.len();
733
734 let status = if dead_letter_count > 100 {
735 ComponentHealth::degraded(format!(
736 "High dead letter queue: {} messages",
737 dead_letter_count
738 ))
739 } else if full_queues > 0 {
740 ComponentHealth::degraded(format!("{} message queues near capacity", full_queues))
741 } else if pending_requests > 50 {
742 ComponentHealth::degraded(format!("Many pending requests: {}", pending_requests))
743 } else {
744 ComponentHealth::healthy(Some(format!(
745 "{} agents registered, {} active topics",
746 queue_count, topic_count
747 )))
748 };
749
750 Ok(status
751 .with_metric("registered_agents".to_string(), queue_count.to_string())
752 .with_metric("active_topics".to_string(), topic_count.to_string())
753 .with_metric(
754 "queued_messages".to_string(),
755 total_queued_messages.to_string(),
756 )
757 .with_metric("pending_requests".to_string(), pending_requests.to_string())
758 .with_metric("dead_letters".to_string(), dead_letter_count.to_string())
759 .with_metric("message_trackers".to_string(), tracker_count.to_string()))
760 }
761}
762
763#[derive(Debug, Clone)]
765struct MessageQueue {
766 messages: Vec<SecureMessage>,
767 created_at: SystemTime,
768}
769
770impl MessageQueue {
771 fn new() -> Self {
772 Self {
773 messages: Vec::new(),
774 created_at: SystemTime::now(),
775 }
776 }
777
778 fn can_accept_message(&self, config: &CommunicationConfig) -> bool {
779 self.messages.len() < config.max_queue_size
780 }
781
782 fn add_message(&mut self, message: SecureMessage) {
783 self.messages.push(message);
784 }
785
786 fn drain_messages(&mut self) -> Vec<SecureMessage> {
787 std::mem::take(&mut self.messages)
788 }
789
790 fn remove_expired_messages(&mut self, now: SystemTime, ttl: Duration) -> Vec<SecureMessage> {
791 let mut expired = Vec::new();
792
793 self.messages.retain(|message| {
794 let age = now.duration_since(message.timestamp).unwrap_or_default();
795 if age > ttl {
796 expired.push(message.clone());
797 false
798 } else {
799 true
800 }
801 });
802
803 expired
804 }
805
806 fn get_queue_age(&self) -> Duration {
807 SystemTime::now()
808 .duration_since(self.created_at)
809 .unwrap_or_default()
810 }
811
812 fn is_stale(&self, max_age: Duration) -> bool {
813 self.get_queue_age() > max_age
814 }
815}
816
817#[derive(Debug, Clone)]
819struct MessageTracker {
820 message: SecureMessage,
821 status: DeliveryStatus,
822 created_at: SystemTime,
823 delivered_at: Option<SystemTime>,
824 failure_reason: Option<String>,
825}
826
827impl MessageTracker {
828 fn new(message: SecureMessage) -> Self {
829 Self {
830 message,
831 status: DeliveryStatus::Pending,
832 created_at: SystemTime::now(),
833 delivered_at: None,
834 failure_reason: None,
835 }
836 }
837
838 fn get_message(&self) -> &SecureMessage {
840 &self.message
841 }
842
843 fn get_message_size(&self) -> usize {
845 self.message.payload.data.len()
846 }
847
848 fn get_age(&self) -> Duration {
850 SystemTime::now()
851 .duration_since(self.created_at)
852 .unwrap_or_default()
853 }
854
855 fn should_retry(&self, max_age: Duration) -> bool {
857 matches!(self.status, DeliveryStatus::Failed) && self.get_age() < max_age
858 }
859}
860
861#[derive(Debug, Clone, PartialEq, Eq)]
863pub enum DeliveryStatus {
864 Pending,
865 Delivered,
866 Failed,
867 Expired,
868}
869
870#[derive(Debug, Clone)]
872enum CommunicationEvent {
873 MessageSent {
874 message: SecureMessage,
875 },
876 TopicPublished {
877 topic: String,
878 message: SecureMessage,
879 },
880 AgentRegistered {
881 agent_id: AgentId,
882 },
883 AgentUnregistered {
884 agent_id: AgentId,
885 },
886}
887
888#[cfg(test)]
889mod tests {
890 use super::*;
891 use crate::types::{EncryptedPayload, MessageType};
892
893 fn create_test_message(sender: AgentId, recipient: AgentId) -> SecureMessage {
894 use crate::types::RequestId;
895 SecureMessage {
896 id: MessageId::new(),
897 sender,
898 recipient: Some(recipient),
899 message_type: MessageType::Request(RequestId::new()),
900 topic: Some("test".to_string()),
901 payload: EncryptedPayload {
902 data: b"test message".to_vec().into(),
903 nonce: [0u8; 12].to_vec(),
904 encryption_algorithm: EncryptionAlgorithm::Aes256Gcm,
905 },
906 signature: MessageSignature {
907 signature: vec![0u8; 64],
908 algorithm: SignatureAlgorithm::Ed25519,
909 public_key: vec![0u8; 32],
910 },
911 ttl: Duration::from_secs(3600),
912 timestamp: SystemTime::now(),
913 }
914 }
915
916 #[tokio::test]
917 async fn test_agent_registration() {
918 let bus = DefaultCommunicationBus::new(CommunicationConfig::default())
919 .await
920 .unwrap();
921 let agent_id = AgentId::new();
922
923 let result = bus.register_agent(agent_id).await;
924 assert!(result.is_ok());
925
926 tokio::time::sleep(Duration::from_millis(50)).await;
928
929 let messages = bus.receive_messages(agent_id).await;
931 assert!(messages.is_ok());
932 }
933
934 #[tokio::test]
935 async fn test_message_sending() {
936 let bus = DefaultCommunicationBus::new(CommunicationConfig::default())
937 .await
938 .unwrap();
939 let sender = AgentId::new();
940 let recipient = AgentId::new();
941
942 bus.register_agent(sender).await.unwrap();
944 bus.register_agent(recipient).await.unwrap();
945
946 tokio::time::sleep(Duration::from_millis(50)).await;
947
948 let message = create_test_message(sender, recipient);
950 let message_id = bus.send_message(message).await.unwrap();
951
952 tokio::time::sleep(Duration::from_millis(50)).await;
953
954 let status = bus.get_delivery_status(message_id).await.unwrap();
956 assert_eq!(status, DeliveryStatus::Delivered);
957
958 let messages = bus.receive_messages(recipient).await.unwrap();
960 assert_eq!(messages.len(), 1);
961 assert_eq!(messages[0].sender, sender);
962 }
963
964 #[tokio::test]
965 async fn test_topic_subscription() {
966 let bus = DefaultCommunicationBus::new(CommunicationConfig::default())
967 .await
968 .unwrap();
969 let publisher = AgentId::new();
970 let subscriber1 = AgentId::new();
971 let subscriber2 = AgentId::new();
972
973 bus.register_agent(publisher).await.unwrap();
975 bus.register_agent(subscriber1).await.unwrap();
976 bus.register_agent(subscriber2).await.unwrap();
977
978 let topic = "test_topic".to_string();
980 bus.subscribe(subscriber1, topic.clone()).await.unwrap();
981 bus.subscribe(subscriber2, topic.clone()).await.unwrap();
982
983 tokio::time::sleep(Duration::from_millis(50)).await;
984
985 let message = create_test_message(publisher, AgentId::new()); bus.publish(topic, message).await.unwrap();
988
989 tokio::time::sleep(Duration::from_millis(50)).await;
990
991 let messages1 = bus.receive_messages(subscriber1).await.unwrap();
993 let messages2 = bus.receive_messages(subscriber2).await.unwrap();
994
995 assert_eq!(messages1.len(), 1);
996 assert_eq!(messages2.len(), 1);
997 assert_eq!(messages1[0].sender, publisher);
998 assert_eq!(messages2[0].sender, publisher);
999 }
1000
1001 #[tokio::test]
1002 async fn test_message_size_limit() {
1003 let config = CommunicationConfig {
1004 max_message_size: 100, ..Default::default()
1006 };
1007
1008 let bus = DefaultCommunicationBus::new(config).await.unwrap();
1009 let sender = AgentId::new();
1010 let recipient = AgentId::new();
1011
1012 bus.register_agent(sender).await.unwrap();
1013 bus.register_agent(recipient).await.unwrap();
1014
1015 let mut message = create_test_message(sender, recipient);
1017 message.payload.data = vec![0u8; 200].into(); let result = bus.send_message(message).await;
1020 assert!(result.is_err());
1021
1022 if let Err(CommunicationError::MessageTooLarge { size, max_size }) = result {
1023 assert_eq!(size, 200);
1024 assert_eq!(max_size, 100);
1025 } else {
1026 panic!("Expected MessageTooLarge error");
1027 }
1028 }
1029
1030 #[tokio::test]
1031 async fn test_agent_unregistration() {
1032 let bus = DefaultCommunicationBus::new(CommunicationConfig::default())
1033 .await
1034 .unwrap();
1035 let agent_id = AgentId::new();
1036
1037 bus.register_agent(agent_id).await.unwrap();
1039 tokio::time::sleep(Duration::from_millis(50)).await;
1040
1041 bus.unregister_agent(agent_id).await.unwrap();
1042 tokio::time::sleep(Duration::from_millis(50)).await;
1043
1044 let result = bus.receive_messages(agent_id).await;
1046 assert!(result.is_err());
1047 }
1048
1049 #[tokio::test]
1050 async fn test_request_response_timeout() {
1051 let bus = DefaultCommunicationBus::new(CommunicationConfig::default())
1052 .await
1053 .unwrap();
1054 let target_agent = AgentId::new();
1055
1056 bus.register_agent(target_agent).await.unwrap();
1058 tokio::time::sleep(Duration::from_millis(50)).await;
1059
1060 let request_payload = bytes::Bytes::from("test request");
1062 let timeout = Duration::from_millis(100);
1063
1064 let result = bus.request(target_agent, request_payload, timeout).await;
1065 assert!(result.is_err());
1066
1067 if let Err(CommunicationError::RequestTimeout {
1068 request_id: _,
1069 timeout: actual_timeout,
1070 }) = result
1071 {
1072 assert_eq!(actual_timeout, timeout);
1073 } else {
1074 panic!("Expected RequestTimeout error");
1075 }
1076 }
1077
1078 #[tokio::test]
1079 async fn test_request_response_success() {
1080 let bus = DefaultCommunicationBus::new(CommunicationConfig::default())
1081 .await
1082 .unwrap();
1083 let requester = AgentId::new();
1084 let responder = AgentId::new();
1085
1086 bus.register_agent(requester).await.unwrap();
1088 bus.register_agent(responder).await.unwrap();
1089 tokio::time::sleep(Duration::from_millis(50)).await;
1090
1091 let request_payload = bytes::Bytes::from("test request");
1092 let response_payload = bytes::Bytes::from("test response");
1093
1094 let bus_clone = Arc::new(bus);
1096 let request_bus = bus_clone.clone();
1097 let request_handle = tokio::spawn(async move {
1098 request_bus
1099 .request(responder, request_payload, Duration::from_secs(5))
1100 .await
1101 });
1102
1103 tokio::time::sleep(Duration::from_millis(100)).await;
1105
1106 let messages = bus_clone.receive_messages(responder).await.unwrap();
1108 assert_eq!(messages.len(), 1);
1109 assert!(matches!(messages[0].message_type, MessageType::Request(_)));
1110
1111 if let MessageType::Request(request_id) = &messages[0].message_type {
1113 let response_message = SecureMessage {
1114 id: MessageId::new(),
1115 sender: responder,
1116 recipient: Some(requester),
1117 topic: None,
1118 message_type: MessageType::Response(*request_id),
1119 payload: EncryptedPayload {
1120 data: response_payload.clone(),
1121 nonce: vec![0u8; 12],
1122 encryption_algorithm: EncryptionAlgorithm::Aes256Gcm,
1123 },
1124 signature: MessageSignature {
1125 signature: vec![0u8; 64],
1126 algorithm: SignatureAlgorithm::Ed25519,
1127 public_key: vec![0u8; 32],
1128 },
1129 ttl: Duration::from_secs(3600),
1130 timestamp: SystemTime::now(),
1131 };
1132
1133 bus_clone.send_message(response_message).await.unwrap();
1134 }
1135
1136 let result = request_handle.await.unwrap();
1138 assert!(result.is_ok());
1139 assert_eq!(result.unwrap(), response_payload);
1140 }
1141}