Skip to main content

mofa_foundation/react/
core.rs

1//! ReAct 核心类型和逻辑
2
3use crate::llm::{LLMAgent, LLMError, LLMResult, Tool};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9/// ReAct 步骤类型
10#[derive(Debug, Clone, Serialize, Deserialize)]
11pub enum ReActStepType {
12    /// 思考步骤
13    Thought,
14    /// 行动步骤
15    Action,
16    /// 观察步骤 (工具执行结果)
17    Observation,
18    /// 最终答案
19    FinalAnswer,
20}
21
22/// ReAct 执行步骤
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ReActStep {
25    /// 步骤类型
26    pub step_type: ReActStepType,
27    /// 步骤内容
28    pub content: String,
29    /// 使用的工具名称 (仅 Action 步骤)
30    #[serde(skip_serializing_if = "Option::is_none")]
31    pub tool_name: Option<String>,
32    /// 工具输入 (仅 Action 步骤)
33    #[serde(skip_serializing_if = "Option::is_none")]
34    pub tool_input: Option<String>,
35    /// 步骤序号
36    pub step_number: usize,
37    /// 时间戳 (毫秒)
38    pub timestamp: u64,
39}
40
41impl ReActStep {
42    pub fn thought(content: impl Into<String>, step_number: usize) -> Self {
43        Self {
44            step_type: ReActStepType::Thought,
45            content: content.into(),
46            tool_name: None,
47            tool_input: None,
48            step_number,
49            timestamp: Self::current_timestamp(),
50        }
51    }
52
53    pub fn action(
54        tool_name: impl Into<String>,
55        tool_input: impl Into<String>,
56        step_number: usize,
57    ) -> Self {
58        let tool_name = tool_name.into();
59        let tool_input = tool_input.into();
60        Self {
61            step_type: ReActStepType::Action,
62            content: format!("Action: {}[{}]", tool_name, tool_input),
63            tool_name: Some(tool_name),
64            tool_input: Some(tool_input),
65            step_number,
66            timestamp: Self::current_timestamp(),
67        }
68    }
69
70    pub fn observation(content: impl Into<String>, step_number: usize) -> Self {
71        Self {
72            step_type: ReActStepType::Observation,
73            content: content.into(),
74            tool_name: None,
75            tool_input: None,
76            step_number,
77            timestamp: Self::current_timestamp(),
78        }
79    }
80
81    pub fn final_answer(content: impl Into<String>, step_number: usize) -> Self {
82        Self {
83            step_type: ReActStepType::FinalAnswer,
84            content: content.into(),
85            tool_name: None,
86            tool_input: None,
87            step_number,
88            timestamp: Self::current_timestamp(),
89        }
90    }
91
92    fn current_timestamp() -> u64 {
93        std::time::SystemTime::now()
94            .duration_since(std::time::UNIX_EPOCH)
95            .unwrap_or_default()
96            .as_millis() as u64
97    }
98}
99
100/// ReAct 执行结果
101#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct ReActResult {
103    /// 任务 ID
104    pub task_id: String,
105    /// 原始任务
106    pub task: String,
107    /// 最终答案
108    pub answer: String,
109    /// 执行步骤
110    pub steps: Vec<ReActStep>,
111    /// 是否成功
112    pub success: bool,
113    /// 错误信息 (如果失败)
114    #[serde(skip_serializing_if = "Option::is_none")]
115    pub error: Option<String>,
116    /// 总迭代次数
117    pub iterations: usize,
118    /// 总耗时 (毫秒)
119    pub duration_ms: u64,
120}
121
122impl ReActResult {
123    pub fn success(
124        task_id: impl Into<String>,
125        task: impl Into<String>,
126        answer: impl Into<String>,
127        steps: Vec<ReActStep>,
128        iterations: usize,
129        duration_ms: u64,
130    ) -> Self {
131        Self {
132            task_id: task_id.into(),
133            task: task.into(),
134            answer: answer.into(),
135            steps,
136            success: true,
137            error: None,
138            iterations,
139            duration_ms,
140        }
141    }
142
143    pub fn failed(
144        task_id: impl Into<String>,
145        task: impl Into<String>,
146        error: impl Into<String>,
147        steps: Vec<ReActStep>,
148        iterations: usize,
149        duration_ms: u64,
150    ) -> Self {
151        Self {
152            task_id: task_id.into(),
153            task: task.into(),
154            answer: String::new(),
155            steps,
156            success: false,
157            error: Some(error.into()),
158            iterations,
159            duration_ms,
160        }
161    }
162}
163
164/// ReAct 工具 trait
165///
166/// 实现此 trait 以创建自定义工具
167#[async_trait::async_trait]
168pub trait ReActTool: Send + Sync {
169    /// 工具名称 (用于 LLM 调用)
170    fn name(&self) -> &str;
171
172    /// 工具描述 (用于 LLM 理解工具功能)
173    fn description(&self) -> &str;
174
175    /// 参数 JSON Schema (可选)
176    fn parameters_schema(&self) -> Option<serde_json::Value> {
177        None
178    }
179
180    /// 执行工具
181    ///
182    /// # 参数
183    /// - `input`: 工具输入 (可以是 JSON 字符串或普通文本)
184    ///
185    /// # 返回
186    /// 工具执行结果
187    async fn execute(&self, input: &str) -> Result<String, String>;
188
189    /// 转换为 LLM Tool 定义
190    fn to_llm_tool(&self) -> Tool {
191        let params = self.parameters_schema().unwrap_or_else(|| {
192            serde_json::json!({
193                "type": "object",
194                "properties": {
195                    "input": {
196                        "type": "string",
197                        "description": "The input for the tool"
198                    }
199                },
200                "required": ["input"]
201            })
202        });
203
204        Tool::function(self.name(), self.description(), params)
205    }
206}
207
208/// ReAct 配置
209#[derive(Debug, Clone)]
210pub struct ReActConfig {
211    /// 最大迭代次数
212    pub max_iterations: usize,
213    /// 是否启用流式输出
214    pub stream_output: bool,
215    /// 思考温度
216    pub temperature: f32,
217    /// 自定义系统提示词
218    pub system_prompt: Option<String>,
219    /// 是否在思考过程中显示详细信息
220    pub verbose: bool,
221    /// 每步最大 token 数
222    pub max_tokens_per_step: Option<u32>,
223}
224
225impl Default for ReActConfig {
226    fn default() -> Self {
227        Self {
228            max_iterations: 10,
229            stream_output: false,
230            temperature: 0.7,
231            system_prompt: None,
232            verbose: true,
233            max_tokens_per_step: Some(2048),
234        }
235    }
236}
237
238impl ReActConfig {
239    pub fn new() -> Self {
240        Self::default()
241    }
242
243    pub fn with_max_iterations(mut self, max: usize) -> Self {
244        self.max_iterations = max;
245        self
246    }
247
248    pub fn with_stream_output(mut self, enabled: bool) -> Self {
249        self.stream_output = enabled;
250        self
251    }
252
253    pub fn with_temperature(mut self, temp: f32) -> Self {
254        self.temperature = temp;
255        self
256    }
257
258    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
259        self.system_prompt = Some(prompt.into());
260        self
261    }
262
263    pub fn with_verbose(mut self, verbose: bool) -> Self {
264        self.verbose = verbose;
265        self
266    }
267}
268
269/// ReAct Agent 核心实现
270pub struct ReActAgent {
271    /// LLM Agent
272    llm: Arc<LLMAgent>,
273    /// 工具注册表
274    tools: Arc<RwLock<HashMap<String, Arc<dyn ReActTool>>>>,
275    /// 配置
276    config: ReActConfig,
277}
278
279impl ReActAgent {
280    /// 创建构建器
281    pub fn builder() -> ReActAgentBuilder {
282        ReActAgentBuilder::new()
283    }
284
285    /// 使用 LLM 和配置创建
286    pub fn new(llm: Arc<LLMAgent>, config: ReActConfig) -> Self {
287        Self {
288            llm,
289            tools: Arc::new(RwLock::new(HashMap::new())),
290            config,
291        }
292    }
293
294    /// 注册工具
295    pub async fn register_tool(&self, tool: Arc<dyn ReActTool>) {
296        let mut tools = self.tools.write().await;
297        tools.insert(tool.name().to_string(), tool);
298    }
299
300    /// 获取所有工具
301    pub async fn get_tools(&self) -> Vec<Arc<dyn ReActTool>> {
302        let tools = self.tools.read().await;
303        tools.values().cloned().collect()
304    }
305
306    /// 执行任务
307    pub async fn run(&self, task: impl Into<String>) -> LLMResult<ReActResult> {
308        let task = task.into();
309        let task_id = uuid::Uuid::now_v7().to_string();
310        let start_time = std::time::Instant::now();
311
312        let mut steps = Vec::new();
313        let mut step_number = 0;
314
315        // 构建系统提示词
316        let system_prompt = self.build_system_prompt().await;
317
318        // 构建初始消息
319        let mut conversation = vec![format!("Task: {}", task)];
320
321        for iteration in 0..self.config.max_iterations {
322            step_number += 1;
323
324            // 获取 LLM 响应
325            let prompt = self.build_prompt(&system_prompt, &conversation).await;
326            let response = self.llm.ask(&prompt).await?;
327
328            // 解析响应
329            let parsed = self.parse_response(&response);
330
331            match parsed {
332                ParsedResponse::Thought(thought) => {
333                    steps.push(ReActStep::thought(&thought, step_number));
334                    conversation.push(format!("Thought: {}", thought));
335
336                    if self.config.verbose {
337                        tracing::info!("Thought: {}", thought);
338                    }
339                }
340                ParsedResponse::Action { tool, input } => {
341                    steps.push(ReActStep::action(&tool, &input, step_number));
342                    conversation.push(format!("Action: {}[{}]", tool, input));
343
344                    if self.config.verbose {
345                        tracing::info!("Action: {}[{}]", tool, input);
346                    }
347
348                    // 执行工具
349                    step_number += 1;
350                    let observation = self.execute_tool(&tool, &input).await;
351                    steps.push(ReActStep::observation(&observation, step_number));
352                    conversation.push(format!("Observation: {}", observation));
353
354                    if self.config.verbose {
355                        tracing::info!("Observation: {}", observation);
356                    }
357                }
358                ParsedResponse::FinalAnswer(answer) => {
359                    steps.push(ReActStep::final_answer(&answer, step_number));
360
361                    if self.config.verbose {
362                        tracing::info!("Final Answer: {}", answer);
363                    }
364
365                    return Ok(ReActResult::success(
366                        task_id,
367                        &task,
368                        answer,
369                        steps,
370                        iteration + 1,
371                        start_time.elapsed().as_millis() as u64,
372                    ));
373                }
374                ParsedResponse::Error(err) => {
375                    return Ok(ReActResult::failed(
376                        task_id,
377                        &task,
378                        err,
379                        steps,
380                        iteration + 1,
381                        start_time.elapsed().as_millis() as u64,
382                    ));
383                }
384            }
385        }
386
387        // 达到最大迭代次数
388        Ok(ReActResult::failed(
389            task_id,
390            &task,
391            format!("Max iterations ({}) exceeded", self.config.max_iterations),
392            steps,
393            self.config.max_iterations,
394            start_time.elapsed().as_millis() as u64,
395        ))
396    }
397
398    /// 构建系统提示词
399    async fn build_system_prompt(&self) -> String {
400        if let Some(ref custom_prompt) = self.config.system_prompt {
401            return custom_prompt.clone();
402        }
403
404        let tools = self.tools.read().await;
405        let tool_descriptions: Vec<String> = tools
406            .values()
407            .map(|t| format!("- {}: {}", t.name(), t.description()))
408            .collect();
409
410        format!(
411            r#"You are a ReAct (Reasoning and Acting) agent. You solve tasks by thinking step by step and using available tools.
412
413Available tools:
414{}
415
416You must respond in one of these formats:
417
4181. When you need to think:
419Thought: <your reasoning about what to do next>
420
4212. When you want to use a tool:
422Action: <tool_name>[<input>]
423
4243. When you have the final answer:
425Final Answer: <your final answer to the task>
426
427Rules:
428- Always start with a Thought
429- Use tools when you need external information
430- Be concise and focused
431- Provide a Final Answer when you have enough information
432- If a tool returns an error, think about alternatives"#,
433            tool_descriptions.join("\n")
434        )
435    }
436
437    /// 构建完整提示词
438    async fn build_prompt(&self, system_prompt: &str, conversation: &[String]) -> String {
439        format!("{}\n\n{}", system_prompt, conversation.join("\n"))
440    }
441
442    /// 解析 LLM 响应
443    fn parse_response(&self, response: &str) -> ParsedResponse {
444        let response = response.trim();
445
446        // 检查 Final Answer
447        if let Some(answer) = response.strip_prefix("Final Answer:") {
448            return ParsedResponse::FinalAnswer(answer.trim().to_string());
449        }
450
451        // 检查 Action
452        if let Some(action_part) = response.strip_prefix("Action:") {
453            let action_part = action_part.trim();
454            if let Some(bracket_start) = action_part.find('[')
455                && let Some(bracket_end) = action_part.rfind(']')
456            {
457                let tool = action_part[..bracket_start].trim().to_string();
458                let input = action_part[bracket_start + 1..bracket_end]
459                    .trim()
460                    .to_string();
461                return ParsedResponse::Action { tool, input };
462            }
463            return ParsedResponse::Error(format!("Invalid action format: {}", action_part));
464        }
465
466        // 检查 Thought
467        if let Some(thought) = response.strip_prefix("Thought:") {
468            return ParsedResponse::Thought(thought.trim().to_string());
469        }
470
471        // 尝试从混合响应中提取
472        for line in response.lines() {
473            let line = line.trim();
474            if line.starts_with("Final Answer:") {
475                return ParsedResponse::FinalAnswer(
476                    line.strip_prefix("Final Answer:")
477                        .unwrap()
478                        .trim()
479                        .to_string(),
480                );
481            }
482            if line.starts_with("Action:") {
483                let action_part = line.strip_prefix("Action:").unwrap().trim();
484                if let Some(bracket_start) = action_part.find('[')
485                    && let Some(bracket_end) = action_part.rfind(']')
486                {
487                    let tool = action_part[..bracket_start].trim().to_string();
488                    let input = action_part[bracket_start + 1..bracket_end]
489                        .trim()
490                        .to_string();
491                    return ParsedResponse::Action { tool, input };
492                }
493            }
494            if line.starts_with("Thought:") {
495                return ParsedResponse::Thought(
496                    line.strip_prefix("Thought:").unwrap().trim().to_string(),
497                );
498            }
499        }
500
501        // 默认作为 Thought 处理
502        ParsedResponse::Thought(response.to_string())
503    }
504
505    /// 执行工具
506    async fn execute_tool(&self, tool_name: &str, input: &str) -> String {
507        let tools = self.tools.read().await;
508
509        match tools.get(tool_name) {
510            Some(tool) => match tool.execute(input).await {
511                Ok(result) => result,
512                Err(e) => format!("Tool error: {}", e),
513            },
514            None => format!(
515                "Tool '{}' not found. Available tools: {:?}",
516                tool_name,
517                tools.keys().collect::<Vec<_>>()
518            ),
519        }
520    }
521}
522
523/// 解析后的响应
524enum ParsedResponse {
525    Thought(String),
526    Action { tool: String, input: String },
527    FinalAnswer(String),
528    Error(String),
529}
530
531/// ReAct Agent 构建器
532pub struct ReActAgentBuilder {
533    llm: Option<Arc<LLMAgent>>,
534    tools: Vec<Arc<dyn ReActTool>>,
535    config: ReActConfig,
536}
537
538impl ReActAgentBuilder {
539    pub fn new() -> Self {
540        Self {
541            llm: None,
542            tools: Vec::new(),
543            config: ReActConfig::default(),
544        }
545    }
546
547    /// 设置 LLM Agent
548    pub fn with_llm(mut self, llm: Arc<LLMAgent>) -> Self {
549        self.llm = Some(llm);
550        self
551    }
552
553    /// 添加工具
554    pub fn with_tool(mut self, tool: Arc<dyn ReActTool>) -> Self {
555        self.tools.push(tool);
556        self
557    }
558
559    /// 添加多个工具
560    pub fn with_tools(mut self, tools: Vec<Arc<dyn ReActTool>>) -> Self {
561        self.tools.extend(tools);
562        self
563    }
564
565    /// 设置最大迭代次数
566    pub fn with_max_iterations(mut self, max: usize) -> Self {
567        self.config.max_iterations = max;
568        self
569    }
570
571    /// 设置温度
572    pub fn with_temperature(mut self, temp: f32) -> Self {
573        self.config.temperature = temp;
574        self
575    }
576
577    /// 设置系统提示词
578    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
579        self.config.system_prompt = Some(prompt.into());
580        self
581    }
582
583    /// 设置是否详细输出
584    pub fn with_verbose(mut self, verbose: bool) -> Self {
585        self.config.verbose = verbose;
586        self
587    }
588
589    /// 设置完整配置
590    pub fn with_config(mut self, config: ReActConfig) -> Self {
591        self.config = config;
592        self
593    }
594
595    /// 构建 ReAct Agent
596    pub fn build(self) -> LLMResult<ReActAgent> {
597        let llm = self
598            .llm
599            .ok_or_else(|| LLMError::ConfigError("LLM agent not set".to_string()))?;
600
601        let agent = ReActAgent::new(llm, self.config);
602
603        // 在运行时注册工具
604        let tools = self.tools;
605        let agent_tools = agent.tools.clone();
606
607        tokio::spawn(async move {
608            let mut tool_map = agent_tools.write().await;
609            for tool in tools {
610                tool_map.insert(tool.name().to_string(), tool);
611            }
612        });
613
614        Ok(agent)
615    }
616
617    /// 异步构建 (确保工具已注册)
618    pub async fn build_async(self) -> LLMResult<ReActAgent> {
619        let llm = self
620            .llm
621            .ok_or_else(|| LLMError::ConfigError("LLM agent not set".to_string()))?;
622
623        let agent = ReActAgent::new(llm, self.config);
624
625        // 注册工具
626        for tool in self.tools {
627            agent.register_tool(tool).await;
628        }
629
630        Ok(agent)
631    }
632}
633
634impl Default for ReActAgentBuilder {
635    fn default() -> Self {
636        Self::new()
637    }
638}
639
640#[cfg(test)]
641mod tests {
642    use super::*;
643
644    #[test]
645    fn test_react_step_creation() {
646        let thought = ReActStep::thought("I need to search for information", 1);
647        assert!(matches!(thought.step_type, ReActStepType::Thought));
648
649        let action = ReActStep::action("search", "capital of France", 2);
650        assert!(matches!(action.step_type, ReActStepType::Action));
651        assert_eq!(action.tool_name, Some("search".to_string()));
652
653        let observation = ReActStep::observation("Paris is the capital of France", 3);
654        assert!(matches!(observation.step_type, ReActStepType::Observation));
655
656        let answer = ReActStep::final_answer("Paris", 4);
657        assert!(matches!(answer.step_type, ReActStepType::FinalAnswer));
658    }
659
660    #[test]
661    fn test_react_config() {
662        let config = ReActConfig::new()
663            .with_max_iterations(5)
664            .with_temperature(0.5)
665            .with_verbose(false);
666
667        assert_eq!(config.max_iterations, 5);
668        assert_eq!(config.temperature, 0.5);
669        assert!(!config.verbose);
670    }
671}