1use crate::llm::{LLMAgent, LLMError, LLMResult, Tool};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
11pub enum ReActStepType {
12 Thought,
14 Action,
16 Observation,
18 FinalAnswer,
20}
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct ReActStep {
25 pub step_type: ReActStepType,
27 pub content: String,
29 #[serde(skip_serializing_if = "Option::is_none")]
31 pub tool_name: Option<String>,
32 #[serde(skip_serializing_if = "Option::is_none")]
34 pub tool_input: Option<String>,
35 pub step_number: usize,
37 pub timestamp: u64,
39}
40
41impl ReActStep {
42 pub fn thought(content: impl Into<String>, step_number: usize) -> Self {
43 Self {
44 step_type: ReActStepType::Thought,
45 content: content.into(),
46 tool_name: None,
47 tool_input: None,
48 step_number,
49 timestamp: Self::current_timestamp(),
50 }
51 }
52
53 pub fn action(
54 tool_name: impl Into<String>,
55 tool_input: impl Into<String>,
56 step_number: usize,
57 ) -> Self {
58 let tool_name = tool_name.into();
59 let tool_input = tool_input.into();
60 Self {
61 step_type: ReActStepType::Action,
62 content: format!("Action: {}[{}]", tool_name, tool_input),
63 tool_name: Some(tool_name),
64 tool_input: Some(tool_input),
65 step_number,
66 timestamp: Self::current_timestamp(),
67 }
68 }
69
70 pub fn observation(content: impl Into<String>, step_number: usize) -> Self {
71 Self {
72 step_type: ReActStepType::Observation,
73 content: content.into(),
74 tool_name: None,
75 tool_input: None,
76 step_number,
77 timestamp: Self::current_timestamp(),
78 }
79 }
80
81 pub fn final_answer(content: impl Into<String>, step_number: usize) -> Self {
82 Self {
83 step_type: ReActStepType::FinalAnswer,
84 content: content.into(),
85 tool_name: None,
86 tool_input: None,
87 step_number,
88 timestamp: Self::current_timestamp(),
89 }
90 }
91
92 fn current_timestamp() -> u64 {
93 std::time::SystemTime::now()
94 .duration_since(std::time::UNIX_EPOCH)
95 .unwrap_or_default()
96 .as_millis() as u64
97 }
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct ReActResult {
103 pub task_id: String,
105 pub task: String,
107 pub answer: String,
109 pub steps: Vec<ReActStep>,
111 pub success: bool,
113 #[serde(skip_serializing_if = "Option::is_none")]
115 pub error: Option<String>,
116 pub iterations: usize,
118 pub duration_ms: u64,
120}
121
122impl ReActResult {
123 pub fn success(
124 task_id: impl Into<String>,
125 task: impl Into<String>,
126 answer: impl Into<String>,
127 steps: Vec<ReActStep>,
128 iterations: usize,
129 duration_ms: u64,
130 ) -> Self {
131 Self {
132 task_id: task_id.into(),
133 task: task.into(),
134 answer: answer.into(),
135 steps,
136 success: true,
137 error: None,
138 iterations,
139 duration_ms,
140 }
141 }
142
143 pub fn failed(
144 task_id: impl Into<String>,
145 task: impl Into<String>,
146 error: impl Into<String>,
147 steps: Vec<ReActStep>,
148 iterations: usize,
149 duration_ms: u64,
150 ) -> Self {
151 Self {
152 task_id: task_id.into(),
153 task: task.into(),
154 answer: String::new(),
155 steps,
156 success: false,
157 error: Some(error.into()),
158 iterations,
159 duration_ms,
160 }
161 }
162}
163
164#[async_trait::async_trait]
168pub trait ReActTool: Send + Sync {
169 fn name(&self) -> &str;
171
172 fn description(&self) -> &str;
174
175 fn parameters_schema(&self) -> Option<serde_json::Value> {
177 None
178 }
179
180 async fn execute(&self, input: &str) -> Result<String, String>;
188
189 fn to_llm_tool(&self) -> Tool {
191 let params = self.parameters_schema().unwrap_or_else(|| {
192 serde_json::json!({
193 "type": "object",
194 "properties": {
195 "input": {
196 "type": "string",
197 "description": "The input for the tool"
198 }
199 },
200 "required": ["input"]
201 })
202 });
203
204 Tool::function(self.name(), self.description(), params)
205 }
206}
207
208#[derive(Debug, Clone)]
210pub struct ReActConfig {
211 pub max_iterations: usize,
213 pub stream_output: bool,
215 pub temperature: f32,
217 pub system_prompt: Option<String>,
219 pub verbose: bool,
221 pub max_tokens_per_step: Option<u32>,
223}
224
225impl Default for ReActConfig {
226 fn default() -> Self {
227 Self {
228 max_iterations: 10,
229 stream_output: false,
230 temperature: 0.7,
231 system_prompt: None,
232 verbose: true,
233 max_tokens_per_step: Some(2048),
234 }
235 }
236}
237
238impl ReActConfig {
239 pub fn new() -> Self {
240 Self::default()
241 }
242
243 pub fn with_max_iterations(mut self, max: usize) -> Self {
244 self.max_iterations = max;
245 self
246 }
247
248 pub fn with_stream_output(mut self, enabled: bool) -> Self {
249 self.stream_output = enabled;
250 self
251 }
252
253 pub fn with_temperature(mut self, temp: f32) -> Self {
254 self.temperature = temp;
255 self
256 }
257
258 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
259 self.system_prompt = Some(prompt.into());
260 self
261 }
262
263 pub fn with_verbose(mut self, verbose: bool) -> Self {
264 self.verbose = verbose;
265 self
266 }
267}
268
269pub struct ReActAgent {
271 llm: Arc<LLMAgent>,
273 tools: Arc<RwLock<HashMap<String, Arc<dyn ReActTool>>>>,
275 config: ReActConfig,
277}
278
279impl ReActAgent {
280 pub fn builder() -> ReActAgentBuilder {
282 ReActAgentBuilder::new()
283 }
284
285 pub fn new(llm: Arc<LLMAgent>, config: ReActConfig) -> Self {
287 Self {
288 llm,
289 tools: Arc::new(RwLock::new(HashMap::new())),
290 config,
291 }
292 }
293
294 pub async fn register_tool(&self, tool: Arc<dyn ReActTool>) {
296 let mut tools = self.tools.write().await;
297 tools.insert(tool.name().to_string(), tool);
298 }
299
300 pub async fn get_tools(&self) -> Vec<Arc<dyn ReActTool>> {
302 let tools = self.tools.read().await;
303 tools.values().cloned().collect()
304 }
305
306 pub async fn run(&self, task: impl Into<String>) -> LLMResult<ReActResult> {
308 let task = task.into();
309 let task_id = uuid::Uuid::now_v7().to_string();
310 let start_time = std::time::Instant::now();
311
312 let mut steps = Vec::new();
313 let mut step_number = 0;
314
315 let system_prompt = self.build_system_prompt().await;
317
318 let mut conversation = vec![format!("Task: {}", task)];
320
321 for iteration in 0..self.config.max_iterations {
322 step_number += 1;
323
324 let prompt = self.build_prompt(&system_prompt, &conversation).await;
326 let response = self.llm.ask(&prompt).await?;
327
328 let parsed = self.parse_response(&response);
330
331 match parsed {
332 ParsedResponse::Thought(thought) => {
333 steps.push(ReActStep::thought(&thought, step_number));
334 conversation.push(format!("Thought: {}", thought));
335
336 if self.config.verbose {
337 tracing::info!("Thought: {}", thought);
338 }
339 }
340 ParsedResponse::Action { tool, input } => {
341 steps.push(ReActStep::action(&tool, &input, step_number));
342 conversation.push(format!("Action: {}[{}]", tool, input));
343
344 if self.config.verbose {
345 tracing::info!("Action: {}[{}]", tool, input);
346 }
347
348 step_number += 1;
350 let observation = self.execute_tool(&tool, &input).await;
351 steps.push(ReActStep::observation(&observation, step_number));
352 conversation.push(format!("Observation: {}", observation));
353
354 if self.config.verbose {
355 tracing::info!("Observation: {}", observation);
356 }
357 }
358 ParsedResponse::FinalAnswer(answer) => {
359 steps.push(ReActStep::final_answer(&answer, step_number));
360
361 if self.config.verbose {
362 tracing::info!("Final Answer: {}", answer);
363 }
364
365 return Ok(ReActResult::success(
366 task_id,
367 &task,
368 answer,
369 steps,
370 iteration + 1,
371 start_time.elapsed().as_millis() as u64,
372 ));
373 }
374 ParsedResponse::Error(err) => {
375 return Ok(ReActResult::failed(
376 task_id,
377 &task,
378 err,
379 steps,
380 iteration + 1,
381 start_time.elapsed().as_millis() as u64,
382 ));
383 }
384 }
385 }
386
387 Ok(ReActResult::failed(
389 task_id,
390 &task,
391 format!("Max iterations ({}) exceeded", self.config.max_iterations),
392 steps,
393 self.config.max_iterations,
394 start_time.elapsed().as_millis() as u64,
395 ))
396 }
397
398 async fn build_system_prompt(&self) -> String {
400 if let Some(ref custom_prompt) = self.config.system_prompt {
401 return custom_prompt.clone();
402 }
403
404 let tools = self.tools.read().await;
405 let tool_descriptions: Vec<String> = tools
406 .values()
407 .map(|t| format!("- {}: {}", t.name(), t.description()))
408 .collect();
409
410 format!(
411 r#"You are a ReAct (Reasoning and Acting) agent. You solve tasks by thinking step by step and using available tools.
412
413Available tools:
414{}
415
416You must respond in one of these formats:
417
4181. When you need to think:
419Thought: <your reasoning about what to do next>
420
4212. When you want to use a tool:
422Action: <tool_name>[<input>]
423
4243. When you have the final answer:
425Final Answer: <your final answer to the task>
426
427Rules:
428- Always start with a Thought
429- Use tools when you need external information
430- Be concise and focused
431- Provide a Final Answer when you have enough information
432- If a tool returns an error, think about alternatives"#,
433 tool_descriptions.join("\n")
434 )
435 }
436
437 async fn build_prompt(&self, system_prompt: &str, conversation: &[String]) -> String {
439 format!("{}\n\n{}", system_prompt, conversation.join("\n"))
440 }
441
442 fn parse_response(&self, response: &str) -> ParsedResponse {
444 let response = response.trim();
445
446 if let Some(answer) = response.strip_prefix("Final Answer:") {
448 return ParsedResponse::FinalAnswer(answer.trim().to_string());
449 }
450
451 if let Some(action_part) = response.strip_prefix("Action:") {
453 let action_part = action_part.trim();
454 if let Some(bracket_start) = action_part.find('[')
455 && let Some(bracket_end) = action_part.rfind(']')
456 {
457 let tool = action_part[..bracket_start].trim().to_string();
458 let input = action_part[bracket_start + 1..bracket_end]
459 .trim()
460 .to_string();
461 return ParsedResponse::Action { tool, input };
462 }
463 return ParsedResponse::Error(format!("Invalid action format: {}", action_part));
464 }
465
466 if let Some(thought) = response.strip_prefix("Thought:") {
468 return ParsedResponse::Thought(thought.trim().to_string());
469 }
470
471 for line in response.lines() {
473 let line = line.trim();
474 if line.starts_with("Final Answer:") {
475 return ParsedResponse::FinalAnswer(
476 line.strip_prefix("Final Answer:")
477 .unwrap()
478 .trim()
479 .to_string(),
480 );
481 }
482 if line.starts_with("Action:") {
483 let action_part = line.strip_prefix("Action:").unwrap().trim();
484 if let Some(bracket_start) = action_part.find('[')
485 && let Some(bracket_end) = action_part.rfind(']')
486 {
487 let tool = action_part[..bracket_start].trim().to_string();
488 let input = action_part[bracket_start + 1..bracket_end]
489 .trim()
490 .to_string();
491 return ParsedResponse::Action { tool, input };
492 }
493 }
494 if line.starts_with("Thought:") {
495 return ParsedResponse::Thought(
496 line.strip_prefix("Thought:").unwrap().trim().to_string(),
497 );
498 }
499 }
500
501 ParsedResponse::Thought(response.to_string())
503 }
504
505 async fn execute_tool(&self, tool_name: &str, input: &str) -> String {
507 let tools = self.tools.read().await;
508
509 match tools.get(tool_name) {
510 Some(tool) => match tool.execute(input).await {
511 Ok(result) => result,
512 Err(e) => format!("Tool error: {}", e),
513 },
514 None => format!(
515 "Tool '{}' not found. Available tools: {:?}",
516 tool_name,
517 tools.keys().collect::<Vec<_>>()
518 ),
519 }
520 }
521}
522
523enum ParsedResponse {
525 Thought(String),
526 Action { tool: String, input: String },
527 FinalAnswer(String),
528 Error(String),
529}
530
531pub struct ReActAgentBuilder {
533 llm: Option<Arc<LLMAgent>>,
534 tools: Vec<Arc<dyn ReActTool>>,
535 config: ReActConfig,
536}
537
538impl ReActAgentBuilder {
539 pub fn new() -> Self {
540 Self {
541 llm: None,
542 tools: Vec::new(),
543 config: ReActConfig::default(),
544 }
545 }
546
547 pub fn with_llm(mut self, llm: Arc<LLMAgent>) -> Self {
549 self.llm = Some(llm);
550 self
551 }
552
553 pub fn with_tool(mut self, tool: Arc<dyn ReActTool>) -> Self {
555 self.tools.push(tool);
556 self
557 }
558
559 pub fn with_tools(mut self, tools: Vec<Arc<dyn ReActTool>>) -> Self {
561 self.tools.extend(tools);
562 self
563 }
564
565 pub fn with_max_iterations(mut self, max: usize) -> Self {
567 self.config.max_iterations = max;
568 self
569 }
570
571 pub fn with_temperature(mut self, temp: f32) -> Self {
573 self.config.temperature = temp;
574 self
575 }
576
577 pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
579 self.config.system_prompt = Some(prompt.into());
580 self
581 }
582
583 pub fn with_verbose(mut self, verbose: bool) -> Self {
585 self.config.verbose = verbose;
586 self
587 }
588
589 pub fn with_config(mut self, config: ReActConfig) -> Self {
591 self.config = config;
592 self
593 }
594
595 pub fn build(self) -> LLMResult<ReActAgent> {
597 let llm = self
598 .llm
599 .ok_or_else(|| LLMError::ConfigError("LLM agent not set".to_string()))?;
600
601 let agent = ReActAgent::new(llm, self.config);
602
603 let tools = self.tools;
605 let agent_tools = agent.tools.clone();
606
607 tokio::spawn(async move {
608 let mut tool_map = agent_tools.write().await;
609 for tool in tools {
610 tool_map.insert(tool.name().to_string(), tool);
611 }
612 });
613
614 Ok(agent)
615 }
616
617 pub async fn build_async(self) -> LLMResult<ReActAgent> {
619 let llm = self
620 .llm
621 .ok_or_else(|| LLMError::ConfigError("LLM agent not set".to_string()))?;
622
623 let agent = ReActAgent::new(llm, self.config);
624
625 for tool in self.tools {
627 agent.register_tool(tool).await;
628 }
629
630 Ok(agent)
631 }
632}
633
634impl Default for ReActAgentBuilder {
635 fn default() -> Self {
636 Self::new()
637 }
638}
639
640#[cfg(test)]
641mod tests {
642 use super::*;
643
644 #[test]
645 fn test_react_step_creation() {
646 let thought = ReActStep::thought("I need to search for information", 1);
647 assert!(matches!(thought.step_type, ReActStepType::Thought));
648
649 let action = ReActStep::action("search", "capital of France", 2);
650 assert!(matches!(action.step_type, ReActStepType::Action));
651 assert_eq!(action.tool_name, Some("search".to_string()));
652
653 let observation = ReActStep::observation("Paris is the capital of France", 3);
654 assert!(matches!(observation.step_type, ReActStepType::Observation));
655
656 let answer = ReActStep::final_answer("Paris", 4);
657 assert!(matches!(answer.step_type, ReActStepType::FinalAnswer));
658 }
659
660 #[test]
661 fn test_react_config() {
662 let config = ReActConfig::new()
663 .with_max_iterations(5)
664 .with_temperature(0.5)
665 .with_verbose(false);
666
667 assert_eq!(config.max_iterations, 5);
668 assert_eq!(config.temperature, 0.5);
669 assert!(!config.verbose);
670 }
671}