1use crate::agent::executor::AgentExecutor;
2use crate::agent::task::Task;
3use crate::agent::{AgentDeriveT, Context, ExecutorConfig, TurnResult};
4use crate::protocol::{Event, StreamingTurnResult, SubmissionId};
5use crate::tool::{ToolCallResult, ToolT, to_llm_tool};
6use async_trait::async_trait;
7use autoagents_llm::ToolCall;
8use autoagents_llm::chat::{ChatMessage, ChatRole, MessageType, StreamChunk, Tool};
9use autoagents_llm::error::LLMError;
10use futures::{Stream, StreamExt};
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use std::collections::HashSet;
14use std::ops::Deref;
15use std::pin::Pin;
16use std::sync::Arc;
17use thiserror::Error;
18
19#[cfg(not(target_arch = "wasm32"))]
20pub use tokio::sync::mpsc::error::SendError;
21
22#[cfg(target_arch = "wasm32")]
23use futures::SinkExt;
24#[cfg(target_arch = "wasm32")]
25pub use futures::lock::Mutex;
26#[cfg(target_arch = "wasm32")]
27type SendError = futures::channel::mpsc::SendError;
28
29use crate::agent::executor::event_helper::EventHelper;
30use crate::agent::executor::memory_helper::MemoryHelper;
31use crate::agent::executor::tool_processor::ToolProcessor;
32use crate::agent::hooks::{AgentHooks, HookOutcome};
33use crate::channel::{Sender, channel};
34use crate::utils::{receiver_into_stream, spawn_future};
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct ReActAgentOutput {
39 pub response: String,
40 pub tool_calls: Vec<ToolCallResult>,
41 pub done: bool,
42}
43
44impl From<ReActAgentOutput> for Value {
45 fn from(output: ReActAgentOutput) -> Self {
46 serde_json::to_value(output).unwrap_or(Value::Null)
47 }
48}
49impl From<ReActAgentOutput> for String {
50 fn from(output: ReActAgentOutput) -> Self {
51 output.response
52 }
53}
54
55impl ReActAgentOutput {
56 pub fn try_parse<T: for<'de> serde::Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
59 serde_json::from_str::<T>(&self.response)
60 }
61
62 pub fn parse_or_map<T, F>(&self, fallback: F) -> T
66 where
67 T: for<'de> serde::Deserialize<'de>,
68 F: FnOnce(&str) -> T,
69 {
70 self.try_parse::<T>()
71 .unwrap_or_else(|_| fallback(&self.response))
72 }
73}
74
75impl ReActAgentOutput {
76 #[allow(clippy::result_large_err)]
78 pub fn extract_agent_output<T>(val: Value) -> Result<T, ReActExecutorError>
79 where
80 T: for<'de> serde::Deserialize<'de>,
81 {
82 let react_output: Self = serde_json::from_value(val)
83 .map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))?;
84 serde_json::from_str(&react_output.response)
85 .map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))
86 }
87}
88
89#[derive(Error, Debug)]
90pub enum ReActExecutorError {
91 #[error("LLM error: {0}")]
92 LLMError(String),
93
94 #[error("Maximum turns exceeded: {max_turns}")]
95 MaxTurnsExceeded { max_turns: usize },
96
97 #[error("Other error: {0}")]
98 Other(String),
99
100 #[cfg(not(target_arch = "wasm32"))]
101 #[error("Event error: {0}")]
102 EventError(#[from] SendError<Event>),
103
104 #[cfg(target_arch = "wasm32")]
105 #[error("Event error: {0}")]
106 EventError(#[from] SendError),
107
108 #[error("Extracting Agent Output Error: {0}")]
109 AgentOutputError(String),
110}
111
112#[derive(Debug)]
117pub struct ReActAgent<T: AgentDeriveT> {
118 inner: Arc<T>,
119}
120
121impl<T: AgentDeriveT> Clone for ReActAgent<T> {
122 fn clone(&self) -> Self {
123 Self {
124 inner: Arc::clone(&self.inner),
125 }
126 }
127}
128
129impl<T: AgentDeriveT> ReActAgent<T> {
130 pub fn new(inner: T) -> Self {
131 Self {
132 inner: Arc::new(inner),
133 }
134 }
135}
136
137impl<T: AgentDeriveT> Deref for ReActAgent<T> {
138 type Target = T;
139
140 fn deref(&self) -> &Self::Target {
141 &self.inner
142 }
143}
144
145#[async_trait]
147impl<T: AgentDeriveT> AgentDeriveT for ReActAgent<T> {
148 type Output = <T as AgentDeriveT>::Output;
149
150 fn description(&self) -> &'static str {
151 self.inner.description()
152 }
153
154 fn output_schema(&self) -> Option<Value> {
155 self.inner.output_schema()
156 }
157
158 fn name(&self) -> &'static str {
159 self.inner.name()
160 }
161
162 fn tools(&self) -> Vec<Box<dyn ToolT>> {
163 self.inner.tools()
164 }
165}
166
167#[async_trait]
168impl<T> AgentHooks for ReActAgent<T>
169where
170 T: AgentDeriveT + AgentHooks + Send + Sync + 'static,
171{
172 async fn on_agent_create(&self) {
173 self.inner.on_agent_create().await
174 }
175
176 async fn on_run_start(&self, task: &Task, ctx: &Context) -> HookOutcome {
177 self.inner.on_run_start(task, ctx).await
178 }
179
180 async fn on_run_complete(&self, task: &Task, result: &Self::Output, ctx: &Context) {
181 self.inner.on_run_complete(task, result, ctx).await
182 }
183
184 async fn on_turn_start(&self, turn_index: usize, ctx: &Context) {
185 self.inner.on_turn_start(turn_index, ctx).await
186 }
187
188 async fn on_turn_complete(&self, turn_index: usize, ctx: &Context) {
189 self.inner.on_turn_complete(turn_index, ctx).await
190 }
191
192 async fn on_tool_call(&self, tool_call: &ToolCall, ctx: &Context) -> HookOutcome {
193 self.inner.on_tool_call(tool_call, ctx).await
194 }
195
196 async fn on_tool_start(&self, tool_call: &ToolCall, ctx: &Context) {
197 self.inner.on_tool_start(tool_call, ctx).await
198 }
199
200 async fn on_tool_result(&self, tool_call: &ToolCall, result: &ToolCallResult, ctx: &Context) {
201 self.inner.on_tool_result(tool_call, result, ctx).await
202 }
203
204 async fn on_tool_error(&self, tool_call: &ToolCall, err: Value, ctx: &Context) {
205 self.inner.on_tool_error(tool_call, err, ctx).await
206 }
207 async fn on_agent_shutdown(&self) {
208 self.inner.on_agent_shutdown().await
209 }
210}
211
212impl<T: AgentDeriveT + AgentHooks> ReActAgent<T> {
213 async fn process_turn(
215 &self,
216 context: &Context,
217 tools: &[Box<dyn ToolT>],
218 ) -> Result<TurnResult<ReActAgentOutput>, ReActExecutorError> {
219 let messages = self.prepare_messages(context).await;
220 let response = self.get_llm_response(context, &messages, tools).await?;
221 let response_text = response.text().unwrap_or_default();
222
223 if let Some(tool_calls) = response.tool_calls() {
224 self.handle_tool_calls(context, tools, tool_calls.clone(), response_text)
225 .await
226 } else {
227 self.handle_text_response(context, response_text).await
228 }
229 }
230
231 async fn get_llm_response(
233 &self,
234 context: &Context,
235 messages: &[ChatMessage],
236 tools: &[Box<dyn ToolT>],
237 ) -> Result<Box<dyn autoagents_llm::chat::ChatResponse>, ReActExecutorError> {
238 let llm = context.llm();
239 let agent_config = context.config();
240 let tools_serialized: Vec<Tool> = tools.iter().map(to_llm_tool).collect();
241
242 if tools.is_empty() {
243 llm.chat(messages, agent_config.output_schema.clone())
244 .await
245 .map_err(|e| ReActExecutorError::LLMError(e.to_string()))
246 } else {
247 llm.chat_with_tools(
248 messages,
249 Some(&tools_serialized),
250 agent_config.output_schema.clone(),
251 )
252 .await
253 .map_err(|e| ReActExecutorError::LLMError(e.to_string()))
254 }
255 }
256
257 async fn handle_tool_calls(
259 &self,
260 context: &Context,
261 tools: &[Box<dyn ToolT>],
262 tool_calls: Vec<ToolCall>,
263 response_text: String,
264 ) -> Result<TurnResult<ReActAgentOutput>, ReActExecutorError> {
265 let tx_event = context.tx().ok();
266
267 let mut tool_results = Vec::new();
269 for call in &tool_calls {
270 if let Some(result) = ToolProcessor::process_single_tool_call_with_hooks(
271 self, context, tools, call, &tx_event,
272 )
273 .await
274 {
275 tool_results.push(result);
276 }
277 }
278
279 MemoryHelper::store_tool_interaction(
281 &context.memory(),
282 &tool_calls,
283 &tool_results,
284 &response_text,
285 )
286 .await;
287
288 {
290 let state = context.state();
291 #[cfg(not(target_arch = "wasm32"))]
292 if let Ok(mut guard) = state.try_lock() {
293 for result in &tool_results {
294 guard.record_tool_call(result.clone());
295 }
296 };
297 #[cfg(target_arch = "wasm32")]
298 if let Some(mut guard) = state.try_lock() {
299 for result in &tool_results {
300 guard.record_tool_call(result.clone());
301 }
302 };
303 }
304
305 Ok(TurnResult::Continue(Some(ReActAgentOutput {
306 response: response_text,
307 done: true,
308 tool_calls: tool_results,
309 })))
310 }
311
312 async fn handle_text_response(
314 &self,
315 context: &Context,
316 response_text: String,
317 ) -> Result<TurnResult<ReActAgentOutput>, ReActExecutorError> {
318 if !response_text.is_empty() {
319 MemoryHelper::store_assistant_response(&context.memory(), response_text.clone()).await;
320 }
321
322 Ok(TurnResult::Complete(ReActAgentOutput {
323 response: response_text,
324 done: true,
325 tool_calls: vec![],
326 }))
327 }
328
329 async fn prepare_messages(&self, context: &Context) -> Vec<ChatMessage> {
331 let mut messages = vec![ChatMessage {
332 role: ChatRole::System,
333 message_type: MessageType::Text,
334 content: context.config().description.clone(),
335 }];
336
337 let recalled = MemoryHelper::recall_messages(&context.memory()).await;
338 messages.extend(recalled);
339
340 messages
341 }
342
343 async fn process_streaming_turn(
345 &self,
346 context: &Context,
347 tools: &[Box<dyn ToolT>],
348 tx: &mut Sender<Result<ReActAgentOutput, ReActExecutorError>>,
349 submission_id: SubmissionId,
350 ) -> Result<StreamingTurnResult, ReActExecutorError> {
351 let messages = self.prepare_messages(context).await;
352 let mut stream = self.get_llm_stream(context, &messages, tools).await?;
353
354 let mut response_text = String::new();
355 let mut tool_calls: Vec<ToolCall> = Vec::new();
356 let mut tool_call_ids: HashSet<String> = HashSet::new();
357
358 while let Some(chunk_result) = stream.next().await {
360 let chunk: StreamChunk =
361 chunk_result.map_err(|e| ReActExecutorError::LLMError(e.to_string()))?;
362 let chunk_clone = chunk.clone();
363
364 match chunk {
365 StreamChunk::Text(content) => {
366 response_text.push_str(&content);
367 let _ = tx
368 .send(Ok(ReActAgentOutput {
369 response: content.to_string(),
370 tool_calls: vec![],
371 done: false,
372 }))
373 .await;
374 }
375 StreamChunk::ToolUseComplete {
376 index: _,
377 tool_call,
378 } => {
379 if tool_call_ids.insert(tool_call.id.clone()) {
380 tool_calls.push(tool_call.clone());
381
382 let tx_event = context.tx().ok();
383 EventHelper::send_stream_tool_call(
384 &tx_event,
385 submission_id,
386 serde_json::to_value(tool_call).unwrap_or(Value::Null),
387 )
388 .await;
389 }
390 }
391 StreamChunk::Usage(_) => {
392 }
394 _ => {
395 }
397 }
398 let tx_event = context.tx().ok();
400 EventHelper::send_stream_chunk(&tx_event, submission_id, chunk_clone).await;
401 }
402
403 if tool_calls.is_empty() {
405 if !response_text.is_empty() {
406 MemoryHelper::store_assistant_response(&context.memory(), response_text.clone())
407 .await;
408 }
409 return Ok(StreamingTurnResult::Complete(response_text));
410 }
411
412 let tx_event = context.tx().ok();
413 let tool_results =
414 ToolProcessor::process_tool_calls(tools, tool_calls.clone(), tx_event).await;
415
416 MemoryHelper::store_tool_interaction(
417 &context.memory(),
418 &tool_calls,
419 &tool_results,
420 &response_text,
421 )
422 .await;
423
424 let state = context.state();
425 let mut guard = state.lock().await;
426 for result in &tool_results {
427 guard.record_tool_call(result.clone());
428 }
429
430 Ok(StreamingTurnResult::ToolCallsProcessed(tool_results))
431 }
432
433 async fn get_llm_stream(
435 &self,
436 context: &Context,
437 messages: &[ChatMessage],
438 tools: &[Box<dyn ToolT>],
439 ) -> Result<
440 Pin<Box<dyn Stream<Item = Result<autoagents_llm::chat::StreamChunk, LLMError>> + Send>>,
441 ReActExecutorError,
442 > {
443 let llm = context.llm();
444 let agent_config = context.config();
445 let tools_serialized: Vec<Tool> = tools.iter().map(to_llm_tool).collect();
446
447 llm.chat_stream_with_tools(
448 messages,
449 if tools.is_empty() {
450 None
451 } else {
452 Some(&tools_serialized)
453 },
454 agent_config.output_schema.clone(),
455 )
456 .await
457 .map_err(|e| ReActExecutorError::LLMError(e.to_string()))
458 }
459}
460
461#[async_trait]
463impl<T: AgentDeriveT + AgentHooks> AgentExecutor for ReActAgent<T> {
464 type Output = ReActAgentOutput;
465 type Error = ReActExecutorError;
466
467 fn config(&self) -> ExecutorConfig {
468 ExecutorConfig { max_turns: 10 }
469 }
470
471 async fn execute(
472 &self,
473 task: &Task,
474 context: Arc<Context>,
475 ) -> Result<Self::Output, Self::Error> {
476 MemoryHelper::store_user_message(
478 &context.memory(),
479 task.prompt.clone(),
480 task.image.clone(),
481 )
482 .await;
483
484 {
486 let state = context.state();
487 #[cfg(not(target_arch = "wasm32"))]
488 if let Ok(mut guard) = state.try_lock() {
489 guard.record_task(task.clone());
490 };
491 #[cfg(target_arch = "wasm32")]
492 if let Some(mut guard) = state.try_lock() {
493 guard.record_task(task.clone());
494 };
495 }
496
497 let tx_event = context.tx().ok();
499 EventHelper::send_task_started(
500 &tx_event,
501 task.submission_id,
502 context.config().id,
503 task.prompt.clone(),
504 context.config().name.clone(),
505 )
506 .await;
507
508 let max_turns = self.config().max_turns;
510 let mut accumulated_tool_calls = Vec::new();
511 let mut final_response = String::new();
512
513 for turn_num in 0..max_turns {
514 let tools = context.tools();
515 EventHelper::send_turn_started(&tx_event, turn_num, max_turns).await;
516
517 self.on_turn_start(turn_num, &context).await;
519
520 match self.process_turn(&context, tools).await? {
521 TurnResult::Complete(result) => {
522 if !accumulated_tool_calls.is_empty() {
523 return Ok(ReActAgentOutput {
524 response: result.response,
525 done: true,
526 tool_calls: accumulated_tool_calls,
527 });
528 }
529 EventHelper::send_turn_completed(&tx_event, turn_num, false).await;
530 self.on_turn_complete(turn_num, &context).await;
532 return Ok(result);
533 }
534 TurnResult::Continue(Some(partial_result)) => {
535 accumulated_tool_calls.extend(partial_result.tool_calls);
536 if !partial_result.response.is_empty() {
537 final_response = partial_result.response;
538 }
539 }
540 TurnResult::Continue(None) => continue,
541 }
542 }
543
544 if !final_response.is_empty() || !accumulated_tool_calls.is_empty() {
545 EventHelper::send_task_completed(
546 &tx_event,
547 task.submission_id,
548 context.config().id,
549 final_response.clone(),
550 context.config().name.clone(),
551 )
552 .await;
553 Ok(ReActAgentOutput {
554 response: final_response,
555 done: true,
556 tool_calls: accumulated_tool_calls,
557 })
558 } else {
559 Err(ReActExecutorError::MaxTurnsExceeded { max_turns })
560 }
561 }
562
563 async fn execute_stream(
564 &self,
565 task: &Task,
566 context: Arc<Context>,
567 ) -> Result<
568 Pin<Box<dyn Stream<Item = Result<ReActAgentOutput, Self::Error>> + Send>>,
569 Self::Error,
570 > {
571 MemoryHelper::store_user_message(
573 &context.memory(),
574 task.prompt.clone(),
575 task.image.clone(),
576 )
577 .await;
578
579 {
581 let state = context.state();
582 #[cfg(not(target_arch = "wasm32"))]
583 if let Ok(mut guard) = state.try_lock() {
584 guard.record_task(task.clone());
585 };
586 #[cfg(target_arch = "wasm32")]
587 if let Some(mut guard) = state.try_lock() {
588 guard.record_task(task.clone());
589 };
590 }
591
592 let tx_event = context.tx().ok();
594 EventHelper::send_task_started(
595 &tx_event,
596 task.submission_id,
597 context.config().id,
598 task.prompt.clone(),
599 context.config().name.clone(),
600 )
601 .await;
602
603 let (mut tx, rx) = channel::<Result<ReActAgentOutput, ReActExecutorError>>(100);
605
606 let executor = self.clone();
608 let context_clone = context.clone();
609 let submission_id = task.submission_id;
610 let max_turns = executor.config().max_turns;
611
612 spawn_future(async move {
614 let mut accumulated_tool_calls = Vec::new();
615 let mut final_response = String::new();
616 let tools = context_clone.tools();
617
618 for turn in 0..max_turns {
619 let tx_event = context_clone.tx().ok();
621 EventHelper::send_turn_started(&tx_event, turn, max_turns).await;
622
623 match executor
625 .process_streaming_turn(&context_clone, tools, &mut tx, submission_id)
626 .await
627 {
628 Ok(StreamingTurnResult::Complete(response)) => {
629 final_response = response;
630 EventHelper::send_turn_completed(&tx_event, turn, true).await;
631 break;
632 }
633 Ok(StreamingTurnResult::ToolCallsProcessed(tool_results)) => {
634 accumulated_tool_calls.extend(tool_results);
635
636 let _ = tx
637 .send(Ok(ReActAgentOutput {
638 response: String::new(),
639 done: false,
640 tool_calls: accumulated_tool_calls.clone(),
641 }))
642 .await;
643
644 EventHelper::send_turn_completed(&tx_event, turn, false).await;
645 }
646 Err(e) => {
647 let _ = tx.send(Err(e)).await;
648 return;
649 }
650 }
651 }
652
653 let tx_event = context_clone.tx().ok();
655 EventHelper::send_stream_complete(&tx_event, submission_id).await;
656
657 let _ = tx
658 .send(Ok(ReActAgentOutput {
659 response: final_response,
660 done: true,
661 tool_calls: accumulated_tool_calls,
662 }))
663 .await;
664 });
665
666 Ok(receiver_into_stream(rx))
667 }
668}
669
670#[cfg(test)]
671mod tests {
672 use super::*;
673
674 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
675 struct TestAgentOutput {
676 value: i32,
677 message: String,
678 }
679
680 #[test]
681 fn test_extract_agent_output_success() {
682 let agent_output = TestAgentOutput {
683 value: 42,
684 message: "Hello, world!".to_string(),
685 };
686
687 let react_output = ReActAgentOutput {
688 response: serde_json::to_string(&agent_output).unwrap(),
689 done: true,
690 tool_calls: vec![],
691 };
692
693 let react_value = serde_json::to_value(react_output).unwrap();
694 let extracted: TestAgentOutput =
695 ReActAgentOutput::extract_agent_output(react_value).unwrap();
696 assert_eq!(extracted, agent_output);
697 }
698}