Skip to main content

llm_stack/
mock.rs

1//! Mock provider for testing.
2//!
3//! [`MockProvider`] is a queue-based fake that lets tests control
4//! exactly what responses and errors a provider returns, without
5//! touching the network. It implements [`Provider`],
6//! so it works anywhere a real provider does — including through
7//! [`DynProvider`](crate::DynProvider) via the blanket impl.
8//!
9//! # Usage
10//!
11//! ```rust,no_run
12//! use llm_stack::mock::{MockProvider, MockError};
13//! use llm_stack::{Provider, ChatParams, ChatResponse, ContentBlock, StopReason, Usage};
14//! use std::collections::{HashMap, HashSet};
15//!
16//! # async fn example() {
17//! let mock = MockProvider::new(llm_stack::ProviderMetadata {
18//!     name: "test".into(),
19//!     model: "test-model".into(),
20//!     context_window: 4096,
21//!     capabilities: HashSet::new(),
22//! });
23//!
24//! mock.queue_response(ChatResponse {
25//!     content: vec![ContentBlock::Text("Hello!".into())],
26//!     usage: Usage::default(),
27//!     stop_reason: StopReason::EndTurn,
28//!     model: "test-model".into(),
29//!     metadata: HashMap::new(),
30//! });
31//!
32//! let resp = mock.generate(&ChatParams::default()).await.unwrap();
33//! assert_eq!(mock.recorded_calls().len(), 1);
34//! # }
35//! ```
36//!
37//! # Why `MockError` instead of `LlmError`?
38//!
39//! [`LlmError`] contains `Box<dyn Error>` and is not
40//! `Clone`, so it can't be stored in a queue. [`MockError`] mirrors the
41//! common error variants in a cloneable form and converts to `LlmError`
42//! at dequeue time.
43
44use std::collections::VecDeque;
45use std::fmt;
46use std::sync::{Arc, Mutex};
47
48use crate::chat::ChatResponse;
49use crate::error::LlmError;
50use crate::provider::{ChatParams, Provider, ProviderMetadata};
51use crate::stream::{ChatStream, StreamEvent};
52
53/// A queue-based mock provider for unit and integration tests.
54///
55/// Push responses with [`queue_response`](Self::queue_response) and
56/// errors with [`queue_error`](Self::queue_error). Each call to
57/// `generate` or `stream` pops from the front of the respective queue.
58///
59/// Every call records its [`ChatParams`] for later assertion via
60/// [`recorded_calls`](Self::recorded_calls).
61///
62/// # Panics
63///
64/// [`generate`](Provider::generate) panics if the response queue is empty.
65/// [`stream`](Provider::stream) panics if the stream queue is empty.
66pub struct MockProvider {
67    responses: Mutex<VecDeque<Result<ChatResponse, MockError>>>,
68    stream_responses: Mutex<VecDeque<Result<Vec<StreamEvent>, MockError>>>,
69    meta: ProviderMetadata,
70    calls: Arc<Mutex<Vec<ChatParams>>>,
71}
72
73/// Cloneable error subset for mock queuing.
74///
75/// [`LlmError`] contains `Box<dyn Error>` and is not `Clone`, so it
76/// can't be queued directly. This type mirrors the common error
77/// variants. Use [`queue_error`](MockProvider::queue_error) to enqueue
78/// one — it is converted to `LlmError` when dequeued.
79#[derive(Debug, Clone)]
80pub enum MockError {
81    /// Maps to [`LlmError::Http`].
82    Http {
83        /// HTTP status code, if any.
84        status: Option<http::StatusCode>,
85        /// Error message.
86        message: String,
87        /// Whether the error is retryable.
88        retryable: bool,
89    },
90    /// Maps to [`LlmError::Auth`].
91    Auth(String),
92    /// Maps to [`LlmError::InvalidRequest`].
93    InvalidRequest(String),
94    /// Maps to [`LlmError::Provider`].
95    Provider {
96        /// Provider error code.
97        code: String,
98        /// Error message.
99        message: String,
100        /// Whether the error is retryable.
101        retryable: bool,
102    },
103    /// Maps to [`LlmError::Timeout`].
104    Timeout {
105        /// Elapsed milliseconds.
106        elapsed_ms: u64,
107    },
108    /// Maps to [`LlmError::ResponseFormat`].
109    ResponseFormat {
110        /// What went wrong during parsing.
111        message: String,
112        /// The raw response body.
113        raw: String,
114    },
115    /// Maps to [`LlmError::SchemaValidation`].
116    SchemaValidation {
117        /// Validation error messages.
118        message: String,
119        /// The schema that was violated.
120        schema: serde_json::Value,
121        /// The value that failed validation.
122        actual: serde_json::Value,
123    },
124    /// Maps to [`LlmError::RetryExhausted`].
125    RetryExhausted {
126        /// How many attempts were made.
127        attempts: u32,
128        /// Description of the last error.
129        last_error_message: String,
130    },
131}
132
133impl MockError {
134    fn into_llm_error(self) -> LlmError {
135        match self {
136            Self::Http {
137                status,
138                message,
139                retryable,
140            } => LlmError::Http {
141                status,
142                message,
143                retryable,
144            },
145            Self::Auth(msg) => LlmError::Auth(msg),
146            Self::InvalidRequest(msg) => LlmError::InvalidRequest(msg),
147            Self::Provider {
148                code,
149                message,
150                retryable,
151            } => LlmError::Provider {
152                code,
153                message,
154                retryable,
155            },
156            Self::Timeout { elapsed_ms } => LlmError::Timeout { elapsed_ms },
157            Self::ResponseFormat { message, raw } => LlmError::ResponseFormat { message, raw },
158            Self::SchemaValidation {
159                message,
160                schema,
161                actual,
162            } => LlmError::SchemaValidation {
163                message,
164                schema,
165                actual,
166            },
167            Self::RetryExhausted {
168                attempts,
169                last_error_message,
170            } => LlmError::RetryExhausted {
171                attempts,
172                last_error: Box::new(LlmError::InvalidRequest(last_error_message)),
173            },
174        }
175    }
176}
177
178impl fmt::Debug for MockProvider {
179    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
180        let response_len = self.responses.lock().unwrap().len();
181        let stream_len = self.stream_responses.lock().unwrap().len();
182        let call_count = self.calls.lock().unwrap().len();
183        f.debug_struct("MockProvider")
184            .field("meta", &self.meta)
185            .field("queued_responses", &response_len)
186            .field("queued_streams", &stream_len)
187            .field("recorded_calls", &call_count)
188            .finish()
189    }
190}
191
192impl MockProvider {
193    /// Creates a new mock with the given metadata and empty queues.
194    pub fn new(meta: ProviderMetadata) -> Self {
195        Self {
196            responses: Mutex::new(VecDeque::new()),
197            stream_responses: Mutex::new(VecDeque::new()),
198            meta,
199            calls: Arc::new(Mutex::new(Vec::new())),
200        }
201    }
202
203    /// Enqueues a successful response for the next `generate` call.
204    pub fn queue_response(&self, response: ChatResponse) -> &Self {
205        self.responses.lock().unwrap().push_back(Ok(response));
206        self
207    }
208
209    /// Enqueues an error for the next `generate` call.
210    pub fn queue_error(&self, error: MockError) -> &Self {
211        self.responses.lock().unwrap().push_back(Err(error));
212        self
213    }
214
215    /// Enqueues stream events for the next `stream` call.
216    pub fn queue_stream(&self, events: Vec<StreamEvent>) -> &Self {
217        self.stream_responses.lock().unwrap().push_back(Ok(events));
218        self
219    }
220
221    /// Enqueues an error for the next `stream` call.
222    ///
223    /// The error is returned from `stream()` itself (before any events
224    /// are yielded), simulating failures like authentication errors or
225    /// network issues that prevent the stream from starting.
226    pub fn queue_stream_error(&self, error: MockError) -> &Self {
227        self.stream_responses.lock().unwrap().push_back(Err(error));
228        self
229    }
230
231    /// Returns a clone of all `ChatParams` passed to `generate` or
232    /// `stream`, in call order.
233    pub fn recorded_calls(&self) -> Vec<ChatParams> {
234        self.calls.lock().unwrap().clone()
235    }
236
237    fn record_call(&self, params: &ChatParams) {
238        self.calls.lock().unwrap().push(params.clone());
239    }
240}
241
242/// Convert a `ChatResponse` into equivalent `StreamEvent`s.
243///
244/// Used by `MockProvider::stream()` when the stream queue is empty but
245/// the response queue has entries — enables existing tests that use
246/// `queue_response` to work transparently with `stream_boxed()`.
247fn response_to_stream_events(response: &ChatResponse) -> Vec<StreamEvent> {
248    use crate::chat::ContentBlock;
249
250    let mut events = Vec::new();
251    let mut tool_index = 0u32;
252
253    for block in &response.content {
254        match block {
255            ContentBlock::Text(text) => {
256                events.push(StreamEvent::TextDelta(text.clone()));
257            }
258            ContentBlock::ToolCall(call) => {
259                events.push(StreamEvent::ToolCallStart {
260                    index: tool_index,
261                    id: call.id.clone(),
262                    name: call.name.clone(),
263                });
264                events.push(StreamEvent::ToolCallComplete {
265                    index: tool_index,
266                    call: call.clone(),
267                });
268                tool_index += 1;
269            }
270            _ => {}
271        }
272    }
273
274    events.push(StreamEvent::Usage(response.usage.clone()));
275
276    events.push(StreamEvent::Done {
277        stop_reason: response.stop_reason,
278    });
279
280    events
281}
282
283impl Provider for MockProvider {
284    async fn generate(&self, params: &ChatParams) -> Result<ChatResponse, LlmError> {
285        self.record_call(params);
286        let result = self
287            .responses
288            .lock()
289            .unwrap()
290            .pop_front()
291            .expect("MockProvider: no queued responses remaining");
292        result.map_err(MockError::into_llm_error)
293    }
294
295    async fn stream(&self, params: &ChatParams) -> Result<ChatStream, LlmError> {
296        self.record_call(params);
297
298        // Try the stream queue first; fall back to the response queue
299        // (auto-converting a ChatResponse into StreamEvents).
300        if let Some(result) = self.stream_responses.lock().unwrap().pop_front() {
301            let events = result.map_err(MockError::into_llm_error)?;
302            let stream = futures::stream::iter(events.into_iter().map(Ok));
303            return Ok(Box::pin(stream));
304        }
305
306        let result = self
307            .responses
308            .lock()
309            .unwrap()
310            .pop_front()
311            .expect("MockProvider: no queued responses or stream responses remaining");
312        let response = result.map_err(MockError::into_llm_error)?;
313        let events = response_to_stream_events(&response);
314        let stream = futures::stream::iter(events.into_iter().map(Ok));
315        Ok(Box::pin(stream))
316    }
317
318    fn metadata(&self) -> ProviderMetadata {
319        self.meta.clone()
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326    use crate::chat::{ContentBlock, StopReason};
327    use crate::provider::{Capability, DynProvider};
328    use crate::test_helpers::sample_response;
329    use futures::StreamExt;
330    use std::collections::HashSet;
331
332    fn test_metadata() -> ProviderMetadata {
333        ProviderMetadata {
334            name: "mock".into(),
335            model: "test-model".into(),
336            context_window: 128_000,
337            capabilities: HashSet::from([Capability::Tools, Capability::StructuredOutput]),
338        }
339    }
340
341    #[tokio::test]
342    async fn test_mock_generate_returns_queued() {
343        let mock = MockProvider::new(test_metadata());
344        let resp = sample_response("test");
345        mock.queue_response(resp.clone());
346
347        let result = mock.generate(&ChatParams::default()).await.unwrap();
348        assert_eq!(result, resp);
349    }
350
351    #[tokio::test]
352    async fn test_mock_generate_multiple_queued() {
353        let mock = MockProvider::new(test_metadata());
354        mock.queue_response(sample_response("first"));
355        mock.queue_response(sample_response("second"));
356        mock.queue_response(sample_response("third"));
357
358        let r1 = mock.generate(&ChatParams::default()).await.unwrap();
359        let r2 = mock.generate(&ChatParams::default()).await.unwrap();
360        let r3 = mock.generate(&ChatParams::default()).await.unwrap();
361
362        assert_eq!(r1.content, vec![ContentBlock::Text("first".into())]);
363        assert_eq!(r2.content, vec![ContentBlock::Text("second".into())]);
364        assert_eq!(r3.content, vec![ContentBlock::Text("third".into())]);
365    }
366
367    #[tokio::test]
368    async fn test_mock_generate_error() {
369        let mock = MockProvider::new(test_metadata());
370        mock.queue_error(MockError::Auth("bad key".into()));
371
372        let result = mock.generate(&ChatParams::default()).await;
373        assert!(result.is_err());
374        assert!(matches!(result.unwrap_err(), LlmError::Auth(_)));
375    }
376
377    #[tokio::test]
378    async fn test_mock_generate_mixed_queue() {
379        let mock = MockProvider::new(test_metadata());
380        mock.queue_response(sample_response("ok"));
381        mock.queue_error(MockError::Timeout { elapsed_ms: 5000 });
382        mock.queue_response(sample_response("ok again"));
383
384        let r1 = mock.generate(&ChatParams::default()).await;
385        let r2 = mock.generate(&ChatParams::default()).await;
386        let r3 = mock.generate(&ChatParams::default()).await;
387
388        assert!(r1.is_ok());
389        assert!(r2.is_err());
390        assert!(r3.is_ok());
391    }
392
393    #[tokio::test]
394    #[should_panic(expected = "no queued responses")]
395    async fn test_mock_generate_empty_queue_panics() {
396        let mock = MockProvider::new(test_metadata());
397        let _ = mock.generate(&ChatParams::default()).await;
398    }
399
400    #[tokio::test]
401    async fn test_mock_stream_returns_events() {
402        let mock = MockProvider::new(test_metadata());
403        mock.queue_stream(vec![
404            StreamEvent::TextDelta("hello".into()),
405            StreamEvent::TextDelta(" world".into()),
406            StreamEvent::Done {
407                stop_reason: StopReason::EndTurn,
408            },
409        ]);
410
411        let stream = mock.stream(&ChatParams::default()).await.unwrap();
412        let events: Vec<_> = stream.collect().await;
413        assert_eq!(events.len(), 3);
414        assert!(events.iter().all(Result::is_ok));
415    }
416
417    #[tokio::test]
418    async fn test_mock_stream_error() {
419        let mock = MockProvider::new(test_metadata());
420        mock.queue_stream_error(MockError::Auth("bad token".into()));
421
422        let result = mock.stream(&ChatParams::default()).await;
423        assert!(result.is_err());
424        let err = result.err().unwrap();
425        assert!(matches!(err, LlmError::Auth(_)));
426    }
427
428    #[tokio::test]
429    async fn test_mock_stream_empty_events() {
430        let mock = MockProvider::new(test_metadata());
431        mock.queue_stream(vec![]);
432
433        let stream = mock.stream(&ChatParams::default()).await.unwrap();
434        let events: Vec<_> = stream.collect().await;
435        assert!(events.is_empty());
436    }
437
438    #[tokio::test]
439    async fn test_mock_records_calls() {
440        let mock = MockProvider::new(test_metadata());
441        mock.queue_response(sample_response("a"));
442        mock.queue_response(sample_response("b"));
443        mock.queue_response(sample_response("c"));
444
445        let _ = mock.generate(&ChatParams::default()).await;
446        let _ = mock.generate(&ChatParams::default()).await;
447        let _ = mock.generate(&ChatParams::default()).await;
448
449        assert_eq!(mock.recorded_calls().len(), 3);
450    }
451
452    #[tokio::test]
453    async fn test_mock_records_params_accurately() {
454        let mock = MockProvider::new(test_metadata());
455        mock.queue_response(sample_response("ok"));
456
457        let params = ChatParams {
458            temperature: Some(0.5),
459            system: Some("be nice".into()),
460            ..Default::default()
461        };
462        let _ = mock.generate(&params).await;
463
464        let recorded = mock.recorded_calls();
465        assert_eq!(recorded[0].temperature, Some(0.5));
466        assert_eq!(recorded[0].system, Some("be nice".into()));
467    }
468
469    #[test]
470    fn test_mock_metadata_returns_configured() {
471        let meta = test_metadata();
472        let mock = MockProvider::new(meta.clone());
473        assert_eq!(Provider::metadata(&mock), meta);
474    }
475
476    #[tokio::test]
477    async fn test_mock_concurrent_access() {
478        let mock = Arc::new(MockProvider::new(test_metadata()));
479        for _ in 0..10 {
480            mock.queue_response(sample_response("ok"));
481        }
482
483        let mut handles = Vec::new();
484        for _ in 0..10 {
485            let m = mock.clone();
486            handles.push(tokio::spawn(async move {
487                m.generate(&ChatParams::default()).await.unwrap()
488            }));
489        }
490
491        for h in handles {
492            h.await.unwrap();
493        }
494
495        assert_eq!(mock.recorded_calls().len(), 10);
496    }
497
498    // --- DynProvider tests (through mock) ---
499
500    #[tokio::test]
501    async fn test_dyn_provider_blanket_impl() {
502        let mock = MockProvider::new(test_metadata());
503        mock.queue_response(sample_response("hello"));
504
505        let dyn_provider: &dyn DynProvider = &mock;
506        let params = ChatParams::default();
507        let result = dyn_provider.generate_boxed(&params).await;
508        assert!(result.is_ok());
509    }
510
511    #[tokio::test]
512    async fn test_dyn_provider_error_propagation() {
513        let mock = MockProvider::new(test_metadata());
514        mock.queue_error(MockError::Http {
515            status: Some(http::StatusCode::TOO_MANY_REQUESTS),
516            message: "rate limited".into(),
517            retryable: true,
518        });
519
520        let dyn_provider: &dyn DynProvider = &mock;
521        let result = dyn_provider.generate_boxed(&ChatParams::default()).await;
522        assert!(result.is_err());
523        assert!(matches!(result.unwrap_err(), LlmError::Http { .. }));
524    }
525
526    #[tokio::test]
527    async fn test_dyn_provider_stream_blanket() {
528        let mock = MockProvider::new(test_metadata());
529        mock.queue_stream(vec![
530            StreamEvent::TextDelta("hi".into()),
531            StreamEvent::Done {
532                stop_reason: StopReason::EndTurn,
533            },
534        ]);
535
536        let dyn_provider: &dyn DynProvider = &mock;
537        let params = ChatParams::default();
538        let stream = dyn_provider.stream_boxed(&params).await.unwrap();
539        let events: Vec<_> = stream.collect().await;
540        assert_eq!(events.len(), 2);
541    }
542
543    #[tokio::test]
544    async fn test_dyn_provider_metadata_matches() {
545        let mock = MockProvider::new(test_metadata());
546        let dyn_provider: &dyn DynProvider = &mock;
547        assert_eq!(Provider::metadata(&mock), dyn_provider.metadata());
548    }
549
550    #[tokio::test]
551    async fn test_dyn_provider_boxed_storage() {
552        let mock = MockProvider::new(test_metadata());
553        mock.queue_response(sample_response("from box"));
554
555        let boxed: Box<dyn DynProvider> = Box::new(mock);
556        let result = boxed.generate_boxed(&ChatParams::default()).await.unwrap();
557        assert_eq!(result.content, vec![ContentBlock::Text("from box".into())]);
558    }
559
560    #[test]
561    fn test_mock_provider_debug() {
562        let mock = MockProvider::new(test_metadata());
563        mock.queue_response(sample_response("a"));
564        mock.queue_stream(vec![StreamEvent::TextDelta("hi".into())]);
565
566        let debug = format!("{mock:?}");
567        assert!(debug.contains("MockProvider"));
568        assert!(debug.contains("queued_responses: 1"));
569        assert!(debug.contains("queued_streams: 1"));
570        assert!(debug.contains("recorded_calls: 0"));
571    }
572
573    #[test]
574    fn test_provider_is_object_safe() {
575        let f1: fn(&dyn DynProvider) = |_| {};
576        let f2: fn(Box<dyn DynProvider>) = |_| {};
577        // Suppress unused variable warnings
578        let _ = (f1, f2);
579    }
580
581    #[tokio::test]
582    async fn test_mock_error_into_llm_error_all_variants() {
583        let variants: Vec<(MockError, &str)> = vec![
584            (MockError::InvalidRequest("bad".into()), "InvalidRequest"),
585            (
586                MockError::Provider {
587                    code: "e1".into(),
588                    message: "fail".into(),
589                    retryable: false,
590                },
591                "Provider",
592            ),
593            (
594                MockError::ResponseFormat {
595                    message: "bad json".into(),
596                    raw: "{}".into(),
597                },
598                "ResponseFormat",
599            ),
600            (
601                MockError::SchemaValidation {
602                    message: "missing field".into(),
603                    schema: serde_json::json!({"type": "object"}),
604                    actual: serde_json::json!(42),
605                },
606                "SchemaValidation",
607            ),
608            (
609                MockError::RetryExhausted {
610                    attempts: 3,
611                    last_error_message: "timed out".into(),
612                },
613                "RetryExhausted",
614            ),
615        ];
616
617        for (mock_err, label) in variants {
618            let mock = MockProvider::new(test_metadata());
619            mock.queue_error(mock_err);
620            let result = mock.generate(&ChatParams::default()).await;
621            assert!(result.is_err(), "{label} should produce error");
622            let err = result.unwrap_err();
623            let debug = format!("{err:?}");
624            assert!(
625                debug.contains(label),
626                "expected {label} in error debug: {debug}"
627            );
628        }
629    }
630}