Skip to main content

rab/agent/
loop.rs

1use crate::agent::extension::{Cancel, Extension, ToolOutput};
2use crate::agent::provider::{Provider, StopReason, StreamEvent, ToolDef};
3use crate::agent::types::{AgentMessage, PendingMessageQueue, Role, ToolCall, ToolExecutionMode};
4use futures::future::join_all;
5
6/// Collect tool definitions from all extensions.
7pub fn collect_tool_defs(extensions: &[Box<dyn Extension>]) -> Vec<ToolDef> {
8    let mut defs = Vec::new();
9    for ext in extensions {
10        for tool in ext.tools() {
11            if !defs.iter().any(|d: &ToolDef| d.name == tool.name()) {
12                defs.push(ToolDef {
13                    name: tool.name().to_string(),
14                    description: tool.description().to_string(),
15                    parameters: tool.parameters(),
16                });
17            }
18        }
19    }
20    defs
21}
22
23/// Emitted by the loop for consumers (print mode writes to stdout; TUI later renders).
24#[derive(Debug, Clone)]
25#[allow(dead_code)]
26pub enum AgentEvent {
27    AgentStart,
28    TurnStart,
29    TextDelta {
30        delta: String,
31    },
32    ThinkingDelta {
33        delta: String,
34    },
35    ToolCall {
36        id: String,
37        name: String,
38        args: serde_json::Value,
39    },
40    /// Progressive args update (pi calls renderCall multiple times).
41    ToolCallArgsUpdate {
42        id: String,
43        args: serde_json::Value,
44    },
45    ToolResult {
46        id: String,
47        name: String,
48        content: String,
49        compact: Option<String>,
50        is_error: bool,
51    },
52    /// Intermediate tool execution progress (bash streaming output).
53    ToolProgress {
54        content: String,
55        is_error: bool,
56    },
57    /// Stream was aborted or errored. TextDelta/ThinkingDelta may have been sent before.
58    Aborted {
59        reason: String,
60    },
61    /// A user message was injected from the steering or follow-up queue.
62    UserMessage {
63        content: String,
64    },
65    TurnEnd,
66    AgentEnd {
67        messages: Vec<AgentMessage>,
68    },
69}
70
71/// Transform function: rewrites messages before each LLM call.
72pub type TransformFn = Box<dyn Fn(&[AgentMessage]) -> Vec<AgentMessage> + Send + Sync>;
73
74/// Prepare-next-turn function: optionally modifies context between turns.
75pub type PrepareNextTurnFn = Box<dyn Fn(&[AgentMessage]) -> Option<TurnUpdate> + Send + Sync>;
76
77/// Should-stop-after-turn predicate: early-stop check.
78pub type ShouldStopFn = Box<dyn Fn(&[AgentMessage]) -> bool + Send + Sync>;
79
80/// Optional return value from `prepare_next_turn` to modify context for the next turn.
81pub struct TurnUpdate {
82    /// Replace the full message context for the next LLM call.
83    pub context: Option<Vec<AgentMessage>>,
84}
85
86/// Configuration for the agent loop.
87pub struct LoopConfig<'a> {
88    pub model: String,
89    pub system_prompt: String,
90    pub tools: Vec<ToolDef>,
91    pub agent_tools: &'a [Box<dyn crate::agent::extension::AgentTool>],
92    pub extensions: &'a [Box<dyn Extension>],
93    /// Tool execution mode: parallel (default) or sequential.
94    pub tool_execution: ToolExecutionMode,
95    /// Optional steering queue: messages delivered after the current assistant turn's
96    /// tool calls finish, before the next LLM call.
97    pub steering_queue: Option<&'a std::sync::Mutex<PendingMessageQueue>>,
98    /// Optional follow-up queue: messages delivered only after the agent has no more
99    /// tool calls (fully idle).
100    pub follow_up_queue: Option<&'a std::sync::Mutex<PendingMessageQueue>>,
101    /// Optional transform applied to the message list before each LLM call.
102    /// Receives the current messages and returns (possibly modified) messages.
103    /// Pi-compatible: `transformContext` for context window management, pruning, etc.
104    pub transform_context: Option<TransformFn>,
105    /// Optional callback invoked after each turn completes.
106    /// Can return a `TurnUpdate` to modify the context for the next turn.
107    /// Pi-compatible: `prepareNextTurn`.
108    pub prepare_next_turn: Option<PrepareNextTurnFn>,
109    /// Optional predicate invoked after each turn completes.
110    /// Return true to stop the agent loop early.
111    /// Pi-compatible: `shouldStopAfterTurn`.
112    pub should_stop_after_turn: Option<ShouldStopFn>,
113}
114
115/// Find a tool by name across all extensions.
116fn find_tool<'a>(
117    tools: &'a [Box<dyn crate::agent::extension::AgentTool>],
118    name: &str,
119) -> Option<&'a dyn crate::agent::extension::AgentTool> {
120    tools.iter().find(|t| t.name() == name).map(|t| t.as_ref())
121}
122
123/// Maximum tool-calling iterations before we force-stop (prevents infinite LLM loops).
124const MAX_TOOL_ITERATIONS: usize = 25;
125
126/// Result of a single tool execution within the execution phase.
127struct ToolExecOutcome {
128    id: String,
129    name: String,
130    content: String,
131    compact: Option<String>,
132    is_error: bool,
133    /// When true and ALL tools in the batch are terminal, skip further LLM calls.
134    terminate: bool,
135}
136
137/// Run the full agent loop. Returns all new messages added during the run.
138/// `history` contains pre-existing messages from a previous session (if continuing).
139///
140/// Supports parallel tool execution (default) and sequential, plus steering/follow-up
141/// message queues for mid-stream interruption (pi-compatible).
142pub async fn run_agent_loop(
143    prompts: Vec<AgentMessage>,
144    history: Vec<AgentMessage>,
145    config: &LoopConfig<'_>,
146    provider: &dyn Provider,
147    emit: &mut (dyn FnMut(AgentEvent) + Send),
148) -> anyhow::Result<Vec<AgentMessage>> {
149    let mut messages: Vec<AgentMessage> = Vec::new();
150    messages.extend(history);
151    messages.extend(prompts.clone());
152
153    let mut new_messages: Vec<AgentMessage> = prompts.clone();
154
155    emit(AgentEvent::AgentStart);
156    emit(AgentEvent::TurnStart);
157
158    let mut iteration_count: usize = 0;
159
160    // ── Outer loop: continues when follow-up messages arrive ──
161    // (pi-compatible: after agent would stop, check follow-up queue and continue)
162    loop {
163        // ── Inner loop: stream LLM → execute tools → repeat ──
164        let mut has_more_tool_calls = true;
165
166        while has_more_tool_calls {
167            iteration_count += 1;
168            if iteration_count > MAX_TOOL_ITERATIONS {
169                let msg = format!(
170                    "Agent loop exceeded maximum iterations ({}). Last response may be incomplete.",
171                    MAX_TOOL_ITERATIONS
172                );
173                emit(AgentEvent::Aborted {
174                    reason: msg.clone(),
175                });
176                emit(AgentEvent::AgentEnd {
177                    messages: new_messages.clone(),
178                });
179                return Ok(new_messages);
180            }
181
182            // Check steering messages before each LLM call
183            // (pi-compatible: delivered after current turn's tool calls finish,
184            //  before next LLM call)
185            drain_steering(config, &mut messages, &mut new_messages, emit);
186
187            // 1. Stream LLM response
188            // Apply transform_context if configured (pi-compatible: rewrite messages
189            // for the LLM call without modifying the stored transcript).
190            let llm_messages: &[AgentMessage] = &messages;
191            let _transformed_holder;
192            let llm_messages = if let Some(ref transform) = config.transform_context {
193                _transformed_holder = transform(llm_messages);
194                &_transformed_holder
195            } else {
196                llm_messages
197            };
198            let mut stream = provider
199                .stream(
200                    &config.model,
201                    &config.system_prompt,
202                    llm_messages,
203                    &config.tools,
204                )
205                .await?;
206
207            // 2. Collect streaming response
208            let mut response_text = String::new();
209            let mut tool_calls: Vec<ToolCall> = Vec::new();
210            let mut stop_reason = StopReason::EndTurn;
211
212            while let Some(event) = futures::StreamExt::next(&mut stream).await {
213                match event {
214                    StreamEvent::TextDelta { text } => {
215                        response_text.push_str(&text);
216                        emit(AgentEvent::TextDelta { delta: text });
217                    }
218                    StreamEvent::ThinkingDelta { text } => {
219                        emit(AgentEvent::ThinkingDelta { delta: text });
220                    }
221                    StreamEvent::ToolCall {
222                        id,
223                        name,
224                        arguments,
225                    } => {
226                        let args: serde_json::Value = serde_json::from_str(&arguments)
227                            .unwrap_or(serde_json::Value::String(arguments.clone()));
228
229                        if let Some(existing) = tool_calls.iter_mut().find(|tc| tc.id == id) {
230                            existing.arguments = args;
231                        } else {
232                            tool_calls.push(ToolCall {
233                                id,
234                                name,
235                                arguments: args,
236                            });
237                        }
238                    }
239                    StreamEvent::Done {
240                        text,
241                        stop_reason: sr,
242                        tool_calls: tcs,
243                        ..
244                    } => {
245                        if response_text.is_empty() && !text.is_empty() {
246                            emit(AgentEvent::TextDelta {
247                                delta: text.clone(),
248                            });
249                        }
250                        response_text = text;
251                        stop_reason = sr;
252                        if !tcs.is_empty() {
253                            tool_calls = tcs;
254                        }
255                    }
256                    StreamEvent::Error { message } => {
257                        emit(AgentEvent::Aborted {
258                            reason: message.clone(),
259                        });
260                        emit(AgentEvent::ToolResult {
261                            id: String::new(),
262                            name: String::new(),
263                            content: message.clone(),
264                            compact: None,
265                            is_error: true,
266                        });
267                        let error_msg =
268                            AgentMessage::tool_result(String::new(), message.clone(), true);
269                        new_messages.push(error_msg);
270                        emit(AgentEvent::AgentEnd {
271                            messages: new_messages.clone(),
272                        });
273                        return Ok(new_messages);
274                    }
275                }
276            }
277
278            // Create assistant message
279            let assistant_msg = AgentMessage {
280                id: uuid::Uuid::new_v4().to_string(),
281                parent_id: None,
282                role: Role::Assistant,
283                content: response_text.clone(),
284                tool_calls: tool_calls.clone(),
285                tool_call_id: None,
286                usage: None,
287                is_error: false,
288                timestamp: chrono::Utc::now().timestamp_millis(),
289            };
290
291            messages.push(assistant_msg.clone());
292            new_messages.push(assistant_msg);
293
294            // Handle errors
295            if stop_reason == StopReason::Error {
296                emit(AgentEvent::AgentEnd {
297                    messages: new_messages.clone(),
298                });
299                return Ok(new_messages);
300            }
301
302            // 3. Execute tool calls
303            if !tool_calls.is_empty() {
304                // Check if any tool in this batch declares sequential execution mode.
305                // If so, the entire batch runs sequentially (pi-compatible per-tool override).
306                let has_sequential_tool = tool_calls.iter().any(|tc| {
307                    config
308                        .agent_tools
309                        .iter()
310                        .find(|t| t.name() == tc.name)
311                        .map(|t| t.execution_mode() == ToolExecutionMode::Sequential)
312                        .unwrap_or(false)
313                });
314
315                let effective_mode = if has_sequential_tool {
316                    ToolExecutionMode::Sequential
317                } else {
318                    config.tool_execution
319                };
320
321                let outcomes = match effective_mode {
322                    ToolExecutionMode::Parallel => {
323                        execute_tool_calls_parallel(&tool_calls, config, emit).await
324                    }
325                    ToolExecutionMode::Sequential => {
326                        execute_tool_calls_sequential(&tool_calls, config, emit).await
327                    }
328                };
329
330                let all_terminate = !outcomes.is_empty() && outcomes.iter().all(|o| o.terminate);
331
332                for outcome in outcomes {
333                    let msg =
334                        AgentMessage::tool_result(&outcome.id, &outcome.content, outcome.is_error);
335                    emit(AgentEvent::ToolResult {
336                        id: outcome.id,
337                        name: outcome.name,
338                        content: outcome.content,
339                        compact: outcome.compact,
340                        is_error: outcome.is_error,
341                    });
342                    messages.push(msg.clone());
343                    new_messages.push(msg);
344                }
345
346                // Prepare next turn (pi-compatible: allows modifying context between turns
347                // even when tools were called)
348                apply_prepare_next_turn(config, &mut messages, &new_messages);
349
350                if all_terminate {
351                    // All tools returned terminate=true, stop further LLM calls
352                    emit(AgentEvent::TurnEnd);
353                    break;
354                }
355
356                // Inner loop continues — tool results go back to LLM
357                continue;
358            }
359
360            // 4. No tool calls — inner turn complete
361            has_more_tool_calls = false;
362            emit(AgentEvent::TurnEnd);
363
364            // Prepare next turn after the turn fully completes
365            apply_prepare_next_turn(config, &mut messages, &new_messages);
366
367            // Check should_stop_after_turn (pi-compatible: early-stop predicate)
368            if apply_should_stop_after_turn(config, &new_messages) {
369                emit(AgentEvent::AgentEnd {
370                    messages: new_messages.clone(),
371                });
372                return Ok(new_messages);
373            }
374        }
375
376        // 5. Agent would stop. Check for follow-up messages.
377        // (pi-compatible: follow-up messages are delivered only after agent is idle)
378        if !drain_follow_up(config, &mut messages, &mut new_messages, emit) {
379            break;
380        }
381    }
382
383    emit(AgentEvent::AgentEnd {
384        messages: new_messages.clone(),
385    });
386    Ok(new_messages)
387}
388
389/// Drain steering messages into the message list, emitting UserMessage events.
390fn drain_steering(
391    config: &LoopConfig<'_>,
392    messages: &mut Vec<AgentMessage>,
393    new_messages: &mut Vec<AgentMessage>,
394    emit: &mut (dyn FnMut(AgentEvent) + Send),
395) -> bool {
396    let Some(queue) = config.steering_queue else {
397        return false;
398    };
399    let drained = queue.lock().unwrap().drain();
400    if drained.is_empty() {
401        return false;
402    }
403    for msg in drained {
404        emit(AgentEvent::UserMessage {
405            content: msg.content.clone(),
406        });
407        messages.push(msg.clone());
408        new_messages.push(msg);
409    }
410    true
411}
412
413/// Drain follow-up messages into the message list, emitting UserMessage events.
414/// Returns true if any messages were drained (caller should continue outer loop).
415fn drain_follow_up(
416    config: &LoopConfig<'_>,
417    messages: &mut Vec<AgentMessage>,
418    new_messages: &mut Vec<AgentMessage>,
419    emit: &mut (dyn FnMut(AgentEvent) + Send),
420) -> bool {
421    let Some(queue) = config.follow_up_queue else {
422        return false;
423    };
424    let drained = queue.lock().unwrap().drain();
425    if drained.is_empty() {
426        return false;
427    }
428    for msg in drained {
429        emit(AgentEvent::UserMessage {
430            content: msg.content.clone(),
431        });
432        messages.push(msg.clone());
433        new_messages.push(msg);
434    }
435    true
436}
437
438/// Apply `prepare_next_turn` callback if configured.
439/// Modifies the message context for the next turn (pi-compatible).
440fn apply_prepare_next_turn(
441    config: &LoopConfig<'_>,
442    messages: &mut Vec<AgentMessage>,
443    new_messages: &[AgentMessage],
444) {
445    if let Some(ref prepare) = config.prepare_next_turn
446        && let Some(update) = prepare(new_messages)
447        && let Some(ctx) = update.context
448    {
449        *messages = ctx;
450    }
451}
452
453/// Apply `should_stop_after_turn` callback if configured.
454/// Returns true if the agent loop should stop early (pi-compatible).
455fn apply_should_stop_after_turn(config: &LoopConfig<'_>, new_messages: &[AgentMessage]) -> bool {
456    config
457        .should_stop_after_turn
458        .as_ref()
459        .map(|stop| stop(new_messages))
460        .unwrap_or(false)
461}
462
463/// Execute tool calls sequentially (one at a time, in order).
464async fn execute_tool_calls_sequential(
465    tool_calls: &[ToolCall],
466    config: &LoopConfig<'_>,
467    emit: &mut (dyn FnMut(AgentEvent) + Send),
468) -> Vec<ToolExecOutcome> {
469    let mut outcomes = Vec::new();
470
471    for tc in tool_calls {
472        emit(AgentEvent::ToolCall {
473            id: tc.id.clone(),
474            name: tc.name.clone(),
475            args: tc.arguments.clone(),
476        });
477
478        // Check before_tool_call hooks
479        let mut blocked = false;
480        for ext in config.extensions {
481            if let Some(reason) = ext.before_tool_call(tc).await {
482                outcomes.push(ToolExecOutcome {
483                    id: tc.id.clone(),
484                    name: tc.name.clone(),
485                    content: format!("Tool execution blocked: {:?}", reason),
486                    compact: None,
487                    is_error: true,
488                    terminate: false,
489                });
490                blocked = true;
491                break;
492            }
493        }
494        if blocked {
495            continue;
496        }
497
498        // Execute the tool with progress forwarding
499        let outcome = execute_single_tool(
500            tc,
501            config.agent_tools,
502            config.extensions,
503            None, // sequential: progress is emitted inline, not via channel
504        )
505        .await;
506        outcomes.push(outcome);
507    }
508
509    outcomes
510}
511
512/// Execute tool calls in parallel (pi-compatible):
513/// Phase 1 (sequential preflight): emit ToolCall events, check before_tool_call hooks.
514/// Phase 2 (concurrent execution): execute all non-blocked tools concurrently via join_all.
515/// Phase 3 (sequential post-processing): collect outcomes in original tool call order.
516async fn execute_tool_calls_parallel(
517    tool_calls: &[ToolCall],
518    config: &LoopConfig<'_>,
519    emit: &mut (dyn FnMut(AgentEvent) + Send),
520) -> Vec<ToolExecOutcome> {
521    let mut outcomes: Vec<ToolExecOutcome> = Vec::with_capacity(tool_calls.len());
522    let mut futures: Vec<
523        std::pin::Pin<Box<dyn std::future::Future<Output = ToolExecOutcome> + Send + '_>>,
524    > = Vec::new();
525
526    // Note: progress updates from parallel tool execution are not forwarded
527    // to `emit` because `emit` is FnMut (not Sync) and can't be shared across
528    // concurrent futures. To wire progress, pass a shared mpsc channel instead.
529    // `execute_single_tool` accepts `progress_tx: Option<UnboundedSender<AgentEvent>>`
530    // for this purpose. Pass Some(channel) when progress forwarding is needed.
531
532    // ── Phase 1: Sequential preflight ──
533    // Emit ToolCall events and check before_tool_call hooks one at a time.
534    for tc in tool_calls {
535        emit(AgentEvent::ToolCall {
536            id: tc.id.clone(),
537            name: tc.name.clone(),
538            args: tc.arguments.clone(),
539        });
540
541        let mut blocked = false;
542        for ext in config.extensions {
543            if let Some(reason) = ext.before_tool_call(tc).await {
544                outcomes.push(ToolExecOutcome {
545                    id: tc.id.clone(),
546                    name: tc.name.clone(),
547                    content: format!("Tool execution blocked: {:?}", reason),
548                    compact: None,
549                    is_error: true,
550                    terminate: false,
551                });
552                blocked = true;
553                break;
554            }
555        }
556        if blocked {
557            continue;
558        }
559
560        // ── Phase 2: Collect non-blocked tools as concurrent futures ──
561        // `execute_single_tool` takes an optional progress channel for streaming
562        // tool output. When None, progress updates are not forwarded (the channel
563        // is still created internally for the tool's `on_update` but discarded).
564        let tc_clone = tc.clone();
565        futures.push(Box::pin(async move {
566            execute_single_tool(
567                &tc_clone,
568                config.agent_tools,
569                config.extensions,
570                None, // progress_tx: pass Some(channel) to get streaming updates
571            )
572            .await
573        }));
574    }
575
576    // ── Phase 3: Await all concurrent executions, preserving preflight order ──
577    if !futures.is_empty() {
578        let results = join_all(futures).await;
579        outcomes.extend(results);
580    }
581
582    outcomes
583}
584
585/// Execute a single tool call and return the outcome.
586/// If `progress_tx` is provided, tool progress updates (from `on_update`) are
587/// forwarded as `AgentEvent::ToolProgress` events.
588async fn execute_single_tool(
589    tc: &ToolCall,
590    agent_tools: &[Box<dyn crate::agent::extension::AgentTool>],
591    extensions: &[Box<dyn Extension>],
592    progress_tx: Option<tokio::sync::mpsc::UnboundedSender<AgentEvent>>,
593) -> ToolExecOutcome {
594    let cancel = Cancel::new();
595
596    if let Some(tool) = find_tool(agent_tools, &tc.name) {
597        // Apply prepare_arguments if the tool defines it (pi-compatible)
598        let args = tool.prepare_arguments(tc.arguments.clone());
599
600        // Wire on_update: if progress forwarding is requested, create a channel
601        // so the tool can stream progress updates back to the agent.
602        let on_update = progress_tx.as_ref().map(|_| {
603            let (tool_tx, mut tool_rx) = tokio::sync::mpsc::unbounded_channel::<ToolOutput>();
604            if let Some(ref tx) = progress_tx {
605                let tx = tx.clone();
606                tokio::spawn(async move {
607                    while let Some(output) = tool_rx.recv().await {
608                        let _ = tx.send(AgentEvent::ToolProgress {
609                            content: output.content,
610                            is_error: output.is_error,
611                        });
612                    }
613                });
614            }
615            tool_tx
616        });
617
618        match tool.execute(tc.id.clone(), args, cancel, on_update).await {
619            Ok(output) => {
620                // Check after_tool_call hooks
621                let mut final_result = output.content.clone();
622                for ext in extensions {
623                    if let Some(overridden) = ext.after_tool_call(tc, &final_result).await {
624                        final_result = overridden;
625                    }
626                }
627
628                ToolExecOutcome {
629                    id: tc.id.clone(),
630                    name: tc.name.clone(),
631                    content: final_result,
632                    compact: output.compact,
633                    is_error: false,
634                    terminate: output.terminate,
635                }
636            }
637            Err(e) => ToolExecOutcome {
638                id: tc.id.clone(),
639                name: tc.name.clone(),
640                content: format!("{:#}", e),
641                compact: None,
642                is_error: true,
643                terminate: false,
644            },
645        }
646    } else {
647        ToolExecOutcome {
648            id: tc.id.clone(),
649            name: tc.name.clone(),
650            content: format!("Tool '{}' not found", tc.name),
651            compact: None,
652            is_error: true,
653            terminate: false,
654        }
655    }
656}
657
658#[cfg(test)]
659mod tests {
660    use super::*;
661    use crate::agent::extension::{AgentTool, BlockReason, Cancel, ToolOutput};
662    use crate::agent::provider::StreamEvent;
663    use crate::agent::types::{
664        AgentMessage, PendingMessageQueue, QueueMode, Role, ToolCall, ToolExecutionMode,
665    };
666    use async_trait::async_trait;
667    use futures::Stream;
668    use std::pin::Pin;
669    use std::sync::Arc;
670
671    // ── Mock Provider ──
672    struct MockProvider {
673        responses: Arc<std::sync::Mutex<Vec<MockResponse>>>,
674        // Track messages sent to the provider for assertions
675        sent_messages: Arc<std::sync::Mutex<Vec<Vec<AgentMessage>>>>,
676    }
677
678    struct MockResponse {
679        text: String,
680        tool_calls: Vec<ToolCall>,
681        stop_reason: StopReason,
682        thinking: String,
683    }
684
685    impl MockProvider {
686        fn new() -> Self {
687            Self {
688                responses: Arc::new(std::sync::Mutex::new(Vec::new())),
689                sent_messages: Arc::new(std::sync::Mutex::new(Vec::new())),
690            }
691        }
692
693        fn add_response(&self, text: &str) {
694            self.responses.lock().unwrap().push(MockResponse {
695                text: text.to_string(),
696                tool_calls: vec![],
697                stop_reason: StopReason::EndTurn,
698                thinking: String::new(),
699            });
700        }
701
702        fn add_tool_call_response(&self, text: &str, tool_calls: Vec<ToolCall>) {
703            self.responses.lock().unwrap().push(MockResponse {
704                text: text.to_string(),
705                tool_calls,
706                stop_reason: StopReason::ToolUse,
707                thinking: String::new(),
708            });
709        }
710
711        #[allow(dead_code)]
712        fn sent_message_count(&self) -> usize {
713            self.sent_messages.lock().unwrap().len()
714        }
715
716        #[allow(dead_code)]
717        fn last_sent_message_count(&self) -> usize {
718            let msgs = self.sent_messages.lock().unwrap();
719            msgs.last().map(|m| m.len()).unwrap_or(0)
720        }
721    }
722
723    #[async_trait]
724    impl Provider for MockProvider {
725        async fn stream(
726            &self,
727            _model: &str,
728            _system: &str,
729            messages: &[AgentMessage],
730            _tools: &[ToolDef],
731        ) -> anyhow::Result<Pin<Box<dyn Stream<Item = StreamEvent> + Send>>> {
732            // Record messages sent to this provider call
733            self.sent_messages.lock().unwrap().push(messages.to_vec());
734
735            let mut resp = self.responses.lock().unwrap();
736            let response = if resp.is_empty() {
737                // Default: return end turn with empty text
738                MockResponse {
739                    text: String::new(),
740                    tool_calls: vec![],
741                    stop_reason: StopReason::EndTurn,
742                    thinking: String::new(),
743                }
744            } else {
745                resp.remove(0)
746            };
747            drop(resp);
748
749            let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
750
751            // Send thinking if present
752            if !response.thinking.is_empty() {
753                let _ = tx.send(StreamEvent::ThinkingDelta {
754                    text: response.thinking.clone(),
755                });
756            }
757
758            // Send text deltas
759            if !response.text.is_empty() {
760                let _ = tx.send(StreamEvent::TextDelta {
761                    text: response.text.clone(),
762                });
763            }
764
765            // Send done
766            let _ = tx.send(StreamEvent::Done {
767                text: response.text,
768                usage: crate::agent::types::Usage::default(),
769                stop_reason: response.stop_reason,
770                tool_calls: response.tool_calls,
771            });
772
773            // Convert receiver to stream using futures::stream::unfold
774            use futures::stream::unfold;
775            let stream = unfold(rx, |mut rx| async move {
776                rx.recv().await.map(|event| (event, rx))
777            });
778            Ok(Box::pin(stream))
779        }
780    }
781
782    // ── Mock Tool ──
783    struct MockTool {
784        name: String,
785        execution_mode: ToolExecutionMode,
786        execute_delay: std::time::Duration,
787        executed: Arc<std::sync::Mutex<Vec<String>>>,
788        terminate: bool,
789    }
790
791    impl MockTool {
792        fn new(name: &str) -> Self {
793            Self {
794                name: name.to_string(),
795                execution_mode: ToolExecutionMode::Parallel,
796                execute_delay: std::time::Duration::ZERO,
797                executed: Arc::new(std::sync::Mutex::new(Vec::new())),
798                terminate: false,
799            }
800        }
801
802        #[allow(dead_code)]
803        fn with_sequential(mut self) -> Self {
804            self.execution_mode = ToolExecutionMode::Sequential;
805            self
806        }
807
808        fn with_delay(mut self, delay: std::time::Duration) -> Self {
809            self.execute_delay = delay;
810            self
811        }
812
813        fn with_terminate(mut self) -> Self {
814            self.terminate = true;
815            self
816        }
817    }
818
819    #[async_trait]
820    impl AgentTool for MockTool {
821        fn name(&self) -> &str {
822            &self.name
823        }
824        fn description(&self) -> &str {
825            "mock tool"
826        }
827        fn parameters(&self) -> serde_json::Value {
828            serde_json::json!({})
829        }
830        fn label(&self) -> &str {
831            &self.name
832        }
833        fn execution_mode(&self) -> ToolExecutionMode {
834            self.execution_mode
835        }
836
837        async fn execute(
838            &self,
839            tool_call_id: String,
840            _args: serde_json::Value,
841            _cancel: Cancel,
842            _on_update: Option<tokio::sync::mpsc::UnboundedSender<ToolOutput>>,
843        ) -> anyhow::Result<ToolOutput> {
844            self.executed.lock().unwrap().push(tool_call_id.clone());
845
846            if self.execute_delay > std::time::Duration::ZERO {
847                tokio::time::sleep(self.execute_delay).await;
848            }
849
850            Ok(ToolOutput {
851                content: format!("executed: {}", tool_call_id),
852                compact: None,
853                is_error: false,
854                terminate: self.terminate,
855            })
856        }
857    }
858
859    // ── Helper: collect events ──
860    #[derive(Debug, Clone)]
861    struct EventRecorder {
862        events: Arc<std::sync::Mutex<Vec<AgentEvent>>>,
863    }
864
865    impl EventRecorder {
866        fn new() -> Self {
867            Self {
868                events: Arc::new(std::sync::Mutex::new(Vec::new())),
869            }
870        }
871
872        fn record(&self, event: AgentEvent) {
873            self.events.lock().unwrap().push(event);
874        }
875
876        fn events(&self) -> Vec<AgentEvent> {
877            self.events.lock().unwrap().clone()
878        }
879
880        fn event_types(&self) -> Vec<String> {
881            self.events()
882                .iter()
883                .map(|e| match e {
884                    AgentEvent::AgentStart => "agent_start".to_string(),
885                    AgentEvent::TurnStart => "turn_start".to_string(),
886                    AgentEvent::TextDelta { .. } => "text_delta".to_string(),
887                    AgentEvent::ThinkingDelta { .. } => "thinking_delta".to_string(),
888                    AgentEvent::ToolCall { .. } => "tool_call".to_string(),
889                    AgentEvent::ToolCallArgsUpdate { .. } => "tool_call_args_update".to_string(),
890                    AgentEvent::ToolResult { .. } => "tool_result".to_string(),
891                    AgentEvent::ToolProgress { .. } => "tool_progress".to_string(),
892                    AgentEvent::Aborted { .. } => "aborted".to_string(),
893                    AgentEvent::UserMessage { .. } => "user_message".to_string(),
894                    AgentEvent::TurnEnd => "turn_end".to_string(),
895                    AgentEvent::AgentEnd { .. } => "agent_end".to_string(),
896                })
897                .collect()
898        }
899
900        fn text_deltas(&self) -> Vec<String> {
901            self.events()
902                .iter()
903                .filter_map(|e| {
904                    if let AgentEvent::TextDelta { delta } = e {
905                        Some(delta.clone())
906                    } else {
907                        None
908                    }
909                })
910                .collect()
911        }
912    }
913
914    // ── Tests ──
915
916    /// Test basic text-only response (no tool calls).
917    #[tokio::test]
918    async fn test_basic_text_response() {
919        let provider = MockProvider::new();
920        provider.add_response("Hello, world!");
921
922        let recorder = EventRecorder::new();
923        let mut emit = |e: AgentEvent| recorder.record(e);
924
925        let config = LoopConfig {
926            model: "test".to_string(),
927            system_prompt: "You are helpful.".to_string(),
928            tools: vec![],
929            agent_tools: &[],
930            extensions: &[],
931            tool_execution: ToolExecutionMode::Parallel,
932            steering_queue: None,
933            follow_up_queue: None,
934            transform_context: None,
935            prepare_next_turn: None,
936            should_stop_after_turn: None,
937        };
938
939        let prompt = AgentMessage::user("Hi");
940        let result = run_agent_loop(vec![prompt], vec![], &config, &provider, &mut emit)
941            .await
942            .unwrap();
943
944        // Should have user message + assistant message
945        assert_eq!(result.len(), 2);
946        assert_eq!(result[0].role, Role::User);
947        assert_eq!(result[1].role, Role::Assistant);
948
949        // Check event sequence
950        let types = recorder.event_types();
951        assert!(types.contains(&"agent_start".to_string()));
952        assert!(types.contains(&"text_delta".to_string()));
953        assert!(types.contains(&"turn_end".to_string()));
954        assert!(types.contains(&"agent_end".to_string()));
955
956        // Check text content
957        let texts = recorder.text_deltas();
958        assert!(texts.iter().any(|t| t == "Hello, world!"));
959    }
960
961    /// Test sequential tool execution.
962    #[tokio::test]
963    async fn test_sequential_tool_execution() {
964        let tool = MockTool::new("echo");
965        let tool_executed = Arc::clone(&tool.executed);
966        let agent_tools: Vec<Box<dyn AgentTool>> = vec![Box::new(tool)];
967
968        let provider = MockProvider::new();
969        provider.add_tool_call_response(
970            "",
971            vec![
972                ToolCall {
973                    id: "call-1".to_string(),
974                    name: "echo".to_string(),
975                    arguments: serde_json::json!({}),
976                },
977                ToolCall {
978                    id: "call-2".to_string(),
979                    name: "echo".to_string(),
980                    arguments: serde_json::json!({}),
981                },
982            ],
983        );
984        provider.add_response("Done after tools.");
985
986        let recorder = EventRecorder::new();
987        let mut emit = |e: AgentEvent| recorder.record(e);
988
989        let config = LoopConfig {
990            model: "test".to_string(),
991            system_prompt: "".to_string(),
992            tools: vec![],
993            agent_tools: &agent_tools,
994            extensions: &[],
995            tool_execution: ToolExecutionMode::Sequential,
996            steering_queue: None,
997            follow_up_queue: None,
998            transform_context: None,
999            prepare_next_turn: None,
1000            should_stop_after_turn: None,
1001        };
1002
1003        let result = run_agent_loop(
1004            vec![AgentMessage::user("run tools")],
1005            vec![],
1006            &config,
1007            &provider,
1008            &mut emit,
1009        )
1010        .await
1011        .unwrap();
1012
1013        // 1 user + 1 assistant (tool call) + 2 tool results + 1 assistant (final)
1014        assert_eq!(result.len(), 5);
1015
1016        let executed = tool_executed.lock().unwrap().clone();
1017        assert_eq!(executed.len(), 2);
1018        assert_eq!(executed[0], "call-1");
1019        assert_eq!(executed[1], "call-2");
1020
1021        // Verify event sequence includes tool calls and results
1022        let types = recorder.event_types();
1023        assert!(types.contains(&"tool_call".to_string()));
1024        assert!(types.contains(&"tool_result".to_string()));
1025    }
1026
1027    /// Test parallel tool execution: tools run concurrently.
1028    #[tokio::test]
1029    async fn test_parallel_tool_execution() {
1030        let fast_tool =
1031            Arc::new(MockTool::new("fast").with_delay(std::time::Duration::from_millis(50)));
1032        let slow_tool =
1033            Arc::new(MockTool::new("slow").with_delay(std::time::Duration::from_millis(100)));
1034        let _fast_executed = Arc::clone(&fast_tool.executed);
1035        let _slow_executed = Arc::clone(&slow_tool.executed);
1036
1037        // Track start times to verify concurrency
1038        let start_times: Arc<std::sync::Mutex<Vec<(String, std::time::Instant)>>> =
1039            Arc::new(std::sync::Mutex::new(Vec::new()));
1040        let start_times_clone = Arc::clone(&start_times);
1041
1042        struct TrackingTool {
1043            inner: MockTool,
1044            start_times: Arc<std::sync::Mutex<Vec<(String, std::time::Instant)>>>,
1045        }
1046        #[async_trait]
1047        impl AgentTool for TrackingTool {
1048            fn name(&self) -> &str {
1049                self.inner.name()
1050            }
1051            fn description(&self) -> &str {
1052                "tracking"
1053            }
1054            fn parameters(&self) -> serde_json::Value {
1055                serde_json::json!({})
1056            }
1057            fn label(&self) -> &str {
1058                self.inner.name()
1059            }
1060            async fn execute(
1061                &self,
1062                tool_call_id: String,
1063                args: serde_json::Value,
1064                cancel: Cancel,
1065                on_update: Option<tokio::sync::mpsc::UnboundedSender<ToolOutput>>,
1066            ) -> anyhow::Result<ToolOutput> {
1067                self.start_times
1068                    .lock()
1069                    .unwrap()
1070                    .push((tool_call_id.clone(), std::time::Instant::now()));
1071                self.inner
1072                    .execute(tool_call_id, args, cancel, on_update)
1073                    .await
1074            }
1075        }
1076
1077        let agent_tools: Vec<Box<dyn AgentTool>> = vec![
1078            Box::new(TrackingTool {
1079                inner: MockTool::new("slow").with_delay(std::time::Duration::from_millis(100)),
1080                start_times: Arc::clone(&start_times),
1081            }),
1082            Box::new(TrackingTool {
1083                inner: MockTool::new("fast").with_delay(std::time::Duration::from_millis(50)),
1084                start_times: Arc::clone(&start_times_clone),
1085            }),
1086        ];
1087
1088        let provider = MockProvider::new();
1089        provider.add_tool_call_response(
1090            "",
1091            vec![
1092                ToolCall {
1093                    id: "slow-1".to_string(),
1094                    name: "slow".to_string(),
1095                    arguments: serde_json::json!({}),
1096                },
1097                ToolCall {
1098                    id: "fast-1".to_string(),
1099                    name: "fast".to_string(),
1100                    arguments: serde_json::json!({}),
1101                },
1102            ],
1103        );
1104        provider.add_response("All tools done.");
1105
1106        let recorder = EventRecorder::new();
1107        let mut emit = |e: AgentEvent| recorder.record(e);
1108
1109        let config = LoopConfig {
1110            model: "test".to_string(),
1111            system_prompt: "".to_string(),
1112            tools: vec![],
1113            agent_tools: &agent_tools,
1114            extensions: &[],
1115            tool_execution: ToolExecutionMode::Parallel,
1116            steering_queue: None,
1117            follow_up_queue: None,
1118            transform_context: None,
1119            prepare_next_turn: None,
1120            should_stop_after_turn: None,
1121        };
1122
1123        run_agent_loop(
1124            vec![AgentMessage::user("run tools")],
1125            vec![],
1126            &config,
1127            &provider,
1128            &mut emit,
1129        )
1130        .await
1131        .unwrap();
1132
1133        let times = start_times.lock().unwrap();
1134        assert_eq!(times.len(), 2, "both tools should have started");
1135
1136        // Both tools should have started — in parallel mode, the second tool (fast)
1137        // starts before the first (slow) finishes. We just verify both started.
1138        let names: Vec<&str> = times.iter().map(|(n, _)| n.as_str()).collect();
1139        assert!(names.contains(&"slow-1"));
1140        assert!(names.contains(&"fast-1"));
1141    }
1142
1143    /// Test that per-tool sequential mode forces the entire batch to be sequential.
1144    #[tokio::test]
1145    async fn test_per_tool_sequential_mode() {
1146        let executed = Arc::new(std::sync::Mutex::new(Vec::new()));
1147        {
1148            // Override tools to track execution order
1149            let _seq_exec = Arc::clone(&executed);
1150            let _par_exec = Arc::clone(&executed);
1151
1152            struct SeqTool;
1153            #[async_trait]
1154            impl AgentTool for SeqTool {
1155                fn name(&self) -> &str {
1156                    "sequential_tool"
1157                }
1158                fn description(&self) -> &str {
1159                    ""
1160                }
1161                fn parameters(&self) -> serde_json::Value {
1162                    serde_json::json!({})
1163                }
1164                fn label(&self) -> &str {
1165                    "sequential_tool"
1166                }
1167                fn execution_mode(&self) -> ToolExecutionMode {
1168                    ToolExecutionMode::Sequential
1169                }
1170                async fn execute(
1171                    &self,
1172                    tool_call_id: String,
1173                    _args: serde_json::Value,
1174                    _cancel: Cancel,
1175                    _on_update: Option<tokio::sync::mpsc::UnboundedSender<ToolOutput>>,
1176                ) -> anyhow::Result<ToolOutput> {
1177                    // Simulate work
1178                    tokio::time::sleep(std::time::Duration::from_millis(30)).await;
1179                    Ok(ToolOutput::ok(format!("done: {}", tool_call_id)))
1180                }
1181            }
1182
1183            struct ParTool {
1184                executed: Arc<std::sync::Mutex<Vec<String>>>,
1185            }
1186            #[async_trait]
1187            impl AgentTool for ParTool {
1188                fn name(&self) -> &str {
1189                    "parallel_tool"
1190                }
1191                fn description(&self) -> &str {
1192                    ""
1193                }
1194                fn parameters(&self) -> serde_json::Value {
1195                    serde_json::json!({})
1196                }
1197                fn label(&self) -> &str {
1198                    "parallel_tool"
1199                }
1200                async fn execute(
1201                    &self,
1202                    tool_call_id: String,
1203                    _args: serde_json::Value,
1204                    _cancel: Cancel,
1205                    _on_update: Option<tokio::sync::mpsc::UnboundedSender<ToolOutput>>,
1206                ) -> anyhow::Result<ToolOutput> {
1207                    self.executed.lock().unwrap().push(tool_call_id.clone());
1208                    Ok(ToolOutput::ok(format!("done: {}", tool_call_id)))
1209                }
1210            }
1211
1212            let agent_tools: Vec<Box<dyn AgentTool>> = vec![
1213                Box::new(SeqTool),
1214                Box::new(ParTool {
1215                    executed: Arc::clone(&executed),
1216                }),
1217            ];
1218
1219            let provider = MockProvider::new();
1220            provider.add_tool_call_response(
1221                "",
1222                vec![
1223                    ToolCall {
1224                        id: "seq-1".to_string(),
1225                        name: "sequential_tool".to_string(),
1226                        arguments: serde_json::json!({}),
1227                    },
1228                    ToolCall {
1229                        id: "par-1".to_string(),
1230                        name: "parallel_tool".to_string(),
1231                        arguments: serde_json::json!({}),
1232                    },
1233                ],
1234            );
1235            provider.add_response("Done.");
1236
1237            let recorder = EventRecorder::new();
1238            let mut emit = |e: AgentEvent| recorder.record(e);
1239
1240            let config = LoopConfig {
1241                model: "test".to_string(),
1242                system_prompt: "".to_string(),
1243                tools: vec![],
1244                agent_tools: &agent_tools,
1245                extensions: &[],
1246                tool_execution: ToolExecutionMode::Parallel,
1247                steering_queue: None,
1248                follow_up_queue: None,
1249                transform_context: None,
1250                prepare_next_turn: None,
1251                should_stop_after_turn: None,
1252            };
1253
1254            run_agent_loop(
1255                vec![AgentMessage::user("run")],
1256                vec![],
1257                &config,
1258                &provider,
1259                &mut emit,
1260            )
1261            .await
1262            .unwrap();
1263
1264            // Both should execute (one is sequential by declaration)
1265            let exec_order = executed.lock().unwrap().clone();
1266            assert_eq!(
1267                exec_order.len(),
1268                1,
1269                "only parallel_tool records in executed"
1270            );
1271        }
1272    }
1273
1274    /// Test that terminate flag on ALL tools stops the loop.
1275    #[tokio::test]
1276    async fn test_terminate_stops_loop() {
1277        let agent_tools: Vec<Box<dyn AgentTool>> =
1278            vec![Box::new(MockTool::new("final").with_terminate())];
1279
1280        let provider = MockProvider::new();
1281        provider.add_tool_call_response(
1282            "",
1283            vec![ToolCall {
1284                id: "final-1".to_string(),
1285                name: "final".to_string(),
1286                arguments: serde_json::json!({}),
1287            }],
1288        );
1289        // No second response — loop should stop after terminate
1290
1291        let recorder = EventRecorder::new();
1292        let mut emit = |e: AgentEvent| recorder.record(e);
1293
1294        let config = LoopConfig {
1295            model: "test".to_string(),
1296            system_prompt: "".to_string(),
1297            tools: vec![],
1298            agent_tools: &agent_tools,
1299            extensions: &[],
1300            tool_execution: ToolExecutionMode::Parallel,
1301            steering_queue: None,
1302            follow_up_queue: None,
1303            transform_context: None,
1304            prepare_next_turn: None,
1305            should_stop_after_turn: None,
1306        };
1307
1308        let result = run_agent_loop(
1309            vec![AgentMessage::user("final")],
1310            vec![],
1311            &config,
1312            &provider,
1313            &mut emit,
1314        )
1315        .await
1316        .unwrap();
1317
1318        // Should have: user msg + assistant (tool call) + tool result
1319        // NOT a second assistant (which would come from another LLM call)
1320        assert_eq!(
1321            result.len(),
1322            3,
1323            "should stop after terminate without second LLM call"
1324        );
1325
1326        let types = recorder.event_types();
1327        assert!(types.contains(&"turn_end".to_string()));
1328        assert!(types.contains(&"agent_end".to_string()));
1329    }
1330
1331    /// Test that transform_context rewrites messages before LLM call.
1332    #[tokio::test]
1333    async fn test_transform_context() {
1334        let provider = MockProvider::new();
1335        provider.add_response("Response");
1336
1337        let transform_called = Arc::new(std::sync::Mutex::new(false));
1338        let transform_called_clone = Arc::clone(&transform_called);
1339
1340        let config = LoopConfig {
1341            model: "test".to_string(),
1342            system_prompt: "".to_string(),
1343            tools: vec![],
1344            agent_tools: &[],
1345            extensions: &[],
1346            tool_execution: ToolExecutionMode::Parallel,
1347            steering_queue: None,
1348            follow_up_queue: None,
1349            transform_context: Some(Box::new(move |msgs| {
1350                *transform_called_clone.lock().unwrap() = true;
1351                msgs.to_vec()
1352            })),
1353            prepare_next_turn: None,
1354            should_stop_after_turn: None,
1355        };
1356
1357        let mut emit = |_: AgentEvent| {};
1358        run_agent_loop(
1359            vec![AgentMessage::user("hi")],
1360            vec![],
1361            &config,
1362            &provider,
1363            &mut emit,
1364        )
1365        .await
1366        .unwrap();
1367
1368        assert!(
1369            *transform_called.lock().unwrap(),
1370            "transform_context should be called"
1371        );
1372    }
1373
1374    /// Test that prepare_next_turn can modify context.
1375    #[tokio::test]
1376    async fn test_prepare_next_turn() {
1377        let agent_tools: Vec<Box<dyn AgentTool>> = vec![Box::new(MockTool::new("echo"))];
1378        let provider = MockProvider::new();
1379        provider.add_tool_call_response(
1380            "",
1381            vec![ToolCall {
1382                id: "tool-1".to_string(),
1383                name: "echo".to_string(),
1384                arguments: serde_json::json!({}),
1385            }],
1386        );
1387        provider.add_response("After prepare.");
1388
1389        let prepare_called = Arc::new(std::sync::Mutex::new(false));
1390        let prepare_called_clone = Arc::clone(&prepare_called);
1391
1392        let config = LoopConfig {
1393            model: "test".to_string(),
1394            system_prompt: "".to_string(),
1395            tools: vec![],
1396            agent_tools: &agent_tools,
1397            extensions: &[],
1398            tool_execution: ToolExecutionMode::Sequential,
1399            steering_queue: None,
1400            follow_up_queue: None,
1401            transform_context: None,
1402            prepare_next_turn: Some(Box::new(move |_new_msgs| {
1403                *prepare_called_clone.lock().unwrap() = true;
1404                None // don't modify context
1405            })),
1406            should_stop_after_turn: None,
1407        };
1408
1409        let mut emit = |_: AgentEvent| {};
1410        run_agent_loop(
1411            vec![AgentMessage::user("run")],
1412            vec![],
1413            &config,
1414            &provider,
1415            &mut emit,
1416        )
1417        .await
1418        .unwrap();
1419
1420        assert!(
1421            *prepare_called.lock().unwrap(),
1422            "prepare_next_turn should be called"
1423        );
1424    }
1425
1426    /// Test that should_stop_after_turn can stop the loop early.
1427    #[tokio::test]
1428    async fn test_should_stop_after_turn() {
1429        let provider = MockProvider::new();
1430        provider.add_response("First turn.");
1431
1432        let stop = Arc::new(std::sync::Mutex::new(true));
1433        let stop_clone = Arc::clone(&stop);
1434
1435        let config = LoopConfig {
1436            model: "test".to_string(),
1437            system_prompt: "".to_string(),
1438            tools: vec![],
1439            agent_tools: &[],
1440            extensions: &[],
1441            tool_execution: ToolExecutionMode::Parallel,
1442            steering_queue: None,
1443            follow_up_queue: None,
1444            transform_context: None,
1445            prepare_next_turn: None,
1446            should_stop_after_turn: Some(Box::new(move |_| *stop_clone.lock().unwrap())),
1447        };
1448
1449        let recorder = EventRecorder::new();
1450        let mut emit = |e: AgentEvent| recorder.record(e);
1451        run_agent_loop(
1452            vec![AgentMessage::user("hi")],
1453            vec![],
1454            &config,
1455            &provider,
1456            &mut emit,
1457        )
1458        .await
1459        .unwrap();
1460
1461        // Should have exactly 1 assistant message (no second turn)
1462        let types = recorder.event_types();
1463        let agent_end_count = types.iter().filter(|t| *t == "agent_end").count();
1464        assert_eq!(agent_end_count, 1, "should end exactly once");
1465    }
1466
1467    /// Test steering queue: messages injected between turns.
1468    #[tokio::test]
1469    async fn test_steering_queue() {
1470        let agent_tools: Vec<Box<dyn AgentTool>> = vec![Box::new(MockTool::new("echo"))];
1471        let provider = MockProvider::new();
1472        provider.add_tool_call_response(
1473            "",
1474            vec![ToolCall {
1475                id: "tool-1".to_string(),
1476                name: "echo".to_string(),
1477                arguments: serde_json::json!({}),
1478            }],
1479        );
1480        provider.add_response("After tool.");
1481        provider.add_response("After steering.");
1482
1483        let steering_queue = std::sync::Mutex::new(PendingMessageQueue::new(QueueMode::OneAtATime));
1484        // Queue a steering message — it should be injected after tool result, before 2nd LLM call
1485        steering_queue
1486            .lock()
1487            .unwrap()
1488            .enqueue(AgentMessage::user("steer here"));
1489
1490        let recorder = EventRecorder::new();
1491        let mut emit = |e: AgentEvent| recorder.record(e);
1492
1493        let config = LoopConfig {
1494            model: "test".to_string(),
1495            system_prompt: "".to_string(),
1496            tools: vec![],
1497            agent_tools: &agent_tools,
1498            extensions: &[],
1499            tool_execution: ToolExecutionMode::Sequential,
1500            steering_queue: Some(&steering_queue),
1501            follow_up_queue: None,
1502            transform_context: None,
1503            prepare_next_turn: None,
1504            should_stop_after_turn: None,
1505        };
1506
1507        let result = run_agent_loop(
1508            vec![AgentMessage::user("run")],
1509            vec![],
1510            &config,
1511            &provider,
1512            &mut emit,
1513        )
1514        .await
1515        .unwrap();
1516
1517        // The steering message should appear as a UserMessage in events and as
1518        // an injected user message in the result
1519        let types = recorder.event_types();
1520        let user_msg_count = types.iter().filter(|t| *t == "user_message").count();
1521        assert!(
1522            user_msg_count >= 1,
1523            "steering should produce at least one user_message event, got {}",
1524            user_msg_count
1525        );
1526
1527        // Result should contain: user prompt + assistant (tool call) + tool result + steering user + assistant
1528        let user_messages: Vec<&AgentMessage> =
1529            result.iter().filter(|m| m.role == Role::User).collect();
1530        assert_eq!(
1531            user_messages.len(),
1532            2,
1533            "should have original prompt + steering message"
1534        );
1535    }
1536
1537    /// Test follow-up queue: messages injected after agent is idle.
1538    #[tokio::test]
1539    async fn test_follow_up_queue() {
1540        let provider = MockProvider::new();
1541        provider.add_response("First response.");
1542        provider.add_response("Follow-up response.");
1543
1544        let follow_up_queue =
1545            std::sync::Mutex::new(PendingMessageQueue::new(QueueMode::OneAtATime));
1546        follow_up_queue
1547            .lock()
1548            .unwrap()
1549            .enqueue(AgentMessage::user("follow up"));
1550
1551        let recorder = EventRecorder::new();
1552        let mut emit = |e: AgentEvent| recorder.record(e);
1553
1554        let config = LoopConfig {
1555            model: "test".to_string(),
1556            system_prompt: "".to_string(),
1557            tools: vec![],
1558            agent_tools: &[],
1559            extensions: &[],
1560            tool_execution: ToolExecutionMode::Parallel,
1561            steering_queue: None,
1562            follow_up_queue: Some(&follow_up_queue),
1563            transform_context: None,
1564            prepare_next_turn: None,
1565            should_stop_after_turn: None,
1566        };
1567
1568        let result = run_agent_loop(
1569            vec![AgentMessage::user("first")],
1570            vec![],
1571            &config,
1572            &provider,
1573            &mut emit,
1574        )
1575        .await
1576        .unwrap();
1577
1578        // Should have: user + assistant (first) + user (follow-up) + assistant (follow-up)
1579        assert_eq!(
1580            result.len(),
1581            4,
1582            "follow-up should add another user+assistant pair"
1583        );
1584        assert_eq!(
1585            result[2].content, "follow up",
1586            "third message should be the injected follow-up"
1587        );
1588
1589        let types = recorder.event_types();
1590        assert!(types.contains(&"user_message".to_string()));
1591    }
1592
1593    /// Test PendingMessageQueue drain modes.
1594    #[tokio::test]
1595    async fn test_message_queue_modes() {
1596        // OneAtATime: drain one message at a time
1597        let mut queue = PendingMessageQueue::new(QueueMode::OneAtATime);
1598        queue.enqueue(AgentMessage::user("msg1"));
1599        queue.enqueue(AgentMessage::user("msg2"));
1600
1601        let batch1 = queue.drain();
1602        assert_eq!(batch1.len(), 1, "OneAtATime should drain 1");
1603        assert_eq!(batch1[0].content, "msg1");
1604
1605        let batch2 = queue.drain();
1606        assert_eq!(batch2.len(), 1, "OneAtATime should drain 1 on second call");
1607        assert_eq!(batch2[0].content, "msg2");
1608
1609        assert!(
1610            queue.drain().is_empty(),
1611            "should be empty after both drained"
1612        );
1613
1614        // All: drain all at once
1615        let mut queue = PendingMessageQueue::new(QueueMode::All);
1616        queue.enqueue(AgentMessage::user("a"));
1617        queue.enqueue(AgentMessage::user("b"));
1618
1619        let all = queue.drain();
1620        assert_eq!(all.len(), 2, "All mode should drain both");
1621        assert!(queue.drain().is_empty(), "should be empty after drain");
1622
1623        // Clear
1624        let mut queue = PendingMessageQueue::new(QueueMode::OneAtATime);
1625        queue.enqueue(AgentMessage::user("x"));
1626        queue.clear();
1627        assert!(queue.is_empty());
1628    }
1629
1630    /// Test that prepare_arguments is called on the tool.
1631    #[tokio::test]
1632    async fn test_prepare_arguments() {
1633        struct PrepTool;
1634        #[async_trait]
1635        impl AgentTool for PrepTool {
1636            fn name(&self) -> &str {
1637                "prep_tool"
1638            }
1639            fn description(&self) -> &str {
1640                ""
1641            }
1642            fn parameters(&self) -> serde_json::Value {
1643                serde_json::json!({})
1644            }
1645            fn label(&self) -> &str {
1646                "prep_tool"
1647            }
1648            fn prepare_arguments(&self, args: serde_json::Value) -> serde_json::Value {
1649                let mut m = serde_json::Map::new();
1650                m.insert("prepared".to_string(), serde_json::json!(true));
1651                if let Some(obj) = args.as_object() {
1652                    for (k, v) in obj {
1653                        m.insert(k.clone(), v.clone());
1654                    }
1655                }
1656                serde_json::Value::Object(m)
1657            }
1658            async fn execute(
1659                &self,
1660                _tool_call_id: String,
1661                args: serde_json::Value,
1662                _cancel: Cancel,
1663                _on_update: Option<tokio::sync::mpsc::UnboundedSender<ToolOutput>>,
1664            ) -> anyhow::Result<ToolOutput> {
1665                // Verify prepare_arguments was called: args should have "prepared": true
1666                assert_eq!(args.get("prepared").and_then(|v| v.as_bool()), Some(true));
1667                Ok(ToolOutput::ok("prepared ok"))
1668            }
1669        }
1670
1671        let agent_tools: Vec<Box<dyn AgentTool>> = vec![Box::new(PrepTool)];
1672        let provider = MockProvider::new();
1673        provider.add_tool_call_response(
1674            "",
1675            vec![ToolCall {
1676                id: "tool-1".to_string(),
1677                name: "prep_tool".to_string(),
1678                arguments: serde_json::json!({"original": "value"}),
1679            }],
1680        );
1681        provider.add_response("Done.");
1682
1683        let config = LoopConfig {
1684            model: "test".to_string(),
1685            system_prompt: "".to_string(),
1686            tools: vec![],
1687            agent_tools: &agent_tools,
1688            extensions: &[],
1689            tool_execution: ToolExecutionMode::Sequential,
1690            steering_queue: None,
1691            follow_up_queue: None,
1692            transform_context: None,
1693            prepare_next_turn: None,
1694            should_stop_after_turn: None,
1695        };
1696
1697        let mut emit = |_: AgentEvent| {};
1698        let result = run_agent_loop(
1699            vec![AgentMessage::user("prep")],
1700            vec![],
1701            &config,
1702            &provider,
1703            &mut emit,
1704        )
1705        .await;
1706
1707        assert!(
1708            result.is_ok(),
1709            "prepare_arguments should work without error"
1710        );
1711    }
1712
1713    /// Test that before_tool_call can block execution.
1714    #[tokio::test]
1715    async fn test_before_tool_call_blocks() {
1716        struct BlockingExt;
1717        #[async_trait]
1718        impl Extension for BlockingExt {
1719            fn name(&self) -> std::borrow::Cow<'static, str> {
1720                std::borrow::Cow::Borrowed("blocker")
1721            }
1722            async fn before_tool_call(&self, _tc: &ToolCall) -> Option<BlockReason> {
1723                Some(BlockReason::Security("blocked for test".into()))
1724            }
1725        }
1726
1727        let agent_tools: Vec<Box<dyn AgentTool>> = vec![Box::new(MockTool::new("echo"))];
1728        let extensions: Vec<Box<dyn Extension>> = vec![Box::new(BlockingExt)];
1729
1730        let provider = MockProvider::new();
1731        provider.add_tool_call_response(
1732            "",
1733            vec![ToolCall {
1734                id: "tool-1".to_string(),
1735                name: "echo".to_string(),
1736                arguments: serde_json::json!({}),
1737            }],
1738        );
1739        provider.add_response("After blocked tool.");
1740
1741        let recorder = EventRecorder::new();
1742        let mut emit = |e: AgentEvent| recorder.record(e);
1743
1744        let config = LoopConfig {
1745            model: "test".to_string(),
1746            system_prompt: "".to_string(),
1747            tools: vec![],
1748            agent_tools: &agent_tools,
1749            extensions: &extensions,
1750            tool_execution: ToolExecutionMode::Sequential,
1751            steering_queue: None,
1752            follow_up_queue: None,
1753            transform_context: None,
1754            prepare_next_turn: None,
1755            should_stop_after_turn: None,
1756        };
1757
1758        let result = run_agent_loop(
1759            vec![AgentMessage::user("block test")],
1760            vec![],
1761            &config,
1762            &provider,
1763            &mut emit,
1764        )
1765        .await
1766        .unwrap();
1767
1768        // Should have: user + assistant (tool call) + tool result (blocked) + assistant (after)
1769        assert!(
1770            result.len() >= 3,
1771            "blocked tool should still produce a result"
1772        );
1773
1774        // Find the tool result
1775        let tool_results: Vec<&AgentMessage> = result
1776            .iter()
1777            .filter(|m| m.role == Role::ToolResult)
1778            .collect();
1779        assert!(!tool_results.is_empty());
1780        assert!(
1781            tool_results[0].is_error,
1782            "blocked tool result should be error"
1783        );
1784        assert!(
1785            tool_results[0].content.contains("blocked"),
1786            "blocked result should mention block reason"
1787        );
1788    }
1789
1790    /// Test error response from provider leads to graceful abort.
1791    #[tokio::test]
1792    async fn test_provider_error_aborts() {
1793        // A provider that returns an error
1794        struct ErrorProvider;
1795        #[async_trait]
1796        impl Provider for ErrorProvider {
1797            async fn stream(
1798                &self,
1799                _model: &str,
1800                _system: &str,
1801                _messages: &[AgentMessage],
1802                _tools: &[ToolDef],
1803            ) -> anyhow::Result<Pin<Box<dyn Stream<Item = StreamEvent> + Send>>> {
1804                anyhow::bail!("provider error")
1805            }
1806        }
1807
1808        let recorder = EventRecorder::new();
1809        let mut emit = |e: AgentEvent| recorder.record(e);
1810
1811        let config = LoopConfig {
1812            model: "test".to_string(),
1813            system_prompt: "".to_string(),
1814            tools: vec![],
1815            agent_tools: &[],
1816            extensions: &[],
1817            tool_execution: ToolExecutionMode::Parallel,
1818            steering_queue: None,
1819            follow_up_queue: None,
1820            transform_context: None,
1821            prepare_next_turn: None,
1822            should_stop_after_turn: None,
1823        };
1824
1825        let result = run_agent_loop(
1826            vec![AgentMessage::user("hi")],
1827            vec![],
1828            &config,
1829            &ErrorProvider,
1830            &mut emit,
1831        )
1832        .await;
1833
1834        // Provider error should propagate as an Err
1835        assert!(result.is_err(), "provider error should propagate");
1836    }
1837
1838    /// Test that tool execution errors are reported as tool results.
1839    #[tokio::test]
1840    async fn test_tool_execution_error() {
1841        struct ErrorTool;
1842        #[async_trait]
1843        impl AgentTool for ErrorTool {
1844            fn name(&self) -> &str {
1845                "error_tool"
1846            }
1847            fn description(&self) -> &str {
1848                ""
1849            }
1850            fn parameters(&self) -> serde_json::Value {
1851                serde_json::json!({})
1852            }
1853            fn label(&self) -> &str {
1854                "error_tool"
1855            }
1856            async fn execute(
1857                &self,
1858                _tool_call_id: String,
1859                _args: serde_json::Value,
1860                _cancel: Cancel,
1861                _on_update: Option<tokio::sync::mpsc::UnboundedSender<ToolOutput>>,
1862            ) -> anyhow::Result<ToolOutput> {
1863                anyhow::bail!("tool crashed")
1864            }
1865        }
1866
1867        let agent_tools: Vec<Box<dyn AgentTool>> = vec![Box::new(ErrorTool)];
1868        let provider = MockProvider::new();
1869        provider.add_tool_call_response(
1870            "",
1871            vec![ToolCall {
1872                id: "tool-1".to_string(),
1873                name: "error_tool".to_string(),
1874                arguments: serde_json::json!({}),
1875            }],
1876        );
1877        provider.add_response("After error.");
1878
1879        let recorder = EventRecorder::new();
1880        let mut emit = |e: AgentEvent| recorder.record(e);
1881
1882        let config = LoopConfig {
1883            model: "test".to_string(),
1884            system_prompt: "".to_string(),
1885            tools: vec![],
1886            agent_tools: &agent_tools,
1887            extensions: &[],
1888            tool_execution: ToolExecutionMode::Sequential,
1889            steering_queue: None,
1890            follow_up_queue: None,
1891            transform_context: None,
1892            prepare_next_turn: None,
1893            should_stop_after_turn: None,
1894        };
1895
1896        let result = run_agent_loop(
1897            vec![AgentMessage::user("error test")],
1898            vec![],
1899            &config,
1900            &provider,
1901            &mut emit,
1902        )
1903        .await
1904        .unwrap();
1905
1906        // Should have error tool result
1907        let tool_results: Vec<&AgentMessage> = result
1908            .iter()
1909            .filter(|m| m.role == Role::ToolResult)
1910            .collect();
1911        assert!(!tool_results.is_empty());
1912        assert!(tool_results[0].is_error);
1913    }
1914
1915    /// Test that tool not found produces an error tool result.
1916    #[tokio::test]
1917    async fn test_tool_not_found() {
1918        let provider = MockProvider::new();
1919        provider.add_tool_call_response(
1920            "",
1921            vec![ToolCall {
1922                id: "tool-1".to_string(),
1923                name: "nonexistent".to_string(),
1924                arguments: serde_json::json!({}),
1925            }],
1926        );
1927        provider.add_response("After missing tool.");
1928
1929        // Empty agent_tools — the tool won't be found
1930        let agent_tools: Vec<Box<dyn AgentTool>> = vec![];
1931
1932        let recorder = EventRecorder::new();
1933        let mut emit = |e: AgentEvent| recorder.record(e);
1934
1935        let config = LoopConfig {
1936            model: "test".to_string(),
1937            system_prompt: "".to_string(),
1938            tools: vec![],
1939            agent_tools: &agent_tools,
1940            extensions: &[],
1941            tool_execution: ToolExecutionMode::Sequential,
1942            steering_queue: None,
1943            follow_up_queue: None,
1944            transform_context: None,
1945            prepare_next_turn: None,
1946            should_stop_after_turn: None,
1947        };
1948
1949        let result = run_agent_loop(
1950            vec![AgentMessage::user("test")],
1951            vec![],
1952            &config,
1953            &provider,
1954            &mut emit,
1955        )
1956        .await
1957        .unwrap();
1958
1959        let tool_results: Vec<&AgentMessage> = result
1960            .iter()
1961            .filter(|m| m.role == Role::ToolResult)
1962            .collect();
1963        assert!(!tool_results.is_empty());
1964        assert!(tool_results[0].is_error);
1965        assert!(tool_results[0].content.contains("not found"));
1966    }
1967}