Skip to main content

motosan_agent_loop/
loop_.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4use futures::future::{join_all, try_join_all};
5use motosan_agent_tool::{Tool, ToolContext, ToolDef, ToolResult};
6use tokio::sync::mpsc;
7
8use crate::context::ContextProvider;
9use crate::error::AgentError;
10use crate::llm::{LlmClient, TokenUsage, ToolCallItem};
11use crate::message::{Message, ToolCallRef};
12use crate::Result;
13
14/// Policy for handling channel backpressure when a bounded queue is full.
15#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
16pub enum BackpressurePolicy {
17    /// Block the sender until capacity is available (default).
18    #[default]
19    Block,
20    /// Drop the **incoming** (newest) op when the queue is full.
21    ///
22    /// Note: true drop-oldest semantics would require a ring-buffer channel.
23    /// This variant currently behaves like [`Reject`] but emits [`AgentEvent::OpDropped`]
24    /// instead of [`AgentEvent::OpRejected`], allowing callers to distinguish intent.
25    DropOldest,
26    /// Reject the new item and emit a telemetry event.
27    Reject,
28}
29
30/// Configuration for bounded channel capacities used by [`AgentSession`].
31///
32/// All capacities have sensible defaults (64) and can be overridden via the
33/// builder pattern on [`AgentLoopBuilder`].
34#[derive(Debug, Clone)]
35pub struct ChannelConfig {
36    /// Capacity of the user-input channel (messages sent via `AgentSession::send`).
37    pub input_capacity: usize,
38    /// Capacity of the per-turn operations channel (ops like Interrupt, InjectHint).
39    pub ops_capacity: usize,
40    /// Backpressure policy applied when the ops channel is full.
41    pub ops_backpressure: BackpressurePolicy,
42}
43
44impl Default for ChannelConfig {
45    fn default() -> Self {
46        Self {
47            input_capacity: 64,
48            ops_capacity: 64,
49            ops_backpressure: BackpressurePolicy::Block,
50        }
51    }
52}
53
54// ---------------------------------------------------------------------------
55// Pipeline stage helpers
56// ---------------------------------------------------------------------------
57
58/// Resolved tool registry for a single run, combining static and dynamic tools.
59///
60/// This always clones the base maps so it can own the data. When `extra_tools`
61/// is empty the clone is the only cost; when extra tools are present they are
62/// inserted into the cloned maps.
63struct MergedTools {
64    map: HashMap<String, Arc<dyn Tool>>,
65    defs: Vec<ToolDef>,
66}
67
68impl MergedTools {
69    /// Clone base maps and merge any extra tools into them.
70    ///
71    /// Note: this always clones `base_map` and `base_defs` because `MergedTools`
72    /// owns its data. If `extra_tools` is empty the clone is the only cost.
73    fn new(
74        base_map: &HashMap<String, Arc<dyn Tool>>,
75        base_defs: &[ToolDef],
76        extra_tools: &[Arc<dyn Tool>],
77    ) -> Self {
78        if extra_tools.is_empty() {
79            return Self {
80                map: base_map.clone(),
81                defs: base_defs.to_vec(),
82            };
83        }
84        let mut map = base_map.clone();
85        let mut defs = base_defs.to_vec();
86        for t in extra_tools {
87            map.insert(t.def().name.clone(), Arc::clone(t));
88            defs.push(t.def());
89        }
90        Self { map, defs }
91    }
92
93    fn tool_map(&self) -> &HashMap<String, Arc<dyn Tool>> {
94        &self.map
95    }
96
97    fn tool_defs(&self) -> &[ToolDef] {
98        &self.defs
99    }
100}
101
102/// Accumulator for per-run mutable state shared across turn iterations.
103struct TurnState {
104    messages: Vec<Message>,
105    all_tool_calls: Vec<(String, serde_json::Value)>,
106    total_usage: TokenUsage,
107}
108
109impl TurnState {
110    fn new(messages: Vec<Message>) -> Self {
111        Self {
112            messages,
113            all_tool_calls: Vec::new(),
114            total_usage: TokenUsage::default(),
115        }
116    }
117
118    /// Record token usage from an LLM call.
119    fn accumulate_usage(&mut self, usage: Option<TokenUsage>) {
120        if let Some(u) = usage {
121            self.total_usage.accumulate(u);
122        }
123    }
124
125    /// Build the final result after the LLM produces a text answer.
126    fn into_result(self, answer: String, iteration: usize) -> AgentResult {
127        AgentResult {
128            answer,
129            tool_calls: self.all_tool_calls,
130            iterations: iteration,
131            usage: self.total_usage,
132            messages: self.messages,
133        }
134    }
135}
136
137/// Stage: execute tool calls, emit events, and append messages.
138///
139/// This is the common "tool-call planning/execution + result merge" stage
140/// extracted from every turn loop variant.
141fn execute_and_record_tool_calls(
142    items: &[ToolCallItem],
143    results: Vec<ToolResult>,
144    state: &mut TurnState,
145    on_event: &(impl Fn(AgentEvent) + Send + Sync),
146) {
147    // Emit ToolCompleted events and record tool calls.
148    for (tc, result) in items.iter().zip(results.iter()) {
149        on_event(AgentEvent::ToolCompleted {
150            name: tc.name.clone(),
151            result: result.clone(),
152        });
153        state
154            .all_tool_calls
155            .push((tc.name.clone(), tc.args.clone()));
156    }
157
158    // Build a single assistant message with all tool call refs.
159    let tool_call_refs: Vec<ToolCallRef> = items
160        .iter()
161        .map(|tc| ToolCallRef {
162            id: tc.id.clone(),
163            name: tc.name.clone(),
164            args: tc.args.clone(),
165        })
166        .collect();
167    state
168        .messages
169        .push(Message::assistant_with_tool_calls("", tool_call_refs));
170
171    // Append individual tool result messages (order matches tool call order).
172    for (tc, result) in items.iter().zip(results.iter()) {
173        state
174            .messages
175            .push(Message::tool_result(&tc.id, &tool_result_to_string(result)));
176    }
177}
178
179/// Stage: emit ToolStarted events for all items in a batch.
180fn emit_tool_started(items: &[ToolCallItem], on_event: &(impl Fn(AgentEvent) + Send + Sync)) {
181    for tc in items {
182        on_event(AgentEvent::ToolStarted {
183            name: tc.name.clone(),
184        });
185    }
186}
187
188/// Events emitted by the agent loop for observability.
189#[derive(Debug, Clone)]
190pub enum AgentEvent {
191    /// A tool is about to be executed.
192    ToolStarted { name: String },
193    /// A tool execution completed.
194    ToolCompleted { name: String, result: ToolResult },
195    /// A text chunk from the LLM's streaming response.
196    TextChunk(String),
197    /// The full accumulated text after streaming completes.
198    TextDone(String),
199    /// Emitted at the start of each reasoning iteration.
200    IterationStarted(usize),
201    /// Loop was interrupted via `AgentOp::Interrupt`.
202    Interrupted,
203    /// Agent requested user input via the built-in `ask_user` tool.
204    AskUser {
205        call_id: String,
206        question: String,
207        options: Vec<String>,
208    },
209    /// A pending `ask_user` request timed out.
210    AskUserTimeout { call_id: String, question: String },
211    /// The ops channel reached its capacity limit.
212    ///
213    /// Emitted once when the queue becomes full, not on every blocked/rejected send.
214    OpsSaturated {
215        /// Current queue capacity.
216        capacity: usize,
217    },
218    /// An operation was dropped due to backpressure policy.
219    OpDropped {
220        /// Description of the dropped operation.
221        reason: String,
222    },
223    /// An operation was rejected due to backpressure policy.
224    OpRejected {
225        /// Description of the rejected operation.
226        reason: String,
227    },
228}
229
230/// Commands that can be sent into a running [`AgentLoop`].
231#[derive(Debug, Clone)]
232pub enum AgentOp {
233    /// Stop the current turn at the next safe checkpoint.
234    Interrupt,
235    /// Append a user message before the next LLM iteration.
236    InjectUserMessage(String),
237    /// Append a user-visible note before the next LLM iteration.
238    InjectHint(String),
239    /// Answer a pending `ask_user` request.
240    AskUserAnswer {
241        call_id: Option<String>,
242        answer: String,
243    },
244}
245
246/// The final outcome produced by [`AgentLoop::run`].
247#[derive(Debug, Clone)]
248pub struct AgentResult {
249    /// The assistant's final textual answer.
250    pub answer: String,
251    /// History of tool calls made: (tool_name, arguments).
252    pub tool_calls: Vec<(String, serde_json::Value)>,
253    /// Number of LLM round-trips performed.
254    pub iterations: usize,
255    /// Accumulated token usage across all LLM calls.
256    pub usage: TokenUsage,
257    /// Full conversation history including tool call/result pairs.
258    ///
259    /// Callers can pass this to a subsequent `run()` call to continue a
260    /// multi-turn conversation.
261    pub messages: Vec<Message>,
262}
263
264/// Builder for constructing an [`AgentLoop`] with validated configuration.
265pub struct AgentLoopBuilder {
266    tools: Vec<Arc<dyn Tool>>,
267    context_providers: Vec<Box<dyn ContextProvider>>,
268    max_iterations: usize,
269    tool_timeout: Option<std::time::Duration>,
270    tool_context: Option<ToolContext>,
271    ask_user_enabled: bool,
272    ask_user_timeout: Option<std::time::Duration>,
273    channel_config: ChannelConfig,
274    #[cfg(feature = "mcp-client")]
275    mcp_servers: Vec<Arc<dyn crate::mcp::McpServer>>,
276}
277
278impl AgentLoopBuilder {
279    /// Set the maximum number of LLM round-trips before aborting.
280    pub fn max_iterations(mut self, n: usize) -> Self {
281        self.max_iterations = n;
282        self
283    }
284
285    /// Register a tool with the agent loop.
286    pub fn tool(mut self, tool: Arc<dyn Tool>) -> Self {
287        self.tools.push(tool);
288        self
289    }
290
291    /// Register multiple tools at once.
292    pub fn tools(mut self, tools: impl IntoIterator<Item = Arc<dyn Tool>>) -> Self {
293        self.tools.extend(tools);
294        self
295    }
296
297    /// Set a static system prompt that is injected as a `System` message
298    /// at the beginning of every conversation.
299    ///
300    /// This is a convenience shortcut for registering a [`ContextProvider`]
301    /// that always returns the given string.
302    pub fn system_prompt(self, prompt: impl Into<String>) -> Self {
303        self.context(crate::context::StringContextProvider(prompt.into()))
304    }
305
306    /// Register a context provider that injects dynamic context into the
307    /// conversation before each [`AgentLoop::run`] call.
308    ///
309    /// Multiple providers can be registered; they are invoked in registration
310    /// order and each non-empty result becomes a `System` message.
311    pub fn context(mut self, provider: impl ContextProvider + 'static) -> Self {
312        self.context_providers.push(Box::new(provider));
313        self
314    }
315
316    /// Register multiple context providers at once.
317    ///
318    /// This is the batch equivalent of [`context()`](Self::context), mirroring
319    /// the [`tools()`](Self::tools) / [`tool()`](Self::tool) pattern.
320    pub fn contexts(
321        mut self,
322        providers: impl IntoIterator<Item = Box<dyn ContextProvider>>,
323    ) -> Self {
324        self.context_providers.extend(providers);
325        self
326    }
327
328    /// Register an MCP server whose tools will be connected lazily on `run()`.
329    ///
330    /// Requires the `mcp-client` feature.
331    #[cfg(feature = "mcp-client")]
332    pub fn mcp_server(mut self, server: impl crate::mcp::McpServer + 'static) -> Self {
333        self.mcp_servers.push(std::sync::Arc::new(server));
334        self
335    }
336
337    /// Register a pre-shared `Arc<dyn McpServer>` whose tools will be connected
338    /// lazily on `run()`.
339    ///
340    /// This is useful when the same MCP server instance must be shared between
341    /// multiple [`AgentLoop`]s or retained by the caller.
342    ///
343    /// Requires the `mcp-client` feature.
344    #[cfg(feature = "mcp-client")]
345    pub fn mcp_server_arc(mut self, server: std::sync::Arc<dyn crate::mcp::McpServer>) -> Self {
346        self.mcp_servers.push(server);
347        self
348    }
349
350    /// Set a per-tool-call timeout.  When a tool takes longer than `duration`
351    /// it returns a `ToolResult::error("timed out")` instead of blocking.
352    ///
353    /// By default there is no timeout.
354    pub fn tool_timeout(mut self, duration: std::time::Duration) -> Self {
355        self.tool_timeout = Some(duration);
356        self
357    }
358
359    /// Set a custom [`ToolContext`] that will be passed to every tool invocation.
360    ///
361    /// By default, [`ToolContext::default()`] is used. Use this when tools need
362    /// access to session-level state such as a working directory or environment
363    /// variables.
364    pub fn tool_context(mut self, ctx: ToolContext) -> Self {
365        self.tool_context = Some(ctx);
366        self
367    }
368
369    /// Set the full channel configuration for bounded queues.
370    ///
371    /// See [`ChannelConfig`] for details on each field.
372    pub fn channel_config(mut self, config: ChannelConfig) -> Self {
373        self.channel_config = config;
374        self
375    }
376
377    /// Set the capacity of the user-input channel.
378    ///
379    /// Default: 64.
380    pub fn input_channel_capacity(mut self, capacity: usize) -> Self {
381        self.channel_config.input_capacity = capacity;
382        self
383    }
384
385    /// Set the capacity of the per-turn operations channel.
386    ///
387    /// Default: 64.
388    pub fn ops_channel_capacity(mut self, capacity: usize) -> Self {
389        self.channel_config.ops_capacity = capacity;
390        self
391    }
392
393    /// Set the backpressure policy for the operations channel.
394    ///
395    /// Default: [`BackpressurePolicy::Block`].
396    pub fn ops_backpressure(mut self, policy: BackpressurePolicy) -> Self {
397        self.channel_config.ops_backpressure = policy;
398        self
399    }
400
401    /// Register the built-in `ask_user` tool with the default timeout (30s).
402    pub fn with_ask_user(mut self) -> Self {
403        self.ask_user_enabled = true;
404        if self.ask_user_timeout.is_none() {
405            self.ask_user_timeout = Some(std::time::Duration::from_secs(30));
406        }
407        self
408    }
409
410    /// Register the built-in `ask_user` tool with a custom timeout.
411    pub fn with_ask_user_timeout(mut self, timeout: std::time::Duration) -> Self {
412        self.ask_user_enabled = true;
413        self.ask_user_timeout = Some(timeout);
414        self
415    }
416
417    /// Consume the builder and produce an [`AgentLoop`].
418    ///
419    /// # Panics
420    ///
421    /// Panics if `input_capacity` or `ops_capacity` is zero, since
422    /// `tokio::sync::mpsc::channel(0)` panics at runtime.
423    pub fn build(self) -> AgentLoop {
424        assert!(
425            self.channel_config.input_capacity > 0,
426            "input_capacity must be >= 1"
427        );
428        assert!(
429            self.channel_config.ops_capacity > 0,
430            "ops_capacity must be >= 1"
431        );
432        let tool_map: HashMap<String, Arc<dyn Tool>> = self
433            .tools
434            .iter()
435            .map(|t| (t.def().name.clone(), Arc::clone(t)))
436            .collect();
437        let mut tool_defs: Vec<ToolDef> = self.tools.iter().map(|t| t.def()).collect();
438        if self.ask_user_enabled {
439            tool_defs.push(ask_user_tool_def());
440        }
441        AgentLoop {
442            tool_map,
443            tool_defs,
444            context_providers: self.context_providers,
445            max_iterations: self.max_iterations,
446            tool_timeout: self.tool_timeout,
447            tool_context: self.tool_context.unwrap_or_default(),
448            ask_user_timeout: if self.ask_user_enabled {
449                self.ask_user_timeout
450            } else {
451                None
452            },
453            channel_config: self.channel_config,
454            #[cfg(feature = "mcp-client")]
455            mcp_servers: self.mcp_servers,
456        }
457    }
458}
459
460/// The core ReAct agent loop.
461///
462/// Drives an LLM through iterative reasoning and tool execution until the
463/// model produces a final text answer or the iteration limit is reached.
464///
465/// The loop does **not** own the LLM client; instead, `run()` takes
466/// `&dyn LlmClient` so the same loop can be reused with different backends.
467pub struct AgentLoop {
468    /// Pre-built lookup map for static tools (excludes per-run MCP tools).
469    tool_map: HashMap<String, Arc<dyn Tool>>,
470    /// Pre-built tool definitions for static tools (excludes per-run MCP tools).
471    tool_defs: Vec<ToolDef>,
472    context_providers: Vec<Box<dyn ContextProvider>>,
473    max_iterations: usize,
474    /// Optional per-tool-call timeout.  When set, any tool that takes longer
475    /// returns `ToolResult::error("timed out")` rather than blocking indefinitely.
476    pub(crate) tool_timeout: Option<std::time::Duration>,
477    /// Context passed to every tool invocation.
478    tool_context: ToolContext,
479    /// Optional timeout for the built-in `ask_user` tool.
480    ask_user_timeout: Option<std::time::Duration>,
481    /// Channel configuration for bounded queues and backpressure.
482    channel_config: ChannelConfig,
483    #[cfg(feature = "mcp-client")]
484    mcp_servers: Vec<Arc<dyn crate::mcp::McpServer>>,
485}
486
487impl AgentLoop {
488    /// Create a builder with default settings.
489    pub fn builder() -> AgentLoopBuilder {
490        AgentLoopBuilder {
491            tools: Vec::new(),
492            context_providers: Vec::new(),
493            max_iterations: 10,
494            tool_timeout: None,
495            tool_context: None,
496            ask_user_enabled: false,
497            ask_user_timeout: None,
498            channel_config: ChannelConfig::default(),
499            #[cfg(feature = "mcp-client")]
500            mcp_servers: Vec::new(),
501        }
502    }
503
504    /// Returns the channel configuration for this agent loop.
505    ///
506    /// Used by [`AgentSession`](crate::AgentSession) to create bounded
507    /// channels with the configured capacities and policies.
508    pub fn channel_config(&self) -> &ChannelConfig {
509        &self.channel_config
510    }
511
512    /// Run the agent loop to completion.
513    ///
514    /// Iteratively calls the LLM, executes any requested tools, and feeds
515    /// tool results back into the conversation until the LLM produces a
516    /// final text response or `max_iterations` is exceeded.
517    ///
518    /// The `on_event` callback is invoked for each notable event (tool
519    /// started, tool completed, text chunk). Pass `|_| {}` for a no-op.
520    pub async fn run(
521        &self,
522        llm: &dyn LlmClient,
523        messages: Vec<Message>,
524        on_event: impl Fn(AgentEvent) + Send + Sync,
525    ) -> Result<AgentResult> {
526        // Connect MCP servers and collect their tools (mcp-client feature).
527        #[cfg(feature = "mcp-client")]
528        let mcp_tools: Vec<Arc<dyn Tool>> = self.connect_mcp_servers().await?;
529        #[cfg(not(feature = "mcp-client"))]
530        let mcp_tools: Vec<Arc<dyn Tool>> = vec![];
531
532        let result = self.run_inner(llm, messages, &mcp_tools, &on_event).await;
533
534        // Disconnect all MCP servers (best-effort) regardless of run outcome.
535        #[cfg(feature = "mcp-client")]
536        for server in &self.mcp_servers {
537            let _ = server.disconnect().await;
538        }
539
540        result
541    }
542
543    /// Run with an optional interactive operations channel.
544    ///
545    /// Pass `ops_rx: None` for behavior equivalent to [`run`](Self::run).
546    pub async fn run_with_ops(
547        &self,
548        llm: &dyn LlmClient,
549        messages: Vec<Message>,
550        ops_rx: Option<mpsc::Receiver<AgentOp>>,
551        on_event: impl Fn(AgentEvent) + Send + Sync,
552    ) -> Result<AgentResult> {
553        #[cfg(feature = "mcp-client")]
554        let mcp_tools: Vec<Arc<dyn Tool>> = self.connect_mcp_servers().await?;
555        #[cfg(not(feature = "mcp-client"))]
556        let mcp_tools: Vec<Arc<dyn Tool>> = vec![];
557
558        let result = self
559            .run_inner_with_ops(llm, messages, &mcp_tools, &on_event, ops_rx)
560            .await;
561
562        #[cfg(feature = "mcp-client")]
563        for server in &self.mcp_servers {
564            let _ = server.disconnect().await;
565        }
566
567        result
568    }
569
570    /// Connect all MCP servers, collecting their tools.
571    ///
572    /// If any server fails to connect, all previously-connected servers are
573    /// disconnected (best-effort) before the error is returned.
574    #[cfg(feature = "mcp-client")]
575    async fn connect_mcp_servers(&self) -> Result<Vec<Arc<dyn Tool>>> {
576        use crate::mcp::adapter::McpToolAdapter;
577
578        let mut connected: Vec<&Arc<dyn crate::mcp::McpServer>> = Vec::new();
579        let mut tools: Vec<Arc<dyn Tool>> = Vec::new();
580
581        for server in &self.mcp_servers {
582            match server.connect().await {
583                Ok(()) => {
584                    connected.push(server);
585                    match McpToolAdapter::from_server(Arc::clone(server)).await {
586                        Ok(adapter_tools) => tools.extend(adapter_tools),
587                        Err(e) => {
588                            for s in &connected {
589                                let _ = s.disconnect().await;
590                            }
591                            return Err(e);
592                        }
593                    }
594                }
595                Err(e) => {
596                    for s in &connected {
597                        let _ = s.disconnect().await;
598                    }
599                    return Err(e);
600                }
601            }
602        }
603
604        Ok(tools)
605    }
606
607    /// Run the agent loop to completion with cancellation support.
608    ///
609    /// Behaves identically to [`run`](Self::run) but accepts a
610    /// [`CancellationToken`](tokio_util::sync::CancellationToken). When the
611    /// token is cancelled the loop exits early with
612    /// [`AgentError::Cancelled`].
613    ///
614    /// Cancellation is checked at the start of each iteration **and** races
615    /// against the LLM call so an in-flight request is interrupted promptly.
616    ///
617    /// Requires the `cancellation` feature.
618    #[cfg(feature = "cancellation")]
619    pub async fn run_with_cancel(
620        &self,
621        llm: &dyn LlmClient,
622        messages: Vec<Message>,
623        cancel: tokio_util::sync::CancellationToken,
624        on_event: impl Fn(AgentEvent) + Send + Sync,
625    ) -> Result<AgentResult> {
626        #[cfg(feature = "mcp-client")]
627        let mcp_tools: Vec<Arc<dyn Tool>> = {
628            use crate::mcp::adapter::McpToolAdapter;
629            let mut tools = Vec::new();
630            for server in &self.mcp_servers {
631                server.connect().await?;
632                tools.extend(McpToolAdapter::from_server(Arc::clone(server)).await?);
633            }
634            tools
635        };
636        #[cfg(not(feature = "mcp-client"))]
637        let mcp_tools: Vec<Arc<dyn Tool>> = vec![];
638
639        let result = self
640            .run_inner_cancel(llm, messages, &mcp_tools, &on_event, &cancel)
641            .await;
642
643        #[cfg(feature = "mcp-client")]
644        for server in &self.mcp_servers {
645            let _ = server.disconnect().await;
646        }
647
648        result
649    }
650
651    /// Streaming variant of [`run_with_cancel`](Self::run_with_cancel).
652    ///
653    /// Uses [`LlmClient::chat_stream`] and races each chunk against the
654    /// cancellation token so the loop exits promptly when cancelled.
655    ///
656    /// Requires the `cancellation` feature.
657    #[cfg(feature = "cancellation")]
658    pub async fn run_streaming_with_cancel(
659        &self,
660        llm: &dyn LlmClient,
661        messages: Vec<Message>,
662        cancel: tokio_util::sync::CancellationToken,
663        on_event: impl Fn(AgentEvent) + Send + Sync,
664    ) -> Result<AgentResult> {
665        use futures::StreamExt;
666
667        #[cfg(feature = "mcp-client")]
668        let mcp_tools: Vec<Arc<dyn Tool>> = {
669            use crate::mcp::adapter::McpToolAdapter;
670            let mut tools = Vec::new();
671            for server in &self.mcp_servers {
672                server.connect().await?;
673                tools.extend(McpToolAdapter::from_server(Arc::clone(server)).await?);
674            }
675            tools
676        };
677        #[cfg(not(feature = "mcp-client"))]
678        let mcp_tools: Vec<Arc<dyn Tool>> = vec![];
679
680        let tools = MergedTools::new(&self.tool_map, &self.tool_defs, &mcp_tools);
681        let mut state = TurnState::new(self.prepare_messages(messages).await?);
682
683        let result = async {
684            for iteration in 1..=self.max_iterations {
685                // Pre-turn: cancellation check.
686                if cancel.is_cancelled() {
687                    return Err(AgentError::Cancelled);
688                }
689                on_event(AgentEvent::IterationStarted(iteration));
690
691                // Streaming LLM step with cancellation race.
692                let (accumulated_text, response) = {
693                    let mut stream = llm.chat_stream(&state.messages, tools.tool_defs());
694                    let mut accumulated = String::new();
695                    let mut final_response: Option<crate::LlmResponse> = None;
696
697                    loop {
698                        tokio::select! {
699                            chunk_opt = stream.next() => {
700                                match chunk_opt {
701                                    Some(chunk_result) => {
702                                        let chunk = chunk_result?;
703                                        match chunk {
704                                            crate::llm::StreamChunk::TextDelta(delta) => {
705                                                accumulated.push_str(&delta);
706                                                on_event(AgentEvent::TextChunk(delta));
707                                            }
708                                            crate::llm::StreamChunk::Done(resp) => {
709                                                final_response = Some(resp);
710                                            }
711                                            crate::llm::StreamChunk::Usage(usage) => {
712                                                state.total_usage.accumulate(usage);
713                                            }
714                                        }
715                                    }
716                                    None => break,
717                                }
718                            }
719                            _ = cancel.cancelled() => {
720                                return Err(AgentError::Cancelled);
721                            }
722                        }
723                    }
724
725                    let resp = final_response
726                        .unwrap_or_else(|| crate::LlmResponse::Message(accumulated.clone()));
727                    (accumulated, resp)
728                };
729
730                match response {
731                    crate::LlmResponse::Message(text) => {
732                        if accumulated_text.is_empty() && !text.is_empty() {
733                            on_event(AgentEvent::TextChunk(text.clone()));
734                        }
735                        on_event(AgentEvent::TextDone(text.clone()));
736                        return Ok(state.into_result(text, iteration));
737                    }
738                    crate::LlmResponse::ToolCalls(items) => {
739                        emit_tool_started(&items, &on_event);
740                        let results = Self::execute_tools_parallel(
741                            tools.tool_map(),
742                            &items,
743                            self.tool_timeout,
744                            &self.tool_context,
745                        )
746                        .await;
747                        execute_and_record_tool_calls(&items, results, &mut state, &on_event);
748                    }
749                }
750            }
751
752            Err(AgentError::MaxIterations(self.max_iterations))
753        }
754        .await;
755
756        #[cfg(feature = "mcp-client")]
757        for server in &self.mcp_servers {
758            let _ = server.disconnect().await;
759        }
760
761        result
762    }
763
764    /// Consume a streaming LLM response, forwarding text deltas as events
765    /// and accumulating usage.  Returns the accumulated text and the final
766    /// [`LlmResponse`](crate::LlmResponse).
767    async fn consume_stream(
768        llm: &dyn LlmClient,
769        messages: &[Message],
770        tool_defs: &[ToolDef],
771        total_usage: &mut TokenUsage,
772        on_event: &(impl Fn(AgentEvent) + Send + Sync),
773    ) -> Result<(String, crate::LlmResponse)> {
774        use futures::StreamExt;
775
776        let mut stream = llm.chat_stream(messages, tool_defs);
777        let mut accumulated = String::new();
778        let mut final_response: Option<crate::LlmResponse> = None;
779
780        while let Some(chunk_result) = stream.next().await {
781            let chunk = chunk_result?;
782            match chunk {
783                crate::llm::StreamChunk::TextDelta(delta) => {
784                    accumulated.push_str(&delta);
785                    on_event(AgentEvent::TextChunk(delta));
786                }
787                crate::llm::StreamChunk::Done(response) => {
788                    final_response = Some(response);
789                }
790                crate::llm::StreamChunk::Usage(usage) => {
791                    total_usage.accumulate(usage);
792                }
793            }
794        }
795
796        let response =
797            final_response.unwrap_or_else(|| crate::LlmResponse::Message(accumulated.clone()));
798        Ok((accumulated, response))
799    }
800
801    /// Prepend context-provider messages to the conversation.
802    ///
803    /// Each non-empty result is inserted in registration order starting at
804    /// index 0 (i.e. at the *beginning* of the conversation, as documented
805    /// on [`ContextProvider`]).
806    async fn prepare_messages(&self, mut messages: Vec<Message>) -> Result<Vec<Message>> {
807        if self.context_providers.is_empty() {
808            return Ok(messages);
809        }
810        let query: String = messages
811            .iter()
812            .rev()
813            .find(|m| m.role == crate::message::Role::User)
814            .map(|m| m.content.clone())
815            .unwrap_or_default();
816
817        let contexts = try_join_all(self.context_providers.iter().map(|p| p.build(&query))).await?;
818
819        let mut insert_idx = 0;
820        for ctx in contexts {
821            if !ctx.is_empty() {
822                messages.insert(insert_idx, Message::system(&ctx));
823                insert_idx += 1;
824            }
825        }
826        Ok(messages)
827    }
828
829    #[cfg(feature = "cancellation")]
830    async fn run_inner_cancel(
831        &self,
832        llm: &dyn LlmClient,
833        messages: Vec<Message>,
834        extra_tools: &[Arc<dyn Tool>],
835        on_event: &(impl Fn(AgentEvent) + Send + Sync),
836        cancel: &tokio_util::sync::CancellationToken,
837    ) -> Result<AgentResult> {
838        let tools = MergedTools::new(&self.tool_map, &self.tool_defs, extra_tools);
839        let mut state = TurnState::new(self.prepare_messages(messages).await?);
840
841        for iteration in 1..=self.max_iterations {
842            // Pre-turn: cancellation check.
843            if cancel.is_cancelled() {
844                return Err(AgentError::Cancelled);
845            }
846            on_event(AgentEvent::IterationStarted(iteration));
847
848            // LLM step with cancellation race.
849            let output = tokio::select! {
850                output = llm.chat(&state.messages, tools.tool_defs()) => output?,
851                _ = cancel.cancelled() => return Err(AgentError::Cancelled),
852            };
853            state.accumulate_usage(output.usage);
854
855            match output.response {
856                crate::LlmResponse::Message(text) => {
857                    on_event(AgentEvent::TextChunk(text.clone()));
858                    return Ok(state.into_result(text, iteration));
859                }
860                crate::LlmResponse::ToolCalls(items) => {
861                    emit_tool_started(&items, on_event);
862                    let results = Self::execute_tools_parallel(
863                        tools.tool_map(),
864                        &items,
865                        self.tool_timeout,
866                        &self.tool_context,
867                    )
868                    .await;
869                    execute_and_record_tool_calls(&items, results, &mut state, on_event);
870                }
871            }
872        }
873
874        Err(AgentError::MaxIterations(self.max_iterations))
875    }
876
877    async fn run_inner_with_ops(
878        &self,
879        llm: &dyn LlmClient,
880        messages: Vec<Message>,
881        extra_tools: &[Arc<dyn Tool>],
882        on_event: &(impl Fn(AgentEvent) + Send + Sync),
883        mut ops_rx: Option<mpsc::Receiver<AgentOp>>,
884    ) -> Result<AgentResult> {
885        let tools = MergedTools::new(&self.tool_map, &self.tool_defs, extra_tools);
886        let mut state = TurnState::new(self.prepare_messages(messages).await?);
887        let mut ops_state = OpsState::default();
888
889        for iteration in 1..=self.max_iterations {
890            // Pre-turn: ingest pending ops.
891            Self::drain_ops(&mut state.messages, &mut ops_rx, &mut ops_state);
892            if ops_state.interrupted {
893                on_event(AgentEvent::Interrupted);
894                return Ok(
895                    state.into_result("(interrupted)".to_string(), iteration.saturating_sub(1))
896                );
897            }
898
899            on_event(AgentEvent::IterationStarted(iteration));
900
901            // LLM step.
902            let output = llm.chat(&state.messages, tools.tool_defs()).await?;
903            state.accumulate_usage(output.usage);
904
905            match output.response {
906                crate::LlmResponse::Message(text) => {
907                    on_event(AgentEvent::TextChunk(text.clone()));
908                    state.messages.push(Message::assistant(&text));
909                    return Ok(state.into_result(text, iteration));
910                }
911                crate::LlmResponse::ToolCalls(items) => {
912                    // Tool-call execution stage (with ask_user policy).
913                    emit_tool_started(&items, on_event);
914                    let results = self
915                        .execute_tools_with_policy(
916                            tools.tool_map(),
917                            &items,
918                            &mut state.messages,
919                            &mut ops_rx,
920                            &mut ops_state,
921                            on_event,
922                        )
923                        .await;
924                    execute_and_record_tool_calls(&items, results, &mut state, on_event);
925                }
926            }
927        }
928
929        Err(AgentError::MaxIterations(self.max_iterations))
930    }
931
932    async fn wait_for_ask_user_answer(
933        &self,
934        call_id: &str,
935        question: &str,
936        messages: &mut Vec<Message>,
937        ops_rx: &mut Option<mpsc::Receiver<AgentOp>>,
938        ops_state: &mut OpsState,
939        on_event: &(impl Fn(AgentEvent) + Send + Sync),
940    ) -> String {
941        if let Some(answer) = pop_matching_answer(&mut ops_state.pending_answers, call_id) {
942            return answer;
943        }
944
945        let timeout = self.ask_user_timeout;
946        let started = tokio::time::Instant::now();
947
948        loop {
949            let next_op = if let Some(rx) = ops_rx.as_mut() {
950                if let Some(limit) = timeout {
951                    let elapsed = started.elapsed();
952                    if elapsed >= limit {
953                        on_event(AgentEvent::AskUserTimeout {
954                            call_id: call_id.to_string(),
955                            question: question.to_string(),
956                        });
957                        return String::new();
958                    }
959                    let remaining = limit - elapsed;
960                    match tokio::time::timeout(remaining, rx.recv()).await {
961                        Ok(op) => op,
962                        Err(_) => {
963                            on_event(AgentEvent::AskUserTimeout {
964                                call_id: call_id.to_string(),
965                                question: question.to_string(),
966                            });
967                            return String::new();
968                        }
969                    }
970                } else {
971                    rx.recv().await
972                }
973            } else {
974                on_event(AgentEvent::AskUserTimeout {
975                    call_id: call_id.to_string(),
976                    question: question.to_string(),
977                });
978                return String::new();
979            };
980
981            let Some(op) = next_op else {
982                on_event(AgentEvent::AskUserTimeout {
983                    call_id: call_id.to_string(),
984                    question: question.to_string(),
985                });
986                return String::new();
987            };
988            Self::apply_op(op, messages, ops_state);
989            if ops_state.interrupted {
990                return String::new();
991            }
992
993            if let Some(answer) = pop_matching_answer(&mut ops_state.pending_answers, call_id) {
994                return answer;
995            }
996        }
997    }
998
999    async fn run_inner(
1000        &self,
1001        llm: &dyn LlmClient,
1002        messages: Vec<Message>,
1003        extra_tools: &[Arc<dyn Tool>],
1004        on_event: &(impl Fn(AgentEvent) + Send + Sync),
1005    ) -> Result<AgentResult> {
1006        let tools = MergedTools::new(&self.tool_map, &self.tool_defs, extra_tools);
1007        let mut state = TurnState::new(self.prepare_messages(messages).await?);
1008
1009        for iteration in 1..=self.max_iterations {
1010            on_event(AgentEvent::IterationStarted(iteration));
1011
1012            // LLM step.
1013            let output = llm.chat(&state.messages, tools.tool_defs()).await?;
1014            state.accumulate_usage(output.usage);
1015
1016            match output.response {
1017                crate::LlmResponse::Message(text) => {
1018                    on_event(AgentEvent::TextChunk(text.clone()));
1019                    state.messages.push(Message::assistant(&text));
1020                    return Ok(state.into_result(text, iteration));
1021                }
1022                crate::LlmResponse::ToolCalls(items) => {
1023                    // Tool-call execution stage.
1024                    emit_tool_started(&items, on_event);
1025                    let results = Self::execute_tools_parallel(
1026                        tools.tool_map(),
1027                        &items,
1028                        self.tool_timeout,
1029                        &self.tool_context,
1030                    )
1031                    .await;
1032                    execute_and_record_tool_calls(&items, results, &mut state, on_event);
1033                }
1034            }
1035        }
1036
1037        Err(AgentError::MaxIterations(self.max_iterations))
1038    }
1039
1040    /// Streaming variant of [`run`](Self::run).
1041    ///
1042    /// Uses [`LlmClient::chat_stream`] so that `on_event` receives
1043    /// [`AgentEvent::TextChunk`] for each text delta as it arrives, plus
1044    /// [`AgentEvent::TextDone`] with the full accumulated text and
1045    /// [`AgentEvent::IterationStarted`] at each reasoning iteration.
1046    ///
1047    /// Tool execution is identical to the non-streaming path.
1048    pub async fn run_streaming(
1049        &self,
1050        llm: &dyn LlmClient,
1051        messages: Vec<Message>,
1052        on_event: impl Fn(AgentEvent) + Send + Sync,
1053    ) -> Result<AgentResult> {
1054        // Connect MCP servers and collect their tools (mcp-client feature).
1055        #[cfg(feature = "mcp-client")]
1056        let mcp_tools: Vec<Arc<dyn Tool>> = self.connect_mcp_servers().await?;
1057        #[cfg(not(feature = "mcp-client"))]
1058        let mcp_tools: Vec<Arc<dyn Tool>> = vec![];
1059
1060        let tools = MergedTools::new(&self.tool_map, &self.tool_defs, &mcp_tools);
1061        let mut state = TurnState::new(self.prepare_messages(messages).await?);
1062
1063        let result = async {
1064            for iteration in 1..=self.max_iterations {
1065                on_event(AgentEvent::IterationStarted(iteration));
1066
1067                // Streaming LLM step.
1068                let (accumulated_text, response) = Self::consume_stream(
1069                    llm,
1070                    &state.messages,
1071                    tools.tool_defs(),
1072                    &mut state.total_usage,
1073                    &on_event,
1074                )
1075                .await?;
1076
1077                match response {
1078                    crate::LlmResponse::Message(text) => {
1079                        if accumulated_text.is_empty() && !text.is_empty() {
1080                            on_event(AgentEvent::TextChunk(text.clone()));
1081                        }
1082                        on_event(AgentEvent::TextDone(text.clone()));
1083                        state.messages.push(Message::assistant(&text));
1084                        return Ok(state.into_result(text, iteration));
1085                    }
1086                    crate::LlmResponse::ToolCalls(items) => {
1087                        emit_tool_started(&items, &on_event);
1088                        let results = Self::execute_tools_parallel(
1089                            tools.tool_map(),
1090                            &items,
1091                            self.tool_timeout,
1092                            &self.tool_context,
1093                        )
1094                        .await;
1095                        execute_and_record_tool_calls(&items, results, &mut state, &on_event);
1096                    }
1097                }
1098            }
1099
1100            Err(AgentError::MaxIterations(self.max_iterations))
1101        }
1102        .await;
1103
1104        // Disconnect all MCP servers (best-effort) regardless of run outcome.
1105        #[cfg(feature = "mcp-client")]
1106        for server in &self.mcp_servers {
1107            let _ = server.disconnect().await;
1108        }
1109
1110        result
1111    }
1112
1113    fn drain_ops(
1114        messages: &mut Vec<Message>,
1115        ops_rx: &mut Option<mpsc::Receiver<AgentOp>>,
1116        ops_state: &mut OpsState,
1117    ) {
1118        if let Some(rx) = ops_rx.as_mut() {
1119            while let Ok(op) = rx.try_recv() {
1120                Self::apply_op(op, messages, ops_state);
1121            }
1122        }
1123    }
1124
1125    fn apply_op(op: AgentOp, messages: &mut Vec<Message>, ops_state: &mut OpsState) {
1126        match op {
1127            AgentOp::Interrupt => {
1128                ops_state.interrupted = true;
1129            }
1130            AgentOp::InjectUserMessage(text) => {
1131                messages.push(Message::user(&text));
1132            }
1133            AgentOp::InjectHint(hint) => {
1134                messages.push(Message::user(&format!("[Note: {hint}]")));
1135            }
1136            AgentOp::AskUserAnswer { call_id, answer } => {
1137                ops_state
1138                    .pending_answers
1139                    .push(PendingAskUserAnswer { call_id, answer });
1140            }
1141        }
1142    }
1143
1144    async fn execute_tools_with_policy(
1145        &self,
1146        tool_map: &HashMap<String, Arc<dyn Tool>>,
1147        items: &[crate::llm::ToolCallItem],
1148        messages: &mut Vec<Message>,
1149        ops_rx: &mut Option<mpsc::Receiver<AgentOp>>,
1150        ops_state: &mut OpsState,
1151        on_event: &(impl Fn(AgentEvent) + Send + Sync),
1152    ) -> Vec<ToolResult> {
1153        let policy = ToolExecutionPolicy::from_items(items);
1154        match policy {
1155            ToolExecutionPolicy::ParallelOnly => {
1156                Self::execute_tools_parallel(tool_map, items, self.tool_timeout, &self.tool_context)
1157                    .await
1158            }
1159            ToolExecutionPolicy::InteractiveAskUser => {
1160                let non_ask_future = join_all(
1161                    items
1162                        .iter()
1163                        .enumerate()
1164                        .filter(|(_, tc)| tc.name != "ask_user")
1165                        .map(|(idx, tc)| async move {
1166                            (
1167                                idx,
1168                                Self::execute_tool(
1169                                    tool_map,
1170                                    &tc.name,
1171                                    tc.args.clone(),
1172                                    self.tool_timeout,
1173                                    &self.tool_context,
1174                                )
1175                                .await,
1176                            )
1177                        }),
1178                );
1179
1180                let ask_future = async {
1181                    let mut ask_results = Vec::new();
1182                    for (idx, tc) in items
1183                        .iter()
1184                        .enumerate()
1185                        .filter(|(_, tc)| tc.name == "ask_user")
1186                    {
1187                        let (question, options) = parse_ask_user_args(&tc.args);
1188                        on_event(AgentEvent::AskUser {
1189                            call_id: tc.id.clone(),
1190                            question: question.clone(),
1191                            options,
1192                        });
1193
1194                        let answer = self
1195                            .wait_for_ask_user_answer(
1196                                &tc.id, &question, messages, ops_rx, ops_state, on_event,
1197                            )
1198                            .await;
1199                        ask_results.push((idx, ToolResult::text(answer)));
1200                    }
1201                    ask_results
1202                };
1203
1204                let (non_ask_results, ask_results) = futures::join!(non_ask_future, ask_future);
1205                let mut merged: Vec<Option<ToolResult>> = vec![None; items.len()];
1206                for (idx, result) in non_ask_results.into_iter().chain(ask_results) {
1207                    merged[idx] = Some(result);
1208                }
1209                merged
1210                    .into_iter()
1211                    .map(|result| result.expect("tool result must be present"))
1212                    .collect()
1213            }
1214        }
1215    }
1216
1217    async fn execute_tools_parallel(
1218        tool_map: &HashMap<String, Arc<dyn Tool>>,
1219        items: &[crate::llm::ToolCallItem],
1220        timeout: Option<std::time::Duration>,
1221        ctx: &ToolContext,
1222    ) -> Vec<ToolResult> {
1223        join_all(
1224            items
1225                .iter()
1226                .map(|tc| Self::execute_tool(tool_map, &tc.name, tc.args.clone(), timeout, ctx)),
1227        )
1228        .await
1229    }
1230
1231    /// Execute a single tool by name using a pre-built lookup map.
1232    /// Returns an error `ToolResult` if the tool is not found or times out.
1233    async fn execute_tool(
1234        tool_map: &HashMap<String, Arc<dyn Tool>>,
1235        name: &str,
1236        args: serde_json::Value,
1237        timeout: Option<std::time::Duration>,
1238        ctx: &ToolContext,
1239    ) -> ToolResult {
1240        let fut = async {
1241            if let Some(tool) = tool_map.get(name) {
1242                tool.call(args, &ctx).await
1243            } else {
1244                ToolResult::error(format!("unknown tool: {name}"))
1245            }
1246        };
1247        if let Some(dur) = timeout {
1248            match tokio::time::timeout(dur, fut).await {
1249                Ok(result) => result,
1250                Err(_) => ToolResult::error(format!("tool '{name}' timed out after {dur:?}")),
1251            }
1252        } else {
1253            fut.await
1254        }
1255    }
1256}
1257
1258#[derive(Debug, Clone, Copy)]
1259enum ToolExecutionPolicy {
1260    ParallelOnly,
1261    InteractiveAskUser,
1262}
1263
1264impl ToolExecutionPolicy {
1265    fn from_items(items: &[crate::llm::ToolCallItem]) -> Self {
1266        if items.iter().any(|tc| tc.name == "ask_user") {
1267            Self::InteractiveAskUser
1268        } else {
1269            Self::ParallelOnly
1270        }
1271    }
1272}
1273
1274#[derive(Default)]
1275struct OpsState {
1276    interrupted: bool,
1277    pending_answers: Vec<PendingAskUserAnswer>,
1278}
1279
1280struct PendingAskUserAnswer {
1281    call_id: Option<String>,
1282    answer: String,
1283}
1284
1285fn pop_matching_answer(pending: &mut Vec<PendingAskUserAnswer>, call_id: &str) -> Option<String> {
1286    let pos = pending
1287        .iter()
1288        .position(|item| item.call_id.as_deref() == Some(call_id) || item.call_id.is_none())?;
1289    Some(pending.remove(pos).answer)
1290}
1291
1292fn parse_ask_user_args(args: &serde_json::Value) -> (String, Vec<String>) {
1293    let question = args
1294        .get("question")
1295        .and_then(|value| value.as_str())
1296        .unwrap_or("")
1297        .to_string();
1298    let options = args
1299        .get("options")
1300        .and_then(|value| value.as_array())
1301        .map(|list| {
1302            list.iter()
1303                .filter_map(|entry| entry.as_str().map(ToString::to_string))
1304                .collect::<Vec<_>>()
1305        })
1306        .unwrap_or_default();
1307    (question, options)
1308}
1309
1310fn ask_user_tool_def() -> ToolDef {
1311    ToolDef {
1312        name: "ask_user".to_string(),
1313        description: "Ask the user a question and wait for a reply.".to_string(),
1314        input_schema: serde_json::json!({
1315            "type": "object",
1316            "properties": {
1317                "question": { "type": "string" },
1318                "options": {
1319                    "type": "array",
1320                    "items": { "type": "string" }
1321                }
1322            },
1323            "required": ["question"]
1324        }),
1325    }
1326}
1327
1328/// Convert a [`ToolResult`] into a string suitable for message content.
1329fn tool_result_to_string(result: &ToolResult) -> String {
1330    match result.as_text() {
1331        Some(text) => text.to_string(),
1332        None => {
1333            // Fall back to JSON serialization of the content.
1334            serde_json::to_string(&result.content).unwrap_or_else(|_| "<no content>".to_string())
1335        }
1336    }
1337}
1338
1339#[cfg(test)]
1340mod tests {
1341    use super::*;
1342    use crate::context::ContextProvider;
1343    use crate::llm::{ChatOutput, LlmResponse, TokenUsage, ToolCallItem};
1344    use async_trait::async_trait;
1345    use std::sync::{Arc, Mutex};
1346
1347    // -----------------------------------------------------------------------
1348    // Mock LLM client
1349    // -----------------------------------------------------------------------
1350
1351    /// A mock LLM client that returns pre-programmed responses in sequence.
1352    struct MockLlm {
1353        responses: Mutex<Vec<LlmResponse>>,
1354        usage_per_call: Option<TokenUsage>,
1355    }
1356
1357    impl MockLlm {
1358        fn new(responses: Vec<LlmResponse>) -> Self {
1359            Self {
1360                responses: Mutex::new(responses),
1361                usage_per_call: None,
1362            }
1363        }
1364
1365        fn with_usage(responses: Vec<LlmResponse>, usage: TokenUsage) -> Self {
1366            Self {
1367                responses: Mutex::new(responses),
1368                usage_per_call: Some(usage),
1369            }
1370        }
1371    }
1372
1373    #[async_trait]
1374    impl LlmClient for MockLlm {
1375        async fn chat(
1376            &self,
1377            _messages: &[Message],
1378            _tools: &[ToolDef],
1379        ) -> crate::Result<ChatOutput> {
1380            let mut responses = self.responses.lock().unwrap();
1381            if responses.is_empty() {
1382                panic!("MockLlm: no more responses");
1383            }
1384            let response = responses.remove(0);
1385            Ok(ChatOutput {
1386                response,
1387                usage: self.usage_per_call,
1388            })
1389        }
1390    }
1391
1392    // -----------------------------------------------------------------------
1393    // Mock tool
1394    // -----------------------------------------------------------------------
1395
1396    struct MockTool {
1397        name: String,
1398        result_text: String,
1399    }
1400
1401    impl MockTool {
1402        fn new(name: &str, result_text: &str) -> Self {
1403            Self {
1404                name: name.to_string(),
1405                result_text: result_text.to_string(),
1406            }
1407        }
1408    }
1409
1410    impl Tool for MockTool {
1411        fn def(&self) -> ToolDef {
1412            ToolDef {
1413                name: self.name.clone(),
1414                description: format!("Mock tool: {}", self.name),
1415                input_schema: serde_json::json!({
1416                    "type": "object",
1417                    "properties": {
1418                        "input": { "type": "string" }
1419                    },
1420                    "required": []
1421                }),
1422            }
1423        }
1424
1425        fn call(
1426            &self,
1427            _args: serde_json::Value,
1428            _ctx: &ToolContext,
1429        ) -> std::pin::Pin<Box<dyn std::future::Future<Output = ToolResult> + Send + '_>> {
1430            let text = self.result_text.clone();
1431            Box::pin(async move { ToolResult::text(text) })
1432        }
1433    }
1434
1435    // -----------------------------------------------------------------------
1436    // Tests
1437    // -----------------------------------------------------------------------
1438
1439    #[tokio::test]
1440    async fn direct_message_response() {
1441        let llm = MockLlm::new(vec![LlmResponse::Message("Hello!".to_string())]);
1442
1443        let agent = AgentLoop::builder().build();
1444        let result = agent
1445            .run(&llm, vec![Message::user("Hi")], |_| {})
1446            .await
1447            .unwrap();
1448
1449        assert_eq!(result.answer, "Hello!");
1450        assert!(result.tool_calls.is_empty());
1451        assert_eq!(result.iterations, 1);
1452    }
1453
1454    #[tokio::test]
1455    async fn single_tool_call_then_message() {
1456        let llm = MockLlm::new(vec![
1457            LlmResponse::single_tool_call(
1458                "call_1".to_string(),
1459                "search".to_string(),
1460                serde_json::json!({"input": "rust"}),
1461            ),
1462            LlmResponse::Message("Found results about Rust.".to_string()),
1463        ]);
1464
1465        let tool: Arc<dyn Tool> = Arc::new(MockTool::new("search", "Rust is a systems language"));
1466
1467        let agent = AgentLoop::builder().tool(tool).build();
1468        let result = agent
1469            .run(&llm, vec![Message::user("Search for rust")], |_| {})
1470            .await
1471            .unwrap();
1472
1473        assert_eq!(result.answer, "Found results about Rust.");
1474        assert_eq!(result.tool_calls.len(), 1);
1475        assert_eq!(result.tool_calls[0].0, "search");
1476        assert_eq!(result.iterations, 2);
1477    }
1478
1479    #[tokio::test]
1480    async fn parallel_tool_calls() {
1481        let llm = MockLlm::new(vec![
1482            LlmResponse::ToolCalls(vec![
1483                ToolCallItem {
1484                    id: "call_1".to_string(),
1485                    name: "search".to_string(),
1486                    args: serde_json::json!({"input": "a"}),
1487                },
1488                ToolCallItem {
1489                    id: "call_2".to_string(),
1490                    name: "fetch".to_string(),
1491                    args: serde_json::json!({"input": "b"}),
1492                },
1493            ]),
1494            LlmResponse::Message("Combined results.".to_string()),
1495        ]);
1496
1497        let search: Arc<dyn Tool> = Arc::new(MockTool::new("search", "result_a"));
1498        let fetch: Arc<dyn Tool> = Arc::new(MockTool::new("fetch", "result_b"));
1499
1500        let agent = AgentLoop::builder().tool(search).tool(fetch).build();
1501        let result = agent
1502            .run(&llm, vec![Message::user("Do both")], |_| {})
1503            .await
1504            .unwrap();
1505
1506        assert_eq!(result.answer, "Combined results.");
1507        assert_eq!(result.tool_calls.len(), 2);
1508        assert_eq!(result.tool_calls[0].0, "search");
1509        assert_eq!(result.tool_calls[1].0, "fetch");
1510        assert_eq!(result.iterations, 2);
1511    }
1512
1513    #[tokio::test]
1514    async fn max_iterations_exceeded() {
1515        // LLM keeps requesting tool calls indefinitely.
1516        let responses: Vec<LlmResponse> = (0..5)
1517            .map(|i| {
1518                LlmResponse::single_tool_call(
1519                    format!("call_{i}"),
1520                    "search".to_string(),
1521                    serde_json::json!({}),
1522                )
1523            })
1524            .collect();
1525        let llm = MockLlm::new(responses);
1526
1527        let tool: Arc<dyn Tool> = Arc::new(MockTool::new("search", "result"));
1528
1529        let agent = AgentLoop::builder().tool(tool).max_iterations(3).build();
1530        let err = agent
1531            .run(&llm, vec![Message::user("loop forever")], |_| {})
1532            .await
1533            .unwrap_err();
1534
1535        assert!(matches!(err, AgentError::MaxIterations(3)));
1536    }
1537
1538    #[tokio::test]
1539    async fn events_emitted_correctly() {
1540        let llm = MockLlm::new(vec![
1541            LlmResponse::single_tool_call(
1542                "call_1".to_string(),
1543                "search".to_string(),
1544                serde_json::json!({}),
1545            ),
1546            LlmResponse::Message("Done.".to_string()),
1547        ]);
1548
1549        let tool: Arc<dyn Tool> = Arc::new(MockTool::new("search", "result"));
1550        let events: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
1551        let events_clone = events.clone();
1552
1553        let agent = AgentLoop::builder().tool(tool).build();
1554        let _ = agent
1555            .run(&llm, vec![Message::user("test")], move |event| {
1556                let label = match &event {
1557                    AgentEvent::ToolStarted { name } => format!("started:{name}"),
1558                    AgentEvent::ToolCompleted { name, .. } => format!("completed:{name}"),
1559                    AgentEvent::TextChunk(t) => format!("text:{t}"),
1560                    AgentEvent::TextDone(t) => format!("done:{t}"),
1561                    AgentEvent::IterationStarted(n) => format!("iter:{n}"),
1562                    AgentEvent::Interrupted => "interrupted".to_string(),
1563                    AgentEvent::AskUser { question, .. } => format!("ask_user:{question}"),
1564                    AgentEvent::AskUserTimeout { call_id, .. } => {
1565                        format!("ask_user_timeout:{call_id}")
1566                    }
1567                    AgentEvent::OpsSaturated { capacity } => {
1568                        format!("ops_saturated:{capacity}")
1569                    }
1570                    AgentEvent::OpDropped { reason } => format!("op_dropped:{reason}"),
1571                    AgentEvent::OpRejected { reason } => format!("op_rejected:{reason}"),
1572                };
1573                events_clone.lock().unwrap().push(label);
1574            })
1575            .await
1576            .unwrap();
1577
1578        let events = events.lock().unwrap();
1579        assert_eq!(
1580            *events,
1581            vec![
1582                "iter:1",
1583                "started:search",
1584                "completed:search",
1585                "iter:2",
1586                "text:Done.",
1587            ]
1588        );
1589    }
1590
1591    #[tokio::test]
1592    async fn unknown_tool_produces_error_result() {
1593        let llm = MockLlm::new(vec![
1594            LlmResponse::single_tool_call(
1595                "call_1".to_string(),
1596                "nonexistent".to_string(),
1597                serde_json::json!({}),
1598            ),
1599            LlmResponse::Message("Handled error.".to_string()),
1600        ]);
1601
1602        let events: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
1603        let events_clone = events.clone();
1604
1605        let agent = AgentLoop::builder().build();
1606        let result = agent
1607            .run(&llm, vec![Message::user("call missing")], move |event| {
1608                if let AgentEvent::ToolCompleted { result, .. } = &event {
1609                    if result.is_error {
1610                        events_clone
1611                            .lock()
1612                            .unwrap()
1613                            .push("error_tool_result".to_string());
1614                    }
1615                }
1616            })
1617            .await
1618            .unwrap();
1619
1620        assert_eq!(result.answer, "Handled error.");
1621        let events = events.lock().unwrap();
1622        assert!(events.contains(&"error_tool_result".to_string()));
1623    }
1624
1625    #[tokio::test]
1626    async fn builder_tools_method() {
1627        let tools: Vec<Arc<dyn Tool>> = vec![
1628            Arc::new(MockTool::new("a", "ra")),
1629            Arc::new(MockTool::new("b", "rb")),
1630        ];
1631
1632        let llm = MockLlm::new(vec![LlmResponse::Message("ok".to_string())]);
1633
1634        let agent = AgentLoop::builder().tools(tools).build();
1635        let result = agent
1636            .run(&llm, vec![Message::user("hi")], |_| {})
1637            .await
1638            .unwrap();
1639
1640        assert_eq!(result.answer, "ok");
1641    }
1642
1643    #[tokio::test]
1644    async fn noop_event_callback() {
1645        let llm = MockLlm::new(vec![LlmResponse::Message("ok".to_string())]);
1646        let agent = AgentLoop::builder().build();
1647        // Verify that |_| {} compiles and works as the noop callback.
1648        let result = agent
1649            .run(&llm, vec![Message::user("hi")], |_| {})
1650            .await
1651            .unwrap();
1652        assert_eq!(result.answer, "ok");
1653    }
1654
1655    // -----------------------------------------------------------------------
1656    // Mock context provider
1657    // -----------------------------------------------------------------------
1658
1659    struct MockContextProvider {
1660        context: String,
1661    }
1662
1663    impl MockContextProvider {
1664        fn new(context: &str) -> Self {
1665            Self {
1666                context: context.to_string(),
1667            }
1668        }
1669    }
1670
1671    #[async_trait]
1672    impl ContextProvider for MockContextProvider {
1673        async fn build(&self, _query: &str) -> crate::Result<String> {
1674            Ok(self.context.clone())
1675        }
1676    }
1677
1678    /// A context provider that echoes the query back, useful for verifying
1679    /// that the correct query string is passed.
1680    struct EchoContextProvider;
1681
1682    #[async_trait]
1683    impl ContextProvider for EchoContextProvider {
1684        async fn build(&self, query: &str) -> crate::Result<String> {
1685            Ok(format!("echo: {query}"))
1686        }
1687    }
1688
1689    // -----------------------------------------------------------------------
1690    // Context provider tests
1691    // -----------------------------------------------------------------------
1692
1693    /// A capturing LLM that records the messages it receives, then returns
1694    /// a canned response.
1695    struct CapturingLlm {
1696        captured: Mutex<Vec<Vec<Message>>>,
1697        responses: Mutex<Vec<LlmResponse>>,
1698    }
1699
1700    impl CapturingLlm {
1701        fn new(responses: Vec<LlmResponse>) -> Self {
1702            Self {
1703                captured: Mutex::new(Vec::new()),
1704                responses: Mutex::new(responses),
1705            }
1706        }
1707
1708        fn captured_messages(&self) -> Vec<Vec<Message>> {
1709            self.captured.lock().unwrap().clone()
1710        }
1711    }
1712
1713    #[async_trait]
1714    impl LlmClient for CapturingLlm {
1715        async fn chat(
1716            &self,
1717            messages: &[Message],
1718            _tools: &[motosan_agent_tool::ToolDef],
1719        ) -> crate::Result<ChatOutput> {
1720            self.captured.lock().unwrap().push(messages.to_vec());
1721            let mut responses = self.responses.lock().unwrap();
1722            if responses.is_empty() {
1723                panic!("CapturingLlm: no more responses");
1724            }
1725            Ok(ChatOutput::new(responses.remove(0)))
1726        }
1727    }
1728
1729    #[tokio::test]
1730    async fn context_provider_injects_system_message() {
1731        let llm = CapturingLlm::new(vec![LlmResponse::Message("answer".to_string())]);
1732
1733        let agent = AgentLoop::builder()
1734            .context(MockContextProvider::new("You have access to RAG docs."))
1735            .build();
1736
1737        let result = agent
1738            .run(&llm, vec![Message::user("tell me about rust")], |_| {})
1739            .await
1740            .unwrap();
1741
1742        assert_eq!(result.answer, "answer");
1743
1744        // Verify the LLM received the context as a system message
1745        // *prepended* before the user message.
1746        let calls = llm.captured_messages();
1747        assert_eq!(calls.len(), 1);
1748        let msgs = &calls[0];
1749        // Should have: system context message + user message
1750        assert_eq!(msgs.len(), 2);
1751        assert_eq!(msgs[0].role, crate::message::Role::System);
1752        assert_eq!(msgs[0].content, "You have access to RAG docs.");
1753        assert_eq!(msgs[1].role, crate::message::Role::User);
1754        assert_eq!(msgs[1].content, "tell me about rust");
1755    }
1756
1757    #[tokio::test]
1758    async fn empty_context_is_skipped() {
1759        let llm = CapturingLlm::new(vec![LlmResponse::Message("answer".to_string())]);
1760
1761        let agent = AgentLoop::builder()
1762            .context(MockContextProvider::new("")) // empty — should be skipped
1763            .build();
1764
1765        let result = agent
1766            .run(&llm, vec![Message::user("hi")], |_| {})
1767            .await
1768            .unwrap();
1769
1770        assert_eq!(result.answer, "answer");
1771
1772        let calls = llm.captured_messages();
1773        assert_eq!(calls.len(), 1);
1774        let msgs = &calls[0];
1775        // Only the user message, no system context.
1776        assert_eq!(msgs.len(), 1);
1777        assert_eq!(msgs[0].role, crate::message::Role::User);
1778    }
1779
1780    #[tokio::test]
1781    async fn multiple_context_providers() {
1782        let llm = CapturingLlm::new(vec![LlmResponse::Message("done".to_string())]);
1783
1784        let agent = AgentLoop::builder()
1785            .context(MockContextProvider::new("RAG context here"))
1786            .context(MockContextProvider::new("")) // empty — skipped
1787            .context(MockContextProvider::new("User profile: premium"))
1788            .build();
1789
1790        let result = agent
1791            .run(&llm, vec![Message::user("query")], |_| {})
1792            .await
1793            .unwrap();
1794
1795        assert_eq!(result.answer, "done");
1796
1797        let calls = llm.captured_messages();
1798        assert_eq!(calls.len(), 1);
1799        let msgs = &calls[0];
1800        // 2 non-empty context messages (prepended) + user
1801        assert_eq!(msgs.len(), 3);
1802        assert_eq!(msgs[0].role, crate::message::Role::System);
1803        assert_eq!(msgs[0].content, "RAG context here");
1804        assert_eq!(msgs[1].role, crate::message::Role::System);
1805        assert_eq!(msgs[1].content, "User profile: premium");
1806        assert_eq!(msgs[2].role, crate::message::Role::User);
1807    }
1808
1809    #[tokio::test]
1810    async fn context_provider_receives_user_query() {
1811        let llm = CapturingLlm::new(vec![LlmResponse::Message("ok".to_string())]);
1812
1813        let agent = AgentLoop::builder().context(EchoContextProvider).build();
1814
1815        let result = agent
1816            .run(&llm, vec![Message::user("my question")], |_| {})
1817            .await
1818            .unwrap();
1819
1820        assert_eq!(result.answer, "ok");
1821
1822        let calls = llm.captured_messages();
1823        let msgs = &calls[0];
1824        assert_eq!(msgs.len(), 2);
1825        assert_eq!(msgs[0].role, crate::message::Role::System);
1826        assert_eq!(msgs[0].content, "echo: my question");
1827    }
1828
1829    #[tokio::test]
1830    async fn no_context_providers_leaves_messages_unchanged() {
1831        let llm = CapturingLlm::new(vec![LlmResponse::Message("ok".to_string())]);
1832
1833        let agent = AgentLoop::builder().build(); // no context providers
1834
1835        let _ = agent
1836            .run(&llm, vec![Message::user("hi")], |_| {})
1837            .await
1838            .unwrap();
1839
1840        let calls = llm.captured_messages();
1841        let msgs = &calls[0];
1842        assert_eq!(msgs.len(), 1);
1843        assert_eq!(msgs[0].content, "hi");
1844    }
1845
1846    #[tokio::test]
1847    async fn builder_contexts_batch_method() {
1848        let llm = CapturingLlm::new(vec![LlmResponse::Message("ok".to_string())]);
1849
1850        let providers: Vec<Box<dyn ContextProvider>> = vec![
1851            Box::new(MockContextProvider::new("ctx-a")),
1852            Box::new(MockContextProvider::new("ctx-b")),
1853        ];
1854
1855        let agent = AgentLoop::builder().contexts(providers).build();
1856        let result = agent
1857            .run(&llm, vec![Message::user("hi")], |_| {})
1858            .await
1859            .unwrap();
1860
1861        assert_eq!(result.answer, "ok");
1862
1863        let calls = llm.captured_messages();
1864        let msgs = &calls[0];
1865        // 2 context messages prepended + user message
1866        assert_eq!(msgs.len(), 3);
1867        assert_eq!(msgs[0].role, crate::message::Role::System);
1868        assert_eq!(msgs[0].content, "ctx-a");
1869        assert_eq!(msgs[1].role, crate::message::Role::System);
1870        assert_eq!(msgs[1].content, "ctx-b");
1871        assert_eq!(msgs[2].role, crate::message::Role::User);
1872        assert_eq!(msgs[2].content, "hi");
1873    }
1874
1875    /// A context provider that sleeps for a given duration, used to verify
1876    /// parallel execution.
1877    struct DelayContextProvider {
1878        context: String,
1879        delay: std::time::Duration,
1880    }
1881
1882    impl DelayContextProvider {
1883        fn new(context: &str, delay: std::time::Duration) -> Self {
1884            Self {
1885                context: context.to_string(),
1886                delay,
1887            }
1888        }
1889    }
1890
1891    #[async_trait]
1892    impl ContextProvider for DelayContextProvider {
1893        async fn build(&self, _query: &str) -> crate::Result<String> {
1894            tokio::time::sleep(self.delay).await;
1895            Ok(self.context.clone())
1896        }
1897    }
1898
1899    #[tokio::test]
1900    async fn context_providers_run_in_parallel() {
1901        let llm = CapturingLlm::new(vec![LlmResponse::Message("ok".to_string())]);
1902        let delay = std::time::Duration::from_millis(100);
1903
1904        let agent = AgentLoop::builder()
1905            .context(DelayContextProvider::new("ctx-a", delay))
1906            .context(DelayContextProvider::new("ctx-b", delay))
1907            .context(DelayContextProvider::new("ctx-c", delay))
1908            .build();
1909
1910        let start = std::time::Instant::now();
1911        let result = agent
1912            .run(&llm, vec![Message::user("hi")], |_| {})
1913            .await
1914            .unwrap();
1915        let elapsed = start.elapsed();
1916
1917        // If sequential, ~300ms; if parallel, ~100ms.
1918        // Use 250ms as a generous threshold.
1919        assert!(
1920            elapsed < std::time::Duration::from_millis(250),
1921            "Expected parallel execution (<250ms), but took {elapsed:?}",
1922        );
1923
1924        assert_eq!(result.answer, "ok");
1925
1926        // Verify order is preserved.
1927        let calls = llm.captured_messages();
1928        let msgs = &calls[0];
1929        assert_eq!(msgs.len(), 4); // 3 context + 1 user
1930        assert_eq!(msgs[0].content, "ctx-a");
1931        assert_eq!(msgs[1].content, "ctx-b");
1932        assert_eq!(msgs[2].content, "ctx-c");
1933        assert_eq!(msgs[3].content, "hi");
1934    }
1935
1936    // -----------------------------------------------------------------------
1937    // system_prompt shortcut tests
1938    // -----------------------------------------------------------------------
1939
1940    #[tokio::test]
1941    async fn system_prompt_injects_system_message() {
1942        let llm = CapturingLlm::new(vec![LlmResponse::Message("ok".to_string())]);
1943
1944        let agent = AgentLoop::builder()
1945            .system_prompt("You are a helpful assistant.")
1946            .build();
1947
1948        let result = agent
1949            .run(&llm, vec![Message::user("hello")], |_| {})
1950            .await
1951            .unwrap();
1952
1953        assert_eq!(result.answer, "ok");
1954
1955        let calls = llm.captured_messages();
1956        assert_eq!(calls.len(), 1);
1957        let msgs = &calls[0];
1958        assert_eq!(msgs.len(), 2);
1959        assert_eq!(msgs[0].role, crate::message::Role::System);
1960        assert_eq!(msgs[0].content, "You are a helpful assistant.");
1961        assert_eq!(msgs[1].role, crate::message::Role::User);
1962        assert_eq!(msgs[1].content, "hello");
1963    }
1964
1965    // -----------------------------------------------------------------------
1966    // Streaming tests
1967    // -----------------------------------------------------------------------
1968
1969    /// A mock LLM that emits text deltas via `chat_stream` before the final
1970    /// `Done` chunk, simulating true streaming behaviour.
1971    struct StreamingMockLlm {
1972        /// Each entry is a sequence of StreamChunks for one call.
1973        responses: Mutex<Vec<Vec<crate::llm::StreamChunk>>>,
1974    }
1975
1976    impl StreamingMockLlm {
1977        fn new(responses: Vec<Vec<crate::llm::StreamChunk>>) -> Self {
1978            Self {
1979                responses: Mutex::new(responses),
1980            }
1981        }
1982    }
1983
1984    #[async_trait]
1985    impl LlmClient for StreamingMockLlm {
1986        async fn chat(
1987            &self,
1988            _messages: &[Message],
1989            _tools: &[ToolDef],
1990        ) -> crate::Result<ChatOutput> {
1991            panic!("StreamingMockLlm: chat() should not be called");
1992        }
1993
1994        fn chat_stream<'a>(
1995            &'a self,
1996            _messages: &'a [Message],
1997            _tools: &'a [ToolDef],
1998        ) -> std::pin::Pin<
1999            Box<dyn futures::Stream<Item = crate::Result<crate::llm::StreamChunk>> + Send + 'a>,
2000        > {
2001            let chunks = {
2002                let mut responses = self.responses.lock().unwrap();
2003                if responses.is_empty() {
2004                    panic!("StreamingMockLlm: no more responses");
2005                }
2006                responses.remove(0)
2007            };
2008            Box::pin(futures::stream::iter(chunks.into_iter().map(Ok)))
2009        }
2010    }
2011
2012    #[tokio::test]
2013    async fn run_streaming_emits_text_chunks_and_done() {
2014        let llm = StreamingMockLlm::new(vec![vec![
2015            crate::llm::StreamChunk::TextDelta("Hel".into()),
2016            crate::llm::StreamChunk::TextDelta("lo!".into()),
2017            crate::llm::StreamChunk::Done(LlmResponse::Message("Hello!".into())),
2018        ]]);
2019
2020        let events: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
2021        let events_clone = events.clone();
2022
2023        let agent = AgentLoop::builder().build();
2024        let result = agent
2025            .run_streaming(&llm, vec![Message::user("Hi")], move |event| {
2026                let label = match &event {
2027                    AgentEvent::TextChunk(t) => format!("chunk:{t}"),
2028                    AgentEvent::TextDone(t) => format!("done:{t}"),
2029                    AgentEvent::IterationStarted(n) => format!("iter:{n}"),
2030                    AgentEvent::ToolStarted { name } => format!("started:{name}"),
2031                    AgentEvent::ToolCompleted { name, .. } => format!("completed:{name}"),
2032                    AgentEvent::Interrupted => "interrupted".to_string(),
2033                    AgentEvent::AskUser { question, .. } => format!("ask_user:{question}"),
2034                    AgentEvent::AskUserTimeout { call_id, .. } => {
2035                        format!("ask_user_timeout:{call_id}")
2036                    }
2037                    _ => format!("{event:?}"),
2038                };
2039                events_clone.lock().unwrap().push(label);
2040            })
2041            .await
2042            .unwrap();
2043
2044        assert_eq!(result.answer, "Hello!");
2045        assert_eq!(result.iterations, 1);
2046
2047        let events = events.lock().unwrap();
2048        assert_eq!(
2049            *events,
2050            vec!["iter:1", "chunk:Hel", "chunk:lo!", "done:Hello!"]
2051        );
2052    }
2053
2054    #[tokio::test]
2055    async fn run_streaming_with_tool_calls() {
2056        let llm = StreamingMockLlm::new(vec![
2057            // First call: tool call (no text deltas)
2058            vec![crate::llm::StreamChunk::Done(
2059                LlmResponse::single_tool_call(
2060                    "call_1".into(),
2061                    "search".into(),
2062                    serde_json::json!({"input": "rust"}),
2063                ),
2064            )],
2065            // Second call: streamed text response
2066            vec![
2067                crate::llm::StreamChunk::TextDelta("Found ".into()),
2068                crate::llm::StreamChunk::TextDelta("it.".into()),
2069                crate::llm::StreamChunk::Done(LlmResponse::Message("Found it.".into())),
2070            ],
2071        ]);
2072
2073        let tool: Arc<dyn Tool> = Arc::new(MockTool::new("search", "Rust is great"));
2074        let events: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
2075        let events_clone = events.clone();
2076
2077        let agent = AgentLoop::builder().tool(tool).build();
2078        let result = agent
2079            .run_streaming(&llm, vec![Message::user("search rust")], move |event| {
2080                let label = match &event {
2081                    AgentEvent::TextChunk(t) => format!("chunk:{t}"),
2082                    AgentEvent::TextDone(t) => format!("done:{t}"),
2083                    AgentEvent::IterationStarted(n) => format!("iter:{n}"),
2084                    AgentEvent::ToolStarted { name } => format!("started:{name}"),
2085                    AgentEvent::ToolCompleted { name, .. } => format!("completed:{name}"),
2086                    AgentEvent::Interrupted => "interrupted".to_string(),
2087                    AgentEvent::AskUser { question, .. } => format!("ask_user:{question}"),
2088                    AgentEvent::AskUserTimeout { call_id, .. } => {
2089                        format!("ask_user_timeout:{call_id}")
2090                    }
2091                    _ => format!("{event:?}"),
2092                };
2093                events_clone.lock().unwrap().push(label);
2094            })
2095            .await
2096            .unwrap();
2097
2098        assert_eq!(result.answer, "Found it.");
2099        assert_eq!(result.tool_calls.len(), 1);
2100        assert_eq!(result.iterations, 2);
2101
2102        let events = events.lock().unwrap();
2103        assert_eq!(
2104            *events,
2105            vec![
2106                "iter:1",
2107                "started:search",
2108                "completed:search",
2109                "iter:2",
2110                "chunk:Found ",
2111                "chunk:it.",
2112                "done:Found it.",
2113            ]
2114        );
2115    }
2116
2117    #[tokio::test]
2118    async fn run_streaming_fallback_non_streaming_llm() {
2119        // MockLlm doesn't override chat_stream, so it uses the default
2120        // fallback that delegates to chat().
2121        let llm = MockLlm::new(vec![LlmResponse::Message("Hello!".into())]);
2122
2123        let events: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
2124        let events_clone = events.clone();
2125
2126        let agent = AgentLoop::builder().build();
2127        let result = agent
2128            .run_streaming(&llm, vec![Message::user("Hi")], move |event| {
2129                let label = match &event {
2130                    AgentEvent::TextChunk(t) => format!("chunk:{t}"),
2131                    AgentEvent::TextDone(t) => format!("done:{t}"),
2132                    AgentEvent::IterationStarted(n) => format!("iter:{n}"),
2133                    _ => "other".into(),
2134                };
2135                events_clone.lock().unwrap().push(label);
2136            })
2137            .await
2138            .unwrap();
2139
2140        assert_eq!(result.answer, "Hello!");
2141
2142        let events = events.lock().unwrap();
2143        // Default fallback: no text deltas, so run_streaming emits
2144        // the full text as a TextChunk, then TextDone.
2145        assert_eq!(*events, vec!["iter:1", "chunk:Hello!", "done:Hello!"]);
2146    }
2147
2148    // -----------------------------------------------------------------------
2149    // Token usage tests
2150    // -----------------------------------------------------------------------
2151
2152    #[tokio::test]
2153    async fn usage_is_nonzero_after_mocked_llm_call() {
2154        let usage = TokenUsage {
2155            input_tokens: 100,
2156            output_tokens: 50,
2157        };
2158        let llm = MockLlm::with_usage(vec![LlmResponse::Message("Hello!".to_string())], usage);
2159
2160        let agent = AgentLoop::builder().build();
2161        let result = agent
2162            .run(&llm, vec![Message::user("Hi")], |_| {})
2163            .await
2164            .unwrap();
2165
2166        assert_eq!(result.usage.input_tokens, 100);
2167        assert_eq!(result.usage.output_tokens, 50);
2168    }
2169
2170    #[tokio::test]
2171    async fn usage_accumulates_across_iterations() {
2172        let usage = TokenUsage {
2173            input_tokens: 10,
2174            output_tokens: 20,
2175        };
2176        let llm = MockLlm::with_usage(
2177            vec![
2178                LlmResponse::single_tool_call(
2179                    "call_1".to_string(),
2180                    "search".to_string(),
2181                    serde_json::json!({}),
2182                ),
2183                LlmResponse::Message("Done.".to_string()),
2184            ],
2185            usage,
2186        );
2187
2188        let tool: Arc<dyn Tool> = Arc::new(MockTool::new("search", "result"));
2189        let agent = AgentLoop::builder().tool(tool).build();
2190        let result = agent
2191            .run(&llm, vec![Message::user("test")], |_| {})
2192            .await
2193            .unwrap();
2194
2195        // Two LLM calls, each with usage (10, 20)
2196        assert_eq!(result.usage.input_tokens, 20);
2197        assert_eq!(result.usage.output_tokens, 40);
2198        assert_eq!(result.iterations, 2);
2199    }
2200
2201    #[tokio::test]
2202    async fn usage_zero_when_llm_reports_no_usage() {
2203        let llm = MockLlm::new(vec![LlmResponse::Message("ok".to_string())]);
2204
2205        let agent = AgentLoop::builder().build();
2206        let result = agent
2207            .run(&llm, vec![Message::user("hi")], |_| {})
2208            .await
2209            .unwrap();
2210
2211        assert_eq!(result.usage.input_tokens, 0);
2212        assert_eq!(result.usage.output_tokens, 0);
2213    }
2214
2215    #[tokio::test]
2216    async fn streaming_usage_accumulates() {
2217        let llm = StreamingMockLlm::new(vec![vec![
2218            crate::llm::StreamChunk::TextDelta("Hi".into()),
2219            crate::llm::StreamChunk::Done(LlmResponse::Message("Hi".into())),
2220            crate::llm::StreamChunk::Usage(TokenUsage {
2221                input_tokens: 50,
2222                output_tokens: 25,
2223            }),
2224        ]]);
2225
2226        let agent = AgentLoop::builder().build();
2227        let result = agent
2228            .run_streaming(&llm, vec![Message::user("Hi")], |_| {})
2229            .await
2230            .unwrap();
2231
2232        assert_eq!(result.usage.input_tokens, 50);
2233        assert_eq!(result.usage.output_tokens, 25);
2234    }
2235
2236    #[tokio::test]
2237    async fn run_streaming_max_iterations() {
2238        let responses: Vec<Vec<crate::llm::StreamChunk>> = (0..5)
2239            .map(|i| {
2240                vec![crate::llm::StreamChunk::Done(
2241                    LlmResponse::single_tool_call(
2242                        format!("call_{i}"),
2243                        "search".into(),
2244                        serde_json::json!({}),
2245                    ),
2246                )]
2247            })
2248            .collect();
2249        let llm = StreamingMockLlm::new(responses);
2250
2251        let tool: Arc<dyn Tool> = Arc::new(MockTool::new("search", "result"));
2252        let agent = AgentLoop::builder().tool(tool).max_iterations(3).build();
2253        let err = agent
2254            .run_streaming(&llm, vec![Message::user("loop")], |_| {})
2255            .await
2256            .unwrap_err();
2257
2258        assert!(matches!(err, AgentError::MaxIterations(3)));
2259    }
2260
2261    // -----------------------------------------------------------------------
2262    // messages field tests
2263    // -----------------------------------------------------------------------
2264
2265    #[tokio::test]
2266    async fn result_messages_contains_full_history() {
2267        let llm = MockLlm::new(vec![LlmResponse::Message("Hello!".to_string())]);
2268
2269        let agent = AgentLoop::builder().build();
2270        let result = agent
2271            .run(&llm, vec![Message::user("Hi")], |_| {})
2272            .await
2273            .unwrap();
2274
2275        // Should contain the user message + assistant reply.
2276        assert_eq!(result.messages.len(), 2);
2277        assert_eq!(result.messages[0].role, crate::message::Role::User);
2278        assert_eq!(result.messages[0].content, "Hi");
2279        assert_eq!(result.messages[1].role, crate::message::Role::Assistant);
2280        assert_eq!(result.messages[1].content, "Hello!");
2281    }
2282
2283    #[tokio::test]
2284    async fn result_messages_includes_tool_call_pairs() {
2285        let llm = MockLlm::new(vec![
2286            LlmResponse::single_tool_call(
2287                "call_1".to_string(),
2288                "search".to_string(),
2289                serde_json::json!({"input": "rust"}),
2290            ),
2291            LlmResponse::Message("Found it.".to_string()),
2292        ]);
2293
2294        let tool: Arc<dyn Tool> = Arc::new(MockTool::new("search", "result text"));
2295
2296        let agent = AgentLoop::builder().tool(tool).build();
2297        let result = agent
2298            .run(&llm, vec![Message::user("Search")], |_| {})
2299            .await
2300            .unwrap();
2301
2302        // user -> assistant(tool_call) -> tool_result -> assistant(final)
2303        assert_eq!(result.messages.len(), 4);
2304        assert_eq!(result.messages[0].role, crate::message::Role::User);
2305        assert_eq!(result.messages[1].role, crate::message::Role::Assistant);
2306        assert_eq!(result.messages[1].tool_calls.len(), 1);
2307        assert_eq!(result.messages[1].tool_calls[0].name, "search");
2308        assert_eq!(result.messages[2].role, crate::message::Role::Tool);
2309        assert_eq!(result.messages[3].role, crate::message::Role::Assistant);
2310        assert_eq!(result.messages[3].content, "Found it.");
2311    }
2312
2313    #[tokio::test]
2314    async fn multi_turn_continuation_via_messages() {
2315        // First turn: simple Q&A
2316        let llm1 = MockLlm::new(vec![LlmResponse::Message("I'm fine!".to_string())]);
2317        let agent = AgentLoop::builder().build();
2318        let result1 = agent
2319            .run(&llm1, vec![Message::user("How are you?")], |_| {})
2320            .await
2321            .unwrap();
2322
2323        // Second turn: continue using result.messages
2324        let llm2 = MockLlm::new(vec![LlmResponse::Message("Goodbye!".to_string())]);
2325        let mut next_messages = result1.messages;
2326        next_messages.push(Message::user("Bye!"));
2327
2328        let result2 = agent.run(&llm2, next_messages, |_| {}).await.unwrap();
2329
2330        assert_eq!(result2.answer, "Goodbye!");
2331        // History: user("How are you?") + assistant("I'm fine!") + user("Bye!") + assistant("Goodbye!")
2332        assert_eq!(result2.messages.len(), 4);
2333        assert_eq!(result2.messages[0].content, "How are you?");
2334        assert_eq!(result2.messages[1].content, "I'm fine!");
2335        assert_eq!(result2.messages[2].content, "Bye!");
2336        assert_eq!(result2.messages[3].content, "Goodbye!");
2337    }
2338
2339    #[tokio::test]
2340    async fn streaming_result_messages_contains_full_history() {
2341        let llm = StreamingMockLlm::new(vec![vec![
2342            crate::llm::StreamChunk::TextDelta("Hi".to_string()),
2343            crate::llm::StreamChunk::TextDelta(" there".to_string()),
2344            crate::llm::StreamChunk::Done(LlmResponse::Message("Hi there".to_string())),
2345        ]]);
2346
2347        let agent = AgentLoop::builder().build();
2348        let result = agent
2349            .run_streaming(&llm, vec![Message::user("Hello")], |_| {})
2350            .await
2351            .unwrap();
2352
2353        assert_eq!(result.messages.len(), 2);
2354        assert_eq!(result.messages[0].role, crate::message::Role::User);
2355        assert_eq!(result.messages[0].content, "Hello");
2356        assert_eq!(result.messages[1].role, crate::message::Role::Assistant);
2357        assert_eq!(result.messages[1].content, "Hi there");
2358    }
2359
2360    #[test]
2361    fn tool_execution_policy_detects_ask_user_presence() {
2362        let parallel_items = vec![ToolCallItem {
2363            id: "call_1".to_string(),
2364            name: "search".to_string(),
2365            args: serde_json::json!({}),
2366        }];
2367        let interactive_items = vec![
2368            ToolCallItem {
2369                id: "call_1".to_string(),
2370                name: "search".to_string(),
2371                args: serde_json::json!({}),
2372            },
2373            ToolCallItem {
2374                id: "call_2".to_string(),
2375                name: "ask_user".to_string(),
2376                args: serde_json::json!({"question": "continue?"}),
2377            },
2378        ];
2379
2380        assert!(matches!(
2381            ToolExecutionPolicy::from_items(&parallel_items),
2382            ToolExecutionPolicy::ParallelOnly
2383        ));
2384        assert!(matches!(
2385            ToolExecutionPolicy::from_items(&interactive_items),
2386            ToolExecutionPolicy::InteractiveAskUser
2387        ));
2388    }
2389
2390    // -----------------------------------------------------------------------
2391    // Pipeline stage regression tests
2392    // -----------------------------------------------------------------------
2393
2394    /// Verify that tool results are appended in the same order as tool calls.
2395    /// This is a regression test for the `execute_and_record_tool_calls` stage:
2396    /// tool result messages must correlate 1:1 with the tool call items.
2397    #[tokio::test]
2398    async fn tool_result_ordering_matches_call_ordering() {
2399        let llm = MockLlm::new(vec![
2400            LlmResponse::ToolCalls(vec![
2401                ToolCallItem {
2402                    id: "c1".to_string(),
2403                    name: "alpha".to_string(),
2404                    args: serde_json::json!({}),
2405                },
2406                ToolCallItem {
2407                    id: "c2".to_string(),
2408                    name: "beta".to_string(),
2409                    args: serde_json::json!({}),
2410                },
2411                ToolCallItem {
2412                    id: "c3".to_string(),
2413                    name: "gamma".to_string(),
2414                    args: serde_json::json!({}),
2415                },
2416            ]),
2417            LlmResponse::Message("done".to_string()),
2418        ]);
2419
2420        let alpha: Arc<dyn Tool> = Arc::new(MockTool::new("alpha", "res_alpha"));
2421        let beta: Arc<dyn Tool> = Arc::new(MockTool::new("beta", "res_beta"));
2422        let gamma: Arc<dyn Tool> = Arc::new(MockTool::new("gamma", "res_gamma"));
2423
2424        let agent = AgentLoop::builder()
2425            .tool(alpha)
2426            .tool(beta)
2427            .tool(gamma)
2428            .build();
2429
2430        let result = agent
2431            .run(&llm, vec![Message::user("go")], |_| {})
2432            .await
2433            .unwrap();
2434
2435        // Verify tool_calls record order.
2436        assert_eq!(result.tool_calls[0].0, "alpha");
2437        assert_eq!(result.tool_calls[1].0, "beta");
2438        assert_eq!(result.tool_calls[2].0, "gamma");
2439
2440        // Verify message history: user -> assistant(tool_calls) -> 3x tool_result -> assistant(done)
2441        assert_eq!(result.messages.len(), 6);
2442        // The assistant message at index 1 should carry all 3 tool call refs.
2443        assert_eq!(result.messages[1].tool_calls.len(), 3);
2444        assert_eq!(result.messages[1].tool_calls[0].id, "c1");
2445        assert_eq!(result.messages[1].tool_calls[1].id, "c2");
2446        assert_eq!(result.messages[1].tool_calls[2].id, "c3");
2447        // Tool result messages at indices 2, 3, 4 must match call order.
2448        assert_eq!(result.messages[2].tool_call_id.as_deref(), Some("c1"));
2449        assert_eq!(result.messages[2].content, "res_alpha");
2450        assert_eq!(result.messages[3].tool_call_id.as_deref(), Some("c2"));
2451        assert_eq!(result.messages[3].content, "res_beta");
2452        assert_eq!(result.messages[4].tool_call_id.as_deref(), Some("c3"));
2453        assert_eq!(result.messages[4].content, "res_gamma");
2454    }
2455
2456    /// Verify that the pipeline stages emit events in the correct order:
2457    /// IterationStarted -> ToolStarted (all) -> ToolCompleted (all) -> next iteration.
2458    /// This tests the boundary between the LLM step and tool-call execution stage.
2459    #[tokio::test]
2460    async fn stage_boundary_event_ordering() {
2461        let llm = MockLlm::new(vec![
2462            LlmResponse::ToolCalls(vec![
2463                ToolCallItem {
2464                    id: "c1".to_string(),
2465                    name: "x".to_string(),
2466                    args: serde_json::json!({}),
2467                },
2468                ToolCallItem {
2469                    id: "c2".to_string(),
2470                    name: "y".to_string(),
2471                    args: serde_json::json!({}),
2472                },
2473            ]),
2474            LlmResponse::Message("final".to_string()),
2475        ]);
2476
2477        let x: Arc<dyn Tool> = Arc::new(MockTool::new("x", "rx"));
2478        let y: Arc<dyn Tool> = Arc::new(MockTool::new("y", "ry"));
2479
2480        let events: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
2481        let events_clone = events.clone();
2482
2483        let agent = AgentLoop::builder().tool(x).tool(y).build();
2484        let _ = agent
2485            .run(&llm, vec![Message::user("go")], move |event| {
2486                let label = match &event {
2487                    AgentEvent::IterationStarted(n) => format!("iter:{n}"),
2488                    AgentEvent::ToolStarted { name } => format!("started:{name}"),
2489                    AgentEvent::ToolCompleted { name, .. } => format!("completed:{name}"),
2490                    AgentEvent::TextChunk(t) => format!("text:{t}"),
2491                    AgentEvent::TextDone(_) => "done".to_string(),
2492                    _ => "other".to_string(),
2493                };
2494                events_clone.lock().unwrap().push(label);
2495            })
2496            .await
2497            .unwrap();
2498
2499        let events = events.lock().unwrap();
2500        // Stage boundaries:
2501        // 1. iter:1 (iteration starts)
2502        // 2. started:x, started:y (tool-call planning emits all ToolStarted before execution)
2503        // 3. completed:x, completed:y (execution completes, results recorded)
2504        // 4. iter:2 (next iteration)
2505        // 5. text:final (LLM returns message)
2506        assert_eq!(events[0], "iter:1");
2507        assert_eq!(events[1], "started:x");
2508        assert_eq!(events[2], "started:y");
2509        assert_eq!(events[3], "completed:x");
2510        assert_eq!(events[4], "completed:y");
2511        assert_eq!(events[5], "iter:2");
2512        assert_eq!(events[6], "text:final");
2513    }
2514
2515    /// Verify that `TurnState::into_result` correctly captures all accumulated
2516    /// state (tool calls, usage, messages).
2517    #[tokio::test]
2518    async fn turn_state_accumulation_across_iterations() {
2519        let usage = TokenUsage {
2520            input_tokens: 7,
2521            output_tokens: 3,
2522        };
2523        let llm = MockLlm::with_usage(
2524            vec![
2525                LlmResponse::single_tool_call(
2526                    "c1".to_string(),
2527                    "t1".to_string(),
2528                    serde_json::json!({}),
2529                ),
2530                LlmResponse::single_tool_call(
2531                    "c2".to_string(),
2532                    "t2".to_string(),
2533                    serde_json::json!({}),
2534                ),
2535                LlmResponse::Message("end".to_string()),
2536            ],
2537            usage,
2538        );
2539
2540        let t1: Arc<dyn Tool> = Arc::new(MockTool::new("t1", "r1"));
2541        let t2: Arc<dyn Tool> = Arc::new(MockTool::new("t2", "r2"));
2542
2543        let agent = AgentLoop::builder().tool(t1).tool(t2).build();
2544        let result = agent
2545            .run(&llm, vec![Message::user("go")], |_| {})
2546            .await
2547            .unwrap();
2548
2549        assert_eq!(result.iterations, 3);
2550        assert_eq!(result.tool_calls.len(), 2);
2551        // 3 LLM calls * 7 input tokens each
2552        assert_eq!(result.usage.input_tokens, 21);
2553        assert_eq!(result.usage.output_tokens, 9);
2554        // Message history: user + (assistant+tool_result) * 2 + assistant(final)
2555        // user, assistant(tc), tool_result, assistant(tc), tool_result, assistant(final)
2556        assert_eq!(result.messages.len(), 6);
2557    }
2558
2559    /// Verify that the streaming path produces identical message history
2560    /// to the non-streaming path for the same logical conversation.
2561    #[tokio::test]
2562    async fn streaming_and_non_streaming_produce_same_messages() {
2563        // Non-streaming
2564        let llm_sync = MockLlm::new(vec![
2565            LlmResponse::single_tool_call(
2566                "c1".to_string(),
2567                "search".to_string(),
2568                serde_json::json!({"q": "rust"}),
2569            ),
2570            LlmResponse::Message("Found it.".to_string()),
2571        ]);
2572        let tool: Arc<dyn Tool> = Arc::new(MockTool::new("search", "result text"));
2573        let agent = AgentLoop::builder().tool(tool).build();
2574        let result_sync = agent
2575            .run(&llm_sync, vec![Message::user("Search")], |_| {})
2576            .await
2577            .unwrap();
2578
2579        // Streaming (using MockLlm which falls back to chat via default chat_stream)
2580        let llm_stream = MockLlm::new(vec![
2581            LlmResponse::single_tool_call(
2582                "c1".to_string(),
2583                "search".to_string(),
2584                serde_json::json!({"q": "rust"}),
2585            ),
2586            LlmResponse::Message("Found it.".to_string()),
2587        ]);
2588        let tool2: Arc<dyn Tool> = Arc::new(MockTool::new("search", "result text"));
2589        let agent2 = AgentLoop::builder().tool(tool2).build();
2590        let result_stream = agent2
2591            .run_streaming(&llm_stream, vec![Message::user("Search")], |_| {})
2592            .await
2593            .unwrap();
2594
2595        // Both paths should produce the same message history.
2596        assert_eq!(result_sync.messages.len(), result_stream.messages.len());
2597        for (a, b) in result_sync
2598            .messages
2599            .iter()
2600            .zip(result_stream.messages.iter())
2601        {
2602            assert_eq!(a.role, b.role);
2603            assert_eq!(a.content, b.content);
2604            assert_eq!(a.tool_call_id, b.tool_call_id);
2605            assert_eq!(a.tool_calls.len(), b.tool_calls.len());
2606        }
2607        assert_eq!(result_sync.iterations, result_stream.iterations);
2608        assert_eq!(result_sync.tool_calls.len(), result_stream.tool_calls.len());
2609    }
2610
2611    // -----------------------------------------------------------------------
2612    // Cancellation tests (cancellation feature)
2613    // -----------------------------------------------------------------------
2614
2615    #[cfg(feature = "cancellation")]
2616    mod cancellation_tests {
2617        use super::*;
2618        use tokio_util::sync::CancellationToken;
2619
2620        #[tokio::test]
2621        async fn cancel_before_run_returns_cancelled() {
2622            let llm = MockLlm::new(vec![LlmResponse::Message("Hello!".into())]);
2623            let agent = AgentLoop::builder().build();
2624
2625            let token = CancellationToken::new();
2626            token.cancel();
2627
2628            let err = agent
2629                .run_with_cancel(&llm, vec![Message::user("Hi")], token, |_| {})
2630                .await
2631                .unwrap_err();
2632
2633            assert!(matches!(err, AgentError::Cancelled));
2634        }
2635
2636        #[tokio::test]
2637        async fn cancel_mid_run_returns_cancelled() {
2638            let responses: Vec<LlmResponse> = (0..5)
2639                .map(|i| {
2640                    LlmResponse::single_tool_call(
2641                        format!("call_{i}"),
2642                        "search".to_string(),
2643                        serde_json::json!({}),
2644                    )
2645                })
2646                .collect();
2647            let llm = MockLlm::new(responses);
2648            let tool: Arc<dyn Tool> = Arc::new(MockTool::new("search", "result"));
2649
2650            let agent = AgentLoop::builder().tool(tool).max_iterations(5).build();
2651
2652            let token = CancellationToken::new();
2653            let child = token.child_token();
2654            let iterations = Arc::new(Mutex::new(0usize));
2655            let iterations_clone = iterations.clone();
2656            let token_clone = token.clone();
2657
2658            let err = agent
2659                .run_with_cancel(&llm, vec![Message::user("loop")], child, move |event| {
2660                    if let AgentEvent::IterationStarted(_) = &event {
2661                        let mut count = iterations_clone.lock().unwrap();
2662                        *count += 1;
2663                        if *count >= 2 {
2664                            token_clone.cancel();
2665                        }
2666                    }
2667                })
2668                .await
2669                .unwrap_err();
2670
2671            assert!(matches!(err, AgentError::Cancelled));
2672            let count = *iterations.lock().unwrap();
2673            assert!(
2674                count >= 2 && count <= 3,
2675                "unexpected iteration count: {count}"
2676            );
2677        }
2678
2679        #[tokio::test]
2680        async fn cancel_streaming_returns_cancelled() {
2681            let llm = MockLlm::new(vec![LlmResponse::Message("Hello!".into())]);
2682            let agent = AgentLoop::builder().build();
2683
2684            let token = CancellationToken::new();
2685            token.cancel();
2686
2687            let err = agent
2688                .run_streaming_with_cancel(&llm, vec![Message::user("Hi")], token, |_| {})
2689                .await
2690                .unwrap_err();
2691
2692            assert!(matches!(err, AgentError::Cancelled));
2693        }
2694
2695        #[tokio::test]
2696        async fn uncancelled_token_runs_normally() {
2697            let llm = MockLlm::new(vec![LlmResponse::Message("Hello!".into())]);
2698            let agent = AgentLoop::builder().build();
2699
2700            let token = CancellationToken::new();
2701
2702            let result = agent
2703                .run_with_cancel(&llm, vec![Message::user("Hi")], token, |_| {})
2704                .await
2705                .unwrap();
2706
2707            assert_eq!(result.answer, "Hello!");
2708            assert_eq!(result.iterations, 1);
2709        }
2710    }
2711}
2712
2713#[cfg(all(test, feature = "mcp-client"))]
2714mod mcp_integration_tests {
2715    use super::*;
2716    use crate::mcp::McpServer;
2717    use async_trait::async_trait;
2718    use motosan_agent_tool::ToolDef;
2719    use serde_json::json;
2720
2721    struct EchoMcpServer {
2722        name: String,
2723    }
2724
2725    #[async_trait]
2726    impl McpServer for EchoMcpServer {
2727        fn name(&self) -> &str {
2728            &self.name
2729        }
2730        async fn connect(&self) -> crate::Result<()> {
2731            Ok(())
2732        }
2733        async fn list_tools(&self) -> crate::Result<Vec<ToolDef>> {
2734            Ok(vec![ToolDef {
2735                name: "echo".to_string(),
2736                description: "Echo input".to_string(),
2737                input_schema: json!({"type": "object", "properties": {"msg": {"type": "string"}}}),
2738            }])
2739        }
2740        async fn call_tool(&self, _name: &str, args: serde_json::Value) -> crate::Result<String> {
2741            Ok(format!(
2742                "echo: {}",
2743                args.get("msg").and_then(|v| v.as_str()).unwrap_or("")
2744            ))
2745        }
2746        async fn disconnect(&self) -> crate::Result<()> {
2747            Ok(())
2748        }
2749    }
2750
2751    #[test]
2752    fn builder_accepts_mcp_server() {
2753        let agent = AgentLoop::builder()
2754            .mcp_server(EchoMcpServer {
2755                name: "test_mcp".to_string(),
2756            })
2757            .max_iterations(5)
2758            .build();
2759        assert_eq!(agent.mcp_servers.len(), 1);
2760    }
2761
2762    #[test]
2763    fn builder_accepts_shared_arc_mcp_server() {
2764        let shared: Arc<dyn McpServer> = Arc::new(EchoMcpServer {
2765            name: "shared_mcp".to_string(),
2766        });
2767
2768        let agent_a = AgentLoop::builder()
2769            .mcp_server_arc(Arc::clone(&shared))
2770            .max_iterations(5)
2771            .build();
2772
2773        let agent_b = AgentLoop::builder()
2774            .mcp_server_arc(Arc::clone(&shared))
2775            .max_iterations(5)
2776            .build();
2777
2778        assert_eq!(agent_a.mcp_servers.len(), 1);
2779        assert_eq!(agent_b.mcp_servers.len(), 1);
2780        assert!(Arc::ptr_eq(
2781            &agent_a.mcp_servers[0],
2782            &agent_b.mcp_servers[0]
2783        ));
2784    }
2785
2786    /// A simple mock LLM for MCP streaming tests.
2787    /// Uses the default `chat_stream` fallback (delegates to `chat`).
2788    struct McpTestLlm {
2789        responses: std::sync::Mutex<Vec<crate::llm::LlmResponse>>,
2790    }
2791
2792    impl McpTestLlm {
2793        fn new(responses: Vec<crate::llm::LlmResponse>) -> Self {
2794            Self {
2795                responses: std::sync::Mutex::new(responses),
2796            }
2797        }
2798    }
2799
2800    #[async_trait]
2801    impl crate::llm::LlmClient for McpTestLlm {
2802        async fn chat(
2803            &self,
2804            _messages: &[Message],
2805            _tools: &[ToolDef],
2806        ) -> crate::Result<crate::llm::ChatOutput> {
2807            let mut responses = self.responses.lock().unwrap();
2808            assert!(!responses.is_empty(), "McpTestLlm: no more responses");
2809            let response = responses.remove(0);
2810            Ok(crate::llm::ChatOutput {
2811                response,
2812                usage: None,
2813            })
2814        }
2815    }
2816
2817    #[tokio::test]
2818    async fn run_streaming_with_mcp_server() {
2819        use crate::llm::{LlmResponse, ToolCallItem};
2820        use std::sync::Mutex;
2821
2822        // LLM first requests the MCP tool (namespaced as server__tool),
2823        // then returns a final message.
2824        // Uses the default chat_stream fallback (delegates to chat).
2825        let llm = McpTestLlm::new(vec![
2826            LlmResponse::ToolCalls(vec![ToolCallItem {
2827                id: "call_mcp".to_string(),
2828                name: "test_mcp__echo".to_string(),
2829                args: json!({"msg": "hello"}),
2830            }]),
2831            LlmResponse::Message("MCP says: echo hello".to_string()),
2832        ]);
2833
2834        let events: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
2835        let events_clone = events.clone();
2836
2837        let agent = AgentLoop::builder()
2838            .mcp_server(EchoMcpServer {
2839                name: "test_mcp".to_string(),
2840            })
2841            .max_iterations(5)
2842            .build();
2843
2844        let result = agent
2845            .run_streaming(&llm, vec![Message::user("call echo")], move |event| {
2846                let label = match &event {
2847                    AgentEvent::ToolStarted { name } => format!("started:{name}"),
2848                    AgentEvent::ToolCompleted { name, result } => {
2849                        format!("completed:{name}:{}", result.as_text().unwrap_or(""))
2850                    }
2851                    AgentEvent::TextChunk(t) => format!("chunk:{t}"),
2852                    AgentEvent::TextDone(t) => format!("done:{t}"),
2853                    AgentEvent::IterationStarted(n) => format!("iter:{n}"),
2854                    AgentEvent::Interrupted => "interrupted".to_string(),
2855                    AgentEvent::AskUser { question, .. } => format!("ask_user:{question}"),
2856                    AgentEvent::AskUserTimeout { call_id, .. } => {
2857                        format!("ask_user_timeout:{call_id}")
2858                    }
2859                };
2860                events_clone.lock().unwrap().push(label);
2861            })
2862            .await
2863            .unwrap();
2864
2865        assert_eq!(result.answer, "MCP says: echo hello");
2866        assert_eq!(result.tool_calls.len(), 1);
2867        assert_eq!(result.tool_calls[0].0, "test_mcp__echo");
2868
2869        let events = events.lock().unwrap();
2870        // The MCP tool should have been executed and returned "echo: hello".
2871        assert!(events
2872            .iter()
2873            .any(|e| e.starts_with("completed:test_mcp__echo:echo: hello")));
2874    }
2875
2876    /// MCP server that tracks connect/disconnect calls via shared counters.
2877    struct TrackingMcpServer {
2878        name: String,
2879        fail_connect: bool,
2880        connected: Arc<std::sync::atomic::AtomicBool>,
2881        disconnect_count: Arc<std::sync::atomic::AtomicUsize>,
2882    }
2883
2884    #[async_trait]
2885    impl McpServer for TrackingMcpServer {
2886        fn name(&self) -> &str {
2887            &self.name
2888        }
2889        async fn connect(&self) -> crate::Result<()> {
2890            if self.fail_connect {
2891                Err(crate::AgentError::Mcp("connect failed".into()))
2892            } else {
2893                self.connected
2894                    .store(true, std::sync::atomic::Ordering::SeqCst);
2895                Ok(())
2896            }
2897        }
2898        async fn list_tools(&self) -> crate::Result<Vec<ToolDef>> {
2899            Ok(vec![])
2900        }
2901        async fn call_tool(&self, _name: &str, _args: serde_json::Value) -> crate::Result<String> {
2902            Ok(String::new())
2903        }
2904        async fn disconnect(&self) -> crate::Result<()> {
2905            self.disconnect_count
2906                .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
2907            Ok(())
2908        }
2909    }
2910
2911    #[tokio::test]
2912    async fn partial_connect_failure_disconnects_already_connected() {
2913        use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
2914
2915        let server1_connected = Arc::new(AtomicBool::new(false));
2916        let server1_disconnects = Arc::new(AtomicUsize::new(0));
2917        let server2_disconnects = Arc::new(AtomicUsize::new(0));
2918
2919        let server1: Arc<dyn McpServer> = Arc::new(TrackingMcpServer {
2920            name: "ok_server".to_string(),
2921            fail_connect: false,
2922            connected: server1_connected.clone(),
2923            disconnect_count: server1_disconnects.clone(),
2924        });
2925        let server2: Arc<dyn McpServer> = Arc::new(TrackingMcpServer {
2926            name: "fail_server".to_string(),
2927            fail_connect: true,
2928            connected: Arc::new(AtomicBool::new(false)),
2929            disconnect_count: server2_disconnects.clone(),
2930        });
2931
2932        let agent = AgentLoop::builder()
2933            .mcp_server_arc(server1)
2934            .mcp_server_arc(server2)
2935            .max_iterations(1)
2936            .build();
2937
2938        let llm = McpTestLlm::new(vec![]);
2939        let err = agent
2940            .run(&llm, vec![Message::user("hi")], |_| {})
2941            .await
2942            .unwrap_err();
2943
2944        assert!(
2945            matches!(err, crate::AgentError::Mcp(_)),
2946            "expected connect error"
2947        );
2948        // Server 1 was successfully connected, so it must have been disconnected.
2949        assert!(server1_connected.load(Ordering::SeqCst));
2950        assert_eq!(server1_disconnects.load(Ordering::SeqCst), 1);
2951        // Server 2 never connected, so disconnect should not have been called.
2952        assert_eq!(server2_disconnects.load(Ordering::SeqCst), 0);
2953    }
2954}