use async_trait::async_trait;
use parking_lot::RwLock;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, SystemTime};
use tokio::sync::{mpsc, oneshot, Notify};
use tokio::time::{interval, timeout};
use crate::crypto::Aes256GcmCrypto;
use crate::types::*;
use ed25519_dalek::{SigningKey, VerifyingKey};
use rand::rngs::OsRng;
use rand::RngCore;
#[async_trait]
pub trait CommunicationBus {
async fn send_message(&self, message: SecureMessage) -> Result<MessageId, CommunicationError>;
async fn receive_messages(
&self,
agent_id: AgentId,
) -> Result<Vec<SecureMessage>, CommunicationError>;
async fn subscribe(&self, agent_id: AgentId, topic: String) -> Result<(), CommunicationError>;
async fn unsubscribe(&self, agent_id: AgentId, topic: String)
-> Result<(), CommunicationError>;
async fn publish(
&self,
topic: String,
message: SecureMessage,
) -> Result<(), CommunicationError>;
async fn get_delivery_status(
&self,
message_id: MessageId,
) -> Result<DeliveryStatus, CommunicationError>;
async fn register_agent(&self, agent_id: AgentId) -> Result<(), CommunicationError>;
async fn unregister_agent(&self, agent_id: AgentId) -> Result<(), CommunicationError>;
async fn request(
&self,
target_agent: AgentId,
request_payload: bytes::Bytes,
timeout_duration: Duration,
) -> Result<bytes::Bytes, CommunicationError>;
async fn shutdown(&self) -> Result<(), CommunicationError>;
async fn check_health(&self) -> Result<ComponentHealth, CommunicationError>;
}
#[derive(Debug, Clone)]
pub struct CommunicationConfig {
pub max_message_size: usize,
pub message_ttl: Duration,
pub max_queue_size: usize,
pub delivery_timeout: Duration,
pub retry_attempts: u32,
pub enable_encryption: bool,
pub enable_compression: bool,
pub dead_letter_queue_size: usize,
}
impl Default for CommunicationConfig {
fn default() -> Self {
Self {
max_message_size: 1024 * 1024, message_ttl: Duration::from_secs(3600), max_queue_size: 10000,
delivery_timeout: Duration::from_secs(30),
retry_attempts: 3,
enable_encryption: true,
enable_compression: true,
dead_letter_queue_size: 1000,
}
}
}
pub struct DefaultCommunicationBus {
config: CommunicationConfig,
message_queues: Arc<RwLock<HashMap<AgentId, MessageQueue>>>,
subscriptions: Arc<RwLock<HashMap<String, Vec<AgentId>>>>,
message_tracker: Arc<RwLock<HashMap<MessageId, MessageTracker>>>,
dead_letter_queue: Arc<RwLock<DeadLetterQueue>>,
pending_requests: Arc<RwLock<HashMap<RequestId, oneshot::Sender<bytes::Bytes>>>>,
event_sender: mpsc::UnboundedSender<CommunicationEvent>,
shutdown_notify: Arc<Notify>,
is_running: Arc<RwLock<bool>>,
signing_key: SigningKey,
verifying_key: VerifyingKey,
system_agent_id: AgentId,
#[allow(dead_code)]
crypto: Aes256GcmCrypto,
}
impl DefaultCommunicationBus {
pub async fn new(config: CommunicationConfig) -> Result<Self, CommunicationError> {
let message_queues = Arc::new(RwLock::new(HashMap::new()));
let subscriptions = Arc::new(RwLock::new(HashMap::new()));
let message_tracker = Arc::new(RwLock::new(HashMap::new()));
let dead_letter_queue = Arc::new(RwLock::new(DeadLetterQueue::new(
config.dead_letter_queue_size,
)));
let pending_requests = Arc::new(RwLock::new(HashMap::new()));
let (event_sender, event_receiver) = mpsc::unbounded_channel();
let shutdown_notify = Arc::new(Notify::new());
let is_running = Arc::new(RwLock::new(true));
let mut secret_bytes = [0u8; 32];
OsRng.fill_bytes(&mut secret_bytes);
let signing_key = SigningKey::from_bytes(&secret_bytes);
let verifying_key = signing_key.verifying_key();
let system_agent_id = AgentId::new();
let crypto = Aes256GcmCrypto::new();
let bus = Self {
config,
message_queues,
subscriptions,
message_tracker,
dead_letter_queue,
pending_requests,
event_sender,
shutdown_notify,
is_running,
signing_key,
verifying_key,
system_agent_id,
crypto,
};
bus.start_event_loop(event_receiver).await;
bus.start_cleanup_loop().await;
Ok(bus)
}
async fn start_event_loop(
&self,
mut event_receiver: mpsc::UnboundedReceiver<CommunicationEvent>,
) {
let message_queues = self.message_queues.clone();
let subscriptions = self.subscriptions.clone();
let message_tracker = self.message_tracker.clone();
let dead_letter_queue = self.dead_letter_queue.clone();
let pending_requests = self.pending_requests.clone();
let shutdown_notify = self.shutdown_notify.clone();
let config = self.config.clone();
tokio::spawn(async move {
loop {
tokio::select! {
event = event_receiver.recv() => {
if let Some(event) = event {
Self::process_communication_event(
event,
&message_queues,
&subscriptions,
&message_tracker,
&dead_letter_queue,
&pending_requests,
&config,
).await;
} else {
break;
}
}
_ = shutdown_notify.notified() => {
break;
}
}
}
});
}
async fn start_cleanup_loop(&self) {
let message_queues = self.message_queues.clone();
let message_tracker = self.message_tracker.clone();
let dead_letter_queue = self.dead_letter_queue.clone();
let shutdown_notify = self.shutdown_notify.clone();
let is_running = self.is_running.clone();
let message_ttl = self.config.message_ttl;
tokio::spawn(async move {
let mut interval = interval(Duration::from_secs(60));
loop {
tokio::select! {
_ = interval.tick() => {
if !*is_running.read() {
break;
}
Self::cleanup_expired_messages(&message_queues, &message_tracker, &dead_letter_queue, message_ttl).await;
}
_ = shutdown_notify.notified() => {
break;
}
}
}
});
}
async fn process_communication_event(
event: CommunicationEvent,
message_queues: &Arc<RwLock<HashMap<AgentId, MessageQueue>>>,
subscriptions: &Arc<RwLock<HashMap<String, Vec<AgentId>>>>,
message_tracker: &Arc<RwLock<HashMap<MessageId, MessageTracker>>>,
dead_letter_queue: &Arc<RwLock<DeadLetterQueue>>,
pending_requests: &Arc<RwLock<HashMap<RequestId, oneshot::Sender<bytes::Bytes>>>>,
config: &CommunicationConfig,
) {
match event {
CommunicationEvent::MessageSent { message } => {
let recipient = message.recipient;
let message_id = message.id;
if let MessageType::Response(request_id) = &message.message_type {
if let Some(sender) = pending_requests.write().remove(request_id) {
let _ = sender.send(message.payload.data.clone());
tracing::debug!(
"Response {} sent for request {:?}",
message_id,
request_id
);
return;
}
}
message_tracker
.write()
.insert(message_id, MessageTracker::new(message.clone()));
let mut queues = message_queues.write();
if let Some(recipient_id) = recipient {
if let Some(queue) = queues.get_mut(&recipient_id) {
if queue.can_accept_message(config) {
queue.add_message(message);
if let Some(tracker) = message_tracker.write().get_mut(&message_id) {
tracker.status = DeliveryStatus::Delivered;
tracker.delivered_at = Some(SystemTime::now());
}
tracing::debug!(
"Message {} delivered to agent {}",
message_id,
recipient_id
);
} else {
dead_letter_queue
.write()
.add_message(message, DeadLetterReason::QueueFull);
if let Some(tracker) = message_tracker.write().get_mut(&message_id) {
tracker.status = DeliveryStatus::Failed;
tracker.failure_reason = Some("Queue full".to_string());
}
tracing::warn!(
"Message {} failed to deliver: queue full for agent {}",
message_id,
recipient_id
);
}
} else {
dead_letter_queue
.write()
.add_message(message, DeadLetterReason::AgentNotFound);
if let Some(tracker) = message_tracker.write().get_mut(&message_id) {
tracker.status = DeliveryStatus::Failed;
tracker.failure_reason = Some("Agent not registered".to_string());
}
tracing::warn!(
"Message {} failed to deliver: agent {:?} not registered",
message_id,
recipient
);
}
} else {
dead_letter_queue
.write()
.add_message(message, DeadLetterReason::AgentNotFound);
if let Some(tracker) = message_tracker.write().get_mut(&message_id) {
tracker.status = DeliveryStatus::Failed;
tracker.failure_reason = Some("Agent not registered".to_string());
}
tracing::warn!(
"Message {} failed to deliver: agent {:?} not registered",
message_id,
recipient
);
}
}
CommunicationEvent::TopicPublished { topic, message } => {
let subscribers = subscriptions
.read()
.get(&topic)
.cloned()
.unwrap_or_default();
let subscriber_count = subscribers.len();
for subscriber in &subscribers {
let mut subscriber_message = message.clone();
subscriber_message.recipient = Some(*subscriber);
subscriber_message.id = MessageId::new();
Box::pin(Self::process_communication_event(
CommunicationEvent::MessageSent {
message: subscriber_message,
},
message_queues,
subscriptions,
message_tracker,
dead_letter_queue,
pending_requests,
config,
))
.await;
}
tracing::debug!(
"Published message to topic {} for {} subscribers",
topic,
subscriber_count
);
}
CommunicationEvent::AgentRegistered { agent_id } => {
message_queues.write().insert(agent_id, MessageQueue::new());
tracing::info!("Registered agent {} for communication", agent_id);
}
CommunicationEvent::AgentUnregistered { agent_id } => {
message_queues.write().remove(&agent_id);
let mut subs = subscriptions.write();
for subscribers in subs.values_mut() {
subscribers.retain(|&id| id != agent_id);
}
tracing::info!("Unregistered agent {} from communication", agent_id);
}
}
}
async fn cleanup_expired_messages(
message_queues: &Arc<RwLock<HashMap<AgentId, MessageQueue>>>,
message_tracker: &Arc<RwLock<HashMap<MessageId, MessageTracker>>>,
dead_letter_queue: &Arc<RwLock<DeadLetterQueue>>,
message_ttl: Duration,
) {
let now = SystemTime::now();
let mut expired_messages = Vec::new();
{
let mut queues = message_queues.write();
let mut stale_queues = 0;
for queue in queues.values_mut() {
let expired = queue.remove_expired_messages(now, message_ttl);
expired_messages.extend(expired);
if queue.is_stale(message_ttl * 3) {
stale_queues += 1;
}
}
if stale_queues > 0 {
tracing::debug!("Found {} stale message queues", stale_queues);
}
}
{
let mut dlq = dead_letter_queue.write();
for message in expired_messages {
dlq.add_message(message.clone(), DeadLetterReason::Expired);
if let Some(tracker) = message_tracker.write().get_mut(&message.id) {
tracker.status = DeliveryStatus::Failed;
tracker.failure_reason = Some("Message expired".to_string());
}
}
}
{
let mut tracker = message_tracker.write();
let mut retry_candidates = Vec::new();
tracker.retain(|message_id, t| {
let age = t.get_age();
if age < message_ttl * 2 {
if t.should_retry(message_ttl) {
retry_candidates.push(*message_id);
let msg = t.get_message();
tracing::debug!(
"Message {} eligible for retry: size={} bytes, age={:?}s, sender={}",
message_id,
t.get_message_size(),
t.get_age().as_secs(),
msg.sender
);
}
true
} else {
false
}
});
if !retry_candidates.is_empty() {
tracing::debug!(
"Found {} messages eligible for retry",
retry_candidates.len()
);
}
}
}
fn send_event(&self, event: CommunicationEvent) -> Result<(), CommunicationError> {
self.event_sender
.send(event)
.map_err(|_| CommunicationError::EventProcessingFailed {
reason: "Failed to send communication event".to_string(),
})
}
fn generate_nonce() -> Vec<u8> {
use aes_gcm::{aead::AeadCore, Aes256Gcm};
let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
nonce.to_vec()
}
fn sign_message_data(&self, data: &[u8]) -> MessageSignature {
use ed25519_dalek::Signer;
let signature = self.signing_key.sign(data);
MessageSignature {
signature: signature.to_bytes().to_vec(),
algorithm: SignatureAlgorithm::Ed25519,
public_key: self.verifying_key.to_bytes().to_vec(),
}
}
fn create_secure_request_message(
&self,
target_agent: AgentId,
request_id: RequestId,
request_payload: bytes::Bytes,
timeout_duration: Duration,
) -> Result<SecureMessage, CommunicationError> {
let nonce = Self::generate_nonce();
let payload = EncryptedPayload {
data: request_payload,
nonce,
encryption_algorithm: EncryptionAlgorithm::Aes256Gcm,
};
let message_data_to_sign = [payload.data.as_ref(), &payload.nonce].concat();
let signature = self.sign_message_data(&message_data_to_sign);
Ok(SecureMessage {
id: MessageId::new(),
sender: self.system_agent_id,
recipient: Some(target_agent),
topic: None,
message_type: MessageType::Request(request_id),
payload,
signature,
ttl: timeout_duration,
timestamp: SystemTime::now(),
})
}
}
#[async_trait]
impl CommunicationBus for DefaultCommunicationBus {
async fn send_message(&self, message: SecureMessage) -> Result<MessageId, CommunicationError> {
if !*self.is_running.read() {
return Err(CommunicationError::ShuttingDown);
}
if message.payload.data.len() > self.config.max_message_size {
return Err(CommunicationError::MessageTooLarge {
size: message.payload.data.len(),
max_size: self.config.max_message_size,
});
}
let message_id = message.id;
self.send_event(CommunicationEvent::MessageSent { message })?;
Ok(message_id)
}
async fn receive_messages(
&self,
agent_id: AgentId,
) -> Result<Vec<SecureMessage>, CommunicationError> {
let mut queues = self.message_queues.write();
if let Some(queue) = queues.get_mut(&agent_id) {
Ok(queue.drain_messages())
} else {
Err(CommunicationError::AgentNotRegistered { agent_id })
}
}
async fn subscribe(&self, agent_id: AgentId, topic: String) -> Result<(), CommunicationError> {
let mut subscriptions = self.subscriptions.write();
subscriptions
.entry(topic.clone())
.or_default()
.push(agent_id);
tracing::info!("Agent {} subscribed to topic {}", agent_id, topic);
Ok(())
}
async fn unsubscribe(
&self,
agent_id: AgentId,
topic: String,
) -> Result<(), CommunicationError> {
let mut subscriptions = self.subscriptions.write();
if let Some(subscribers) = subscriptions.get_mut(&topic) {
subscribers.retain(|&id| id != agent_id);
if subscribers.is_empty() {
subscriptions.remove(&topic);
}
}
tracing::info!("Agent {} unsubscribed from topic {}", agent_id, topic);
Ok(())
}
async fn publish(
&self,
topic: String,
message: SecureMessage,
) -> Result<(), CommunicationError> {
if !*self.is_running.read() {
return Err(CommunicationError::ShuttingDown);
}
self.send_event(CommunicationEvent::TopicPublished { topic, message })?;
Ok(())
}
async fn get_delivery_status(
&self,
message_id: MessageId,
) -> Result<DeliveryStatus, CommunicationError> {
self.message_tracker
.read()
.get(&message_id)
.map(|tracker| tracker.status.clone())
.ok_or(CommunicationError::MessageNotFound { message_id })
}
async fn register_agent(&self, agent_id: AgentId) -> Result<(), CommunicationError> {
self.send_event(CommunicationEvent::AgentRegistered { agent_id })?;
Ok(())
}
async fn unregister_agent(&self, agent_id: AgentId) -> Result<(), CommunicationError> {
self.send_event(CommunicationEvent::AgentUnregistered { agent_id })?;
Ok(())
}
async fn request(
&self,
target_agent: AgentId,
request_payload: bytes::Bytes,
timeout_duration: Duration,
) -> Result<bytes::Bytes, CommunicationError> {
if !*self.is_running.read() {
return Err(CommunicationError::ShuttingDown);
}
let request_id = RequestId::new();
let (response_sender, response_receiver) = oneshot::channel();
self.pending_requests
.write()
.insert(request_id, response_sender);
let request_message = self.create_secure_request_message(
target_agent,
request_id,
request_payload,
timeout_duration,
)?;
self.send_message(request_message).await?;
match timeout(timeout_duration, response_receiver).await {
Ok(Ok(response_payload)) => Ok(response_payload),
Ok(Err(_)) => {
self.pending_requests.write().remove(&request_id);
Err(CommunicationError::RequestCancelled { request_id })
}
Err(_) => {
self.pending_requests.write().remove(&request_id);
Err(CommunicationError::RequestTimeout {
request_id,
timeout: timeout_duration,
})
}
}
}
async fn shutdown(&self) -> Result<(), CommunicationError> {
tracing::info!("Shutting down communication bus");
*self.is_running.write() = false;
self.shutdown_notify.notify_waiters();
let agent_ids: Vec<AgentId> = self.message_queues.read().keys().copied().collect();
for agent_id in agent_ids {
if let Err(e) = self.unregister_agent(agent_id).await {
tracing::error!(
"Failed to unregister agent {} during shutdown: {}",
agent_id,
e
);
}
}
Ok(())
}
async fn check_health(&self) -> Result<ComponentHealth, CommunicationError> {
let is_running = *self.is_running.read();
if !is_running {
return Ok(ComponentHealth::unhealthy(
"Communication bus is shut down".to_string(),
));
}
let queue_count = self.message_queues.read().len();
let topic_count = self.subscriptions.read().len();
let tracker_count = self.message_tracker.read().len();
let pending_requests = self.pending_requests.read().len();
let mut total_queued_messages = 0;
let mut full_queues = 0;
{
let queues = self.message_queues.read();
for queue in queues.values() {
total_queued_messages += queue.messages.len();
if queue.messages.len() >= self.config.max_queue_size * 9 / 10 {
full_queues += 1;
}
}
}
let dead_letter_count = self.dead_letter_queue.read().messages.len();
let status = if dead_letter_count > 100 {
ComponentHealth::degraded(format!(
"High dead letter queue: {} messages",
dead_letter_count
))
} else if full_queues > 0 {
ComponentHealth::degraded(format!("{} message queues near capacity", full_queues))
} else if pending_requests > 50 {
ComponentHealth::degraded(format!("Many pending requests: {}", pending_requests))
} else {
ComponentHealth::healthy(Some(format!(
"{} agents registered, {} active topics",
queue_count, topic_count
)))
};
Ok(status
.with_metric("registered_agents".to_string(), queue_count.to_string())
.with_metric("active_topics".to_string(), topic_count.to_string())
.with_metric(
"queued_messages".to_string(),
total_queued_messages.to_string(),
)
.with_metric("pending_requests".to_string(), pending_requests.to_string())
.with_metric("dead_letters".to_string(), dead_letter_count.to_string())
.with_metric("message_trackers".to_string(), tracker_count.to_string()))
}
}
#[derive(Debug, Clone)]
struct MessageQueue {
messages: Vec<SecureMessage>,
created_at: SystemTime,
}
impl MessageQueue {
fn new() -> Self {
Self {
messages: Vec::new(),
created_at: SystemTime::now(),
}
}
fn can_accept_message(&self, config: &CommunicationConfig) -> bool {
self.messages.len() < config.max_queue_size
}
fn add_message(&mut self, message: SecureMessage) {
self.messages.push(message);
}
fn drain_messages(&mut self) -> Vec<SecureMessage> {
std::mem::take(&mut self.messages)
}
fn remove_expired_messages(&mut self, now: SystemTime, ttl: Duration) -> Vec<SecureMessage> {
let mut expired = Vec::new();
self.messages.retain(|message| {
let age = now.duration_since(message.timestamp).unwrap_or_default();
if age > ttl {
expired.push(message.clone());
false
} else {
true
}
});
expired
}
fn get_queue_age(&self) -> Duration {
SystemTime::now()
.duration_since(self.created_at)
.unwrap_or_default()
}
fn is_stale(&self, max_age: Duration) -> bool {
self.get_queue_age() > max_age
}
}
#[derive(Debug, Clone)]
struct MessageTracker {
message: SecureMessage,
status: DeliveryStatus,
created_at: SystemTime,
delivered_at: Option<SystemTime>,
failure_reason: Option<String>,
}
impl MessageTracker {
fn new(message: SecureMessage) -> Self {
Self {
message,
status: DeliveryStatus::Pending,
created_at: SystemTime::now(),
delivered_at: None,
failure_reason: None,
}
}
fn get_message(&self) -> &SecureMessage {
&self.message
}
fn get_message_size(&self) -> usize {
self.message.payload.data.len()
}
fn get_age(&self) -> Duration {
SystemTime::now()
.duration_since(self.created_at)
.unwrap_or_default()
}
fn should_retry(&self, max_age: Duration) -> bool {
matches!(self.status, DeliveryStatus::Failed) && self.get_age() < max_age
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DeliveryStatus {
Pending,
Delivered,
Failed,
Expired,
}
#[derive(Debug, Clone)]
enum CommunicationEvent {
MessageSent {
message: SecureMessage,
},
TopicPublished {
topic: String,
message: SecureMessage,
},
AgentRegistered {
agent_id: AgentId,
},
AgentUnregistered {
agent_id: AgentId,
},
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::{EncryptedPayload, MessageType};
fn create_test_message(sender: AgentId, recipient: AgentId) -> SecureMessage {
use crate::types::RequestId;
SecureMessage {
id: MessageId::new(),
sender,
recipient: Some(recipient),
message_type: MessageType::Request(RequestId::new()),
topic: Some("test".to_string()),
payload: EncryptedPayload {
data: b"test message".to_vec().into(),
nonce: [0u8; 12].to_vec(),
encryption_algorithm: EncryptionAlgorithm::Aes256Gcm,
},
signature: MessageSignature {
signature: vec![0u8; 64],
algorithm: SignatureAlgorithm::Ed25519,
public_key: vec![0u8; 32],
},
ttl: Duration::from_secs(3600),
timestamp: SystemTime::now(),
}
}
#[tokio::test]
async fn test_agent_registration() {
let bus = DefaultCommunicationBus::new(CommunicationConfig::default())
.await
.unwrap();
let agent_id = AgentId::new();
let result = bus.register_agent(agent_id).await;
assert!(result.is_ok());
tokio::time::sleep(Duration::from_millis(50)).await;
let messages = bus.receive_messages(agent_id).await;
assert!(messages.is_ok());
}
#[tokio::test]
async fn test_message_sending() {
let bus = DefaultCommunicationBus::new(CommunicationConfig::default())
.await
.unwrap();
let sender = AgentId::new();
let recipient = AgentId::new();
bus.register_agent(sender).await.unwrap();
bus.register_agent(recipient).await.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
let message = create_test_message(sender, recipient);
let message_id = bus.send_message(message).await.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
let status = bus.get_delivery_status(message_id).await.unwrap();
assert_eq!(status, DeliveryStatus::Delivered);
let messages = bus.receive_messages(recipient).await.unwrap();
assert_eq!(messages.len(), 1);
assert_eq!(messages[0].sender, sender);
}
#[tokio::test]
async fn test_topic_subscription() {
let bus = DefaultCommunicationBus::new(CommunicationConfig::default())
.await
.unwrap();
let publisher = AgentId::new();
let subscriber1 = AgentId::new();
let subscriber2 = AgentId::new();
bus.register_agent(publisher).await.unwrap();
bus.register_agent(subscriber1).await.unwrap();
bus.register_agent(subscriber2).await.unwrap();
let topic = "test_topic".to_string();
bus.subscribe(subscriber1, topic.clone()).await.unwrap();
bus.subscribe(subscriber2, topic.clone()).await.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
let message = create_test_message(publisher, AgentId::new()); bus.publish(topic, message).await.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
let messages1 = bus.receive_messages(subscriber1).await.unwrap();
let messages2 = bus.receive_messages(subscriber2).await.unwrap();
assert_eq!(messages1.len(), 1);
assert_eq!(messages2.len(), 1);
assert_eq!(messages1[0].sender, publisher);
assert_eq!(messages2[0].sender, publisher);
}
#[tokio::test]
async fn test_message_size_limit() {
let config = CommunicationConfig {
max_message_size: 100, ..Default::default()
};
let bus = DefaultCommunicationBus::new(config).await.unwrap();
let sender = AgentId::new();
let recipient = AgentId::new();
bus.register_agent(sender).await.unwrap();
bus.register_agent(recipient).await.unwrap();
let mut message = create_test_message(sender, recipient);
message.payload.data = vec![0u8; 200].into();
let result = bus.send_message(message).await;
assert!(result.is_err());
if let Err(CommunicationError::MessageTooLarge { size, max_size }) = result {
assert_eq!(size, 200);
assert_eq!(max_size, 100);
} else {
panic!("Expected MessageTooLarge error");
}
}
#[tokio::test]
async fn test_agent_unregistration() {
let bus = DefaultCommunicationBus::new(CommunicationConfig::default())
.await
.unwrap();
let agent_id = AgentId::new();
bus.register_agent(agent_id).await.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
bus.unregister_agent(agent_id).await.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
let result = bus.receive_messages(agent_id).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_request_response_timeout() {
let bus = DefaultCommunicationBus::new(CommunicationConfig::default())
.await
.unwrap();
let target_agent = AgentId::new();
bus.register_agent(target_agent).await.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
let request_payload = bytes::Bytes::from("test request");
let timeout = Duration::from_millis(100);
let result = bus.request(target_agent, request_payload, timeout).await;
assert!(result.is_err());
if let Err(CommunicationError::RequestTimeout {
request_id: _,
timeout: actual_timeout,
}) = result
{
assert_eq!(actual_timeout, timeout);
} else {
panic!("Expected RequestTimeout error");
}
}
#[tokio::test]
async fn test_request_response_success() {
let bus = DefaultCommunicationBus::new(CommunicationConfig::default())
.await
.unwrap();
let requester = AgentId::new();
let responder = AgentId::new();
bus.register_agent(requester).await.unwrap();
bus.register_agent(responder).await.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
let request_payload = bytes::Bytes::from("test request");
let response_payload = bytes::Bytes::from("test response");
let bus_clone = Arc::new(bus);
let request_bus = bus_clone.clone();
let request_handle = tokio::spawn(async move {
request_bus
.request(responder, request_payload, Duration::from_secs(5))
.await
});
tokio::time::sleep(Duration::from_millis(100)).await;
let messages = bus_clone.receive_messages(responder).await.unwrap();
assert_eq!(messages.len(), 1);
assert!(matches!(messages[0].message_type, MessageType::Request(_)));
if let MessageType::Request(request_id) = &messages[0].message_type {
let response_message = SecureMessage {
id: MessageId::new(),
sender: responder,
recipient: Some(requester),
topic: None,
message_type: MessageType::Response(*request_id),
payload: EncryptedPayload {
data: response_payload.clone(),
nonce: vec![0u8; 12],
encryption_algorithm: EncryptionAlgorithm::Aes256Gcm,
},
signature: MessageSignature {
signature: vec![0u8; 64],
algorithm: SignatureAlgorithm::Ed25519,
public_key: vec![0u8; 32],
},
ttl: Duration::from_secs(3600),
timestamp: SystemTime::now(),
};
bus_clone.send_message(response_message).await.unwrap();
}
let result = request_handle.await.unwrap();
assert!(result.is_ok());
assert_eq!(result.unwrap(), response_payload);
}
}