Skip to main content

mofa_kernel/agent/
capabilities.rs

1//! Agent 能力定义
2//!
3//! 定义 Agent 的能力发现和匹配机制
4
5use super::types::{InputType, OutputType};
6use serde::{Deserialize, Serialize};
7use std::collections::{HashMap, HashSet};
8
9/// 推理策略
10#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
11pub enum ReasoningStrategy {
12    /// 直接 LLM 推理
13    #[default]
14    Direct,
15    /// ReAct 风格的思考-行动-观察循环
16    ReAct {
17        /// 最大迭代次数
18        max_iterations: usize,
19    },
20    /// 思维链推理
21    ChainOfThought,
22    /// 思维树探索
23    TreeOfThought {
24        /// 分支因子
25        branching_factor: usize,
26    },
27    /// 自定义推理模式
28    Custom(String),
29}
30
31/// Agent 能力描述
32///
33/// 用于能力发现和任务路由
34#[derive(Debug, Clone, Default, Serialize, Deserialize)]
35pub struct AgentCapabilities {
36    /// 能力标签 (如 "llm", "coding", "research")
37    pub tags: HashSet<String>,
38    /// 支持的输入类型
39    pub input_types: HashSet<InputType>,
40    /// 支持的输出类型
41    pub output_types: HashSet<OutputType>,
42    /// 最大上下文长度 (对于 LLM 类 Agent)
43    pub max_context_length: Option<usize>,
44    /// 支持的推理策略
45    pub reasoning_strategies: Vec<ReasoningStrategy>,
46    /// 是否支持流式输出
47    pub supports_streaming: bool,
48    /// 是否支持多轮对话
49    pub supports_conversation: bool,
50    /// 是否支持工具调用
51    pub supports_tools: bool,
52    /// 是否支持多 Agent 协调
53    pub supports_coordination: bool,
54    /// 自定义能力标志
55    pub custom: HashMap<String, serde_json::Value>,
56}
57
58impl AgentCapabilities {
59    /// 创建新的能力描述
60    pub fn new() -> Self {
61        Self::default()
62    }
63
64    /// 创建构建器
65    pub fn builder() -> AgentCapabilitiesBuilder {
66        AgentCapabilitiesBuilder::default()
67    }
68
69    /// 检查是否有指定标签
70    pub fn has_tag(&self, tag: &str) -> bool {
71        self.tags.contains(tag)
72    }
73
74    /// 检查是否支持指定输入类型
75    pub fn supports_input(&self, input_type: &InputType) -> bool {
76        self.input_types.contains(input_type)
77    }
78
79    /// 检查是否支持指定输出类型
80    pub fn supports_output(&self, output_type: &OutputType) -> bool {
81        self.output_types.contains(output_type)
82    }
83
84    /// 检查是否匹配需求
85    pub fn matches(&self, requirements: &AgentRequirements) -> bool {
86        // 检查必需标签
87        if !requirements
88            .required_tags
89            .iter()
90            .all(|t| self.tags.contains(t))
91        {
92            return false;
93        }
94
95        // 检查输入类型
96        if !requirements
97            .input_types
98            .iter()
99            .all(|t| self.input_types.contains(t))
100        {
101            return false;
102        }
103
104        // 检查输出类型
105        if !requirements
106            .output_types
107            .iter()
108            .all(|t| self.output_types.contains(t))
109        {
110            return false;
111        }
112
113        // 检查功能要求
114        if requirements.requires_streaming && !self.supports_streaming {
115            return false;
116        }
117        if requirements.requires_tools && !self.supports_tools {
118            return false;
119        }
120        if requirements.requires_conversation && !self.supports_conversation {
121            return false;
122        }
123        if requirements.requires_coordination && !self.supports_coordination {
124            return false;
125        }
126
127        true
128    }
129
130    /// 计算与需求的匹配分数 (0.0 - 1.0)
131    pub fn match_score(&self, requirements: &AgentRequirements) -> f64 {
132        if !self.matches(requirements) {
133            return 0.0;
134        }
135
136        let mut score = 0.0;
137        let mut weight = 0.0;
138
139        // 标签匹配
140        weight += 1.0;
141        if !requirements.required_tags.is_empty() {
142            let matched = requirements
143                .required_tags
144                .iter()
145                .filter(|t| self.tags.contains(*t))
146                .count();
147            score += matched as f64 / requirements.required_tags.len() as f64;
148        } else {
149            score += 1.0;
150        }
151
152        // 优选标签匹配
153        if !requirements.preferred_tags.is_empty() {
154            weight += 0.5;
155            let matched = requirements
156                .preferred_tags
157                .iter()
158                .filter(|t| self.tags.contains(*t))
159                .count();
160            score += 0.5 * (matched as f64 / requirements.preferred_tags.len() as f64);
161        }
162
163        // 额外能力加分
164        if self.supports_streaming {
165            score += 0.1;
166            weight += 0.1;
167        }
168        if self.supports_tools {
169            score += 0.1;
170            weight += 0.1;
171        }
172
173        score / weight
174    }
175}
176
177/// Agent 能力构建器
178#[derive(Debug, Default)]
179pub struct AgentCapabilitiesBuilder {
180    capabilities: AgentCapabilities,
181}
182
183impl AgentCapabilitiesBuilder {
184    /// 创建新的构建器
185    pub fn new() -> Self {
186        Self::default()
187    }
188
189    /// 添加标签
190    pub fn tag(mut self, tag: impl Into<String>) -> Self {
191        self.capabilities.tags.insert(tag.into());
192        self
193    }
194
195    /// 添加标签 (别名)
196    pub fn with_tag(self, tag: impl Into<String>) -> Self {
197        self.tag(tag)
198    }
199
200    /// 添加多个标签
201    pub fn tags(mut self, tags: impl IntoIterator<Item = impl Into<String>>) -> Self {
202        for tag in tags {
203            self.capabilities.tags.insert(tag.into());
204        }
205        self
206    }
207
208    /// 添加输入类型
209    pub fn input_type(mut self, input_type: InputType) -> Self {
210        self.capabilities.input_types.insert(input_type);
211        self
212    }
213
214    /// 添加输入类型 (别名)
215    pub fn with_input_type(self, input_type: InputType) -> Self {
216        self.input_type(input_type)
217    }
218
219    /// 添加输出类型
220    pub fn output_type(mut self, output_type: OutputType) -> Self {
221        self.capabilities.output_types.insert(output_type);
222        self
223    }
224
225    /// 添加输出类型 (别名)
226    pub fn with_output_type(self, output_type: OutputType) -> Self {
227        self.output_type(output_type)
228    }
229
230    /// 设置最大上下文长度
231    pub fn max_context_length(mut self, length: usize) -> Self {
232        self.capabilities.max_context_length = Some(length);
233        self
234    }
235
236    /// 添加推理策略
237    pub fn reasoning_strategy(mut self, strategy: ReasoningStrategy) -> Self {
238        self.capabilities.reasoning_strategies.push(strategy);
239        self
240    }
241
242    /// 添加推理策略 (别名)
243    pub fn with_reasoning_strategy(self, strategy: ReasoningStrategy) -> Self {
244        self.reasoning_strategy(strategy)
245    }
246
247    /// 设置流式输出支持
248    pub fn supports_streaming(mut self, supports: bool) -> Self {
249        self.capabilities.supports_streaming = supports;
250        self
251    }
252
253    /// 设置多轮对话支持
254    pub fn supports_conversation(mut self, supports: bool) -> Self {
255        self.capabilities.supports_conversation = supports;
256        self
257    }
258
259    /// 设置工具调用支持
260    pub fn supports_tools(mut self, supports: bool) -> Self {
261        self.capabilities.supports_tools = supports;
262        self
263    }
264
265    /// 设置多 Agent 协调支持
266    pub fn supports_coordination(mut self, supports: bool) -> Self {
267        self.capabilities.supports_coordination = supports;
268        self
269    }
270
271    /// 添加自定义能力
272    pub fn custom(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
273        self.capabilities.custom.insert(key.into(), value);
274        self
275    }
276
277    /// 构建能力描述
278    pub fn build(self) -> AgentCapabilities {
279        self.capabilities
280    }
281}
282
283/// Agent 需求描述
284///
285/// 用于查找满足特定需求的 Agent
286#[derive(Debug, Clone, Default, Serialize, Deserialize)]
287pub struct AgentRequirements {
288    /// 必需的标签
289    pub required_tags: HashSet<String>,
290    /// 优选的标签 (用于排序)
291    pub preferred_tags: HashSet<String>,
292    /// 必需的输入类型
293    pub input_types: HashSet<InputType>,
294    /// 必需的输出类型
295    pub output_types: HashSet<OutputType>,
296    /// 是否需要流式输出
297    pub requires_streaming: bool,
298    /// 是否需要工具支持
299    pub requires_tools: bool,
300    /// 是否需要多轮对话
301    pub requires_conversation: bool,
302    /// 是否需要多 Agent 协调
303    pub requires_coordination: bool,
304}
305
306impl AgentRequirements {
307    /// 创建新的需求描述
308    pub fn new() -> Self {
309        Self::default()
310    }
311
312    /// 创建构建器
313    pub fn builder() -> AgentRequirementsBuilder {
314        AgentRequirementsBuilder::default()
315    }
316
317    /// 检查给定的能力是否满足需求
318    pub fn matches(&self, capabilities: &AgentCapabilities) -> bool {
319        // 检查必需标签
320        for tag in &self.required_tags {
321            if !capabilities.tags.contains(tag) {
322                return false;
323            }
324        }
325
326        // 检查输入类型
327        for input_type in &self.input_types {
328            if !capabilities.input_types.contains(input_type) {
329                return false;
330            }
331        }
332
333        // 检查输出类型
334        for output_type in &self.output_types {
335            if !capabilities.output_types.contains(output_type) {
336                return false;
337            }
338        }
339
340        // 检查流式输出
341        if self.requires_streaming && !capabilities.supports_streaming {
342            return false;
343        }
344
345        // 检查工具支持
346        if self.requires_tools && !capabilities.supports_tools {
347            return false;
348        }
349
350        // 检查多轮对话
351        if self.requires_conversation && !capabilities.supports_conversation {
352            return false;
353        }
354
355        // 检查多 Agent 协调
356        if self.requires_coordination && !capabilities.supports_coordination {
357            return false;
358        }
359
360        true
361    }
362
363    /// 计算匹配分数 (用于排序)
364    pub fn score(&self, capabilities: &AgentCapabilities) -> f32 {
365        if !self.matches(capabilities) {
366            return 0.0;
367        }
368
369        let mut score = 1.0;
370
371        // 优选标签匹配加分
372        let preferred_count = self
373            .preferred_tags
374            .iter()
375            .filter(|tag| capabilities.tags.contains(*tag))
376            .count();
377
378        if !self.preferred_tags.is_empty() {
379            score += (preferred_count as f32) / (self.preferred_tags.len() as f32);
380        }
381
382        score
383    }
384}
385
386/// Agent 需求构建器
387#[derive(Debug, Default)]
388pub struct AgentRequirementsBuilder {
389    requirements: AgentRequirements,
390}
391
392impl AgentRequirementsBuilder {
393    /// 创建新的构建器
394    pub fn new() -> Self {
395        Self::default()
396    }
397
398    /// 添加必需标签
399    pub fn require_tag(mut self, tag: impl Into<String>) -> Self {
400        self.requirements.required_tags.insert(tag.into());
401        self
402    }
403
404    /// 添加优选标签
405    pub fn prefer_tag(mut self, tag: impl Into<String>) -> Self {
406        self.requirements.preferred_tags.insert(tag.into());
407        self
408    }
409
410    /// 添加输入类型需求
411    pub fn require_input(mut self, input_type: InputType) -> Self {
412        self.requirements.input_types.insert(input_type);
413        self
414    }
415
416    /// 添加输出类型需求
417    pub fn require_output(mut self, output_type: OutputType) -> Self {
418        self.requirements.output_types.insert(output_type);
419        self
420    }
421
422    /// 要求流式输出
423    pub fn require_streaming(mut self) -> Self {
424        self.requirements.requires_streaming = true;
425        self
426    }
427
428    /// 要求工具支持
429    pub fn require_tools(mut self) -> Self {
430        self.requirements.requires_tools = true;
431        self
432    }
433
434    /// 要求多轮对话
435    pub fn require_conversation(mut self) -> Self {
436        self.requirements.requires_conversation = true;
437        self
438    }
439
440    /// 要求多 Agent 协调
441    pub fn require_coordination(mut self) -> Self {
442        self.requirements.requires_coordination = true;
443        self
444    }
445
446    /// 构建需求描述
447    pub fn build(self) -> AgentRequirements {
448        self.requirements
449    }
450}
451
452#[cfg(test)]
453mod tests {
454    use super::*;
455
456    #[test]
457    fn test_capabilities_builder() {
458        let caps = AgentCapabilities::builder()
459            .tag("llm")
460            .tag("coding")
461            .input_type(InputType::Text)
462            .output_type(OutputType::Text)
463            .supports_streaming(true)
464            .supports_tools(true)
465            .build();
466
467        assert!(caps.has_tag("llm"));
468        assert!(caps.has_tag("coding"));
469        assert!(caps.supports_input(&InputType::Text));
470        assert!(caps.supports_streaming);
471        assert!(caps.supports_tools);
472    }
473
474    #[test]
475    fn test_capabilities_matching() {
476        let caps = AgentCapabilities::builder()
477            .tag("llm")
478            .tag("coding")
479            .input_type(InputType::Text)
480            .output_type(OutputType::Text)
481            .supports_tools(true)
482            .build();
483
484        let requirements = AgentRequirements::builder()
485            .require_tag("llm")
486            .require_input(InputType::Text)
487            .require_tools()
488            .build();
489
490        assert!(caps.matches(&requirements));
491    }
492
493    #[test]
494    fn test_capabilities_mismatch() {
495        let caps = AgentCapabilities::builder()
496            .tag("llm")
497            .input_type(InputType::Text)
498            .build();
499
500        let requirements = AgentRequirements::builder()
501            .require_tag("coding") // Not present
502            .build();
503
504        assert!(!caps.matches(&requirements));
505    }
506
507    #[test]
508    fn test_match_score() {
509        let caps = AgentCapabilities::builder()
510            .tag("llm")
511            .tag("coding")
512            .tag("research")
513            .supports_streaming(true)
514            .supports_tools(true)
515            .build();
516
517        let requirements = AgentRequirements::builder()
518            .require_tag("llm")
519            .prefer_tag("coding")
520            .prefer_tag("research")
521            .build();
522
523        let score = caps.match_score(&requirements);
524        assert!(score > 0.8);
525    }
526}