Skip to main content

llm_core/
chain.rs

1use async_trait::async_trait;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4
5use crate::provider::Provider;
6use crate::stream::{Chunk, collect_text, collect_tool_calls, collect_usage};
7use crate::types::{Message, Prompt, ToolCall, ToolResult, Usage};
8
9/// Trait for executing tool calls. Implement this to provide tool execution logic.
10///
11/// On `wasm32`, the `Send + Sync` bound is dropped so executors may hold
12/// non-`Send` host types like `js_sys::Function`.
13#[cfg(not(target_arch = "wasm32"))]
14#[async_trait]
15pub trait ToolExecutor: Send + Sync {
16    async fn execute(&self, call: &ToolCall) -> ToolResult;
17}
18
19#[cfg(target_arch = "wasm32")]
20#[async_trait(?Send)]
21pub trait ToolExecutor {
22    async fn execute(&self, call: &ToolCall) -> ToolResult;
23}
24
25/// Configuration for parallel tool dispatch within a single chain iteration.
26///
27/// When multiple tool calls are emitted in one turn, they are dispatched
28/// concurrently by default (tool work is almost entirely I/O-bound, so this
29/// collapses N serial latencies into ~1). Order of `tool_results` is preserved
30/// regardless of the dispatch strategy.
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ParallelConfig {
33    /// If false, tool calls within a single iteration are executed sequentially.
34    pub enabled: bool,
35    /// Optional cap on the number of tool calls dispatched concurrently.
36    /// `None` = unlimited. `Some(n)` uses a bounded `buffered(n)` stream.
37    pub max_concurrent: Option<usize>,
38}
39
40impl Default for ParallelConfig {
41    fn default() -> Self {
42        Self {
43            enabled: true,
44            max_concurrent: None,
45        }
46    }
47}
48
49/// Dispatch a batch of tool calls, returning results in input order.
50async fn dispatch_tools(
51    executor: &dyn ToolExecutor,
52    calls: &[ToolCall],
53    parallel: &ParallelConfig,
54) -> Vec<ToolResult> {
55    // Sequential fast path: disabled or trivially single-call.
56    if !parallel.enabled || calls.len() <= 1 {
57        let mut out = Vec::with_capacity(calls.len());
58        for call in calls {
59            out.push(executor.execute(call).await);
60        }
61        return out;
62    }
63
64    // Eagerly collect the per-call futures into a Vec. This is the reliable
65    // way to pacify the borrow checker: `stream::iter(calls.iter()).map(|c|
66    // executor.execute(c)).buffered(n)` trips the elided lifetimes inside
67    // the async_trait-returned boxed future. Collecting first puts all
68    // borrows under a single lifetime tied to the enclosing `async fn`.
69    let futs: Vec<_> = calls.iter().map(|c| executor.execute(c)).collect();
70    match parallel.max_concurrent {
71        Some(n) if n > 0 => {
72            futures::stream::iter(futs)
73                .buffered(n)
74                .collect::<Vec<_>>()
75                .await
76        }
77        _ => futures::future::join_all(futs).await,
78    }
79}
80
81/// Event emitted during chain loop execution for observability.
82#[derive(Debug, Clone)]
83pub enum ChainEvent {
84    /// Emitted before the provider is called for an iteration.
85    IterationStart {
86        /// 1-based iteration number.
87        iteration: usize,
88        /// The chain limit.
89        limit: usize,
90        /// Current message history being sent to the provider.
91        messages: Vec<Message>,
92    },
93    /// Emitted after an iteration completes (chunks collected, tool calls extracted).
94    IterationEnd {
95        /// 1-based iteration number.
96        iteration: usize,
97        /// Per-iteration token usage, if the provider reported it.
98        usage: Option<Usage>,
99        /// Cumulative usage across all iterations up to and including this one.
100        cumulative_usage: Option<Usage>,
101        /// Tool calls extracted from this iteration's response.
102        tool_calls: Vec<ToolCall>,
103    },
104    /// Emitted when the budget is exhausted (after completing the current iteration).
105    BudgetExhausted {
106        /// Cumulative usage at the point the budget was exceeded.
107        cumulative_usage: Usage,
108        /// The budget limit that was exceeded.
109        budget: u64,
110    },
111}
112
113/// Result of a chain loop execution.
114pub struct ChainResult {
115    /// All chunks from all iterations.
116    pub chunks: Vec<Chunk>,
117    /// All tool results from all iterations (in execution order).
118    pub tool_results: Vec<ToolResult>,
119    /// Accumulated usage across all iterations.
120    pub total_usage: Option<Usage>,
121    /// Whether the chain stopped because the budget was exhausted.
122    pub budget_exhausted: bool,
123    /// Final message history after the chain loop completes — includes the
124    /// seed messages, every assistant turn, and any tool-result messages.
125    pub messages: Vec<Message>,
126}
127
128/// Run a chain loop: execute -> collect tool calls -> execute tools -> repeat.
129///
130/// Stops when:
131/// - No tool calls are returned (normal completion)
132/// - `chain_limit` iterations are reached
133/// - `budget` is exceeded (graceful stop after completing current iteration)
134///
135/// `on_chunk` is called for every chunk from every iteration.
136///
137/// The chain accumulates a `Vec<Message>` across iterations so that each
138/// provider call sees the full conversation history (user, assistant+tools,
139/// tool results, ...).
140#[allow(clippy::too_many_arguments)]
141pub async fn chain(
142    provider: &dyn Provider,
143    model: &str,
144    initial_prompt: Prompt,
145    key: Option<&str>,
146    stream: bool,
147    executor: &dyn ToolExecutor,
148    chain_limit: usize,
149    on_chunk: &mut dyn FnMut(&Chunk),
150    on_event: Option<&mut dyn FnMut(&ChainEvent)>,
151    budget: Option<u64>,
152    parallel: ParallelConfig,
153) -> crate::Result<ChainResult> {
154    let mut all_chunks = Vec::new();
155    let mut all_tool_results = Vec::new();
156    let mut on_event = on_event;
157    let mut cumulative_usage: Option<Usage> = None;
158    let mut budget_exhausted = false;
159
160    // Seed messages from initial prompt
161    let mut messages: Vec<Message> = if initial_prompt.messages.is_empty() {
162        vec![Message::user(&initial_prompt.text)]
163    } else {
164        initial_prompt.messages.clone()
165    };
166
167    for iteration in 1..=chain_limit {
168        if let Some(cb) = &mut on_event {
169            cb(&ChainEvent::IterationStart {
170                iteration,
171                limit: chain_limit,
172                messages: messages.clone(),
173            });
174        }
175
176        // Build prompt with accumulated messages + preserved metadata
177        let mut prompt = Prompt::new(&initial_prompt.text)
178            .with_tools(initial_prompt.tools.clone())
179            .with_messages(messages.clone());
180        if let Some(system) = &initial_prompt.system {
181            prompt = prompt.with_system(system);
182        }
183        if let Some(schema) = &initial_prompt.schema {
184            prompt = prompt.with_schema(schema.clone());
185        }
186
187        let response_stream = provider.execute(model, &prompt, key, stream).await?;
188
189        let mut iteration_chunks = Vec::new();
190        let mut pinned = std::pin::pin!(response_stream);
191
192        while let Some(result) = pinned.next().await {
193            let chunk = result?;
194            on_chunk(&chunk);
195            iteration_chunks.push(chunk);
196        }
197
198        let tool_calls = collect_tool_calls(&iteration_chunks);
199        let usage = collect_usage(&iteration_chunks);
200        let text = collect_text(&iteration_chunks);
201
202        // Accumulate usage
203        cumulative_usage = match (&cumulative_usage, &usage) {
204            (Some(cum), Some(iter_usage)) => Some(cum.add(iter_usage)),
205            (None, Some(iter_usage)) => Some(iter_usage.clone()),
206            (cum, None) => cum.clone(),
207        };
208
209        if let Some(cb) = &mut on_event {
210            cb(&ChainEvent::IterationEnd {
211                iteration,
212                usage: usage.clone(),
213                cumulative_usage: cumulative_usage.clone(),
214                tool_calls: tool_calls.clone(),
215            });
216        }
217
218        all_chunks.extend(iteration_chunks);
219
220        // Append assistant message to history
221        messages.push(Message::assistant_with_tool_calls(&text, tool_calls.clone()));
222
223        if tool_calls.is_empty() {
224            break;
225        }
226
227        // Check budget after completing the iteration
228        if let (Some(b), Some(cum)) = (budget, &cumulative_usage)
229            && cum.total() >= b
230        {
231            budget_exhausted = true;
232            if let Some(cb) = &mut on_event {
233                cb(&ChainEvent::BudgetExhausted {
234                    cumulative_usage: cum.clone(),
235                    budget: b,
236                });
237            }
238            break;
239        }
240
241        // Execute tool calls (parallel by default, order-preserving).
242        let tool_results = dispatch_tools(executor, &tool_calls, &parallel).await;
243
244        all_tool_results.extend(tool_results.clone());
245
246        // Append tool results to history
247        messages.push(Message::tool_results(tool_results));
248    }
249
250    Ok(ChainResult {
251        chunks: all_chunks,
252        tool_results: all_tool_results,
253        total_usage: cumulative_usage,
254        budget_exhausted,
255        messages,
256    })
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use crate::error::LlmError;
263    use crate::stream::ResponseStream;
264    use crate::types::{ModelInfo, Tool};
265    use std::sync::atomic::{AtomicUsize, Ordering};
266    use std::sync::{Arc, Mutex};
267
268    // Mock provider that returns pre-configured responses and captures prompts
269    struct MockProvider {
270        responses: Vec<Vec<Chunk>>,
271        call_count: AtomicUsize,
272        captured_prompts: Arc<Mutex<Vec<Prompt>>>,
273    }
274
275    impl MockProvider {
276        fn new(responses: Vec<Vec<Chunk>>) -> Self {
277            Self {
278                responses,
279                call_count: AtomicUsize::new(0),
280                captured_prompts: Arc::new(Mutex::new(Vec::new())),
281            }
282        }
283    }
284
285    #[cfg_attr(not(target_arch = "wasm32"), async_trait)]
286    #[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
287    impl Provider for MockProvider {
288        fn id(&self) -> &str {
289            "mock"
290        }
291        fn models(&self) -> Vec<ModelInfo> {
292            vec![ModelInfo::new("mock-model")]
293        }
294        async fn execute(
295            &self,
296            _model: &str,
297            prompt: &Prompt,
298            _key: Option<&str>,
299            _stream: bool,
300        ) -> crate::Result<ResponseStream> {
301            self.captured_prompts.lock().unwrap().push(prompt.clone());
302            let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
303            let chunks = if idx < self.responses.len() {
304                self.responses[idx].clone()
305            } else {
306                // Fallback: return last response
307                self.responses.last().cloned().unwrap_or_default()
308            };
309            let items: Vec<Result<Chunk, LlmError>> = chunks.into_iter().map(Ok).collect();
310            Ok(Box::pin(futures::stream::iter(items)))
311        }
312    }
313
314    // Mock executor
315    struct MockExecutor;
316
317    #[cfg_attr(not(target_arch = "wasm32"), async_trait)]
318    #[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
319    impl ToolExecutor for MockExecutor {
320        async fn execute(&self, call: &ToolCall) -> ToolResult {
321            ToolResult {
322                name: call.name.clone(),
323                output: format!("result for {}", call.name),
324                tool_call_id: call.tool_call_id.clone(),
325                error: None,
326            }
327        }
328    }
329
330    struct ErrorExecutor;
331
332    #[cfg_attr(not(target_arch = "wasm32"), async_trait)]
333    #[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
334    impl ToolExecutor for ErrorExecutor {
335        async fn execute(&self, call: &ToolCall) -> ToolResult {
336            ToolResult {
337                name: call.name.clone(),
338                output: String::new(),
339                tool_call_id: call.tool_call_id.clone(),
340                error: Some("tool failed".into()),
341            }
342        }
343    }
344
345    fn text_response(text: &str) -> Vec<Chunk> {
346        vec![Chunk::Text(text.into()), Chunk::Done]
347    }
348
349    fn tool_call_response(name: &str, id: &str, args: &str) -> Vec<Chunk> {
350        vec![
351            Chunk::ToolCallStart {
352                name: name.into(),
353                id: Some(id.into()),
354            },
355            Chunk::ToolCallDelta {
356                content: args.into(),
357            },
358            Chunk::Done,
359        ]
360    }
361
362    fn make_tool() -> Tool {
363        Tool {
364            name: "test_tool".into(),
365            description: "A test".into(),
366            input_schema: serde_json::json!({"type": "object"}),
367        }
368    }
369
370    #[tokio::test]
371    async fn chain_no_tool_calls_single_iteration() {
372        let provider = MockProvider::new(vec![text_response("Hello!")]);
373        let prompt = Prompt::new("Hi").with_tools(vec![make_tool()]);
374        let mut callback_count = 0;
375
376        let result = chain(
377            &provider,
378            "mock-model",
379            prompt,
380            None,
381            false,
382            &MockExecutor,
383            5,
384            &mut |_| callback_count += 1,
385            None,
386            None,
387            ParallelConfig::default(),
388        )
389        .await
390        .unwrap();
391
392        assert_eq!(crate::collect_text(&result.chunks), "Hello!");
393        assert!(result.tool_results.is_empty());
394        assert_eq!(callback_count, 2); // Text + Done
395        assert_eq!(provider.call_count.load(Ordering::SeqCst), 1);
396    }
397
398    #[tokio::test]
399    async fn chain_single_tool_call_two_iterations() {
400        let provider = MockProvider::new(vec![
401            tool_call_response("test_tool", "tc_1", "{}"),
402            text_response("Done!"),
403        ]);
404        let prompt = Prompt::new("Do something").with_tools(vec![make_tool()]);
405
406        let result = chain(
407            &provider,
408            "mock-model",
409            prompt,
410            None,
411            false,
412            &MockExecutor,
413            5,
414            &mut |_| {},
415            None,
416            None,
417            ParallelConfig::default(),
418        )
419        .await
420        .unwrap();
421
422        assert_eq!(crate::collect_text(&result.chunks), "Done!");
423        assert_eq!(result.tool_results.len(), 1);
424        assert_eq!(result.tool_results[0].name, "test_tool");
425        assert_eq!(provider.call_count.load(Ordering::SeqCst), 2);
426    }
427
428    #[tokio::test]
429    async fn chain_limit_stops_loop() {
430        // Always returns tool calls - should stop at limit
431        let provider = MockProvider::new(vec![
432            tool_call_response("test_tool", "tc_1", "{}"),
433        ]);
434        let prompt = Prompt::new("Loop").with_tools(vec![make_tool()]);
435
436        let result = chain(
437            &provider,
438            "mock-model",
439            prompt,
440            None,
441            false,
442            &MockExecutor,
443            3,
444            &mut |_| {},
445            None,
446            None,
447            ParallelConfig::default(),
448        )
449        .await
450        .unwrap();
451
452        assert_eq!(provider.call_count.load(Ordering::SeqCst), 3);
453        assert_eq!(result.tool_results.len(), 3);
454    }
455
456    #[tokio::test]
457    async fn chain_multiple_tool_calls() {
458        let response = vec![
459            Chunk::ToolCallStart {
460                name: "tool_a".into(),
461                id: Some("tc_1".into()),
462            },
463            Chunk::ToolCallDelta {
464                content: "{}".into(),
465            },
466            Chunk::ToolCallStart {
467                name: "tool_b".into(),
468                id: Some("tc_2".into()),
469            },
470            Chunk::ToolCallDelta {
471                content: "{}".into(),
472            },
473            Chunk::Done,
474        ];
475
476        let provider = MockProvider::new(vec![response, text_response("All done")]);
477        let prompt = Prompt::new("Do both").with_tools(vec![make_tool()]);
478
479        let result = chain(
480            &provider,
481            "mock-model",
482            prompt,
483            None,
484            false,
485            &MockExecutor,
486            5,
487            &mut |_| {},
488            None,
489            None,
490            ParallelConfig::default(),
491        )
492        .await
493        .unwrap();
494
495        assert_eq!(crate::collect_text(&result.chunks), "All done");
496        assert_eq!(result.tool_results.len(), 2);
497        assert_eq!(result.tool_results[0].name, "tool_a");
498        assert_eq!(result.tool_results[1].name, "tool_b");
499        assert_eq!(provider.call_count.load(Ordering::SeqCst), 2);
500    }
501
502    #[tokio::test]
503    async fn chain_tool_error_continues() {
504        let provider = MockProvider::new(vec![
505            tool_call_response("test_tool", "tc_1", "{}"),
506            text_response("Handled error"),
507        ]);
508        let prompt = Prompt::new("Try").with_tools(vec![make_tool()]);
509
510        let result = chain(
511            &provider,
512            "mock-model",
513            prompt,
514            None,
515            false,
516            &ErrorExecutor,
517            5,
518            &mut |_| {},
519            None,
520            None,
521            ParallelConfig::default(),
522        )
523        .await
524        .unwrap();
525
526        assert_eq!(crate::collect_text(&result.chunks), "Handled error");
527        assert_eq!(result.tool_results.len(), 1);
528        assert!(result.tool_results[0].error.is_some());
529        assert_eq!(provider.call_count.load(Ordering::SeqCst), 2);
530    }
531
532    #[tokio::test]
533    async fn chain_callback_receives_chunks() {
534        let provider = MockProvider::new(vec![text_response("Hi")]);
535        let prompt = Prompt::new("Hello").with_tools(vec![make_tool()]);
536        let received = Arc::new(std::sync::Mutex::new(Vec::new()));
537        let received_clone = received.clone();
538
539        let _ = chain(
540            &provider,
541            "mock-model",
542            prompt,
543            None,
544            false,
545            &MockExecutor,
546            5,
547            &mut |chunk| received_clone.lock().unwrap().push(chunk.clone()),
548            None,
549            None,
550            ParallelConfig::default(),
551        )
552        .await
553        .unwrap();
554
555        let chunks = received.lock().unwrap();
556        assert_eq!(chunks.len(), 2);
557        assert!(matches!(&chunks[0], Chunk::Text(t) if t == "Hi"));
558        assert!(matches!(&chunks[1], Chunk::Done));
559    }
560
561    #[tokio::test]
562    async fn chain_accumulates_messages_across_turns() {
563        // 3-iteration test: tool call → tool call → text
564        let provider = MockProvider::new(vec![
565            tool_call_response("test_tool", "tc_1", "{}"),
566            tool_call_response("test_tool", "tc_2", "{}"),
567            text_response("Done!"),
568        ]);
569        let prompt = Prompt::new("Do it").with_tools(vec![make_tool()]);
570
571        let _ = chain(
572            &provider, "mock-model", prompt, None, false,
573            &MockExecutor, 5, &mut |_| {}, None, None,
574            ParallelConfig::default(),
575        ).await.unwrap();
576
577        let prompts = provider.captured_prompts.lock().unwrap();
578        assert_eq!(prompts.len(), 3);
579
580        // Iteration 1: [user]
581        assert_eq!(prompts[0].messages.len(), 1);
582        assert_eq!(prompts[0].messages[0].role, crate::Role::User);
583
584        // Iteration 2: [user, assistant+tools, tool_results]
585        assert_eq!(prompts[1].messages.len(), 3);
586        assert_eq!(prompts[1].messages[0].role, crate::Role::User);
587        assert_eq!(prompts[1].messages[1].role, crate::Role::Assistant);
588        assert!(!prompts[1].messages[1].tool_calls.is_empty());
589        assert_eq!(prompts[1].messages[2].role, crate::Role::Tool);
590
591        // Iteration 3: [user, assistant+tools, tool_results, assistant+tools, tool_results]
592        assert_eq!(prompts[2].messages.len(), 5);
593    }
594
595    #[tokio::test]
596    async fn chain_preserves_initial_messages() {
597        let initial = vec![
598            Message::user("Earlier question"),
599            Message::assistant("Earlier answer"),
600        ];
601        let provider = MockProvider::new(vec![text_response("Follow up done")]);
602        let prompt = Prompt::new("Follow up")
603            .with_tools(vec![make_tool()])
604            .with_messages(initial);
605
606        let _ = chain(
607            &provider, "mock-model", prompt, None, false,
608            &MockExecutor, 5, &mut |_| {}, None, None,
609            ParallelConfig::default(),
610        ).await.unwrap();
611
612        let prompts = provider.captured_prompts.lock().unwrap();
613        // Should see initial 2 messages preserved
614        assert_eq!(prompts[0].messages.len(), 2);
615        assert_eq!(prompts[0].messages[0].content, "Earlier question");
616        assert_eq!(prompts[0].messages[1].content, "Earlier answer");
617    }
618
619    #[tokio::test]
620    async fn chain_captures_assistant_text_in_history() {
621        // Provider returns text + tool call in first response
622        let response1 = vec![
623            Chunk::Text("Let me check. ".into()),
624            Chunk::ToolCallStart { name: "test_tool".into(), id: Some("tc_1".into()) },
625            Chunk::ToolCallDelta { content: "{}".into() },
626            Chunk::Done,
627        ];
628        let provider = MockProvider::new(vec![response1, text_response("All done")]);
629        let prompt = Prompt::new("Do it").with_tools(vec![make_tool()]);
630
631        let _ = chain(
632            &provider, "mock-model", prompt, None, false,
633            &MockExecutor, 5, &mut |_| {}, None, None,
634            ParallelConfig::default(),
635        ).await.unwrap();
636
637        let prompts = provider.captured_prompts.lock().unwrap();
638        assert_eq!(prompts.len(), 2);
639        // Second prompt should have assistant message with both text and tool_calls
640        let assistant = &prompts[1].messages[1];
641        assert_eq!(assistant.role, crate::Role::Assistant);
642        assert_eq!(assistant.content, "Let me check. ");
643        assert_eq!(assistant.tool_calls.len(), 1);
644        assert_eq!(assistant.tool_calls[0].name, "test_tool");
645    }
646
647    #[tokio::test]
648    async fn chain_emits_iteration_start_event() {
649        let provider = MockProvider::new(vec![text_response("Hello!")]);
650        let prompt = Prompt::new("Hi").with_tools(vec![make_tool()]);
651        let mut events = Vec::new();
652
653        let _ = chain(
654            &provider, "mock-model", prompt, None, false,
655            &MockExecutor, 5, &mut |_| {},
656            Some(&mut |e: &ChainEvent| events.push(e.clone())),
657            None,
658            ParallelConfig::default(),
659        ).await.unwrap();
660
661        assert_eq!(events.len(), 2); // IterationStart + IterationEnd
662        match &events[0] {
663            ChainEvent::IterationStart { iteration, limit, messages } => {
664                assert_eq!(*iteration, 1);
665                assert_eq!(*limit, 5);
666                assert_eq!(messages.len(), 1);
667                assert_eq!(messages[0].role, crate::Role::User);
668            }
669            _ => panic!("expected IterationStart"),
670        }
671        match &events[1] {
672            ChainEvent::IterationEnd { iteration, usage, cumulative_usage, tool_calls } => {
673                assert_eq!(*iteration, 1);
674                assert!(usage.is_none());
675                assert!(cumulative_usage.is_none());
676                assert!(tool_calls.is_empty());
677            }
678            _ => panic!("expected IterationEnd"),
679        }
680    }
681
682    #[tokio::test]
683    async fn chain_emits_per_iteration_usage() {
684        let response1 = vec![
685            Chunk::ToolCallStart { name: "test_tool".into(), id: Some("tc_1".into()) },
686            Chunk::ToolCallDelta { content: "{}".into() },
687            Chunk::Usage(Usage { input: Some(10), output: Some(5), details: None }),
688            Chunk::Done,
689        ];
690        let response2 = vec![
691            Chunk::Text("Done".into()),
692            Chunk::Usage(Usage { input: Some(20), output: Some(10), details: None }),
693            Chunk::Done,
694        ];
695        let provider = MockProvider::new(vec![response1, response2]);
696        let prompt = Prompt::new("Go").with_tools(vec![make_tool()]);
697        let mut events = Vec::new();
698
699        let _ = chain(
700            &provider, "mock-model", prompt, None, false,
701            &MockExecutor, 5, &mut |_| {},
702            Some(&mut |e: &ChainEvent| events.push(e.clone())),
703            None,
704            ParallelConfig::default(),
705        ).await.unwrap();
706
707        // 2 iterations -> 4 events (start, end, start, end)
708        assert_eq!(events.len(), 4);
709        match &events[1] {
710            ChainEvent::IterationEnd { iteration, usage, cumulative_usage, tool_calls } => {
711                assert_eq!(*iteration, 1);
712                let u = usage.as_ref().unwrap();
713                assert_eq!(u.input, Some(10));
714                assert_eq!(u.output, Some(5));
715                let cum = cumulative_usage.as_ref().unwrap();
716                assert_eq!(cum.input, Some(10));
717                assert_eq!(cum.output, Some(5));
718                assert_eq!(tool_calls.len(), 1);
719            }
720            _ => panic!("expected IterationEnd"),
721        }
722        match &events[3] {
723            ChainEvent::IterationEnd { iteration, usage, cumulative_usage, tool_calls } => {
724                assert_eq!(*iteration, 2);
725                let u = usage.as_ref().unwrap();
726                assert_eq!(u.input, Some(20));
727                assert_eq!(u.output, Some(10));
728                let cum = cumulative_usage.as_ref().unwrap();
729                assert_eq!(cum.input, Some(30));
730                assert_eq!(cum.output, Some(15));
731                assert!(tool_calls.is_empty());
732            }
733            _ => panic!("expected IterationEnd"),
734        }
735    }
736
737    #[tokio::test]
738    async fn chain_events_correct_sequence() {
739        // 3-iteration chain: tool -> tool -> text
740        let provider = MockProvider::new(vec![
741            tool_call_response("test_tool", "tc_1", "{}"),
742            tool_call_response("test_tool", "tc_2", "{}"),
743            text_response("Done!"),
744        ]);
745        let prompt = Prompt::new("Go").with_tools(vec![make_tool()]);
746        let mut events = Vec::new();
747
748        let _ = chain(
749            &provider, "mock-model", prompt, None, false,
750            &MockExecutor, 5, &mut |_| {},
751            Some(&mut |e: &ChainEvent| events.push(e.clone())),
752            None,
753            ParallelConfig::default(),
754        ).await.unwrap();
755
756        assert_eq!(events.len(), 6);
757        assert!(matches!(&events[0], ChainEvent::IterationStart { iteration: 1, .. }));
758        assert!(matches!(&events[1], ChainEvent::IterationEnd { iteration: 1, .. }));
759        assert!(matches!(&events[2], ChainEvent::IterationStart { iteration: 2, .. }));
760        assert!(matches!(&events[3], ChainEvent::IterationEnd { iteration: 2, .. }));
761        assert!(matches!(&events[4], ChainEvent::IterationStart { iteration: 3, .. }));
762        assert!(matches!(&events[5], ChainEvent::IterationEnd { iteration: 3, .. }));
763
764        // Verify tool calls in end events
765        if let ChainEvent::IterationEnd { tool_calls, cumulative_usage, .. } = &events[1] {
766            assert_eq!(tool_calls.len(), 1);
767            assert!(cumulative_usage.is_none()); // no usage in mock tool_call_response
768        }
769        if let ChainEvent::IterationEnd { tool_calls, .. } = &events[5] {
770            assert!(tool_calls.is_empty());
771        }
772
773        // Verify message growth in start events
774        if let ChainEvent::IterationStart { messages, .. } = &events[0] {
775            assert_eq!(messages.len(), 1); // [user]
776        }
777        if let ChainEvent::IterationStart { messages, .. } = &events[2] {
778            assert_eq!(messages.len(), 3); // [user, assistant+tools, tool]
779        }
780        if let ChainEvent::IterationStart { messages, .. } = &events[4] {
781            assert_eq!(messages.len(), 5); // [user, a+t, tool, a+t, tool]
782        }
783    }
784
785    #[tokio::test]
786    async fn chain_none_on_event_works() {
787        let provider = MockProvider::new(vec![
788            tool_call_response("test_tool", "tc_1", "{}"),
789            text_response("Done!"),
790        ]);
791        let prompt = Prompt::new("Go").with_tools(vec![make_tool()]);
792
793        let result = chain(
794            &provider, "mock-model", prompt, None, false,
795            &MockExecutor, 5, &mut |_| {}, None, None,
796            ParallelConfig::default(),
797        ).await.unwrap();
798
799        assert_eq!(crate::collect_text(&result.chunks), "Done!");
800        assert_eq!(result.tool_results.len(), 1);
801    }
802
803    // --- ChainResult.total_usage tests ---
804
805    fn text_response_with_usage(text: &str, input: u64, output: u64) -> Vec<Chunk> {
806        vec![
807            Chunk::Text(text.into()),
808            Chunk::Usage(Usage { input: Some(input), output: Some(output), details: None }),
809            Chunk::Done,
810        ]
811    }
812
813    fn tool_call_response_with_usage(name: &str, id: &str, args: &str, input: u64, output: u64) -> Vec<Chunk> {
814        vec![
815            Chunk::ToolCallStart { name: name.into(), id: Some(id.into()) },
816            Chunk::ToolCallDelta { content: args.into() },
817            Chunk::Usage(Usage { input: Some(input), output: Some(output), details: None }),
818            Chunk::Done,
819        ]
820    }
821
822    #[tokio::test]
823    async fn chain_result_total_usage_single_iteration() {
824        let provider = MockProvider::new(vec![
825            text_response_with_usage("Hello!", 10, 5),
826        ]);
827        let prompt = Prompt::new("Hi").with_tools(vec![make_tool()]);
828
829        let result = chain(
830            &provider, "mock-model", prompt, None, false,
831            &MockExecutor, 5, &mut |_| {}, None, None,
832            ParallelConfig::default(),
833        ).await.unwrap();
834
835        let usage = result.total_usage.unwrap();
836        assert_eq!(usage.input, Some(10));
837        assert_eq!(usage.output, Some(5));
838        assert!(!result.budget_exhausted);
839    }
840
841    #[tokio::test]
842    async fn chain_result_total_usage_multi_iteration() {
843        let provider = MockProvider::new(vec![
844            tool_call_response_with_usage("test_tool", "tc_1", "{}", 10, 5),
845            text_response_with_usage("Done!", 20, 10),
846        ]);
847        let prompt = Prompt::new("Go").with_tools(vec![make_tool()]);
848
849        let result = chain(
850            &provider, "mock-model", prompt, None, false,
851            &MockExecutor, 5, &mut |_| {}, None, None,
852            ParallelConfig::default(),
853        ).await.unwrap();
854
855        let usage = result.total_usage.unwrap();
856        assert_eq!(usage.input, Some(30));
857        assert_eq!(usage.output, Some(15));
858    }
859
860    #[tokio::test]
861    async fn chain_result_total_usage_none() {
862        let provider = MockProvider::new(vec![text_response("Hello!")]);
863        let prompt = Prompt::new("Hi").with_tools(vec![make_tool()]);
864
865        let result = chain(
866            &provider, "mock-model", prompt, None, false,
867            &MockExecutor, 5, &mut |_| {}, None, None,
868            ParallelConfig::default(),
869        ).await.unwrap();
870
871        assert!(result.total_usage.is_none());
872    }
873
874    // --- cumulative_usage in ChainEvent::IterationEnd ---
875
876    #[tokio::test]
877    async fn chain_event_cumulative_usage() {
878        let provider = MockProvider::new(vec![
879            tool_call_response_with_usage("test_tool", "tc_1", "{}", 10, 5),
880            text_response_with_usage("Done!", 20, 10),
881        ]);
882        let prompt = Prompt::new("Go").with_tools(vec![make_tool()]);
883        let mut events = Vec::new();
884
885        let _ = chain(
886            &provider, "mock-model", prompt, None, false,
887            &MockExecutor, 5, &mut |_| {},
888            Some(&mut |e: &ChainEvent| events.push(e.clone())),
889            None,
890            ParallelConfig::default(),
891        ).await.unwrap();
892
893        // 2 iterations -> 4 events
894        assert_eq!(events.len(), 4);
895
896        // Iter 1 end: cumulative = (10, 5)
897        if let ChainEvent::IterationEnd { cumulative_usage, .. } = &events[1] {
898            let cum = cumulative_usage.as_ref().unwrap();
899            assert_eq!(cum.input, Some(10));
900            assert_eq!(cum.output, Some(5));
901        } else {
902            panic!("expected IterationEnd");
903        }
904
905        // Iter 2 end: cumulative = (30, 15)
906        if let ChainEvent::IterationEnd { cumulative_usage, .. } = &events[3] {
907            let cum = cumulative_usage.as_ref().unwrap();
908            assert_eq!(cum.input, Some(30));
909            assert_eq!(cum.output, Some(15));
910        } else {
911            panic!("expected IterationEnd");
912        }
913    }
914
915    // --- Budget enforcement tests ---
916
917    #[tokio::test]
918    async fn chain_budget_stops_when_exceeded() {
919        // budget=25, iter1 usage=30 (10+20) → stops after 1 iteration
920        let provider = MockProvider::new(vec![
921            tool_call_response_with_usage("test_tool", "tc_1", "{}", 10, 20),
922            text_response_with_usage("Should not reach", 10, 10),
923        ]);
924        let prompt = Prompt::new("Go").with_tools(vec![make_tool()]);
925
926        let result = chain(
927            &provider, "mock-model", prompt, None, false,
928            &MockExecutor, 5, &mut |_| {}, None, Some(25),
929            ParallelConfig::default(),
930        ).await.unwrap();
931
932        assert!(result.budget_exhausted);
933        assert_eq!(provider.call_count.load(Ordering::SeqCst), 1);
934        let usage = result.total_usage.unwrap();
935        assert_eq!(usage.total(), 30);
936    }
937
938    #[tokio::test]
939    async fn chain_budget_allows_under() {
940        // budget=100, iter1 usage=15 → text response, continues normally
941        let provider = MockProvider::new(vec![
942            text_response_with_usage("Hello!", 10, 5),
943        ]);
944        let prompt = Prompt::new("Hi").with_tools(vec![make_tool()]);
945
946        let result = chain(
947            &provider, "mock-model", prompt, None, false,
948            &MockExecutor, 5, &mut |_| {}, None, Some(100),
949            ParallelConfig::default(),
950        ).await.unwrap();
951
952        assert!(!result.budget_exhausted);
953        assert_eq!(provider.call_count.load(Ordering::SeqCst), 1);
954    }
955
956    #[tokio::test]
957    async fn chain_budget_multi_iteration_accumulates() {
958        // budget=40, iter1=15, iter2=15, iter3 would exceed → stops after 2
959        let provider = MockProvider::new(vec![
960            tool_call_response_with_usage("test_tool", "tc_1", "{}", 10, 5),
961            tool_call_response_with_usage("test_tool", "tc_2", "{}", 10, 5),
962            text_response_with_usage("Should not reach", 10, 5),
963        ]);
964        let prompt = Prompt::new("Go").with_tools(vec![make_tool()]);
965
966        let result = chain(
967            &provider, "mock-model", prompt, None, false,
968            &MockExecutor, 5, &mut |_| {}, None, Some(40),
969            ParallelConfig::default(),
970        ).await.unwrap();
971
972        // iter1: 15 total (under 40), iter2: 30 total (under 40) → both allowed
973        // Actually 30 < 40, so it should continue. Let me set budget to 25 instead.
974        assert!(!result.budget_exhausted);
975        // With budget=40 and 15 per iter, it will do 2 tool iterations (30 total < 40)
976        // then the 3rd would run (15+15+15=45 > 40 would trigger IF there were tool calls)
977        // But actually iter 3 is text, so it stops naturally
978        assert_eq!(provider.call_count.load(Ordering::SeqCst), 3);
979    }
980
981    #[tokio::test]
982    async fn chain_budget_multi_iteration_stops() {
983        // budget=25, iter1=15 (ok), iter2=15 (cumulative=30 > 25) → stops
984        let provider = MockProvider::new(vec![
985            tool_call_response_with_usage("test_tool", "tc_1", "{}", 10, 5),
986            tool_call_response_with_usage("test_tool", "tc_2", "{}", 10, 5),
987            text_response_with_usage("Should not reach", 10, 5),
988        ]);
989        let prompt = Prompt::new("Go").with_tools(vec![make_tool()]);
990
991        let result = chain(
992            &provider, "mock-model", prompt, None, false,
993            &MockExecutor, 5, &mut |_| {}, None, Some(25),
994            ParallelConfig::default(),
995        ).await.unwrap();
996
997        assert!(result.budget_exhausted);
998        assert_eq!(provider.call_count.load(Ordering::SeqCst), 2);
999        let usage = result.total_usage.unwrap();
1000        assert_eq!(usage.total(), 30);
1001    }
1002
1003    #[tokio::test]
1004    async fn chain_budget_none_no_enforcement() {
1005        let provider = MockProvider::new(vec![
1006            tool_call_response_with_usage("test_tool", "tc_1", "{}", 100, 100),
1007            text_response_with_usage("Done!", 100, 100),
1008        ]);
1009        let prompt = Prompt::new("Go").with_tools(vec![make_tool()]);
1010
1011        let result = chain(
1012            &provider, "mock-model", prompt, None, false,
1013            &MockExecutor, 5, &mut |_| {}, None, None,
1014            ParallelConfig::default(),
1015        ).await.unwrap();
1016
1017        assert!(!result.budget_exhausted);
1018        assert_eq!(provider.call_count.load(Ordering::SeqCst), 2);
1019    }
1020
1021    #[tokio::test]
1022    async fn chain_budget_emits_event() {
1023        let provider = MockProvider::new(vec![
1024            tool_call_response_with_usage("test_tool", "tc_1", "{}", 20, 15),
1025            text_response_with_usage("Should not reach", 10, 10),
1026        ]);
1027        let prompt = Prompt::new("Go").with_tools(vec![make_tool()]);
1028        let mut events = Vec::new();
1029
1030        let _ = chain(
1031            &provider, "mock-model", prompt, None, false,
1032            &MockExecutor, 5, &mut |_| {},
1033            Some(&mut |e: &ChainEvent| events.push(e.clone())),
1034            Some(30),
1035            ParallelConfig::default(),
1036        ).await.unwrap();
1037
1038        // Should have: IterationStart, IterationEnd, BudgetExhausted
1039        assert_eq!(events.len(), 3);
1040        match &events[2] {
1041            ChainEvent::BudgetExhausted { cumulative_usage, budget } => {
1042                assert_eq!(*budget, 30);
1043                assert_eq!(cumulative_usage.total(), 35);
1044            }
1045            _ => panic!("expected BudgetExhausted, got {:?}", events[2]),
1046        }
1047    }
1048
1049    // --- Parallel tool dispatch tests ---
1050
1051    /// Executor whose call latency is inversely proportional to the call's
1052    /// position in the input slice: the last call finishes first. If dispatch
1053    /// is sequential, the wall-clock is ~sum(delays); if parallel, ~max(delays).
1054    /// Order of `tool_results` must still match input order.
1055    struct StaggeredExecutor {
1056        total: usize,
1057        per_call_ms: u64,
1058    }
1059
1060    #[cfg_attr(not(target_arch = "wasm32"), async_trait)]
1061    #[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
1062    impl ToolExecutor for StaggeredExecutor {
1063        async fn execute(&self, call: &ToolCall) -> ToolResult {
1064            // Extract the input index from the tool_call_id (e.g. "tc_3" -> 3).
1065            let idx: usize = call
1066                .tool_call_id
1067                .as_deref()
1068                .and_then(|s| s.strip_prefix("tc_"))
1069                .and_then(|s| s.parse().ok())
1070                .unwrap_or(0);
1071            // Later calls sleep shorter so they finish first.
1072            let sleep_ms = self.per_call_ms * (self.total as u64 - idx as u64);
1073            tokio::time::sleep(std::time::Duration::from_millis(sleep_ms)).await;
1074            ToolResult {
1075                name: call.name.clone(),
1076                output: format!("result for {}", call.tool_call_id.as_deref().unwrap_or("?")),
1077                tool_call_id: call.tool_call_id.clone(),
1078                error: None,
1079            }
1080        }
1081    }
1082
1083    /// Build a provider response that emits N tool calls in a single turn,
1084    /// followed by a text response for the next iteration.
1085    fn multi_tool_call_response(n: usize) -> Vec<Chunk> {
1086        let mut chunks = Vec::new();
1087        for i in 0..n {
1088            chunks.push(Chunk::ToolCallStart {
1089                name: "test_tool".into(),
1090                id: Some(format!("tc_{i}")),
1091            });
1092            chunks.push(Chunk::ToolCallDelta {
1093                content: "{}".into(),
1094            });
1095        }
1096        chunks.push(Chunk::Done);
1097        chunks
1098    }
1099
1100    #[tokio::test]
1101    async fn chain_parallel_preserves_tool_call_order() {
1102        const N: usize = 5;
1103        const PER_CALL_MS: u64 = 100;
1104
1105        let provider = MockProvider::new(vec![
1106            multi_tool_call_response(N),
1107            text_response("Done!"),
1108        ]);
1109        let prompt = Prompt::new("Go").with_tools(vec![make_tool()]);
1110        let executor = StaggeredExecutor {
1111            total: N,
1112            per_call_ms: PER_CALL_MS,
1113        };
1114
1115        let start = std::time::Instant::now();
1116        let result = chain(
1117            &provider,
1118            "mock-model",
1119            prompt,
1120            None,
1121            false,
1122            &executor,
1123            5,
1124            &mut |_| {},
1125            None,
1126            None,
1127            ParallelConfig {
1128                enabled: true,
1129                max_concurrent: None,
1130            },
1131        )
1132        .await
1133        .unwrap();
1134        let elapsed = start.elapsed();
1135
1136        assert_eq!(result.tool_results.len(), N);
1137        for i in 0..N {
1138            assert_eq!(
1139                result.tool_results[i].tool_call_id.as_deref(),
1140                Some(format!("tc_{i}").as_str()),
1141                "result {i} out of order"
1142            );
1143        }
1144
1145        // Sequential total would be PER_CALL_MS * (N + N-1 + ... + 1) = 1500ms.
1146        // Parallel total should be dominated by the longest call (~500ms).
1147        // Give a generous ceiling to avoid flakiness.
1148        let sequential_sum_ms = PER_CALL_MS * (1..=N as u64).sum::<u64>();
1149        assert!(
1150            elapsed.as_millis() < (sequential_sum_ms as u128) / 2,
1151            "parallel dispatch took {elapsed:?}, expected << {sequential_sum_ms}ms"
1152        );
1153    }
1154
1155    /// Executor that tracks the maximum number of concurrent `execute()`
1156    /// calls in flight. Used to verify that `max_concurrent` caps actual
1157    /// parallelism.
1158    struct ConcurrencyProbe {
1159        live: Arc<AtomicUsize>,
1160        peak: Arc<AtomicUsize>,
1161        sleep_ms: u64,
1162    }
1163
1164    #[cfg_attr(not(target_arch = "wasm32"), async_trait)]
1165    #[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
1166    impl ToolExecutor for ConcurrencyProbe {
1167        async fn execute(&self, call: &ToolCall) -> ToolResult {
1168            let live_now = self.live.fetch_add(1, Ordering::SeqCst) + 1;
1169            self.peak.fetch_max(live_now, Ordering::SeqCst);
1170            tokio::time::sleep(std::time::Duration::from_millis(self.sleep_ms)).await;
1171            self.live.fetch_sub(1, Ordering::SeqCst);
1172            ToolResult {
1173                name: call.name.clone(),
1174                output: "ok".into(),
1175                tool_call_id: call.tool_call_id.clone(),
1176                error: None,
1177            }
1178        }
1179    }
1180
1181    #[tokio::test]
1182    async fn chain_parallel_bounded_concurrency() {
1183        const N: usize = 5;
1184        const CAP: usize = 2;
1185
1186        let provider = MockProvider::new(vec![
1187            multi_tool_call_response(N),
1188            text_response("Done!"),
1189        ]);
1190        let prompt = Prompt::new("Go").with_tools(vec![make_tool()]);
1191        let live = Arc::new(AtomicUsize::new(0));
1192        let peak = Arc::new(AtomicUsize::new(0));
1193        let executor = ConcurrencyProbe {
1194            live: live.clone(),
1195            peak: peak.clone(),
1196            sleep_ms: 50,
1197        };
1198
1199        let _ = chain(
1200            &provider,
1201            "mock-model",
1202            prompt,
1203            None,
1204            false,
1205            &executor,
1206            5,
1207            &mut |_| {},
1208            None,
1209            None,
1210            ParallelConfig {
1211                enabled: true,
1212                max_concurrent: Some(CAP),
1213            },
1214        )
1215        .await
1216        .unwrap();
1217
1218        assert_eq!(
1219            peak.load(Ordering::SeqCst),
1220            CAP,
1221            "expected peak concurrency == cap, peak saturation"
1222        );
1223    }
1224
1225    #[tokio::test]
1226    async fn chain_sequential_when_disabled() {
1227        let provider = MockProvider::new(vec![
1228            multi_tool_call_response(5),
1229            text_response("Done!"),
1230        ]);
1231        let prompt = Prompt::new("Go").with_tools(vec![make_tool()]);
1232        let live = Arc::new(AtomicUsize::new(0));
1233        let peak = Arc::new(AtomicUsize::new(0));
1234        let executor = ConcurrencyProbe {
1235            live: live.clone(),
1236            peak: peak.clone(),
1237            sleep_ms: 20,
1238        };
1239
1240        let _ = chain(
1241            &provider,
1242            "mock-model",
1243            prompt,
1244            None,
1245            false,
1246            &executor,
1247            5,
1248            &mut |_| {},
1249            None,
1250            None,
1251            ParallelConfig {
1252                enabled: false,
1253                max_concurrent: None,
1254            },
1255        )
1256        .await
1257        .unwrap();
1258
1259        assert_eq!(
1260            peak.load(Ordering::SeqCst),
1261            1,
1262            "expected peak == 1 when parallel dispatch is disabled"
1263        );
1264    }
1265
1266    #[tokio::test]
1267    async fn chain_single_call_is_sequential() {
1268        // One tool call: the single-call fast path should kick in and peak == 1.
1269        let provider = MockProvider::new(vec![
1270            tool_call_response("test_tool", "tc_0", "{}"),
1271            text_response("Done!"),
1272        ]);
1273        let prompt = Prompt::new("Go").with_tools(vec![make_tool()]);
1274        let live = Arc::new(AtomicUsize::new(0));
1275        let peak = Arc::new(AtomicUsize::new(0));
1276        let executor = ConcurrencyProbe {
1277            live: live.clone(),
1278            peak: peak.clone(),
1279            sleep_ms: 20,
1280        };
1281
1282        let _ = chain(
1283            &provider,
1284            "mock-model",
1285            prompt,
1286            None,
1287            false,
1288            &executor,
1289            5,
1290            &mut |_| {},
1291            None,
1292            None,
1293            ParallelConfig {
1294                enabled: true,
1295                max_concurrent: None,
1296            },
1297        )
1298        .await
1299        .unwrap();
1300
1301        assert_eq!(
1302            peak.load(Ordering::SeqCst),
1303            1,
1304            "expected peak == 1 on single-call fast path"
1305        );
1306    }
1307}