Skip to main content

llm_stack/
mock.rs

1//! Mock provider for testing.
2//!
3//! [`MockProvider`] is a queue-based fake that lets tests control
4//! exactly what responses and errors a provider returns, without
5//! touching the network. It implements [`Provider`],
6//! so it works anywhere a real provider does — including through
7//! [`DynProvider`](crate::DynProvider) via the blanket impl.
8//!
9//! # Usage
10//!
11//! ```rust,no_run
12//! use llm_stack::mock::{MockProvider, MockError};
13//! use llm_stack::{Provider, ChatParams, ChatResponse, ContentBlock, StopReason, Usage};
14//! use std::collections::{HashMap, HashSet};
15//!
16//! # async fn example() {
17//! let mock = MockProvider::new(llm_stack::ProviderMetadata {
18//!     name: "test".into(),
19//!     model: "test-model".into(),
20//!     context_window: 4096,
21//!     capabilities: HashSet::new(),
22//! });
23//!
24//! mock.queue_response(ChatResponse {
25//!     content: vec![ContentBlock::Text("Hello!".into())],
26//!     usage: Usage::default(),
27//!     stop_reason: StopReason::EndTurn,
28//!     model: "test-model".into(),
29//!     metadata: HashMap::new(),
30//! });
31//!
32//! let resp = mock.generate(&ChatParams::default()).await.unwrap();
33//! assert_eq!(mock.recorded_calls().len(), 1);
34//! # }
35//! ```
36//!
37//! # Why `MockError` instead of `LlmError`?
38//!
39//! [`LlmError`] contains `Box<dyn Error>` and is not
40//! `Clone`, so it can't be stored in a queue. [`MockError`] mirrors the
41//! common error variants in a cloneable form and converts to `LlmError`
42//! at dequeue time.
43
44use std::collections::VecDeque;
45use std::fmt;
46use std::sync::{Arc, Mutex};
47
48use crate::chat::ChatResponse;
49use crate::error::LlmError;
50use crate::provider::{ChatParams, Provider, ProviderMetadata};
51use crate::stream::{ChatStream, StreamEvent};
52
53/// A queue-based mock provider for unit and integration tests.
54///
55/// Push responses with [`queue_response`](Self::queue_response) and
56/// errors with [`queue_error`](Self::queue_error). Each call to
57/// `generate` or `stream` pops from the front of the respective queue.
58///
59/// Every call records its [`ChatParams`] for later assertion via
60/// [`recorded_calls`](Self::recorded_calls).
61///
62/// # Panics
63///
64/// [`generate`](Provider::generate) panics if the response queue is empty.
65/// [`stream`](Provider::stream) panics if the stream queue is empty.
66pub struct MockProvider {
67    responses: Mutex<VecDeque<Result<ChatResponse, MockError>>>,
68    stream_responses: Mutex<VecDeque<Result<Vec<StreamEvent>, MockError>>>,
69    meta: ProviderMetadata,
70    calls: Arc<Mutex<Vec<ChatParams>>>,
71}
72
73/// Cloneable error subset for mock queuing.
74///
75/// [`LlmError`] contains `Box<dyn Error>` and is not `Clone`, so it
76/// can't be queued directly. This type mirrors the common error
77/// variants. Use [`queue_error`](MockProvider::queue_error) to enqueue
78/// one — it is converted to `LlmError` when dequeued.
79#[derive(Debug, Clone)]
80pub enum MockError {
81    /// Maps to [`LlmError::Http`].
82    Http {
83        /// HTTP status code, if any.
84        status: Option<http::StatusCode>,
85        /// Error message.
86        message: String,
87        /// Whether the error is retryable.
88        retryable: bool,
89    },
90    /// Maps to [`LlmError::Auth`].
91    Auth(String),
92    /// Maps to [`LlmError::InvalidRequest`].
93    InvalidRequest(String),
94    /// Maps to [`LlmError::Provider`].
95    Provider {
96        /// Provider error code.
97        code: String,
98        /// Error message.
99        message: String,
100        /// Whether the error is retryable.
101        retryable: bool,
102    },
103    /// Maps to [`LlmError::Timeout`].
104    Timeout {
105        /// Elapsed milliseconds.
106        elapsed_ms: u64,
107    },
108    /// Maps to [`LlmError::ResponseFormat`].
109    ResponseFormat {
110        /// What went wrong during parsing.
111        message: String,
112        /// The raw response body.
113        raw: String,
114    },
115    /// Maps to [`LlmError::SchemaValidation`].
116    SchemaValidation {
117        /// Validation error messages.
118        message: String,
119        /// The schema that was violated.
120        schema: serde_json::Value,
121        /// The value that failed validation.
122        actual: serde_json::Value,
123    },
124    /// Maps to [`LlmError::RetryExhausted`].
125    RetryExhausted {
126        /// How many attempts were made.
127        attempts: u32,
128        /// Description of the last error.
129        last_error_message: String,
130    },
131}
132
133impl MockError {
134    fn into_llm_error(self) -> LlmError {
135        match self {
136            Self::Http {
137                status,
138                message,
139                retryable,
140            } => LlmError::Http {
141                status,
142                message,
143                retryable,
144            },
145            Self::Auth(msg) => LlmError::Auth(msg),
146            Self::InvalidRequest(msg) => LlmError::InvalidRequest(msg),
147            Self::Provider {
148                code,
149                message,
150                retryable,
151            } => LlmError::Provider {
152                code,
153                message,
154                retryable,
155            },
156            Self::Timeout { elapsed_ms } => LlmError::Timeout { elapsed_ms },
157            Self::ResponseFormat { message, raw } => LlmError::ResponseFormat { message, raw },
158            Self::SchemaValidation {
159                message,
160                schema,
161                actual,
162            } => LlmError::SchemaValidation {
163                message,
164                schema,
165                actual,
166            },
167            Self::RetryExhausted {
168                attempts,
169                last_error_message,
170            } => LlmError::RetryExhausted {
171                attempts,
172                last_error: Box::new(LlmError::InvalidRequest(last_error_message)),
173            },
174        }
175    }
176}
177
178impl fmt::Debug for MockProvider {
179    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
180        let response_len = self.responses.lock().unwrap().len();
181        let stream_len = self.stream_responses.lock().unwrap().len();
182        let call_count = self.calls.lock().unwrap().len();
183        f.debug_struct("MockProvider")
184            .field("meta", &self.meta)
185            .field("queued_responses", &response_len)
186            .field("queued_streams", &stream_len)
187            .field("recorded_calls", &call_count)
188            .finish()
189    }
190}
191
192impl MockProvider {
193    /// Creates a new mock with the given metadata and empty queues.
194    pub fn new(meta: ProviderMetadata) -> Self {
195        Self {
196            responses: Mutex::new(VecDeque::new()),
197            stream_responses: Mutex::new(VecDeque::new()),
198            meta,
199            calls: Arc::new(Mutex::new(Vec::new())),
200        }
201    }
202
203    /// Enqueues a successful response for the next `generate` call.
204    pub fn queue_response(&self, response: ChatResponse) -> &Self {
205        self.responses.lock().unwrap().push_back(Ok(response));
206        self
207    }
208
209    /// Enqueues an error for the next `generate` call.
210    pub fn queue_error(&self, error: MockError) -> &Self {
211        self.responses.lock().unwrap().push_back(Err(error));
212        self
213    }
214
215    /// Enqueues stream events for the next `stream` call.
216    pub fn queue_stream(&self, events: Vec<StreamEvent>) -> &Self {
217        self.stream_responses.lock().unwrap().push_back(Ok(events));
218        self
219    }
220
221    /// Enqueues an error for the next `stream` call.
222    ///
223    /// The error is returned from `stream()` itself (before any events
224    /// are yielded), simulating failures like authentication errors or
225    /// network issues that prevent the stream from starting.
226    pub fn queue_stream_error(&self, error: MockError) -> &Self {
227        self.stream_responses.lock().unwrap().push_back(Err(error));
228        self
229    }
230
231    /// Returns a clone of all `ChatParams` passed to `generate` or
232    /// `stream`, in call order.
233    pub fn recorded_calls(&self) -> Vec<ChatParams> {
234        self.calls.lock().unwrap().clone()
235    }
236
237    fn record_call(&self, params: &ChatParams) {
238        self.calls.lock().unwrap().push(params.clone());
239    }
240}
241
242impl Provider for MockProvider {
243    async fn generate(&self, params: &ChatParams) -> Result<ChatResponse, LlmError> {
244        self.record_call(params);
245        let result = self
246            .responses
247            .lock()
248            .unwrap()
249            .pop_front()
250            .expect("MockProvider: no queued responses remaining");
251        result.map_err(MockError::into_llm_error)
252    }
253
254    async fn stream(&self, params: &ChatParams) -> Result<ChatStream, LlmError> {
255        self.record_call(params);
256        let result = self
257            .stream_responses
258            .lock()
259            .unwrap()
260            .pop_front()
261            .expect("MockProvider: no queued stream responses remaining");
262        let events = result.map_err(MockError::into_llm_error)?;
263        let stream = futures::stream::iter(events.into_iter().map(Ok));
264        Ok(Box::pin(stream))
265    }
266
267    fn metadata(&self) -> ProviderMetadata {
268        self.meta.clone()
269    }
270}
271
272#[cfg(test)]
273mod tests {
274    use super::*;
275    use crate::chat::{ContentBlock, StopReason};
276    use crate::provider::{Capability, DynProvider};
277    use crate::test_helpers::sample_response;
278    use futures::StreamExt;
279    use std::collections::HashSet;
280
281    fn test_metadata() -> ProviderMetadata {
282        ProviderMetadata {
283            name: "mock".into(),
284            model: "test-model".into(),
285            context_window: 128_000,
286            capabilities: HashSet::from([Capability::Tools, Capability::StructuredOutput]),
287        }
288    }
289
290    #[tokio::test]
291    async fn test_mock_generate_returns_queued() {
292        let mock = MockProvider::new(test_metadata());
293        let resp = sample_response("test");
294        mock.queue_response(resp.clone());
295
296        let result = mock.generate(&ChatParams::default()).await.unwrap();
297        assert_eq!(result, resp);
298    }
299
300    #[tokio::test]
301    async fn test_mock_generate_multiple_queued() {
302        let mock = MockProvider::new(test_metadata());
303        mock.queue_response(sample_response("first"));
304        mock.queue_response(sample_response("second"));
305        mock.queue_response(sample_response("third"));
306
307        let r1 = mock.generate(&ChatParams::default()).await.unwrap();
308        let r2 = mock.generate(&ChatParams::default()).await.unwrap();
309        let r3 = mock.generate(&ChatParams::default()).await.unwrap();
310
311        assert_eq!(r1.content, vec![ContentBlock::Text("first".into())]);
312        assert_eq!(r2.content, vec![ContentBlock::Text("second".into())]);
313        assert_eq!(r3.content, vec![ContentBlock::Text("third".into())]);
314    }
315
316    #[tokio::test]
317    async fn test_mock_generate_error() {
318        let mock = MockProvider::new(test_metadata());
319        mock.queue_error(MockError::Auth("bad key".into()));
320
321        let result = mock.generate(&ChatParams::default()).await;
322        assert!(result.is_err());
323        assert!(matches!(result.unwrap_err(), LlmError::Auth(_)));
324    }
325
326    #[tokio::test]
327    async fn test_mock_generate_mixed_queue() {
328        let mock = MockProvider::new(test_metadata());
329        mock.queue_response(sample_response("ok"));
330        mock.queue_error(MockError::Timeout { elapsed_ms: 5000 });
331        mock.queue_response(sample_response("ok again"));
332
333        let r1 = mock.generate(&ChatParams::default()).await;
334        let r2 = mock.generate(&ChatParams::default()).await;
335        let r3 = mock.generate(&ChatParams::default()).await;
336
337        assert!(r1.is_ok());
338        assert!(r2.is_err());
339        assert!(r3.is_ok());
340    }
341
342    #[tokio::test]
343    #[should_panic(expected = "no queued responses")]
344    async fn test_mock_generate_empty_queue_panics() {
345        let mock = MockProvider::new(test_metadata());
346        let _ = mock.generate(&ChatParams::default()).await;
347    }
348
349    #[tokio::test]
350    async fn test_mock_stream_returns_events() {
351        let mock = MockProvider::new(test_metadata());
352        mock.queue_stream(vec![
353            StreamEvent::TextDelta("hello".into()),
354            StreamEvent::TextDelta(" world".into()),
355            StreamEvent::Done {
356                stop_reason: StopReason::EndTurn,
357            },
358        ]);
359
360        let stream = mock.stream(&ChatParams::default()).await.unwrap();
361        let events: Vec<_> = stream.collect().await;
362        assert_eq!(events.len(), 3);
363        assert!(events.iter().all(Result::is_ok));
364    }
365
366    #[tokio::test]
367    async fn test_mock_stream_error() {
368        let mock = MockProvider::new(test_metadata());
369        mock.queue_stream_error(MockError::Auth("bad token".into()));
370
371        let result = mock.stream(&ChatParams::default()).await;
372        assert!(result.is_err());
373        let err = result.err().unwrap();
374        assert!(matches!(err, LlmError::Auth(_)));
375    }
376
377    #[tokio::test]
378    async fn test_mock_stream_empty_events() {
379        let mock = MockProvider::new(test_metadata());
380        mock.queue_stream(vec![]);
381
382        let stream = mock.stream(&ChatParams::default()).await.unwrap();
383        let events: Vec<_> = stream.collect().await;
384        assert!(events.is_empty());
385    }
386
387    #[tokio::test]
388    async fn test_mock_records_calls() {
389        let mock = MockProvider::new(test_metadata());
390        mock.queue_response(sample_response("a"));
391        mock.queue_response(sample_response("b"));
392        mock.queue_response(sample_response("c"));
393
394        let _ = mock.generate(&ChatParams::default()).await;
395        let _ = mock.generate(&ChatParams::default()).await;
396        let _ = mock.generate(&ChatParams::default()).await;
397
398        assert_eq!(mock.recorded_calls().len(), 3);
399    }
400
401    #[tokio::test]
402    async fn test_mock_records_params_accurately() {
403        let mock = MockProvider::new(test_metadata());
404        mock.queue_response(sample_response("ok"));
405
406        let params = ChatParams {
407            temperature: Some(0.5),
408            system: Some("be nice".into()),
409            ..Default::default()
410        };
411        let _ = mock.generate(&params).await;
412
413        let recorded = mock.recorded_calls();
414        assert_eq!(recorded[0].temperature, Some(0.5));
415        assert_eq!(recorded[0].system, Some("be nice".into()));
416    }
417
418    #[test]
419    fn test_mock_metadata_returns_configured() {
420        let meta = test_metadata();
421        let mock = MockProvider::new(meta.clone());
422        assert_eq!(Provider::metadata(&mock), meta);
423    }
424
425    #[tokio::test]
426    async fn test_mock_concurrent_access() {
427        let mock = Arc::new(MockProvider::new(test_metadata()));
428        for _ in 0..10 {
429            mock.queue_response(sample_response("ok"));
430        }
431
432        let mut handles = Vec::new();
433        for _ in 0..10 {
434            let m = mock.clone();
435            handles.push(tokio::spawn(async move {
436                m.generate(&ChatParams::default()).await.unwrap()
437            }));
438        }
439
440        for h in handles {
441            h.await.unwrap();
442        }
443
444        assert_eq!(mock.recorded_calls().len(), 10);
445    }
446
447    // --- DynProvider tests (through mock) ---
448
449    #[tokio::test]
450    async fn test_dyn_provider_blanket_impl() {
451        let mock = MockProvider::new(test_metadata());
452        mock.queue_response(sample_response("hello"));
453
454        let dyn_provider: &dyn DynProvider = &mock;
455        let params = ChatParams::default();
456        let result = dyn_provider.generate_boxed(&params).await;
457        assert!(result.is_ok());
458    }
459
460    #[tokio::test]
461    async fn test_dyn_provider_error_propagation() {
462        let mock = MockProvider::new(test_metadata());
463        mock.queue_error(MockError::Http {
464            status: Some(http::StatusCode::TOO_MANY_REQUESTS),
465            message: "rate limited".into(),
466            retryable: true,
467        });
468
469        let dyn_provider: &dyn DynProvider = &mock;
470        let result = dyn_provider.generate_boxed(&ChatParams::default()).await;
471        assert!(result.is_err());
472        assert!(matches!(result.unwrap_err(), LlmError::Http { .. }));
473    }
474
475    #[tokio::test]
476    async fn test_dyn_provider_stream_blanket() {
477        let mock = MockProvider::new(test_metadata());
478        mock.queue_stream(vec![
479            StreamEvent::TextDelta("hi".into()),
480            StreamEvent::Done {
481                stop_reason: StopReason::EndTurn,
482            },
483        ]);
484
485        let dyn_provider: &dyn DynProvider = &mock;
486        let params = ChatParams::default();
487        let stream = dyn_provider.stream_boxed(&params).await.unwrap();
488        let events: Vec<_> = stream.collect().await;
489        assert_eq!(events.len(), 2);
490    }
491
492    #[tokio::test]
493    async fn test_dyn_provider_metadata_matches() {
494        let mock = MockProvider::new(test_metadata());
495        let dyn_provider: &dyn DynProvider = &mock;
496        assert_eq!(Provider::metadata(&mock), dyn_provider.metadata());
497    }
498
499    #[tokio::test]
500    async fn test_dyn_provider_boxed_storage() {
501        let mock = MockProvider::new(test_metadata());
502        mock.queue_response(sample_response("from box"));
503
504        let boxed: Box<dyn DynProvider> = Box::new(mock);
505        let result = boxed.generate_boxed(&ChatParams::default()).await.unwrap();
506        assert_eq!(result.content, vec![ContentBlock::Text("from box".into())]);
507    }
508
509    #[test]
510    fn test_mock_provider_debug() {
511        let mock = MockProvider::new(test_metadata());
512        mock.queue_response(sample_response("a"));
513        mock.queue_stream(vec![StreamEvent::TextDelta("hi".into())]);
514
515        let debug = format!("{mock:?}");
516        assert!(debug.contains("MockProvider"));
517        assert!(debug.contains("queued_responses: 1"));
518        assert!(debug.contains("queued_streams: 1"));
519        assert!(debug.contains("recorded_calls: 0"));
520    }
521
522    #[test]
523    fn test_provider_is_object_safe() {
524        let f1: fn(&dyn DynProvider) = |_| {};
525        let f2: fn(Box<dyn DynProvider>) = |_| {};
526        // Suppress unused variable warnings
527        let _ = (f1, f2);
528    }
529
530    #[tokio::test]
531    async fn test_mock_error_into_llm_error_all_variants() {
532        let variants: Vec<(MockError, &str)> = vec![
533            (MockError::InvalidRequest("bad".into()), "InvalidRequest"),
534            (
535                MockError::Provider {
536                    code: "e1".into(),
537                    message: "fail".into(),
538                    retryable: false,
539                },
540                "Provider",
541            ),
542            (
543                MockError::ResponseFormat {
544                    message: "bad json".into(),
545                    raw: "{}".into(),
546                },
547                "ResponseFormat",
548            ),
549            (
550                MockError::SchemaValidation {
551                    message: "missing field".into(),
552                    schema: serde_json::json!({"type": "object"}),
553                    actual: serde_json::json!(42),
554                },
555                "SchemaValidation",
556            ),
557            (
558                MockError::RetryExhausted {
559                    attempts: 3,
560                    last_error_message: "timed out".into(),
561                },
562                "RetryExhausted",
563            ),
564        ];
565
566        for (mock_err, label) in variants {
567            let mock = MockProvider::new(test_metadata());
568            mock.queue_error(mock_err);
569            let result = mock.generate(&ChatParams::default()).await;
570            assert!(result.is_err(), "{label} should produce error");
571            let err = result.unwrap_err();
572            let debug = format!("{err:?}");
573            assert!(
574                debug.contains(label),
575                "expected {label} in error debug: {debug}"
576            );
577        }
578    }
579}