1use async_trait::async_trait;
2use futures::StreamExt;
3use serde::{Deserialize, Serialize};
4
5use crate::provider::Provider;
6use crate::stream::{Chunk, collect_text, collect_tool_calls, collect_usage};
7use crate::types::{Message, Prompt, ToolCall, ToolResult, Usage};
8
9#[cfg(not(target_arch = "wasm32"))]
14#[async_trait]
15pub trait ToolExecutor: Send + Sync {
16 async fn execute(&self, call: &ToolCall) -> ToolResult;
17}
18
19#[cfg(target_arch = "wasm32")]
20#[async_trait(?Send)]
21pub trait ToolExecutor {
22 async fn execute(&self, call: &ToolCall) -> ToolResult;
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct ParallelConfig {
33 pub enabled: bool,
35 pub max_concurrent: Option<usize>,
38}
39
40impl Default for ParallelConfig {
41 fn default() -> Self {
42 Self {
43 enabled: true,
44 max_concurrent: None,
45 }
46 }
47}
48
49async fn dispatch_tools(
51 executor: &dyn ToolExecutor,
52 calls: &[ToolCall],
53 parallel: &ParallelConfig,
54) -> Vec<ToolResult> {
55 if !parallel.enabled || calls.len() <= 1 {
57 let mut out = Vec::with_capacity(calls.len());
58 for call in calls {
59 out.push(executor.execute(call).await);
60 }
61 return out;
62 }
63
64 let futs: Vec<_> = calls.iter().map(|c| executor.execute(c)).collect();
70 match parallel.max_concurrent {
71 Some(n) if n > 0 => {
72 futures::stream::iter(futs)
73 .buffered(n)
74 .collect::<Vec<_>>()
75 .await
76 }
77 _ => futures::future::join_all(futs).await,
78 }
79}
80
81#[derive(Debug, Clone)]
83pub enum ChainEvent {
84 IterationStart {
86 iteration: usize,
88 limit: usize,
90 messages: Vec<Message>,
92 },
93 IterationEnd {
95 iteration: usize,
97 usage: Option<Usage>,
99 cumulative_usage: Option<Usage>,
101 tool_calls: Vec<ToolCall>,
103 },
104 BudgetExhausted {
106 cumulative_usage: Usage,
108 budget: u64,
110 },
111}
112
113pub struct ChainResult {
115 pub chunks: Vec<Chunk>,
117 pub tool_results: Vec<ToolResult>,
119 pub total_usage: Option<Usage>,
121 pub budget_exhausted: bool,
123 pub messages: Vec<Message>,
126}
127
128#[allow(clippy::too_many_arguments)]
141pub async fn chain(
142 provider: &dyn Provider,
143 model: &str,
144 initial_prompt: Prompt,
145 key: Option<&str>,
146 stream: bool,
147 executor: &dyn ToolExecutor,
148 chain_limit: usize,
149 on_chunk: &mut dyn FnMut(&Chunk),
150 on_event: Option<&mut dyn FnMut(&ChainEvent)>,
151 budget: Option<u64>,
152 parallel: ParallelConfig,
153) -> crate::Result<ChainResult> {
154 let mut all_chunks = Vec::new();
155 let mut all_tool_results = Vec::new();
156 let mut on_event = on_event;
157 let mut cumulative_usage: Option<Usage> = None;
158 let mut budget_exhausted = false;
159
160 let mut messages: Vec<Message> = if initial_prompt.messages.is_empty() {
162 vec![Message::user(&initial_prompt.text)]
163 } else {
164 initial_prompt.messages.clone()
165 };
166
167 for iteration in 1..=chain_limit {
168 if let Some(cb) = &mut on_event {
169 cb(&ChainEvent::IterationStart {
170 iteration,
171 limit: chain_limit,
172 messages: messages.clone(),
173 });
174 }
175
176 let mut prompt = Prompt::new(&initial_prompt.text)
178 .with_tools(initial_prompt.tools.clone())
179 .with_messages(messages.clone());
180 if let Some(system) = &initial_prompt.system {
181 prompt = prompt.with_system(system);
182 }
183 if let Some(schema) = &initial_prompt.schema {
184 prompt = prompt.with_schema(schema.clone());
185 }
186
187 let response_stream = provider.execute(model, &prompt, key, stream).await?;
188
189 let mut iteration_chunks = Vec::new();
190 let mut pinned = std::pin::pin!(response_stream);
191
192 while let Some(result) = pinned.next().await {
193 let chunk = result?;
194 on_chunk(&chunk);
195 iteration_chunks.push(chunk);
196 }
197
198 let tool_calls = collect_tool_calls(&iteration_chunks);
199 let usage = collect_usage(&iteration_chunks);
200 let text = collect_text(&iteration_chunks);
201
202 cumulative_usage = match (&cumulative_usage, &usage) {
204 (Some(cum), Some(iter_usage)) => Some(cum.add(iter_usage)),
205 (None, Some(iter_usage)) => Some(iter_usage.clone()),
206 (cum, None) => cum.clone(),
207 };
208
209 if let Some(cb) = &mut on_event {
210 cb(&ChainEvent::IterationEnd {
211 iteration,
212 usage: usage.clone(),
213 cumulative_usage: cumulative_usage.clone(),
214 tool_calls: tool_calls.clone(),
215 });
216 }
217
218 all_chunks.extend(iteration_chunks);
219
220 messages.push(Message::assistant_with_tool_calls(&text, tool_calls.clone()));
222
223 if tool_calls.is_empty() {
224 break;
225 }
226
227 if let (Some(b), Some(cum)) = (budget, &cumulative_usage)
229 && cum.total() >= b
230 {
231 budget_exhausted = true;
232 if let Some(cb) = &mut on_event {
233 cb(&ChainEvent::BudgetExhausted {
234 cumulative_usage: cum.clone(),
235 budget: b,
236 });
237 }
238 break;
239 }
240
241 let tool_results = dispatch_tools(executor, &tool_calls, ¶llel).await;
243
244 all_tool_results.extend(tool_results.clone());
245
246 messages.push(Message::tool_results(tool_results));
248 }
249
250 Ok(ChainResult {
251 chunks: all_chunks,
252 tool_results: all_tool_results,
253 total_usage: cumulative_usage,
254 budget_exhausted,
255 messages,
256 })
257}
258
259#[cfg(test)]
260mod tests {
261 use super::*;
262 use crate::error::LlmError;
263 use crate::stream::ResponseStream;
264 use crate::types::{ModelInfo, Tool};
265 use std::sync::atomic::{AtomicUsize, Ordering};
266 use std::sync::{Arc, Mutex};
267
268 struct MockProvider {
270 responses: Vec<Vec<Chunk>>,
271 call_count: AtomicUsize,
272 captured_prompts: Arc<Mutex<Vec<Prompt>>>,
273 }
274
275 impl MockProvider {
276 fn new(responses: Vec<Vec<Chunk>>) -> Self {
277 Self {
278 responses,
279 call_count: AtomicUsize::new(0),
280 captured_prompts: Arc::new(Mutex::new(Vec::new())),
281 }
282 }
283 }
284
285 #[cfg_attr(not(target_arch = "wasm32"), async_trait)]
286 #[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
287 impl Provider for MockProvider {
288 fn id(&self) -> &str {
289 "mock"
290 }
291 fn models(&self) -> Vec<ModelInfo> {
292 vec![ModelInfo::new("mock-model")]
293 }
294 async fn execute(
295 &self,
296 _model: &str,
297 prompt: &Prompt,
298 _key: Option<&str>,
299 _stream: bool,
300 ) -> crate::Result<ResponseStream> {
301 self.captured_prompts.lock().unwrap().push(prompt.clone());
302 let idx = self.call_count.fetch_add(1, Ordering::SeqCst);
303 let chunks = if idx < self.responses.len() {
304 self.responses[idx].clone()
305 } else {
306 self.responses.last().cloned().unwrap_or_default()
308 };
309 let items: Vec<Result<Chunk, LlmError>> = chunks.into_iter().map(Ok).collect();
310 Ok(Box::pin(futures::stream::iter(items)))
311 }
312 }
313
314 struct MockExecutor;
316
317 #[cfg_attr(not(target_arch = "wasm32"), async_trait)]
318 #[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
319 impl ToolExecutor for MockExecutor {
320 async fn execute(&self, call: &ToolCall) -> ToolResult {
321 ToolResult {
322 name: call.name.clone(),
323 output: format!("result for {}", call.name),
324 tool_call_id: call.tool_call_id.clone(),
325 error: None,
326 }
327 }
328 }
329
330 struct ErrorExecutor;
331
332 #[cfg_attr(not(target_arch = "wasm32"), async_trait)]
333 #[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
334 impl ToolExecutor for ErrorExecutor {
335 async fn execute(&self, call: &ToolCall) -> ToolResult {
336 ToolResult {
337 name: call.name.clone(),
338 output: String::new(),
339 tool_call_id: call.tool_call_id.clone(),
340 error: Some("tool failed".into()),
341 }
342 }
343 }
344
345 fn text_response(text: &str) -> Vec<Chunk> {
346 vec![Chunk::Text(text.into()), Chunk::Done]
347 }
348
349 fn tool_call_response(name: &str, id: &str, args: &str) -> Vec<Chunk> {
350 vec![
351 Chunk::ToolCallStart {
352 name: name.into(),
353 id: Some(id.into()),
354 },
355 Chunk::ToolCallDelta {
356 content: args.into(),
357 },
358 Chunk::Done,
359 ]
360 }
361
362 fn make_tool() -> Tool {
363 Tool {
364 name: "test_tool".into(),
365 description: "A test".into(),
366 input_schema: serde_json::json!({"type": "object"}),
367 }
368 }
369
370 #[tokio::test]
371 async fn chain_no_tool_calls_single_iteration() {
372 let provider = MockProvider::new(vec![text_response("Hello!")]);
373 let prompt = Prompt::new("Hi").with_tools(vec![make_tool()]);
374 let mut callback_count = 0;
375
376 let result = chain(
377 &provider,
378 "mock-model",
379 prompt,
380 None,
381 false,
382 &MockExecutor,
383 5,
384 &mut |_| callback_count += 1,
385 None,
386 None,
387 ParallelConfig::default(),
388 )
389 .await
390 .unwrap();
391
392 assert_eq!(crate::collect_text(&result.chunks), "Hello!");
393 assert!(result.tool_results.is_empty());
394 assert_eq!(callback_count, 2); assert_eq!(provider.call_count.load(Ordering::SeqCst), 1);
396 }
397
398 #[tokio::test]
399 async fn chain_single_tool_call_two_iterations() {
400 let provider = MockProvider::new(vec![
401 tool_call_response("test_tool", "tc_1", "{}"),
402 text_response("Done!"),
403 ]);
404 let prompt = Prompt::new("Do something").with_tools(vec![make_tool()]);
405
406 let result = chain(
407 &provider,
408 "mock-model",
409 prompt,
410 None,
411 false,
412 &MockExecutor,
413 5,
414 &mut |_| {},
415 None,
416 None,
417 ParallelConfig::default(),
418 )
419 .await
420 .unwrap();
421
422 assert_eq!(crate::collect_text(&result.chunks), "Done!");
423 assert_eq!(result.tool_results.len(), 1);
424 assert_eq!(result.tool_results[0].name, "test_tool");
425 assert_eq!(provider.call_count.load(Ordering::SeqCst), 2);
426 }
427
428 #[tokio::test]
429 async fn chain_limit_stops_loop() {
430 let provider = MockProvider::new(vec![
432 tool_call_response("test_tool", "tc_1", "{}"),
433 ]);
434 let prompt = Prompt::new("Loop").with_tools(vec![make_tool()]);
435
436 let result = chain(
437 &provider,
438 "mock-model",
439 prompt,
440 None,
441 false,
442 &MockExecutor,
443 3,
444 &mut |_| {},
445 None,
446 None,
447 ParallelConfig::default(),
448 )
449 .await
450 .unwrap();
451
452 assert_eq!(provider.call_count.load(Ordering::SeqCst), 3);
453 assert_eq!(result.tool_results.len(), 3);
454 }
455
456 #[tokio::test]
457 async fn chain_multiple_tool_calls() {
458 let response = vec![
459 Chunk::ToolCallStart {
460 name: "tool_a".into(),
461 id: Some("tc_1".into()),
462 },
463 Chunk::ToolCallDelta {
464 content: "{}".into(),
465 },
466 Chunk::ToolCallStart {
467 name: "tool_b".into(),
468 id: Some("tc_2".into()),
469 },
470 Chunk::ToolCallDelta {
471 content: "{}".into(),
472 },
473 Chunk::Done,
474 ];
475
476 let provider = MockProvider::new(vec![response, text_response("All done")]);
477 let prompt = Prompt::new("Do both").with_tools(vec![make_tool()]);
478
479 let result = chain(
480 &provider,
481 "mock-model",
482 prompt,
483 None,
484 false,
485 &MockExecutor,
486 5,
487 &mut |_| {},
488 None,
489 None,
490 ParallelConfig::default(),
491 )
492 .await
493 .unwrap();
494
495 assert_eq!(crate::collect_text(&result.chunks), "All done");
496 assert_eq!(result.tool_results.len(), 2);
497 assert_eq!(result.tool_results[0].name, "tool_a");
498 assert_eq!(result.tool_results[1].name, "tool_b");
499 assert_eq!(provider.call_count.load(Ordering::SeqCst), 2);
500 }
501
502 #[tokio::test]
503 async fn chain_tool_error_continues() {
504 let provider = MockProvider::new(vec![
505 tool_call_response("test_tool", "tc_1", "{}"),
506 text_response("Handled error"),
507 ]);
508 let prompt = Prompt::new("Try").with_tools(vec![make_tool()]);
509
510 let result = chain(
511 &provider,
512 "mock-model",
513 prompt,
514 None,
515 false,
516 &ErrorExecutor,
517 5,
518 &mut |_| {},
519 None,
520 None,
521 ParallelConfig::default(),
522 )
523 .await
524 .unwrap();
525
526 assert_eq!(crate::collect_text(&result.chunks), "Handled error");
527 assert_eq!(result.tool_results.len(), 1);
528 assert!(result.tool_results[0].error.is_some());
529 assert_eq!(provider.call_count.load(Ordering::SeqCst), 2);
530 }
531
532 #[tokio::test]
533 async fn chain_callback_receives_chunks() {
534 let provider = MockProvider::new(vec![text_response("Hi")]);
535 let prompt = Prompt::new("Hello").with_tools(vec![make_tool()]);
536 let received = Arc::new(std::sync::Mutex::new(Vec::new()));
537 let received_clone = received.clone();
538
539 let _ = chain(
540 &provider,
541 "mock-model",
542 prompt,
543 None,
544 false,
545 &MockExecutor,
546 5,
547 &mut |chunk| received_clone.lock().unwrap().push(chunk.clone()),
548 None,
549 None,
550 ParallelConfig::default(),
551 )
552 .await
553 .unwrap();
554
555 let chunks = received.lock().unwrap();
556 assert_eq!(chunks.len(), 2);
557 assert!(matches!(&chunks[0], Chunk::Text(t) if t == "Hi"));
558 assert!(matches!(&chunks[1], Chunk::Done));
559 }
560
561 #[tokio::test]
562 async fn chain_accumulates_messages_across_turns() {
563 let provider = MockProvider::new(vec![
565 tool_call_response("test_tool", "tc_1", "{}"),
566 tool_call_response("test_tool", "tc_2", "{}"),
567 text_response("Done!"),
568 ]);
569 let prompt = Prompt::new("Do it").with_tools(vec![make_tool()]);
570
571 let _ = chain(
572 &provider, "mock-model", prompt, None, false,
573 &MockExecutor, 5, &mut |_| {}, None, None,
574 ParallelConfig::default(),
575 ).await.unwrap();
576
577 let prompts = provider.captured_prompts.lock().unwrap();
578 assert_eq!(prompts.len(), 3);
579
580 assert_eq!(prompts[0].messages.len(), 1);
582 assert_eq!(prompts[0].messages[0].role, crate::Role::User);
583
584 assert_eq!(prompts[1].messages.len(), 3);
586 assert_eq!(prompts[1].messages[0].role, crate::Role::User);
587 assert_eq!(prompts[1].messages[1].role, crate::Role::Assistant);
588 assert!(!prompts[1].messages[1].tool_calls.is_empty());
589 assert_eq!(prompts[1].messages[2].role, crate::Role::Tool);
590
591 assert_eq!(prompts[2].messages.len(), 5);
593 }
594
595 #[tokio::test]
596 async fn chain_preserves_initial_messages() {
597 let initial = vec![
598 Message::user("Earlier question"),
599 Message::assistant("Earlier answer"),
600 ];
601 let provider = MockProvider::new(vec![text_response("Follow up done")]);
602 let prompt = Prompt::new("Follow up")
603 .with_tools(vec![make_tool()])
604 .with_messages(initial);
605
606 let _ = chain(
607 &provider, "mock-model", prompt, None, false,
608 &MockExecutor, 5, &mut |_| {}, None, None,
609 ParallelConfig::default(),
610 ).await.unwrap();
611
612 let prompts = provider.captured_prompts.lock().unwrap();
613 assert_eq!(prompts[0].messages.len(), 2);
615 assert_eq!(prompts[0].messages[0].content, "Earlier question");
616 assert_eq!(prompts[0].messages[1].content, "Earlier answer");
617 }
618
619 #[tokio::test]
620 async fn chain_captures_assistant_text_in_history() {
621 let response1 = vec![
623 Chunk::Text("Let me check. ".into()),
624 Chunk::ToolCallStart { name: "test_tool".into(), id: Some("tc_1".into()) },
625 Chunk::ToolCallDelta { content: "{}".into() },
626 Chunk::Done,
627 ];
628 let provider = MockProvider::new(vec![response1, text_response("All done")]);
629 let prompt = Prompt::new("Do it").with_tools(vec![make_tool()]);
630
631 let _ = chain(
632 &provider, "mock-model", prompt, None, false,
633 &MockExecutor, 5, &mut |_| {}, None, None,
634 ParallelConfig::default(),
635 ).await.unwrap();
636
637 let prompts = provider.captured_prompts.lock().unwrap();
638 assert_eq!(prompts.len(), 2);
639 let assistant = &prompts[1].messages[1];
641 assert_eq!(assistant.role, crate::Role::Assistant);
642 assert_eq!(assistant.content, "Let me check. ");
643 assert_eq!(assistant.tool_calls.len(), 1);
644 assert_eq!(assistant.tool_calls[0].name, "test_tool");
645 }
646
647 #[tokio::test]
648 async fn chain_emits_iteration_start_event() {
649 let provider = MockProvider::new(vec![text_response("Hello!")]);
650 let prompt = Prompt::new("Hi").with_tools(vec![make_tool()]);
651 let mut events = Vec::new();
652
653 let _ = chain(
654 &provider, "mock-model", prompt, None, false,
655 &MockExecutor, 5, &mut |_| {},
656 Some(&mut |e: &ChainEvent| events.push(e.clone())),
657 None,
658 ParallelConfig::default(),
659 ).await.unwrap();
660
661 assert_eq!(events.len(), 2); match &events[0] {
663 ChainEvent::IterationStart { iteration, limit, messages } => {
664 assert_eq!(*iteration, 1);
665 assert_eq!(*limit, 5);
666 assert_eq!(messages.len(), 1);
667 assert_eq!(messages[0].role, crate::Role::User);
668 }
669 _ => panic!("expected IterationStart"),
670 }
671 match &events[1] {
672 ChainEvent::IterationEnd { iteration, usage, cumulative_usage, tool_calls } => {
673 assert_eq!(*iteration, 1);
674 assert!(usage.is_none());
675 assert!(cumulative_usage.is_none());
676 assert!(tool_calls.is_empty());
677 }
678 _ => panic!("expected IterationEnd"),
679 }
680 }
681
682 #[tokio::test]
683 async fn chain_emits_per_iteration_usage() {
684 let response1 = vec![
685 Chunk::ToolCallStart { name: "test_tool".into(), id: Some("tc_1".into()) },
686 Chunk::ToolCallDelta { content: "{}".into() },
687 Chunk::Usage(Usage { input: Some(10), output: Some(5), details: None }),
688 Chunk::Done,
689 ];
690 let response2 = vec![
691 Chunk::Text("Done".into()),
692 Chunk::Usage(Usage { input: Some(20), output: Some(10), details: None }),
693 Chunk::Done,
694 ];
695 let provider = MockProvider::new(vec![response1, response2]);
696 let prompt = Prompt::new("Go").with_tools(vec![make_tool()]);
697 let mut events = Vec::new();
698
699 let _ = chain(
700 &provider, "mock-model", prompt, None, false,
701 &MockExecutor, 5, &mut |_| {},
702 Some(&mut |e: &ChainEvent| events.push(e.clone())),
703 None,
704 ParallelConfig::default(),
705 ).await.unwrap();
706
707 assert_eq!(events.len(), 4);
709 match &events[1] {
710 ChainEvent::IterationEnd { iteration, usage, cumulative_usage, tool_calls } => {
711 assert_eq!(*iteration, 1);
712 let u = usage.as_ref().unwrap();
713 assert_eq!(u.input, Some(10));
714 assert_eq!(u.output, Some(5));
715 let cum = cumulative_usage.as_ref().unwrap();
716 assert_eq!(cum.input, Some(10));
717 assert_eq!(cum.output, Some(5));
718 assert_eq!(tool_calls.len(), 1);
719 }
720 _ => panic!("expected IterationEnd"),
721 }
722 match &events[3] {
723 ChainEvent::IterationEnd { iteration, usage, cumulative_usage, tool_calls } => {
724 assert_eq!(*iteration, 2);
725 let u = usage.as_ref().unwrap();
726 assert_eq!(u.input, Some(20));
727 assert_eq!(u.output, Some(10));
728 let cum = cumulative_usage.as_ref().unwrap();
729 assert_eq!(cum.input, Some(30));
730 assert_eq!(cum.output, Some(15));
731 assert!(tool_calls.is_empty());
732 }
733 _ => panic!("expected IterationEnd"),
734 }
735 }
736
737 #[tokio::test]
738 async fn chain_events_correct_sequence() {
739 let provider = MockProvider::new(vec![
741 tool_call_response("test_tool", "tc_1", "{}"),
742 tool_call_response("test_tool", "tc_2", "{}"),
743 text_response("Done!"),
744 ]);
745 let prompt = Prompt::new("Go").with_tools(vec![make_tool()]);
746 let mut events = Vec::new();
747
748 let _ = chain(
749 &provider, "mock-model", prompt, None, false,
750 &MockExecutor, 5, &mut |_| {},
751 Some(&mut |e: &ChainEvent| events.push(e.clone())),
752 None,
753 ParallelConfig::default(),
754 ).await.unwrap();
755
756 assert_eq!(events.len(), 6);
757 assert!(matches!(&events[0], ChainEvent::IterationStart { iteration: 1, .. }));
758 assert!(matches!(&events[1], ChainEvent::IterationEnd { iteration: 1, .. }));
759 assert!(matches!(&events[2], ChainEvent::IterationStart { iteration: 2, .. }));
760 assert!(matches!(&events[3], ChainEvent::IterationEnd { iteration: 2, .. }));
761 assert!(matches!(&events[4], ChainEvent::IterationStart { iteration: 3, .. }));
762 assert!(matches!(&events[5], ChainEvent::IterationEnd { iteration: 3, .. }));
763
764 if let ChainEvent::IterationEnd { tool_calls, cumulative_usage, .. } = &events[1] {
766 assert_eq!(tool_calls.len(), 1);
767 assert!(cumulative_usage.is_none()); }
769 if let ChainEvent::IterationEnd { tool_calls, .. } = &events[5] {
770 assert!(tool_calls.is_empty());
771 }
772
773 if let ChainEvent::IterationStart { messages, .. } = &events[0] {
775 assert_eq!(messages.len(), 1); }
777 if let ChainEvent::IterationStart { messages, .. } = &events[2] {
778 assert_eq!(messages.len(), 3); }
780 if let ChainEvent::IterationStart { messages, .. } = &events[4] {
781 assert_eq!(messages.len(), 5); }
783 }
784
785 #[tokio::test]
786 async fn chain_none_on_event_works() {
787 let provider = MockProvider::new(vec![
788 tool_call_response("test_tool", "tc_1", "{}"),
789 text_response("Done!"),
790 ]);
791 let prompt = Prompt::new("Go").with_tools(vec![make_tool()]);
792
793 let result = chain(
794 &provider, "mock-model", prompt, None, false,
795 &MockExecutor, 5, &mut |_| {}, None, None,
796 ParallelConfig::default(),
797 ).await.unwrap();
798
799 assert_eq!(crate::collect_text(&result.chunks), "Done!");
800 assert_eq!(result.tool_results.len(), 1);
801 }
802
803 fn text_response_with_usage(text: &str, input: u64, output: u64) -> Vec<Chunk> {
806 vec![
807 Chunk::Text(text.into()),
808 Chunk::Usage(Usage { input: Some(input), output: Some(output), details: None }),
809 Chunk::Done,
810 ]
811 }
812
813 fn tool_call_response_with_usage(name: &str, id: &str, args: &str, input: u64, output: u64) -> Vec<Chunk> {
814 vec![
815 Chunk::ToolCallStart { name: name.into(), id: Some(id.into()) },
816 Chunk::ToolCallDelta { content: args.into() },
817 Chunk::Usage(Usage { input: Some(input), output: Some(output), details: None }),
818 Chunk::Done,
819 ]
820 }
821
822 #[tokio::test]
823 async fn chain_result_total_usage_single_iteration() {
824 let provider = MockProvider::new(vec![
825 text_response_with_usage("Hello!", 10, 5),
826 ]);
827 let prompt = Prompt::new("Hi").with_tools(vec![make_tool()]);
828
829 let result = chain(
830 &provider, "mock-model", prompt, None, false,
831 &MockExecutor, 5, &mut |_| {}, None, None,
832 ParallelConfig::default(),
833 ).await.unwrap();
834
835 let usage = result.total_usage.unwrap();
836 assert_eq!(usage.input, Some(10));
837 assert_eq!(usage.output, Some(5));
838 assert!(!result.budget_exhausted);
839 }
840
841 #[tokio::test]
842 async fn chain_result_total_usage_multi_iteration() {
843 let provider = MockProvider::new(vec![
844 tool_call_response_with_usage("test_tool", "tc_1", "{}", 10, 5),
845 text_response_with_usage("Done!", 20, 10),
846 ]);
847 let prompt = Prompt::new("Go").with_tools(vec![make_tool()]);
848
849 let result = chain(
850 &provider, "mock-model", prompt, None, false,
851 &MockExecutor, 5, &mut |_| {}, None, None,
852 ParallelConfig::default(),
853 ).await.unwrap();
854
855 let usage = result.total_usage.unwrap();
856 assert_eq!(usage.input, Some(30));
857 assert_eq!(usage.output, Some(15));
858 }
859
860 #[tokio::test]
861 async fn chain_result_total_usage_none() {
862 let provider = MockProvider::new(vec![text_response("Hello!")]);
863 let prompt = Prompt::new("Hi").with_tools(vec![make_tool()]);
864
865 let result = chain(
866 &provider, "mock-model", prompt, None, false,
867 &MockExecutor, 5, &mut |_| {}, None, None,
868 ParallelConfig::default(),
869 ).await.unwrap();
870
871 assert!(result.total_usage.is_none());
872 }
873
874 #[tokio::test]
877 async fn chain_event_cumulative_usage() {
878 let provider = MockProvider::new(vec![
879 tool_call_response_with_usage("test_tool", "tc_1", "{}", 10, 5),
880 text_response_with_usage("Done!", 20, 10),
881 ]);
882 let prompt = Prompt::new("Go").with_tools(vec![make_tool()]);
883 let mut events = Vec::new();
884
885 let _ = chain(
886 &provider, "mock-model", prompt, None, false,
887 &MockExecutor, 5, &mut |_| {},
888 Some(&mut |e: &ChainEvent| events.push(e.clone())),
889 None,
890 ParallelConfig::default(),
891 ).await.unwrap();
892
893 assert_eq!(events.len(), 4);
895
896 if let ChainEvent::IterationEnd { cumulative_usage, .. } = &events[1] {
898 let cum = cumulative_usage.as_ref().unwrap();
899 assert_eq!(cum.input, Some(10));
900 assert_eq!(cum.output, Some(5));
901 } else {
902 panic!("expected IterationEnd");
903 }
904
905 if let ChainEvent::IterationEnd { cumulative_usage, .. } = &events[3] {
907 let cum = cumulative_usage.as_ref().unwrap();
908 assert_eq!(cum.input, Some(30));
909 assert_eq!(cum.output, Some(15));
910 } else {
911 panic!("expected IterationEnd");
912 }
913 }
914
915 #[tokio::test]
918 async fn chain_budget_stops_when_exceeded() {
919 let provider = MockProvider::new(vec![
921 tool_call_response_with_usage("test_tool", "tc_1", "{}", 10, 20),
922 text_response_with_usage("Should not reach", 10, 10),
923 ]);
924 let prompt = Prompt::new("Go").with_tools(vec![make_tool()]);
925
926 let result = chain(
927 &provider, "mock-model", prompt, None, false,
928 &MockExecutor, 5, &mut |_| {}, None, Some(25),
929 ParallelConfig::default(),
930 ).await.unwrap();
931
932 assert!(result.budget_exhausted);
933 assert_eq!(provider.call_count.load(Ordering::SeqCst), 1);
934 let usage = result.total_usage.unwrap();
935 assert_eq!(usage.total(), 30);
936 }
937
938 #[tokio::test]
939 async fn chain_budget_allows_under() {
940 let provider = MockProvider::new(vec![
942 text_response_with_usage("Hello!", 10, 5),
943 ]);
944 let prompt = Prompt::new("Hi").with_tools(vec![make_tool()]);
945
946 let result = chain(
947 &provider, "mock-model", prompt, None, false,
948 &MockExecutor, 5, &mut |_| {}, None, Some(100),
949 ParallelConfig::default(),
950 ).await.unwrap();
951
952 assert!(!result.budget_exhausted);
953 assert_eq!(provider.call_count.load(Ordering::SeqCst), 1);
954 }
955
956 #[tokio::test]
957 async fn chain_budget_multi_iteration_accumulates() {
958 let provider = MockProvider::new(vec![
960 tool_call_response_with_usage("test_tool", "tc_1", "{}", 10, 5),
961 tool_call_response_with_usage("test_tool", "tc_2", "{}", 10, 5),
962 text_response_with_usage("Should not reach", 10, 5),
963 ]);
964 let prompt = Prompt::new("Go").with_tools(vec![make_tool()]);
965
966 let result = chain(
967 &provider, "mock-model", prompt, None, false,
968 &MockExecutor, 5, &mut |_| {}, None, Some(40),
969 ParallelConfig::default(),
970 ).await.unwrap();
971
972 assert!(!result.budget_exhausted);
975 assert_eq!(provider.call_count.load(Ordering::SeqCst), 3);
979 }
980
981 #[tokio::test]
982 async fn chain_budget_multi_iteration_stops() {
983 let provider = MockProvider::new(vec![
985 tool_call_response_with_usage("test_tool", "tc_1", "{}", 10, 5),
986 tool_call_response_with_usage("test_tool", "tc_2", "{}", 10, 5),
987 text_response_with_usage("Should not reach", 10, 5),
988 ]);
989 let prompt = Prompt::new("Go").with_tools(vec![make_tool()]);
990
991 let result = chain(
992 &provider, "mock-model", prompt, None, false,
993 &MockExecutor, 5, &mut |_| {}, None, Some(25),
994 ParallelConfig::default(),
995 ).await.unwrap();
996
997 assert!(result.budget_exhausted);
998 assert_eq!(provider.call_count.load(Ordering::SeqCst), 2);
999 let usage = result.total_usage.unwrap();
1000 assert_eq!(usage.total(), 30);
1001 }
1002
1003 #[tokio::test]
1004 async fn chain_budget_none_no_enforcement() {
1005 let provider = MockProvider::new(vec![
1006 tool_call_response_with_usage("test_tool", "tc_1", "{}", 100, 100),
1007 text_response_with_usage("Done!", 100, 100),
1008 ]);
1009 let prompt = Prompt::new("Go").with_tools(vec![make_tool()]);
1010
1011 let result = chain(
1012 &provider, "mock-model", prompt, None, false,
1013 &MockExecutor, 5, &mut |_| {}, None, None,
1014 ParallelConfig::default(),
1015 ).await.unwrap();
1016
1017 assert!(!result.budget_exhausted);
1018 assert_eq!(provider.call_count.load(Ordering::SeqCst), 2);
1019 }
1020
1021 #[tokio::test]
1022 async fn chain_budget_emits_event() {
1023 let provider = MockProvider::new(vec![
1024 tool_call_response_with_usage("test_tool", "tc_1", "{}", 20, 15),
1025 text_response_with_usage("Should not reach", 10, 10),
1026 ]);
1027 let prompt = Prompt::new("Go").with_tools(vec![make_tool()]);
1028 let mut events = Vec::new();
1029
1030 let _ = chain(
1031 &provider, "mock-model", prompt, None, false,
1032 &MockExecutor, 5, &mut |_| {},
1033 Some(&mut |e: &ChainEvent| events.push(e.clone())),
1034 Some(30),
1035 ParallelConfig::default(),
1036 ).await.unwrap();
1037
1038 assert_eq!(events.len(), 3);
1040 match &events[2] {
1041 ChainEvent::BudgetExhausted { cumulative_usage, budget } => {
1042 assert_eq!(*budget, 30);
1043 assert_eq!(cumulative_usage.total(), 35);
1044 }
1045 _ => panic!("expected BudgetExhausted, got {:?}", events[2]),
1046 }
1047 }
1048
1049 struct StaggeredExecutor {
1056 total: usize,
1057 per_call_ms: u64,
1058 }
1059
1060 #[cfg_attr(not(target_arch = "wasm32"), async_trait)]
1061 #[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
1062 impl ToolExecutor for StaggeredExecutor {
1063 async fn execute(&self, call: &ToolCall) -> ToolResult {
1064 let idx: usize = call
1066 .tool_call_id
1067 .as_deref()
1068 .and_then(|s| s.strip_prefix("tc_"))
1069 .and_then(|s| s.parse().ok())
1070 .unwrap_or(0);
1071 let sleep_ms = self.per_call_ms * (self.total as u64 - idx as u64);
1073 tokio::time::sleep(std::time::Duration::from_millis(sleep_ms)).await;
1074 ToolResult {
1075 name: call.name.clone(),
1076 output: format!("result for {}", call.tool_call_id.as_deref().unwrap_or("?")),
1077 tool_call_id: call.tool_call_id.clone(),
1078 error: None,
1079 }
1080 }
1081 }
1082
1083 fn multi_tool_call_response(n: usize) -> Vec<Chunk> {
1086 let mut chunks = Vec::new();
1087 for i in 0..n {
1088 chunks.push(Chunk::ToolCallStart {
1089 name: "test_tool".into(),
1090 id: Some(format!("tc_{i}")),
1091 });
1092 chunks.push(Chunk::ToolCallDelta {
1093 content: "{}".into(),
1094 });
1095 }
1096 chunks.push(Chunk::Done);
1097 chunks
1098 }
1099
1100 #[tokio::test]
1101 async fn chain_parallel_preserves_tool_call_order() {
1102 const N: usize = 5;
1103 const PER_CALL_MS: u64 = 100;
1104
1105 let provider = MockProvider::new(vec![
1106 multi_tool_call_response(N),
1107 text_response("Done!"),
1108 ]);
1109 let prompt = Prompt::new("Go").with_tools(vec![make_tool()]);
1110 let executor = StaggeredExecutor {
1111 total: N,
1112 per_call_ms: PER_CALL_MS,
1113 };
1114
1115 let start = std::time::Instant::now();
1116 let result = chain(
1117 &provider,
1118 "mock-model",
1119 prompt,
1120 None,
1121 false,
1122 &executor,
1123 5,
1124 &mut |_| {},
1125 None,
1126 None,
1127 ParallelConfig {
1128 enabled: true,
1129 max_concurrent: None,
1130 },
1131 )
1132 .await
1133 .unwrap();
1134 let elapsed = start.elapsed();
1135
1136 assert_eq!(result.tool_results.len(), N);
1137 for i in 0..N {
1138 assert_eq!(
1139 result.tool_results[i].tool_call_id.as_deref(),
1140 Some(format!("tc_{i}").as_str()),
1141 "result {i} out of order"
1142 );
1143 }
1144
1145 let sequential_sum_ms = PER_CALL_MS * (1..=N as u64).sum::<u64>();
1149 assert!(
1150 elapsed.as_millis() < (sequential_sum_ms as u128) / 2,
1151 "parallel dispatch took {elapsed:?}, expected << {sequential_sum_ms}ms"
1152 );
1153 }
1154
1155 struct ConcurrencyProbe {
1159 live: Arc<AtomicUsize>,
1160 peak: Arc<AtomicUsize>,
1161 sleep_ms: u64,
1162 }
1163
1164 #[cfg_attr(not(target_arch = "wasm32"), async_trait)]
1165 #[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
1166 impl ToolExecutor for ConcurrencyProbe {
1167 async fn execute(&self, call: &ToolCall) -> ToolResult {
1168 let live_now = self.live.fetch_add(1, Ordering::SeqCst) + 1;
1169 self.peak.fetch_max(live_now, Ordering::SeqCst);
1170 tokio::time::sleep(std::time::Duration::from_millis(self.sleep_ms)).await;
1171 self.live.fetch_sub(1, Ordering::SeqCst);
1172 ToolResult {
1173 name: call.name.clone(),
1174 output: "ok".into(),
1175 tool_call_id: call.tool_call_id.clone(),
1176 error: None,
1177 }
1178 }
1179 }
1180
1181 #[tokio::test]
1182 async fn chain_parallel_bounded_concurrency() {
1183 const N: usize = 5;
1184 const CAP: usize = 2;
1185
1186 let provider = MockProvider::new(vec![
1187 multi_tool_call_response(N),
1188 text_response("Done!"),
1189 ]);
1190 let prompt = Prompt::new("Go").with_tools(vec![make_tool()]);
1191 let live = Arc::new(AtomicUsize::new(0));
1192 let peak = Arc::new(AtomicUsize::new(0));
1193 let executor = ConcurrencyProbe {
1194 live: live.clone(),
1195 peak: peak.clone(),
1196 sleep_ms: 50,
1197 };
1198
1199 let _ = chain(
1200 &provider,
1201 "mock-model",
1202 prompt,
1203 None,
1204 false,
1205 &executor,
1206 5,
1207 &mut |_| {},
1208 None,
1209 None,
1210 ParallelConfig {
1211 enabled: true,
1212 max_concurrent: Some(CAP),
1213 },
1214 )
1215 .await
1216 .unwrap();
1217
1218 assert_eq!(
1219 peak.load(Ordering::SeqCst),
1220 CAP,
1221 "expected peak concurrency == cap, peak saturation"
1222 );
1223 }
1224
1225 #[tokio::test]
1226 async fn chain_sequential_when_disabled() {
1227 let provider = MockProvider::new(vec![
1228 multi_tool_call_response(5),
1229 text_response("Done!"),
1230 ]);
1231 let prompt = Prompt::new("Go").with_tools(vec![make_tool()]);
1232 let live = Arc::new(AtomicUsize::new(0));
1233 let peak = Arc::new(AtomicUsize::new(0));
1234 let executor = ConcurrencyProbe {
1235 live: live.clone(),
1236 peak: peak.clone(),
1237 sleep_ms: 20,
1238 };
1239
1240 let _ = chain(
1241 &provider,
1242 "mock-model",
1243 prompt,
1244 None,
1245 false,
1246 &executor,
1247 5,
1248 &mut |_| {},
1249 None,
1250 None,
1251 ParallelConfig {
1252 enabled: false,
1253 max_concurrent: None,
1254 },
1255 )
1256 .await
1257 .unwrap();
1258
1259 assert_eq!(
1260 peak.load(Ordering::SeqCst),
1261 1,
1262 "expected peak == 1 when parallel dispatch is disabled"
1263 );
1264 }
1265
1266 #[tokio::test]
1267 async fn chain_single_call_is_sequential() {
1268 let provider = MockProvider::new(vec![
1270 tool_call_response("test_tool", "tc_0", "{}"),
1271 text_response("Done!"),
1272 ]);
1273 let prompt = Prompt::new("Go").with_tools(vec![make_tool()]);
1274 let live = Arc::new(AtomicUsize::new(0));
1275 let peak = Arc::new(AtomicUsize::new(0));
1276 let executor = ConcurrencyProbe {
1277 live: live.clone(),
1278 peak: peak.clone(),
1279 sleep_ms: 20,
1280 };
1281
1282 let _ = chain(
1283 &provider,
1284 "mock-model",
1285 prompt,
1286 None,
1287 false,
1288 &executor,
1289 5,
1290 &mut |_| {},
1291 None,
1292 None,
1293 ParallelConfig {
1294 enabled: true,
1295 max_concurrent: None,
1296 },
1297 )
1298 .await
1299 .unwrap();
1300
1301 assert_eq!(
1302 peak.load(Ordering::SeqCst),
1303 1,
1304 "expected peak == 1 on single-call fast path"
1305 );
1306 }
1307}