Skip to main content

ai_session/coordination/
mod.rs

1//! Multi-agent coordination functionality
2
3use anyhow::Result;
4use crossbeam_channel::{Receiver, Sender};
5use dashmap::DashMap;
6use serde::{Deserialize, Serialize};
7use std::sync::Arc;
8use tokio::sync::RwLock;
9use uuid::Uuid;
10
11use crate::core::AISession;
12
13/// Default channel capacity for agent message channels
14const DEFAULT_CHANNEL_CAPACITY: usize = 1000;
15
16/// Channel capacity for broadcast messages (higher to handle burst)
17const BROADCAST_CHANNEL_CAPACITY: usize = 5000;
18
19/// Channel capacity for monitoring all messages
20const ALL_MESSAGES_CHANNEL_CAPACITY: usize = 10000;
21
22/// Agent identifier
23#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
24pub struct AgentId(Uuid);
25
26impl Default for AgentId {
27    fn default() -> Self {
28        Self::new()
29    }
30}
31
32impl AgentId {
33    /// Create a new agent ID
34    pub fn new() -> Self {
35        Self(Uuid::new_v4())
36    }
37}
38
39impl std::fmt::Display for AgentId {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        write!(f, "{}", self.0)
42    }
43}
44
45/// Multi-agent session coordinator
46pub struct MultiAgentSession {
47    /// Active agent sessions
48    pub agents: Arc<DashMap<AgentId, Arc<AISession>>>,
49    /// Message bus for inter-agent communication
50    pub message_bus: Arc<MessageBus>,
51    /// Task distributor
52    pub task_distributor: Arc<TaskDistributor>,
53    /// Resource manager
54    pub resource_manager: Arc<ResourceManager>,
55}
56
57impl Default for MultiAgentSession {
58    fn default() -> Self {
59        Self::new()
60    }
61}
62
63impl MultiAgentSession {
64    /// Create a new multi-agent session
65    pub fn new() -> Self {
66        Self {
67            agents: Arc::new(DashMap::new()),
68            message_bus: Arc::new(MessageBus::new()),
69            task_distributor: Arc::new(TaskDistributor::new()),
70            resource_manager: Arc::new(ResourceManager::new()),
71        }
72    }
73
74    /// Register an agent
75    pub fn register_agent(&self, agent_id: AgentId, session: Arc<AISession>) -> Result<()> {
76        self.agents.insert(agent_id.clone(), session);
77        self.message_bus.register_agent(agent_id)?;
78        Ok(())
79    }
80
81    /// Unregister an agent
82    pub fn unregister_agent(&self, agent_id: &AgentId) -> Result<()> {
83        self.agents.remove(agent_id);
84        self.message_bus.unregister_agent(agent_id)?;
85        Ok(())
86    }
87
88    /// Get an agent session
89    pub fn get_agent(&self, agent_id: &AgentId) -> Option<Arc<AISession>> {
90        self.agents.get(agent_id).map(|entry| entry.clone())
91    }
92
93    /// List all agents
94    pub fn list_agents(&self) -> Vec<AgentId> {
95        self.agents
96            .iter()
97            .map(|entry| entry.key().clone())
98            .collect()
99    }
100
101    /// Send a message to an agent
102    pub async fn send_message(&self, from: AgentId, to: AgentId, message: Message) -> Result<()> {
103        self.message_bus.send_message(from, to, message)
104    }
105
106    /// Broadcast a message to all agents
107    pub async fn broadcast(&self, from: AgentId, message: BroadcastMessage) -> Result<()> {
108        self.message_bus.broadcast(from, message)
109    }
110}
111
112/// Message bus for inter-agent communication
113pub struct MessageBus {
114    /// Message channels for each agent
115    channels: DashMap<AgentId, (Sender<Message>, Receiver<Message>)>,
116    /// Broadcast channel
117    broadcast_sender: Sender<BroadcastMessage>,
118    _broadcast_receiver: Receiver<BroadcastMessage>,
119    /// Agent message channels for ccswarm integration
120    agent_channels: DashMap<AgentId, (Sender<AgentMessage>, Receiver<AgentMessage>)>,
121    /// All messages channel for monitoring
122    all_messages_sender: Sender<AgentMessage>,
123    all_messages_receiver: Receiver<AgentMessage>,
124}
125
126impl Default for MessageBus {
127    fn default() -> Self {
128        Self::new()
129    }
130}
131
132impl MessageBus {
133    /// Create a new message bus with bounded channels
134    ///
135    /// Uses bounded channels to prevent memory exhaustion under load:
136    /// - Broadcast channel: 5000 messages (high capacity for burst traffic)
137    /// - All messages channel: 10000 messages (monitoring may fall behind)
138    pub fn new() -> Self {
139        let (broadcast_sender, broadcast_receiver) =
140            crossbeam_channel::bounded(BROADCAST_CHANNEL_CAPACITY);
141        let (all_messages_sender, all_messages_receiver) =
142            crossbeam_channel::bounded(ALL_MESSAGES_CHANNEL_CAPACITY);
143        Self {
144            channels: DashMap::new(),
145            broadcast_sender,
146            _broadcast_receiver: broadcast_receiver,
147            agent_channels: DashMap::new(),
148            all_messages_sender,
149            all_messages_receiver,
150        }
151    }
152
153    /// Register an agent with bounded message channels
154    ///
155    /// Each agent gets a bounded channel with DEFAULT_CHANNEL_CAPACITY to
156    /// prevent any single slow agent from causing memory exhaustion.
157    pub fn register_agent(&self, agent_id: AgentId) -> Result<()> {
158        let (sender, receiver) = crossbeam_channel::bounded(DEFAULT_CHANNEL_CAPACITY);
159        self.channels.insert(agent_id.clone(), (sender, receiver));
160
161        // Also register agent message channel
162        let (agent_sender, agent_receiver) = crossbeam_channel::bounded(DEFAULT_CHANNEL_CAPACITY);
163        self.agent_channels
164            .insert(agent_id, (agent_sender, agent_receiver));
165        Ok(())
166    }
167
168    /// Unregister an agent
169    pub fn unregister_agent(&self, agent_id: &AgentId) -> Result<()> {
170        self.channels.remove(agent_id);
171        self.agent_channels.remove(agent_id);
172        Ok(())
173    }
174
175    /// Send a message to a specific agent (non-blocking)
176    ///
177    /// Returns an error if the agent's channel is full or the agent is not found.
178    /// This prevents slow agents from blocking the sender.
179    pub fn send_message(&self, _from: AgentId, to: AgentId, message: Message) -> Result<()> {
180        if let Some(channel) = self.channels.get(&to) {
181            channel.0.try_send(message).map_err(|e| match e {
182                crossbeam_channel::TrySendError::Full(_) => {
183                    anyhow::anyhow!("Agent {} channel is full (backpressure)", to)
184                }
185                crossbeam_channel::TrySendError::Disconnected(_) => {
186                    anyhow::anyhow!("Agent {} channel disconnected", to)
187                }
188            })?;
189            Ok(())
190        } else {
191            Err(anyhow::anyhow!("Agent not found: {}", to))
192        }
193    }
194
195    /// Broadcast a message to all agents (non-blocking)
196    ///
197    /// Returns an error if the broadcast channel is full.
198    pub fn broadcast(&self, _from: AgentId, message: BroadcastMessage) -> Result<()> {
199        self.broadcast_sender
200            .try_send(message)
201            .map_err(|e| match e {
202                crossbeam_channel::TrySendError::Full(_) => {
203                    anyhow::anyhow!("Broadcast channel is full (backpressure)")
204                }
205                crossbeam_channel::TrySendError::Disconnected(_) => {
206                    anyhow::anyhow!("Broadcast channel disconnected")
207                }
208            })?;
209        Ok(())
210    }
211
212    /// Get receiver for an agent
213    pub fn get_receiver(&self, agent_id: &AgentId) -> Option<Receiver<Message>> {
214        self.channels.get(agent_id).map(|entry| entry.1.clone())
215    }
216
217    /// Subscribe to all messages (for monitoring)
218    pub fn subscribe_all(&self) -> Receiver<AgentMessage> {
219        self.all_messages_receiver.clone()
220    }
221
222    /// Publish a message to a specific agent (non-blocking)
223    ///
224    /// Returns an error if either the agent's channel or the monitoring channel is full.
225    /// This prevents slow consumers from causing memory exhaustion.
226    pub async fn publish_to_agent(&self, agent_id: &AgentId, message: AgentMessage) -> Result<()> {
227        // Send to the specific agent
228        if let Some(channel) = self.agent_channels.get(agent_id) {
229            channel.0.try_send(message.clone()).map_err(|e| match e {
230                crossbeam_channel::TrySendError::Full(_) => {
231                    anyhow::anyhow!("Agent {} channel is full (backpressure)", agent_id)
232                }
233                crossbeam_channel::TrySendError::Disconnected(_) => {
234                    anyhow::anyhow!("Agent {} channel disconnected", agent_id)
235                }
236            })?;
237        } else {
238            return Err(anyhow::anyhow!("Agent not found: {}", agent_id));
239        }
240
241        // Also send to the all messages channel for monitoring (drop if full to avoid blocking)
242        // Monitoring is best-effort - we don't want to fail the primary send
243        let _ = self.all_messages_sender.try_send(message);
244
245        Ok(())
246    }
247
248    /// Get agent message receiver for a specific agent
249    pub fn get_agent_receiver(&self, agent_id: &AgentId) -> Option<Receiver<AgentMessage>> {
250        self.agent_channels
251            .get(agent_id)
252            .map(|entry| entry.1.clone())
253    }
254}
255
256// ============================================================================
257// Unified Message System
258// ============================================================================
259
260/// Unified message content - the actual message data
261///
262/// This enum consolidates all message types into a single, well-typed structure
263/// following the DRY principle. Each variant contains all necessary data for
264/// that specific message type.
265#[derive(Debug, Clone, Serialize, Deserialize)]
266pub enum MessageContent {
267    /// Agent registration message
268    Registration {
269        agent_id: AgentId,
270        capabilities: Vec<String>,
271        metadata: serde_json::Value,
272    },
273    /// Task assignment to agent
274    TaskAssignment {
275        task_id: TaskId,
276        agent_id: AgentId,
277        task_data: serde_json::Value,
278    },
279    /// Task completion notification
280    TaskCompleted {
281        agent_id: AgentId,
282        task_id: TaskId,
283        result: serde_json::Value,
284    },
285    /// Task progress update
286    TaskProgress {
287        agent_id: AgentId,
288        task_id: TaskId,
289        progress: f32,
290        message: String,
291    },
292    /// Help request from agent
293    HelpRequest {
294        agent_id: AgentId,
295        context: String,
296        priority: MessagePriority,
297    },
298    /// Status update from agent
299    StatusUpdate {
300        agent_id: AgentId,
301        status: String,
302        metrics: serde_json::Value,
303    },
304    /// Data sharing between agents
305    DataShare { data: serde_json::Value },
306    /// Coordination request
307    CoordinationRequest {
308        request_type: String,
309        data: serde_json::Value,
310    },
311    /// Response to a previous message
312    Response {
313        in_reply_to: Uuid,
314        data: serde_json::Value,
315    },
316    /// Custom message type
317    Custom {
318        message_type: String,
319        data: serde_json::Value,
320    },
321}
322
323/// Unified inter-agent message with metadata
324///
325/// This structure combines message metadata (id, from, timestamp) with
326/// the actual message content. This is the primary message type for
327/// inter-agent communication.
328#[derive(Debug, Clone, Serialize, Deserialize)]
329pub struct UnifiedMessage {
330    /// Unique message ID
331    pub id: Uuid,
332    /// Sender agent ID
333    pub from: AgentId,
334    /// Message content
335    pub content: MessageContent,
336    /// When the message was created
337    pub timestamp: chrono::DateTime<chrono::Utc>,
338}
339
340impl UnifiedMessage {
341    /// Create a new unified message
342    pub fn new(from: AgentId, content: MessageContent) -> Self {
343        Self {
344            id: Uuid::new_v4(),
345            from,
346            content,
347            timestamp: chrono::Utc::now(),
348        }
349    }
350
351    /// Create from legacy Message type
352    pub fn from_legacy_message(msg: Message) -> Self {
353        let content = match msg.message_type {
354            MessageType::TaskAssignment => MessageContent::Custom {
355                message_type: "task_assignment".to_string(),
356                data: msg.payload,
357            },
358            MessageType::StatusUpdate => MessageContent::Custom {
359                message_type: "status_update".to_string(),
360                data: msg.payload,
361            },
362            MessageType::DataShare => MessageContent::DataShare { data: msg.payload },
363            MessageType::CoordinationRequest => MessageContent::CoordinationRequest {
364                request_type: "legacy".to_string(),
365                data: msg.payload,
366            },
367            MessageType::Response => MessageContent::Response {
368                in_reply_to: Uuid::nil(),
369                data: msg.payload,
370            },
371            MessageType::Custom(t) => MessageContent::Custom {
372                message_type: t,
373                data: msg.payload,
374            },
375        };
376
377        Self {
378            id: msg.id,
379            from: msg.from,
380            content,
381            timestamp: msg.timestamp,
382        }
383    }
384
385    /// Create from AgentMessage (backward compatibility)
386    pub fn from_agent_message(from: AgentId, msg: AgentMessage) -> Self {
387        let content = match msg {
388            AgentMessage::Registration {
389                agent_id,
390                capabilities,
391                metadata,
392            } => MessageContent::Registration {
393                agent_id,
394                capabilities,
395                metadata,
396            },
397            AgentMessage::TaskAssignment {
398                task_id,
399                agent_id,
400                task_data,
401            } => MessageContent::TaskAssignment {
402                task_id,
403                agent_id,
404                task_data,
405            },
406            AgentMessage::TaskCompleted {
407                agent_id,
408                task_id,
409                result,
410            } => MessageContent::TaskCompleted {
411                agent_id,
412                task_id,
413                result,
414            },
415            AgentMessage::TaskProgress {
416                agent_id,
417                task_id,
418                progress,
419                message,
420            } => MessageContent::TaskProgress {
421                agent_id,
422                task_id,
423                progress,
424                message,
425            },
426            AgentMessage::HelpRequest {
427                agent_id,
428                context,
429                priority,
430            } => MessageContent::HelpRequest {
431                agent_id,
432                context,
433                priority,
434            },
435            AgentMessage::StatusUpdate {
436                agent_id,
437                status,
438                metrics,
439            } => MessageContent::StatusUpdate {
440                agent_id,
441                status,
442                metrics,
443            },
444            AgentMessage::Custom { message_type, data } => {
445                MessageContent::Custom { message_type, data }
446            }
447        };
448
449        Self {
450            id: Uuid::new_v4(),
451            from,
452            content,
453            timestamp: chrono::Utc::now(),
454        }
455    }
456}
457
458// ============================================================================
459// Backward Compatibility Types (Deprecated - use UnifiedMessage instead)
460// ============================================================================
461
462/// Legacy inter-agent message
463///
464/// **Deprecated**: Use `UnifiedMessage` instead for new code.
465/// This type is maintained for backward compatibility.
466#[derive(Debug, Clone, Serialize, Deserialize)]
467pub struct Message {
468    /// Message ID
469    pub id: Uuid,
470    /// Sender agent
471    pub from: AgentId,
472    /// Message type
473    pub message_type: MessageType,
474    /// Message payload
475    pub payload: serde_json::Value,
476    /// Timestamp
477    pub timestamp: chrono::DateTime<chrono::Utc>,
478}
479
480/// Legacy message types
481///
482/// **Deprecated**: Use `MessageContent` variants instead.
483#[derive(Debug, Clone, Serialize, Deserialize)]
484pub enum MessageType {
485    /// Task assignment
486    TaskAssignment,
487    /// Status update
488    StatusUpdate,
489    /// Data sharing
490    DataShare,
491    /// Coordination request
492    CoordinationRequest,
493    /// Response
494    Response,
495    /// Custom message
496    Custom(String),
497}
498
499/// Legacy agent message for ccswarm integration
500///
501/// **Deprecated**: Use `UnifiedMessage` with `MessageContent` instead.
502/// This type is maintained for backward compatibility.
503#[derive(Debug, Clone, Serialize, Deserialize)]
504pub enum AgentMessage {
505    /// Agent registration
506    Registration {
507        agent_id: AgentId,
508        capabilities: Vec<String>,
509        metadata: serde_json::Value,
510    },
511    /// Task assignment to agent
512    TaskAssignment {
513        task_id: TaskId,
514        agent_id: AgentId,
515        task_data: serde_json::Value,
516    },
517    /// Task completion notification
518    TaskCompleted {
519        agent_id: AgentId,
520        task_id: TaskId,
521        result: serde_json::Value,
522    },
523    /// Task progress update
524    TaskProgress {
525        agent_id: AgentId,
526        task_id: TaskId,
527        progress: f32,
528        message: String,
529    },
530    /// Help request from agent
531    HelpRequest {
532        agent_id: AgentId,
533        context: String,
534        priority: MessagePriority,
535    },
536    /// Status update from agent
537    StatusUpdate {
538        agent_id: AgentId,
539        status: String,
540        metrics: serde_json::Value,
541    },
542    /// Custom message type
543    Custom {
544        message_type: String,
545        data: serde_json::Value,
546    },
547}
548
549/// Broadcast message
550#[derive(Debug, Clone, Serialize, Deserialize)]
551pub struct BroadcastMessage {
552    /// Message ID
553    pub id: Uuid,
554    /// Sender agent
555    pub from: AgentId,
556    /// Message content
557    pub content: String,
558    /// Priority
559    pub priority: MessagePriority,
560    /// Timestamp
561    pub timestamp: chrono::DateTime<chrono::Utc>,
562}
563
564/// Message priority
565#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
566pub enum MessagePriority {
567    Low,
568    Normal,
569    High,
570    Critical,
571}
572
573/// Task distributor for workload management
574pub struct TaskDistributor {
575    /// Task queue
576    task_queue: Arc<RwLock<Vec<Task>>>,
577    /// Agent capabilities
578    agent_capabilities: Arc<DashMap<AgentId, Vec<String>>>,
579    /// Task assignments
580    assignments: Arc<DashMap<TaskId, AgentId>>,
581}
582
583impl Default for TaskDistributor {
584    fn default() -> Self {
585        Self::new()
586    }
587}
588
589impl TaskDistributor {
590    /// Create a new task distributor
591    pub fn new() -> Self {
592        Self {
593            task_queue: Arc::new(RwLock::new(Vec::new())),
594            agent_capabilities: Arc::new(DashMap::new()),
595            assignments: Arc::new(DashMap::new()),
596        }
597    }
598
599    /// Register agent capabilities
600    pub fn register_capabilities(&self, agent_id: AgentId, capabilities: Vec<String>) {
601        self.agent_capabilities.insert(agent_id, capabilities);
602    }
603
604    /// Submit a task
605    pub async fn submit_task(&self, task: Task) -> Result<()> {
606        self.task_queue.write().await.push(task);
607        Ok(())
608    }
609
610    /// Assign tasks to agents
611    pub async fn distribute_tasks(&self) -> Result<Vec<(TaskId, AgentId)>> {
612        let mut assignments = Vec::new();
613        let mut queue = self.task_queue.write().await;
614
615        // Simple round-robin distribution
616        // In a real implementation, this would use sophisticated matching
617        let agents: Vec<AgentId> = self
618            .agent_capabilities
619            .iter()
620            .map(|entry| entry.key().clone())
621            .collect();
622
623        if agents.is_empty() {
624            return Ok(assignments);
625        }
626
627        let mut agent_index = 0;
628        while let Some(task) = queue.pop() {
629            let agent_id = &agents[agent_index % agents.len()];
630            self.assignments.insert(task.id.clone(), agent_id.clone());
631            assignments.push((task.id, agent_id.clone()));
632            agent_index += 1;
633        }
634
635        Ok(assignments)
636    }
637}
638
639/// Task identifier
640#[derive(Debug, Clone, Hash, Eq, PartialEq, Serialize, Deserialize)]
641pub struct TaskId(Uuid);
642
643impl Default for TaskId {
644    fn default() -> Self {
645        Self::new()
646    }
647}
648
649impl TaskId {
650    /// Create a new task ID
651    pub fn new() -> Self {
652        Self(Uuid::new_v4())
653    }
654}
655
656impl std::fmt::Display for TaskId {
657    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
658        write!(f, "{}", self.0)
659    }
660}
661
662/// Task definition
663#[derive(Debug, Clone, Serialize, Deserialize)]
664pub struct Task {
665    /// Task ID
666    pub id: TaskId,
667    /// Task name
668    pub name: String,
669    /// Required capabilities
670    pub required_capabilities: Vec<String>,
671    /// Task payload
672    pub payload: serde_json::Value,
673    /// Priority
674    pub priority: TaskPriority,
675    /// Created at
676    pub created_at: chrono::DateTime<chrono::Utc>,
677}
678
679/// Task priority
680#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
681pub enum TaskPriority {
682    Low,
683    Normal,
684    High,
685    Critical,
686}
687
688/// Resource manager for preventing conflicts
689pub struct ResourceManager {
690    /// File locks
691    file_locks: Arc<DashMap<String, AgentId>>,
692    /// API rate limits
693    rate_limits: Arc<DashMap<String, RateLimit>>,
694    /// Shared memory pool
695    shared_memory: Arc<DashMap<String, Vec<u8>>>,
696}
697
698impl Default for ResourceManager {
699    fn default() -> Self {
700        Self::new()
701    }
702}
703
704impl ResourceManager {
705    /// Create a new resource manager
706    pub fn new() -> Self {
707        Self {
708            file_locks: Arc::new(DashMap::new()),
709            rate_limits: Arc::new(DashMap::new()),
710            shared_memory: Arc::new(DashMap::new()),
711        }
712    }
713
714    /// Acquire a file lock
715    pub fn acquire_file_lock(&self, path: &str, agent_id: AgentId) -> Result<()> {
716        match self.file_locks.entry(path.to_string()) {
717            dashmap::mapref::entry::Entry::Occupied(_) => {
718                Err(anyhow::anyhow!("File already locked: {}", path))
719            }
720            dashmap::mapref::entry::Entry::Vacant(entry) => {
721                entry.insert(agent_id);
722                Ok(())
723            }
724        }
725    }
726
727    /// Release a file lock
728    pub fn release_file_lock(&self, path: &str, agent_id: &AgentId) -> Result<()> {
729        if let Some((_, owner)) = self.file_locks.remove(path)
730            && owner != *agent_id
731        {
732            return Err(anyhow::anyhow!("Not the lock owner"));
733        }
734        Ok(())
735    }
736
737    /// Check rate limit for a resource
738    ///
739    /// Returns true if the request is within rate limits for the given resource,
740    /// or if no rate limit is configured for that resource.
741    pub fn check_rate_limit(&self, resource: &str) -> bool {
742        if let Some(limit) = self.rate_limits.get(resource) {
743            limit.can_proceed()
744        } else {
745            true
746        }
747    }
748
749    /// Set a rate limit for a resource
750    ///
751    /// # Arguments
752    /// * `resource` - The resource identifier (e.g., "api", "file_ops")
753    /// * `max_requests` - Maximum requests allowed per interval
754    /// * `interval` - Time window for rate limiting
755    pub fn set_rate_limit(
756        &self,
757        resource: &str,
758        max_requests: usize,
759        interval: std::time::Duration,
760    ) {
761        self.rate_limits
762            .insert(resource.to_string(), RateLimit::new(max_requests, interval));
763    }
764
765    /// Get remaining rate limit for a resource
766    pub fn rate_limit_remaining(&self, resource: &str) -> Option<usize> {
767        self.rate_limits
768            .get(resource)
769            .map(|limit| limit.remaining())
770    }
771
772    /// Write to shared memory
773    pub fn write_shared_memory(&self, key: &str, data: Vec<u8>) {
774        self.shared_memory.insert(key.to_string(), data);
775    }
776
777    /// Read from shared memory
778    pub fn read_shared_memory(&self, key: &str) -> Option<Vec<u8>> {
779        self.shared_memory.get(key).map(|entry| entry.clone())
780    }
781}
782
783/// Rate limit information using token bucket algorithm
784///
785/// This implementation provides proper rate limiting with automatic window
786/// reset and thread-safe counter management.
787#[derive(Debug, Clone)]
788pub struct RateLimit {
789    /// Maximum requests per interval
790    pub max_requests: usize,
791    /// Time interval for the rate limit window
792    pub interval: std::time::Duration,
793    /// Current count of requests in the window
794    current_count: Arc<std::sync::atomic::AtomicUsize>,
795    /// Last reset time (stored as nanos since UNIX_EPOCH for atomic operations)
796    last_reset_nanos: Arc<std::sync::atomic::AtomicU64>,
797}
798
799impl RateLimit {
800    /// Create a new rate limit
801    ///
802    /// # Arguments
803    /// * `max_requests` - Maximum requests allowed per interval
804    /// * `interval` - Time window for rate limiting
805    pub fn new(max_requests: usize, interval: std::time::Duration) -> Self {
806        let now = std::time::SystemTime::now()
807            .duration_since(std::time::UNIX_EPOCH)
808            .map(|d| d.as_nanos() as u64)
809            .unwrap_or(0);
810
811        Self {
812            max_requests,
813            interval,
814            current_count: Arc::new(std::sync::atomic::AtomicUsize::new(0)),
815            last_reset_nanos: Arc::new(std::sync::atomic::AtomicU64::new(now)),
816        }
817    }
818
819    /// Check if request can proceed and increment counter if allowed
820    ///
821    /// Returns true if the request is within rate limits, false otherwise.
822    /// This is a non-blocking, thread-safe operation using atomic operations.
823    pub fn can_proceed(&self) -> bool {
824        use std::sync::atomic::Ordering;
825
826        let now_nanos = std::time::SystemTime::now()
827            .duration_since(std::time::UNIX_EPOCH)
828            .map(|d| d.as_nanos() as u64)
829            .unwrap_or(0);
830
831        let last_reset = self.last_reset_nanos.load(Ordering::Acquire);
832        let interval_nanos = self.interval.as_nanos() as u64;
833
834        // Check if we need to reset the window
835        if now_nanos.saturating_sub(last_reset) >= interval_nanos {
836            // Try to reset the window (only one thread should succeed)
837            if self
838                .last_reset_nanos
839                .compare_exchange(last_reset, now_nanos, Ordering::AcqRel, Ordering::Relaxed)
840                .is_ok()
841            {
842                // Successfully reset, also reset the counter
843                self.current_count.store(0, Ordering::Release);
844            }
845        }
846
847        // Try to increment the counter
848        let current = self.current_count.fetch_add(1, Ordering::AcqRel);
849
850        if current < self.max_requests {
851            true
852        } else {
853            // Exceeded limit, decrement counter back
854            self.current_count.fetch_sub(1, Ordering::AcqRel);
855            false
856        }
857    }
858
859    /// Get the current count of requests in this window
860    pub fn current_count(&self) -> usize {
861        self.current_count
862            .load(std::sync::atomic::Ordering::Acquire)
863    }
864
865    /// Get remaining requests in this window
866    pub fn remaining(&self) -> usize {
867        let current = self.current_count();
868        self.max_requests.saturating_sub(current)
869    }
870
871    /// Reset the rate limit counter
872    pub fn reset(&self) {
873        use std::sync::atomic::Ordering;
874
875        let now_nanos = std::time::SystemTime::now()
876            .duration_since(std::time::UNIX_EPOCH)
877            .map(|d| d.as_nanos() as u64)
878            .unwrap_or(0);
879
880        self.current_count.store(0, Ordering::Release);
881        self.last_reset_nanos.store(now_nanos, Ordering::Release);
882    }
883}
884
885#[cfg(test)]
886mod tests {
887    use super::*;
888
889    #[test]
890    fn test_multi_agent_session() {
891        let multi_session = MultiAgentSession::new();
892        let _agent_id = AgentId::new();
893
894        // Would need a mock session for testing
895        assert_eq!(multi_session.list_agents().len(), 0);
896    }
897
898    #[test]
899    fn test_message_bus() {
900        let bus = MessageBus::new();
901        let agent1 = AgentId::new();
902        let agent2 = AgentId::new();
903
904        bus.register_agent(agent1.clone()).unwrap();
905        bus.register_agent(agent2.clone()).unwrap();
906
907        let message = Message {
908            id: Uuid::new_v4(),
909            from: agent1.clone(),
910            message_type: MessageType::StatusUpdate,
911            payload: serde_json::json!({"status": "ready"}),
912            timestamp: chrono::Utc::now(),
913        };
914
915        bus.send_message(agent1, agent2.clone(), message).unwrap();
916
917        if let Some(receiver) = bus.get_receiver(&agent2) {
918            assert!(receiver.try_recv().is_ok());
919        }
920    }
921
922    #[tokio::test]
923    async fn test_agent_message_publish() {
924        let bus = MessageBus::new();
925        let agent1 = AgentId::new();
926        let agent2 = AgentId::new();
927
928        bus.register_agent(agent1.clone()).unwrap();
929        bus.register_agent(agent2.clone()).unwrap();
930
931        // Subscribe to all messages
932        let all_receiver = bus.subscribe_all();
933
934        // Create a registration message
935        let registration_msg = AgentMessage::Registration {
936            agent_id: agent1.clone(),
937            capabilities: vec!["frontend".to_string(), "react".to_string()],
938            metadata: serde_json::json!({"version": "1.0"}),
939        };
940
941        // Publish to agent2
942        bus.publish_to_agent(&agent2, registration_msg.clone())
943            .await
944            .unwrap();
945
946        // Check agent2 received the message
947        if let Some(receiver) = bus.get_agent_receiver(&agent2) {
948            let received = receiver.try_recv().unwrap();
949            match received {
950                AgentMessage::Registration { agent_id, .. } => {
951                    assert_eq!(agent_id, agent1);
952                }
953                _ => panic!("Unexpected message type"),
954            }
955        }
956
957        // Check all_messages channel received it too
958        let all_msg = all_receiver.try_recv().unwrap();
959        match all_msg {
960            AgentMessage::Registration { agent_id, .. } => {
961                assert_eq!(agent_id, agent1);
962            }
963            _ => panic!("Unexpected message type"),
964        }
965    }
966
967    #[tokio::test]
968    async fn test_all_agent_message_variants() {
969        let bus = MessageBus::new();
970        let agent1 = AgentId::new();
971        bus.register_agent(agent1.clone()).unwrap();
972
973        // Test all message variants
974        let messages = vec![
975            AgentMessage::Registration {
976                agent_id: agent1.clone(),
977                capabilities: vec!["test".to_string()],
978                metadata: serde_json::json!({}),
979            },
980            AgentMessage::TaskAssignment {
981                task_id: TaskId::new(),
982                agent_id: agent1.clone(),
983                task_data: serde_json::json!({"task": "test"}),
984            },
985            AgentMessage::TaskCompleted {
986                agent_id: agent1.clone(),
987                task_id: TaskId::new(),
988                result: serde_json::json!({"success": true}),
989            },
990            AgentMessage::TaskProgress {
991                agent_id: agent1.clone(),
992                task_id: TaskId::new(),
993                progress: 0.5,
994                message: "Halfway done".to_string(),
995            },
996            AgentMessage::HelpRequest {
997                agent_id: agent1.clone(),
998                context: "Need help with React".to_string(),
999                priority: MessagePriority::High,
1000            },
1001            AgentMessage::StatusUpdate {
1002                agent_id: agent1.clone(),
1003                status: "active".to_string(),
1004                metrics: serde_json::json!({"cpu": 50, "memory": 1024}),
1005            },
1006            AgentMessage::Custom {
1007                message_type: "test_message".to_string(),
1008                data: serde_json::json!({"foo": "bar"}),
1009            },
1010        ];
1011
1012        for msg in messages {
1013            bus.publish_to_agent(&agent1, msg).await.unwrap();
1014        }
1015
1016        // Verify all messages were received
1017        if let Some(receiver) = bus.get_agent_receiver(&agent1) {
1018            let mut count = 0;
1019            while receiver.try_recv().is_ok() {
1020                count += 1;
1021            }
1022            assert_eq!(count, 7); // All 7 message variants
1023        }
1024    }
1025
1026    #[test]
1027    fn test_rate_limit_basic() {
1028        let limit = RateLimit::new(3, std::time::Duration::from_secs(60));
1029
1030        // First 3 requests should succeed
1031        assert!(limit.can_proceed());
1032        assert!(limit.can_proceed());
1033        assert!(limit.can_proceed());
1034
1035        // 4th request should fail
1036        assert!(!limit.can_proceed());
1037
1038        // Check remaining
1039        assert_eq!(limit.current_count(), 3);
1040        assert_eq!(limit.remaining(), 0);
1041    }
1042
1043    #[test]
1044    fn test_rate_limit_reset() {
1045        let limit = RateLimit::new(2, std::time::Duration::from_secs(60));
1046
1047        assert!(limit.can_proceed());
1048        assert!(limit.can_proceed());
1049        assert!(!limit.can_proceed());
1050
1051        // Reset should allow more requests
1052        limit.reset();
1053        assert!(limit.can_proceed());
1054        assert_eq!(limit.current_count(), 1);
1055    }
1056
1057    #[test]
1058    fn test_resource_manager_rate_limit() {
1059        let manager = ResourceManager::new();
1060
1061        // No rate limit set - should allow
1062        assert!(manager.check_rate_limit("api"));
1063
1064        // Set rate limit
1065        manager.set_rate_limit("api", 2, std::time::Duration::from_secs(60));
1066
1067        // Check rate limit
1068        assert!(manager.check_rate_limit("api"));
1069        assert!(manager.check_rate_limit("api"));
1070        assert!(!manager.check_rate_limit("api"));
1071
1072        // Check remaining
1073        assert_eq!(manager.rate_limit_remaining("api"), Some(0));
1074    }
1075}