1use std::collections::HashMap;
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::sync::{Arc, Mutex};
4
5use tokio::spawn;
6use tokio::sync::{Notify, mpsc};
7use tokio_stream::wrappers::UnboundedReceiverStream;
8
9use crate::{Context, LlmError, LlmResponse, LlmResponseStream, StreamingModelProvider};
10
11pub struct FakeLlmProvider {
12 responses: Vec<Vec<Result<LlmResponse, LlmError>>>,
13 pauses: HashMap<usize, HashMap<usize, Arc<Notify>>>,
14 call_count: AtomicUsize,
15 captured_contexts: Arc<Mutex<Vec<Context>>>,
17 display_name: String,
18 context_window: Option<u32>,
19}
20
21impl FakeLlmProvider {
22 pub fn new(responses: Vec<Vec<LlmResponse>>) -> Self {
23 let wrapped = responses.into_iter().map(|turn| turn.into_iter().map(Ok).collect()).collect();
24 Self::from_results(wrapped)
25 }
26
27 pub fn with_single_response(chunks: Vec<LlmResponse>) -> Self {
28 Self::new(vec![chunks])
29 }
30
31 pub fn from_results(responses: Vec<Vec<Result<LlmResponse, LlmError>>>) -> Self {
32 Self {
33 responses,
34 pauses: HashMap::new(),
35 call_count: AtomicUsize::new(0),
36 captured_contexts: Arc::new(Mutex::new(Vec::new())),
37 display_name: "Fake LLM".to_string(),
38 context_window: None,
39 }
40 }
41
42 pub fn with_display_name(mut self, name: &str) -> Self {
43 self.display_name = name.to_string();
44 self
45 }
46
47 pub fn with_context_window(mut self, window: Option<u32>) -> Self {
48 self.context_window = window;
49 self
50 }
51
52 pub fn pause_turn_after(mut self, turn_index: usize, chunk_index: usize, notify: Arc<Notify>) -> Self {
56 self.pauses.entry(turn_index).or_default().insert(chunk_index, notify);
57 self
58 }
59
60 pub fn captured_contexts(&self) -> Arc<Mutex<Vec<Context>>> {
63 Arc::clone(&self.captured_contexts)
64 }
65}
66
67impl StreamingModelProvider for FakeLlmProvider {
68 fn stream_response(&self, context: &Context) -> LlmResponseStream {
69 if let Ok(mut contexts) = self.captured_contexts.lock() {
70 contexts.push(context.clone());
71 }
72
73 let current_call = self.call_count.fetch_add(1, Ordering::SeqCst);
74
75 let response = if current_call < self.responses.len() {
76 self.responses[current_call].clone()
77 } else {
78 vec![Ok(LlmResponse::done())]
79 };
80
81 let pauses = self.pauses.get(¤t_call).cloned().unwrap_or_default();
82 if pauses.is_empty() {
83 return Box::pin(tokio_stream::iter(response));
84 }
85
86 let (tx, rx) = mpsc::unbounded_channel();
87 spawn(async move {
88 for (index, chunk) in response.into_iter().enumerate() {
89 if tx.send(chunk).is_err() {
90 return;
91 }
92 if let Some(notify) = pauses.get(&index) {
93 notify.notified().await;
94 }
95 }
96 });
97 Box::pin(UnboundedReceiverStream::new(rx))
98 }
99
100 fn display_name(&self) -> String {
101 self.display_name.clone()
102 }
103
104 fn context_window(&self) -> Option<u32> {
105 self.context_window
106 }
107}