Skip to main content

llm/testing/
fake_llm.rs

1use std::sync::atomic::{AtomicUsize, Ordering};
2use std::sync::{Arc, Mutex};
3
4use crate::{Context, LlmError, LlmResponse, LlmResponseStream, StreamingModelProvider};
5
6pub struct FakeLlmProvider {
7    responses: Vec<Vec<Result<LlmResponse, LlmError>>>,
8    call_count: AtomicUsize,
9    /// Captured contexts from each call to `stream_response`
10    captured_contexts: Arc<Mutex<Vec<Context>>>,
11    display_name: String,
12    context_window: Option<u32>,
13}
14
15impl FakeLlmProvider {
16    pub fn new(responses: Vec<Vec<LlmResponse>>) -> Self {
17        let wrapped = responses.into_iter().map(|turn| turn.into_iter().map(Ok).collect()).collect();
18        Self::from_results(wrapped)
19    }
20
21    pub fn with_single_response(chunks: Vec<LlmResponse>) -> Self {
22        Self::new(vec![chunks])
23    }
24
25    pub fn from_results(responses: Vec<Vec<Result<LlmResponse, LlmError>>>) -> Self {
26        Self {
27            responses,
28            call_count: AtomicUsize::new(0),
29            captured_contexts: Arc::new(Mutex::new(Vec::new())),
30            display_name: "Fake LLM".to_string(),
31            context_window: None,
32        }
33    }
34
35    pub fn with_display_name(mut self, name: &str) -> Self {
36        self.display_name = name.to_string();
37        self
38    }
39
40    pub fn with_context_window(mut self, window: Option<u32>) -> Self {
41        self.context_window = window;
42        self
43    }
44
45    /// Returns a handle to the captured contexts that can be used to verify
46    /// what contexts were passed to the LLM.
47    pub fn captured_contexts(&self) -> Arc<Mutex<Vec<Context>>> {
48        Arc::clone(&self.captured_contexts)
49    }
50}
51
52impl StreamingModelProvider for FakeLlmProvider {
53    fn stream_response(&self, context: &Context) -> LlmResponseStream {
54        // Capture the context for later verification
55        if let Ok(mut contexts) = self.captured_contexts.lock() {
56            contexts.push(context.clone());
57        }
58
59        let current_call = self.call_count.fetch_add(1, Ordering::SeqCst);
60
61        let response = if current_call < self.responses.len() {
62            self.responses[current_call].clone()
63        } else {
64            vec![Ok(LlmResponse::done())]
65        };
66
67        Box::pin(tokio_stream::iter(response))
68    }
69
70    fn display_name(&self) -> String {
71        self.display_name.clone()
72    }
73
74    fn context_window(&self) -> Option<u32> {
75        self.context_window
76    }
77}