1use super::core::{ReActAgent, ReActConfig, ReActResult, ReActStep, ReActTool};
6use crate::llm::{LLMAgent, LLMError, LLMResult};
7use ractor::{Actor, ActorProcessingErr, ActorRef};
8use std::fmt;
9use std::future::Future;
10use std::sync::Arc;
11use tokio::sync::{mpsc, oneshot};
12
13pub enum ReActMessage {
15 RunTask {
17 task: String,
18 reply: oneshot::Sender<LLMResult<ReActResult>>,
19 },
20 RunTaskStreaming {
22 task: String,
23 step_tx: mpsc::Sender<ReActStep>,
24 reply: oneshot::Sender<LLMResult<ReActResult>>,
25 },
26 RegisterTool { tool: Arc<dyn ReActTool> },
28 GetStatus {
30 reply: oneshot::Sender<ReActActorStatus>,
31 },
32 CancelTask,
34 Stop,
36}
37
38impl fmt::Debug for ReActMessage {
39 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
40 match self {
41 Self::RunTask { task, .. } => f.debug_struct("RunTask").field("task", task).finish(),
42 Self::RunTaskStreaming { task, .. } => f
43 .debug_struct("RunTaskStreaming")
44 .field("task", task)
45 .finish(),
46 Self::RegisterTool { tool } => f
47 .debug_struct("RegisterTool")
48 .field("tool_name", &tool.name())
49 .finish(),
50 Self::GetStatus { .. } => f.debug_struct("GetStatus").finish(),
51 Self::CancelTask => f.debug_struct("CancelTask").finish(),
52 Self::Stop => f.debug_struct("Stop").finish(),
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
59pub struct ReActActorStatus {
60 pub id: String,
62 pub is_running: bool,
64 pub completed_tasks: usize,
66 pub tool_count: usize,
68 pub current_task_id: Option<String>,
70}
71
72pub struct ReActActorState {
74 agent: Option<ReActAgent>,
76 llm: Option<Arc<LLMAgent>>,
78 config: ReActConfig,
80 pending_tools: Vec<Arc<dyn ReActTool>>,
82 is_running: bool,
84 completed_tasks: usize,
86 current_task_id: Option<String>,
88 #[allow(dead_code)]
90 cancelled: bool,
91}
92
93impl ReActActorState {
94 pub fn new(llm: Arc<LLMAgent>, config: ReActConfig) -> Self {
95 Self {
96 agent: None,
97 llm: Some(llm),
98 config,
99 pending_tools: Vec::new(),
100 is_running: false,
101 completed_tasks: 0,
102 current_task_id: None,
103 cancelled: false,
104 }
105 }
106
107 async fn ensure_agent(&mut self) -> LLMResult<&ReActAgent> {
109 if self.agent.is_none() {
110 let llm = self
111 .llm
112 .take()
113 .ok_or_else(|| LLMError::ConfigError("LLM already consumed".to_string()))?;
114
115 let agent = ReActAgent::new(llm, self.config.clone());
116
117 for tool in self.pending_tools.drain(..) {
119 agent.register_tool(tool).await;
120 }
121
122 self.agent = Some(agent);
123 }
124
125 self.agent
126 .as_ref()
127 .ok_or_else(|| LLMError::Other("Agent not initialized".to_string()))
128 }
129}
130
131pub struct ReActActor;
133
134impl ReActActor {
135 pub fn new() -> Self {
137 Self
138 }
139}
140
141impl Default for ReActActor {
142 fn default() -> Self {
143 Self::new()
144 }
145}
146
147impl Actor for ReActActor {
148 type Msg = ReActMessage;
149 type State = ReActActorState;
150 type Arguments = (Arc<LLMAgent>, ReActConfig, Vec<Arc<dyn ReActTool>>);
151
152 fn pre_start(
153 &self,
154 _myself: ActorRef<Self::Msg>,
155 args: Self::Arguments,
156 ) -> impl Future<Output = Result<Self::State, ActorProcessingErr>> + Send {
157 async move {
158 let (llm, config, tools) = args;
159 let mut state = ReActActorState::new(llm, config);
160 state.pending_tools = tools;
161 Ok(state)
162 }
163 }
164
165 fn handle(
166 &self,
167 myself: ActorRef<Self::Msg>,
168 message: Self::Msg,
169 state: &mut Self::State,
170 ) -> impl Future<Output = Result<(), ActorProcessingErr>> + Send {
171 handle_message(myself, message, state)
174 }
175}
176
177async fn handle_message(
179 myself: ActorRef<ReActMessage>,
180 message: ReActMessage,
181 state: &mut ReActActorState,
182) -> Result<(), ActorProcessingErr> {
183 match message {
184 ReActMessage::RunTask { task, reply } => {
185 if state.is_running {
186 let _ = reply.send(Err(LLMError::Other(
187 "Agent is already running a task".to_string(),
188 )));
189 return Ok(());
190 }
191
192 state.is_running = true;
193 state.cancelled = false;
194 state.current_task_id = Some(uuid::Uuid::now_v7().to_string());
195
196 let result = match state.ensure_agent().await {
197 Ok(agent) => agent.run(&task).await,
198 Err(e) => Err(e),
199 };
200
201 state.is_running = false;
202 state.current_task_id = None;
203
204 if result.is_ok() {
205 state.completed_tasks += 1;
206 }
207
208 let _ = reply.send(result);
209 }
210
211 ReActMessage::RunTaskStreaming {
212 task,
213 step_tx,
214 reply,
215 } => {
216 if state.is_running {
217 let _ = reply.send(Err(LLMError::Other(
218 "Agent is already running a task".to_string(),
219 )));
220 return Ok(());
221 }
222
223 state.is_running = true;
224 state.cancelled = false;
225 let task_id = uuid::Uuid::now_v7().to_string();
226 state.current_task_id = Some(task_id.clone());
227
228 let result = match state.ensure_agent().await {
230 Ok(agent) => {
231 let result = agent.run(&task).await;
233
234 if let Ok(ref res) = result {
236 for step in &res.steps {
237 let _ = step_tx.send(step.clone()).await;
238 }
239 }
240
241 result
242 }
243 Err(e) => Err(e),
244 };
245
246 state.is_running = false;
247 state.current_task_id = None;
248
249 if result.is_ok() {
250 state.completed_tasks += 1;
251 }
252
253 let _ = reply.send(result);
254 }
255
256 ReActMessage::RegisterTool { tool } => {
257 if let Some(ref agent) = state.agent {
258 agent.register_tool(tool).await;
259 } else {
260 state.pending_tools.push(tool);
261 }
262 }
263
264 ReActMessage::GetStatus { reply } => {
265 let tool_count = if let Some(ref agent) = state.agent {
266 agent.get_tools().await.len()
267 } else {
268 state.pending_tools.len()
269 };
270
271 let status = ReActActorStatus {
272 id: state.current_task_id.clone().unwrap_or_default(),
273 is_running: state.is_running,
274 completed_tasks: state.completed_tasks,
275 tool_count,
276 current_task_id: state.current_task_id.clone(),
277 };
278
279 let _ = reply.send(status);
280 }
281
282 ReActMessage::CancelTask => {
283 state.cancelled = true;
284 }
285
286 ReActMessage::Stop => {
287 myself.stop(Some("Stop requested".to_string()));
288 }
289 }
290
291 Ok(())
292}
293
294pub struct ReActActorRef {
298 actor: ActorRef<ReActMessage>,
299}
300
301impl ReActActorRef {
302 pub fn new(actor: ActorRef<ReActMessage>) -> Self {
304 Self { actor }
305 }
306
307 pub async fn run_task(&self, task: impl Into<String>) -> LLMResult<ReActResult> {
309 let (tx, rx) = oneshot::channel();
310 self.actor
311 .send_message(ReActMessage::RunTask {
312 task: task.into(),
313 reply: tx,
314 })
315 .map_err(|e| LLMError::Other(format!("Failed to send message: {}", e)))?;
316
317 rx.await
318 .map_err(|e| LLMError::Other(format!("Failed to receive response: {}", e)))?
319 }
320
321 pub async fn run_task_streaming(
323 &self,
324 task: impl Into<String>,
325 ) -> LLMResult<(
326 mpsc::Receiver<ReActStep>,
327 oneshot::Receiver<LLMResult<ReActResult>>,
328 )> {
329 let (step_tx, step_rx) = mpsc::channel(100);
330 let (result_tx, result_rx) = oneshot::channel();
331
332 self.actor
333 .send_message(ReActMessage::RunTaskStreaming {
334 task: task.into(),
335 step_tx,
336 reply: result_tx,
337 })
338 .map_err(|e| LLMError::Other(format!("Failed to send message: {}", e)))?;
339
340 Ok((step_rx, result_rx))
341 }
342
343 pub fn register_tool(&self, tool: Arc<dyn ReActTool>) -> LLMResult<()> {
345 self.actor
346 .send_message(ReActMessage::RegisterTool { tool })
347 .map_err(|e| LLMError::Other(format!("Failed to register tool: {}", e)))
348 }
349
350 pub async fn get_status(&self) -> LLMResult<ReActActorStatus> {
352 let (tx, rx) = oneshot::channel();
353 self.actor
354 .send_message(ReActMessage::GetStatus { reply: tx })
355 .map_err(|e| LLMError::Other(format!("Failed to send message: {}", e)))?;
356
357 rx.await
358 .map_err(|e| LLMError::Other(format!("Failed to receive status: {}", e)))
359 }
360
361 pub fn cancel_task(&self) -> LLMResult<()> {
363 self.actor
364 .send_message(ReActMessage::CancelTask)
365 .map_err(|e| LLMError::Other(format!("Failed to cancel task: {}", e)))
366 }
367
368 pub fn stop(&self) -> LLMResult<()> {
370 self.actor
371 .send_message(ReActMessage::Stop)
372 .map_err(|e| LLMError::Other(format!("Failed to stop actor: {}", e)))
373 }
374
375 pub fn inner(&self) -> &ActorRef<ReActMessage> {
377 &self.actor
378 }
379}
380
381pub async fn spawn_react_actor(
396 name: impl Into<String>,
397 llm: Arc<LLMAgent>,
398 config: ReActConfig,
399 tools: Vec<Arc<dyn ReActTool>>,
400) -> LLMResult<(ReActActorRef, tokio::task::JoinHandle<()>)> {
401 let (actor_ref, handle) =
402 Actor::spawn(Some(name.into()), ReActActor::new(), (llm, config, tools))
403 .await
404 .map_err(|e| LLMError::Other(format!("Failed to spawn actor: {}", e)))?;
405
406 Ok((ReActActorRef::new(actor_ref), handle))
407}
408
409pub struct AutoAgent {
417 react_agent: Arc<ReActAgent>,
419 llm: Arc<LLMAgent>,
421 auto_mode: bool,
423}
424
425impl AutoAgent {
426 pub fn new(llm: Arc<LLMAgent>, react_agent: Arc<ReActAgent>) -> Self {
428 Self {
429 react_agent,
430 llm,
431 auto_mode: true,
432 }
433 }
434
435 pub fn with_auto_mode(mut self, enabled: bool) -> Self {
437 self.auto_mode = enabled;
438 self
439 }
440
441 pub async fn run(&self, task: impl Into<String>) -> LLMResult<AutoAgentResult> {
443 let task = task.into();
444 let start = std::time::Instant::now();
445
446 if !self.auto_mode {
447 let result = self.react_agent.run(&task).await?;
449 let answer = result.answer.clone();
450 return Ok(AutoAgentResult {
451 mode: ExecutionMode::ReAct,
452 answer,
453 react_result: Some(result),
454 duration_ms: start.elapsed().as_millis() as u64,
455 });
456 }
457
458 let complexity = self.analyze_complexity(&task).await;
460
461 match complexity {
462 TaskComplexity::Simple => {
463 let answer = self.llm.ask(&task).await?;
465 Ok(AutoAgentResult {
466 mode: ExecutionMode::Direct,
467 answer,
468 react_result: None,
469 duration_ms: start.elapsed().as_millis() as u64,
470 })
471 }
472 TaskComplexity::RequiresTool | TaskComplexity::Complex => {
473 let result = self.react_agent.run(&task).await?;
475 let answer = result.answer.clone();
476 Ok(AutoAgentResult {
477 mode: ExecutionMode::ReAct,
478 answer,
479 react_result: Some(result),
480 duration_ms: start.elapsed().as_millis() as u64,
481 })
482 }
483 }
484 }
485
486 async fn analyze_complexity(&self, task: &str) -> TaskComplexity {
488 let task_lower = task.to_lowercase();
490
491 let tool_keywords = [
493 "search",
494 "find",
495 "lookup",
496 "calculate",
497 "compute",
498 "weather",
499 "current",
500 "latest",
501 "today",
502 "now",
503 ];
504
505 let complex_keywords = [
507 "analyze",
508 "compare",
509 "research",
510 "investigate",
511 "step by step",
512 "explain in detail",
513 ];
514
515 for keyword in complex_keywords {
516 if task_lower.contains(keyword) {
517 return TaskComplexity::Complex;
518 }
519 }
520
521 for keyword in tool_keywords {
522 if task_lower.contains(keyword) {
523 return TaskComplexity::RequiresTool;
524 }
525 }
526
527 let question_marks = task.matches('?').count();
529 if question_marks > 1 {
530 return TaskComplexity::Complex;
531 }
532
533 TaskComplexity::Simple
534 }
535}
536
537#[derive(Debug, Clone, PartialEq, Eq)]
539pub enum TaskComplexity {
540 Simple,
542 RequiresTool,
544 Complex,
546}
547
548#[derive(Debug, Clone)]
550pub enum ExecutionMode {
551 Direct,
553 ReAct,
555}
556
557#[derive(Debug, Clone)]
559pub struct AutoAgentResult {
560 pub mode: ExecutionMode,
562 pub answer: String,
564 pub react_result: Option<ReActResult>,
566 pub duration_ms: u64,
568}