Skip to main content

ai_session/coordination/
mod.rs

1//! Multi-agent coordination functionality
2
3use anyhow::Result;
4use crossbeam_channel::{Receiver, Sender};
5use dashmap::DashMap;
6use serde::{Deserialize, Serialize};
7use std::sync::Arc;
8use tokio::sync::RwLock;
9use uuid::Uuid;
10
11use crate::core::AISession;
12
13/// Agent identifier
14#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
15pub struct AgentId(Uuid);
16
17impl Default for AgentId {
18    fn default() -> Self {
19        Self::new()
20    }
21}
22
23impl AgentId {
24    /// Create a new agent ID
25    pub fn new() -> Self {
26        Self(Uuid::new_v4())
27    }
28}
29
30impl std::fmt::Display for AgentId {
31    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
32        write!(f, "{}", self.0)
33    }
34}
35
36/// Multi-agent session coordinator
37pub struct MultiAgentSession {
38    /// Active agent sessions
39    pub agents: Arc<DashMap<AgentId, Arc<AISession>>>,
40    /// Message bus for inter-agent communication
41    pub message_bus: Arc<MessageBus>,
42    /// Task distributor
43    pub task_distributor: Arc<TaskDistributor>,
44    /// Resource manager
45    pub resource_manager: Arc<ResourceManager>,
46}
47
48impl Default for MultiAgentSession {
49    fn default() -> Self {
50        Self::new()
51    }
52}
53
54impl MultiAgentSession {
55    /// Create a new multi-agent session
56    pub fn new() -> Self {
57        Self {
58            agents: Arc::new(DashMap::new()),
59            message_bus: Arc::new(MessageBus::new()),
60            task_distributor: Arc::new(TaskDistributor::new()),
61            resource_manager: Arc::new(ResourceManager::new()),
62        }
63    }
64
65    /// Register an agent
66    pub fn register_agent(&self, agent_id: AgentId, session: Arc<AISession>) -> Result<()> {
67        self.agents.insert(agent_id.clone(), session);
68        self.message_bus.register_agent(agent_id)?;
69        Ok(())
70    }
71
72    /// Unregister an agent
73    pub fn unregister_agent(&self, agent_id: &AgentId) -> Result<()> {
74        self.agents.remove(agent_id);
75        self.message_bus.unregister_agent(agent_id)?;
76        Ok(())
77    }
78
79    /// Get an agent session
80    pub fn get_agent(&self, agent_id: &AgentId) -> Option<Arc<AISession>> {
81        self.agents.get(agent_id).map(|entry| entry.clone())
82    }
83
84    /// List all agents
85    pub fn list_agents(&self) -> Vec<AgentId> {
86        self.agents
87            .iter()
88            .map(|entry| entry.key().clone())
89            .collect()
90    }
91
92    /// Send a message to an agent
93    pub async fn send_message(&self, from: AgentId, to: AgentId, message: Message) -> Result<()> {
94        self.message_bus.send_message(from, to, message)
95    }
96
97    /// Broadcast a message to all agents
98    pub async fn broadcast(&self, from: AgentId, message: BroadcastMessage) -> Result<()> {
99        self.message_bus.broadcast(from, message)
100    }
101}
102
103/// Message bus for inter-agent communication
104pub struct MessageBus {
105    /// Message channels for each agent
106    channels: DashMap<AgentId, (Sender<Message>, Receiver<Message>)>,
107    /// Broadcast channel
108    broadcast_sender: Sender<BroadcastMessage>,
109    _broadcast_receiver: Receiver<BroadcastMessage>,
110    /// Agent message channels for ccswarm integration
111    agent_channels: DashMap<AgentId, (Sender<AgentMessage>, Receiver<AgentMessage>)>,
112    /// All messages channel for monitoring
113    all_messages_sender: Sender<AgentMessage>,
114    all_messages_receiver: Receiver<AgentMessage>,
115}
116
117impl Default for MessageBus {
118    fn default() -> Self {
119        Self::new()
120    }
121}
122
123impl MessageBus {
124    /// Create a new message bus
125    pub fn new() -> Self {
126        let (broadcast_sender, broadcast_receiver) = crossbeam_channel::unbounded();
127        let (all_messages_sender, all_messages_receiver) = crossbeam_channel::unbounded();
128        Self {
129            channels: DashMap::new(),
130            broadcast_sender,
131            _broadcast_receiver: broadcast_receiver,
132            agent_channels: DashMap::new(),
133            all_messages_sender,
134            all_messages_receiver,
135        }
136    }
137
138    /// Register an agent
139    pub fn register_agent(&self, agent_id: AgentId) -> Result<()> {
140        let (sender, receiver) = crossbeam_channel::unbounded();
141        self.channels.insert(agent_id.clone(), (sender, receiver));
142
143        // Also register agent message channel
144        let (agent_sender, agent_receiver) = crossbeam_channel::unbounded();
145        self.agent_channels
146            .insert(agent_id, (agent_sender, agent_receiver));
147        Ok(())
148    }
149
150    /// Unregister an agent
151    pub fn unregister_agent(&self, agent_id: &AgentId) -> Result<()> {
152        self.channels.remove(agent_id);
153        self.agent_channels.remove(agent_id);
154        Ok(())
155    }
156
157    /// Send a message to a specific agent
158    pub fn send_message(&self, _from: AgentId, to: AgentId, message: Message) -> Result<()> {
159        if let Some(channel) = self.channels.get(&to) {
160            channel.0.send(message)?;
161            Ok(())
162        } else {
163            Err(anyhow::anyhow!("Agent not found: {}", to))
164        }
165    }
166
167    /// Broadcast a message to all agents
168    pub fn broadcast(&self, _from: AgentId, message: BroadcastMessage) -> Result<()> {
169        self.broadcast_sender.send(message)?;
170        Ok(())
171    }
172
173    /// Get receiver for an agent
174    pub fn get_receiver(&self, agent_id: &AgentId) -> Option<Receiver<Message>> {
175        self.channels.get(agent_id).map(|entry| entry.1.clone())
176    }
177
178    /// Subscribe to all messages (for monitoring)
179    pub fn subscribe_all(&self) -> Receiver<AgentMessage> {
180        self.all_messages_receiver.clone()
181    }
182
183    /// Publish a message to a specific agent
184    pub async fn publish_to_agent(&self, agent_id: &AgentId, message: AgentMessage) -> Result<()> {
185        // Send to the specific agent
186        if let Some(channel) = self.agent_channels.get(agent_id) {
187            channel.0.send(message.clone())?;
188        } else {
189            return Err(anyhow::anyhow!("Agent not found: {}", agent_id));
190        }
191
192        // Also send to the all messages channel for monitoring
193        self.all_messages_sender.send(message)?;
194
195        Ok(())
196    }
197
198    /// Get agent message receiver for a specific agent
199    pub fn get_agent_receiver(&self, agent_id: &AgentId) -> Option<Receiver<AgentMessage>> {
200        self.agent_channels
201            .get(agent_id)
202            .map(|entry| entry.1.clone())
203    }
204}
205
206/// Inter-agent message
207#[derive(Debug, Clone, Serialize, Deserialize)]
208pub struct Message {
209    /// Message ID
210    pub id: Uuid,
211    /// Sender agent
212    pub from: AgentId,
213    /// Message type
214    pub message_type: MessageType,
215    /// Message payload
216    pub payload: serde_json::Value,
217    /// Timestamp
218    pub timestamp: chrono::DateTime<chrono::Utc>,
219}
220
221/// Message types
222#[derive(Debug, Clone, Serialize, Deserialize)]
223pub enum MessageType {
224    /// Task assignment
225    TaskAssignment,
226    /// Status update
227    StatusUpdate,
228    /// Data sharing
229    DataShare,
230    /// Coordination request
231    CoordinationRequest,
232    /// Response
233    Response,
234    /// Custom message
235    Custom(String),
236}
237
238/// Agent message for ccswarm integration
239#[derive(Debug, Clone, Serialize, Deserialize)]
240pub enum AgentMessage {
241    /// Agent registration
242    Registration {
243        agent_id: AgentId,
244        capabilities: Vec<String>,
245        metadata: serde_json::Value,
246    },
247    /// Task assignment to agent
248    TaskAssignment {
249        task_id: TaskId,
250        agent_id: AgentId,
251        task_data: serde_json::Value,
252    },
253    /// Task completion notification
254    TaskCompleted {
255        agent_id: AgentId,
256        task_id: TaskId,
257        result: serde_json::Value,
258    },
259    /// Task progress update
260    TaskProgress {
261        agent_id: AgentId,
262        task_id: TaskId,
263        progress: f32,
264        message: String,
265    },
266    /// Help request from agent
267    HelpRequest {
268        agent_id: AgentId,
269        context: String,
270        priority: MessagePriority,
271    },
272    /// Status update from agent
273    StatusUpdate {
274        agent_id: AgentId,
275        status: String,
276        metrics: serde_json::Value,
277    },
278    /// Custom message type
279    Custom {
280        message_type: String,
281        data: serde_json::Value,
282    },
283}
284
285/// Broadcast message
286#[derive(Debug, Clone, Serialize, Deserialize)]
287pub struct BroadcastMessage {
288    /// Message ID
289    pub id: Uuid,
290    /// Sender agent
291    pub from: AgentId,
292    /// Message content
293    pub content: String,
294    /// Priority
295    pub priority: MessagePriority,
296    /// Timestamp
297    pub timestamp: chrono::DateTime<chrono::Utc>,
298}
299
300/// Message priority
301#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
302pub enum MessagePriority {
303    Low,
304    Normal,
305    High,
306    Critical,
307}
308
309/// Task distributor for workload management
310pub struct TaskDistributor {
311    /// Task queue
312    task_queue: Arc<RwLock<Vec<Task>>>,
313    /// Agent capabilities
314    agent_capabilities: Arc<DashMap<AgentId, Vec<String>>>,
315    /// Task assignments
316    assignments: Arc<DashMap<TaskId, AgentId>>,
317}
318
319impl Default for TaskDistributor {
320    fn default() -> Self {
321        Self::new()
322    }
323}
324
325impl TaskDistributor {
326    /// Create a new task distributor
327    pub fn new() -> Self {
328        Self {
329            task_queue: Arc::new(RwLock::new(Vec::new())),
330            agent_capabilities: Arc::new(DashMap::new()),
331            assignments: Arc::new(DashMap::new()),
332        }
333    }
334
335    /// Register agent capabilities
336    pub fn register_capabilities(&self, agent_id: AgentId, capabilities: Vec<String>) {
337        self.agent_capabilities.insert(agent_id, capabilities);
338    }
339
340    /// Submit a task
341    pub async fn submit_task(&self, task: Task) -> Result<()> {
342        self.task_queue.write().await.push(task);
343        Ok(())
344    }
345
346    /// Assign tasks to agents
347    pub async fn distribute_tasks(&self) -> Result<Vec<(TaskId, AgentId)>> {
348        let mut assignments = Vec::new();
349        let mut queue = self.task_queue.write().await;
350
351        // Simple round-robin distribution
352        // In a real implementation, this would use sophisticated matching
353        let agents: Vec<AgentId> = self
354            .agent_capabilities
355            .iter()
356            .map(|entry| entry.key().clone())
357            .collect();
358
359        if agents.is_empty() {
360            return Ok(assignments);
361        }
362
363        let mut agent_index = 0;
364        while let Some(task) = queue.pop() {
365            let agent_id = &agents[agent_index % agents.len()];
366            self.assignments.insert(task.id.clone(), agent_id.clone());
367            assignments.push((task.id, agent_id.clone()));
368            agent_index += 1;
369        }
370
371        Ok(assignments)
372    }
373}
374
375/// Task identifier
376#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
377pub struct TaskId(Uuid);
378
379impl Default for TaskId {
380    fn default() -> Self {
381        Self::new()
382    }
383}
384
385impl TaskId {
386    /// Create a new task ID
387    pub fn new() -> Self {
388        Self(Uuid::new_v4())
389    }
390}
391
392impl std::fmt::Display for TaskId {
393    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
394        write!(f, "{}", self.0)
395    }
396}
397
398/// Task definition
399#[derive(Debug, Clone, Serialize, Deserialize)]
400pub struct Task {
401    /// Task ID
402    pub id: TaskId,
403    /// Task name
404    pub name: String,
405    /// Required capabilities
406    pub required_capabilities: Vec<String>,
407    /// Task payload
408    pub payload: serde_json::Value,
409    /// Priority
410    pub priority: TaskPriority,
411    /// Created at
412    pub created_at: chrono::DateTime<chrono::Utc>,
413}
414
415/// Task priority
416#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
417pub enum TaskPriority {
418    Low,
419    Normal,
420    High,
421    Critical,
422}
423
424/// Resource manager for preventing conflicts
425pub struct ResourceManager {
426    /// File locks
427    file_locks: Arc<DashMap<String, AgentId>>,
428    /// API rate limits
429    rate_limits: Arc<DashMap<String, RateLimit>>,
430    /// Shared memory pool
431    shared_memory: Arc<DashMap<String, Vec<u8>>>,
432}
433
434impl Default for ResourceManager {
435    fn default() -> Self {
436        Self::new()
437    }
438}
439
440impl ResourceManager {
441    /// Create a new resource manager
442    pub fn new() -> Self {
443        Self {
444            file_locks: Arc::new(DashMap::new()),
445            rate_limits: Arc::new(DashMap::new()),
446            shared_memory: Arc::new(DashMap::new()),
447        }
448    }
449
450    /// Acquire a file lock
451    pub fn acquire_file_lock(&self, path: &str, agent_id: AgentId) -> Result<()> {
452        match self.file_locks.entry(path.to_string()) {
453            dashmap::mapref::entry::Entry::Occupied(_) => {
454                Err(anyhow::anyhow!("File already locked: {}", path))
455            }
456            dashmap::mapref::entry::Entry::Vacant(entry) => {
457                entry.insert(agent_id);
458                Ok(())
459            }
460        }
461    }
462
463    /// Release a file lock
464    pub fn release_file_lock(&self, path: &str, agent_id: &AgentId) -> Result<()> {
465        if let Some((_, owner)) = self.file_locks.remove(path)
466            && owner != *agent_id
467        {
468            return Err(anyhow::anyhow!("Not the lock owner"));
469        }
470        Ok(())
471    }
472
473    /// Check rate limit
474    pub fn check_rate_limit(&self, resource: &str) -> bool {
475        if let Some(limit) = self.rate_limits.get(resource) {
476            limit.can_proceed()
477        } else {
478            true
479        }
480    }
481
482    /// Write to shared memory
483    pub fn write_shared_memory(&self, key: &str, data: Vec<u8>) {
484        self.shared_memory.insert(key.to_string(), data);
485    }
486
487    /// Read from shared memory
488    pub fn read_shared_memory(&self, key: &str) -> Option<Vec<u8>> {
489        self.shared_memory.get(key).map(|entry| entry.clone())
490    }
491}
492
493/// Rate limit information
494#[derive(Debug, Clone)]
495pub struct RateLimit {
496    /// Maximum requests per interval
497    pub max_requests: usize,
498    /// Time interval
499    pub interval: std::time::Duration,
500    /// Current count
501    pub current_count: Arc<RwLock<usize>>,
502    /// Last reset time
503    pub last_reset: Arc<RwLock<std::time::Instant>>,
504}
505
506impl RateLimit {
507    /// Check if request can proceed
508    pub fn can_proceed(&self) -> bool {
509        // Simplified implementation
510        true
511    }
512}
513
514#[cfg(test)]
515mod tests {
516    use super::*;
517
518    #[test]
519    fn test_multi_agent_session() {
520        let multi_session = MultiAgentSession::new();
521        let _agent_id = AgentId::new();
522
523        // Would need a mock session for testing
524        assert_eq!(multi_session.list_agents().len(), 0);
525    }
526
527    #[test]
528    fn test_message_bus() {
529        let bus = MessageBus::new();
530        let agent1 = AgentId::new();
531        let agent2 = AgentId::new();
532
533        bus.register_agent(agent1.clone()).unwrap();
534        bus.register_agent(agent2.clone()).unwrap();
535
536        let message = Message {
537            id: Uuid::new_v4(),
538            from: agent1.clone(),
539            message_type: MessageType::StatusUpdate,
540            payload: serde_json::json!({"status": "ready"}),
541            timestamp: chrono::Utc::now(),
542        };
543
544        bus.send_message(agent1, agent2.clone(), message).unwrap();
545
546        if let Some(receiver) = bus.get_receiver(&agent2) {
547            assert!(receiver.try_recv().is_ok());
548        }
549    }
550
551    #[tokio::test]
552    async fn test_agent_message_publish() {
553        let bus = MessageBus::new();
554        let agent1 = AgentId::new();
555        let agent2 = AgentId::new();
556
557        bus.register_agent(agent1.clone()).unwrap();
558        bus.register_agent(agent2.clone()).unwrap();
559
560        // Subscribe to all messages
561        let all_receiver = bus.subscribe_all();
562
563        // Create a registration message
564        let registration_msg = AgentMessage::Registration {
565            agent_id: agent1.clone(),
566            capabilities: vec!["frontend".to_string(), "react".to_string()],
567            metadata: serde_json::json!({"version": "1.0"}),
568        };
569
570        // Publish to agent2
571        bus.publish_to_agent(&agent2, registration_msg.clone())
572            .await
573            .unwrap();
574
575        // Check agent2 received the message
576        if let Some(receiver) = bus.get_agent_receiver(&agent2) {
577            let received = receiver.try_recv().unwrap();
578            match received {
579                AgentMessage::Registration { agent_id, .. } => {
580                    assert_eq!(agent_id, agent1);
581                }
582                _ => panic!("Unexpected message type"),
583            }
584        }
585
586        // Check all_messages channel received it too
587        let all_msg = all_receiver.try_recv().unwrap();
588        match all_msg {
589            AgentMessage::Registration { agent_id, .. } => {
590                assert_eq!(agent_id, agent1);
591            }
592            _ => panic!("Unexpected message type"),
593        }
594    }
595
596    #[tokio::test]
597    async fn test_all_agent_message_variants() {
598        let bus = MessageBus::new();
599        let agent1 = AgentId::new();
600        bus.register_agent(agent1.clone()).unwrap();
601
602        // Test all message variants
603        let messages = vec![
604            AgentMessage::Registration {
605                agent_id: agent1.clone(),
606                capabilities: vec!["test".to_string()],
607                metadata: serde_json::json!({}),
608            },
609            AgentMessage::TaskAssignment {
610                task_id: TaskId::new(),
611                agent_id: agent1.clone(),
612                task_data: serde_json::json!({"task": "test"}),
613            },
614            AgentMessage::TaskCompleted {
615                agent_id: agent1.clone(),
616                task_id: TaskId::new(),
617                result: serde_json::json!({"success": true}),
618            },
619            AgentMessage::TaskProgress {
620                agent_id: agent1.clone(),
621                task_id: TaskId::new(),
622                progress: 0.5,
623                message: "Halfway done".to_string(),
624            },
625            AgentMessage::HelpRequest {
626                agent_id: agent1.clone(),
627                context: "Need help with React".to_string(),
628                priority: MessagePriority::High,
629            },
630            AgentMessage::StatusUpdate {
631                agent_id: agent1.clone(),
632                status: "active".to_string(),
633                metrics: serde_json::json!({"cpu": 50, "memory": 1024}),
634            },
635            AgentMessage::Custom {
636                message_type: "test_message".to_string(),
637                data: serde_json::json!({"foo": "bar"}),
638            },
639        ];
640
641        for msg in messages {
642            bus.publish_to_agent(&agent1, msg).await.unwrap();
643        }
644
645        // Verify all messages were received
646        if let Some(receiver) = bus.get_agent_receiver(&agent1) {
647            let mut count = 0;
648            while receiver.try_recv().is_ok() {
649                count += 1;
650            }
651            assert_eq!(count, 7); // All 7 message variants
652        }
653    }
654}