1use std::sync::atomic::{AtomicUsize, Ordering};
2use std::sync::{Arc, Mutex};
3
4use crate::{Context, LlmError, LlmResponse, LlmResponseStream, StreamingModelProvider};
5
6pub struct FakeLlmProvider {
7 responses: Vec<Vec<Result<LlmResponse, LlmError>>>,
8 call_count: AtomicUsize,
9 captured_contexts: Arc<Mutex<Vec<Context>>>,
11 display_name: String,
12 context_window: Option<u32>,
13}
14
15impl FakeLlmProvider {
16 pub fn new(responses: Vec<Vec<LlmResponse>>) -> Self {
17 let wrapped = responses.into_iter().map(|turn| turn.into_iter().map(Ok).collect()).collect();
18 Self::from_results(wrapped)
19 }
20
21 pub fn with_single_response(chunks: Vec<LlmResponse>) -> Self {
22 Self::new(vec![chunks])
23 }
24
25 pub fn from_results(responses: Vec<Vec<Result<LlmResponse, LlmError>>>) -> Self {
26 Self {
27 responses,
28 call_count: AtomicUsize::new(0),
29 captured_contexts: Arc::new(Mutex::new(Vec::new())),
30 display_name: "Fake LLM".to_string(),
31 context_window: None,
32 }
33 }
34
35 pub fn with_display_name(mut self, name: &str) -> Self {
36 self.display_name = name.to_string();
37 self
38 }
39
40 pub fn with_context_window(mut self, window: Option<u32>) -> Self {
41 self.context_window = window;
42 self
43 }
44
45 pub fn captured_contexts(&self) -> Arc<Mutex<Vec<Context>>> {
48 Arc::clone(&self.captured_contexts)
49 }
50}
51
52impl StreamingModelProvider for FakeLlmProvider {
53 fn stream_response(&self, context: &Context) -> LlmResponseStream {
54 if let Ok(mut contexts) = self.captured_contexts.lock() {
56 contexts.push(context.clone());
57 }
58
59 let current_call = self.call_count.fetch_add(1, Ordering::SeqCst);
60
61 let response = if current_call < self.responses.len() {
62 self.responses[current_call].clone()
63 } else {
64 vec![Ok(LlmResponse::done())]
65 };
66
67 Box::pin(tokio_stream::iter(response))
68 }
69
70 fn display_name(&self) -> String {
71 self.display_name.clone()
72 }
73
74 fn context_window(&self) -> Option<u32> {
75 self.context_window
76 }
77}