Skip to main content

autoagents_core/agent/executor/
turn_engine.rs

1use crate::agent::Context;
2use crate::agent::executor::event_helper::EventHelper;
3use crate::agent::executor::memory_policy::{MemoryAdapter, MemoryPolicy};
4use crate::agent::executor::tool_processor::ToolProcessor;
5use crate::agent::hooks::AgentHooks;
6use crate::agent::task::Task;
7use crate::channel::{Sender, channel};
8use crate::tool::{ToolCallResult, ToolT, to_llm_tool};
9use crate::utils::{receiver_into_stream, spawn_future};
10use autoagents_llm::ToolCall;
11use autoagents_llm::chat::{ChatMessage, ChatRole, MessageType, StreamChunk, StreamResponse};
12use autoagents_llm::error::LLMError;
13use autoagents_protocol::{Event, SubmissionId};
14use futures::{Stream, StreamExt};
15use serde_json::Value;
16use std::collections::HashSet;
17use std::pin::Pin;
18use std::sync::Arc;
19use thiserror::Error;
20
21#[cfg(not(target_arch = "wasm32"))]
22use tokio::sync::mpsc;
23
24#[cfg(target_arch = "wasm32")]
25use futures::channel::mpsc;
26
27/// Defines if tools are enabled for a given execution plan.
28#[derive(Debug, Clone, Copy)]
29pub enum ToolMode {
30    Enabled,
31    Disabled,
32}
33
34/// Defines which streaming primitive to use.
35#[derive(Debug, Clone, Copy)]
36pub enum StreamMode {
37    Structured,
38    Tool,
39}
40
41/// Configuration for the shared executor engine.
42#[derive(Debug, Clone)]
43pub struct TurnEngineConfig {
44    pub max_turns: usize,
45    pub tool_mode: ToolMode,
46    pub stream_mode: StreamMode,
47    pub memory_policy: MemoryPolicy,
48}
49
50impl TurnEngineConfig {
51    pub fn basic(max_turns: usize) -> Self {
52        Self {
53            max_turns,
54            tool_mode: ToolMode::Disabled,
55            stream_mode: StreamMode::Structured,
56            memory_policy: MemoryPolicy::basic(),
57        }
58    }
59
60    pub fn react(max_turns: usize) -> Self {
61        Self {
62            max_turns,
63            tool_mode: ToolMode::Enabled,
64            stream_mode: StreamMode::Tool,
65            memory_policy: MemoryPolicy::react(),
66        }
67    }
68}
69
70/// Normalized output emitted by the engine for a single turn.
71#[derive(Debug, Clone)]
72pub struct TurnEngineOutput {
73    pub response: String,
74    pub reasoning_content: String,
75    pub tool_calls: Vec<ToolCallResult>,
76}
77
78/// Streaming deltas emitted per turn.
79#[derive(Debug)]
80pub enum TurnDelta {
81    Text(String),
82    ReasoningContent(String),
83    ToolResults(Vec<ToolCallResult>),
84    Done(crate::agent::executor::TurnResult<TurnEngineOutput>),
85}
86
87#[derive(Error, Debug)]
88pub enum TurnEngineError {
89    #[error("LLM error: {0}")]
90    LLMError(
91        #[from]
92        #[source]
93        LLMError,
94    ),
95
96    #[error("Run aborted by hook")]
97    Aborted,
98
99    #[error("Other error: {0}")]
100    Other(String),
101}
102
103/// Per-run state for the turn engine.
104#[derive(Clone)]
105pub struct TurnState {
106    memory: MemoryAdapter,
107    stored_user: bool,
108}
109
110impl TurnState {
111    pub fn new(context: &Context, policy: MemoryPolicy) -> Self {
112        Self {
113            memory: MemoryAdapter::new(context.memory(), policy),
114            stored_user: false,
115        }
116    }
117
118    pub fn memory(&self) -> &MemoryAdapter {
119        &self.memory
120    }
121
122    pub fn stored_user(&self) -> bool {
123        self.stored_user
124    }
125
126    fn mark_user_stored(&mut self) {
127        self.stored_user = true;
128    }
129}
130
131/// Shared turn engine that handles memory, tools, and events consistently.
132#[derive(Debug, Clone)]
133pub struct TurnEngine {
134    config: TurnEngineConfig,
135}
136
137impl TurnEngine {
138    pub fn new(config: TurnEngineConfig) -> Self {
139        Self { config }
140    }
141
142    pub fn turn_state(&self, context: &Context) -> TurnState {
143        TurnState::new(context, self.config.memory_policy.clone())
144    }
145
146    pub async fn run_turn<H: AgentHooks>(
147        &self,
148        hooks: &H,
149        task: &Task,
150        context: &Context,
151        turn_state: &mut TurnState,
152        turn_index: usize,
153        max_turns: usize,
154    ) -> Result<crate::agent::executor::TurnResult<TurnEngineOutput>, TurnEngineError> {
155        let max_turns = normalize_max_turns(max_turns, self.config.max_turns);
156        let tx_event = context.tx().ok();
157        EventHelper::send_turn_started(
158            &tx_event,
159            task.submission_id,
160            context.config().id,
161            turn_index,
162            max_turns,
163        )
164        .await;
165
166        hooks.on_turn_start(turn_index, context).await;
167
168        let include_user_prompt =
169            should_include_user_prompt(turn_state.memory(), turn_state.stored_user());
170        let messages = self
171            .build_messages(context, task, turn_state.memory(), include_user_prompt)
172            .await;
173        let store_user = should_store_user(turn_state);
174
175        let tools = context.tools();
176        let response = self.get_llm_response(context, &messages, tools).await?;
177        let response_text = response.text().unwrap_or_default();
178        let reasoning_content = response.thinking().unwrap_or_default();
179        if store_user {
180            turn_state.memory.store_user(task).await;
181            turn_state.mark_user_stored();
182        }
183
184        let tool_calls = if matches!(self.config.tool_mode, ToolMode::Enabled) {
185            response.tool_calls().unwrap_or_default()
186        } else {
187            Vec::new()
188        };
189
190        if !tool_calls.is_empty() {
191            let tool_results = process_tool_calls_with_hooks(
192                hooks,
193                context,
194                task.submission_id,
195                tools,
196                &tool_calls,
197                &tx_event,
198            )
199            .await;
200
201            turn_state
202                .memory
203                .store_tool_interaction(&tool_calls, &tool_results, &response_text)
204                .await;
205            record_tool_calls_state(context, &tool_results);
206
207            EventHelper::send_turn_completed(
208                &tx_event,
209                task.submission_id,
210                context.config().id,
211                turn_index,
212                false,
213            )
214            .await;
215            hooks.on_turn_complete(turn_index, context).await;
216
217            return Ok(crate::agent::executor::TurnResult::Continue(Some(
218                TurnEngineOutput {
219                    response: response_text,
220                    reasoning_content,
221                    tool_calls: tool_results,
222                },
223            )));
224        }
225
226        if !response_text.is_empty() {
227            turn_state.memory.store_assistant(&response_text).await;
228        }
229
230        EventHelper::send_turn_completed(
231            &tx_event,
232            task.submission_id,
233            context.config().id,
234            turn_index,
235            true,
236        )
237        .await;
238        hooks.on_turn_complete(turn_index, context).await;
239
240        Ok(crate::agent::executor::TurnResult::Complete(
241            TurnEngineOutput {
242                response: response_text,
243                reasoning_content,
244                tool_calls: Vec::new(),
245            },
246        ))
247    }
248
249    pub async fn run_turn_stream<H>(
250        &self,
251        hooks: H,
252        task: &Task,
253        context: Arc<Context>,
254        turn_state: &mut TurnState,
255        turn_index: usize,
256        max_turns: usize,
257    ) -> Result<
258        Pin<Box<dyn Stream<Item = Result<TurnDelta, TurnEngineError>> + Send>>,
259        TurnEngineError,
260    >
261    where
262        H: AgentHooks + Clone + Send + Sync + 'static,
263    {
264        let max_turns = normalize_max_turns(max_turns, self.config.max_turns);
265        let include_user_prompt =
266            should_include_user_prompt(turn_state.memory(), turn_state.stored_user());
267        let messages = self
268            .build_messages(&context, task, turn_state.memory(), include_user_prompt)
269            .await;
270        let store_user = should_store_user(turn_state);
271        if store_user {
272            turn_state.mark_user_stored();
273        }
274
275        let (mut tx, rx) = channel::<Result<TurnDelta, TurnEngineError>>(100);
276        let engine = self.clone();
277        let context_clone = context.clone();
278        let task = task.clone();
279        let hooks = hooks.clone();
280        let memory = turn_state.memory.clone();
281        let messages = messages.clone();
282
283        spawn_future(async move {
284            let tx_event = context_clone.tx().ok();
285            EventHelper::send_turn_started(
286                &tx_event,
287                task.submission_id,
288                context_clone.config().id,
289                turn_index,
290                max_turns,
291            )
292            .await;
293            hooks.on_turn_start(turn_index, &context_clone).await;
294
295            let result = match engine.config.stream_mode {
296                StreamMode::Structured => {
297                    engine
298                        .stream_structured(
299                            &context_clone,
300                            &task,
301                            &memory,
302                            &mut tx,
303                            &messages,
304                            store_user,
305                        )
306                        .await
307                }
308                StreamMode::Tool => {
309                    engine
310                        .stream_with_tools(
311                            &hooks,
312                            &context_clone,
313                            &task,
314                            context_clone.tools(),
315                            &memory,
316                            &mut tx,
317                            &messages,
318                            store_user,
319                        )
320                        .await
321                }
322            };
323
324            match result {
325                Ok(turn_result) => {
326                    let final_turn =
327                        matches!(turn_result, crate::agent::executor::TurnResult::Complete(_));
328                    EventHelper::send_turn_completed(
329                        &tx_event,
330                        task.submission_id,
331                        context_clone.config().id,
332                        turn_index,
333                        final_turn,
334                    )
335                    .await;
336                    hooks.on_turn_complete(turn_index, &context_clone).await;
337                    let _ = tx.send(Ok(TurnDelta::Done(turn_result))).await;
338                }
339                Err(err) => {
340                    let _ = tx.send(Err(err)).await;
341                }
342            }
343        });
344
345        Ok(receiver_into_stream(rx))
346    }
347
348    async fn stream_structured(
349        &self,
350        context: &Context,
351        task: &Task,
352        memory: &MemoryAdapter,
353        tx: &mut Sender<Result<TurnDelta, TurnEngineError>>,
354        messages: &[ChatMessage],
355        store_user: bool,
356    ) -> Result<crate::agent::executor::TurnResult<TurnEngineOutput>, TurnEngineError> {
357        let mut stream = self.get_structured_stream(context, messages).await?;
358        if store_user {
359            memory.store_user(task).await;
360        }
361        let mut response_text = String::default();
362        let mut reasoning_content = String::default();
363
364        while let Some(chunk_result) = stream.next().await {
365            let chunk = chunk_result.map_err(TurnEngineError::LLMError)?;
366            let delta = chunk.choices.first().map(|choice| &choice.delta);
367            let content = delta
368                .and_then(|d| d.content.as_ref())
369                .map(String::as_str)
370                .unwrap_or("")
371                .to_string();
372            let reasoning = delta
373                .and_then(|d| d.reasoning_content.as_ref())
374                .map(String::as_str)
375                .unwrap_or("")
376                .to_string();
377
378            let tx_event = context.tx().ok();
379            if !content.is_empty() {
380                response_text.push_str(&content);
381                let _ = tx.send(Ok(TurnDelta::Text(content.clone()))).await;
382                EventHelper::send_stream_chunk(
383                    &tx_event,
384                    task.submission_id,
385                    StreamChunk::Text(content),
386                )
387                .await;
388            }
389            if !reasoning.is_empty() {
390                reasoning_content.push_str(&reasoning);
391                let _ = tx
392                    .send(Ok(TurnDelta::ReasoningContent(reasoning.clone())))
393                    .await;
394                EventHelper::send_stream_chunk(
395                    &tx_event,
396                    task.submission_id,
397                    StreamChunk::ReasoningContent(reasoning),
398                )
399                .await;
400            }
401        }
402
403        if !response_text.is_empty() {
404            memory.store_assistant(&response_text).await;
405        }
406
407        Ok(crate::agent::executor::TurnResult::Complete(
408            TurnEngineOutput {
409                response: response_text,
410                reasoning_content,
411                tool_calls: Vec::default(),
412            },
413        ))
414    }
415
416    #[allow(clippy::too_many_arguments)]
417    async fn stream_with_tools<H: AgentHooks>(
418        &self,
419        hooks: &H,
420        context: &Context,
421        task: &Task,
422        tools: &[Box<dyn ToolT>],
423        memory: &MemoryAdapter,
424        tx: &mut Sender<Result<TurnDelta, TurnEngineError>>,
425        messages: &[ChatMessage],
426        store_user: bool,
427    ) -> Result<crate::agent::executor::TurnResult<TurnEngineOutput>, TurnEngineError> {
428        let mut stream = self.get_tool_stream(context, messages, tools).await?;
429        if store_user {
430            memory.store_user(task).await;
431        }
432        let mut response_text = String::default();
433        let mut reasoning_content = String::default();
434        let mut tool_calls = Vec::default();
435        let mut tool_call_ids: HashSet<String> = HashSet::default();
436
437        while let Some(chunk_result) = stream.next().await {
438            let chunk = chunk_result.map_err(TurnEngineError::LLMError)?;
439            let chunk_clone = chunk.clone();
440
441            match chunk {
442                StreamChunk::Text(content) => {
443                    response_text.push_str(&content);
444                    let _ = tx.send(Ok(TurnDelta::Text(content.clone()))).await;
445                }
446                StreamChunk::ReasoningContent(content) => {
447                    reasoning_content.push_str(&content);
448                    let _ = tx.send(Ok(TurnDelta::ReasoningContent(content))).await;
449                }
450                StreamChunk::ToolUseComplete { tool_call, .. } => {
451                    if tool_call_ids.insert(tool_call.id.clone()) {
452                        tool_calls.push(tool_call.clone());
453                        let tx_event = context.tx().ok();
454                        EventHelper::send_stream_tool_call(
455                            &tx_event,
456                            task.submission_id,
457                            serde_json::to_value(tool_call).unwrap_or(Value::Null),
458                        )
459                        .await;
460                    }
461                }
462                StreamChunk::Usage(_) => {}
463                _ => {}
464            }
465
466            let tx_event = context.tx().ok();
467            EventHelper::send_stream_chunk(&tx_event, task.submission_id, chunk_clone).await;
468        }
469
470        if tool_calls.is_empty() {
471            if !response_text.is_empty() {
472                memory.store_assistant(&response_text).await;
473            }
474            return Ok(crate::agent::executor::TurnResult::Complete(
475                TurnEngineOutput {
476                    response: response_text,
477                    reasoning_content,
478                    tool_calls: Vec::new(),
479                },
480            ));
481        }
482
483        let tx_event = context.tx().ok();
484        let tool_results = process_tool_calls_with_hooks(
485            hooks,
486            context,
487            task.submission_id,
488            tools,
489            &tool_calls,
490            &tx_event,
491        )
492        .await;
493
494        memory
495            .store_tool_interaction(&tool_calls, &tool_results, &response_text)
496            .await;
497        record_tool_calls_state(context, &tool_results);
498
499        let _ = tx
500            .send(Ok(TurnDelta::ToolResults(tool_results.clone())))
501            .await;
502
503        Ok(crate::agent::executor::TurnResult::Continue(Some(
504            TurnEngineOutput {
505                response: response_text,
506                reasoning_content,
507                tool_calls: tool_results,
508            },
509        )))
510    }
511
512    async fn get_llm_response(
513        &self,
514        context: &Context,
515        messages: &[ChatMessage],
516        tools: &[Box<dyn ToolT>],
517    ) -> Result<Box<dyn autoagents_llm::chat::ChatResponse>, TurnEngineError> {
518        let llm = context.llm();
519        let output_schema = context.config().output_schema.clone();
520
521        if matches!(self.config.tool_mode, ToolMode::Enabled) && !tools.is_empty() {
522            let cached = context.serialized_tools();
523            let tools_serialized = if let Some(cached) = cached {
524                cached
525            } else {
526                Arc::new(tools.iter().map(to_llm_tool).collect::<Vec<_>>())
527            };
528            llm.chat_with_tools(messages, Some(&tools_serialized), output_schema)
529                .await
530                .map_err(TurnEngineError::LLMError)
531        } else {
532            llm.chat(messages, output_schema)
533                .await
534                .map_err(TurnEngineError::LLMError)
535        }
536    }
537
538    async fn get_structured_stream(
539        &self,
540        context: &Context,
541        messages: &[ChatMessage],
542    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamResponse, LLMError>> + Send>>, TurnEngineError>
543    {
544        context
545            .llm()
546            .chat_stream_struct(messages, None, context.config().output_schema.clone())
547            .await
548            .map_err(TurnEngineError::LLMError)
549    }
550
551    async fn get_tool_stream(
552        &self,
553        context: &Context,
554        messages: &[ChatMessage],
555        tools: &[Box<dyn ToolT>],
556    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk, LLMError>> + Send>>, TurnEngineError>
557    {
558        let cached = context.serialized_tools();
559        let tools_serialized = if let Some(cached) = cached {
560            cached
561        } else {
562            Arc::new(tools.iter().map(to_llm_tool).collect::<Vec<_>>())
563        };
564        context
565            .llm()
566            .chat_stream_with_tools(
567                messages,
568                if tools_serialized.is_empty() {
569                    None
570                } else {
571                    Some(&tools_serialized)
572                },
573                context.config().output_schema.clone(),
574            )
575            .await
576            .map_err(TurnEngineError::LLMError)
577    }
578
579    async fn build_messages(
580        &self,
581        context: &Context,
582        task: &Task,
583        memory: &MemoryAdapter,
584        include_user_prompt: bool,
585    ) -> Vec<ChatMessage> {
586        let system_prompt = task
587            .system_prompt
588            .as_deref()
589            .unwrap_or_else(|| &context.config().description);
590        let mut messages = vec![ChatMessage {
591            role: ChatRole::System,
592            message_type: MessageType::Text,
593            content: system_prompt.to_string(),
594        }];
595
596        let recalled = memory.recall_messages(task).await;
597        messages.extend(recalled);
598
599        if include_user_prompt {
600            messages.push(user_message(task));
601        }
602
603        messages
604    }
605}
606
607pub fn record_task_state(context: &Context, task: &Task) {
608    let state = context.state();
609    #[cfg(not(target_arch = "wasm32"))]
610    if let Ok(mut guard) = state.try_lock() {
611        guard.record_task(task.clone());
612    };
613    #[cfg(target_arch = "wasm32")]
614    if let Some(mut guard) = state.try_lock() {
615        guard.record_task(task.clone());
616    };
617}
618
619fn user_message(task: &Task) -> ChatMessage {
620    if let Some((mime, image_data)) = &task.image {
621        ChatMessage {
622            role: ChatRole::User,
623            message_type: MessageType::Image(((*mime).into(), image_data.clone())),
624            content: task.prompt.clone(),
625        }
626    } else {
627        ChatMessage {
628            role: ChatRole::User,
629            message_type: MessageType::Text,
630            content: task.prompt.clone(),
631        }
632    }
633}
634
635fn should_include_user_prompt(memory: &MemoryAdapter, stored_user: bool) -> bool {
636    if !memory.is_enabled() {
637        return true;
638    }
639    if !memory.policy().recall {
640        return true;
641    }
642    if !memory.policy().store_user {
643        return true;
644    }
645    !stored_user
646}
647
648fn should_store_user(turn_state: &TurnState) -> bool {
649    if !turn_state.memory.is_enabled() {
650        return false;
651    }
652    if !turn_state.memory.policy().store_user {
653        return false;
654    }
655    !turn_state.stored_user
656}
657
658fn normalize_max_turns(max_turns: usize, fallback: usize) -> usize {
659    if max_turns == 0 {
660        return fallback.max(1);
661    }
662    max_turns
663}
664
665fn record_tool_calls_state(context: &Context, tool_results: &[ToolCallResult]) {
666    if tool_results.is_empty() {
667        return;
668    }
669    let state = context.state();
670    #[cfg(not(target_arch = "wasm32"))]
671    if let Ok(mut guard) = state.try_lock() {
672        for result in tool_results {
673            guard.record_tool_call(result.clone());
674        }
675    };
676    #[cfg(target_arch = "wasm32")]
677    if let Some(mut guard) = state.try_lock() {
678        for result in tool_results {
679            guard.record_tool_call(result.clone());
680        }
681    };
682}
683
684async fn process_tool_calls_with_hooks<H: AgentHooks>(
685    hooks: &H,
686    context: &Context,
687    submission_id: SubmissionId,
688    tools: &[Box<dyn ToolT>],
689    tool_calls: &[ToolCall],
690    tx_event: &Option<mpsc::Sender<Event>>,
691) -> Vec<ToolCallResult> {
692    let mut results = Vec::new();
693    for call in tool_calls {
694        if let Some(result) = ToolProcessor::process_single_tool_call_with_hooks(
695            hooks,
696            context,
697            submission_id,
698            tools,
699            call,
700            tx_event,
701        )
702        .await
703        {
704            results.push(result);
705        }
706    }
707    results
708}
709
710#[cfg(test)]
711mod tests {
712    use super::*;
713    use crate::agent::memory::{MemoryProvider, SlidingWindowMemory};
714    use crate::agent::task::Task;
715    use crate::agent::{AgentConfig, Context};
716    use crate::tests::{ConfigurableLLMProvider, StaticChatResponse};
717    use async_trait::async_trait;
718    use autoagents_llm::LLMProvider;
719    use autoagents_llm::ToolCall;
720    use autoagents_llm::chat::{StreamChoice, StreamChunk, StreamDelta, StreamResponse};
721    use autoagents_llm::error::GuardrailPhase;
722    use autoagents_protocol::ActorID;
723    use futures::StreamExt;
724
725    #[derive(Debug)]
726    struct LocalTool {
727        name: String,
728        output: serde_json::Value,
729    }
730
731    impl LocalTool {
732        fn new(name: &str, output: serde_json::Value) -> Self {
733            Self {
734                name: name.to_string(),
735                output,
736            }
737        }
738    }
739
740    impl crate::tool::ToolT for LocalTool {
741        fn name(&self) -> &str {
742            &self.name
743        }
744
745        fn description(&self) -> &str {
746            "local tool"
747        }
748
749        fn args_schema(&self) -> serde_json::Value {
750            serde_json::json!({"type": "object"})
751        }
752    }
753
754    #[async_trait]
755    impl crate::tool::ToolRuntime for LocalTool {
756        async fn execute(
757            &self,
758            _args: serde_json::Value,
759        ) -> Result<serde_json::Value, crate::tool::ToolCallError> {
760            Ok(self.output.clone())
761        }
762    }
763
764    #[derive(Debug)]
765    struct GuardrailRejectLLMProvider;
766
767    fn guardrail_block_error() -> LLMError {
768        LLMError::GuardrailBlocked {
769            phase: GuardrailPhase::Input,
770            guard: "prompt-injection".to_string().into(),
771            rule_id: "prompt_injection_detected".to_string().into(),
772            category: "prompt_injection".to_string().into(),
773            severity: "high".to_string().into(),
774            message: "detected suspicious instruction pattern: jailbreak"
775                .to_string()
776                .into(),
777        }
778    }
779
780    #[async_trait]
781    impl autoagents_llm::chat::ChatProvider for GuardrailRejectLLMProvider {
782        async fn chat(
783            &self,
784            _messages: &[ChatMessage],
785            _json_schema: Option<autoagents_llm::chat::StructuredOutputFormat>,
786        ) -> Result<Box<dyn autoagents_llm::chat::ChatResponse>, LLMError> {
787            Err(guardrail_block_error())
788        }
789
790        async fn chat_with_tools(
791            &self,
792            _messages: &[ChatMessage],
793            _tools: Option<&[autoagents_llm::chat::Tool]>,
794            _json_schema: Option<autoagents_llm::chat::StructuredOutputFormat>,
795        ) -> Result<Box<dyn autoagents_llm::chat::ChatResponse>, LLMError> {
796            Err(guardrail_block_error())
797        }
798
799        async fn chat_stream_struct(
800            &self,
801            _messages: &[ChatMessage],
802            _tools: Option<&[autoagents_llm::chat::Tool]>,
803            _json_schema: Option<autoagents_llm::chat::StructuredOutputFormat>,
804        ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamResponse, LLMError>> + Send>>, LLMError>
805        {
806            Err(guardrail_block_error())
807        }
808
809        async fn chat_stream_with_tools(
810            &self,
811            _messages: &[ChatMessage],
812            _tools: Option<&[autoagents_llm::chat::Tool]>,
813            _json_schema: Option<autoagents_llm::chat::StructuredOutputFormat>,
814        ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamChunk, LLMError>> + Send>>, LLMError>
815        {
816            Err(guardrail_block_error())
817        }
818    }
819
820    #[async_trait]
821    impl autoagents_llm::completion::CompletionProvider for GuardrailRejectLLMProvider {
822        async fn complete(
823            &self,
824            _req: &autoagents_llm::completion::CompletionRequest,
825            _json_schema: Option<autoagents_llm::chat::StructuredOutputFormat>,
826        ) -> Result<autoagents_llm::completion::CompletionResponse, LLMError> {
827            Ok(autoagents_llm::completion::CompletionResponse {
828                text: String::default(),
829            })
830        }
831    }
832
833    #[async_trait]
834    impl autoagents_llm::embedding::EmbeddingProvider for GuardrailRejectLLMProvider {
835        async fn embed(&self, _input: Vec<String>) -> Result<Vec<Vec<f32>>, LLMError> {
836            Ok(Vec::new())
837        }
838    }
839
840    #[async_trait]
841    impl autoagents_llm::models::ModelsProvider for GuardrailRejectLLMProvider {}
842
843    impl LLMProvider for GuardrailRejectLLMProvider {}
844
845    fn context_with_memory(llm: Arc<dyn LLMProvider>) -> Context {
846        let config = AgentConfig {
847            id: ActorID::new_v4(),
848            name: "memory_agent".to_string(),
849            description: "desc".to_string(),
850            output_schema: None,
851        };
852        let memory: Box<dyn MemoryProvider> = Box::new(SlidingWindowMemory::new(20));
853        Context::new(llm, None)
854            .with_config(config)
855            .with_memory(Some(Arc::new(tokio::sync::Mutex::new(memory))))
856    }
857
858    async fn recalled_messages(context: &Context) -> Vec<ChatMessage> {
859        let memory = context.memory().expect("memory should exist");
860        memory
861            .lock()
862            .await
863            .recall("", None)
864            .await
865            .expect("memory recall should succeed")
866    }
867
868    #[test]
869    fn test_turn_engine_config_basic() {
870        let config = TurnEngineConfig::basic(5);
871        assert_eq!(config.max_turns, 5);
872        assert!(matches!(config.tool_mode, ToolMode::Disabled));
873        assert!(matches!(config.stream_mode, StreamMode::Structured));
874        assert!(config.memory_policy.recall);
875    }
876
877    #[test]
878    fn test_turn_engine_config_react() {
879        let config = TurnEngineConfig::react(10);
880        assert_eq!(config.max_turns, 10);
881        assert!(matches!(config.tool_mode, ToolMode::Enabled));
882        assert!(matches!(config.stream_mode, StreamMode::Tool));
883        assert!(config.memory_policy.recall);
884    }
885
886    #[tokio::test]
887    async fn test_run_turn_llm_error_does_not_store_user_message() {
888        use crate::tests::MockAgentImpl;
889
890        let llm: Arc<dyn LLMProvider> = Arc::new(GuardrailRejectLLMProvider);
891        let context = context_with_memory(llm);
892        let engine = TurnEngine::new(TurnEngineConfig::basic(1));
893        let mut turn_state = engine.turn_state(&context);
894        let task = Task::new("jailbreak");
895        let hooks = MockAgentImpl::new("test", "test");
896
897        let result = engine
898            .run_turn(&hooks, &task, &context, &mut turn_state, 0, 1)
899            .await;
900        assert!(matches!(
901            result,
902            Err(TurnEngineError::LLMError(LLMError::GuardrailBlocked { .. }))
903        ));
904
905        let stored = recalled_messages(&context).await;
906        assert!(stored.is_empty());
907    }
908
909    #[tokio::test]
910    async fn test_run_turn_success_stores_user_once_in_memory() {
911        use crate::tests::MockAgentImpl;
912
913        let llm: Arc<dyn LLMProvider> = Arc::new(ConfigurableLLMProvider::default());
914        let context = context_with_memory(llm);
915        let engine = TurnEngine::new(TurnEngineConfig::basic(1));
916        let mut turn_state = engine.turn_state(&context);
917        let task = Task::new("hello");
918        let hooks = MockAgentImpl::new("test", "test");
919
920        let result = engine
921            .run_turn(&hooks, &task, &context, &mut turn_state, 0, 1)
922            .await;
923        assert!(matches!(
924            result,
925            Ok(crate::agent::executor::TurnResult::Complete(_))
926        ));
927
928        let stored = recalled_messages(&context).await;
929        let user_count = stored
930            .iter()
931            .filter(|m| m.role == ChatRole::User && m.content == "hello")
932            .count();
933        let assistant_count = stored
934            .iter()
935            .filter(|m| m.role == ChatRole::Assistant)
936            .count();
937
938        assert_eq!(user_count, 1);
939        assert_eq!(assistant_count, 1);
940    }
941
942    #[test]
943    fn test_normalize_max_turns_nonzero() {
944        assert_eq!(normalize_max_turns(5, 10), 5);
945    }
946
947    #[test]
948    fn test_normalize_max_turns_zero_uses_fallback() {
949        assert_eq!(normalize_max_turns(0, 10), 10);
950    }
951
952    #[test]
953    fn test_normalize_max_turns_zero_fallback_zero() {
954        assert_eq!(normalize_max_turns(0, 0), 1);
955    }
956
957    #[test]
958    fn test_should_include_user_prompt_no_memory() {
959        let adapter = MemoryAdapter::new(None, MemoryPolicy::basic());
960        assert!(should_include_user_prompt(&adapter, false));
961    }
962
963    #[test]
964    fn test_should_include_user_prompt_recall_disabled() {
965        let mut policy = MemoryPolicy::basic();
966        policy.recall = false;
967        let mem: Box<dyn crate::agent::memory::MemoryProvider> =
968            Box::new(crate::agent::memory::SlidingWindowMemory::new(10));
969        let adapter = MemoryAdapter::new(
970            Some(std::sync::Arc::new(tokio::sync::Mutex::new(mem))),
971            policy,
972        );
973        assert!(should_include_user_prompt(&adapter, false));
974    }
975
976    #[test]
977    fn test_should_include_user_prompt_store_user_disabled() {
978        let mut policy = MemoryPolicy::basic();
979        policy.store_user = false;
980        let mem: Box<dyn crate::agent::memory::MemoryProvider> =
981            Box::new(crate::agent::memory::SlidingWindowMemory::new(10));
982        let adapter = MemoryAdapter::new(
983            Some(std::sync::Arc::new(tokio::sync::Mutex::new(mem))),
984            policy,
985        );
986        assert!(should_include_user_prompt(&adapter, false));
987    }
988
989    #[test]
990    fn test_should_include_user_prompt_already_stored() {
991        let mem: Box<dyn crate::agent::memory::MemoryProvider> =
992            Box::new(crate::agent::memory::SlidingWindowMemory::new(10));
993        let adapter = MemoryAdapter::new(
994            Some(std::sync::Arc::new(tokio::sync::Mutex::new(mem))),
995            MemoryPolicy::basic(),
996        );
997        // stored_user = true => should not include
998        assert!(!should_include_user_prompt(&adapter, true));
999    }
1000
1001    #[test]
1002    fn test_should_store_user_no_memory() {
1003        let state = TurnState {
1004            memory: MemoryAdapter::new(None, MemoryPolicy::basic()),
1005            stored_user: false,
1006        };
1007        assert!(!should_store_user(&state));
1008    }
1009
1010    #[test]
1011    fn test_should_store_user_already_stored() {
1012        let mem: Box<dyn crate::agent::memory::MemoryProvider> =
1013            Box::new(crate::agent::memory::SlidingWindowMemory::new(10));
1014        let state = TurnState {
1015            memory: MemoryAdapter::new(
1016                Some(std::sync::Arc::new(tokio::sync::Mutex::new(mem))),
1017                MemoryPolicy::basic(),
1018            ),
1019            stored_user: true,
1020        };
1021        assert!(!should_store_user(&state));
1022    }
1023
1024    #[test]
1025    fn test_user_message_text() {
1026        let task = Task::new("hello");
1027        let msg = user_message(&task);
1028        assert!(matches!(msg.role, ChatRole::User));
1029        assert!(matches!(msg.message_type, MessageType::Text));
1030        assert_eq!(msg.content, "hello");
1031    }
1032
1033    #[test]
1034    fn test_user_message_image() {
1035        let mut task = Task::new("describe");
1036        task.image = Some((autoagents_protocol::ImageMime::PNG, vec![1, 2, 3]));
1037        let msg = user_message(&task);
1038        assert!(matches!(msg.role, ChatRole::User));
1039        assert!(matches!(msg.message_type, MessageType::Image(_)));
1040    }
1041
1042    #[test]
1043    fn test_turn_state_new_and_mark_user_stored() {
1044        let config = AgentConfig {
1045            id: ActorID::new_v4(),
1046            name: "test".to_string(),
1047            description: "test".to_string(),
1048            output_schema: None,
1049        };
1050        let llm = std::sync::Arc::new(crate::tests::MockLLMProvider {});
1051        let context = Context::new(llm, None).with_config(config);
1052
1053        let mut state = TurnState::new(&context, MemoryPolicy::basic());
1054        assert!(!state.stored_user());
1055        state.mark_user_stored();
1056        assert!(state.stored_user());
1057    }
1058
1059    #[tokio::test]
1060    async fn test_build_messages_with_system_prompt() {
1061        let config = AgentConfig {
1062            id: ActorID::new_v4(),
1063            name: "test".to_string(),
1064            description: "default desc".to_string(),
1065            output_schema: None,
1066        };
1067        let llm = std::sync::Arc::new(crate::tests::MockLLMProvider {});
1068        let context = Context::new(llm, None).with_config(config);
1069
1070        let engine = TurnEngine::new(TurnEngineConfig::basic(1));
1071        let adapter = MemoryAdapter::new(None, MemoryPolicy::basic());
1072        let mut task = Task::new("user input");
1073        task.system_prompt = Some("custom system".to_string());
1074
1075        let messages = engine.build_messages(&context, &task, &adapter, true).await;
1076        // System + user
1077        assert_eq!(messages.len(), 2);
1078        assert_eq!(messages[0].content, "custom system");
1079        assert_eq!(messages[0].role, ChatRole::System);
1080        assert_eq!(messages[1].content, "user input");
1081    }
1082
1083    #[tokio::test]
1084    async fn test_build_messages_without_user_prompt() {
1085        let config = AgentConfig {
1086            id: ActorID::new_v4(),
1087            name: "test".to_string(),
1088            description: "desc".to_string(),
1089            output_schema: None,
1090        };
1091        let llm = std::sync::Arc::new(crate::tests::MockLLMProvider {});
1092        let context = Context::new(llm, None).with_config(config);
1093
1094        let engine = TurnEngine::new(TurnEngineConfig::basic(1));
1095        let adapter = MemoryAdapter::new(None, MemoryPolicy::basic());
1096        let task = Task::new("user input");
1097
1098        let messages = engine
1099            .build_messages(&context, &task, &adapter, false)
1100            .await;
1101        // Only system prompt
1102        assert_eq!(messages.len(), 1);
1103        assert_eq!(messages[0].role, ChatRole::System);
1104    }
1105
1106    #[tokio::test]
1107    async fn test_run_turn_no_tools_single_turn() {
1108        use crate::tests::MockAgentImpl;
1109        let config = AgentConfig {
1110            id: ActorID::new_v4(),
1111            name: "test".to_string(),
1112            description: "test desc".to_string(),
1113            output_schema: None,
1114        };
1115        let llm = std::sync::Arc::new(crate::tests::MockLLMProvider {});
1116        let context = Context::new(llm, None).with_config(config);
1117
1118        let engine = TurnEngine::new(TurnEngineConfig::basic(1));
1119        let mut turn_state = engine.turn_state(&context);
1120        let task = Task::new("test prompt");
1121        let hooks = MockAgentImpl::new("test", "test");
1122
1123        let result = engine
1124            .run_turn(&hooks, &task, &context, &mut turn_state, 0, 1)
1125            .await;
1126        assert!(result.is_ok());
1127        let turn_result = result.unwrap();
1128        assert!(matches!(
1129            turn_result,
1130            crate::agent::executor::TurnResult::Complete(_)
1131        ));
1132        if let crate::agent::executor::TurnResult::Complete(output) = turn_result {
1133            assert_eq!(output.response, "Mock response");
1134        }
1135    }
1136
1137    #[tokio::test]
1138    async fn test_run_turn_with_tool_calls_continues() {
1139        use crate::tests::MockAgentImpl;
1140        let tool_call = ToolCall {
1141            id: "call_1".to_string(),
1142            call_type: "function".to_string(),
1143            function: autoagents_llm::FunctionCall {
1144                name: "tool_a".to_string(),
1145                arguments: r#"{"value":1}"#.to_string(),
1146            },
1147        };
1148
1149        let llm = Arc::new(ConfigurableLLMProvider {
1150            chat_response: StaticChatResponse {
1151                text: Some("Use tool".to_string()),
1152                tool_calls: Some(vec![tool_call.clone()]),
1153                usage: None,
1154                thinking: None,
1155            },
1156            ..ConfigurableLLMProvider::default()
1157        });
1158
1159        let config = AgentConfig {
1160            id: ActorID::new_v4(),
1161            name: "tool_agent".to_string(),
1162            description: "desc".to_string(),
1163            output_schema: None,
1164        };
1165        let tool = LocalTool::new("tool_a", serde_json::json!({"ok": true}));
1166        let context = Context::new(llm, None)
1167            .with_config(config)
1168            .with_tools(vec![Box::new(tool)]);
1169
1170        let engine = TurnEngine::new(TurnEngineConfig {
1171            max_turns: 2,
1172            tool_mode: ToolMode::Enabled,
1173            stream_mode: StreamMode::Structured,
1174            memory_policy: MemoryPolicy::basic(),
1175        });
1176        let mut turn_state = engine.turn_state(&context);
1177        let task = Task::new("prompt");
1178        let hooks = MockAgentImpl::new("test", "test");
1179
1180        let result = engine
1181            .run_turn(&hooks, &task, &context, &mut turn_state, 0, 2)
1182            .await
1183            .unwrap();
1184
1185        match result {
1186            crate::agent::executor::TurnResult::Continue(Some(output)) => {
1187                assert_eq!(output.response, "Use tool");
1188                assert_eq!(output.tool_calls.len(), 1);
1189                assert!(output.tool_calls[0].success);
1190            }
1191            _ => panic!("expected Continue(Some)"),
1192        }
1193
1194        #[cfg(not(target_arch = "wasm32"))]
1195        if let Ok(state) = context.state().try_lock() {
1196            assert_eq!(state.tool_calls.len(), 1);
1197        }
1198    }
1199
1200    #[tokio::test]
1201    async fn test_run_turn_tool_mode_disabled_ignores_tool_calls() {
1202        use crate::tests::MockAgentImpl;
1203        let tool_call = ToolCall {
1204            id: "call_1".to_string(),
1205            call_type: "function".to_string(),
1206            function: autoagents_llm::FunctionCall {
1207                name: "tool_a".to_string(),
1208                arguments: r#"{"value":1}"#.to_string(),
1209            },
1210        };
1211
1212        let llm = Arc::new(ConfigurableLLMProvider {
1213            chat_response: StaticChatResponse {
1214                text: Some("No tools".to_string()),
1215                tool_calls: Some(vec![tool_call]),
1216                usage: None,
1217                thinking: None,
1218            },
1219            ..ConfigurableLLMProvider::default()
1220        });
1221
1222        let config = AgentConfig {
1223            id: ActorID::new_v4(),
1224            name: "tool_agent".to_string(),
1225            description: "desc".to_string(),
1226            output_schema: None,
1227        };
1228        let context = Context::new(llm, None).with_config(config);
1229
1230        let engine = TurnEngine::new(TurnEngineConfig {
1231            max_turns: 1,
1232            tool_mode: ToolMode::Disabled,
1233            stream_mode: StreamMode::Structured,
1234            memory_policy: MemoryPolicy::basic(),
1235        });
1236        let mut turn_state = engine.turn_state(&context);
1237        let task = Task::new("prompt");
1238        let hooks = MockAgentImpl::new("test", "test");
1239
1240        let result = engine
1241            .run_turn(&hooks, &task, &context, &mut turn_state, 0, 1)
1242            .await
1243            .unwrap();
1244
1245        match result {
1246            crate::agent::executor::TurnResult::Complete(output) => {
1247                assert_eq!(output.response, "No tools");
1248                assert!(output.tool_calls.is_empty());
1249            }
1250            _ => panic!("expected Complete"),
1251        }
1252    }
1253
1254    #[tokio::test]
1255    async fn test_run_turn_propagates_reasoning_content() {
1256        use crate::tests::MockAgentImpl;
1257
1258        let llm = Arc::new(ConfigurableLLMProvider {
1259            chat_response: StaticChatResponse {
1260                text: Some("answer".to_string()),
1261                tool_calls: None,
1262                usage: None,
1263                thinking: Some("reasoning".to_string()),
1264            },
1265            ..ConfigurableLLMProvider::default()
1266        });
1267
1268        let config = AgentConfig {
1269            id: ActorID::new_v4(),
1270            name: "reasoning_agent".to_string(),
1271            description: "desc".to_string(),
1272            output_schema: None,
1273        };
1274        let context = Context::new(llm, None).with_config(config);
1275        let engine = TurnEngine::new(TurnEngineConfig::basic(1));
1276        let mut turn_state = engine.turn_state(&context);
1277        let task = Task::new("prompt");
1278        let hooks = MockAgentImpl::new("test", "test");
1279
1280        let result = engine
1281            .run_turn(&hooks, &task, &context, &mut turn_state, 0, 1)
1282            .await
1283            .unwrap();
1284
1285        match result {
1286            crate::agent::executor::TurnResult::Complete(output) => {
1287                assert_eq!(output.response, "answer");
1288                assert_eq!(output.reasoning_content, "reasoning");
1289            }
1290            _ => panic!("expected Complete"),
1291        }
1292    }
1293
1294    #[tokio::test]
1295    async fn test_run_turn_stream_structured_aggregates_text() {
1296        use crate::tests::MockAgentImpl;
1297        let llm = Arc::new(ConfigurableLLMProvider {
1298            structured_stream: vec![
1299                StreamResponse {
1300                    choices: vec![StreamChoice {
1301                        delta: StreamDelta {
1302                            content: Some("Hello ".to_string()),
1303                            reasoning_content: None,
1304                            tool_calls: None,
1305                        },
1306                    }],
1307                    usage: None,
1308                },
1309                StreamResponse {
1310                    choices: vec![StreamChoice {
1311                        delta: StreamDelta {
1312                            content: Some("world".to_string()),
1313                            reasoning_content: None,
1314                            tool_calls: None,
1315                        },
1316                    }],
1317                    usage: None,
1318                },
1319            ],
1320            ..ConfigurableLLMProvider::default()
1321        });
1322
1323        let config = AgentConfig {
1324            id: ActorID::new_v4(),
1325            name: "stream_agent".to_string(),
1326            description: "desc".to_string(),
1327            output_schema: None,
1328        };
1329        let context = Arc::new(Context::new(llm, None).with_config(config));
1330        let engine = TurnEngine::new(TurnEngineConfig {
1331            max_turns: 1,
1332            tool_mode: ToolMode::Disabled,
1333            stream_mode: StreamMode::Structured,
1334            memory_policy: MemoryPolicy::basic(),
1335        });
1336        let mut turn_state = engine.turn_state(&context);
1337        let task = Task::new("prompt");
1338        let hooks = MockAgentImpl::new("test", "test");
1339
1340        let mut stream = engine
1341            .run_turn_stream(hooks, &task, context, &mut turn_state, 0, 1)
1342            .await
1343            .unwrap();
1344
1345        let mut final_text = String::default();
1346        while let Some(delta) = stream.next().await {
1347            if let Ok(TurnDelta::Done(result)) = delta {
1348                final_text = match result {
1349                    crate::agent::executor::TurnResult::Complete(output) => output.response,
1350                    crate::agent::executor::TurnResult::Continue(Some(output)) => output.response,
1351                    crate::agent::executor::TurnResult::Continue(None) => String::default(),
1352                };
1353                break;
1354            }
1355        }
1356
1357        assert_eq!(final_text, "Hello world");
1358    }
1359
1360    #[tokio::test]
1361    async fn test_run_turn_stream_structured_emits_reasoning_content() {
1362        use crate::tests::MockAgentImpl;
1363        let llm = Arc::new(ConfigurableLLMProvider {
1364            structured_stream: vec![StreamResponse {
1365                choices: vec![StreamChoice {
1366                    delta: StreamDelta {
1367                        content: None,
1368                        reasoning_content: Some("think".to_string()),
1369                        tool_calls: None,
1370                    },
1371                }],
1372                usage: None,
1373            }],
1374            ..ConfigurableLLMProvider::default()
1375        });
1376
1377        let config = AgentConfig {
1378            id: ActorID::new_v4(),
1379            name: "stream_reasoning_agent".to_string(),
1380            description: "desc".to_string(),
1381            output_schema: None,
1382        };
1383        let context = Arc::new(Context::new(llm, None).with_config(config));
1384        let engine = TurnEngine::new(TurnEngineConfig::basic(1));
1385        let mut turn_state = engine.turn_state(&context);
1386        let task = Task::new("prompt");
1387        let hooks = MockAgentImpl::new("test", "test");
1388
1389        let mut stream = engine
1390            .run_turn_stream(hooks, &task, context, &mut turn_state, 0, 1)
1391            .await
1392            .unwrap();
1393
1394        let mut saw_delta = false;
1395        let mut final_reasoning = String::default();
1396        while let Some(delta) = stream.next().await {
1397            match delta {
1398                Ok(TurnDelta::ReasoningContent(text)) => {
1399                    saw_delta = true;
1400                    assert_eq!(text, "think");
1401                }
1402                Ok(TurnDelta::Done(result)) => {
1403                    final_reasoning = match result {
1404                        crate::agent::executor::TurnResult::Complete(output) => {
1405                            output.reasoning_content
1406                        }
1407                        crate::agent::executor::TurnResult::Continue(Some(output)) => {
1408                            output.reasoning_content
1409                        }
1410                        crate::agent::executor::TurnResult::Continue(None) => String::default(),
1411                    };
1412                    break;
1413                }
1414                _ => {}
1415            }
1416        }
1417
1418        assert!(saw_delta);
1419        assert_eq!(final_reasoning, "think");
1420    }
1421
1422    #[tokio::test]
1423    async fn test_run_turn_stream_with_tools_executes_tools() {
1424        use crate::tests::MockAgentImpl;
1425        let tool_call = ToolCall {
1426            id: "call_1".to_string(),
1427            call_type: "function".to_string(),
1428            function: autoagents_llm::FunctionCall {
1429                name: "tool_a".to_string(),
1430                arguments: r#"{"value":1}"#.to_string(),
1431            },
1432        };
1433
1434        let llm = Arc::new(ConfigurableLLMProvider {
1435            stream_chunks: vec![
1436                StreamChunk::Text("thinking".to_string()),
1437                StreamChunk::ToolUseComplete {
1438                    index: 0,
1439                    tool_call: tool_call.clone(),
1440                },
1441                StreamChunk::Done {
1442                    stop_reason: "tool_use".to_string(),
1443                },
1444            ],
1445            ..ConfigurableLLMProvider::default()
1446        });
1447
1448        let config = AgentConfig {
1449            id: ActorID::new_v4(),
1450            name: "tool_stream_agent".to_string(),
1451            description: "desc".to_string(),
1452            output_schema: None,
1453        };
1454        let tool = LocalTool::new("tool_a", serde_json::json!({"ok": true}));
1455        let context = Arc::new(
1456            Context::new(llm, None)
1457                .with_config(config)
1458                .with_tools(vec![Box::new(tool)]),
1459        );
1460        let engine = TurnEngine::new(TurnEngineConfig {
1461            max_turns: 1,
1462            tool_mode: ToolMode::Enabled,
1463            stream_mode: StreamMode::Tool,
1464            memory_policy: MemoryPolicy::basic(),
1465        });
1466        let mut turn_state = engine.turn_state(&context);
1467        let task = Task::new("prompt");
1468        let hooks = MockAgentImpl::new("test", "test");
1469
1470        let mut stream = engine
1471            .run_turn_stream(hooks, &task, context, &mut turn_state, 0, 1)
1472            .await
1473            .unwrap();
1474
1475        let mut final_result = None;
1476        while let Some(delta) = stream.next().await {
1477            if let Ok(TurnDelta::Done(result)) = delta {
1478                final_result = Some(result);
1479                break;
1480            }
1481        }
1482
1483        match final_result.expect("done") {
1484            crate::agent::executor::TurnResult::Continue(Some(output)) => {
1485                assert_eq!(output.tool_calls.len(), 1);
1486                assert!(output.tool_calls[0].success);
1487            }
1488            _ => panic!("expected Continue(Some)"),
1489        }
1490    }
1491
1492    #[tokio::test]
1493    async fn test_run_turn_stream_llm_error_does_not_store_user_message() {
1494        use crate::tests::MockAgentImpl;
1495
1496        let llm: Arc<dyn LLMProvider> = Arc::new(GuardrailRejectLLMProvider);
1497        let context = Arc::new(context_with_memory(llm));
1498        let engine = TurnEngine::new(TurnEngineConfig::basic(1));
1499        let mut turn_state = engine.turn_state(&context);
1500        let task = Task::new("jailbreak");
1501        let hooks = MockAgentImpl::new("test", "test");
1502
1503        let mut stream = engine
1504            .run_turn_stream(hooks, &task, context.clone(), &mut turn_state, 0, 1)
1505            .await
1506            .expect("stream should initialize");
1507
1508        let first = stream
1509            .next()
1510            .await
1511            .expect("stream should emit an error event");
1512        assert!(matches!(
1513            first,
1514            Err(TurnEngineError::LLMError(LLMError::GuardrailBlocked { .. }))
1515        ));
1516
1517        let stored = recalled_messages(&context).await;
1518        assert!(stored.is_empty());
1519    }
1520
1521    #[tokio::test]
1522    async fn test_run_turn_stream_success_stores_user_once_in_memory() {
1523        use crate::tests::MockAgentImpl;
1524
1525        let llm: Arc<dyn LLMProvider> = Arc::new(ConfigurableLLMProvider {
1526            structured_stream: vec![StreamResponse {
1527                choices: vec![StreamChoice {
1528                    delta: StreamDelta {
1529                        content: Some("hello".to_string()),
1530                        reasoning_content: None,
1531                        tool_calls: None,
1532                    },
1533                }],
1534                usage: None,
1535            }],
1536            ..ConfigurableLLMProvider::default()
1537        });
1538        let context = Arc::new(context_with_memory(llm));
1539        let engine = TurnEngine::new(TurnEngineConfig::basic(1));
1540        let mut turn_state = engine.turn_state(&context);
1541        let task = Task::new("hello");
1542        let hooks = MockAgentImpl::new("test", "test");
1543
1544        let mut stream = engine
1545            .run_turn_stream(hooks, &task, context.clone(), &mut turn_state, 0, 1)
1546            .await
1547            .expect("stream should initialize");
1548
1549        while let Some(delta) = stream.next().await {
1550            if matches!(delta, Ok(TurnDelta::Done(_))) {
1551                break;
1552            }
1553        }
1554
1555        let stored = recalled_messages(&context).await;
1556        let user_count = stored
1557            .iter()
1558            .filter(|m| m.role == ChatRole::User && m.content == "hello")
1559            .count();
1560        let assistant_count = stored
1561            .iter()
1562            .filter(|m| m.role == ChatRole::Assistant)
1563            .count();
1564
1565        assert_eq!(user_count, 1);
1566        assert_eq!(assistant_count, 1);
1567    }
1568}