Skip to main content

cortexai_agents/
discovery.rs

1//! Agent Discovery Module
2//!
3//! Enables dynamic agent discovery through heartbeat and capability broadcasting.
4//!
5//! ## Features
6//!
7//! - **Heartbeat**: Agents periodically announce their presence
8//! - **Capabilities**: Agents advertise what they can do
9//! - **Registry**: Central registry for discovering agents
10//! - **Events**: Subscribe to agent join/leave events
11//!
12//! ## Example
13//!
14//! ```rust,ignore
15//! use cortexai_agents::discovery::{DiscoveryRegistry, AgentCapabilities};
16//!
17//! let registry = DiscoveryRegistry::new();
18//!
19//! // Register agent with capabilities
20//! registry.register(AgentInfo {
21//!     id: AgentId::new("researcher"),
22//!     capabilities: vec![Capability::Analysis, Capability::WebSearch],
23//!     ..Default::default()
24//! }).await;
25//!
26//! // Find agents by capability
27//! let agents = registry.find_by_capability(Capability::Analysis).await;
28//!
29//! // Subscribe to discovery events
30//! let mut rx = registry.subscribe();
31//! while let Some(event) = rx.recv().await {
32//!     match event {
33//!         DiscoveryEvent::AgentJoined(info) => println!("Agent joined: {}", info.id),
34//!         DiscoveryEvent::AgentLeft(id) => println!("Agent left: {}", id),
35//!         DiscoveryEvent::HeartbeatReceived(id) => println!("Heartbeat from: {}", id),
36//!     }
37//! }
38//! ```
39
40use std::collections::HashMap;
41use std::sync::Arc;
42use std::time::{Duration, Instant};
43use tokio::sync::{broadcast, RwLock};
44use tracing::{debug, info, warn};
45
46use cortexai_core::types::{AgentId, AgentRole, Capability};
47
48/// Information about a discovered agent
49#[derive(Debug, Clone)]
50pub struct AgentInfo {
51    /// Agent identifier
52    pub id: AgentId,
53    /// Agent name
54    pub name: String,
55    /// Agent role
56    pub role: AgentRole,
57    /// Agent capabilities
58    pub capabilities: Vec<Capability>,
59    /// Agent status
60    pub status: AgentDiscoveryStatus,
61    /// Last heartbeat time
62    pub last_heartbeat: Instant,
63    /// Agent metadata
64    pub metadata: HashMap<String, serde_json::Value>,
65    /// Agent endpoint (for remote agents)
66    pub endpoint: Option<String>,
67}
68
69impl AgentInfo {
70    pub fn new(id: AgentId, name: impl Into<String>, role: AgentRole) -> Self {
71        Self {
72            id,
73            name: name.into(),
74            role,
75            capabilities: Vec::new(),
76            status: AgentDiscoveryStatus::Online,
77            last_heartbeat: Instant::now(),
78            metadata: HashMap::new(),
79            endpoint: None,
80        }
81    }
82
83    pub fn with_capabilities(mut self, capabilities: Vec<Capability>) -> Self {
84        self.capabilities = capabilities;
85        self
86    }
87
88    pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
89        self.endpoint = Some(endpoint.into());
90        self
91    }
92
93    pub fn add_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
94        self.metadata.insert(key.into(), value);
95        self
96    }
97
98    /// Check if agent has a specific capability
99    pub fn has_capability(&self, capability: &Capability) -> bool {
100        self.capabilities.contains(capability)
101    }
102
103    /// Update heartbeat timestamp
104    pub fn heartbeat(&mut self) {
105        self.last_heartbeat = Instant::now();
106        self.status = AgentDiscoveryStatus::Online;
107    }
108
109    /// Check if agent is considered stale (no heartbeat for too long)
110    pub fn is_stale(&self, timeout: Duration) -> bool {
111        self.last_heartbeat.elapsed() > timeout
112    }
113}
114
115impl Default for AgentInfo {
116    fn default() -> Self {
117        Self {
118            id: AgentId::generate(),
119            name: "Unknown".to_string(),
120            role: AgentRole::Executor,
121            capabilities: Vec::new(),
122            status: AgentDiscoveryStatus::Unknown,
123            last_heartbeat: Instant::now(),
124            metadata: HashMap::new(),
125            endpoint: None,
126        }
127    }
128}
129
130/// Agent discovery status
131#[derive(Debug, Clone, PartialEq, Eq)]
132pub enum AgentDiscoveryStatus {
133    /// Agent is online and responding
134    Online,
135    /// Agent missed recent heartbeats
136    Degraded,
137    /// Agent is offline/unreachable
138    Offline,
139    /// Agent status is unknown
140    Unknown,
141}
142
143/// Events emitted by the discovery system
144#[derive(Debug, Clone)]
145pub enum DiscoveryEvent {
146    /// A new agent has joined
147    AgentJoined(AgentInfo),
148    /// An agent has left or timed out
149    AgentLeft(AgentId),
150    /// Heartbeat received from an agent
151    HeartbeatReceived(AgentId),
152    /// Agent capabilities updated
153    CapabilitiesUpdated(AgentId, Vec<Capability>),
154    /// Agent status changed
155    StatusChanged(AgentId, AgentDiscoveryStatus),
156}
157
158/// Configuration for the discovery registry
159#[derive(Debug, Clone)]
160pub struct DiscoveryConfig {
161    /// Heartbeat interval
162    pub heartbeat_interval: Duration,
163    /// Timeout before marking agent as degraded
164    pub degraded_timeout: Duration,
165    /// Timeout before marking agent as offline
166    pub offline_timeout: Duration,
167    /// Whether to auto-cleanup offline agents
168    pub auto_cleanup: bool,
169    /// Cleanup interval
170    pub cleanup_interval: Duration,
171}
172
173impl Default for DiscoveryConfig {
174    fn default() -> Self {
175        Self {
176            heartbeat_interval: Duration::from_secs(30),
177            degraded_timeout: Duration::from_secs(60),
178            offline_timeout: Duration::from_secs(120),
179            auto_cleanup: true,
180            cleanup_interval: Duration::from_secs(60),
181        }
182    }
183}
184
185/// Central registry for agent discovery
186pub struct DiscoveryRegistry {
187    agents: Arc<RwLock<HashMap<AgentId, AgentInfo>>>,
188    #[allow(dead_code)]
189    config: DiscoveryConfig,
190    event_tx: broadcast::Sender<DiscoveryEvent>,
191    _cleanup_handle: Option<tokio::task::JoinHandle<()>>,
192}
193
194impl DiscoveryRegistry {
195    /// Create a new discovery registry with default config
196    pub fn new() -> Self {
197        Self::with_config(DiscoveryConfig::default())
198    }
199
200    /// Create with custom configuration
201    pub fn with_config(config: DiscoveryConfig) -> Self {
202        let (event_tx, _) = broadcast::channel(100);
203        let agents = Arc::new(RwLock::new(HashMap::new()));
204
205        let cleanup_handle = if config.auto_cleanup {
206            let agents_clone = agents.clone();
207            let event_tx_clone = event_tx.clone();
208            let cleanup_interval = config.cleanup_interval;
209            let offline_timeout = config.offline_timeout;
210
211            Some(tokio::spawn(async move {
212                let mut interval = tokio::time::interval(cleanup_interval);
213                loop {
214                    interval.tick().await;
215                    Self::cleanup_stale_agents(&agents_clone, &event_tx_clone, offline_timeout)
216                        .await;
217                }
218            }))
219        } else {
220            None
221        };
222
223        Self {
224            agents,
225            config,
226            event_tx,
227            _cleanup_handle: cleanup_handle,
228        }
229    }
230
231    /// Register a new agent
232    pub async fn register(&self, info: AgentInfo) {
233        let id = info.id.clone();
234        let is_new = {
235            let agents = self.agents.read().await;
236            !agents.contains_key(&id)
237        };
238
239        {
240            let mut agents = self.agents.write().await;
241            agents.insert(id.clone(), info.clone());
242        }
243
244        if is_new {
245            info!(agent = %id, "Agent registered");
246            let _ = self.event_tx.send(DiscoveryEvent::AgentJoined(info));
247        }
248    }
249
250    /// Unregister an agent
251    pub async fn unregister(&self, id: &AgentId) {
252        let removed = {
253            let mut agents = self.agents.write().await;
254            agents.remove(id)
255        };
256
257        if removed.is_some() {
258            info!(agent = %id, "Agent unregistered");
259            let _ = self.event_tx.send(DiscoveryEvent::AgentLeft(id.clone()));
260        }
261    }
262
263    /// Record a heartbeat from an agent
264    pub async fn heartbeat(&self, id: &AgentId) {
265        let mut agents = self.agents.write().await;
266        if let Some(agent) = agents.get_mut(id) {
267            let old_status = agent.status.clone();
268            agent.heartbeat();
269
270            if old_status != AgentDiscoveryStatus::Online {
271                drop(agents);
272                let _ = self.event_tx.send(DiscoveryEvent::StatusChanged(
273                    id.clone(),
274                    AgentDiscoveryStatus::Online,
275                ));
276            }
277
278            debug!(agent = %id, "Heartbeat received");
279            let _ = self
280                .event_tx
281                .send(DiscoveryEvent::HeartbeatReceived(id.clone()));
282        }
283    }
284
285    /// Update agent capabilities
286    pub async fn update_capabilities(&self, id: &AgentId, capabilities: Vec<Capability>) {
287        let mut agents = self.agents.write().await;
288        if let Some(agent) = agents.get_mut(id) {
289            agent.capabilities = capabilities.clone();
290            drop(agents);
291
292            info!(agent = %id, caps = ?capabilities, "Capabilities updated");
293            let _ = self.event_tx.send(DiscoveryEvent::CapabilitiesUpdated(
294                id.clone(),
295                capabilities,
296            ));
297        }
298    }
299
300    /// Get agent info by ID
301    pub async fn get(&self, id: &AgentId) -> Option<AgentInfo> {
302        let agents = self.agents.read().await;
303        agents.get(id).cloned()
304    }
305
306    /// List all registered agents
307    pub async fn list_all(&self) -> Vec<AgentInfo> {
308        let agents = self.agents.read().await;
309        agents.values().cloned().collect()
310    }
311
312    /// List online agents only
313    pub async fn list_online(&self) -> Vec<AgentInfo> {
314        let agents = self.agents.read().await;
315        agents
316            .values()
317            .filter(|a| a.status == AgentDiscoveryStatus::Online)
318            .cloned()
319            .collect()
320    }
321
322    /// Find agents by capability
323    pub async fn find_by_capability(&self, capability: &Capability) -> Vec<AgentInfo> {
324        let agents = self.agents.read().await;
325        agents
326            .values()
327            .filter(|a| a.has_capability(capability))
328            .cloned()
329            .collect()
330    }
331
332    /// Find agents by role
333    pub async fn find_by_role(&self, role: &AgentRole) -> Vec<AgentInfo> {
334        let agents = self.agents.read().await;
335        agents
336            .values()
337            .filter(|a| std::mem::discriminant(&a.role) == std::mem::discriminant(role))
338            .cloned()
339            .collect()
340    }
341
342    /// Find agents matching multiple capabilities (AND)
343    pub async fn find_by_capabilities(&self, capabilities: &[Capability]) -> Vec<AgentInfo> {
344        let agents = self.agents.read().await;
345        agents
346            .values()
347            .filter(|a| capabilities.iter().all(|c| a.has_capability(c)))
348            .cloned()
349            .collect()
350    }
351
352    /// Find agents matching any of the capabilities (OR)
353    pub async fn find_by_any_capability(&self, capabilities: &[Capability]) -> Vec<AgentInfo> {
354        let agents = self.agents.read().await;
355        agents
356            .values()
357            .filter(|a| capabilities.iter().any(|c| a.has_capability(c)))
358            .cloned()
359            .collect()
360    }
361
362    /// Subscribe to discovery events
363    pub fn subscribe(&self) -> broadcast::Receiver<DiscoveryEvent> {
364        self.event_tx.subscribe()
365    }
366
367    /// Get the number of registered agents
368    pub async fn agent_count(&self) -> usize {
369        let agents = self.agents.read().await;
370        agents.len()
371    }
372
373    /// Get the number of online agents
374    pub async fn online_count(&self) -> usize {
375        let agents = self.agents.read().await;
376        agents
377            .values()
378            .filter(|a| a.status == AgentDiscoveryStatus::Online)
379            .count()
380    }
381
382    /// Cleanup stale agents
383    async fn cleanup_stale_agents(
384        agents: &Arc<RwLock<HashMap<AgentId, AgentInfo>>>,
385        event_tx: &broadcast::Sender<DiscoveryEvent>,
386        offline_timeout: Duration,
387    ) {
388        let mut agents_write = agents.write().await;
389
390        // Collect IDs of agents to remove
391        let stale_ids: Vec<AgentId> = agents_write
392            .iter()
393            .filter(|(_, a)| a.is_stale(offline_timeout))
394            .map(|(id, _)| id.clone())
395            .collect();
396
397        for id in stale_ids {
398            if let Some(agent) = agents_write.remove(&id) {
399                warn!(agent = %id, "Removing stale agent");
400                let _ = event_tx.send(DiscoveryEvent::AgentLeft(agent.id));
401            }
402        }
403
404        // Update status for degraded agents
405        let degraded_timeout = offline_timeout / 2;
406        for agent in agents_write.values_mut() {
407            if agent.status == AgentDiscoveryStatus::Online && agent.is_stale(degraded_timeout) {
408                agent.status = AgentDiscoveryStatus::Degraded;
409                let _ = event_tx.send(DiscoveryEvent::StatusChanged(
410                    agent.id.clone(),
411                    AgentDiscoveryStatus::Degraded,
412                ));
413            }
414        }
415    }
416}
417
418impl Default for DiscoveryRegistry {
419    fn default() -> Self {
420        Self::new()
421    }
422}
423
424/// Heartbeat sender for an agent
425pub struct HeartbeatSender {
426    registry: Arc<DiscoveryRegistry>,
427    agent_id: AgentId,
428    interval: Duration,
429    handle: Option<tokio::task::JoinHandle<()>>,
430}
431
432impl HeartbeatSender {
433    /// Create a new heartbeat sender
434    pub fn new(registry: Arc<DiscoveryRegistry>, agent_id: AgentId) -> Self {
435        Self {
436            registry,
437            agent_id,
438            interval: Duration::from_secs(30),
439            handle: None,
440        }
441    }
442
443    /// Set heartbeat interval
444    pub fn with_interval(mut self, interval: Duration) -> Self {
445        self.interval = interval;
446        self
447    }
448
449    /// Start sending heartbeats
450    pub fn start(&mut self) {
451        let registry = self.registry.clone();
452        let agent_id = self.agent_id.clone();
453        let interval = self.interval;
454
455        let handle = tokio::spawn(async move {
456            let mut ticker = tokio::time::interval(interval);
457            loop {
458                ticker.tick().await;
459                registry.heartbeat(&agent_id).await;
460            }
461        });
462
463        self.handle = Some(handle);
464        info!(agent = %self.agent_id, interval = ?self.interval, "Heartbeat sender started");
465    }
466
467    /// Stop sending heartbeats
468    pub fn stop(&mut self) {
469        if let Some(handle) = self.handle.take() {
470            handle.abort();
471            info!(agent = %self.agent_id, "Heartbeat sender stopped");
472        }
473    }
474}
475
476impl Drop for HeartbeatSender {
477    fn drop(&mut self) {
478        self.stop();
479    }
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485
486    #[tokio::test]
487    async fn test_register_and_get() {
488        let registry = DiscoveryRegistry::new();
489
490        let info = AgentInfo::new(AgentId::new("test"), "Test Agent", AgentRole::Executor)
491            .with_capabilities(vec![Capability::Analysis]);
492
493        registry.register(info.clone()).await;
494
495        let retrieved = registry.get(&AgentId::new("test")).await;
496        assert!(retrieved.is_some());
497        assert_eq!(retrieved.unwrap().name, "Test Agent");
498    }
499
500    #[tokio::test]
501    async fn test_find_by_capability() {
502        let registry = DiscoveryRegistry::new();
503
504        let agent1 = AgentInfo::new(AgentId::new("a1"), "Agent 1", AgentRole::Researcher)
505            .with_capabilities(vec![Capability::Analysis, Capability::WebSearch]);
506
507        let agent2 = AgentInfo::new(AgentId::new("a2"), "Agent 2", AgentRole::Writer)
508            .with_capabilities(vec![Capability::ContentGeneration]);
509
510        registry.register(agent1).await;
511        registry.register(agent2).await;
512
513        let analysts = registry.find_by_capability(&Capability::Analysis).await;
514        assert_eq!(analysts.len(), 1);
515        assert_eq!(analysts[0].id, AgentId::new("a1"));
516
517        let writers = registry
518            .find_by_capability(&Capability::ContentGeneration)
519            .await;
520        assert_eq!(writers.len(), 1);
521        assert_eq!(writers[0].id, AgentId::new("a2"));
522    }
523
524    #[tokio::test]
525    async fn test_unregister() {
526        let registry = DiscoveryRegistry::new();
527
528        let info = AgentInfo::new(AgentId::new("test"), "Test Agent", AgentRole::Executor);
529        registry.register(info).await;
530
531        assert_eq!(registry.agent_count().await, 1);
532
533        registry.unregister(&AgentId::new("test")).await;
534        assert_eq!(registry.agent_count().await, 0);
535    }
536
537    #[tokio::test]
538    async fn test_heartbeat() {
539        let registry = DiscoveryRegistry::new();
540
541        let info = AgentInfo::new(AgentId::new("test"), "Test Agent", AgentRole::Executor);
542        registry.register(info).await;
543
544        // Send heartbeat
545        registry.heartbeat(&AgentId::new("test")).await;
546
547        let agent = registry.get(&AgentId::new("test")).await.unwrap();
548        assert_eq!(agent.status, AgentDiscoveryStatus::Online);
549    }
550
551    #[tokio::test]
552    async fn test_update_capabilities() {
553        let registry = DiscoveryRegistry::new();
554
555        let info = AgentInfo::new(AgentId::new("test"), "Test Agent", AgentRole::Executor);
556        registry.register(info).await;
557
558        registry
559            .update_capabilities(
560                &AgentId::new("test"),
561                vec![Capability::Analysis, Capability::Prediction],
562            )
563            .await;
564
565        let agent = registry.get(&AgentId::new("test")).await.unwrap();
566        assert_eq!(agent.capabilities.len(), 2);
567        assert!(agent.has_capability(&Capability::Analysis));
568    }
569
570    #[test]
571    fn test_agent_info_stale_check() {
572        let mut info = AgentInfo::new(AgentId::new("test"), "Test", AgentRole::Executor);
573
574        // Fresh agent should not be stale
575        assert!(!info.is_stale(Duration::from_secs(60)));
576
577        // Simulate old heartbeat by creating with old timestamp
578        info.last_heartbeat = Instant::now() - Duration::from_secs(120);
579        assert!(info.is_stale(Duration::from_secs(60)));
580
581        // Heartbeat should refresh
582        info.heartbeat();
583        assert!(!info.is_stale(Duration::from_secs(60)));
584    }
585}