Skip to main content

mofa_foundation/react/
actor.rs

1//! ReAct Actor 实现
2//!
3//! 基于 ractor 的 ReAct Agent Actor 实现
4
5use super::core::{ReActAgent, ReActConfig, ReActResult, ReActStep, ReActTool};
6use crate::llm::{LLMAgent, LLMError, LLMResult};
7use ractor::{Actor, ActorProcessingErr, ActorRef};
8use std::fmt;
9use std::future::Future;
10use std::sync::Arc;
11use tokio::sync::{mpsc, oneshot};
12
13/// ReAct Actor 消息类型
14pub enum ReActMessage {
15    /// 执行任务
16    RunTask {
17        task: String,
18        reply: oneshot::Sender<LLMResult<ReActResult>>,
19    },
20    /// 执行任务并流式返回步骤
21    RunTaskStreaming {
22        task: String,
23        step_tx: mpsc::Sender<ReActStep>,
24        reply: oneshot::Sender<LLMResult<ReActResult>>,
25    },
26    /// 注册工具
27    RegisterTool { tool: Arc<dyn ReActTool> },
28    /// 获取状态
29    GetStatus {
30        reply: oneshot::Sender<ReActActorStatus>,
31    },
32    /// 取消当前任务
33    CancelTask,
34    /// 停止 Actor
35    Stop,
36}
37
38impl fmt::Debug for ReActMessage {
39    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40        match self {
41            Self::RunTask { task, .. } => f.debug_struct("RunTask").field("task", task).finish(),
42            Self::RunTaskStreaming { task, .. } => f
43                .debug_struct("RunTaskStreaming")
44                .field("task", task)
45                .finish(),
46            Self::RegisterTool { tool } => f
47                .debug_struct("RegisterTool")
48                .field("tool_name", &tool.name())
49                .finish(),
50            Self::GetStatus { .. } => f.debug_struct("GetStatus").finish(),
51            Self::CancelTask => f.debug_struct("CancelTask").finish(),
52            Self::Stop => f.debug_struct("Stop").finish(),
53        }
54    }
55}
56
57/// ReAct Actor 状态
58#[derive(Debug, Clone)]
59pub struct ReActActorStatus {
60    /// Actor ID
61    pub id: String,
62    /// 是否正在执行任务
63    pub is_running: bool,
64    /// 已完成的任务数
65    pub completed_tasks: usize,
66    /// 注册的工具数
67    pub tool_count: usize,
68    /// 当前任务 ID
69    pub current_task_id: Option<String>,
70}
71
72/// ReAct Actor 内部状态
73pub struct ReActActorState {
74    /// ReAct Agent 实例
75    agent: Option<ReActAgent>,
76    /// LLM Agent (用于延迟初始化)
77    llm: Option<Arc<LLMAgent>>,
78    /// 配置
79    config: ReActConfig,
80    /// 待注册的工具
81    pending_tools: Vec<Arc<dyn ReActTool>>,
82    /// 是否正在运行任务
83    is_running: bool,
84    /// 已完成任务数
85    completed_tasks: usize,
86    /// 当前任务 ID
87    current_task_id: Option<String>,
88    /// 取消标志
89    #[allow(dead_code)]
90    cancelled: bool,
91}
92
93impl ReActActorState {
94    pub fn new(llm: Arc<LLMAgent>, config: ReActConfig) -> Self {
95        Self {
96            agent: None,
97            llm: Some(llm),
98            config,
99            pending_tools: Vec::new(),
100            is_running: false,
101            completed_tasks: 0,
102            current_task_id: None,
103            cancelled: false,
104        }
105    }
106
107    /// 确保 Agent 已初始化
108    async fn ensure_agent(&mut self) -> LLMResult<&ReActAgent> {
109        if self.agent.is_none() {
110            let llm = self
111                .llm
112                .take()
113                .ok_or_else(|| LLMError::ConfigError("LLM already consumed".to_string()))?;
114
115            let agent = ReActAgent::new(llm, self.config.clone());
116
117            // 注册待注册的工具
118            for tool in self.pending_tools.drain(..) {
119                agent.register_tool(tool).await;
120            }
121
122            self.agent = Some(agent);
123        }
124
125        self.agent
126            .as_ref()
127            .ok_or_else(|| LLMError::Other("Agent not initialized".to_string()))
128    }
129}
130
131/// ReAct Actor
132pub struct ReActActor;
133
134impl ReActActor {
135    /// 创建新的 ReAct Actor
136    pub fn new() -> Self {
137        Self
138    }
139}
140
141impl Default for ReActActor {
142    fn default() -> Self {
143        Self::new()
144    }
145}
146
147impl Actor for ReActActor {
148    type Msg = ReActMessage;
149    type State = ReActActorState;
150    type Arguments = (Arc<LLMAgent>, ReActConfig, Vec<Arc<dyn ReActTool>>);
151
152    fn pre_start(
153        &self,
154        _myself: ActorRef<Self::Msg>,
155        args: Self::Arguments,
156    ) -> impl Future<Output = Result<Self::State, ActorProcessingErr>> + Send {
157        async move {
158            let (llm, config, tools) = args;
159            let mut state = ReActActorState::new(llm, config);
160            state.pending_tools = tools;
161            Ok(state)
162        }
163    }
164
165    fn handle(
166        &self,
167        myself: ActorRef<Self::Msg>,
168        message: Self::Msg,
169        state: &mut Self::State,
170    ) -> impl Future<Output = Result<(), ActorProcessingErr>> + Send {
171        // 我们需要在 future 之前处理 state 的可变借用
172        // 为了避免生命周期问题,我们需要将处理逻辑分离出来
173        handle_message(myself, message, state)
174    }
175}
176
177/// 处理消息的异步函数
178async fn handle_message(
179    myself: ActorRef<ReActMessage>,
180    message: ReActMessage,
181    state: &mut ReActActorState,
182) -> Result<(), ActorProcessingErr> {
183    match message {
184        ReActMessage::RunTask { task, reply } => {
185            if state.is_running {
186                let _ = reply.send(Err(LLMError::Other(
187                    "Agent is already running a task".to_string(),
188                )));
189                return Ok(());
190            }
191
192            state.is_running = true;
193            state.cancelled = false;
194            state.current_task_id = Some(uuid::Uuid::now_v7().to_string());
195
196            let result = match state.ensure_agent().await {
197                Ok(agent) => agent.run(&task).await,
198                Err(e) => Err(e),
199            };
200
201            state.is_running = false;
202            state.current_task_id = None;
203
204            if result.is_ok() {
205                state.completed_tasks += 1;
206            }
207
208            let _ = reply.send(result);
209        }
210
211        ReActMessage::RunTaskStreaming {
212            task,
213            step_tx,
214            reply,
215        } => {
216            if state.is_running {
217                let _ = reply.send(Err(LLMError::Other(
218                    "Agent is already running a task".to_string(),
219                )));
220                return Ok(());
221            }
222
223            state.is_running = true;
224            state.cancelled = false;
225            let task_id = uuid::Uuid::now_v7().to_string();
226            state.current_task_id = Some(task_id.clone());
227
228            // 执行带步骤回调的任务
229            let result = match state.ensure_agent().await {
230                Ok(agent) => {
231                    // 运行任务
232                    let result = agent.run(&task).await;
233
234                    // 发送所有步骤
235                    if let Ok(ref res) = result {
236                        for step in &res.steps {
237                            let _ = step_tx.send(step.clone()).await;
238                        }
239                    }
240
241                    result
242                }
243                Err(e) => Err(e),
244            };
245
246            state.is_running = false;
247            state.current_task_id = None;
248
249            if result.is_ok() {
250                state.completed_tasks += 1;
251            }
252
253            let _ = reply.send(result);
254        }
255
256        ReActMessage::RegisterTool { tool } => {
257            if let Some(ref agent) = state.agent {
258                agent.register_tool(tool).await;
259            } else {
260                state.pending_tools.push(tool);
261            }
262        }
263
264        ReActMessage::GetStatus { reply } => {
265            let tool_count = if let Some(ref agent) = state.agent {
266                agent.get_tools().await.len()
267            } else {
268                state.pending_tools.len()
269            };
270
271            let status = ReActActorStatus {
272                id: state.current_task_id.clone().unwrap_or_default(),
273                is_running: state.is_running,
274                completed_tasks: state.completed_tasks,
275                tool_count,
276                current_task_id: state.current_task_id.clone(),
277            };
278
279            let _ = reply.send(status);
280        }
281
282        ReActMessage::CancelTask => {
283            state.cancelled = true;
284        }
285
286        ReActMessage::Stop => {
287            myself.stop(Some("Stop requested".to_string()));
288        }
289    }
290
291    Ok(())
292}
293
294/// ReAct Actor 引用包装
295///
296/// 提供便捷的方法与 ReAct Actor 交互
297pub struct ReActActorRef {
298    actor: ActorRef<ReActMessage>,
299}
300
301impl ReActActorRef {
302    /// 从 ActorRef 创建
303    pub fn new(actor: ActorRef<ReActMessage>) -> Self {
304        Self { actor }
305    }
306
307    /// 执行任务
308    pub async fn run_task(&self, task: impl Into<String>) -> LLMResult<ReActResult> {
309        let (tx, rx) = oneshot::channel();
310        self.actor
311            .send_message(ReActMessage::RunTask {
312                task: task.into(),
313                reply: tx,
314            })
315            .map_err(|e| LLMError::Other(format!("Failed to send message: {}", e)))?;
316
317        rx.await
318            .map_err(|e| LLMError::Other(format!("Failed to receive response: {}", e)))?
319    }
320
321    /// 执行任务并流式返回步骤
322    pub async fn run_task_streaming(
323        &self,
324        task: impl Into<String>,
325    ) -> LLMResult<(
326        mpsc::Receiver<ReActStep>,
327        oneshot::Receiver<LLMResult<ReActResult>>,
328    )> {
329        let (step_tx, step_rx) = mpsc::channel(100);
330        let (result_tx, result_rx) = oneshot::channel();
331
332        self.actor
333            .send_message(ReActMessage::RunTaskStreaming {
334                task: task.into(),
335                step_tx,
336                reply: result_tx,
337            })
338            .map_err(|e| LLMError::Other(format!("Failed to send message: {}", e)))?;
339
340        Ok((step_rx, result_rx))
341    }
342
343    /// 注册工具
344    pub fn register_tool(&self, tool: Arc<dyn ReActTool>) -> LLMResult<()> {
345        self.actor
346            .send_message(ReActMessage::RegisterTool { tool })
347            .map_err(|e| LLMError::Other(format!("Failed to register tool: {}", e)))
348    }
349
350    /// 获取状态
351    pub async fn get_status(&self) -> LLMResult<ReActActorStatus> {
352        let (tx, rx) = oneshot::channel();
353        self.actor
354            .send_message(ReActMessage::GetStatus { reply: tx })
355            .map_err(|e| LLMError::Other(format!("Failed to send message: {}", e)))?;
356
357        rx.await
358            .map_err(|e| LLMError::Other(format!("Failed to receive status: {}", e)))
359    }
360
361    /// 取消当前任务
362    pub fn cancel_task(&self) -> LLMResult<()> {
363        self.actor
364            .send_message(ReActMessage::CancelTask)
365            .map_err(|e| LLMError::Other(format!("Failed to cancel task: {}", e)))
366    }
367
368    /// 停止 Actor
369    pub fn stop(&self) -> LLMResult<()> {
370        self.actor
371            .send_message(ReActMessage::Stop)
372            .map_err(|e| LLMError::Other(format!("Failed to stop actor: {}", e)))
373    }
374
375    /// 获取内部 ActorRef
376    pub fn inner(&self) -> &ActorRef<ReActMessage> {
377        &self.actor
378    }
379}
380
381/// 启动 ReAct Actor
382///
383/// # 示例
384///
385/// ```rust,ignore
386/// let (actor_ref, handle) = spawn_react_actor(
387///     "my-react-agent",
388///     llm_agent,
389///     ReActConfig::default(),
390///     vec![Arc::new(SearchTool)],
391/// ).await?;
392///
393/// let result = actor_ref.run_task("What is Rust?").await?;
394/// ```
395pub async fn spawn_react_actor(
396    name: impl Into<String>,
397    llm: Arc<LLMAgent>,
398    config: ReActConfig,
399    tools: Vec<Arc<dyn ReActTool>>,
400) -> LLMResult<(ReActActorRef, tokio::task::JoinHandle<()>)> {
401    let (actor_ref, handle) =
402        Actor::spawn(Some(name.into()), ReActActor::new(), (llm, config, tools))
403            .await
404            .map_err(|e| LLMError::Other(format!("Failed to spawn actor: {}", e)))?;
405
406    Ok((ReActActorRef::new(actor_ref), handle))
407}
408
409/// AutoAgent - 自动选择最佳策略的智能 Agent
410///
411/// 根据任务类型自动选择:
412/// - 简单问答:直接 LLM 回答
413/// - 需要搜索:使用搜索工具
414/// - 需要计算:使用计算工具
415/// - 复杂任务:使用完整 ReAct 循环
416pub struct AutoAgent {
417    /// ReAct Agent
418    react_agent: Arc<ReActAgent>,
419    /// 直接 LLM Agent (用于简单问答)
420    llm: Arc<LLMAgent>,
421    /// 是否启用自动模式选择
422    auto_mode: bool,
423}
424
425impl AutoAgent {
426    /// 创建 AutoAgent
427    pub fn new(llm: Arc<LLMAgent>, react_agent: Arc<ReActAgent>) -> Self {
428        Self {
429            react_agent,
430            llm,
431            auto_mode: true,
432        }
433    }
434
435    /// 设置是否自动选择模式
436    pub fn with_auto_mode(mut self, enabled: bool) -> Self {
437        self.auto_mode = enabled;
438        self
439    }
440
441    /// 执行任务
442    pub async fn run(&self, task: impl Into<String>) -> LLMResult<AutoAgentResult> {
443        let task = task.into();
444        let start = std::time::Instant::now();
445
446        if !self.auto_mode {
447            // 强制使用 ReAct
448            let result = self.react_agent.run(&task).await?;
449            let answer = result.answer.clone();
450            return Ok(AutoAgentResult {
451                mode: ExecutionMode::ReAct,
452                answer,
453                react_result: Some(result),
454                duration_ms: start.elapsed().as_millis() as u64,
455            });
456        }
457
458        // 分析任务复杂度
459        let complexity = self.analyze_complexity(&task).await;
460
461        match complexity {
462            TaskComplexity::Simple => {
463                // 直接 LLM 回答
464                let answer = self.llm.ask(&task).await?;
465                Ok(AutoAgentResult {
466                    mode: ExecutionMode::Direct,
467                    answer,
468                    react_result: None,
469                    duration_ms: start.elapsed().as_millis() as u64,
470                })
471            }
472            TaskComplexity::RequiresTool | TaskComplexity::Complex => {
473                // 使用 ReAct
474                let result = self.react_agent.run(&task).await?;
475                let answer = result.answer.clone();
476                Ok(AutoAgentResult {
477                    mode: ExecutionMode::ReAct,
478                    answer,
479                    react_result: Some(result),
480                    duration_ms: start.elapsed().as_millis() as u64,
481                })
482            }
483        }
484    }
485
486    /// 分析任务复杂度
487    async fn analyze_complexity(&self, task: &str) -> TaskComplexity {
488        // 简单的关键词分析
489        let task_lower = task.to_lowercase();
490
491        // 需要工具的关键词
492        let tool_keywords = [
493            "search",
494            "find",
495            "lookup",
496            "calculate",
497            "compute",
498            "weather",
499            "current",
500            "latest",
501            "today",
502            "now",
503        ];
504
505        // 复杂任务关键词
506        let complex_keywords = [
507            "analyze",
508            "compare",
509            "research",
510            "investigate",
511            "step by step",
512            "explain in detail",
513        ];
514
515        for keyword in complex_keywords {
516            if task_lower.contains(keyword) {
517                return TaskComplexity::Complex;
518            }
519        }
520
521        for keyword in tool_keywords {
522            if task_lower.contains(keyword) {
523                return TaskComplexity::RequiresTool;
524            }
525        }
526
527        // 问号数量
528        let question_marks = task.matches('?').count();
529        if question_marks > 1 {
530            return TaskComplexity::Complex;
531        }
532
533        TaskComplexity::Simple
534    }
535}
536
537/// 任务复杂度
538#[derive(Debug, Clone, PartialEq, Eq)]
539pub enum TaskComplexity {
540    /// 简单任务 - 直接 LLM 回答
541    Simple,
542    /// 需要工具 - 使用单个工具
543    RequiresTool,
544    /// 复杂任务 - 使用完整 ReAct 循环
545    Complex,
546}
547
548/// 执行模式
549#[derive(Debug, Clone)]
550pub enum ExecutionMode {
551    /// 直接 LLM 回答
552    Direct,
553    /// ReAct 模式
554    ReAct,
555}
556
557/// AutoAgent 执行结果
558#[derive(Debug, Clone)]
559pub struct AutoAgentResult {
560    /// 执行模式
561    pub mode: ExecutionMode,
562    /// 答案
563    pub answer: String,
564    /// ReAct 结果 (如果使用 ReAct 模式)
565    pub react_result: Option<ReActResult>,
566    /// 耗时 (毫秒)
567    pub duration_ms: u64,
568}