1use anyhow::Result;
4use crossbeam_channel::{Receiver, Sender};
5use dashmap::DashMap;
6use serde::{Deserialize, Serialize};
7use std::sync::Arc;
8use tokio::sync::RwLock;
9use uuid::Uuid;
10
11use crate::core::AISession;
12
13#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
15pub struct AgentId(Uuid);
16
17impl Default for AgentId {
18 fn default() -> Self {
19 Self::new()
20 }
21}
22
23impl AgentId {
24 pub fn new() -> Self {
26 Self(Uuid::new_v4())
27 }
28}
29
30impl std::fmt::Display for AgentId {
31 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32 write!(f, "{}", self.0)
33 }
34}
35
36pub struct MultiAgentSession {
38 pub agents: Arc<DashMap<AgentId, Arc<AISession>>>,
40 pub message_bus: Arc<MessageBus>,
42 pub task_distributor: Arc<TaskDistributor>,
44 pub resource_manager: Arc<ResourceManager>,
46}
47
48impl Default for MultiAgentSession {
49 fn default() -> Self {
50 Self::new()
51 }
52}
53
54impl MultiAgentSession {
55 pub fn new() -> Self {
57 Self {
58 agents: Arc::new(DashMap::new()),
59 message_bus: Arc::new(MessageBus::new()),
60 task_distributor: Arc::new(TaskDistributor::new()),
61 resource_manager: Arc::new(ResourceManager::new()),
62 }
63 }
64
65 pub fn register_agent(&self, agent_id: AgentId, session: Arc<AISession>) -> Result<()> {
67 self.agents.insert(agent_id.clone(), session);
68 self.message_bus.register_agent(agent_id)?;
69 Ok(())
70 }
71
72 pub fn unregister_agent(&self, agent_id: &AgentId) -> Result<()> {
74 self.agents.remove(agent_id);
75 self.message_bus.unregister_agent(agent_id)?;
76 Ok(())
77 }
78
79 pub fn get_agent(&self, agent_id: &AgentId) -> Option<Arc<AISession>> {
81 self.agents.get(agent_id).map(|entry| entry.clone())
82 }
83
84 pub fn list_agents(&self) -> Vec<AgentId> {
86 self.agents
87 .iter()
88 .map(|entry| entry.key().clone())
89 .collect()
90 }
91
92 pub async fn send_message(&self, from: AgentId, to: AgentId, message: Message) -> Result<()> {
94 self.message_bus.send_message(from, to, message)
95 }
96
97 pub async fn broadcast(&self, from: AgentId, message: BroadcastMessage) -> Result<()> {
99 self.message_bus.broadcast(from, message)
100 }
101}
102
103pub struct MessageBus {
105 channels: DashMap<AgentId, (Sender<Message>, Receiver<Message>)>,
107 broadcast_sender: Sender<BroadcastMessage>,
109 _broadcast_receiver: Receiver<BroadcastMessage>,
110 agent_channels: DashMap<AgentId, (Sender<AgentMessage>, Receiver<AgentMessage>)>,
112 all_messages_sender: Sender<AgentMessage>,
114 all_messages_receiver: Receiver<AgentMessage>,
115}
116
117impl Default for MessageBus {
118 fn default() -> Self {
119 Self::new()
120 }
121}
122
123impl MessageBus {
124 pub fn new() -> Self {
126 let (broadcast_sender, broadcast_receiver) = crossbeam_channel::unbounded();
127 let (all_messages_sender, all_messages_receiver) = crossbeam_channel::unbounded();
128 Self {
129 channels: DashMap::new(),
130 broadcast_sender,
131 _broadcast_receiver: broadcast_receiver,
132 agent_channels: DashMap::new(),
133 all_messages_sender,
134 all_messages_receiver,
135 }
136 }
137
138 pub fn register_agent(&self, agent_id: AgentId) -> Result<()> {
140 let (sender, receiver) = crossbeam_channel::unbounded();
141 self.channels.insert(agent_id.clone(), (sender, receiver));
142
143 let (agent_sender, agent_receiver) = crossbeam_channel::unbounded();
145 self.agent_channels
146 .insert(agent_id, (agent_sender, agent_receiver));
147 Ok(())
148 }
149
150 pub fn unregister_agent(&self, agent_id: &AgentId) -> Result<()> {
152 self.channels.remove(agent_id);
153 self.agent_channels.remove(agent_id);
154 Ok(())
155 }
156
157 pub fn send_message(&self, _from: AgentId, to: AgentId, message: Message) -> Result<()> {
159 if let Some(channel) = self.channels.get(&to) {
160 channel.0.send(message)?;
161 Ok(())
162 } else {
163 Err(anyhow::anyhow!("Agent not found: {}", to))
164 }
165 }
166
167 pub fn broadcast(&self, _from: AgentId, message: BroadcastMessage) -> Result<()> {
169 self.broadcast_sender.send(message)?;
170 Ok(())
171 }
172
173 pub fn get_receiver(&self, agent_id: &AgentId) -> Option<Receiver<Message>> {
175 self.channels.get(agent_id).map(|entry| entry.1.clone())
176 }
177
178 pub fn subscribe_all(&self) -> Receiver<AgentMessage> {
180 self.all_messages_receiver.clone()
181 }
182
183 pub async fn publish_to_agent(&self, agent_id: &AgentId, message: AgentMessage) -> Result<()> {
185 if let Some(channel) = self.agent_channels.get(agent_id) {
187 channel.0.send(message.clone())?;
188 } else {
189 return Err(anyhow::anyhow!("Agent not found: {}", agent_id));
190 }
191
192 self.all_messages_sender.send(message)?;
194
195 Ok(())
196 }
197
198 pub fn get_agent_receiver(&self, agent_id: &AgentId) -> Option<Receiver<AgentMessage>> {
200 self.agent_channels
201 .get(agent_id)
202 .map(|entry| entry.1.clone())
203 }
204}
205
206#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct Message {
209 pub id: Uuid,
211 pub from: AgentId,
213 pub message_type: MessageType,
215 pub payload: serde_json::Value,
217 pub timestamp: chrono::DateTime<chrono::Utc>,
219}
220
221#[derive(Debug, Clone, Serialize, Deserialize)]
223pub enum MessageType {
224 TaskAssignment,
226 StatusUpdate,
228 DataShare,
230 CoordinationRequest,
232 Response,
234 Custom(String),
236}
237
238#[derive(Debug, Clone, Serialize, Deserialize)]
240pub enum AgentMessage {
241 Registration {
243 agent_id: AgentId,
244 capabilities: Vec<String>,
245 metadata: serde_json::Value,
246 },
247 TaskAssignment {
249 task_id: TaskId,
250 agent_id: AgentId,
251 task_data: serde_json::Value,
252 },
253 TaskCompleted {
255 agent_id: AgentId,
256 task_id: TaskId,
257 result: serde_json::Value,
258 },
259 TaskProgress {
261 agent_id: AgentId,
262 task_id: TaskId,
263 progress: f32,
264 message: String,
265 },
266 HelpRequest {
268 agent_id: AgentId,
269 context: String,
270 priority: MessagePriority,
271 },
272 StatusUpdate {
274 agent_id: AgentId,
275 status: String,
276 metrics: serde_json::Value,
277 },
278 Custom {
280 message_type: String,
281 data: serde_json::Value,
282 },
283}
284
285#[derive(Debug, Clone, Serialize, Deserialize)]
287pub struct BroadcastMessage {
288 pub id: Uuid,
290 pub from: AgentId,
292 pub content: String,
294 pub priority: MessagePriority,
296 pub timestamp: chrono::DateTime<chrono::Utc>,
298}
299
300#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
302pub enum MessagePriority {
303 Low,
304 Normal,
305 High,
306 Critical,
307}
308
309pub struct TaskDistributor {
311 task_queue: Arc<RwLock<Vec<Task>>>,
313 agent_capabilities: Arc<DashMap<AgentId, Vec<String>>>,
315 assignments: Arc<DashMap<TaskId, AgentId>>,
317}
318
319impl Default for TaskDistributor {
320 fn default() -> Self {
321 Self::new()
322 }
323}
324
325impl TaskDistributor {
326 pub fn new() -> Self {
328 Self {
329 task_queue: Arc::new(RwLock::new(Vec::new())),
330 agent_capabilities: Arc::new(DashMap::new()),
331 assignments: Arc::new(DashMap::new()),
332 }
333 }
334
335 pub fn register_capabilities(&self, agent_id: AgentId, capabilities: Vec<String>) {
337 self.agent_capabilities.insert(agent_id, capabilities);
338 }
339
340 pub async fn submit_task(&self, task: Task) -> Result<()> {
342 self.task_queue.write().await.push(task);
343 Ok(())
344 }
345
346 pub async fn distribute_tasks(&self) -> Result<Vec<(TaskId, AgentId)>> {
348 let mut assignments = Vec::new();
349 let mut queue = self.task_queue.write().await;
350
351 let agents: Vec<AgentId> = self
354 .agent_capabilities
355 .iter()
356 .map(|entry| entry.key().clone())
357 .collect();
358
359 if agents.is_empty() {
360 return Ok(assignments);
361 }
362
363 let mut agent_index = 0;
364 while let Some(task) = queue.pop() {
365 let agent_id = &agents[agent_index % agents.len()];
366 self.assignments.insert(task.id.clone(), agent_id.clone());
367 assignments.push((task.id, agent_id.clone()));
368 agent_index += 1;
369 }
370
371 Ok(assignments)
372 }
373}
374
375#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
377pub struct TaskId(Uuid);
378
379impl Default for TaskId {
380 fn default() -> Self {
381 Self::new()
382 }
383}
384
385impl TaskId {
386 pub fn new() -> Self {
388 Self(Uuid::new_v4())
389 }
390}
391
392impl std::fmt::Display for TaskId {
393 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
394 write!(f, "{}", self.0)
395 }
396}
397
398#[derive(Debug, Clone, Serialize, Deserialize)]
400pub struct Task {
401 pub id: TaskId,
403 pub name: String,
405 pub required_capabilities: Vec<String>,
407 pub payload: serde_json::Value,
409 pub priority: TaskPriority,
411 pub created_at: chrono::DateTime<chrono::Utc>,
413}
414
415#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
417pub enum TaskPriority {
418 Low,
419 Normal,
420 High,
421 Critical,
422}
423
424pub struct ResourceManager {
426 file_locks: Arc<DashMap<String, AgentId>>,
428 rate_limits: Arc<DashMap<String, RateLimit>>,
430 shared_memory: Arc<DashMap<String, Vec<u8>>>,
432}
433
434impl Default for ResourceManager {
435 fn default() -> Self {
436 Self::new()
437 }
438}
439
440impl ResourceManager {
441 pub fn new() -> Self {
443 Self {
444 file_locks: Arc::new(DashMap::new()),
445 rate_limits: Arc::new(DashMap::new()),
446 shared_memory: Arc::new(DashMap::new()),
447 }
448 }
449
450 pub fn acquire_file_lock(&self, path: &str, agent_id: AgentId) -> Result<()> {
452 match self.file_locks.entry(path.to_string()) {
453 dashmap::mapref::entry::Entry::Occupied(_) => {
454 Err(anyhow::anyhow!("File already locked: {}", path))
455 }
456 dashmap::mapref::entry::Entry::Vacant(entry) => {
457 entry.insert(agent_id);
458 Ok(())
459 }
460 }
461 }
462
463 pub fn release_file_lock(&self, path: &str, agent_id: &AgentId) -> Result<()> {
465 if let Some((_, owner)) = self.file_locks.remove(path)
466 && owner != *agent_id
467 {
468 return Err(anyhow::anyhow!("Not the lock owner"));
469 }
470 Ok(())
471 }
472
473 pub fn check_rate_limit(&self, resource: &str) -> bool {
475 if let Some(limit) = self.rate_limits.get(resource) {
476 limit.can_proceed()
477 } else {
478 true
479 }
480 }
481
482 pub fn write_shared_memory(&self, key: &str, data: Vec<u8>) {
484 self.shared_memory.insert(key.to_string(), data);
485 }
486
487 pub fn read_shared_memory(&self, key: &str) -> Option<Vec<u8>> {
489 self.shared_memory.get(key).map(|entry| entry.clone())
490 }
491}
492
493#[derive(Debug, Clone)]
495pub struct RateLimit {
496 pub max_requests: usize,
498 pub interval: std::time::Duration,
500 pub current_count: Arc<RwLock<usize>>,
502 pub last_reset: Arc<RwLock<std::time::Instant>>,
504}
505
506impl RateLimit {
507 pub fn can_proceed(&self) -> bool {
509 true
511 }
512}
513
514#[cfg(test)]
515mod tests {
516 use super::*;
517
518 #[test]
519 fn test_multi_agent_session() {
520 let multi_session = MultiAgentSession::new();
521 let _agent_id = AgentId::new();
522
523 assert_eq!(multi_session.list_agents().len(), 0);
525 }
526
527 #[test]
528 fn test_message_bus() {
529 let bus = MessageBus::new();
530 let agent1 = AgentId::new();
531 let agent2 = AgentId::new();
532
533 bus.register_agent(agent1.clone()).unwrap();
534 bus.register_agent(agent2.clone()).unwrap();
535
536 let message = Message {
537 id: Uuid::new_v4(),
538 from: agent1.clone(),
539 message_type: MessageType::StatusUpdate,
540 payload: serde_json::json!({"status": "ready"}),
541 timestamp: chrono::Utc::now(),
542 };
543
544 bus.send_message(agent1, agent2.clone(), message).unwrap();
545
546 if let Some(receiver) = bus.get_receiver(&agent2) {
547 assert!(receiver.try_recv().is_ok());
548 }
549 }
550
551 #[tokio::test]
552 async fn test_agent_message_publish() {
553 let bus = MessageBus::new();
554 let agent1 = AgentId::new();
555 let agent2 = AgentId::new();
556
557 bus.register_agent(agent1.clone()).unwrap();
558 bus.register_agent(agent2.clone()).unwrap();
559
560 let all_receiver = bus.subscribe_all();
562
563 let registration_msg = AgentMessage::Registration {
565 agent_id: agent1.clone(),
566 capabilities: vec!["frontend".to_string(), "react".to_string()],
567 metadata: serde_json::json!({"version": "1.0"}),
568 };
569
570 bus.publish_to_agent(&agent2, registration_msg.clone())
572 .await
573 .unwrap();
574
575 if let Some(receiver) = bus.get_agent_receiver(&agent2) {
577 let received = receiver.try_recv().unwrap();
578 match received {
579 AgentMessage::Registration { agent_id, .. } => {
580 assert_eq!(agent_id, agent1);
581 }
582 _ => panic!("Unexpected message type"),
583 }
584 }
585
586 let all_msg = all_receiver.try_recv().unwrap();
588 match all_msg {
589 AgentMessage::Registration { agent_id, .. } => {
590 assert_eq!(agent_id, agent1);
591 }
592 _ => panic!("Unexpected message type"),
593 }
594 }
595
596 #[tokio::test]
597 async fn test_all_agent_message_variants() {
598 let bus = MessageBus::new();
599 let agent1 = AgentId::new();
600 bus.register_agent(agent1.clone()).unwrap();
601
602 let messages = vec![
604 AgentMessage::Registration {
605 agent_id: agent1.clone(),
606 capabilities: vec!["test".to_string()],
607 metadata: serde_json::json!({}),
608 },
609 AgentMessage::TaskAssignment {
610 task_id: TaskId::new(),
611 agent_id: agent1.clone(),
612 task_data: serde_json::json!({"task": "test"}),
613 },
614 AgentMessage::TaskCompleted {
615 agent_id: agent1.clone(),
616 task_id: TaskId::new(),
617 result: serde_json::json!({"success": true}),
618 },
619 AgentMessage::TaskProgress {
620 agent_id: agent1.clone(),
621 task_id: TaskId::new(),
622 progress: 0.5,
623 message: "Halfway done".to_string(),
624 },
625 AgentMessage::HelpRequest {
626 agent_id: agent1.clone(),
627 context: "Need help with React".to_string(),
628 priority: MessagePriority::High,
629 },
630 AgentMessage::StatusUpdate {
631 agent_id: agent1.clone(),
632 status: "active".to_string(),
633 metrics: serde_json::json!({"cpu": 50, "memory": 1024}),
634 },
635 AgentMessage::Custom {
636 message_type: "test_message".to_string(),
637 data: serde_json::json!({"foo": "bar"}),
638 },
639 ];
640
641 for msg in messages {
642 bus.publish_to_agent(&agent1, msg).await.unwrap();
643 }
644
645 if let Some(receiver) = bus.get_agent_receiver(&agent1) {
647 let mut count = 0;
648 while receiver.try_recv().is_ok() {
649 count += 1;
650 }
651 assert_eq!(count, 7); }
653 }
654}