1use crate::agent::capabilities::{AgentCapabilities, AgentRequirements};
6use crate::agent::context::AgentContext;
7use crate::agent::core::MoFAAgent;
8use crate::agent::error::{AgentError, AgentResult};
9use crate::agent::traits::AgentMetadata;
10use crate::agent::types::AgentState;
11use mofa_kernel::agent::config::AgentConfig;
12use mofa_kernel::agent::registry::AgentFactory;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::sync::Arc;
16use tokio::sync::RwLock;
17
18struct AgentEntry {
24 agent: Arc<RwLock<dyn MoFAAgent>>,
26 metadata: AgentMetadata,
28 registered_at: u64,
30}
31
32struct CapabilityIndex {
40 by_tag: HashMap<String, Vec<String>>,
42 by_strategy: HashMap<String, Vec<String>>,
44}
45
46impl CapabilityIndex {
47 fn new() -> Self {
48 Self {
49 by_tag: HashMap::new(),
50 by_strategy: HashMap::new(),
51 }
52 }
53
54 fn index(&mut self, agent_id: &str, capabilities: &AgentCapabilities) {
56 for tag in &capabilities.tags {
58 self.by_tag
59 .entry(tag.clone())
60 .or_default()
61 .push(agent_id.to_string());
62 }
63
64 for strategy in &capabilities.reasoning_strategies {
66 let strategy_name = format!("{:?}", strategy);
67 self.by_strategy
68 .entry(strategy_name)
69 .or_default()
70 .push(agent_id.to_string());
71 }
72 }
73
74 fn unindex(&mut self, agent_id: &str) {
76 for ids in self.by_tag.values_mut() {
77 ids.retain(|id| id != agent_id);
78 }
79 for ids in self.by_strategy.values_mut() {
80 ids.retain(|id| id != agent_id);
81 }
82 }
83
84 fn find_by_tag(&self, tag: &str) -> Vec<String> {
86 self.by_tag.get(tag).cloned().unwrap_or_default()
87 }
88
89 fn find_by_tags(&self, tags: &[String]) -> Vec<String> {
91 if tags.is_empty() {
92 return vec![];
93 }
94
95 let mut result: Option<Vec<String>> = None;
96 for tag in tags {
97 let ids = self.find_by_tag(tag);
98 result = match result {
99 None => Some(ids),
100 Some(existing) => {
101 let intersection: Vec<String> =
102 existing.into_iter().filter(|id| ids.contains(id)).collect();
103 Some(intersection)
104 }
105 };
106 }
107 result.unwrap_or_default()
108 }
109}
110
111pub struct AgentRegistry {
141 agents: Arc<RwLock<HashMap<String, AgentEntry>>>,
143 capability_index: Arc<RwLock<CapabilityIndex>>,
145 factories: Arc<RwLock<HashMap<String, Arc<dyn AgentFactory>>>>,
147}
148
149impl AgentRegistry {
150 pub fn new() -> Self {
152 Self {
153 agents: Arc::new(RwLock::new(HashMap::new())),
154 capability_index: Arc::new(RwLock::new(CapabilityIndex::new())),
155 factories: Arc::new(RwLock::new(HashMap::new())),
156 }
157 }
158
159 pub async fn register(&self, agent: Arc<RwLock<dyn MoFAAgent>>) -> AgentResult<()> {
165 let agent_guard = agent.read().await;
166 let id = agent_guard.id().to_string();
167 let name = agent_guard.name().to_string();
168 let capabilities = agent_guard.capabilities().clone();
169 let state = agent_guard.state();
170 drop(agent_guard);
171
172 let now = std::time::SystemTime::now()
173 .duration_since(std::time::UNIX_EPOCH)
174 .unwrap_or_default()
175 .as_millis() as u64;
176
177 let metadata = AgentMetadata {
178 id: id.clone(),
179 name,
180 description: None,
181 version: None,
182 capabilities: capabilities.clone(),
183 state,
184 };
185
186 let entry = AgentEntry {
187 agent,
188 metadata,
189 registered_at: now,
190 };
191
192 {
194 let mut index = self.capability_index.write().await;
195 index.index(&id, &capabilities);
196 }
197
198 {
200 let mut agents = self.agents.write().await;
201 agents.insert(id, entry);
202 }
203
204 Ok(())
205 }
206
207 pub async fn get(&self, id: &str) -> Option<Arc<RwLock<dyn MoFAAgent>>> {
209 let agents = self.agents.read().await;
210 agents.get(id).map(|e| e.agent.clone())
211 }
212
213 pub async fn unregister(&self, id: &str) -> AgentResult<bool> {
215 {
217 let mut index = self.capability_index.write().await;
218 index.unindex(id);
219 }
220
221 let mut agents = self.agents.write().await;
223 Ok(agents.remove(id).is_some())
224 }
225
226 pub async fn get_metadata(&self, id: &str) -> Option<AgentMetadata> {
228 let agents = self.agents.read().await;
229 agents.get(id).map(|e| e.metadata.clone())
230 }
231
232 pub async fn list(&self) -> Vec<AgentMetadata> {
234 let agents = self.agents.read().await;
235 agents.values().map(|e| e.metadata.clone()).collect()
236 }
237
238 pub async fn count(&self) -> usize {
240 let agents = self.agents.read().await;
241 agents.len()
242 }
243
244 pub async fn contains(&self, id: &str) -> bool {
246 let agents = self.agents.read().await;
247 agents.contains_key(id)
248 }
249
250 pub async fn find_by_capabilities(
256 &self,
257 requirements: &AgentRequirements,
258 ) -> Vec<AgentMetadata> {
259 let agents = self.agents.read().await;
260
261 agents
262 .values()
263 .filter(|entry| requirements.matches(&entry.metadata.capabilities))
264 .map(|entry| entry.metadata.clone())
265 .collect()
266 }
267
268 pub async fn find_by_tag(&self, tag: &str) -> Vec<AgentMetadata> {
270 let index = self.capability_index.read().await;
271 let ids = index.find_by_tag(tag);
272 drop(index);
273
274 let agents = self.agents.read().await;
275 ids.iter()
276 .filter_map(|id| agents.get(id).map(|e| e.metadata.clone()))
277 .collect()
278 }
279
280 pub async fn find_by_tags(&self, tags: &[String]) -> Vec<AgentMetadata> {
282 let index = self.capability_index.read().await;
283 let ids = index.find_by_tags(tags);
284 drop(index);
285
286 let agents = self.agents.read().await;
287 ids.iter()
288 .filter_map(|id| agents.get(id).map(|e| e.metadata.clone()))
289 .collect()
290 }
291
292 pub async fn find_by_state(&self, state: AgentState) -> Vec<AgentMetadata> {
294 let agents = self.agents.read().await;
295
296 agents
297 .values()
298 .filter(|entry| entry.metadata.state == state)
299 .map(|entry| entry.metadata.clone())
300 .collect()
301 }
302
303 pub async fn register_factory(&self, factory: Arc<dyn AgentFactory>) -> AgentResult<()> {
309 let type_id = factory.type_id().to_string();
310 let mut factories = self.factories.write().await;
311 factories.insert(type_id, factory);
312 Ok(())
313 }
314
315 pub async fn get_factory(&self, type_id: &str) -> Option<Arc<dyn AgentFactory>> {
317 let factories = self.factories.read().await;
318 factories.get(type_id).cloned()
319 }
320
321 pub async fn unregister_factory(&self, type_id: &str) -> AgentResult<bool> {
323 let mut factories = self.factories.write().await;
324 Ok(factories.remove(type_id).is_some())
325 }
326
327 pub async fn list_factory_types(&self) -> Vec<String> {
329 let factories = self.factories.read().await;
330 factories.keys().cloned().collect()
331 }
332
333 pub async fn create(
335 &self,
336 type_id: &str,
337 config: AgentConfig,
338 ) -> AgentResult<Arc<RwLock<dyn MoFAAgent>>> {
339 let factory = self
340 .get_factory(type_id)
341 .await
342 .ok_or_else(|| AgentError::NotFound(format!("Factory not found: {}", type_id)))?;
343
344 factory.validate_config(&config)?;
345 factory.create(config).await
346 }
347
348 pub async fn create_and_register(
350 &self,
351 type_id: &str,
352 config: AgentConfig,
353 ) -> AgentResult<Arc<RwLock<dyn MoFAAgent>>> {
354 let agent = self.create(type_id, config).await?;
355 self.register(agent.clone()).await?;
356 Ok(agent)
357 }
358
359 pub async fn initialize_all(&self, ctx: &AgentContext) -> AgentResult<Vec<String>> {
365 let agents = self.agents.read().await;
366 let mut initialized = Vec::new();
367
368 for (id, entry) in agents.iter() {
369 let mut agent = entry.agent.write().await;
370 if agent.state() == AgentState::Created {
371 agent.initialize(ctx).await?;
372 initialized.push(id.clone());
373 }
374 }
375
376 Ok(initialized)
377 }
378
379 pub async fn shutdown_all(&self) -> AgentResult<Vec<String>> {
381 let agents = self.agents.read().await;
382 let mut shutdown = Vec::new();
383
384 for (id, entry) in agents.iter() {
385 let mut agent = entry.agent.write().await;
386 let state = agent.state();
387 if state != AgentState::Shutdown && state != AgentState::Failed {
388 agent.shutdown().await?;
389 shutdown.push(id.clone());
390 }
391 }
392
393 Ok(shutdown)
394 }
395
396 pub async fn clear(&self) -> AgentResult<usize> {
398 self.shutdown_all().await?;
400
401 {
403 let mut index = self.capability_index.write().await;
404 *index = CapabilityIndex::new();
405 }
406
407 let mut agents = self.agents.write().await;
409 let count = agents.len();
410 agents.clear();
411
412 Ok(count)
413 }
414}
415
416impl Default for AgentRegistry {
417 fn default() -> Self {
418 Self::new()
419 }
420}
421
422#[derive(Debug, Clone, Default, Serialize, Deserialize)]
428pub struct RegistryStats {
429 pub total_agents: usize,
431 pub by_state: HashMap<String, usize>,
433 pub by_tag: HashMap<String, usize>,
435 pub factory_count: usize,
437}
438
439impl AgentRegistry {
440 pub async fn stats(&self) -> RegistryStats {
442 let agents = self.agents.read().await;
443 let factories = self.factories.read().await;
444
445 let mut by_state: HashMap<String, usize> = HashMap::new();
446 let mut by_tag: HashMap<String, usize> = HashMap::new();
447
448 for entry in agents.values() {
449 let state_name = format!("{:?}", entry.metadata.state);
451 *by_state.entry(state_name).or_insert(0) += 1;
452
453 for tag in &entry.metadata.capabilities.tags {
455 *by_tag.entry(tag.clone()).or_insert(0) += 1;
456 }
457 }
458
459 RegistryStats {
460 total_agents: agents.len(),
461 by_state,
462 by_tag,
463 factory_count: factories.len(),
464 }
465 }
466}
467
468#[cfg(test)]
469mod tests {
470 use super::*;
471 use crate::agent::capabilities::AgentCapabilities;
472 use crate::agent::context::AgentContext;
473 use crate::agent::core::MoFAAgent;
474 use crate::agent::error::AgentResult;
475 use crate::agent::types::{AgentInput, AgentOutput, AgentState};
476 use async_trait::async_trait;
477
478 struct TestAgent {
480 id: String,
481 name: String,
482 capabilities: AgentCapabilities,
483 state: AgentState,
484 }
485
486 impl TestAgent {
487 fn new(id: &str, name: &str) -> Self {
488 Self {
489 id: id.to_string(),
490 name: name.to_string(),
491 capabilities: AgentCapabilities::default(),
492 state: AgentState::Created,
493 }
494 }
495 }
496
497 #[async_trait]
498 impl MoFAAgent for TestAgent {
499 fn id(&self) -> &str {
500 &self.id
501 }
502
503 fn name(&self) -> &str {
504 &self.name
505 }
506
507 fn capabilities(&self) -> &AgentCapabilities {
508 &self.capabilities
509 }
510
511 fn state(&self) -> AgentState {
512 self.state.clone()
513 }
514
515 async fn initialize(&mut self, _ctx: &AgentContext) -> AgentResult<()> {
516 self.state = AgentState::Ready;
517 Ok(())
518 }
519
520 async fn execute(
521 &mut self,
522 _input: AgentInput,
523 _ctx: &AgentContext,
524 ) -> AgentResult<AgentOutput> {
525 Ok(AgentOutput::text("test output"))
526 }
527
528 async fn shutdown(&mut self) -> AgentResult<()> {
529 self.state = AgentState::Shutdown;
530 Ok(())
531 }
532 }
533
534 struct TestAgentFactory;
536
537 #[async_trait]
538 impl AgentFactory for TestAgentFactory {
539 async fn create(&self, config: AgentConfig) -> AgentResult<Arc<RwLock<dyn MoFAAgent>>> {
540 let agent = TestAgent::new(&config.id, &config.name);
541 Ok(Arc::new(RwLock::new(agent)))
542 }
543
544 fn type_id(&self) -> &str {
545 "test"
546 }
547
548 fn default_capabilities(&self) -> AgentCapabilities {
549 AgentCapabilities::builder().with_tag("test").build()
550 }
551 }
552
553 #[tokio::test]
554 async fn test_register_and_get() {
555 let registry = AgentRegistry::new();
556 let agent = Arc::new(RwLock::new(TestAgent::new("agent-1", "Test Agent")));
557
558 registry.register(agent).await.unwrap();
559
560 let found = registry.get("agent-1").await;
561 assert!(found.is_some());
562
563 let not_found = registry.get("nonexistent").await;
564 assert!(not_found.is_none());
565 }
566
567 #[tokio::test]
568 async fn test_factory_create() {
569 let registry = AgentRegistry::new();
570 registry
571 .register_factory(Arc::new(TestAgentFactory))
572 .await
573 .unwrap();
574
575 let config = AgentConfig::new("agent-2", "Created Agent");
576 let agent = registry.create("test", config).await.unwrap();
577
578 let agent_guard = agent.read().await;
579 assert_eq!(agent_guard.id(), "agent-2");
580 assert_eq!(agent_guard.name(), "Created Agent");
581 }
582
583 #[tokio::test]
584 async fn test_find_by_tag() {
585 let registry = AgentRegistry::new();
586
587 let mut agent1 = TestAgent::new("agent-1", "Agent 1");
589 agent1.capabilities = AgentCapabilities::builder()
590 .with_tag("llm")
591 .with_tag("chat")
592 .build();
593
594 let mut agent2 = TestAgent::new("agent-2", "Agent 2");
595 agent2.capabilities = AgentCapabilities::builder()
596 .with_tag("react")
597 .with_tag("chat")
598 .build();
599
600 registry
601 .register(Arc::new(RwLock::new(agent1)))
602 .await
603 .unwrap();
604 registry
605 .register(Arc::new(RwLock::new(agent2)))
606 .await
607 .unwrap();
608
609 let chat_agents = registry.find_by_tag("chat").await;
611 assert_eq!(chat_agents.len(), 2);
612
613 let llm_agents = registry.find_by_tag("llm").await;
614 assert_eq!(llm_agents.len(), 1);
615 }
616
617 #[tokio::test]
618 async fn test_unregister() {
619 let registry = AgentRegistry::new();
620 let agent = Arc::new(RwLock::new(TestAgent::new("agent-1", "Test Agent")));
621
622 registry.register(agent).await.unwrap();
623 assert!(registry.contains("agent-1").await);
624
625 registry.unregister("agent-1").await.unwrap();
626 assert!(!registry.contains("agent-1").await);
627 }
628
629 #[tokio::test]
630 async fn test_stats() {
631 let registry = AgentRegistry::new();
632 registry
633 .register_factory(Arc::new(TestAgentFactory))
634 .await
635 .unwrap();
636
637 let agent = Arc::new(RwLock::new(TestAgent::new("agent-1", "Test")));
638 registry.register(agent).await.unwrap();
639
640 let stats = registry.stats().await;
641 assert_eq!(stats.total_agents, 1);
642 assert_eq!(stats.factory_count, 1);
643 }
644}