Skip to main content

autoagents_core/agent/executor/
turn_engine.rs

1use crate::agent::Context;
2use crate::agent::executor::event_helper::EventHelper;
3use crate::agent::executor::memory_policy::{MemoryAdapter, MemoryPolicy};
4use crate::agent::executor::tool_processor::ToolProcessor;
5use crate::agent::hooks::AgentHooks;
6use crate::agent::task::Task;
7use crate::channel::{Sender, channel};
8use crate::tool::{ToolCallResult, ToolT, to_llm_tool};
9use crate::utils::{receiver_into_stream, spawn_future};
10use autoagents_llm::ToolCall;
11use autoagents_llm::chat::{ChatMessage, ChatRole, MessageType, StreamChunk, StreamResponse};
12use autoagents_llm::error::LLMError;
13use autoagents_protocol::{Event, SubmissionId};
14use futures::{Stream, StreamExt};
15use serde_json::Value;
16use std::collections::HashSet;
17use std::pin::Pin;
18use std::sync::Arc;
19use thiserror::Error;
20
21#[cfg(not(target_arch = "wasm32"))]
22use tokio::sync::mpsc;
23
24#[cfg(target_arch = "wasm32")]
25use futures::channel::mpsc;
26
27/// Defines if tools are enabled for a given execution plan.
28#[derive(Debug, Clone, Copy)]
29pub enum ToolMode {
30    Enabled,
31    Disabled,
32}
33
34/// Defines which streaming primitive to use.
35#[derive(Debug, Clone, Copy)]
36pub enum StreamMode {
37    Structured,
38    Tool,
39}
40
41/// Configuration for the shared executor engine.
42#[derive(Debug, Clone)]
43pub struct TurnEngineConfig {
44    pub max_turns: usize,
45    pub tool_mode: ToolMode,
46    pub stream_mode: StreamMode,
47    pub memory_policy: MemoryPolicy,
48}
49
50impl TurnEngineConfig {
51    pub fn basic(max_turns: usize) -> Self {
52        Self {
53            max_turns,
54            tool_mode: ToolMode::Disabled,
55            stream_mode: StreamMode::Structured,
56            memory_policy: MemoryPolicy::basic(),
57        }
58    }
59
60    pub fn react(max_turns: usize) -> Self {
61        Self {
62            max_turns,
63            tool_mode: ToolMode::Enabled,
64            stream_mode: StreamMode::Tool,
65            memory_policy: MemoryPolicy::react(),
66        }
67    }
68}
69
70/// Normalized output emitted by the engine for a single turn.
71#[derive(Debug, Clone)]
72pub struct TurnEngineOutput {
73    pub response: String,
74    pub tool_calls: Vec<ToolCallResult>,
75}
76
77/// Streaming deltas emitted per turn.
78#[derive(Debug)]
79pub enum TurnDelta {
80    Text(String),
81    ToolResults(Vec<ToolCallResult>),
82    Done(crate::agent::executor::TurnResult<TurnEngineOutput>),
83}
84
85#[derive(Error, Debug)]
86pub enum TurnEngineError {
87    #[error("LLM error: {0}")]
88    LLMError(String),
89
90    #[error("Run aborted by hook")]
91    Aborted,
92
93    #[error("Other error: {0}")]
94    Other(String),
95}
96
97/// Per-run state for the turn engine.
98#[derive(Clone)]
99pub struct TurnState {
100    memory: MemoryAdapter,
101    stored_user: bool,
102}
103
104impl TurnState {
105    pub fn new(context: &Context, policy: MemoryPolicy) -> Self {
106        Self {
107            memory: MemoryAdapter::new(context.memory(), policy),
108            stored_user: false,
109        }
110    }
111
112    pub fn memory(&self) -> &MemoryAdapter {
113        &self.memory
114    }
115
116    pub fn stored_user(&self) -> bool {
117        self.stored_user
118    }
119
120    fn mark_user_stored(&mut self) {
121        self.stored_user = true;
122    }
123}
124
125/// Shared turn engine that handles memory, tools, and events consistently.
126#[derive(Debug, Clone)]
127pub struct TurnEngine {
128    config: TurnEngineConfig,
129}
130
131impl TurnEngine {
132    pub fn new(config: TurnEngineConfig) -> Self {
133        Self { config }
134    }
135
136    pub fn turn_state(&self, context: &Context) -> TurnState {
137        TurnState::new(context, self.config.memory_policy.clone())
138    }
139
140    pub async fn run_turn<H: AgentHooks>(
141        &self,
142        hooks: &H,
143        task: &Task,
144        context: &Context,
145        turn_state: &mut TurnState,
146        turn_index: usize,
147        max_turns: usize,
148    ) -> Result<crate::agent::executor::TurnResult<TurnEngineOutput>, TurnEngineError> {
149        let max_turns = normalize_max_turns(max_turns, self.config.max_turns);
150        let tx_event = context.tx().ok();
151        EventHelper::send_turn_started(
152            &tx_event,
153            task.submission_id,
154            context.config().id,
155            turn_index,
156            max_turns,
157        )
158        .await;
159
160        hooks.on_turn_start(turn_index, context).await;
161
162        let include_user_prompt =
163            should_include_user_prompt(turn_state.memory(), turn_state.stored_user());
164        let messages = self
165            .build_messages(context, task, turn_state.memory(), include_user_prompt)
166            .await;
167
168        if should_store_user(turn_state) {
169            turn_state.memory.store_user(task).await;
170            turn_state.mark_user_stored();
171        }
172
173        let tools = context.tools();
174        let response = self.get_llm_response(context, &messages, tools).await?;
175        let response_text = response.text().unwrap_or_default();
176
177        let tool_calls = if matches!(self.config.tool_mode, ToolMode::Enabled) {
178            response.tool_calls().unwrap_or_default()
179        } else {
180            Vec::new()
181        };
182
183        if !tool_calls.is_empty() {
184            let tool_results = process_tool_calls_with_hooks(
185                hooks,
186                context,
187                task.submission_id,
188                tools,
189                &tool_calls,
190                &tx_event,
191            )
192            .await;
193
194            turn_state
195                .memory
196                .store_tool_interaction(&tool_calls, &tool_results, &response_text)
197                .await;
198            record_tool_calls_state(context, &tool_results);
199
200            EventHelper::send_turn_completed(
201                &tx_event,
202                task.submission_id,
203                context.config().id,
204                turn_index,
205                false,
206            )
207            .await;
208            hooks.on_turn_complete(turn_index, context).await;
209
210            return Ok(crate::agent::executor::TurnResult::Continue(Some(
211                TurnEngineOutput {
212                    response: response_text,
213                    tool_calls: tool_results,
214                },
215            )));
216        }
217
218        if !response_text.is_empty() {
219            turn_state.memory.store_assistant(&response_text).await;
220        }
221
222        EventHelper::send_turn_completed(
223            &tx_event,
224            task.submission_id,
225            context.config().id,
226            turn_index,
227            true,
228        )
229        .await;
230        hooks.on_turn_complete(turn_index, context).await;
231
232        Ok(crate::agent::executor::TurnResult::Complete(
233            TurnEngineOutput {
234                response: response_text,
235                tool_calls: Vec::new(),
236            },
237        ))
238    }
239
240    pub async fn run_turn_stream<H>(
241        &self,
242        hooks: H,
243        task: &Task,
244        context: Arc<Context>,
245        turn_state: &mut TurnState,
246        turn_index: usize,
247        max_turns: usize,
248    ) -> Result<
249        Pin<Box<dyn Stream<Item = Result<TurnDelta, TurnEngineError>> + Send>>,
250        TurnEngineError,
251    >
252    where
253        H: AgentHooks + Clone + Send + Sync + 'static,
254    {
255        let max_turns = normalize_max_turns(max_turns, self.config.max_turns);
256        let include_user_prompt =
257            should_include_user_prompt(turn_state.memory(), turn_state.stored_user());
258        let messages = self
259            .build_messages(&context, task, turn_state.memory(), include_user_prompt)
260            .await;
261
262        if should_store_user(turn_state) {
263            turn_state.memory.store_user(task).await;
264            turn_state.mark_user_stored();
265        }
266
267        let (mut tx, rx) = channel::<Result<TurnDelta, TurnEngineError>>(100);
268        let engine = self.clone();
269        let context_clone = context.clone();
270        let task = task.clone();
271        let hooks = hooks.clone();
272        let memory = turn_state.memory.clone();
273        let messages = messages.clone();
274
275        spawn_future(async move {
276            let tx_event = context_clone.tx().ok();
277            EventHelper::send_turn_started(
278                &tx_event,
279                task.submission_id,
280                context_clone.config().id,
281                turn_index,
282                max_turns,
283            )
284            .await;
285            hooks.on_turn_start(turn_index, &context_clone).await;
286
287            let result = match engine.config.stream_mode {
288                StreamMode::Structured => {
289                    engine
290                        .stream_structured(&context_clone, &task, &memory, &mut tx, &messages)
291                        .await
292                }
293                StreamMode::Tool => {
294                    engine
295                        .stream_with_tools(
296                            &hooks,
297                            &context_clone,
298                            &task,
299                            context_clone.tools(),
300                            &memory,
301                            &mut tx,
302                            &messages,
303                        )
304                        .await
305                }
306            };
307
308            match result {
309                Ok(turn_result) => {
310                    let final_turn =
311                        matches!(turn_result, crate::agent::executor::TurnResult::Complete(_));
312                    EventHelper::send_turn_completed(
313                        &tx_event,
314                        task.submission_id,
315                        context_clone.config().id,
316                        turn_index,
317                        final_turn,
318                    )
319                    .await;
320                    hooks.on_turn_complete(turn_index, &context_clone).await;
321                    let _ = tx.send(Ok(TurnDelta::Done(turn_result))).await;
322                }
323                Err(err) => {
324                    let _ = tx.send(Err(err)).await;
325                }
326            }
327        });
328
329        Ok(receiver_into_stream(rx))
330    }
331
332    async fn stream_structured(
333        &self,
334        context: &Context,
335        task: &Task,
336        memory: &MemoryAdapter,
337        tx: &mut Sender<Result<TurnDelta, TurnEngineError>>,
338        messages: &[ChatMessage],
339    ) -> Result<crate::agent::executor::TurnResult<TurnEngineOutput>, TurnEngineError> {
340        let mut stream = self.get_structured_stream(context, messages).await?;
341        let mut response_text = String::default();
342
343        while let Some(chunk_result) = stream.next().await {
344            let chunk = chunk_result.map_err(|e| TurnEngineError::LLMError(e.to_string()))?;
345            let content = chunk
346                .choices
347                .first()
348                .and_then(|choice| choice.delta.content.as_ref())
349                .map_or("", |value| value)
350                .to_string();
351
352            if content.is_empty() {
353                continue;
354            }
355
356            response_text.push_str(&content);
357
358            let _ = tx.send(Ok(TurnDelta::Text(content.clone()))).await;
359
360            let tx_event = context.tx().ok();
361            EventHelper::send_stream_chunk(
362                &tx_event,
363                task.submission_id,
364                StreamChunk::Text(content),
365            )
366            .await;
367        }
368
369        if !response_text.is_empty() {
370            memory.store_assistant(&response_text).await;
371        }
372
373        Ok(crate::agent::executor::TurnResult::Complete(
374            TurnEngineOutput {
375                response: response_text,
376                tool_calls: Vec::default(),
377            },
378        ))
379    }
380
381    #[allow(clippy::too_many_arguments)]
382    async fn stream_with_tools<H: AgentHooks>(
383        &self,
384        hooks: &H,
385        context: &Context,
386        task: &Task,
387        tools: &[Box<dyn ToolT>],
388        memory: &MemoryAdapter,
389        tx: &mut Sender<Result<TurnDelta, TurnEngineError>>,
390        messages: &[ChatMessage],
391    ) -> Result<crate::agent::executor::TurnResult<TurnEngineOutput>, TurnEngineError> {
392        let mut stream = self.get_tool_stream(context, messages, tools).await?;
393        let mut response_text = String::default();
394        let mut tool_calls = Vec::default();
395        let mut tool_call_ids: HashSet<String> = HashSet::default();
396
397        while let Some(chunk_result) = stream.next().await {
398            let chunk = chunk_result.map_err(|e| TurnEngineError::LLMError(e.to_string()))?;
399            let chunk_clone = chunk.clone();
400
401            match chunk {
402                StreamChunk::Text(content) => {
403                    response_text.push_str(&content);
404                    let _ = tx.send(Ok(TurnDelta::Text(content.clone()))).await;
405                }
406                StreamChunk::ToolUseComplete { tool_call, .. } => {
407                    if tool_call_ids.insert(tool_call.id.clone()) {
408                        tool_calls.push(tool_call.clone());
409                        let tx_event = context.tx().ok();
410                        EventHelper::send_stream_tool_call(
411                            &tx_event,
412                            task.submission_id,
413                            serde_json::to_value(tool_call).unwrap_or(Value::Null),
414                        )
415                        .await;
416                    }
417                }
418                StreamChunk::Usage(_) => {}
419                _ => {}
420            }
421
422            let tx_event = context.tx().ok();
423            EventHelper::send_stream_chunk(&tx_event, task.submission_id, chunk_clone).await;
424        }
425
426        if tool_calls.is_empty() {
427            if !response_text.is_empty() {
428                memory.store_assistant(&response_text).await;
429            }
430            return Ok(crate::agent::executor::TurnResult::Complete(
431                TurnEngineOutput {
432                    response: response_text,
433                    tool_calls: Vec::new(),
434                },
435            ));
436        }
437
438        let tx_event = context.tx().ok();
439        let tool_results = process_tool_calls_with_hooks(
440            hooks,
441            context,
442            task.submission_id,
443            tools,
444            &tool_calls,
445            &tx_event,
446        )
447        .await;
448
449        memory
450            .store_tool_interaction(&tool_calls, &tool_results, &response_text)
451            .await;
452        record_tool_calls_state(context, &tool_results);
453
454        let _ = tx
455            .send(Ok(TurnDelta::ToolResults(tool_results.clone())))
456            .await;
457
458        Ok(crate::agent::executor::TurnResult::Continue(Some(
459            TurnEngineOutput {
460                response: response_text,
461                tool_calls: tool_results,
462            },
463        )))
464    }
465
466    async fn get_llm_response(
467        &self,
468        context: &Context,
469        messages: &[ChatMessage],
470        tools: &[Box<dyn ToolT>],
471    ) -> Result<Box<dyn autoagents_llm::chat::ChatResponse>, TurnEngineError> {
472        let llm = context.llm();
473        let output_schema = context.config().output_schema.clone();
474
475        if matches!(self.config.tool_mode, ToolMode::Enabled) && !tools.is_empty() {
476            let cached = context.serialized_tools();
477            let tools_serialized = if let Some(cached) = cached {
478                cached
479            } else {
480                Arc::new(tools.iter().map(to_llm_tool).collect::<Vec<_>>())
481            };
482            llm.chat_with_tools(messages, Some(&tools_serialized), output_schema)
483                .await
484                .map_err(|e| TurnEngineError::LLMError(e.to_string()))
485        } else {
486            llm.chat(messages, output_schema)
487                .await
488                .map_err(|e| TurnEngineError::LLMError(e.to_string()))
489        }
490    }
491
492    async fn get_structured_stream(
493        &self,
494        context: &Context,
495        messages: &[ChatMessage],
496    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamResponse, LLMError>> + Send>>, TurnEngineError>
497    {
498        context
499            .llm()
500            .chat_stream_struct(messages, None, context.config().output_schema.clone())
501            .await
502            .map_err(|e| TurnEngineError::LLMError(e.to_string()))
503    }
504
505    async fn get_tool_stream(
506        &self,
507        context: &Context,
508        messages: &[ChatMessage],
509        tools: &[Box<dyn ToolT>],
510    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk, LLMError>> + Send>>, TurnEngineError>
511    {
512        let cached = context.serialized_tools();
513        let tools_serialized = if let Some(cached) = cached {
514            cached
515        } else {
516            Arc::new(tools.iter().map(to_llm_tool).collect::<Vec<_>>())
517        };
518        context
519            .llm()
520            .chat_stream_with_tools(
521                messages,
522                if tools_serialized.is_empty() {
523                    None
524                } else {
525                    Some(&tools_serialized)
526                },
527                context.config().output_schema.clone(),
528            )
529            .await
530            .map_err(|e| TurnEngineError::LLMError(e.to_string()))
531    }
532
533    async fn build_messages(
534        &self,
535        context: &Context,
536        task: &Task,
537        memory: &MemoryAdapter,
538        include_user_prompt: bool,
539    ) -> Vec<ChatMessage> {
540        let system_prompt = task
541            .system_prompt
542            .as_deref()
543            .unwrap_or_else(|| &context.config().description);
544        let mut messages = vec![ChatMessage {
545            role: ChatRole::System,
546            message_type: MessageType::Text,
547            content: system_prompt.to_string(),
548        }];
549
550        let recalled = memory.recall_messages(task).await;
551        messages.extend(recalled);
552
553        if include_user_prompt {
554            messages.push(user_message(task));
555        }
556
557        messages
558    }
559}
560
561pub fn record_task_state(context: &Context, task: &Task) {
562    let state = context.state();
563    #[cfg(not(target_arch = "wasm32"))]
564    if let Ok(mut guard) = state.try_lock() {
565        guard.record_task(task.clone());
566    };
567    #[cfg(target_arch = "wasm32")]
568    if let Some(mut guard) = state.try_lock() {
569        guard.record_task(task.clone());
570    };
571}
572
573fn user_message(task: &Task) -> ChatMessage {
574    if let Some((mime, image_data)) = &task.image {
575        ChatMessage {
576            role: ChatRole::User,
577            message_type: MessageType::Image(((*mime).into(), image_data.clone())),
578            content: task.prompt.clone(),
579        }
580    } else {
581        ChatMessage {
582            role: ChatRole::User,
583            message_type: MessageType::Text,
584            content: task.prompt.clone(),
585        }
586    }
587}
588
589fn should_include_user_prompt(memory: &MemoryAdapter, stored_user: bool) -> bool {
590    if !memory.is_enabled() {
591        return true;
592    }
593    if !memory.policy().recall {
594        return true;
595    }
596    if !memory.policy().store_user {
597        return true;
598    }
599    !stored_user
600}
601
602fn should_store_user(turn_state: &TurnState) -> bool {
603    if !turn_state.memory.is_enabled() {
604        return false;
605    }
606    if !turn_state.memory.policy().store_user {
607        return false;
608    }
609    !turn_state.stored_user
610}
611
612fn normalize_max_turns(max_turns: usize, fallback: usize) -> usize {
613    if max_turns == 0 {
614        return fallback.max(1);
615    }
616    max_turns
617}
618
619fn record_tool_calls_state(context: &Context, tool_results: &[ToolCallResult]) {
620    if tool_results.is_empty() {
621        return;
622    }
623    let state = context.state();
624    #[cfg(not(target_arch = "wasm32"))]
625    if let Ok(mut guard) = state.try_lock() {
626        for result in tool_results {
627            guard.record_tool_call(result.clone());
628        }
629    };
630    #[cfg(target_arch = "wasm32")]
631    if let Some(mut guard) = state.try_lock() {
632        for result in tool_results {
633            guard.record_tool_call(result.clone());
634        }
635    };
636}
637
638async fn process_tool_calls_with_hooks<H: AgentHooks>(
639    hooks: &H,
640    context: &Context,
641    submission_id: SubmissionId,
642    tools: &[Box<dyn ToolT>],
643    tool_calls: &[ToolCall],
644    tx_event: &Option<mpsc::Sender<Event>>,
645) -> Vec<ToolCallResult> {
646    let mut results = Vec::new();
647    for call in tool_calls {
648        if let Some(result) = ToolProcessor::process_single_tool_call_with_hooks(
649            hooks,
650            context,
651            submission_id,
652            tools,
653            call,
654            tx_event,
655        )
656        .await
657        {
658            results.push(result);
659        }
660    }
661    results
662}
663
664#[cfg(test)]
665mod tests {
666    use super::*;
667    use crate::agent::task::Task;
668    use crate::agent::{AgentConfig, Context};
669    use crate::tests::{ConfigurableLLMProvider, StaticChatResponse};
670    use async_trait::async_trait;
671    use autoagents_llm::ToolCall;
672    use autoagents_llm::chat::{StreamChoice, StreamChunk, StreamDelta, StreamResponse};
673    use autoagents_protocol::ActorID;
674    use futures::StreamExt;
675
676    #[derive(Debug)]
677    struct LocalTool {
678        name: String,
679        output: serde_json::Value,
680    }
681
682    impl LocalTool {
683        fn new(name: &str, output: serde_json::Value) -> Self {
684            Self {
685                name: name.to_string(),
686                output,
687            }
688        }
689    }
690
691    impl crate::tool::ToolT for LocalTool {
692        fn name(&self) -> &str {
693            &self.name
694        }
695
696        fn description(&self) -> &str {
697            "local tool"
698        }
699
700        fn args_schema(&self) -> serde_json::Value {
701            serde_json::json!({"type": "object"})
702        }
703    }
704
705    #[async_trait]
706    impl crate::tool::ToolRuntime for LocalTool {
707        async fn execute(
708            &self,
709            _args: serde_json::Value,
710        ) -> Result<serde_json::Value, crate::tool::ToolCallError> {
711            Ok(self.output.clone())
712        }
713    }
714
715    #[test]
716    fn test_turn_engine_config_basic() {
717        let config = TurnEngineConfig::basic(5);
718        assert_eq!(config.max_turns, 5);
719        assert!(matches!(config.tool_mode, ToolMode::Disabled));
720        assert!(matches!(config.stream_mode, StreamMode::Structured));
721        assert!(config.memory_policy.recall);
722    }
723
724    #[test]
725    fn test_turn_engine_config_react() {
726        let config = TurnEngineConfig::react(10);
727        assert_eq!(config.max_turns, 10);
728        assert!(matches!(config.tool_mode, ToolMode::Enabled));
729        assert!(matches!(config.stream_mode, StreamMode::Tool));
730        assert!(config.memory_policy.recall);
731    }
732
733    #[test]
734    fn test_normalize_max_turns_nonzero() {
735        assert_eq!(normalize_max_turns(5, 10), 5);
736    }
737
738    #[test]
739    fn test_normalize_max_turns_zero_uses_fallback() {
740        assert_eq!(normalize_max_turns(0, 10), 10);
741    }
742
743    #[test]
744    fn test_normalize_max_turns_zero_fallback_zero() {
745        assert_eq!(normalize_max_turns(0, 0), 1);
746    }
747
748    #[test]
749    fn test_should_include_user_prompt_no_memory() {
750        let adapter = MemoryAdapter::new(None, MemoryPolicy::basic());
751        assert!(should_include_user_prompt(&adapter, false));
752    }
753
754    #[test]
755    fn test_should_include_user_prompt_recall_disabled() {
756        let mut policy = MemoryPolicy::basic();
757        policy.recall = false;
758        let mem: Box<dyn crate::agent::memory::MemoryProvider> =
759            Box::new(crate::agent::memory::SlidingWindowMemory::new(10));
760        let adapter = MemoryAdapter::new(
761            Some(std::sync::Arc::new(tokio::sync::Mutex::new(mem))),
762            policy,
763        );
764        assert!(should_include_user_prompt(&adapter, false));
765    }
766
767    #[test]
768    fn test_should_include_user_prompt_store_user_disabled() {
769        let mut policy = MemoryPolicy::basic();
770        policy.store_user = false;
771        let mem: Box<dyn crate::agent::memory::MemoryProvider> =
772            Box::new(crate::agent::memory::SlidingWindowMemory::new(10));
773        let adapter = MemoryAdapter::new(
774            Some(std::sync::Arc::new(tokio::sync::Mutex::new(mem))),
775            policy,
776        );
777        assert!(should_include_user_prompt(&adapter, false));
778    }
779
780    #[test]
781    fn test_should_include_user_prompt_already_stored() {
782        let mem: Box<dyn crate::agent::memory::MemoryProvider> =
783            Box::new(crate::agent::memory::SlidingWindowMemory::new(10));
784        let adapter = MemoryAdapter::new(
785            Some(std::sync::Arc::new(tokio::sync::Mutex::new(mem))),
786            MemoryPolicy::basic(),
787        );
788        // stored_user = true => should not include
789        assert!(!should_include_user_prompt(&adapter, true));
790    }
791
792    #[test]
793    fn test_should_store_user_no_memory() {
794        let state = TurnState {
795            memory: MemoryAdapter::new(None, MemoryPolicy::basic()),
796            stored_user: false,
797        };
798        assert!(!should_store_user(&state));
799    }
800
801    #[test]
802    fn test_should_store_user_already_stored() {
803        let mem: Box<dyn crate::agent::memory::MemoryProvider> =
804            Box::new(crate::agent::memory::SlidingWindowMemory::new(10));
805        let state = TurnState {
806            memory: MemoryAdapter::new(
807                Some(std::sync::Arc::new(tokio::sync::Mutex::new(mem))),
808                MemoryPolicy::basic(),
809            ),
810            stored_user: true,
811        };
812        assert!(!should_store_user(&state));
813    }
814
815    #[test]
816    fn test_user_message_text() {
817        let task = Task::new("hello");
818        let msg = user_message(&task);
819        assert!(matches!(msg.role, ChatRole::User));
820        assert!(matches!(msg.message_type, MessageType::Text));
821        assert_eq!(msg.content, "hello");
822    }
823
824    #[test]
825    fn test_user_message_image() {
826        let mut task = Task::new("describe");
827        task.image = Some((autoagents_protocol::ImageMime::PNG, vec![1, 2, 3]));
828        let msg = user_message(&task);
829        assert!(matches!(msg.role, ChatRole::User));
830        assert!(matches!(msg.message_type, MessageType::Image(_)));
831    }
832
833    #[test]
834    fn test_turn_state_new_and_mark_user_stored() {
835        let config = AgentConfig {
836            id: ActorID::new_v4(),
837            name: "test".to_string(),
838            description: "test".to_string(),
839            output_schema: None,
840        };
841        let llm = std::sync::Arc::new(crate::tests::MockLLMProvider {});
842        let context = Context::new(llm, None).with_config(config);
843
844        let mut state = TurnState::new(&context, MemoryPolicy::basic());
845        assert!(!state.stored_user());
846        state.mark_user_stored();
847        assert!(state.stored_user());
848    }
849
850    #[tokio::test]
851    async fn test_build_messages_with_system_prompt() {
852        let config = AgentConfig {
853            id: ActorID::new_v4(),
854            name: "test".to_string(),
855            description: "default desc".to_string(),
856            output_schema: None,
857        };
858        let llm = std::sync::Arc::new(crate::tests::MockLLMProvider {});
859        let context = Context::new(llm, None).with_config(config);
860
861        let engine = TurnEngine::new(TurnEngineConfig::basic(1));
862        let adapter = MemoryAdapter::new(None, MemoryPolicy::basic());
863        let mut task = Task::new("user input");
864        task.system_prompt = Some("custom system".to_string());
865
866        let messages = engine.build_messages(&context, &task, &adapter, true).await;
867        // System + user
868        assert_eq!(messages.len(), 2);
869        assert_eq!(messages[0].content, "custom system");
870        assert_eq!(messages[0].role, ChatRole::System);
871        assert_eq!(messages[1].content, "user input");
872    }
873
874    #[tokio::test]
875    async fn test_build_messages_without_user_prompt() {
876        let config = AgentConfig {
877            id: ActorID::new_v4(),
878            name: "test".to_string(),
879            description: "desc".to_string(),
880            output_schema: None,
881        };
882        let llm = std::sync::Arc::new(crate::tests::MockLLMProvider {});
883        let context = Context::new(llm, None).with_config(config);
884
885        let engine = TurnEngine::new(TurnEngineConfig::basic(1));
886        let adapter = MemoryAdapter::new(None, MemoryPolicy::basic());
887        let task = Task::new("user input");
888
889        let messages = engine
890            .build_messages(&context, &task, &adapter, false)
891            .await;
892        // Only system prompt
893        assert_eq!(messages.len(), 1);
894        assert_eq!(messages[0].role, ChatRole::System);
895    }
896
897    #[tokio::test]
898    async fn test_run_turn_no_tools_single_turn() {
899        use crate::tests::MockAgentImpl;
900        let config = AgentConfig {
901            id: ActorID::new_v4(),
902            name: "test".to_string(),
903            description: "test desc".to_string(),
904            output_schema: None,
905        };
906        let llm = std::sync::Arc::new(crate::tests::MockLLMProvider {});
907        let context = Context::new(llm, None).with_config(config);
908
909        let engine = TurnEngine::new(TurnEngineConfig::basic(1));
910        let mut turn_state = engine.turn_state(&context);
911        let task = Task::new("test prompt");
912        let hooks = MockAgentImpl::new("test", "test");
913
914        let result = engine
915            .run_turn(&hooks, &task, &context, &mut turn_state, 0, 1)
916            .await;
917        assert!(result.is_ok());
918        let turn_result = result.unwrap();
919        assert!(matches!(
920            turn_result,
921            crate::agent::executor::TurnResult::Complete(_)
922        ));
923        if let crate::agent::executor::TurnResult::Complete(output) = turn_result {
924            assert_eq!(output.response, "Mock response");
925        }
926    }
927
928    #[tokio::test]
929    async fn test_run_turn_with_tool_calls_continues() {
930        use crate::tests::MockAgentImpl;
931        let tool_call = ToolCall {
932            id: "call_1".to_string(),
933            call_type: "function".to_string(),
934            function: autoagents_llm::FunctionCall {
935                name: "tool_a".to_string(),
936                arguments: r#"{"value":1}"#.to_string(),
937            },
938        };
939
940        let llm = Arc::new(ConfigurableLLMProvider {
941            chat_response: StaticChatResponse {
942                text: Some("Use tool".to_string()),
943                tool_calls: Some(vec![tool_call.clone()]),
944                usage: None,
945                thinking: None,
946            },
947            ..ConfigurableLLMProvider::default()
948        });
949
950        let config = AgentConfig {
951            id: ActorID::new_v4(),
952            name: "tool_agent".to_string(),
953            description: "desc".to_string(),
954            output_schema: None,
955        };
956        let tool = LocalTool::new("tool_a", serde_json::json!({"ok": true}));
957        let context = Context::new(llm, None)
958            .with_config(config)
959            .with_tools(vec![Box::new(tool)]);
960
961        let engine = TurnEngine::new(TurnEngineConfig {
962            max_turns: 2,
963            tool_mode: ToolMode::Enabled,
964            stream_mode: StreamMode::Structured,
965            memory_policy: MemoryPolicy::basic(),
966        });
967        let mut turn_state = engine.turn_state(&context);
968        let task = Task::new("prompt");
969        let hooks = MockAgentImpl::new("test", "test");
970
971        let result = engine
972            .run_turn(&hooks, &task, &context, &mut turn_state, 0, 2)
973            .await
974            .unwrap();
975
976        match result {
977            crate::agent::executor::TurnResult::Continue(Some(output)) => {
978                assert_eq!(output.response, "Use tool");
979                assert_eq!(output.tool_calls.len(), 1);
980                assert!(output.tool_calls[0].success);
981            }
982            _ => panic!("expected Continue(Some)"),
983        }
984
985        #[cfg(not(target_arch = "wasm32"))]
986        if let Ok(state) = context.state().try_lock() {
987            assert_eq!(state.tool_calls.len(), 1);
988        }
989    }
990
991    #[tokio::test]
992    async fn test_run_turn_tool_mode_disabled_ignores_tool_calls() {
993        use crate::tests::MockAgentImpl;
994        let tool_call = ToolCall {
995            id: "call_1".to_string(),
996            call_type: "function".to_string(),
997            function: autoagents_llm::FunctionCall {
998                name: "tool_a".to_string(),
999                arguments: r#"{"value":1}"#.to_string(),
1000            },
1001        };
1002
1003        let llm = Arc::new(ConfigurableLLMProvider {
1004            chat_response: StaticChatResponse {
1005                text: Some("No tools".to_string()),
1006                tool_calls: Some(vec![tool_call]),
1007                usage: None,
1008                thinking: None,
1009            },
1010            ..ConfigurableLLMProvider::default()
1011        });
1012
1013        let config = AgentConfig {
1014            id: ActorID::new_v4(),
1015            name: "tool_agent".to_string(),
1016            description: "desc".to_string(),
1017            output_schema: None,
1018        };
1019        let context = Context::new(llm, None).with_config(config);
1020
1021        let engine = TurnEngine::new(TurnEngineConfig {
1022            max_turns: 1,
1023            tool_mode: ToolMode::Disabled,
1024            stream_mode: StreamMode::Structured,
1025            memory_policy: MemoryPolicy::basic(),
1026        });
1027        let mut turn_state = engine.turn_state(&context);
1028        let task = Task::new("prompt");
1029        let hooks = MockAgentImpl::new("test", "test");
1030
1031        let result = engine
1032            .run_turn(&hooks, &task, &context, &mut turn_state, 0, 1)
1033            .await
1034            .unwrap();
1035
1036        match result {
1037            crate::agent::executor::TurnResult::Complete(output) => {
1038                assert_eq!(output.response, "No tools");
1039                assert!(output.tool_calls.is_empty());
1040            }
1041            _ => panic!("expected Complete"),
1042        }
1043    }
1044
1045    #[tokio::test]
1046    async fn test_run_turn_stream_structured_aggregates_text() {
1047        use crate::tests::MockAgentImpl;
1048        let llm = Arc::new(ConfigurableLLMProvider {
1049            structured_stream: vec![
1050                StreamResponse {
1051                    choices: vec![StreamChoice {
1052                        delta: StreamDelta {
1053                            content: Some("Hello ".to_string()),
1054                            tool_calls: None,
1055                        },
1056                    }],
1057                    usage: None,
1058                },
1059                StreamResponse {
1060                    choices: vec![StreamChoice {
1061                        delta: StreamDelta {
1062                            content: Some("world".to_string()),
1063                            tool_calls: None,
1064                        },
1065                    }],
1066                    usage: None,
1067                },
1068            ],
1069            ..ConfigurableLLMProvider::default()
1070        });
1071
1072        let config = AgentConfig {
1073            id: ActorID::new_v4(),
1074            name: "stream_agent".to_string(),
1075            description: "desc".to_string(),
1076            output_schema: None,
1077        };
1078        let context = Arc::new(Context::new(llm, None).with_config(config));
1079        let engine = TurnEngine::new(TurnEngineConfig {
1080            max_turns: 1,
1081            tool_mode: ToolMode::Disabled,
1082            stream_mode: StreamMode::Structured,
1083            memory_policy: MemoryPolicy::basic(),
1084        });
1085        let mut turn_state = engine.turn_state(&context);
1086        let task = Task::new("prompt");
1087        let hooks = MockAgentImpl::new("test", "test");
1088
1089        let mut stream = engine
1090            .run_turn_stream(hooks, &task, context, &mut turn_state, 0, 1)
1091            .await
1092            .unwrap();
1093
1094        let mut final_text = String::default();
1095        while let Some(delta) = stream.next().await {
1096            if let Ok(TurnDelta::Done(result)) = delta {
1097                final_text = match result {
1098                    crate::agent::executor::TurnResult::Complete(output) => output.response,
1099                    crate::agent::executor::TurnResult::Continue(Some(output)) => output.response,
1100                    crate::agent::executor::TurnResult::Continue(None) => String::default(),
1101                };
1102                break;
1103            }
1104        }
1105
1106        assert_eq!(final_text, "Hello world");
1107    }
1108
1109    #[tokio::test]
1110    async fn test_run_turn_stream_with_tools_executes_tools() {
1111        use crate::tests::MockAgentImpl;
1112        let tool_call = ToolCall {
1113            id: "call_1".to_string(),
1114            call_type: "function".to_string(),
1115            function: autoagents_llm::FunctionCall {
1116                name: "tool_a".to_string(),
1117                arguments: r#"{"value":1}"#.to_string(),
1118            },
1119        };
1120
1121        let llm = Arc::new(ConfigurableLLMProvider {
1122            stream_chunks: vec![
1123                StreamChunk::Text("thinking".to_string()),
1124                StreamChunk::ToolUseComplete {
1125                    index: 0,
1126                    tool_call: tool_call.clone(),
1127                },
1128                StreamChunk::Done {
1129                    stop_reason: "tool_use".to_string(),
1130                },
1131            ],
1132            ..ConfigurableLLMProvider::default()
1133        });
1134
1135        let config = AgentConfig {
1136            id: ActorID::new_v4(),
1137            name: "tool_stream_agent".to_string(),
1138            description: "desc".to_string(),
1139            output_schema: None,
1140        };
1141        let tool = LocalTool::new("tool_a", serde_json::json!({"ok": true}));
1142        let context = Arc::new(
1143            Context::new(llm, None)
1144                .with_config(config)
1145                .with_tools(vec![Box::new(tool)]),
1146        );
1147        let engine = TurnEngine::new(TurnEngineConfig {
1148            max_turns: 1,
1149            tool_mode: ToolMode::Enabled,
1150            stream_mode: StreamMode::Tool,
1151            memory_policy: MemoryPolicy::basic(),
1152        });
1153        let mut turn_state = engine.turn_state(&context);
1154        let task = Task::new("prompt");
1155        let hooks = MockAgentImpl::new("test", "test");
1156
1157        let mut stream = engine
1158            .run_turn_stream(hooks, &task, context, &mut turn_state, 0, 1)
1159            .await
1160            .unwrap();
1161
1162        let mut final_result = None;
1163        while let Some(delta) = stream.next().await {
1164            if let Ok(TurnDelta::Done(result)) = delta {
1165                final_result = Some(result);
1166                break;
1167            }
1168        }
1169
1170        match final_result.expect("done") {
1171            crate::agent::executor::TurnResult::Continue(Some(output)) => {
1172                assert_eq!(output.tool_calls.len(), 1);
1173                assert!(output.tool_calls[0].success);
1174            }
1175            _ => panic!("expected Continue(Some)"),
1176        }
1177    }
1178}