Skip to main content

llm/testing/
fake_llm.rs

1use std::collections::HashMap;
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::sync::{Arc, Mutex};
4
5use tokio::spawn;
6use tokio::sync::{Notify, mpsc};
7use tokio_stream::wrappers::UnboundedReceiverStream;
8
9use crate::{Context, LlmError, LlmResponse, LlmResponseStream, StreamingModelProvider};
10
11pub struct FakeLlmProvider {
12    responses: Vec<Vec<Result<LlmResponse, LlmError>>>,
13    pauses: HashMap<usize, HashMap<usize, Arc<Notify>>>,
14    call_count: AtomicUsize,
15    /// Captured contexts from each call to `stream_response`
16    captured_contexts: Arc<Mutex<Vec<Context>>>,
17    display_name: String,
18    context_window: Option<u32>,
19}
20
21impl FakeLlmProvider {
22    pub fn new(responses: Vec<Vec<LlmResponse>>) -> Self {
23        let wrapped = responses.into_iter().map(|turn| turn.into_iter().map(Ok).collect()).collect();
24        Self::from_results(wrapped)
25    }
26
27    pub fn with_single_response(chunks: Vec<LlmResponse>) -> Self {
28        Self::new(vec![chunks])
29    }
30
31    pub fn from_results(responses: Vec<Vec<Result<LlmResponse, LlmError>>>) -> Self {
32        Self {
33            responses,
34            pauses: HashMap::new(),
35            call_count: AtomicUsize::new(0),
36            captured_contexts: Arc::new(Mutex::new(Vec::new())),
37            display_name: "Fake LLM".to_string(),
38            context_window: None,
39        }
40    }
41
42    pub fn with_display_name(mut self, name: &str) -> Self {
43        self.display_name = name.to_string();
44        self
45    }
46
47    pub fn with_context_window(mut self, window: Option<u32>) -> Self {
48        self.context_window = window;
49        self
50    }
51
52    /// Pause the stream for `turn_index` after emitting the chunk at `chunk_index`
53    /// until the returned `Notify` is notified. Enables deterministic mid-stream
54    /// tests (e.g. send a user message after the provider has emitted some text).
55    pub fn pause_turn_after(mut self, turn_index: usize, chunk_index: usize, notify: Arc<Notify>) -> Self {
56        self.pauses.entry(turn_index).or_default().insert(chunk_index, notify);
57        self
58    }
59
60    /// Returns a handle to the captured contexts that can be used to verify
61    /// what contexts were passed to the LLM.
62    pub fn captured_contexts(&self) -> Arc<Mutex<Vec<Context>>> {
63        Arc::clone(&self.captured_contexts)
64    }
65}
66
67impl StreamingModelProvider for FakeLlmProvider {
68    fn stream_response(&self, context: &Context) -> LlmResponseStream {
69        if let Ok(mut contexts) = self.captured_contexts.lock() {
70            contexts.push(context.clone());
71        }
72
73        let current_call = self.call_count.fetch_add(1, Ordering::SeqCst);
74
75        let response = if current_call < self.responses.len() {
76            self.responses[current_call].clone()
77        } else {
78            vec![Ok(LlmResponse::done())]
79        };
80
81        let pauses = self.pauses.get(&current_call).cloned().unwrap_or_default();
82        if pauses.is_empty() {
83            return Box::pin(tokio_stream::iter(response));
84        }
85
86        let (tx, rx) = mpsc::unbounded_channel();
87        spawn(async move {
88            for (index, chunk) in response.into_iter().enumerate() {
89                if tx.send(chunk).is_err() {
90                    return;
91                }
92                if let Some(notify) = pauses.get(&index) {
93                    notify.notified().await;
94                }
95            }
96        });
97        Box::pin(UnboundedReceiverStream::new(rx))
98    }
99
100    fn display_name(&self) -> String {
101        self.display_name.clone()
102    }
103
104    fn context_window(&self) -> Option<u32> {
105        self.context_window
106    }
107}