use anyhow::Result;
use crossbeam_channel::{Receiver, Sender};
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::core::AISession;
const DEFAULT_CHANNEL_CAPACITY: usize = 1000;
const BROADCAST_CHANNEL_CAPACITY: usize = 5000;
const ALL_MESSAGES_CHANNEL_CAPACITY: usize = 10000;
#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
pub struct AgentId(Uuid);
impl Default for AgentId {
fn default() -> Self {
Self::new()
}
}
impl AgentId {
pub fn new() -> Self {
Self(Uuid::new_v4())
}
}
impl std::fmt::Display for AgentId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
pub struct MultiAgentSession {
pub agents: Arc<DashMap<AgentId, Arc<AISession>>>,
pub message_bus: Arc<MessageBus>,
pub task_distributor: Arc<TaskDistributor>,
pub resource_manager: Arc<ResourceManager>,
}
impl Default for MultiAgentSession {
fn default() -> Self {
Self::new()
}
}
impl MultiAgentSession {
pub fn new() -> Self {
Self {
agents: Arc::new(DashMap::new()),
message_bus: Arc::new(MessageBus::new()),
task_distributor: Arc::new(TaskDistributor::new()),
resource_manager: Arc::new(ResourceManager::new()),
}
}
pub fn register_agent(&self, agent_id: AgentId, session: Arc<AISession>) -> Result<()> {
self.agents.insert(agent_id.clone(), session);
self.message_bus.register_agent(agent_id)?;
Ok(())
}
pub fn unregister_agent(&self, agent_id: &AgentId) -> Result<()> {
self.agents.remove(agent_id);
self.message_bus.unregister_agent(agent_id)?;
Ok(())
}
pub fn get_agent(&self, agent_id: &AgentId) -> Option<Arc<AISession>> {
self.agents.get(agent_id).map(|entry| entry.clone())
}
pub fn list_agents(&self) -> Vec<AgentId> {
self.agents
.iter()
.map(|entry| entry.key().clone())
.collect()
}
pub async fn send_message(&self, from: AgentId, to: AgentId, message: Message) -> Result<()> {
self.message_bus.send_message(from, to, message)
}
pub async fn broadcast(&self, from: AgentId, message: BroadcastMessage) -> Result<()> {
self.message_bus.broadcast(from, message)
}
}
pub struct MessageBus {
channels: DashMap<AgentId, (Sender<Message>, Receiver<Message>)>,
broadcast_sender: Sender<BroadcastMessage>,
_broadcast_receiver: Receiver<BroadcastMessage>,
agent_channels: DashMap<AgentId, (Sender<AgentMessage>, Receiver<AgentMessage>)>,
all_messages_sender: Sender<AgentMessage>,
all_messages_receiver: Receiver<AgentMessage>,
}
impl Default for MessageBus {
fn default() -> Self {
Self::new()
}
}
impl MessageBus {
pub fn new() -> Self {
let (broadcast_sender, broadcast_receiver) =
crossbeam_channel::bounded(BROADCAST_CHANNEL_CAPACITY);
let (all_messages_sender, all_messages_receiver) =
crossbeam_channel::bounded(ALL_MESSAGES_CHANNEL_CAPACITY);
Self {
channels: DashMap::new(),
broadcast_sender,
_broadcast_receiver: broadcast_receiver,
agent_channels: DashMap::new(),
all_messages_sender,
all_messages_receiver,
}
}
pub fn register_agent(&self, agent_id: AgentId) -> Result<()> {
let (sender, receiver) = crossbeam_channel::bounded(DEFAULT_CHANNEL_CAPACITY);
self.channels.insert(agent_id.clone(), (sender, receiver));
let (agent_sender, agent_receiver) = crossbeam_channel::bounded(DEFAULT_CHANNEL_CAPACITY);
self.agent_channels
.insert(agent_id, (agent_sender, agent_receiver));
Ok(())
}
pub fn unregister_agent(&self, agent_id: &AgentId) -> Result<()> {
self.channels.remove(agent_id);
self.agent_channels.remove(agent_id);
Ok(())
}
pub fn send_message(&self, _from: AgentId, to: AgentId, message: Message) -> Result<()> {
if let Some(channel) = self.channels.get(&to) {
channel.0.try_send(message).map_err(|e| match e {
crossbeam_channel::TrySendError::Full(_) => {
anyhow::anyhow!("Agent {} channel is full (backpressure)", to)
}
crossbeam_channel::TrySendError::Disconnected(_) => {
anyhow::anyhow!("Agent {} channel disconnected", to)
}
})?;
Ok(())
} else {
Err(anyhow::anyhow!("Agent not found: {}", to))
}
}
pub fn broadcast(&self, _from: AgentId, message: BroadcastMessage) -> Result<()> {
self.broadcast_sender
.try_send(message)
.map_err(|e| match e {
crossbeam_channel::TrySendError::Full(_) => {
anyhow::anyhow!("Broadcast channel is full (backpressure)")
}
crossbeam_channel::TrySendError::Disconnected(_) => {
anyhow::anyhow!("Broadcast channel disconnected")
}
})?;
Ok(())
}
pub fn get_receiver(&self, agent_id: &AgentId) -> Option<Receiver<Message>> {
self.channels.get(agent_id).map(|entry| entry.1.clone())
}
pub fn subscribe_all(&self) -> Receiver<AgentMessage> {
self.all_messages_receiver.clone()
}
pub async fn publish_to_agent(&self, agent_id: &AgentId, message: AgentMessage) -> Result<()> {
if let Some(channel) = self.agent_channels.get(agent_id) {
channel.0.try_send(message.clone()).map_err(|e| match e {
crossbeam_channel::TrySendError::Full(_) => {
anyhow::anyhow!("Agent {} channel is full (backpressure)", agent_id)
}
crossbeam_channel::TrySendError::Disconnected(_) => {
anyhow::anyhow!("Agent {} channel disconnected", agent_id)
}
})?;
} else {
return Err(anyhow::anyhow!("Agent not found: {}", agent_id));
}
let _ = self.all_messages_sender.try_send(message);
Ok(())
}
pub fn get_agent_receiver(&self, agent_id: &AgentId) -> Option<Receiver<AgentMessage>> {
self.agent_channels
.get(agent_id)
.map(|entry| entry.1.clone())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MessageContent {
Registration {
agent_id: AgentId,
capabilities: Vec<String>,
metadata: serde_json::Value,
},
TaskAssignment {
task_id: TaskId,
agent_id: AgentId,
task_data: serde_json::Value,
},
TaskCompleted {
agent_id: AgentId,
task_id: TaskId,
result: serde_json::Value,
},
TaskProgress {
agent_id: AgentId,
task_id: TaskId,
progress: f32,
message: String,
},
HelpRequest {
agent_id: AgentId,
context: String,
priority: MessagePriority,
},
StatusUpdate {
agent_id: AgentId,
status: String,
metrics: serde_json::Value,
},
DataShare { data: serde_json::Value },
CoordinationRequest {
request_type: String,
data: serde_json::Value,
},
Response {
in_reply_to: Uuid,
data: serde_json::Value,
},
Custom {
message_type: String,
data: serde_json::Value,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct UnifiedMessage {
pub id: Uuid,
pub from: AgentId,
pub content: MessageContent,
pub timestamp: chrono::DateTime<chrono::Utc>,
}
impl UnifiedMessage {
pub fn new(from: AgentId, content: MessageContent) -> Self {
Self {
id: Uuid::new_v4(),
from,
content,
timestamp: chrono::Utc::now(),
}
}
pub fn from_legacy_message(msg: Message) -> Self {
let content = match msg.message_type {
MessageType::TaskAssignment => MessageContent::Custom {
message_type: "task_assignment".to_string(),
data: msg.payload,
},
MessageType::StatusUpdate => MessageContent::Custom {
message_type: "status_update".to_string(),
data: msg.payload,
},
MessageType::DataShare => MessageContent::DataShare { data: msg.payload },
MessageType::CoordinationRequest => MessageContent::CoordinationRequest {
request_type: "legacy".to_string(),
data: msg.payload,
},
MessageType::Response => MessageContent::Response {
in_reply_to: Uuid::nil(),
data: msg.payload,
},
MessageType::Custom(t) => MessageContent::Custom {
message_type: t,
data: msg.payload,
},
};
Self {
id: msg.id,
from: msg.from,
content,
timestamp: msg.timestamp,
}
}
pub fn from_agent_message(from: AgentId, msg: AgentMessage) -> Self {
let content = match msg {
AgentMessage::Registration {
agent_id,
capabilities,
metadata,
} => MessageContent::Registration {
agent_id,
capabilities,
metadata,
},
AgentMessage::TaskAssignment {
task_id,
agent_id,
task_data,
} => MessageContent::TaskAssignment {
task_id,
agent_id,
task_data,
},
AgentMessage::TaskCompleted {
agent_id,
task_id,
result,
} => MessageContent::TaskCompleted {
agent_id,
task_id,
result,
},
AgentMessage::TaskProgress {
agent_id,
task_id,
progress,
message,
} => MessageContent::TaskProgress {
agent_id,
task_id,
progress,
message,
},
AgentMessage::HelpRequest {
agent_id,
context,
priority,
} => MessageContent::HelpRequest {
agent_id,
context,
priority,
},
AgentMessage::StatusUpdate {
agent_id,
status,
metrics,
} => MessageContent::StatusUpdate {
agent_id,
status,
metrics,
},
AgentMessage::Custom { message_type, data } => {
MessageContent::Custom { message_type, data }
}
};
Self {
id: Uuid::new_v4(),
from,
content,
timestamp: chrono::Utc::now(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub id: Uuid,
pub from: AgentId,
pub message_type: MessageType,
pub payload: serde_json::Value,
pub timestamp: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MessageType {
TaskAssignment,
StatusUpdate,
DataShare,
CoordinationRequest,
Response,
Custom(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum AgentMessage {
Registration {
agent_id: AgentId,
capabilities: Vec<String>,
metadata: serde_json::Value,
},
TaskAssignment {
task_id: TaskId,
agent_id: AgentId,
task_data: serde_json::Value,
},
TaskCompleted {
agent_id: AgentId,
task_id: TaskId,
result: serde_json::Value,
},
TaskProgress {
agent_id: AgentId,
task_id: TaskId,
progress: f32,
message: String,
},
HelpRequest {
agent_id: AgentId,
context: String,
priority: MessagePriority,
},
StatusUpdate {
agent_id: AgentId,
status: String,
metrics: serde_json::Value,
},
Custom {
message_type: String,
data: serde_json::Value,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BroadcastMessage {
pub id: Uuid,
pub from: AgentId,
pub content: String,
pub priority: MessagePriority,
pub timestamp: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MessagePriority {
Low,
Normal,
High,
Critical,
}
pub struct TaskDistributor {
task_queue: Arc<RwLock<Vec<Task>>>,
agent_capabilities: Arc<DashMap<AgentId, Vec<String>>>,
assignments: Arc<DashMap<TaskId, AgentId>>,
}
impl Default for TaskDistributor {
fn default() -> Self {
Self::new()
}
}
impl TaskDistributor {
pub fn new() -> Self {
Self {
task_queue: Arc::new(RwLock::new(Vec::new())),
agent_capabilities: Arc::new(DashMap::new()),
assignments: Arc::new(DashMap::new()),
}
}
pub fn register_capabilities(&self, agent_id: AgentId, capabilities: Vec<String>) {
self.agent_capabilities.insert(agent_id, capabilities);
}
pub async fn submit_task(&self, task: Task) -> Result<()> {
self.task_queue.write().await.push(task);
Ok(())
}
pub async fn distribute_tasks(&self) -> Result<Vec<(TaskId, AgentId)>> {
let mut assignments = Vec::new();
let mut queue = self.task_queue.write().await;
let agents: Vec<AgentId> = self
.agent_capabilities
.iter()
.map(|entry| entry.key().clone())
.collect();
if agents.is_empty() {
return Ok(assignments);
}
let mut agent_index = 0;
while let Some(task) = queue.pop() {
let agent_id = &agents[agent_index % agents.len()];
self.assignments.insert(task.id.clone(), agent_id.clone());
assignments.push((task.id, agent_id.clone()));
agent_index += 1;
}
Ok(assignments)
}
}
#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
pub struct TaskId(Uuid);
impl Default for TaskId {
fn default() -> Self {
Self::new()
}
}
impl TaskId {
pub fn new() -> Self {
Self(Uuid::new_v4())
}
}
impl std::fmt::Display for TaskId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Task {
pub id: TaskId,
pub name: String,
pub required_capabilities: Vec<String>,
pub payload: serde_json::Value,
pub priority: TaskPriority,
pub created_at: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TaskPriority {
Low,
Normal,
High,
Critical,
}
pub struct ResourceManager {
file_locks: Arc<DashMap<String, AgentId>>,
rate_limits: Arc<DashMap<String, RateLimit>>,
shared_memory: Arc<DashMap<String, Vec<u8>>>,
}
impl Default for ResourceManager {
fn default() -> Self {
Self::new()
}
}
impl ResourceManager {
pub fn new() -> Self {
Self {
file_locks: Arc::new(DashMap::new()),
rate_limits: Arc::new(DashMap::new()),
shared_memory: Arc::new(DashMap::new()),
}
}
pub fn acquire_file_lock(&self, path: &str, agent_id: AgentId) -> Result<()> {
match self.file_locks.entry(path.to_string()) {
dashmap::mapref::entry::Entry::Occupied(_) => {
Err(anyhow::anyhow!("File already locked: {}", path))
}
dashmap::mapref::entry::Entry::Vacant(entry) => {
entry.insert(agent_id);
Ok(())
}
}
}
pub fn release_file_lock(&self, path: &str, agent_id: &AgentId) -> Result<()> {
if let Some((_, owner)) = self.file_locks.remove(path)
&& owner != *agent_id
{
return Err(anyhow::anyhow!("Not the lock owner"));
}
Ok(())
}
pub fn check_rate_limit(&self, resource: &str) -> bool {
if let Some(limit) = self.rate_limits.get(resource) {
limit.can_proceed()
} else {
true
}
}
pub fn set_rate_limit(
&self,
resource: &str,
max_requests: usize,
interval: std::time::Duration,
) {
self.rate_limits
.insert(resource.to_string(), RateLimit::new(max_requests, interval));
}
pub fn rate_limit_remaining(&self, resource: &str) -> Option<usize> {
self.rate_limits
.get(resource)
.map(|limit| limit.remaining())
}
pub fn write_shared_memory(&self, key: &str, data: Vec<u8>) {
self.shared_memory.insert(key.to_string(), data);
}
pub fn read_shared_memory(&self, key: &str) -> Option<Vec<u8>> {
self.shared_memory.get(key).map(|entry| entry.clone())
}
}
#[derive(Debug, Clone)]
pub struct RateLimit {
pub max_requests: usize,
pub interval: std::time::Duration,
current_count: Arc<std::sync::atomic::AtomicUsize>,
last_reset_nanos: Arc<std::sync::atomic::AtomicU64>,
}
impl RateLimit {
pub fn new(max_requests: usize, interval: std::time::Duration) -> Self {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0);
Self {
max_requests,
interval,
current_count: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
last_reset_nanos: Arc::new(std::sync::atomic::AtomicU64::new(now)),
}
}
pub fn can_proceed(&self) -> bool {
use std::sync::atomic::Ordering;
let now_nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0);
let last_reset = self.last_reset_nanos.load(Ordering::Acquire);
let interval_nanos = self.interval.as_nanos() as u64;
if now_nanos.saturating_sub(last_reset) >= interval_nanos {
if self
.last_reset_nanos
.compare_exchange(last_reset, now_nanos, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
self.current_count.store(0, Ordering::Release);
}
}
let current = self.current_count.fetch_add(1, Ordering::AcqRel);
if current < self.max_requests {
true
} else {
self.current_count.fetch_sub(1, Ordering::AcqRel);
false
}
}
pub fn current_count(&self) -> usize {
self.current_count
.load(std::sync::atomic::Ordering::Acquire)
}
pub fn remaining(&self) -> usize {
let current = self.current_count();
self.max_requests.saturating_sub(current)
}
pub fn reset(&self) {
use std::sync::atomic::Ordering;
let now_nanos = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0);
self.current_count.store(0, Ordering::Release);
self.last_reset_nanos.store(now_nanos, Ordering::Release);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_multi_agent_session() {
let multi_session = MultiAgentSession::new();
let _agent_id = AgentId::new();
assert_eq!(multi_session.list_agents().len(), 0);
}
#[test]
fn test_message_bus() {
let bus = MessageBus::new();
let agent1 = AgentId::new();
let agent2 = AgentId::new();
bus.register_agent(agent1.clone()).unwrap();
bus.register_agent(agent2.clone()).unwrap();
let message = Message {
id: Uuid::new_v4(),
from: agent1.clone(),
message_type: MessageType::StatusUpdate,
payload: serde_json::json!({"status": "ready"}),
timestamp: chrono::Utc::now(),
};
bus.send_message(agent1, agent2.clone(), message).unwrap();
if let Some(receiver) = bus.get_receiver(&agent2) {
assert!(receiver.try_recv().is_ok());
}
}
#[tokio::test]
async fn test_agent_message_publish() {
let bus = MessageBus::new();
let agent1 = AgentId::new();
let agent2 = AgentId::new();
bus.register_agent(agent1.clone()).unwrap();
bus.register_agent(agent2.clone()).unwrap();
let all_receiver = bus.subscribe_all();
let registration_msg = AgentMessage::Registration {
agent_id: agent1.clone(),
capabilities: vec!["frontend".to_string(), "react".to_string()],
metadata: serde_json::json!({"version": "1.0"}),
};
bus.publish_to_agent(&agent2, registration_msg.clone())
.await
.unwrap();
if let Some(receiver) = bus.get_agent_receiver(&agent2) {
let received = receiver.try_recv().unwrap();
match received {
AgentMessage::Registration { agent_id, .. } => {
assert_eq!(agent_id, agent1);
}
_ => panic!("Unexpected message type"),
}
}
let all_msg = all_receiver.try_recv().unwrap();
match all_msg {
AgentMessage::Registration { agent_id, .. } => {
assert_eq!(agent_id, agent1);
}
_ => panic!("Unexpected message type"),
}
}
#[tokio::test]
async fn test_all_agent_message_variants() {
let bus = MessageBus::new();
let agent1 = AgentId::new();
bus.register_agent(agent1.clone()).unwrap();
let messages = vec![
AgentMessage::Registration {
agent_id: agent1.clone(),
capabilities: vec!["test".to_string()],
metadata: serde_json::json!({}),
},
AgentMessage::TaskAssignment {
task_id: TaskId::new(),
agent_id: agent1.clone(),
task_data: serde_json::json!({"task": "test"}),
},
AgentMessage::TaskCompleted {
agent_id: agent1.clone(),
task_id: TaskId::new(),
result: serde_json::json!({"success": true}),
},
AgentMessage::TaskProgress {
agent_id: agent1.clone(),
task_id: TaskId::new(),
progress: 0.5,
message: "Halfway done".to_string(),
},
AgentMessage::HelpRequest {
agent_id: agent1.clone(),
context: "Need help with React".to_string(),
priority: MessagePriority::High,
},
AgentMessage::StatusUpdate {
agent_id: agent1.clone(),
status: "active".to_string(),
metrics: serde_json::json!({"cpu": 50, "memory": 1024}),
},
AgentMessage::Custom {
message_type: "test_message".to_string(),
data: serde_json::json!({"foo": "bar"}),
},
];
for msg in messages {
bus.publish_to_agent(&agent1, msg).await.unwrap();
}
if let Some(receiver) = bus.get_agent_receiver(&agent1) {
let mut count = 0;
while receiver.try_recv().is_ok() {
count += 1;
}
assert_eq!(count, 7); }
}
#[test]
fn test_rate_limit_basic() {
let limit = RateLimit::new(3, std::time::Duration::from_secs(60));
assert!(limit.can_proceed());
assert!(limit.can_proceed());
assert!(limit.can_proceed());
assert!(!limit.can_proceed());
assert_eq!(limit.current_count(), 3);
assert_eq!(limit.remaining(), 0);
}
#[test]
fn test_rate_limit_reset() {
let limit = RateLimit::new(2, std::time::Duration::from_secs(60));
assert!(limit.can_proceed());
assert!(limit.can_proceed());
assert!(!limit.can_proceed());
limit.reset();
assert!(limit.can_proceed());
assert_eq!(limit.current_count(), 1);
}
#[test]
fn test_resource_manager_rate_limit() {
let manager = ResourceManager::new();
assert!(manager.check_rate_limit("api"));
manager.set_rate_limit("api", 2, std::time::Duration::from_secs(60));
assert!(manager.check_rate_limit("api"));
assert!(manager.check_rate_limit("api"));
assert!(!manager.check_rate_limit("api"));
assert_eq!(manager.rate_limit_remaining("api"), Some(0));
}
}