Skip to main content

mofa_kernel/agent/config/
schema.rs

1//! 配置 Schema 定义
2//!
3//! 定义 Agent 的配置结构
4
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8// ============================================================================
9// 主配置结构
10// ============================================================================
11
12/// Agent 配置
13///
14/// 统一的 Agent 配置结构,支持多种 Agent 类型
15///
16/// # 示例
17///
18/// ```rust,ignore
19/// use mofa_kernel::agent::config::{AgentConfig, AgentType, LlmAgentConfig};
20///
21/// let config = AgentConfig {
22///     id: "my-agent".to_string(),
23///     name: "My LLM Agent".to_string(),
24///     description: Some("A helpful assistant".to_string()),
25///     agent_type: AgentType::Llm(LlmAgentConfig {
26///         model: "gpt-4".to_string(),
27///         ..Default::default()
28///     }),
29///     ..Default::default()
30/// };
31/// ```
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct AgentConfig {
34    /// Agent ID (唯一标识符)
35    pub id: String,
36
37    /// Agent 名称 (显示名)
38    pub name: String,
39
40    /// Agent 描述
41    #[serde(default)]
42    pub description: Option<String>,
43
44    /// Agent 类型配置
45    #[serde(flatten)]
46    pub agent_type: AgentType,
47
48    /// 组件配置
49    #[serde(default)]
50    pub components: ComponentsConfig,
51
52    /// 能力配置
53    #[serde(default)]
54    pub capabilities: CapabilitiesConfig,
55
56    /// 自定义配置
57    #[serde(default)]
58    pub custom: HashMap<String, serde_json::Value>,
59
60    /// 环境变量映射
61    #[serde(default)]
62    pub env_mappings: HashMap<String, String>,
63
64    /// 是否启用
65    #[serde(default = "default_enabled")]
66    pub enabled: bool,
67
68    /// 版本号
69    #[serde(default)]
70    pub version: Option<String>,
71}
72
73fn default_enabled() -> bool {
74    true
75}
76
77impl Default for AgentConfig {
78    fn default() -> Self {
79        Self {
80            id: String::new(),
81            name: String::new(),
82            description: None,
83            agent_type: AgentType::default(),
84            components: ComponentsConfig::default(),
85            capabilities: CapabilitiesConfig::default(),
86            custom: HashMap::new(),
87            env_mappings: HashMap::new(),
88            enabled: true,
89            version: None,
90        }
91    }
92}
93
94impl AgentConfig {
95    /// 创建新配置
96    pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
97        Self {
98            id: id.into(),
99            name: name.into(),
100            ..Default::default()
101        }
102    }
103
104    /// 设置描述
105    pub fn with_description(mut self, description: impl Into<String>) -> Self {
106        self.description = Some(description.into());
107        self
108    }
109
110    /// 设置 Agent 类型
111    pub fn with_type(mut self, agent_type: AgentType) -> Self {
112        self.agent_type = agent_type;
113        self
114    }
115
116    /// 添加自定义配置
117    pub fn with_custom(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
118        self.custom.insert(key.into(), value);
119        self
120    }
121
122    /// 获取自定义配置
123    pub fn get_custom<T: serde::de::DeserializeOwned>(&self, key: &str) -> Option<T> {
124        self.custom
125            .get(key)
126            .and_then(|v| serde_json::from_value(v.clone()).ok())
127    }
128
129    /// 验证配置
130    pub fn validate(&self) -> Result<(), Vec<String>> {
131        let mut errors = Vec::new();
132
133        if self.id.is_empty() {
134            errors.push("Agent ID cannot be empty".to_string());
135        }
136
137        if self.name.is_empty() {
138            errors.push("Agent name cannot be empty".to_string());
139        }
140
141        // 验证类型特定配置
142        if let Err(type_errors) = self.agent_type.validate() {
143            errors.extend(type_errors);
144        }
145
146        if errors.is_empty() {
147            Ok(())
148        } else {
149            Err(errors)
150        }
151    }
152}
153
154// ============================================================================
155// Agent 类型
156// ============================================================================
157
158/// Agent 类型
159#[derive(Debug, Clone, Serialize, Deserialize)]
160#[serde(tag = "type", rename_all = "snake_case")]
161pub enum AgentType {
162    /// LLM Agent
163    Llm(LlmAgentConfig),
164
165    /// ReAct Agent
166    #[serde(rename = "react")]
167    ReAct(ReActAgentConfig),
168
169    /// 工作流 Agent
170    Workflow(WorkflowAgentConfig),
171
172    /// 团队 Agent
173    Team(TeamAgentConfig),
174
175    /// 自定义 Agent
176    Custom {
177        /// 类路径或插件标识
178        class_path: String,
179        /// 自定义配置
180        #[serde(default)]
181        config: HashMap<String, serde_json::Value>,
182    },
183}
184
185impl Default for AgentType {
186    fn default() -> Self {
187        Self::Llm(LlmAgentConfig::default())
188    }
189}
190
191impl AgentType {
192    /// 获取类型名称
193    pub fn type_name(&self) -> &str {
194        match self {
195            Self::Llm(_) => "llm",
196            Self::ReAct(_) => "react",
197            Self::Workflow(_) => "workflow",
198            Self::Team(_) => "team",
199            Self::Custom { .. } => "custom",
200        }
201    }
202
203    /// 验证类型配置
204    pub fn validate(&self) -> Result<(), Vec<String>> {
205        match self {
206            Self::Llm(config) => config.validate(),
207            Self::ReAct(config) => config.validate(),
208            Self::Workflow(config) => config.validate(),
209            Self::Team(config) => config.validate(),
210            Self::Custom { class_path, .. } => {
211                if class_path.is_empty() {
212                    Err(vec!["Custom agent class_path cannot be empty".to_string()])
213                } else {
214                    Ok(())
215                }
216            }
217        }
218    }
219}
220
221// ============================================================================
222// LLM Agent 配置
223// ============================================================================
224
225/// LLM Agent 配置
226#[derive(Debug, Clone, Serialize, Deserialize)]
227pub struct LlmAgentConfig {
228    /// 模型名称
229    pub model: String,
230
231    /// 系统提示词
232    #[serde(default)]
233    pub system_prompt: Option<String>,
234
235    /// 温度参数
236    #[serde(default = "default_temperature")]
237    pub temperature: f32,
238
239    /// 最大 token 数
240    #[serde(default)]
241    pub max_tokens: Option<u32>,
242
243    /// Top P 参数
244    #[serde(default)]
245    pub top_p: Option<f32>,
246
247    /// 停止序列
248    #[serde(default)]
249    pub stop_sequences: Vec<String>,
250
251    /// 是否启用流式输出
252    #[serde(default)]
253    pub streaming: bool,
254
255    /// API Key 环境变量名
256    #[serde(default)]
257    pub api_key_env: Option<String>,
258
259    /// API Base URL
260    #[serde(default)]
261    pub base_url: Option<String>,
262
263    /// 额外参数
264    #[serde(default)]
265    pub extra: HashMap<String, serde_json::Value>,
266}
267
268fn default_temperature() -> f32 {
269    0.7
270}
271
272impl Default for LlmAgentConfig {
273    fn default() -> Self {
274        Self {
275            model: "gpt-4".to_string(),
276            system_prompt: None,
277            temperature: 0.7,
278            max_tokens: None,
279            top_p: None,
280            stop_sequences: Vec::new(),
281            streaming: false,
282            api_key_env: None,
283            base_url: None,
284            extra: HashMap::new(),
285        }
286    }
287}
288
289impl LlmAgentConfig {
290    /// 验证配置
291    pub fn validate(&self) -> Result<(), Vec<String>> {
292        let mut errors = Vec::new();
293
294        if self.model.is_empty() {
295            errors.push("LLM model cannot be empty".to_string());
296        }
297
298        if self.temperature < 0.0 || self.temperature > 2.0 {
299            errors.push("Temperature must be between 0.0 and 2.0".to_string());
300        }
301
302        if let Some(top_p) = self.top_p
303            && (!(0.0..=1.0).contains(&top_p))
304        {
305            errors.push("Top P must be between 0.0 and 1.0".to_string());
306        }
307
308        if errors.is_empty() {
309            Ok(())
310        } else {
311            Err(errors)
312        }
313    }
314}
315
316// ============================================================================
317// ReAct Agent 配置
318// ============================================================================
319
320/// ReAct Agent 配置
321#[derive(Debug, Clone, Serialize, Deserialize)]
322pub struct ReActAgentConfig {
323    /// LLM 配置
324    pub llm: LlmAgentConfig,
325
326    /// 最大推理步数
327    #[serde(default = "default_max_steps")]
328    pub max_steps: usize,
329
330    /// 工具配置
331    #[serde(default)]
332    pub tools: Vec<ToolConfig>,
333
334    /// 是否启用并行工具调用
335    #[serde(default)]
336    pub parallel_tool_calls: bool,
337
338    /// 思考格式
339    #[serde(default)]
340    pub thought_format: Option<String>,
341}
342
343fn default_max_steps() -> usize {
344    10
345}
346
347impl Default for ReActAgentConfig {
348    fn default() -> Self {
349        Self {
350            llm: LlmAgentConfig::default(),
351            max_steps: 10,
352            tools: Vec::new(),
353            parallel_tool_calls: false,
354            thought_format: None,
355        }
356    }
357}
358
359impl ReActAgentConfig {
360    /// 验证配置
361    pub fn validate(&self) -> Result<(), Vec<String>> {
362        let mut errors = Vec::new();
363
364        if let Err(llm_errors) = self.llm.validate() {
365            errors.extend(llm_errors);
366        }
367
368        if self.max_steps == 0 {
369            errors.push("ReAct max_steps must be greater than 0".to_string());
370        }
371
372        if errors.is_empty() {
373            Ok(())
374        } else {
375            Err(errors)
376        }
377    }
378}
379
380/// 工具配置
381#[derive(Debug, Clone, Serialize, Deserialize)]
382pub struct ToolConfig {
383    /// 工具名称
384    pub name: String,
385
386    /// 工具类型
387    #[serde(default)]
388    pub tool_type: ToolType,
389
390    /// 工具配置
391    #[serde(default)]
392    pub config: HashMap<String, serde_json::Value>,
393
394    /// 是否启用
395    #[serde(default = "default_enabled")]
396    pub enabled: bool,
397}
398
399/// 工具类型
400#[derive(Debug, Clone, Default, Serialize, Deserialize)]
401#[serde(rename_all = "snake_case")]
402pub enum ToolType {
403    /// 内置工具
404    #[default]
405    Builtin,
406    /// MCP 工具
407    Mcp,
408    /// 自定义工具
409    Custom,
410    /// 插件工具
411    Plugin,
412}
413
414// ============================================================================
415// Workflow Agent 配置
416// ============================================================================
417
418/// Workflow Agent 配置
419#[derive(Debug, Clone, Serialize, Deserialize, Default)]
420pub struct WorkflowAgentConfig {
421    /// 工作流步骤
422    pub steps: Vec<WorkflowStep>,
423
424    /// 是否启用并行执行
425    #[serde(default)]
426    pub parallel: bool,
427
428    /// 错误处理策略
429    #[serde(default)]
430    pub error_strategy: ErrorStrategy,
431}
432
433impl WorkflowAgentConfig {
434    /// 验证配置
435    pub fn validate(&self) -> Result<(), Vec<String>> {
436        let mut errors = Vec::new();
437
438        if self.steps.is_empty() {
439            errors.push("Workflow steps cannot be empty".to_string());
440        }
441
442        for (i, step) in self.steps.iter().enumerate() {
443            if step.agent_id.is_empty() {
444                errors.push(format!("Workflow step {} agent_id cannot be empty", i));
445            }
446        }
447
448        if errors.is_empty() {
449            Ok(())
450        } else {
451            Err(errors)
452        }
453    }
454}
455
456/// 工作流步骤
457#[derive(Debug, Clone, Serialize, Deserialize)]
458pub struct WorkflowStep {
459    /// 步骤 ID
460    pub id: String,
461
462    /// Agent ID
463    pub agent_id: String,
464
465    /// 输入映射
466    #[serde(default)]
467    pub input_mapping: HashMap<String, String>,
468
469    /// 输出映射
470    #[serde(default)]
471    pub output_mapping: HashMap<String, String>,
472
473    /// 条件表达式
474    #[serde(default)]
475    pub condition: Option<String>,
476
477    /// 超时 (毫秒)
478    #[serde(default)]
479    pub timeout_ms: Option<u64>,
480}
481
482/// 错误处理策略
483#[derive(Debug, Clone, Default, Serialize, Deserialize)]
484#[serde(rename_all = "snake_case")]
485pub enum ErrorStrategy {
486    /// 快速失败
487    #[default]
488    FailFast,
489    /// 继续执行
490    Continue,
491    /// 重试
492    Retry { max_retries: usize, delay_ms: u64 },
493    /// 回退
494    Fallback { fallback_agent_id: String },
495}
496
497// ============================================================================
498// Team Agent 配置
499// ============================================================================
500
501/// Team Agent 配置
502#[derive(Debug, Clone, Serialize, Deserialize, Default)]
503pub struct TeamAgentConfig {
504    /// 团队成员
505    pub members: Vec<TeamMember>,
506
507    /// 协调模式
508    #[serde(default)]
509    pub coordination: CoordinationMode,
510
511    /// 领导者 Agent ID (用于 Hierarchical 模式)
512    #[serde(default)]
513    pub leader_id: Option<String>,
514
515    /// 任务分发策略
516    #[serde(default)]
517    pub dispatch_strategy: DispatchStrategy,
518}
519
520impl TeamAgentConfig {
521    /// 验证配置
522    pub fn validate(&self) -> Result<(), Vec<String>> {
523        let mut errors = Vec::new();
524
525        if self.members.is_empty() {
526            errors.push("Team members cannot be empty".to_string());
527        }
528
529        if matches!(self.coordination, CoordinationMode::Hierarchical) && self.leader_id.is_none() {
530            errors.push("Hierarchical coordination requires leader_id".to_string());
531        }
532
533        for member in &self.members {
534            if member.agent_id.is_empty() {
535                errors.push("Team member agent_id cannot be empty".to_string());
536            }
537        }
538
539        if errors.is_empty() {
540            Ok(())
541        } else {
542            Err(errors)
543        }
544    }
545}
546
547/// 团队成员
548#[derive(Debug, Clone, Serialize, Deserialize)]
549pub struct TeamMember {
550    /// Agent ID
551    pub agent_id: String,
552
553    /// 角色
554    #[serde(default)]
555    pub role: Option<String>,
556
557    /// 权重 (用于负载均衡)
558    #[serde(default = "default_weight")]
559    pub weight: f32,
560
561    /// 是否为可选成员
562    #[serde(default)]
563    pub optional: bool,
564}
565
566fn default_weight() -> f32 {
567    1.0
568}
569
570/// 协调模式
571#[derive(Debug, Clone, Default, Serialize, Deserialize)]
572#[serde(rename_all = "snake_case")]
573pub enum CoordinationMode {
574    /// 顺序执行
575    #[default]
576    Sequential,
577    /// 并行执行
578    Parallel,
579    /// 层级执行
580    Hierarchical,
581    /// 共识模式
582    Consensus,
583    /// 投票模式
584    Voting,
585    /// 辩论模式
586    Debate,
587}
588
589/// 任务分发策略
590#[derive(Debug, Clone, Default, Serialize, Deserialize)]
591#[serde(rename_all = "snake_case")]
592pub enum DispatchStrategy {
593    /// 广播 (所有成员)
594    #[default]
595    Broadcast,
596    /// 轮询
597    RoundRobin,
598    /// 随机
599    Random,
600    /// 负载均衡
601    LoadBalanced,
602    /// 按能力匹配
603    CapabilityBased,
604}
605
606// ============================================================================
607// 组件配置
608// ============================================================================
609
610/// 组件配置
611#[derive(Debug, Clone, Default, Serialize, Deserialize)]
612pub struct ComponentsConfig {
613    /// 推理器配置
614    #[serde(default)]
615    pub reasoner: Option<ReasonerConfig>,
616
617    /// 记忆配置
618    #[serde(default)]
619    pub memory: Option<MemoryConfig>,
620
621    /// 协调器配置
622    #[serde(default)]
623    pub coordinator: Option<CoordinatorConfig>,
624}
625
626/// 推理器配置
627#[derive(Debug, Clone, Serialize, Deserialize)]
628pub struct ReasonerConfig {
629    /// 推理策略
630    #[serde(default)]
631    pub strategy: ReasonerStrategy,
632
633    /// 自定义配置
634    #[serde(default)]
635    pub config: HashMap<String, serde_json::Value>,
636}
637
638/// 推理策略
639#[derive(Debug, Clone, Default, Serialize, Deserialize)]
640#[serde(rename_all = "snake_case")]
641pub enum ReasonerStrategy {
642    #[default]
643    Direct,
644    ChainOfThought,
645    TreeOfThought,
646    ReAct,
647    Custom,
648}
649
650/// 记忆配置
651#[derive(Debug, Clone, Serialize, Deserialize)]
652pub struct MemoryConfig {
653    /// 记忆类型
654    #[serde(default)]
655    pub memory_type: MemoryType,
656
657    /// 最大记忆项数
658    #[serde(default)]
659    pub max_items: Option<usize>,
660
661    /// 向量数据库配置
662    #[serde(default)]
663    pub vector_db: Option<VectorDbConfig>,
664}
665
666/// 记忆类型
667#[derive(Debug, Clone, Default, Serialize, Deserialize)]
668#[serde(rename_all = "snake_case")]
669pub enum MemoryType {
670    #[default]
671    InMemory,
672    Redis,
673    Sqlite,
674    VectorDb,
675    Custom,
676}
677
678/// 向量数据库配置
679#[derive(Debug, Clone, Serialize, Deserialize)]
680pub struct VectorDbConfig {
681    /// 数据库类型
682    pub db_type: String,
683    /// 连接 URL
684    pub url: String,
685    /// 集合/索引名称
686    #[serde(default)]
687    pub collection: Option<String>,
688}
689
690/// 协调器配置
691#[derive(Debug, Clone, Serialize, Deserialize)]
692pub struct CoordinatorConfig {
693    /// 协调模式
694    #[serde(default)]
695    pub pattern: CoordinationMode,
696
697    /// 超时 (毫秒)
698    #[serde(default)]
699    pub timeout_ms: Option<u64>,
700
701    /// 自定义配置
702    #[serde(default)]
703    pub config: HashMap<String, serde_json::Value>,
704}
705
706// ============================================================================
707// 能力配置
708// ============================================================================
709
710/// 能力配置
711#[derive(Debug, Clone, Default, Serialize, Deserialize)]
712pub struct CapabilitiesConfig {
713    /// 标签
714    #[serde(default)]
715    pub tags: Vec<String>,
716
717    /// 支持的输入类型
718    #[serde(default)]
719    pub input_types: Vec<String>,
720
721    /// 支持的输出类型
722    #[serde(default)]
723    pub output_types: Vec<String>,
724
725    /// 是否支持流式输出
726    #[serde(default)]
727    pub supports_streaming: bool,
728
729    /// 是否支持工具调用
730    #[serde(default)]
731    pub supports_tools: bool,
732
733    /// 是否支持多 Agent 协调
734    #[serde(default)]
735    pub supports_coordination: bool,
736
737    /// 推理策略
738    #[serde(default)]
739    pub reasoning_strategies: Vec<String>,
740}
741
742#[cfg(test)]
743mod tests {
744    use super::*;
745
746    #[test]
747    fn test_agent_config_validation() {
748        let config = AgentConfig::new("test-agent", "Test Agent")
749            .with_type(AgentType::Llm(LlmAgentConfig::default()));
750
751        assert!(config.validate().is_ok());
752    }
753
754    #[test]
755    fn test_empty_config_validation() {
756        let config = AgentConfig::default();
757        assert!(config.validate().is_err());
758    }
759
760    #[test]
761    fn test_llm_config_serialization() {
762        let config = AgentConfig {
763            id: "llm-agent".to_string(),
764            name: "LLM Agent".to_string(),
765            agent_type: AgentType::Llm(LlmAgentConfig {
766                model: "gpt-4".to_string(),
767                temperature: 0.8,
768                ..Default::default()
769            }),
770            ..Default::default()
771        };
772
773        let json = serde_json::to_string_pretty(&config).unwrap();
774        assert!(json.contains("gpt-4"));
775        assert!(json.contains("0.8"));
776    }
777
778    #[test]
779    fn test_react_config_serialization() {
780        let config = AgentConfig {
781            id: "react-agent".to_string(),
782            name: "ReAct Agent".to_string(),
783            agent_type: AgentType::ReAct(ReActAgentConfig {
784                max_steps: 15,
785                ..Default::default()
786            }),
787            ..Default::default()
788        };
789
790        let json = serde_json::to_string(&config).unwrap();
791        assert!(json.contains("react"));
792        assert!(json.contains("15"));
793    }
794
795    #[test]
796    fn test_team_config_validation() {
797        let config = TeamAgentConfig {
798            members: vec![TeamMember {
799                agent_id: "agent-1".to_string(),
800                role: Some("worker".to_string()),
801                weight: 1.0,
802                optional: false,
803            }],
804            coordination: CoordinationMode::Hierarchical,
805            leader_id: None, // Missing leader
806            dispatch_strategy: DispatchStrategy::Broadcast,
807        };
808
809        assert!(config.validate().is_err());
810    }
811}