Skip to main content

mofa_kernel/agent/
types.rs

1//! Agent 核心类型定义
2//!
3//! 定义统一的 Agent 输入、输出和状态类型
4
5use async_trait::async_trait;
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8use std::fmt;
9
10// 导出统一类型模块
11pub mod error;
12pub mod event;
13pub mod global;
14
15pub use error::{ErrorCategory, ErrorContext, GlobalError, GlobalResult};
16pub use event::{EventBuilder, GlobalEvent};
17pub use event::{execution, lifecycle, message, plugin, state};
18// 重新导出常用类型
19pub use global::{GlobalMessage, MessageContent, MessageMetadata};
20
21// ============================================================================
22// Agent 状态
23// ============================================================================
24
25/// Agent 状态机
26#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize, Default)]
27pub enum AgentState {
28    /// 已创建,未初始化
29    #[default]
30    Created,
31    /// 正在初始化
32    Initializing,
33    /// 就绪,可执行
34    Ready,
35    /// 运行中
36    Running,
37    /// 正在执行
38    Executing,
39    /// 已暂停
40    Paused,
41    /// 已中断
42    Interrupted,
43    /// 正在关闭
44    ShuttingDown,
45    /// 已终止/关闭
46    Shutdown,
47    /// 失败状态
48    Failed,
49    /// 销毁
50    Destroyed,
51    /// 错误状态 (带消息)
52    Error(String),
53}
54
55impl fmt::Display for AgentState {
56    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57        match self {
58            AgentState::Created => write!(f, "Created"),
59            AgentState::Initializing => write!(f, "Initializing"),
60            AgentState::Ready => write!(f, "Ready"),
61            AgentState::Executing => write!(f, "Executing"),
62            AgentState::Paused => write!(f, "Paused"),
63            AgentState::Interrupted => write!(f, "Interrupted"),
64            AgentState::ShuttingDown => write!(f, "ShuttingDown"),
65            AgentState::Shutdown => write!(f, "Shutdown"),
66            AgentState::Failed => write!(f, "Failed"),
67            AgentState::Error(msg) => write!(f, "Error({})", msg),
68            AgentState::Running => {
69                write!(f, "Running")
70            }
71            AgentState::Destroyed => {
72                write!(f, "Destroyed")
73            }
74        }
75    }
76}
77
78impl AgentState {
79    /// 转换到目标状态
80    pub fn transition_to(
81        &self,
82        target: AgentState,
83    ) -> Result<AgentState, super::error::AgentError> {
84        if self.can_transition_to(&target) {
85            Ok(target)
86        } else {
87            Err(super::error::AgentError::invalid_state_transition(
88                self, &target,
89            ))
90        }
91    }
92
93    /// 检查是否可以转换到目标状态
94    pub fn can_transition_to(&self, target: &AgentState) -> bool {
95        use AgentState::*;
96        matches!(
97            (self, target),
98            (Created, Initializing)
99                | (Initializing, Ready)
100                | (Initializing, Error(_))
101                | (Initializing, Failed)
102                | (Ready, Executing)
103                | (Ready, ShuttingDown)
104                | (Executing, Ready)
105                | (Executing, Paused)
106                | (Executing, Interrupted)
107                | (Executing, Error(_))
108                | (Executing, Failed)
109                | (Paused, Ready)
110                | (Paused, Executing)
111                | (Paused, ShuttingDown)
112                | (Interrupted, Ready)
113                | (Interrupted, ShuttingDown)
114                | (ShuttingDown, Shutdown)
115                | (Error(_), ShuttingDown)
116                | (Error(_), Shutdown)
117                | (Failed, ShuttingDown)
118                | (Failed, Shutdown)
119        )
120    }
121
122    /// 是否为活动状态
123    pub fn is_active(&self) -> bool {
124        matches!(self, AgentState::Ready | AgentState::Executing)
125    }
126
127    /// 是否为终止状态
128    pub fn is_terminal(&self) -> bool {
129        matches!(
130            self,
131            AgentState::Shutdown | AgentState::Failed | AgentState::Error(_)
132        )
133    }
134}
135
136// ============================================================================
137// Agent 输入
138// ============================================================================
139
140/// Agent 输入类型
141#[derive(Debug, Clone, Serialize, Deserialize, Default)]
142pub enum AgentInput {
143    /// 文本输入
144    Text(String),
145    /// 多行文本
146    Texts(Vec<String>),
147    /// 结构化 JSON
148    Json(serde_json::Value),
149    /// 键值对
150    Map(HashMap<String, serde_json::Value>),
151    /// 二进制数据
152    Binary(Vec<u8>),
153    /// 空输入
154    #[default]
155    Empty,
156}
157
158impl AgentInput {
159    /// 创建文本输入
160    pub fn text(s: impl Into<String>) -> Self {
161        Self::Text(s.into())
162    }
163
164    /// 创建 JSON 输入
165    pub fn json(value: serde_json::Value) -> Self {
166        Self::Json(value)
167    }
168
169    /// 创建键值对输入
170    pub fn map(map: HashMap<String, serde_json::Value>) -> Self {
171        Self::Map(map)
172    }
173
174    /// 获取文本内容
175    pub fn as_text(&self) -> Option<&str> {
176        match self {
177            Self::Text(s) => Some(s),
178            _ => None,
179        }
180    }
181
182    /// 转换为文本
183    pub fn to_text(&self) -> String {
184        match self {
185            Self::Text(s) => s.clone(),
186            Self::Texts(v) => v.join("\n"),
187            Self::Json(v) => v.to_string(),
188            Self::Map(m) => serde_json::to_string(m).unwrap_or_default(),
189            Self::Binary(b) => String::from_utf8_lossy(b).to_string(),
190            Self::Empty => String::new(),
191        }
192    }
193
194    /// 获取 JSON 内容
195    pub fn as_json(&self) -> Option<&serde_json::Value> {
196        match self {
197            Self::Json(v) => Some(v),
198            _ => None,
199        }
200    }
201
202    /// 转换为 JSON
203    pub fn to_json(&self) -> serde_json::Value {
204        match self {
205            Self::Text(s) => serde_json::Value::String(s.clone()),
206            Self::Texts(v) => serde_json::json!(v),
207            Self::Json(v) => v.clone(),
208            Self::Map(m) => serde_json::to_value(m).unwrap_or_default(),
209            Self::Binary(b) => serde_json::json!({ "binary": base64_encode(b) }),
210            Self::Empty => serde_json::Value::Null,
211        }
212    }
213
214    /// 是否为空
215    pub fn is_empty(&self) -> bool {
216        matches!(self, Self::Empty)
217    }
218}
219
220impl From<String> for AgentInput {
221    fn from(s: String) -> Self {
222        Self::Text(s)
223    }
224}
225
226impl From<&str> for AgentInput {
227    fn from(s: &str) -> Self {
228        Self::Text(s.to_string())
229    }
230}
231
232impl From<serde_json::Value> for AgentInput {
233    fn from(v: serde_json::Value) -> Self {
234        Self::Json(v)
235    }
236}
237
238// ============================================================================
239// Agent 输出
240// ============================================================================
241
242/// Agent 输出类型
243#[derive(Debug, Clone, Serialize, Deserialize)]
244pub struct AgentOutput {
245    /// 主输出内容
246    pub content: OutputContent,
247    /// 输出元数据
248    pub metadata: HashMap<String, serde_json::Value>,
249    /// 使用的工具
250    pub tools_used: Vec<ToolUsage>,
251    /// 推理步骤 (如果有)
252    pub reasoning_steps: Vec<ReasoningStep>,
253    /// 执行时间 (毫秒)
254    pub duration_ms: u64,
255    /// Token 使用统计
256    pub token_usage: Option<TokenUsage>,
257}
258
259impl Default for AgentOutput {
260    fn default() -> Self {
261        Self {
262            content: OutputContent::Empty,
263            metadata: HashMap::new(),
264            tools_used: Vec::new(),
265            reasoning_steps: Vec::new(),
266            duration_ms: 0,
267            token_usage: None,
268        }
269    }
270}
271
272impl AgentOutput {
273    /// 创建文本输出
274    pub fn text(s: impl Into<String>) -> Self {
275        Self {
276            content: OutputContent::Text(s.into()),
277            ..Default::default()
278        }
279    }
280
281    /// 创建 JSON 输出
282    pub fn json(value: serde_json::Value) -> Self {
283        Self {
284            content: OutputContent::Json(value),
285            ..Default::default()
286        }
287    }
288
289    /// 创建错误输出
290    pub fn error(message: impl Into<String>) -> Self {
291        Self {
292            content: OutputContent::Error(message.into()),
293            ..Default::default()
294        }
295    }
296
297    /// 获取文本内容
298    pub fn as_text(&self) -> Option<&str> {
299        match &self.content {
300            OutputContent::Text(s) => Some(s),
301            _ => None,
302        }
303    }
304
305    /// 转换为文本
306    pub fn to_text(&self) -> String {
307        self.content.to_text()
308    }
309
310    /// 设置执行时间
311    pub fn with_duration(mut self, duration_ms: u64) -> Self {
312        self.duration_ms = duration_ms;
313        self
314    }
315
316    /// 添加元数据
317    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
318        self.metadata.insert(key.into(), value);
319        self
320    }
321
322    /// 添加工具使用记录
323    pub fn with_tool_usage(mut self, usage: ToolUsage) -> Self {
324        self.tools_used.push(usage);
325        self
326    }
327
328    /// 设置所有工具使用记录
329    pub fn with_tools_used(mut self, usages: Vec<ToolUsage>) -> Self {
330        self.tools_used = usages;
331        self
332    }
333
334    /// 添加推理步骤
335    pub fn with_reasoning_step(mut self, step: ReasoningStep) -> Self {
336        self.reasoning_steps.push(step);
337        self
338    }
339
340    /// 设置所有推理步骤
341    pub fn with_reasoning_steps(mut self, steps: Vec<ReasoningStep>) -> Self {
342        self.reasoning_steps = steps;
343        self
344    }
345
346    /// 设置 Token 使用
347    pub fn with_token_usage(mut self, usage: TokenUsage) -> Self {
348        self.token_usage = Some(usage);
349        self
350    }
351
352    /// 是否为错误
353    pub fn is_error(&self) -> bool {
354        matches!(self.content, OutputContent::Error(_))
355    }
356}
357
358/// 输出内容类型
359#[derive(Debug, Clone, Serialize, Deserialize)]
360pub enum OutputContent {
361    /// 文本输出
362    Text(String),
363    /// 多行文本
364    Texts(Vec<String>),
365    /// JSON 输出
366    Json(serde_json::Value),
367    /// 二进制输出
368    Binary(Vec<u8>),
369    /// 流式输出标记
370    Stream,
371    /// 错误输出
372    Error(String),
373    /// 空输出
374    Empty,
375}
376
377impl OutputContent {
378    /// 转换为文本
379    pub fn to_text(&self) -> String {
380        match self {
381            Self::Text(s) => s.clone(),
382            Self::Texts(v) => v.join("\n"),
383            Self::Json(v) => v.to_string(),
384            Self::Binary(b) => String::from_utf8_lossy(b).to_string(),
385            Self::Stream => "[STREAM]".to_string(),
386            Self::Error(e) => format!("Error: {}", e),
387            Self::Empty => String::new(),
388        }
389    }
390}
391
392// ============================================================================
393// 辅助类型
394// ============================================================================
395
396/// 工具使用记录
397#[derive(Debug, Clone, Serialize, Deserialize)]
398pub struct ToolUsage {
399    /// 工具名称
400    pub name: String,
401    /// 工具输入
402    pub input: serde_json::Value,
403    /// 工具输出
404    pub output: Option<serde_json::Value>,
405    /// 是否成功
406    pub success: bool,
407    /// 错误信息
408    pub error: Option<String>,
409    /// 执行时间 (毫秒)
410    pub duration_ms: u64,
411}
412
413impl ToolUsage {
414    /// 创建成功的工具使用记录
415    pub fn success(
416        name: impl Into<String>,
417        input: serde_json::Value,
418        output: serde_json::Value,
419        duration_ms: u64,
420    ) -> Self {
421        Self {
422            name: name.into(),
423            input,
424            output: Some(output),
425            success: true,
426            error: None,
427            duration_ms,
428        }
429    }
430
431    /// 创建失败的工具使用记录
432    pub fn failure(
433        name: impl Into<String>,
434        input: serde_json::Value,
435        error: impl Into<String>,
436        duration_ms: u64,
437    ) -> Self {
438        Self {
439            name: name.into(),
440            input,
441            output: None,
442            success: false,
443            error: Some(error.into()),
444            duration_ms,
445        }
446    }
447}
448
449/// 推理步骤
450#[derive(Debug, Clone, Serialize, Deserialize)]
451pub struct ReasoningStep {
452    /// 步骤类型
453    pub step_type: ReasoningStepType,
454    /// 步骤内容
455    pub content: String,
456    /// 步骤序号
457    pub step_number: usize,
458    /// 时间戳
459    pub timestamp_ms: u64,
460}
461
462impl ReasoningStep {
463    /// 创建新的推理步骤
464    pub fn new(
465        step_type: ReasoningStepType,
466        content: impl Into<String>,
467        step_number: usize,
468    ) -> Self {
469        let now = std::time::SystemTime::now()
470            .duration_since(std::time::UNIX_EPOCH)
471            .unwrap_or_default()
472            .as_millis() as u64;
473
474        Self {
475            step_type,
476            content: content.into(),
477            step_number,
478            timestamp_ms: now,
479        }
480    }
481}
482
483/// 推理步骤类型
484#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
485pub enum ReasoningStepType {
486    /// 思考
487    Thought,
488    /// 行动
489    Action,
490    /// 观察
491    Observation,
492    /// 反思
493    Reflection,
494    /// 决策
495    Decision,
496    /// 最终答案
497    FinalAnswer,
498    /// 自定义
499    Custom(String),
500}
501
502/// Token 使用统计
503#[derive(Debug, Clone, Default, Serialize, Deserialize)]
504pub struct TokenUsage {
505    /// 提示词 tokens
506    pub prompt_tokens: u32,
507    /// 完成 tokens
508    pub completion_tokens: u32,
509    /// 总 tokens
510    pub total_tokens: u32,
511}
512
513impl TokenUsage {
514    pub fn new(prompt_tokens: u32, completion_tokens: u32) -> Self {
515        let total_tokens = prompt_tokens + completion_tokens;
516        Self {
517            prompt_tokens,
518            completion_tokens,
519            total_tokens,
520        }
521    }
522}
523
524// ============================================================================
525// LLM 相关类型
526// ============================================================================
527
528/// LLM 聊天完成请求
529#[derive(Debug, Clone)]
530pub struct ChatCompletionRequest {
531    /// Messages for the chat completion
532    pub messages: Vec<ChatMessage>,
533    /// Model to use
534    pub model: Option<String>,
535    /// Tool definitions (if tools are available)
536    pub tools: Option<Vec<ToolDefinition>>,
537    /// Temperature
538    pub temperature: Option<f32>,
539    /// Max tokens
540    pub max_tokens: Option<u32>,
541}
542
543/// 聊天消息
544#[derive(Debug, Clone, Serialize, Deserialize)]
545pub struct ChatMessage {
546    /// Role: system, user, assistant, tool
547    pub role: String,
548    /// Content (text or structured)
549    pub content: Option<String>,
550    /// Tool call ID (for tool responses)
551    pub tool_call_id: Option<String>,
552    /// Tool calls (for assistant messages with tools)
553    pub tool_calls: Option<Vec<ToolCall>>,
554}
555
556/// LLM 工具调用
557#[derive(Debug, Clone, Serialize, Deserialize)]
558pub struct ToolCall {
559    /// Tool call ID
560    pub id: String,
561    /// Tool name
562    pub name: String,
563    /// Tool arguments (as JSON string or Value)
564    pub arguments: serde_json::Value,
565}
566
567/// LLM 工具定义
568#[derive(Debug, Clone, Serialize, Deserialize)]
569pub struct ToolDefinition {
570    /// Tool name
571    pub name: String,
572    /// Tool description
573    pub description: String,
574    /// Tool parameters (JSON Schema)
575    pub parameters: serde_json::Value,
576}
577
578/// LLM 聊天完成响应
579#[derive(Debug, Clone)]
580pub struct ChatCompletionResponse {
581    /// Response content
582    pub content: Option<String>,
583    /// Tool calls from the LLM
584    pub tool_calls: Option<Vec<ToolCall>>,
585    /// Usage statistics
586    pub usage: Option<TokenUsage>,
587}
588
589/// LLM Provider trait - 定义 LLM 提供商接口
590///
591/// 这是一个核心抽象,定义了所有 LLM 提供商必须实现的最小接口。
592///
593/// # 示例
594///
595/// ```rust,ignore
596/// use mofa_kernel::agent::types::{LLMProvider, ChatCompletionRequest, ChatCompletionResponse};
597///
598/// struct MyLLMProvider;
599///
600/// #[async_trait]
601/// impl LLMProvider for MyLLMProvider {
602///     fn name(&self) -> &str { "my-llm" }
603///
604///     async fn chat(&self, request: ChatCompletionRequest) -> AgentResult<ChatCompletionResponse> {
605///         // 实现 LLM 调用逻辑
606///     }
607/// }
608/// ```
609#[async_trait]
610pub trait LLMProvider: Send + Sync {
611    /// Get provider name
612    fn name(&self) -> &str;
613
614    /// Complete a chat request
615    async fn chat(
616        &self,
617        request: ChatCompletionRequest,
618    ) -> super::error::AgentResult<ChatCompletionResponse>;
619}
620
621// ============================================================================
622// 中断处理
623// ============================================================================
624
625/// 中断处理结果
626#[derive(Debug, Clone, Serialize, Deserialize)]
627pub enum InterruptResult {
628    /// 中断已确认,继续执行
629    Acknowledged,
630    /// 中断导致暂停
631    Paused,
632    /// 已中断(带部分结果)
633    Interrupted {
634        /// 部分结果
635        partial_result: Option<String>,
636    },
637    /// 中断导致任务终止
638    TaskTerminated {
639        /// 部分结果
640        partial_result: Option<AgentOutput>,
641    },
642    /// 中断被忽略(Agent 在关键区段)
643    Ignored,
644}
645
646// ============================================================================
647// 输入输出类型
648// ============================================================================
649
650/// 支持的输入类型
651#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
652pub enum InputType {
653    Text,
654    Image,
655    Audio,
656    Video,
657    Structured(String),
658    Binary,
659}
660
661/// 支持的输出类型
662#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
663pub enum OutputType {
664    Text,
665    Json,
666    StructuredJson,
667    Stream,
668    Binary,
669    Multimodal,
670}
671
672// ============================================================================
673// 辅助函数
674// ============================================================================
675
676fn base64_encode(data: &[u8]) -> String {
677    const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
678    let mut result = Vec::new();
679
680    for chunk in data.chunks(3) {
681        let (n, _pad) = match chunk.len() {
682            1 => (((chunk[0] as u32) << 16), 2),
683            2 => (((chunk[0] as u32) << 16) | ((chunk[1] as u32) << 8), 1),
684            _ => (
685                ((chunk[0] as u32) << 16) | ((chunk[1] as u32) << 8) | (chunk[2] as u32),
686                0,
687            ),
688        };
689
690        result.push(CHARS[((n >> 18) & 0x3F) as usize]);
691        result.push(CHARS[((n >> 12) & 0x3F) as usize]);
692
693        if chunk.len() > 1 {
694            result.push(CHARS[((n >> 6) & 0x3F) as usize]);
695        } else {
696            result.push(b'=');
697        }
698
699        if chunk.len() > 2 {
700            result.push(CHARS[(n & 0x3F) as usize]);
701        } else {
702            result.push(b'=');
703        }
704    }
705
706    String::from_utf8(result).unwrap_or_default()
707}
708
709#[cfg(test)]
710mod tests {
711    use super::*;
712
713    #[test]
714    fn test_agent_state_transitions() {
715        let state = AgentState::Created;
716        assert!(state.can_transition_to(&AgentState::Initializing));
717        assert!(!state.can_transition_to(&AgentState::Executing));
718    }
719
720    #[test]
721    fn test_agent_input_text() {
722        let input = AgentInput::text("Hello");
723        assert_eq!(input.as_text(), Some("Hello"));
724        assert_eq!(input.to_text(), "Hello");
725    }
726
727    #[test]
728    fn test_agent_output_text() {
729        let output = AgentOutput::text("World")
730            .with_duration(100)
731            .with_metadata("key", serde_json::json!("value"));
732
733        assert_eq!(output.as_text(), Some("World"));
734        assert_eq!(output.duration_ms, 100);
735        assert!(output.metadata.contains_key("key"));
736    }
737
738    #[test]
739    fn test_tool_usage() {
740        let usage = ToolUsage::success(
741            "calculator",
742            serde_json::json!({"a": 1, "b": 2}),
743            serde_json::json!(3),
744            50,
745        );
746        assert!(usage.success);
747        assert_eq!(usage.name, "calculator");
748    }
749}