1use crate::agent::executor::AgentExecutor;
2use crate::agent::task::Task;
3use crate::agent::{AgentDeriveT, Context, ExecutorConfig, TurnResult};
4use crate::protocol::{Event, StreamingTurnResult, SubmissionId};
5use crate::tool::{to_llm_tool, ToolCallResult, ToolT};
6use async_trait::async_trait;
7use autoagents_llm::chat::{ChatMessage, ChatRole, MessageType, StreamChoice, Tool};
8use autoagents_llm::error::LLMError;
9use autoagents_llm::{FunctionCall, ToolCall};
10use futures::{Stream, StreamExt};
11use serde::{Deserialize, Serialize};
12use serde_json::Value;
13use std::collections::HashMap;
14use std::ops::Deref;
15use std::pin::Pin;
16use std::sync::Arc;
17use thiserror::Error;
18
19#[cfg(not(target_arch = "wasm32"))]
20pub use tokio::sync::mpsc::error::SendError;
21
22#[cfg(target_arch = "wasm32")]
23pub use futures::lock::Mutex;
24#[cfg(target_arch = "wasm32")]
25use futures::SinkExt;
26#[cfg(target_arch = "wasm32")]
27type SendError = futures::channel::mpsc::SendError;
28
29use crate::agent::executor::event_helper::EventHelper;
30use crate::agent::executor::memory_helper::MemoryHelper;
31use crate::agent::executor::tool_processor::ToolProcessor;
32use crate::agent::hooks::{AgentHooks, HookOutcome};
33use crate::channel::{channel, Sender};
34use crate::utils::{receiver_into_stream, spawn_future};
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
38pub struct ReActAgentOutput {
39 pub response: String,
40 pub tool_calls: Vec<ToolCallResult>,
41 pub done: bool,
42}
43
44impl From<ReActAgentOutput> for Value {
45 fn from(output: ReActAgentOutput) -> Self {
46 serde_json::to_value(output).unwrap_or(Value::Null)
47 }
48}
49impl From<ReActAgentOutput> for String {
50 fn from(output: ReActAgentOutput) -> Self {
51 output.response
52 }
53}
54
55impl ReActAgentOutput {
56 pub fn try_parse<T: for<'de> serde::Deserialize<'de>>(&self) -> Result<T, serde_json::Error> {
59 serde_json::from_str::<T>(&self.response)
60 }
61
62 pub fn parse_or_map<T, F>(&self, fallback: F) -> T
66 where
67 T: for<'de> serde::Deserialize<'de>,
68 F: FnOnce(&str) -> T,
69 {
70 self.try_parse::<T>()
71 .unwrap_or_else(|_| fallback(&self.response))
72 }
73}
74
75impl ReActAgentOutput {
76 #[allow(clippy::result_large_err)]
78 pub fn extract_agent_output<T>(val: Value) -> Result<T, ReActExecutorError>
79 where
80 T: for<'de> serde::Deserialize<'de>,
81 {
82 let react_output: Self = serde_json::from_value(val)
83 .map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))?;
84 serde_json::from_str(&react_output.response)
85 .map_err(|e| ReActExecutorError::AgentOutputError(e.to_string()))
86 }
87}
88
89#[derive(Error, Debug)]
90pub enum ReActExecutorError {
91 #[error("LLM error: {0}")]
92 LLMError(String),
93
94 #[error("Maximum turns exceeded: {max_turns}")]
95 MaxTurnsExceeded { max_turns: usize },
96
97 #[error("Other error: {0}")]
98 Other(String),
99
100 #[cfg(not(target_arch = "wasm32"))]
101 #[error("Event error: {0}")]
102 EventError(#[from] SendError<Event>),
103
104 #[cfg(target_arch = "wasm32")]
105 #[error("Event error: {0}")]
106 EventError(#[from] SendError),
107
108 #[error("Extracting Agent Output Error: {0}")]
109 AgentOutputError(String),
110}
111
112#[derive(Debug)]
117pub struct ReActAgent<T: AgentDeriveT> {
118 inner: Arc<T>,
119}
120
121impl<T: AgentDeriveT> Clone for ReActAgent<T> {
122 fn clone(&self) -> Self {
123 Self {
124 inner: Arc::clone(&self.inner),
125 }
126 }
127}
128
129impl<T: AgentDeriveT> ReActAgent<T> {
130 pub fn new(inner: T) -> Self {
131 Self {
132 inner: Arc::new(inner),
133 }
134 }
135}
136
137impl<T: AgentDeriveT> Deref for ReActAgent<T> {
138 type Target = T;
139
140 fn deref(&self) -> &Self::Target {
141 &self.inner
142 }
143}
144
145#[async_trait]
147impl<T: AgentDeriveT> AgentDeriveT for ReActAgent<T> {
148 type Output = <T as AgentDeriveT>::Output;
149
150 fn description(&self) -> &'static str {
151 self.inner.description()
152 }
153
154 fn output_schema(&self) -> Option<Value> {
155 self.inner.output_schema()
156 }
157
158 fn name(&self) -> &'static str {
159 self.inner.name()
160 }
161
162 fn tools(&self) -> Vec<Box<dyn ToolT>> {
163 self.inner.tools()
164 }
165}
166
167#[async_trait]
168impl<T> AgentHooks for ReActAgent<T>
169where
170 T: AgentDeriveT + AgentHooks + Send + Sync + 'static,
171{
172 async fn on_agent_create(&self) {
173 self.inner.on_agent_create().await
174 }
175
176 async fn on_run_start(&self, task: &Task, ctx: &Context) -> HookOutcome {
177 self.inner.on_run_start(task, ctx).await
178 }
179
180 async fn on_run_complete(&self, task: &Task, result: &Self::Output, ctx: &Context) {
181 self.inner.on_run_complete(task, result, ctx).await
182 }
183
184 async fn on_turn_start(&self, turn_index: usize, ctx: &Context) {
185 self.inner.on_turn_start(turn_index, ctx).await
186 }
187
188 async fn on_turn_complete(&self, turn_index: usize, ctx: &Context) {
189 self.inner.on_turn_complete(turn_index, ctx).await
190 }
191
192 async fn on_tool_call(&self, tool_call: &ToolCall, ctx: &Context) -> HookOutcome {
193 self.inner.on_tool_call(tool_call, ctx).await
194 }
195
196 async fn on_tool_start(&self, tool_call: &ToolCall, ctx: &Context) {
197 self.inner.on_tool_start(tool_call, ctx).await
198 }
199
200 async fn on_tool_result(&self, tool_call: &ToolCall, result: &ToolCallResult, ctx: &Context) {
201 self.inner.on_tool_result(tool_call, result, ctx).await
202 }
203
204 async fn on_tool_error(&self, tool_call: &ToolCall, err: Value, ctx: &Context) {
205 self.inner.on_tool_error(tool_call, err, ctx).await
206 }
207 async fn on_agent_shutdown(&self) {
208 self.inner.on_agent_shutdown().await
209 }
210}
211
212impl<T: AgentDeriveT + AgentHooks> ReActAgent<T> {
213 async fn process_turn(
215 &self,
216 context: &Context,
217 tools: &[Box<dyn ToolT>],
218 ) -> Result<TurnResult<ReActAgentOutput>, ReActExecutorError> {
219 let messages = self.prepare_messages(context).await;
220 let response = self.get_llm_response(context, &messages, tools).await?;
221 let response_text = response.text().unwrap_or_default();
222
223 if let Some(tool_calls) = response.tool_calls() {
224 self.handle_tool_calls(context, tools, tool_calls.clone(), response_text)
225 .await
226 } else {
227 self.handle_text_response(context, response_text).await
228 }
229 }
230
231 async fn get_llm_response(
233 &self,
234 context: &Context,
235 messages: &[ChatMessage],
236 tools: &[Box<dyn ToolT>],
237 ) -> Result<Box<dyn autoagents_llm::chat::ChatResponse>, ReActExecutorError> {
238 let llm = context.llm();
239 let agent_config = context.config();
240 let tools_serialized: Vec<Tool> = tools.iter().map(to_llm_tool).collect();
241
242 llm.chat(
243 messages,
244 if tools.is_empty() {
245 None
246 } else {
247 Some(&tools_serialized)
248 },
249 agent_config.output_schema.clone(),
250 )
251 .await
252 .map_err(|e| ReActExecutorError::LLMError(e.to_string()))
253 }
254
255 async fn handle_tool_calls(
257 &self,
258 context: &Context,
259 tools: &[Box<dyn ToolT>],
260 tool_calls: Vec<ToolCall>,
261 response_text: String,
262 ) -> Result<TurnResult<ReActAgentOutput>, ReActExecutorError> {
263 let tx_event = context.tx().ok();
264
265 let mut tool_results = Vec::new();
267 for call in &tool_calls {
268 if let Some(result) = ToolProcessor::process_single_tool_call_with_hooks(
269 self, context, tools, call, &tx_event,
270 )
271 .await
272 {
273 tool_results.push(result);
274 }
275 }
276
277 MemoryHelper::store_tool_interaction(
279 &context.memory(),
280 &tool_calls,
281 &tool_results,
282 &response_text,
283 )
284 .await;
285
286 {
288 let state = context.state();
289 #[cfg(not(target_arch = "wasm32"))]
290 if let Ok(mut guard) = state.try_lock() {
291 for result in &tool_results {
292 guard.record_tool_call(result.clone());
293 }
294 };
295 #[cfg(target_arch = "wasm32")]
296 if let Some(mut guard) = state.try_lock() {
297 for result in &tool_results {
298 guard.record_tool_call(result.clone());
299 }
300 };
301 }
302
303 Ok(TurnResult::Continue(Some(ReActAgentOutput {
304 response: response_text,
305 done: true,
306 tool_calls: tool_results,
307 })))
308 }
309
310 async fn handle_text_response(
312 &self,
313 context: &Context,
314 response_text: String,
315 ) -> Result<TurnResult<ReActAgentOutput>, ReActExecutorError> {
316 if !response_text.is_empty() {
317 MemoryHelper::store_assistant_response(&context.memory(), response_text.clone()).await;
318 }
319
320 Ok(TurnResult::Complete(ReActAgentOutput {
321 response: response_text,
322 done: true,
323 tool_calls: vec![],
324 }))
325 }
326
327 async fn prepare_messages(&self, context: &Context) -> Vec<ChatMessage> {
329 let mut messages = vec![ChatMessage {
330 role: ChatRole::System,
331 message_type: MessageType::Text,
332 content: context.config().description.clone(),
333 }];
334
335 let recalled = MemoryHelper::recall_messages(&context.memory()).await;
336 messages.extend(recalled);
337
338 messages
339 }
340
341 async fn process_streaming_turn(
343 &self,
344 context: &Context,
345 tools: &[Box<dyn ToolT>],
346 tx: &mut Sender<Result<ReActAgentOutput, ReActExecutorError>>,
347 submission_id: SubmissionId,
348 ) -> Result<StreamingTurnResult, ReActExecutorError> {
349 let messages = self.prepare_messages(context).await;
350 let mut stream = self.get_llm_stream(context, &messages, tools).await?;
351
352 let mut response_text = String::new();
353 let mut tool_calls_map: HashMap<usize, (Option<String>, Option<String>, String)> =
354 HashMap::new();
355
356 while let Some(chunk_result) = stream.next().await {
358 let chunk = chunk_result.map_err(|e| ReActExecutorError::LLMError(e.to_string()))?;
359
360 if let Some(choice) = chunk.choices.first() {
361 if let Some(content) = &choice.delta.content {
363 response_text.push_str(content);
364 let _ = tx
365 .send(Ok(ReActAgentOutput {
366 response: content.to_string(),
367 tool_calls: vec![],
368 done: false,
369 }))
370 .await;
371 }
372
373 self.process_stream_tool_calls(&mut tool_calls_map, choice);
375
376 let tx_event = context.tx().ok();
378 EventHelper::send_stream_chunk(&tx_event, submission_id, choice.clone()).await;
379 }
380 }
381
382 self.finalize_stream_tool_calls(
384 context,
385 tools,
386 tool_calls_map,
387 submission_id,
388 response_text,
389 )
390 .await
391 }
392
393 async fn get_llm_stream(
395 &self,
396 context: &Context,
397 messages: &[ChatMessage],
398 tools: &[Box<dyn ToolT>],
399 ) -> Result<
400 Pin<Box<dyn Stream<Item = Result<autoagents_llm::chat::StreamResponse, LLMError>> + Send>>,
401 ReActExecutorError,
402 > {
403 let llm = context.llm();
404 let agent_config = context.config();
405 let tools_serialized: Vec<Tool> = tools.iter().map(to_llm_tool).collect();
406
407 llm.chat_stream_struct(
408 messages,
409 if tools.is_empty() {
410 None
411 } else {
412 Some(&tools_serialized)
413 },
414 agent_config.output_schema.clone(),
415 )
416 .await
417 .map_err(|e| ReActExecutorError::LLMError(e.to_string()))
418 }
419
420 fn process_stream_tool_calls(
422 &self,
423 tool_calls_map: &mut HashMap<usize, (Option<String>, Option<String>, String)>,
424 choice: &StreamChoice,
425 ) {
426 if let Some(tool_call_deltas) = &choice.delta.tool_calls {
427 for delta in tool_call_deltas {
428 let entry =
429 tool_calls_map
430 .entry(delta.index)
431 .or_insert((None, None, String::new()));
432
433 if let Some(function) = &delta.function {
434 if !function.name.is_empty() {
435 entry.0 = Some(function.name.clone());
436 }
437 entry.2.push_str(&function.arguments);
438 }
439 }
440 }
441 }
442
443 async fn finalize_stream_tool_calls(
445 &self,
446 context: &Context,
447 tools: &[Box<dyn ToolT>],
448 tool_calls_map: HashMap<usize, (Option<String>, Option<String>, String)>,
449 submission_id: SubmissionId,
450 response_text: String,
451 ) -> Result<StreamingTurnResult, ReActExecutorError> {
452 if tool_calls_map.is_empty() {
453 if !response_text.is_empty() {
454 MemoryHelper::store_assistant_response(&context.memory(), response_text.clone())
455 .await;
456 }
457 return Ok(StreamingTurnResult::Complete(response_text));
458 }
459
460 let mut sorted_calls: Vec<_> = tool_calls_map.into_iter().collect();
462 sorted_calls.sort_by_key(|(index, _)| *index);
463
464 let collected_tool_calls: Vec<ToolCall> = sorted_calls
465 .into_iter()
466 .filter_map(|(_, (name, id, args))| {
467 name.map(|name| ToolCall {
468 id: id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string()),
469 call_type: "function".to_string(),
470 function: FunctionCall {
471 name,
472 arguments: args,
473 },
474 })
475 })
476 .collect();
477
478 let tx_event = context.tx().ok();
480 for tool_call in &collected_tool_calls {
481 EventHelper::send_stream_tool_call(
482 &tx_event,
483 submission_id,
484 serde_json::to_value(tool_call).unwrap_or(Value::Null),
485 )
486 .await;
487 }
488
489 let tool_results =
491 ToolProcessor::process_tool_calls(tools, collected_tool_calls.clone(), tx_event).await;
492
493 MemoryHelper::store_tool_interaction(
495 &context.memory(),
496 &collected_tool_calls,
497 &tool_results,
498 &response_text,
499 )
500 .await;
501
502 let state = context.state();
504 let mut guard = state.lock().await;
505 for result in &tool_results {
506 guard.record_tool_call(result.clone());
507 }
508
509 Ok(StreamingTurnResult::ToolCallsProcessed(tool_results))
510 }
511}
512
513#[async_trait]
515impl<T: AgentDeriveT + AgentHooks> AgentExecutor for ReActAgent<T> {
516 type Output = ReActAgentOutput;
517 type Error = ReActExecutorError;
518
519 fn config(&self) -> ExecutorConfig {
520 ExecutorConfig { max_turns: 10 }
521 }
522
523 async fn execute(
524 &self,
525 task: &Task,
526 context: Arc<Context>,
527 ) -> Result<Self::Output, Self::Error> {
528 MemoryHelper::store_user_message(
530 &context.memory(),
531 task.prompt.clone(),
532 task.image.clone(),
533 )
534 .await;
535
536 {
538 let state = context.state();
539 #[cfg(not(target_arch = "wasm32"))]
540 if let Ok(mut guard) = state.try_lock() {
541 guard.record_task(task.clone());
542 };
543 #[cfg(target_arch = "wasm32")]
544 if let Some(mut guard) = state.try_lock() {
545 guard.record_task(task.clone());
546 };
547 }
548
549 let tx_event = context.tx().ok();
551 EventHelper::send_task_started(
552 &tx_event,
553 task.submission_id,
554 context.config().id,
555 task.prompt.clone(),
556 context.config().name.clone(),
557 )
558 .await;
559
560 let max_turns = self.config().max_turns;
562 let mut accumulated_tool_calls = Vec::new();
563 let mut final_response = String::new();
564
565 for turn_num in 0..max_turns {
566 let tools = context.tools();
567 EventHelper::send_turn_started(&tx_event, turn_num, max_turns).await;
568
569 self.on_turn_start(turn_num, &context).await;
571
572 match self.process_turn(&context, tools).await? {
573 TurnResult::Complete(result) => {
574 if !accumulated_tool_calls.is_empty() {
575 return Ok(ReActAgentOutput {
576 response: result.response,
577 done: true,
578 tool_calls: accumulated_tool_calls,
579 });
580 }
581 EventHelper::send_turn_completed(&tx_event, turn_num, false).await;
582 self.on_turn_complete(turn_num, &context).await;
584 return Ok(result);
585 }
586 TurnResult::Continue(Some(partial_result)) => {
587 accumulated_tool_calls.extend(partial_result.tool_calls);
588 if !partial_result.response.is_empty() {
589 final_response = partial_result.response;
590 }
591 }
592 TurnResult::Continue(None) => continue,
593 }
594 }
595
596 if !final_response.is_empty() || !accumulated_tool_calls.is_empty() {
597 EventHelper::send_task_completed(
598 &tx_event,
599 task.submission_id,
600 context.config().id,
601 final_response.clone(),
602 context.config().name.clone(),
603 )
604 .await;
605 Ok(ReActAgentOutput {
606 response: final_response,
607 done: true,
608 tool_calls: accumulated_tool_calls,
609 })
610 } else {
611 Err(ReActExecutorError::MaxTurnsExceeded { max_turns })
612 }
613 }
614
615 async fn execute_stream(
616 &self,
617 task: &Task,
618 context: Arc<Context>,
619 ) -> Result<
620 Pin<Box<dyn Stream<Item = Result<ReActAgentOutput, Self::Error>> + Send>>,
621 Self::Error,
622 > {
623 MemoryHelper::store_user_message(
625 &context.memory(),
626 task.prompt.clone(),
627 task.image.clone(),
628 )
629 .await;
630
631 {
633 let state = context.state();
634 #[cfg(not(target_arch = "wasm32"))]
635 if let Ok(mut guard) = state.try_lock() {
636 guard.record_task(task.clone());
637 };
638 #[cfg(target_arch = "wasm32")]
639 if let Some(mut guard) = state.try_lock() {
640 guard.record_task(task.clone());
641 };
642 }
643
644 let tx_event = context.tx().ok();
646 EventHelper::send_task_started(
647 &tx_event,
648 task.submission_id,
649 context.config().id,
650 task.prompt.clone(),
651 context.config().name.clone(),
652 )
653 .await;
654
655 let (mut tx, rx) = channel::<Result<ReActAgentOutput, ReActExecutorError>>(100);
657
658 let executor = self.clone();
660 let context_clone = context.clone();
661 let submission_id = task.submission_id;
662 let max_turns = executor.config().max_turns;
663
664 spawn_future(async move {
666 let mut accumulated_tool_calls = Vec::new();
667 let mut final_response = String::new();
668 let tools = context_clone.tools();
669
670 for turn in 0..max_turns {
671 let tx_event = context_clone.tx().ok();
673 EventHelper::send_turn_started(&tx_event, turn, max_turns).await;
674
675 match executor
677 .process_streaming_turn(&context_clone, tools, &mut tx, submission_id)
678 .await
679 {
680 Ok(StreamingTurnResult::Complete(response)) => {
681 final_response = response;
682 EventHelper::send_turn_completed(&tx_event, turn, true).await;
683 break;
684 }
685 Ok(StreamingTurnResult::ToolCallsProcessed(tool_results)) => {
686 accumulated_tool_calls.extend(tool_results);
687
688 let _ = tx
689 .send(Ok(ReActAgentOutput {
690 response: String::new(),
691 done: false,
692 tool_calls: accumulated_tool_calls.clone(),
693 }))
694 .await;
695
696 EventHelper::send_turn_completed(&tx_event, turn, false).await;
697 }
698 Err(e) => {
699 let _ = tx.send(Err(e)).await;
700 return;
701 }
702 }
703 }
704
705 let tx_event = context_clone.tx().ok();
707 EventHelper::send_stream_complete(&tx_event, submission_id).await;
708
709 let _ = tx
710 .send(Ok(ReActAgentOutput {
711 response: final_response,
712 done: true,
713 tool_calls: accumulated_tool_calls,
714 }))
715 .await;
716 });
717
718 Ok(receiver_into_stream(rx))
719 }
720}
721
722#[cfg(test)]
723mod tests {
724 use super::*;
725
726 #[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
727 struct TestAgentOutput {
728 value: i32,
729 message: String,
730 }
731
732 #[test]
733 fn test_extract_agent_output_success() {
734 let agent_output = TestAgentOutput {
735 value: 42,
736 message: "Hello, world!".to_string(),
737 };
738
739 let react_output = ReActAgentOutput {
740 response: serde_json::to_string(&agent_output).unwrap(),
741 done: true,
742 tool_calls: vec![],
743 };
744
745 let react_value = serde_json::to_value(react_output).unwrap();
746 let extracted: TestAgentOutput =
747 ReActAgentOutput::extract_agent_output(react_value).unwrap();
748 assert_eq!(extracted, agent_output);
749 }
750}