1use std::collections::HashMap;
41use std::sync::Arc;
42use std::time::{Duration, Instant};
43use tokio::sync::{broadcast, RwLock};
44use tracing::{debug, info, warn};
45
46use cortexai_core::types::{AgentId, AgentRole, Capability};
47
48#[derive(Debug, Clone)]
50pub struct AgentInfo {
51 pub id: AgentId,
53 pub name: String,
55 pub role: AgentRole,
57 pub capabilities: Vec<Capability>,
59 pub status: AgentDiscoveryStatus,
61 pub last_heartbeat: Instant,
63 pub metadata: HashMap<String, serde_json::Value>,
65 pub endpoint: Option<String>,
67}
68
69impl AgentInfo {
70 pub fn new(id: AgentId, name: impl Into<String>, role: AgentRole) -> Self {
71 Self {
72 id,
73 name: name.into(),
74 role,
75 capabilities: Vec::new(),
76 status: AgentDiscoveryStatus::Online,
77 last_heartbeat: Instant::now(),
78 metadata: HashMap::new(),
79 endpoint: None,
80 }
81 }
82
83 pub fn with_capabilities(mut self, capabilities: Vec<Capability>) -> Self {
84 self.capabilities = capabilities;
85 self
86 }
87
88 pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
89 self.endpoint = Some(endpoint.into());
90 self
91 }
92
93 pub fn add_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
94 self.metadata.insert(key.into(), value);
95 self
96 }
97
98 pub fn has_capability(&self, capability: &Capability) -> bool {
100 self.capabilities.contains(capability)
101 }
102
103 pub fn heartbeat(&mut self) {
105 self.last_heartbeat = Instant::now();
106 self.status = AgentDiscoveryStatus::Online;
107 }
108
109 pub fn is_stale(&self, timeout: Duration) -> bool {
111 self.last_heartbeat.elapsed() > timeout
112 }
113}
114
115impl Default for AgentInfo {
116 fn default() -> Self {
117 Self {
118 id: AgentId::generate(),
119 name: "Unknown".to_string(),
120 role: AgentRole::Executor,
121 capabilities: Vec::new(),
122 status: AgentDiscoveryStatus::Unknown,
123 last_heartbeat: Instant::now(),
124 metadata: HashMap::new(),
125 endpoint: None,
126 }
127 }
128}
129
130#[derive(Debug, Clone, PartialEq, Eq)]
132pub enum AgentDiscoveryStatus {
133 Online,
135 Degraded,
137 Offline,
139 Unknown,
141}
142
143#[derive(Debug, Clone)]
145pub enum DiscoveryEvent {
146 AgentJoined(AgentInfo),
148 AgentLeft(AgentId),
150 HeartbeatReceived(AgentId),
152 CapabilitiesUpdated(AgentId, Vec<Capability>),
154 StatusChanged(AgentId, AgentDiscoveryStatus),
156}
157
158#[derive(Debug, Clone)]
160pub struct DiscoveryConfig {
161 pub heartbeat_interval: Duration,
163 pub degraded_timeout: Duration,
165 pub offline_timeout: Duration,
167 pub auto_cleanup: bool,
169 pub cleanup_interval: Duration,
171}
172
173impl Default for DiscoveryConfig {
174 fn default() -> Self {
175 Self {
176 heartbeat_interval: Duration::from_secs(30),
177 degraded_timeout: Duration::from_secs(60),
178 offline_timeout: Duration::from_secs(120),
179 auto_cleanup: true,
180 cleanup_interval: Duration::from_secs(60),
181 }
182 }
183}
184
185pub struct DiscoveryRegistry {
187 agents: Arc<RwLock<HashMap<AgentId, AgentInfo>>>,
188 #[allow(dead_code)]
189 config: DiscoveryConfig,
190 event_tx: broadcast::Sender<DiscoveryEvent>,
191 _cleanup_handle: Option<tokio::task::JoinHandle<()>>,
192}
193
194impl DiscoveryRegistry {
195 pub fn new() -> Self {
197 Self::with_config(DiscoveryConfig::default())
198 }
199
200 pub fn with_config(config: DiscoveryConfig) -> Self {
202 let (event_tx, _) = broadcast::channel(100);
203 let agents = Arc::new(RwLock::new(HashMap::new()));
204
205 let cleanup_handle = if config.auto_cleanup {
206 let agents_clone = agents.clone();
207 let event_tx_clone = event_tx.clone();
208 let cleanup_interval = config.cleanup_interval;
209 let offline_timeout = config.offline_timeout;
210
211 Some(tokio::spawn(async move {
212 let mut interval = tokio::time::interval(cleanup_interval);
213 loop {
214 interval.tick().await;
215 Self::cleanup_stale_agents(&agents_clone, &event_tx_clone, offline_timeout)
216 .await;
217 }
218 }))
219 } else {
220 None
221 };
222
223 Self {
224 agents,
225 config,
226 event_tx,
227 _cleanup_handle: cleanup_handle,
228 }
229 }
230
231 pub async fn register(&self, info: AgentInfo) {
233 let id = info.id.clone();
234 let is_new = {
235 let agents = self.agents.read().await;
236 !agents.contains_key(&id)
237 };
238
239 {
240 let mut agents = self.agents.write().await;
241 agents.insert(id.clone(), info.clone());
242 }
243
244 if is_new {
245 info!(agent = %id, "Agent registered");
246 let _ = self.event_tx.send(DiscoveryEvent::AgentJoined(info));
247 }
248 }
249
250 pub async fn unregister(&self, id: &AgentId) {
252 let removed = {
253 let mut agents = self.agents.write().await;
254 agents.remove(id)
255 };
256
257 if removed.is_some() {
258 info!(agent = %id, "Agent unregistered");
259 let _ = self.event_tx.send(DiscoveryEvent::AgentLeft(id.clone()));
260 }
261 }
262
263 pub async fn heartbeat(&self, id: &AgentId) {
265 let mut agents = self.agents.write().await;
266 if let Some(agent) = agents.get_mut(id) {
267 let old_status = agent.status.clone();
268 agent.heartbeat();
269
270 if old_status != AgentDiscoveryStatus::Online {
271 drop(agents);
272 let _ = self.event_tx.send(DiscoveryEvent::StatusChanged(
273 id.clone(),
274 AgentDiscoveryStatus::Online,
275 ));
276 }
277
278 debug!(agent = %id, "Heartbeat received");
279 let _ = self
280 .event_tx
281 .send(DiscoveryEvent::HeartbeatReceived(id.clone()));
282 }
283 }
284
285 pub async fn update_capabilities(&self, id: &AgentId, capabilities: Vec<Capability>) {
287 let mut agents = self.agents.write().await;
288 if let Some(agent) = agents.get_mut(id) {
289 agent.capabilities = capabilities.clone();
290 drop(agents);
291
292 info!(agent = %id, caps = ?capabilities, "Capabilities updated");
293 let _ = self.event_tx.send(DiscoveryEvent::CapabilitiesUpdated(
294 id.clone(),
295 capabilities,
296 ));
297 }
298 }
299
300 pub async fn get(&self, id: &AgentId) -> Option<AgentInfo> {
302 let agents = self.agents.read().await;
303 agents.get(id).cloned()
304 }
305
306 pub async fn list_all(&self) -> Vec<AgentInfo> {
308 let agents = self.agents.read().await;
309 agents.values().cloned().collect()
310 }
311
312 pub async fn list_online(&self) -> Vec<AgentInfo> {
314 let agents = self.agents.read().await;
315 agents
316 .values()
317 .filter(|a| a.status == AgentDiscoveryStatus::Online)
318 .cloned()
319 .collect()
320 }
321
322 pub async fn find_by_capability(&self, capability: &Capability) -> Vec<AgentInfo> {
324 let agents = self.agents.read().await;
325 agents
326 .values()
327 .filter(|a| a.has_capability(capability))
328 .cloned()
329 .collect()
330 }
331
332 pub async fn find_by_role(&self, role: &AgentRole) -> Vec<AgentInfo> {
334 let agents = self.agents.read().await;
335 agents
336 .values()
337 .filter(|a| std::mem::discriminant(&a.role) == std::mem::discriminant(role))
338 .cloned()
339 .collect()
340 }
341
342 pub async fn find_by_capabilities(&self, capabilities: &[Capability]) -> Vec<AgentInfo> {
344 let agents = self.agents.read().await;
345 agents
346 .values()
347 .filter(|a| capabilities.iter().all(|c| a.has_capability(c)))
348 .cloned()
349 .collect()
350 }
351
352 pub async fn find_by_any_capability(&self, capabilities: &[Capability]) -> Vec<AgentInfo> {
354 let agents = self.agents.read().await;
355 agents
356 .values()
357 .filter(|a| capabilities.iter().any(|c| a.has_capability(c)))
358 .cloned()
359 .collect()
360 }
361
362 pub fn subscribe(&self) -> broadcast::Receiver<DiscoveryEvent> {
364 self.event_tx.subscribe()
365 }
366
367 pub async fn agent_count(&self) -> usize {
369 let agents = self.agents.read().await;
370 agents.len()
371 }
372
373 pub async fn online_count(&self) -> usize {
375 let agents = self.agents.read().await;
376 agents
377 .values()
378 .filter(|a| a.status == AgentDiscoveryStatus::Online)
379 .count()
380 }
381
382 async fn cleanup_stale_agents(
384 agents: &Arc<RwLock<HashMap<AgentId, AgentInfo>>>,
385 event_tx: &broadcast::Sender<DiscoveryEvent>,
386 offline_timeout: Duration,
387 ) {
388 let mut agents_write = agents.write().await;
389
390 let stale_ids: Vec<AgentId> = agents_write
392 .iter()
393 .filter(|(_, a)| a.is_stale(offline_timeout))
394 .map(|(id, _)| id.clone())
395 .collect();
396
397 for id in stale_ids {
398 if let Some(agent) = agents_write.remove(&id) {
399 warn!(agent = %id, "Removing stale agent");
400 let _ = event_tx.send(DiscoveryEvent::AgentLeft(agent.id));
401 }
402 }
403
404 let degraded_timeout = offline_timeout / 2;
406 for agent in agents_write.values_mut() {
407 if agent.status == AgentDiscoveryStatus::Online && agent.is_stale(degraded_timeout) {
408 agent.status = AgentDiscoveryStatus::Degraded;
409 let _ = event_tx.send(DiscoveryEvent::StatusChanged(
410 agent.id.clone(),
411 AgentDiscoveryStatus::Degraded,
412 ));
413 }
414 }
415 }
416}
417
418impl Default for DiscoveryRegistry {
419 fn default() -> Self {
420 Self::new()
421 }
422}
423
424pub struct HeartbeatSender {
426 registry: Arc<DiscoveryRegistry>,
427 agent_id: AgentId,
428 interval: Duration,
429 handle: Option<tokio::task::JoinHandle<()>>,
430}
431
432impl HeartbeatSender {
433 pub fn new(registry: Arc<DiscoveryRegistry>, agent_id: AgentId) -> Self {
435 Self {
436 registry,
437 agent_id,
438 interval: Duration::from_secs(30),
439 handle: None,
440 }
441 }
442
443 pub fn with_interval(mut self, interval: Duration) -> Self {
445 self.interval = interval;
446 self
447 }
448
449 pub fn start(&mut self) {
451 let registry = self.registry.clone();
452 let agent_id = self.agent_id.clone();
453 let interval = self.interval;
454
455 let handle = tokio::spawn(async move {
456 let mut ticker = tokio::time::interval(interval);
457 loop {
458 ticker.tick().await;
459 registry.heartbeat(&agent_id).await;
460 }
461 });
462
463 self.handle = Some(handle);
464 info!(agent = %self.agent_id, interval = ?self.interval, "Heartbeat sender started");
465 }
466
467 pub fn stop(&mut self) {
469 if let Some(handle) = self.handle.take() {
470 handle.abort();
471 info!(agent = %self.agent_id, "Heartbeat sender stopped");
472 }
473 }
474}
475
476impl Drop for HeartbeatSender {
477 fn drop(&mut self) {
478 self.stop();
479 }
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485
486 #[tokio::test]
487 async fn test_register_and_get() {
488 let registry = DiscoveryRegistry::new();
489
490 let info = AgentInfo::new(AgentId::new("test"), "Test Agent", AgentRole::Executor)
491 .with_capabilities(vec![Capability::Analysis]);
492
493 registry.register(info.clone()).await;
494
495 let retrieved = registry.get(&AgentId::new("test")).await;
496 assert!(retrieved.is_some());
497 assert_eq!(retrieved.unwrap().name, "Test Agent");
498 }
499
500 #[tokio::test]
501 async fn test_find_by_capability() {
502 let registry = DiscoveryRegistry::new();
503
504 let agent1 = AgentInfo::new(AgentId::new("a1"), "Agent 1", AgentRole::Researcher)
505 .with_capabilities(vec![Capability::Analysis, Capability::WebSearch]);
506
507 let agent2 = AgentInfo::new(AgentId::new("a2"), "Agent 2", AgentRole::Writer)
508 .with_capabilities(vec![Capability::ContentGeneration]);
509
510 registry.register(agent1).await;
511 registry.register(agent2).await;
512
513 let analysts = registry.find_by_capability(&Capability::Analysis).await;
514 assert_eq!(analysts.len(), 1);
515 assert_eq!(analysts[0].id, AgentId::new("a1"));
516
517 let writers = registry
518 .find_by_capability(&Capability::ContentGeneration)
519 .await;
520 assert_eq!(writers.len(), 1);
521 assert_eq!(writers[0].id, AgentId::new("a2"));
522 }
523
524 #[tokio::test]
525 async fn test_unregister() {
526 let registry = DiscoveryRegistry::new();
527
528 let info = AgentInfo::new(AgentId::new("test"), "Test Agent", AgentRole::Executor);
529 registry.register(info).await;
530
531 assert_eq!(registry.agent_count().await, 1);
532
533 registry.unregister(&AgentId::new("test")).await;
534 assert_eq!(registry.agent_count().await, 0);
535 }
536
537 #[tokio::test]
538 async fn test_heartbeat() {
539 let registry = DiscoveryRegistry::new();
540
541 let info = AgentInfo::new(AgentId::new("test"), "Test Agent", AgentRole::Executor);
542 registry.register(info).await;
543
544 registry.heartbeat(&AgentId::new("test")).await;
546
547 let agent = registry.get(&AgentId::new("test")).await.unwrap();
548 assert_eq!(agent.status, AgentDiscoveryStatus::Online);
549 }
550
551 #[tokio::test]
552 async fn test_update_capabilities() {
553 let registry = DiscoveryRegistry::new();
554
555 let info = AgentInfo::new(AgentId::new("test"), "Test Agent", AgentRole::Executor);
556 registry.register(info).await;
557
558 registry
559 .update_capabilities(
560 &AgentId::new("test"),
561 vec![Capability::Analysis, Capability::Prediction],
562 )
563 .await;
564
565 let agent = registry.get(&AgentId::new("test")).await.unwrap();
566 assert_eq!(agent.capabilities.len(), 2);
567 assert!(agent.has_capability(&Capability::Analysis));
568 }
569
570 #[test]
571 fn test_agent_info_stale_check() {
572 let mut info = AgentInfo::new(AgentId::new("test"), "Test", AgentRole::Executor);
573
574 assert!(!info.is_stale(Duration::from_secs(60)));
576
577 info.last_heartbeat = Instant::now() - Duration::from_secs(120);
579 assert!(info.is_stale(Duration::from_secs(60)));
580
581 info.heartbeat();
583 assert!(!info.is_stale(Duration::from_secs(60)));
584 }
585}