1use crate::agent::executor::AgentExecutor;
2use crate::agent::task::Task;
3use crate::agent::{AgentDeriveT, Context, ExecutorConfig, TurnResult};
4use crate::protocol::{Event, StreamingTurnResult, SubmissionId};
5use crate::tool::{ToolCallResult, ToolT};
6use async_trait::async_trait;
7use autoagents_llm::chat::{ChatMessage, ChatRole, MessageType, StreamChoice, Tool};
8use autoagents_llm::error::LLMError;
9use autoagents_llm::{FunctionCall, ToolCall};
10use futures::{Stream, StreamExt};
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use std::collections::HashMap;
14use std::ops::Deref;
15use std::pin::Pin;
16use std::sync::Arc;
17use thiserror::Error;
18
19#[cfg(not(target_arch = "wasm32"))]
20pub use tokio::sync::mpsc::error::SendError;
21
22#[cfg(target_arch = "wasm32")]
23pub use futures::lock::Mutex;
24#[cfg(target_arch = "wasm32")]
25use futures::SinkExt;
26#[cfg(target_arch = "wasm32")]
27type SendError = futures::channel::mpsc::SendError;
28
29use crate::agent::executor::event_helper::EventHelper;
30use crate::agent::executor::memory_helper::MemoryHelper;
31use crate::agent::executor::tool_processor::ToolProcessor;
32use crate::agent::hooks::{AgentHooks, HookOutcome};
33use crate::channel::{channel, Sender};
34use crate::utils::{receiver_into_stream, spawn_future};
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct ReActAgentOutput {
39 pub response: String,
40 pub tool_calls: Vec<ToolCallResult>,
41 pub done: bool,
42}
43
44impl From<ReActAgentOutput> for Value {
45 fn from(output: ReActAgentOutput) -> Self {
46 serde_json::to_value(output).unwrap_or(Value::Null)
47 }
48}
49impl From<ReActAgentOutput> for String {
50 fn from(output: ReActAgentOutput) -> Self {
51 output.response
52 }
53}
54
55impl ReActAgentOutput {
56 #[allow(clippy::result_large_err)]
58 pub fn extract_agent_output<T>(val: Value) -> Result<T, ReActExecutorError>
59 where
60 T: for<'de> serde::Deserialize<'de>,
61 {
62 let react_output: Self = serde_json::from_value(val)
63 .map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))?;
64 serde_json::from_str(&react_output.response)
65 .map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))
66 }
67}
68
69#[derive(Error, Debug)]
70pub enum ReActExecutorError {
71 #[error("LLM error: {0}")]
72 LLMError(String),
73
74 #[error("Maximum turns exceeded: {max_turns}")]
75 MaxTurnsExceeded { max_turns: usize },
76
77 #[error("Other error: {0}")]
78 Other(String),
79
80 #[cfg(not(target_arch = "wasm32"))]
81 #[error("Event error: {0}")]
82 EventError(#[from] SendError<Event>),
83
84 #[cfg(target_arch = "wasm32")]
85 #[error("Event error: {0}")]
86 EventError(#[from] SendError),
87
88 #[error("Extracting Agent Output Error: {0}")]
89 AgentOutputError(String),
90}
91
92#[derive(Debug)]
94pub struct ReActAgent<T: AgentDeriveT> {
95 inner: Arc<T>,
96}
97
98impl<T: AgentDeriveT> Clone for ReActAgent<T> {
99 fn clone(&self) -> Self {
100 Self {
101 inner: Arc::clone(&self.inner),
102 }
103 }
104}
105
106impl<T: AgentDeriveT> ReActAgent<T> {
107 pub fn new(inner: T) -> Self {
108 Self {
109 inner: Arc::new(inner),
110 }
111 }
112}
113
114impl<T: AgentDeriveT> Deref for ReActAgent<T> {
115 type Target = T;
116
117 fn deref(&self) -> &Self::Target {
118 &self.inner
119 }
120}
121
122#[async_trait]
124impl<T: AgentDeriveT> AgentDeriveT for ReActAgent<T> {
125 type Output = <T as AgentDeriveT>::Output;
126
127 fn description(&self) -> &'static str {
128 self.inner.description()
129 }
130
131 fn output_schema(&self) -> Option<Value> {
132 self.inner.output_schema()
133 }
134
135 fn name(&self) -> &'static str {
136 self.inner.name()
137 }
138
139 fn tools(&self) -> Vec<Box<dyn ToolT>> {
140 self.inner.tools()
141 }
142}
143
144#[async_trait]
145impl<T> AgentHooks for ReActAgent<T>
146where
147 T: AgentDeriveT + AgentHooks + Send + Sync + 'static,
148{
149 async fn on_agent_create(&self) {
150 self.inner.on_agent_create().await
151 }
152
153 async fn on_run_start(&self, task: &Task, ctx: &Context) -> HookOutcome {
154 self.inner.on_run_start(task, ctx).await
155 }
156
157 async fn on_run_complete(&self, task: &Task, result: &Self::Output, ctx: &Context) {
158 self.inner.on_run_complete(task, result, ctx).await
159 }
160
161 async fn on_turn_start(&self, turn_index: usize, ctx: &Context) {
162 self.inner.on_turn_start(turn_index, ctx).await
163 }
164
165 async fn on_turn_complete(&self, turn_index: usize, ctx: &Context) {
166 self.inner.on_turn_complete(turn_index, ctx).await
167 }
168
169 async fn on_tool_call(&self, tool_call: &ToolCall, ctx: &Context) -> HookOutcome {
170 self.inner.on_tool_call(tool_call, ctx).await
171 }
172
173 async fn on_tool_start(&self, tool_call: &ToolCall, ctx: &Context) {
174 self.inner.on_tool_start(tool_call, ctx).await
175 }
176
177 async fn on_tool_result(&self, tool_call: &ToolCall, result: &ToolCallResult, ctx: &Context) {
178 self.inner.on_tool_result(tool_call, result, ctx).await
179 }
180
181 async fn on_tool_error(&self, tool_call: &ToolCall, err: Value, ctx: &Context) {
182 self.inner.on_tool_error(tool_call, err, ctx).await
183 }
184 async fn on_agent_shutdown(&self) {
185 self.inner.on_agent_shutdown().await
186 }
187}
188
189impl<T: AgentDeriveT + AgentHooks> ReActAgent<T> {
190 async fn process_turn(
192 &self,
193 context: &Context,
194 tools: &[Box<dyn ToolT>],
195 ) -> Result<TurnResult<ReActAgentOutput>, ReActExecutorError> {
196 let messages = self.prepare_messages(context).await;
197 let response = self.get_llm_response(context, &messages, tools).await?;
198 let response_text = response.text().unwrap_or_default();
199
200 if let Some(tool_calls) = response.tool_calls() {
201 self.handle_tool_calls(context, tools, tool_calls.clone(), response_text)
202 .await
203 } else {
204 self.handle_text_response(context, response_text).await
205 }
206 }
207
208 async fn get_llm_response(
210 &self,
211 context: &Context,
212 messages: &[ChatMessage],
213 tools: &[Box<dyn ToolT>],
214 ) -> Result<Box<dyn autoagents_llm::chat::ChatResponse>, ReActExecutorError> {
215 let llm = context.llm();
216 let agent_config = context.config();
217 let tools_serialized: Vec<Tool> = tools.iter().map(Tool::from).collect();
218
219 llm.chat(
220 messages,
221 if tools.is_empty() {
222 None
223 } else {
224 Some(&tools_serialized)
225 },
226 agent_config.output_schema.clone(),
227 )
228 .await
229 .map_err(|e| ReActExecutorError::LLMError(e.to_string()))
230 }
231
232 async fn handle_tool_calls(
234 &self,
235 context: &Context,
236 tools: &[Box<dyn ToolT>],
237 tool_calls: Vec<ToolCall>,
238 response_text: String,
239 ) -> Result<TurnResult<ReActAgentOutput>, ReActExecutorError> {
240 let tx_event = context.tx().ok();
241
242 let mut tool_results = Vec::new();
244 for call in &tool_calls {
245 if let Some(result) = ToolProcessor::process_single_tool_call_with_hooks(
246 self, context, tools, call, &tx_event,
247 )
248 .await
249 {
250 tool_results.push(result);
251 }
252 }
253
254 MemoryHelper::store_tool_interaction(
256 &context.memory(),
257 &tool_calls,
258 &tool_results,
259 &response_text,
260 )
261 .await;
262
263 {
265 let state = context.state();
266 #[cfg(not(target_arch = "wasm32"))]
267 if let Ok(mut guard) = state.try_lock() {
268 for result in &tool_results {
269 guard.record_tool_call(result.clone());
270 }
271 };
272 #[cfg(target_arch = "wasm32")]
273 if let Some(mut guard) = state.try_lock() {
274 for result in &tool_results {
275 guard.record_tool_call(result.clone());
276 }
277 };
278 }
279
280 Ok(TurnResult::Continue(Some(ReActAgentOutput {
281 response: response_text,
282 done: true,
283 tool_calls: tool_results,
284 })))
285 }
286
287 async fn handle_text_response(
289 &self,
290 context: &Context,
291 response_text: String,
292 ) -> Result<TurnResult<ReActAgentOutput>, ReActExecutorError> {
293 if !response_text.is_empty() {
294 MemoryHelper::store_assistant_response(&context.memory(), response_text.clone()).await;
295 }
296
297 Ok(TurnResult::Complete(ReActAgentOutput {
298 response: response_text,
299 done: true,
300 tool_calls: vec![],
301 }))
302 }
303
304 async fn prepare_messages(&self, context: &Context) -> Vec<ChatMessage> {
306 let mut messages = vec![ChatMessage {
307 role: ChatRole::System,
308 message_type: MessageType::Text,
309 content: context.config().description.clone(),
310 }];
311
312 let recalled = MemoryHelper::recall_messages(&context.memory()).await;
313 messages.extend(recalled);
314
315 messages
316 }
317
318 async fn process_streaming_turn(
320 &self,
321 context: &Context,
322 tools: &[Box<dyn ToolT>],
323 tx: &mut Sender<Result<ReActAgentOutput, ReActExecutorError>>,
324 submission_id: SubmissionId,
325 ) -> Result<StreamingTurnResult, ReActExecutorError> {
326 let messages = self.prepare_messages(context).await;
327 let mut stream = self.get_llm_stream(context, &messages, tools).await?;
328
329 let mut response_text = String::new();
330 let mut tool_calls_map: HashMap<usize, (Option<String>, Option<String>, String)> =
331 HashMap::new();
332
333 while let Some(chunk_result) = stream.next().await {
335 let chunk = chunk_result.map_err(|e| ReActExecutorError::LLMError(e.to_string()))?;
336
337 if let Some(choice) = chunk.choices.first() {
338 if let Some(content) = &choice.delta.content {
340 response_text.push_str(content);
341 let _ = tx
342 .send(Ok(ReActAgentOutput {
343 response: content.to_string(),
344 tool_calls: vec![],
345 done: false,
346 }))
347 .await;
348 }
349
350 self.process_stream_tool_calls(&mut tool_calls_map, choice);
352
353 let tx_event = context.tx().ok();
355 EventHelper::send_stream_chunk(&tx_event, submission_id, choice.clone()).await;
356 }
357 }
358
359 self.finalize_stream_tool_calls(
361 context,
362 tools,
363 tool_calls_map,
364 submission_id,
365 response_text,
366 )
367 .await
368 }
369
370 async fn get_llm_stream(
372 &self,
373 context: &Context,
374 messages: &[ChatMessage],
375 tools: &[Box<dyn ToolT>],
376 ) -> Result<
377 Pin<Box<dyn Stream<Item = Result<autoagents_llm::chat::StreamResponse, LLMError>> + Send>>,
378 ReActExecutorError,
379 > {
380 let llm = context.llm();
381 let agent_config = context.config();
382 let tools_serialized: Vec<Tool> = tools.iter().map(Tool::from).collect();
383
384 llm.chat_stream_struct(
385 messages,
386 if tools.is_empty() {
387 None
388 } else {
389 Some(&tools_serialized)
390 },
391 agent_config.output_schema.clone(),
392 )
393 .await
394 .map_err(|e| ReActExecutorError::LLMError(e.to_string()))
395 }
396
397 fn process_stream_tool_calls(
399 &self,
400 tool_calls_map: &mut HashMap<usize, (Option<String>, Option<String>, String)>,
401 choice: &StreamChoice,
402 ) {
403 if let Some(tool_call_deltas) = &choice.delta.tool_calls {
404 for delta in tool_call_deltas {
405 let entry =
406 tool_calls_map
407 .entry(delta.index)
408 .or_insert((None, None, String::new()));
409
410 if let Some(function) = &delta.function {
411 if !function.name.is_empty() {
412 entry.0 = Some(function.name.clone());
413 }
414 entry.2.push_str(&function.arguments);
415 }
416 }
417 }
418 }
419
420 async fn finalize_stream_tool_calls(
422 &self,
423 context: &Context,
424 tools: &[Box<dyn ToolT>],
425 tool_calls_map: HashMap<usize, (Option<String>, Option<String>, String)>,
426 submission_id: SubmissionId,
427 response_text: String,
428 ) -> Result<StreamingTurnResult, ReActExecutorError> {
429 if tool_calls_map.is_empty() {
430 if !response_text.is_empty() {
431 MemoryHelper::store_assistant_response(&context.memory(), response_text.clone())
432 .await;
433 }
434 return Ok(StreamingTurnResult::Complete(response_text));
435 }
436
437 let mut sorted_calls: Vec<_> = tool_calls_map.into_iter().collect();
439 sorted_calls.sort_by_key(|(index, _)| *index);
440
441 let collected_tool_calls: Vec<ToolCall> = sorted_calls
442 .into_iter()
443 .filter_map(|(_, (name, id, args))| {
444 name.map(|name| ToolCall {
445 id: id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
446 call_type: "function".to_string(),
447 function: FunctionCall {
448 name,
449 arguments: args,
450 },
451 })
452 })
453 .collect();
454
455 let tx_event = context.tx().ok();
457 for tool_call in &collected_tool_calls {
458 EventHelper::send_stream_tool_call(
459 &tx_event,
460 submission_id,
461 serde_json::to_value(tool_call).unwrap_or(Value::Null),
462 )
463 .await;
464 }
465
466 let tool_results =
468 ToolProcessor::process_tool_calls(tools, collected_tool_calls.clone(), tx_event).await;
469
470 MemoryHelper::store_tool_interaction(
472 &context.memory(),
473 &collected_tool_calls,
474 &tool_results,
475 &response_text,
476 )
477 .await;
478
479 let state = context.state();
481 let mut guard = state.lock().await;
482 for result in &tool_results {
483 guard.record_tool_call(result.clone());
484 }
485
486 Ok(StreamingTurnResult::ToolCallsProcessed(tool_results))
487 }
488}
489
490#[async_trait]
492impl<T: AgentDeriveT + AgentHooks> AgentExecutor for ReActAgent<T> {
493 type Output = ReActAgentOutput;
494 type Error = ReActExecutorError;
495
496 fn config(&self) -> ExecutorConfig {
497 ExecutorConfig { max_turns: 10 }
498 }
499
500 async fn execute(
501 &self,
502 task: &Task,
503 context: Arc<Context>,
504 ) -> Result<Self::Output, Self::Error> {
505 MemoryHelper::store_user_message(
507 &context.memory(),
508 task.prompt.clone(),
509 task.image.clone(),
510 )
511 .await;
512
513 {
515 let state = context.state();
516 #[cfg(not(target_arch = "wasm32"))]
517 if let Ok(mut guard) = state.try_lock() {
518 guard.record_task(task.clone());
519 };
520 #[cfg(target_arch = "wasm32")]
521 if let Some(mut guard) = state.try_lock() {
522 guard.record_task(task.clone());
523 };
524 }
525
526 let tx_event = context.tx().ok();
528 EventHelper::send_task_started(
529 &tx_event,
530 task.submission_id,
531 context.config().id,
532 task.prompt.clone(),
533 )
534 .await;
535
536 let max_turns = self.config().max_turns;
538 let mut accumulated_tool_calls = Vec::new();
539 let mut final_response = String::new();
540
541 for turn_num in 0..max_turns {
542 let tools = context.tools();
543 EventHelper::send_turn_started(&tx_event, turn_num, max_turns).await;
544
545 self.on_turn_start(turn_num, &context).await;
547
548 match self.process_turn(&context, tools).await? {
549 TurnResult::Complete(result) => {
550 if !accumulated_tool_calls.is_empty() {
551 return Ok(ReActAgentOutput {
552 response: result.response,
553 done: true,
554 tool_calls: accumulated_tool_calls,
555 });
556 }
557 EventHelper::send_turn_completed(&tx_event, turn_num, false).await;
558 self.on_turn_complete(turn_num, &context).await;
560 return Ok(result);
561 }
562 TurnResult::Continue(Some(partial_result)) => {
563 accumulated_tool_calls.extend(partial_result.tool_calls);
564 if !partial_result.response.is_empty() {
565 final_response = partial_result.response;
566 }
567 }
568 TurnResult::Continue(None) => continue,
569 }
570 }
571
572 if !final_response.is_empty() || !accumulated_tool_calls.is_empty() {
573 EventHelper::send_task_completed(&tx_event, task.submission_id, final_response.clone())
574 .await;
575 Ok(ReActAgentOutput {
576 response: final_response,
577 done: true,
578 tool_calls: accumulated_tool_calls,
579 })
580 } else {
581 Err(ReActExecutorError::MaxTurnsExceeded { max_turns })
582 }
583 }
584
585 async fn execute_stream(
586 &self,
587 task: &Task,
588 context: Arc<Context>,
589 ) -> Result<
590 Pin<Box<dyn Stream<Item = Result<ReActAgentOutput, Self::Error>> + Send>>,
591 Self::Error,
592 > {
593 MemoryHelper::store_user_message(
595 &context.memory(),
596 task.prompt.clone(),
597 task.image.clone(),
598 )
599 .await;
600
601 {
603 let state = context.state();
604 #[cfg(not(target_arch = "wasm32"))]
605 if let Ok(mut guard) = state.try_lock() {
606 guard.record_task(task.clone());
607 };
608 #[cfg(target_arch = "wasm32")]
609 if let Some(mut guard) = state.try_lock() {
610 guard.record_task(task.clone());
611 };
612 }
613
614 let tx_event = context.tx().ok();
616 EventHelper::send_task_started(
617 &tx_event,
618 task.submission_id,
619 context.config().id,
620 task.prompt.clone(),
621 )
622 .await;
623
624 let (mut tx, rx) = channel::<Result<ReActAgentOutput, ReActExecutorError>>(100);
626
627 let executor = self.clone();
629 let context_clone = context.clone();
630 let submission_id = task.submission_id;
631 let max_turns = executor.config().max_turns;
632
633 spawn_future(async move {
635 let mut accumulated_tool_calls = Vec::new();
636 let mut final_response = String::new();
637 let tools = context_clone.tools();
638
639 for turn in 0..max_turns {
640 let tx_event = context_clone.tx().ok();
642 EventHelper::send_turn_started(&tx_event, turn, max_turns).await;
643
644 match executor
646 .process_streaming_turn(&context_clone, tools, &mut tx, submission_id)
647 .await
648 {
649 Ok(StreamingTurnResult::Complete(response)) => {
650 final_response = response;
651 EventHelper::send_turn_completed(&tx_event, turn, true).await;
652 break;
653 }
654 Ok(StreamingTurnResult::ToolCallsProcessed(tool_results)) => {
655 accumulated_tool_calls.extend(tool_results);
656
657 let _ = tx
658 .send(Ok(ReActAgentOutput {
659 response: String::new(),
660 done: false,
661 tool_calls: accumulated_tool_calls.clone(),
662 }))
663 .await;
664
665 EventHelper::send_turn_completed(&tx_event, turn, false).await;
666 }
667 Err(e) => {
668 let _ = tx.send(Err(e)).await;
669 return;
670 }
671 }
672 }
673
674 let tx_event = context_clone.tx().ok();
676 EventHelper::send_stream_complete(&tx_event, submission_id).await;
677
678 let _ = tx
679 .send(Ok(ReActAgentOutput {
680 response: final_response,
681 done: true,
682 tool_calls: accumulated_tool_calls,
683 }))
684 .await;
685 });
686
687 Ok(receiver_into_stream(rx))
688 }
689}
690
691#[cfg(test)]
692mod tests {
693 use super::*;
694
695 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
696 struct TestAgentOutput {
697 value: i32,
698 message: String,
699 }
700
701 #[test]
702 fn test_extract_agent_output_success() {
703 let agent_output = TestAgentOutput {
704 value: 42,
705 message: "Hello, world!".to_string(),
706 };
707
708 let react_output = ReActAgentOutput {
709 response: serde_json::to_string(&agent_output).unwrap(),
710 done: true,
711 tool_calls: vec![],
712 };
713
714 let react_value = serde_json::to_value(react_output).unwrap();
715 let extracted: TestAgentOutput =
716 ReActAgentOutput::extract_agent_output(react_value).unwrap();
717 assert_eq!(extracted, agent_output);
718 }
719}