Skip to main content

mofa_runtime/agent/
registry.rs

1//! Agent 注册中心
2//!
3//! 提供 Agent 的注册、发现、工厂创建功能
4
5use crate::agent::capabilities::{AgentCapabilities, AgentRequirements};
6use crate::agent::context::AgentContext;
7use crate::agent::core::MoFAAgent;
8use crate::agent::error::{AgentError, AgentResult};
9use crate::agent::traits::AgentMetadata;
10use crate::agent::types::AgentState;
11use mofa_kernel::agent::config::AgentConfig;
12use mofa_kernel::agent::registry::AgentFactory;
13use serde::{Deserialize, Serialize};
14use std::collections::HashMap;
15use std::sync::Arc;
16use tokio::sync::RwLock;
17
18// ============================================================================
19// Agent 注册条目
20// ============================================================================
21
22/// Agent 注册条目
23struct AgentEntry {
24    /// Agent 实例
25    agent: Arc<RwLock<dyn MoFAAgent>>,
26    /// 元数据
27    metadata: AgentMetadata,
28    /// 注册时间
29    registered_at: u64,
30}
31
32// ============================================================================
33// 能力索引
34// ============================================================================
35
36/// 能力索引
37///
38/// 用于快速查找具有特定能力的 Agent
39struct CapabilityIndex {
40    /// 标签索引: tag -> agent_ids
41    by_tag: HashMap<String, Vec<String>>,
42    /// 推理策略索引: strategy -> agent_ids
43    by_strategy: HashMap<String, Vec<String>>,
44}
45
46impl CapabilityIndex {
47    fn new() -> Self {
48        Self {
49            by_tag: HashMap::new(),
50            by_strategy: HashMap::new(),
51        }
52    }
53
54    /// 添加索引
55    fn index(&mut self, agent_id: &str, capabilities: &AgentCapabilities) {
56        // 索引标签
57        for tag in &capabilities.tags {
58            self.by_tag
59                .entry(tag.clone())
60                .or_default()
61                .push(agent_id.to_string());
62        }
63
64        // 索引推理策略
65        for strategy in &capabilities.reasoning_strategies {
66            let strategy_name = format!("{:?}", strategy);
67            self.by_strategy
68                .entry(strategy_name)
69                .or_default()
70                .push(agent_id.to_string());
71        }
72    }
73
74    /// 移除索引
75    fn unindex(&mut self, agent_id: &str) {
76        for ids in self.by_tag.values_mut() {
77            ids.retain(|id| id != agent_id);
78        }
79        for ids in self.by_strategy.values_mut() {
80            ids.retain(|id| id != agent_id);
81        }
82    }
83
84    /// 按标签查找
85    fn find_by_tag(&self, tag: &str) -> Vec<String> {
86        self.by_tag.get(tag).cloned().unwrap_or_default()
87    }
88
89    /// 按多个标签查找 (交集)
90    fn find_by_tags(&self, tags: &[String]) -> Vec<String> {
91        if tags.is_empty() {
92            return vec![];
93        }
94
95        let mut result: Option<Vec<String>> = None;
96        for tag in tags {
97            let ids = self.find_by_tag(tag);
98            result = match result {
99                None => Some(ids),
100                Some(existing) => {
101                    let intersection: Vec<String> =
102                        existing.into_iter().filter(|id| ids.contains(id)).collect();
103                    Some(intersection)
104                }
105            };
106        }
107        result.unwrap_or_default()
108    }
109}
110
111// ============================================================================
112// Agent 注册中心
113// ============================================================================
114
115/// Agent 注册中心
116///
117/// 提供 Agent 的注册、发现、工厂创建功能
118///
119/// # 示例
120///
121/// ```rust,ignore
122/// use mofa_runtime::agent::AgentRegistry;
123/// use mofa_kernel::agent::config::AgentConfig;
124///
125/// let registry = AgentRegistry::new();
126///
127/// // 注册工厂
128/// registry.register_factory(Arc::new(LLMAgentFactory)).await?;
129///
130/// // 通过工厂创建 Agent
131/// let config = AgentConfig::new("agent-1", "My Agent", "llm");
132/// let agent = registry.create("llm", config).await?;
133///
134/// // 注册 Agent
135/// registry.register(agent).await?;
136///
137/// // 查找 Agent
138/// let found = registry.get("agent-1").await;
139/// ```
140pub struct AgentRegistry {
141    /// 已注册的 Agent
142    agents: Arc<RwLock<HashMap<String, AgentEntry>>>,
143    /// 能力索引
144    capability_index: Arc<RwLock<CapabilityIndex>>,
145    /// Agent 工厂
146    factories: Arc<RwLock<HashMap<String, Arc<dyn AgentFactory>>>>,
147}
148
149impl AgentRegistry {
150    /// 创建新的注册中心
151    pub fn new() -> Self {
152        Self {
153            agents: Arc::new(RwLock::new(HashMap::new())),
154            capability_index: Arc::new(RwLock::new(CapabilityIndex::new())),
155            factories: Arc::new(RwLock::new(HashMap::new())),
156        }
157    }
158
159    // ========================================================================
160    // Agent 管理
161    // ========================================================================
162
163    /// 注册 Agent
164    pub async fn register(&self, agent: Arc<RwLock<dyn MoFAAgent>>) -> AgentResult<()> {
165        let agent_guard = agent.read().await;
166        let id = agent_guard.id().to_string();
167        let name = agent_guard.name().to_string();
168        let capabilities = agent_guard.capabilities().clone();
169        let state = agent_guard.state();
170        drop(agent_guard);
171
172        let now = std::time::SystemTime::now()
173            .duration_since(std::time::UNIX_EPOCH)
174            .unwrap_or_default()
175            .as_millis() as u64;
176
177        let metadata = AgentMetadata {
178            id: id.clone(),
179            name,
180            description: None,
181            version: None,
182            capabilities: capabilities.clone(),
183            state,
184        };
185
186        let entry = AgentEntry {
187            agent,
188            metadata,
189            registered_at: now,
190        };
191
192        // 更新能力索引
193        {
194            let mut index = self.capability_index.write().await;
195            index.index(&id, &capabilities);
196        }
197
198        // 注册 Agent
199        {
200            let mut agents = self.agents.write().await;
201            agents.insert(id, entry);
202        }
203
204        Ok(())
205    }
206
207    /// 获取 Agent
208    pub async fn get(&self, id: &str) -> Option<Arc<RwLock<dyn MoFAAgent>>> {
209        let agents = self.agents.read().await;
210        agents.get(id).map(|e| e.agent.clone())
211    }
212
213    /// 移除 Agent
214    pub async fn unregister(&self, id: &str) -> AgentResult<bool> {
215        // 更新能力索引
216        {
217            let mut index = self.capability_index.write().await;
218            index.unindex(id);
219        }
220
221        // 移除 Agent
222        let mut agents = self.agents.write().await;
223        Ok(agents.remove(id).is_some())
224    }
225
226    /// 获取 Agent 元数据
227    pub async fn get_metadata(&self, id: &str) -> Option<AgentMetadata> {
228        let agents = self.agents.read().await;
229        agents.get(id).map(|e| e.metadata.clone())
230    }
231
232    /// 列出所有 Agent
233    pub async fn list(&self) -> Vec<AgentMetadata> {
234        let agents = self.agents.read().await;
235        agents.values().map(|e| e.metadata.clone()).collect()
236    }
237
238    /// 获取 Agent 数量
239    pub async fn count(&self) -> usize {
240        let agents = self.agents.read().await;
241        agents.len()
242    }
243
244    /// 检查 Agent 是否存在
245    pub async fn contains(&self, id: &str) -> bool {
246        let agents = self.agents.read().await;
247        agents.contains_key(id)
248    }
249
250    // ========================================================================
251    // 能力查询
252    // ========================================================================
253
254    /// 按能力要求查找 Agent
255    pub async fn find_by_capabilities(
256        &self,
257        requirements: &AgentRequirements,
258    ) -> Vec<AgentMetadata> {
259        let agents = self.agents.read().await;
260
261        agents
262            .values()
263            .filter(|entry| requirements.matches(&entry.metadata.capabilities))
264            .map(|entry| entry.metadata.clone())
265            .collect()
266    }
267
268    /// 按标签查找 Agent
269    pub async fn find_by_tag(&self, tag: &str) -> Vec<AgentMetadata> {
270        let index = self.capability_index.read().await;
271        let ids = index.find_by_tag(tag);
272        drop(index);
273
274        let agents = self.agents.read().await;
275        ids.iter()
276            .filter_map(|id| agents.get(id).map(|e| e.metadata.clone()))
277            .collect()
278    }
279
280    /// 按多个标签查找 Agent (交集)
281    pub async fn find_by_tags(&self, tags: &[String]) -> Vec<AgentMetadata> {
282        let index = self.capability_index.read().await;
283        let ids = index.find_by_tags(tags);
284        drop(index);
285
286        let agents = self.agents.read().await;
287        ids.iter()
288            .filter_map(|id| agents.get(id).map(|e| e.metadata.clone()))
289            .collect()
290    }
291
292    /// 按状态查找 Agent
293    pub async fn find_by_state(&self, state: AgentState) -> Vec<AgentMetadata> {
294        let agents = self.agents.read().await;
295
296        agents
297            .values()
298            .filter(|entry| entry.metadata.state == state)
299            .map(|entry| entry.metadata.clone())
300            .collect()
301    }
302
303    // ========================================================================
304    // 工厂管理
305    // ========================================================================
306
307    /// 注册 Agent 工厂
308    pub async fn register_factory(&self, factory: Arc<dyn AgentFactory>) -> AgentResult<()> {
309        let type_id = factory.type_id().to_string();
310        let mut factories = self.factories.write().await;
311        factories.insert(type_id, factory);
312        Ok(())
313    }
314
315    /// 获取 Agent 工厂
316    pub async fn get_factory(&self, type_id: &str) -> Option<Arc<dyn AgentFactory>> {
317        let factories = self.factories.read().await;
318        factories.get(type_id).cloned()
319    }
320
321    /// 移除 Agent 工厂
322    pub async fn unregister_factory(&self, type_id: &str) -> AgentResult<bool> {
323        let mut factories = self.factories.write().await;
324        Ok(factories.remove(type_id).is_some())
325    }
326
327    /// 列出所有工厂类型
328    pub async fn list_factory_types(&self) -> Vec<String> {
329        let factories = self.factories.read().await;
330        factories.keys().cloned().collect()
331    }
332
333    /// 通过工厂创建 Agent
334    pub async fn create(
335        &self,
336        type_id: &str,
337        config: AgentConfig,
338    ) -> AgentResult<Arc<RwLock<dyn MoFAAgent>>> {
339        let factory = self
340            .get_factory(type_id)
341            .await
342            .ok_or_else(|| AgentError::NotFound(format!("Factory not found: {}", type_id)))?;
343
344        factory.validate_config(&config)?;
345        factory.create(config).await
346    }
347
348    /// 创建并注册 Agent
349    pub async fn create_and_register(
350        &self,
351        type_id: &str,
352        config: AgentConfig,
353    ) -> AgentResult<Arc<RwLock<dyn MoFAAgent>>> {
354        let agent = self.create(type_id, config).await?;
355        self.register(agent.clone()).await?;
356        Ok(agent)
357    }
358
359    // ========================================================================
360    // 批量操作
361    // ========================================================================
362
363    /// 初始化所有 Agent
364    pub async fn initialize_all(&self, ctx: &AgentContext) -> AgentResult<Vec<String>> {
365        let agents = self.agents.read().await;
366        let mut initialized = Vec::new();
367
368        for (id, entry) in agents.iter() {
369            let mut agent = entry.agent.write().await;
370            if agent.state() == AgentState::Created {
371                agent.initialize(ctx).await?;
372                initialized.push(id.clone());
373            }
374        }
375
376        Ok(initialized)
377    }
378
379    /// 关闭所有 Agent
380    pub async fn shutdown_all(&self) -> AgentResult<Vec<String>> {
381        let agents = self.agents.read().await;
382        let mut shutdown = Vec::new();
383
384        for (id, entry) in agents.iter() {
385            let mut agent = entry.agent.write().await;
386            let state = agent.state();
387            if state != AgentState::Shutdown && state != AgentState::Failed {
388                agent.shutdown().await?;
389                shutdown.push(id.clone());
390            }
391        }
392
393        Ok(shutdown)
394    }
395
396    /// 清空所有 Agent
397    pub async fn clear(&self) -> AgentResult<usize> {
398        // 先关闭所有 Agent
399        self.shutdown_all().await?;
400
401        // 清空索引
402        {
403            let mut index = self.capability_index.write().await;
404            *index = CapabilityIndex::new();
405        }
406
407        // 清空 Agent
408        let mut agents = self.agents.write().await;
409        let count = agents.len();
410        agents.clear();
411
412        Ok(count)
413    }
414}
415
416impl Default for AgentRegistry {
417    fn default() -> Self {
418        Self::new()
419    }
420}
421
422// ============================================================================
423// 注册中心统计
424// ============================================================================
425
426/// 注册中心统计
427#[derive(Debug, Clone, Default, Serialize, Deserialize)]
428pub struct RegistryStats {
429    /// 总 Agent 数
430    pub total_agents: usize,
431    /// 各状态 Agent 数
432    pub by_state: HashMap<String, usize>,
433    /// 各标签 Agent 数
434    pub by_tag: HashMap<String, usize>,
435    /// 工厂类型数
436    pub factory_count: usize,
437}
438
439impl AgentRegistry {
440    /// 获取统计信息
441    pub async fn stats(&self) -> RegistryStats {
442        let agents = self.agents.read().await;
443        let factories = self.factories.read().await;
444
445        let mut by_state: HashMap<String, usize> = HashMap::new();
446        let mut by_tag: HashMap<String, usize> = HashMap::new();
447
448        for entry in agents.values() {
449            // 统计状态
450            let state_name = format!("{:?}", entry.metadata.state);
451            *by_state.entry(state_name).or_insert(0) += 1;
452
453            // 统计标签
454            for tag in &entry.metadata.capabilities.tags {
455                *by_tag.entry(tag.clone()).or_insert(0) += 1;
456            }
457        }
458
459        RegistryStats {
460            total_agents: agents.len(),
461            by_state,
462            by_tag,
463            factory_count: factories.len(),
464        }
465    }
466}
467
468#[cfg(test)]
469mod tests {
470    use super::*;
471    use crate::agent::capabilities::AgentCapabilities;
472    use crate::agent::context::AgentContext;
473    use crate::agent::core::MoFAAgent;
474    use crate::agent::error::AgentResult;
475    use crate::agent::types::{AgentInput, AgentOutput, AgentState};
476    use async_trait::async_trait;
477
478    // 测试用的简单 Agent (内联实现,不依赖 BaseAgent)
479    struct TestAgent {
480        id: String,
481        name: String,
482        capabilities: AgentCapabilities,
483        state: AgentState,
484    }
485
486    impl TestAgent {
487        fn new(id: &str, name: &str) -> Self {
488            Self {
489                id: id.to_string(),
490                name: name.to_string(),
491                capabilities: AgentCapabilities::default(),
492                state: AgentState::Created,
493            }
494        }
495    }
496
497    #[async_trait]
498    impl MoFAAgent for TestAgent {
499        fn id(&self) -> &str {
500            &self.id
501        }
502
503        fn name(&self) -> &str {
504            &self.name
505        }
506
507        fn capabilities(&self) -> &AgentCapabilities {
508            &self.capabilities
509        }
510
511        fn state(&self) -> AgentState {
512            self.state.clone()
513        }
514
515        async fn initialize(&mut self, _ctx: &AgentContext) -> AgentResult<()> {
516            self.state = AgentState::Ready;
517            Ok(())
518        }
519
520        async fn execute(
521            &mut self,
522            _input: AgentInput,
523            _ctx: &AgentContext,
524        ) -> AgentResult<AgentOutput> {
525            Ok(AgentOutput::text("test output"))
526        }
527
528        async fn shutdown(&mut self) -> AgentResult<()> {
529            self.state = AgentState::Shutdown;
530            Ok(())
531        }
532    }
533
534    // 测试用的工厂
535    struct TestAgentFactory;
536
537    #[async_trait]
538    impl AgentFactory for TestAgentFactory {
539        async fn create(&self, config: AgentConfig) -> AgentResult<Arc<RwLock<dyn MoFAAgent>>> {
540            let agent = TestAgent::new(&config.id, &config.name);
541            Ok(Arc::new(RwLock::new(agent)))
542        }
543
544        fn type_id(&self) -> &str {
545            "test"
546        }
547
548        fn default_capabilities(&self) -> AgentCapabilities {
549            AgentCapabilities::builder().with_tag("test").build()
550        }
551    }
552
553    #[tokio::test]
554    async fn test_register_and_get() {
555        let registry = AgentRegistry::new();
556        let agent = Arc::new(RwLock::new(TestAgent::new("agent-1", "Test Agent")));
557
558        registry.register(agent).await.unwrap();
559
560        let found = registry.get("agent-1").await;
561        assert!(found.is_some());
562
563        let not_found = registry.get("nonexistent").await;
564        assert!(not_found.is_none());
565    }
566
567    #[tokio::test]
568    async fn test_factory_create() {
569        let registry = AgentRegistry::new();
570        registry
571            .register_factory(Arc::new(TestAgentFactory))
572            .await
573            .unwrap();
574
575        let config = AgentConfig::new("agent-2", "Created Agent");
576        let agent = registry.create("test", config).await.unwrap();
577
578        let agent_guard = agent.read().await;
579        assert_eq!(agent_guard.id(), "agent-2");
580        assert_eq!(agent_guard.name(), "Created Agent");
581    }
582
583    #[tokio::test]
584    async fn test_find_by_tag() {
585        let registry = AgentRegistry::new();
586
587        // 创建带有标签的 Agent
588        let mut agent1 = TestAgent::new("agent-1", "Agent 1");
589        agent1.capabilities = AgentCapabilities::builder()
590            .with_tag("llm")
591            .with_tag("chat")
592            .build();
593
594        let mut agent2 = TestAgent::new("agent-2", "Agent 2");
595        agent2.capabilities = AgentCapabilities::builder()
596            .with_tag("react")
597            .with_tag("chat")
598            .build();
599
600        registry
601            .register(Arc::new(RwLock::new(agent1)))
602            .await
603            .unwrap();
604        registry
605            .register(Arc::new(RwLock::new(agent2)))
606            .await
607            .unwrap();
608
609        // 按标签查找
610        let chat_agents = registry.find_by_tag("chat").await;
611        assert_eq!(chat_agents.len(), 2);
612
613        let llm_agents = registry.find_by_tag("llm").await;
614        assert_eq!(llm_agents.len(), 1);
615    }
616
617    #[tokio::test]
618    async fn test_unregister() {
619        let registry = AgentRegistry::new();
620        let agent = Arc::new(RwLock::new(TestAgent::new("agent-1", "Test Agent")));
621
622        registry.register(agent).await.unwrap();
623        assert!(registry.contains("agent-1").await);
624
625        registry.unregister("agent-1").await.unwrap();
626        assert!(!registry.contains("agent-1").await);
627    }
628
629    #[tokio::test]
630    async fn test_stats() {
631        let registry = AgentRegistry::new();
632        registry
633            .register_factory(Arc::new(TestAgentFactory))
634            .await
635            .unwrap();
636
637        let agent = Arc::new(RwLock::new(TestAgent::new("agent-1", "Test")));
638        registry.register(agent).await.unwrap();
639
640        let stats = registry.stats().await;
641        assert_eq!(stats.total_agents, 1);
642        assert_eq!(stats.factory_count, 1);
643    }
644}