Skip to main content

autoagents_core/agent/executor/
turn_engine.rs

1use crate::agent::Context;
2use crate::agent::executor::event_helper::EventHelper;
3use crate::agent::executor::memory_policy::{MemoryAdapter, MemoryPolicy};
4use crate::agent::executor::tool_processor::ToolProcessor;
5use crate::agent::hooks::AgentHooks;
6use crate::agent::task::Task;
7use crate::channel::{Sender, channel};
8use crate::tool::{ToolCallResult, ToolT, to_llm_tool};
9use crate::utils::{receiver_into_stream, spawn_future};
10use autoagents_llm::ToolCall;
11use autoagents_llm::chat::{ChatMessage, ChatRole, MessageType, StreamChunk, StreamResponse, Tool};
12use autoagents_llm::error::LLMError;
13use autoagents_protocol::{Event, SubmissionId};
14use futures::{Stream, StreamExt};
15use serde_json::Value;
16use std::collections::HashSet;
17use std::pin::Pin;
18use std::sync::Arc;
19use thiserror::Error;
20
21#[cfg(not(target_arch = "wasm32"))]
22use tokio::sync::mpsc;
23
24#[cfg(target_arch = "wasm32")]
25use futures::channel::mpsc;
26
27/// Defines if tools are enabled for a given execution plan.
28#[derive(Debug, Clone, Copy)]
29pub enum ToolMode {
30    Enabled,
31    Disabled,
32}
33
34/// Defines which streaming primitive to use.
35#[derive(Debug, Clone, Copy)]
36pub enum StreamMode {
37    Structured,
38    Tool,
39}
40
41/// Configuration for the shared executor engine.
42#[derive(Debug, Clone)]
43pub struct TurnEngineConfig {
44    pub max_turns: usize,
45    pub tool_mode: ToolMode,
46    pub stream_mode: StreamMode,
47    pub memory_policy: MemoryPolicy,
48}
49
50impl TurnEngineConfig {
51    pub fn basic(max_turns: usize) -> Self {
52        Self {
53            max_turns,
54            tool_mode: ToolMode::Disabled,
55            stream_mode: StreamMode::Structured,
56            memory_policy: MemoryPolicy::basic(),
57        }
58    }
59
60    pub fn react(max_turns: usize) -> Self {
61        Self {
62            max_turns,
63            tool_mode: ToolMode::Enabled,
64            stream_mode: StreamMode::Tool,
65            memory_policy: MemoryPolicy::react(),
66        }
67    }
68}
69
70/// Normalized output emitted by the engine for a single turn.
71#[derive(Debug, Clone)]
72pub struct TurnEngineOutput {
73    pub response: String,
74    pub tool_calls: Vec<ToolCallResult>,
75}
76
77/// Streaming deltas emitted per turn.
78#[derive(Debug)]
79pub enum TurnDelta {
80    Text(String),
81    ToolResults(Vec<ToolCallResult>),
82    Done(crate::agent::executor::TurnResult<TurnEngineOutput>),
83}
84
85#[derive(Error, Debug)]
86pub enum TurnEngineError {
87    #[error("LLM error: {0}")]
88    LLMError(String),
89
90    #[error("Run aborted by hook")]
91    Aborted,
92
93    #[error("Other error: {0}")]
94    Other(String),
95}
96
97/// Per-run state for the turn engine.
98#[derive(Clone)]
99pub struct TurnState {
100    memory: MemoryAdapter,
101    stored_user: bool,
102}
103
104impl TurnState {
105    pub fn new(context: &Context, policy: MemoryPolicy) -> Self {
106        Self {
107            memory: MemoryAdapter::new(context.memory(), policy),
108            stored_user: false,
109        }
110    }
111
112    pub fn memory(&self) -> &MemoryAdapter {
113        &self.memory
114    }
115
116    pub fn stored_user(&self) -> bool {
117        self.stored_user
118    }
119
120    fn mark_user_stored(&mut self) {
121        self.stored_user = true;
122    }
123}
124
125/// Shared turn engine that handles memory, tools, and events consistently.
126#[derive(Debug, Clone)]
127pub struct TurnEngine {
128    config: TurnEngineConfig,
129}
130
131impl TurnEngine {
132    pub fn new(config: TurnEngineConfig) -> Self {
133        Self { config }
134    }
135
136    pub fn turn_state(&self, context: &Context) -> TurnState {
137        TurnState::new(context, self.config.memory_policy.clone())
138    }
139
140    pub async fn run_turn<H: AgentHooks>(
141        &self,
142        hooks: &H,
143        task: &Task,
144        context: &Context,
145        turn_state: &mut TurnState,
146        turn_index: usize,
147        max_turns: usize,
148    ) -> Result<crate::agent::executor::TurnResult<TurnEngineOutput>, TurnEngineError> {
149        let max_turns = normalize_max_turns(max_turns, self.config.max_turns);
150        let tx_event = context.tx().ok();
151        EventHelper::send_turn_started(
152            &tx_event,
153            task.submission_id,
154            context.config().id,
155            turn_index,
156            max_turns,
157        )
158        .await;
159
160        hooks.on_turn_start(turn_index, context).await;
161
162        let include_user_prompt =
163            should_include_user_prompt(turn_state.memory(), turn_state.stored_user());
164        let messages = self
165            .build_messages(context, task, turn_state.memory(), include_user_prompt)
166            .await;
167
168        if should_store_user(turn_state) {
169            turn_state.memory.store_user(task).await;
170            turn_state.mark_user_stored();
171        }
172
173        let tools = context.tools();
174        let response = self.get_llm_response(context, &messages, tools).await?;
175        let response_text = response.text().unwrap_or_default();
176
177        let tool_calls = if matches!(self.config.tool_mode, ToolMode::Enabled) {
178            response.tool_calls().unwrap_or_default()
179        } else {
180            Vec::new()
181        };
182
183        if !tool_calls.is_empty() {
184            let tool_results = process_tool_calls_with_hooks(
185                hooks,
186                context,
187                task.submission_id,
188                tools,
189                &tool_calls,
190                &tx_event,
191            )
192            .await;
193
194            turn_state
195                .memory
196                .store_tool_interaction(&tool_calls, &tool_results, &response_text)
197                .await;
198            record_tool_calls_state(context, &tool_results);
199
200            EventHelper::send_turn_completed(
201                &tx_event,
202                task.submission_id,
203                context.config().id,
204                turn_index,
205                false,
206            )
207            .await;
208            hooks.on_turn_complete(turn_index, context).await;
209
210            return Ok(crate::agent::executor::TurnResult::Continue(Some(
211                TurnEngineOutput {
212                    response: response_text,
213                    tool_calls: tool_results,
214                },
215            )));
216        }
217
218        if !response_text.is_empty() {
219            turn_state.memory.store_assistant(&response_text).await;
220        }
221
222        EventHelper::send_turn_completed(
223            &tx_event,
224            task.submission_id,
225            context.config().id,
226            turn_index,
227            true,
228        )
229        .await;
230        hooks.on_turn_complete(turn_index, context).await;
231
232        Ok(crate::agent::executor::TurnResult::Complete(
233            TurnEngineOutput {
234                response: response_text,
235                tool_calls: Vec::new(),
236            },
237        ))
238    }
239
240    pub async fn run_turn_stream<H>(
241        &self,
242        hooks: H,
243        task: &Task,
244        context: Arc<Context>,
245        turn_state: &mut TurnState,
246        turn_index: usize,
247        max_turns: usize,
248    ) -> Result<
249        Pin<Box<dyn Stream<Item = Result<TurnDelta, TurnEngineError>> + Send>>,
250        TurnEngineError,
251    >
252    where
253        H: AgentHooks + Clone + Send + Sync + 'static,
254    {
255        let max_turns = normalize_max_turns(max_turns, self.config.max_turns);
256        let include_user_prompt =
257            should_include_user_prompt(turn_state.memory(), turn_state.stored_user());
258        let messages = self
259            .build_messages(&context, task, turn_state.memory(), include_user_prompt)
260            .await;
261
262        if should_store_user(turn_state) {
263            turn_state.memory.store_user(task).await;
264            turn_state.mark_user_stored();
265        }
266
267        let (mut tx, rx) = channel::<Result<TurnDelta, TurnEngineError>>(100);
268        let engine = self.clone();
269        let context_clone = context.clone();
270        let task = task.clone();
271        let hooks = hooks.clone();
272        let memory = turn_state.memory.clone();
273        let messages = messages.clone();
274
275        spawn_future(async move {
276            let tx_event = context_clone.tx().ok();
277            EventHelper::send_turn_started(
278                &tx_event,
279                task.submission_id,
280                context_clone.config().id,
281                turn_index,
282                max_turns,
283            )
284            .await;
285            hooks.on_turn_start(turn_index, &context_clone).await;
286
287            let result = match engine.config.stream_mode {
288                StreamMode::Structured => {
289                    engine
290                        .stream_structured(&context_clone, &task, &memory, &mut tx, &messages)
291                        .await
292                }
293                StreamMode::Tool => {
294                    engine
295                        .stream_with_tools(
296                            &hooks,
297                            &context_clone,
298                            &task,
299                            context_clone.tools(),
300                            &memory,
301                            &mut tx,
302                            &messages,
303                        )
304                        .await
305                }
306            };
307
308            match result {
309                Ok(turn_result) => {
310                    let final_turn =
311                        matches!(turn_result, crate::agent::executor::TurnResult::Complete(_));
312                    EventHelper::send_turn_completed(
313                        &tx_event,
314                        task.submission_id,
315                        context_clone.config().id,
316                        turn_index,
317                        final_turn,
318                    )
319                    .await;
320                    hooks.on_turn_complete(turn_index, &context_clone).await;
321                    let _ = tx.send(Ok(TurnDelta::Done(turn_result))).await;
322                }
323                Err(err) => {
324                    let _ = tx.send(Err(err)).await;
325                }
326            }
327        });
328
329        Ok(receiver_into_stream(rx))
330    }
331
332    async fn stream_structured(
333        &self,
334        context: &Context,
335        task: &Task,
336        memory: &MemoryAdapter,
337        tx: &mut Sender<Result<TurnDelta, TurnEngineError>>,
338        messages: &[ChatMessage],
339    ) -> Result<crate::agent::executor::TurnResult<TurnEngineOutput>, TurnEngineError> {
340        let mut stream = self.get_structured_stream(context, messages).await?;
341        let mut response_text = String::new();
342
343        while let Some(chunk_result) = stream.next().await {
344            let chunk = chunk_result.map_err(|e| TurnEngineError::LLMError(e.to_string()))?;
345            let content = chunk
346                .choices
347                .first()
348                .and_then(|choice| choice.delta.content.as_ref())
349                .map_or("", |value| value)
350                .to_string();
351
352            if content.is_empty() {
353                continue;
354            }
355
356            response_text.push_str(&content);
357
358            let _ = tx.send(Ok(TurnDelta::Text(content.clone()))).await;
359
360            let tx_event = context.tx().ok();
361            EventHelper::send_stream_chunk(
362                &tx_event,
363                task.submission_id,
364                StreamChunk::Text(content),
365            )
366            .await;
367        }
368
369        if !response_text.is_empty() {
370            memory.store_assistant(&response_text).await;
371        }
372
373        Ok(crate::agent::executor::TurnResult::Complete(
374            TurnEngineOutput {
375                response: response_text,
376                tool_calls: Vec::new(),
377            },
378        ))
379    }
380
381    #[allow(clippy::too_many_arguments)]
382    async fn stream_with_tools<H: AgentHooks>(
383        &self,
384        hooks: &H,
385        context: &Context,
386        task: &Task,
387        tools: &[Box<dyn ToolT>],
388        memory: &MemoryAdapter,
389        tx: &mut Sender<Result<TurnDelta, TurnEngineError>>,
390        messages: &[ChatMessage],
391    ) -> Result<crate::agent::executor::TurnResult<TurnEngineOutput>, TurnEngineError> {
392        let mut stream = self.get_tool_stream(context, messages, tools).await?;
393        let mut response_text = String::new();
394        let mut tool_calls = Vec::new();
395        let mut tool_call_ids = HashSet::new();
396
397        while let Some(chunk_result) = stream.next().await {
398            let chunk = chunk_result.map_err(|e| TurnEngineError::LLMError(e.to_string()))?;
399            let chunk_clone = chunk.clone();
400
401            match chunk {
402                StreamChunk::Text(content) => {
403                    response_text.push_str(&content);
404                    let _ = tx.send(Ok(TurnDelta::Text(content.clone()))).await;
405                }
406                StreamChunk::ToolUseComplete {
407                    index: _,
408                    tool_call,
409                } => {
410                    if tool_call_ids.insert(tool_call.id.clone()) {
411                        tool_calls.push(tool_call.clone());
412                        let tx_event = context.tx().ok();
413                        EventHelper::send_stream_tool_call(
414                            &tx_event,
415                            task.submission_id,
416                            serde_json::to_value(tool_call).unwrap_or(Value::Null),
417                        )
418                        .await;
419                    }
420                }
421                StreamChunk::Usage(_) => {}
422                _ => {}
423            }
424
425            let tx_event = context.tx().ok();
426            EventHelper::send_stream_chunk(&tx_event, task.submission_id, chunk_clone).await;
427        }
428
429        if tool_calls.is_empty() {
430            if !response_text.is_empty() {
431                memory.store_assistant(&response_text).await;
432            }
433            return Ok(crate::agent::executor::TurnResult::Complete(
434                TurnEngineOutput {
435                    response: response_text,
436                    tool_calls: Vec::new(),
437                },
438            ));
439        }
440
441        let tx_event = context.tx().ok();
442        let tool_results = process_tool_calls_with_hooks(
443            hooks,
444            context,
445            task.submission_id,
446            tools,
447            &tool_calls,
448            &tx_event,
449        )
450        .await;
451
452        memory
453            .store_tool_interaction(&tool_calls, &tool_results, &response_text)
454            .await;
455        record_tool_calls_state(context, &tool_results);
456
457        let _ = tx
458            .send(Ok(TurnDelta::ToolResults(tool_results.clone())))
459            .await;
460
461        Ok(crate::agent::executor::TurnResult::Continue(Some(
462            TurnEngineOutput {
463                response: response_text,
464                tool_calls: tool_results,
465            },
466        )))
467    }
468
469    async fn get_llm_response(
470        &self,
471        context: &Context,
472        messages: &[ChatMessage],
473        tools: &[Box<dyn ToolT>],
474    ) -> Result<Box<dyn autoagents_llm::chat::ChatResponse>, TurnEngineError> {
475        let llm = context.llm();
476        let output_schema = context.config().output_schema.clone();
477
478        if matches!(self.config.tool_mode, ToolMode::Enabled) && !tools.is_empty() {
479            let tools_serialized: Vec<Tool> = tools.iter().map(to_llm_tool).collect();
480            llm.chat_with_tools(messages, Some(&tools_serialized), output_schema)
481                .await
482                .map_err(|e| TurnEngineError::LLMError(e.to_string()))
483        } else {
484            llm.chat(messages, output_schema)
485                .await
486                .map_err(|e| TurnEngineError::LLMError(e.to_string()))
487        }
488    }
489
490    async fn get_structured_stream(
491        &self,
492        context: &Context,
493        messages: &[ChatMessage],
494    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamResponse, LLMError>> + Send>>, TurnEngineError>
495    {
496        context
497            .llm()
498            .chat_stream_struct(messages, None, context.config().output_schema.clone())
499            .await
500            .map_err(|e| TurnEngineError::LLMError(e.to_string()))
501    }
502
503    async fn get_tool_stream(
504        &self,
505        context: &Context,
506        messages: &[ChatMessage],
507        tools: &[Box<dyn ToolT>],
508    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk, LLMError>> + Send>>, TurnEngineError>
509    {
510        let tools_serialized: Vec<Tool> = tools.iter().map(to_llm_tool).collect();
511        context
512            .llm()
513            .chat_stream_with_tools(
514                messages,
515                if tools_serialized.is_empty() {
516                    None
517                } else {
518                    Some(&tools_serialized)
519                },
520                context.config().output_schema.clone(),
521            )
522            .await
523            .map_err(|e| TurnEngineError::LLMError(e.to_string()))
524    }
525
526    async fn build_messages(
527        &self,
528        context: &Context,
529        task: &Task,
530        memory: &MemoryAdapter,
531        include_user_prompt: bool,
532    ) -> Vec<ChatMessage> {
533        let system_prompt = task
534            .system_prompt
535            .as_deref()
536            .unwrap_or(&context.config().description);
537        let mut messages = vec![ChatMessage {
538            role: ChatRole::System,
539            message_type: MessageType::Text,
540            content: system_prompt.to_string(),
541        }];
542
543        let recalled = memory.recall_messages(task).await;
544        messages.extend(recalled);
545
546        if include_user_prompt {
547            messages.push(user_message(task));
548        }
549
550        messages
551    }
552}
553
554pub fn record_task_state(context: &Context, task: &Task) {
555    let state = context.state();
556    #[cfg(not(target_arch = "wasm32"))]
557    if let Ok(mut guard) = state.try_lock() {
558        guard.record_task(task.clone());
559    };
560    #[cfg(target_arch = "wasm32")]
561    if let Some(mut guard) = state.try_lock() {
562        guard.record_task(task.clone());
563    };
564}
565
566fn user_message(task: &Task) -> ChatMessage {
567    if let Some((mime, image_data)) = &task.image {
568        ChatMessage {
569            role: ChatRole::User,
570            message_type: MessageType::Image(((*mime).into(), image_data.clone())),
571            content: task.prompt.clone(),
572        }
573    } else {
574        ChatMessage {
575            role: ChatRole::User,
576            message_type: MessageType::Text,
577            content: task.prompt.clone(),
578        }
579    }
580}
581
582fn should_include_user_prompt(memory: &MemoryAdapter, stored_user: bool) -> bool {
583    if !memory.is_enabled() {
584        return true;
585    }
586    if !memory.policy().recall {
587        return true;
588    }
589    if !memory.policy().store_user {
590        return true;
591    }
592    !stored_user
593}
594
595fn should_store_user(turn_state: &TurnState) -> bool {
596    if !turn_state.memory.is_enabled() {
597        return false;
598    }
599    if !turn_state.memory.policy().store_user {
600        return false;
601    }
602    !turn_state.stored_user
603}
604
605fn normalize_max_turns(max_turns: usize, fallback: usize) -> usize {
606    if max_turns == 0 {
607        return fallback.max(1);
608    }
609    max_turns
610}
611
612fn record_tool_calls_state(context: &Context, tool_results: &[ToolCallResult]) {
613    if tool_results.is_empty() {
614        return;
615    }
616    let state = context.state();
617    #[cfg(not(target_arch = "wasm32"))]
618    if let Ok(mut guard) = state.try_lock() {
619        for result in tool_results {
620            guard.record_tool_call(result.clone());
621        }
622    };
623    #[cfg(target_arch = "wasm32")]
624    if let Some(mut guard) = state.try_lock() {
625        for result in tool_results {
626            guard.record_tool_call(result.clone());
627        }
628    };
629}
630
631async fn process_tool_calls_with_hooks<H: AgentHooks>(
632    hooks: &H,
633    context: &Context,
634    submission_id: SubmissionId,
635    tools: &[Box<dyn ToolT>],
636    tool_calls: &[ToolCall],
637    tx_event: &Option<mpsc::Sender<Event>>,
638) -> Vec<ToolCallResult> {
639    let mut results = Vec::new();
640    for call in tool_calls {
641        if let Some(result) = ToolProcessor::process_single_tool_call_with_hooks(
642            hooks,
643            context,
644            submission_id,
645            tools,
646            call,
647            tx_event,
648        )
649        .await
650        {
651            results.push(result);
652        }
653    }
654    results
655}