Skip to main content

llm_stack/
mock.rs

1//! Mock provider for testing.
2//!
3//! [`MockProvider`] is a queue-based fake that lets tests control
4//! exactly what responses and errors a provider returns, without
5//! touching the network. It implements [`Provider`],
6//! so it works anywhere a real provider does — including through
7//! [`DynProvider`](crate::DynProvider) via the blanket impl.
8//!
9//! # Usage
10//!
11//! ```rust,no_run
12//! use llm_stack::mock::{MockProvider, MockError};
13//! use llm_stack::{Provider, ChatParams, ChatResponse, ContentBlock};
14//! use llm_stack::chat::StopReason;
15//! use llm_stack::usage::Usage;
16//! use std::collections::{HashMap, HashSet};
17//!
18//! # async fn example() {
19//! let mock = MockProvider::new(llm_stack::provider::ProviderMetadata {
20//!     name: "test".into(),
21//!     model: "test-model".into(),
22//!     context_window: 4096,
23//!     capabilities: HashSet::new(),
24//! });
25//!
26//! mock.queue_response(ChatResponse {
27//!     content: vec![ContentBlock::Text("Hello!".into())],
28//!     usage: Usage::default(),
29//!     stop_reason: StopReason::EndTurn,
30//!     model: "test-model".into(),
31//!     metadata: HashMap::new(),
32//! });
33//!
34//! let resp = mock.generate(&ChatParams::default()).await.unwrap();
35//! assert_eq!(mock.recorded_calls().len(), 1);
36//! # }
37//! ```
38//!
39//! # Why `MockError` instead of `LlmError`?
40//!
41//! [`LlmError`] contains `Box<dyn Error>` and is not
42//! `Clone`, so it can't be stored in a queue. [`MockError`] mirrors the
43//! common error variants in a cloneable form and converts to `LlmError`
44//! at dequeue time.
45
46use std::collections::VecDeque;
47use std::fmt;
48use std::sync::{Arc, Mutex};
49
50use crate::chat::ChatResponse;
51use crate::error::LlmError;
52use crate::provider::{ChatParams, Provider, ProviderMetadata};
53use crate::stream::{ChatStream, StreamEvent};
54
55/// A queue-based mock provider for unit and integration tests.
56///
57/// Push responses with [`queue_response`](Self::queue_response) and
58/// errors with [`queue_error`](Self::queue_error). Each call to
59/// `generate` or `stream` pops from the front of the respective queue.
60///
61/// Every call records its [`ChatParams`] for later assertion via
62/// [`recorded_calls`](Self::recorded_calls).
63///
64/// # Panics
65///
66/// [`generate`](Provider::generate) panics if the response queue is empty.
67/// [`stream`](Provider::stream) panics if the stream queue is empty.
68pub struct MockProvider {
69    responses: Mutex<VecDeque<Result<ChatResponse, MockError>>>,
70    stream_responses: Mutex<VecDeque<Result<Vec<StreamEvent>, MockError>>>,
71    meta: ProviderMetadata,
72    calls: Arc<Mutex<Vec<ChatParams>>>,
73}
74
75/// Cloneable error subset for mock queuing.
76///
77/// [`LlmError`] contains `Box<dyn Error>` and is not `Clone`, so it
78/// can't be queued directly. This type mirrors the common error
79/// variants. Use [`queue_error`](MockProvider::queue_error) to enqueue
80/// one — it is converted to `LlmError` when dequeued.
81#[derive(Debug, Clone)]
82pub enum MockError {
83    /// Maps to [`LlmError::Http`].
84    Http {
85        /// HTTP status code, if any.
86        status: Option<http::StatusCode>,
87        /// Error message.
88        message: String,
89        /// Whether the error is retryable.
90        retryable: bool,
91    },
92    /// Maps to [`LlmError::Auth`].
93    Auth(String),
94    /// Maps to [`LlmError::InvalidRequest`].
95    InvalidRequest(String),
96    /// Maps to [`LlmError::Provider`].
97    Provider {
98        /// Provider error code.
99        code: String,
100        /// Error message.
101        message: String,
102        /// Whether the error is retryable.
103        retryable: bool,
104    },
105    /// Maps to [`LlmError::Timeout`].
106    Timeout {
107        /// Elapsed milliseconds.
108        elapsed_ms: u64,
109    },
110    /// Maps to [`LlmError::ResponseFormat`].
111    ResponseFormat {
112        /// What went wrong during parsing.
113        message: String,
114        /// The raw response body.
115        raw: String,
116    },
117    /// Maps to [`LlmError::SchemaValidation`].
118    SchemaValidation {
119        /// Validation error messages.
120        message: String,
121        /// The schema that was violated.
122        schema: serde_json::Value,
123        /// The value that failed validation.
124        actual: serde_json::Value,
125    },
126    /// Maps to [`LlmError::RetryExhausted`].
127    RetryExhausted {
128        /// How many attempts were made.
129        attempts: u32,
130        /// Description of the last error.
131        last_error_message: String,
132    },
133}
134
135impl MockError {
136    fn into_llm_error(self) -> LlmError {
137        match self {
138            Self::Http {
139                status,
140                message,
141                retryable,
142            } => LlmError::Http {
143                status,
144                message,
145                retryable,
146            },
147            Self::Auth(msg) => LlmError::Auth(msg),
148            Self::InvalidRequest(msg) => LlmError::InvalidRequest(msg),
149            Self::Provider {
150                code,
151                message,
152                retryable,
153            } => LlmError::Provider {
154                code,
155                message,
156                retryable,
157            },
158            Self::Timeout { elapsed_ms } => LlmError::Timeout { elapsed_ms },
159            Self::ResponseFormat { message, raw } => LlmError::ResponseFormat { message, raw },
160            Self::SchemaValidation {
161                message,
162                schema,
163                actual,
164            } => LlmError::SchemaValidation {
165                message,
166                schema,
167                actual,
168            },
169            Self::RetryExhausted {
170                attempts,
171                last_error_message,
172            } => LlmError::RetryExhausted {
173                attempts,
174                last_error: Box::new(LlmError::InvalidRequest(last_error_message)),
175            },
176        }
177    }
178}
179
180impl fmt::Debug for MockProvider {
181    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182        let response_len = self.responses.lock().unwrap().len();
183        let stream_len = self.stream_responses.lock().unwrap().len();
184        let call_count = self.calls.lock().unwrap().len();
185        f.debug_struct("MockProvider")
186            .field("meta", &self.meta)
187            .field("queued_responses", &response_len)
188            .field("queued_streams", &stream_len)
189            .field("recorded_calls", &call_count)
190            .finish()
191    }
192}
193
194impl MockProvider {
195    /// Creates a new mock with the given metadata and empty queues.
196    pub fn new(meta: ProviderMetadata) -> Self {
197        Self {
198            responses: Mutex::new(VecDeque::new()),
199            stream_responses: Mutex::new(VecDeque::new()),
200            meta,
201            calls: Arc::new(Mutex::new(Vec::new())),
202        }
203    }
204
205    /// Enqueues a successful response for the next `generate` call.
206    pub fn queue_response(&self, response: ChatResponse) -> &Self {
207        self.responses.lock().unwrap().push_back(Ok(response));
208        self
209    }
210
211    /// Enqueues an error for the next `generate` call.
212    pub fn queue_error(&self, error: MockError) -> &Self {
213        self.responses.lock().unwrap().push_back(Err(error));
214        self
215    }
216
217    /// Enqueues stream events for the next `stream` call.
218    pub fn queue_stream(&self, events: Vec<StreamEvent>) -> &Self {
219        self.stream_responses.lock().unwrap().push_back(Ok(events));
220        self
221    }
222
223    /// Enqueues an error for the next `stream` call.
224    ///
225    /// The error is returned from `stream()` itself (before any events
226    /// are yielded), simulating failures like authentication errors or
227    /// network issues that prevent the stream from starting.
228    pub fn queue_stream_error(&self, error: MockError) -> &Self {
229        self.stream_responses.lock().unwrap().push_back(Err(error));
230        self
231    }
232
233    /// Returns a clone of all `ChatParams` passed to `generate` or
234    /// `stream`, in call order.
235    pub fn recorded_calls(&self) -> Vec<ChatParams> {
236        self.calls.lock().unwrap().clone()
237    }
238
239    fn record_call(&self, params: &ChatParams) {
240        self.calls.lock().unwrap().push(params.clone());
241    }
242}
243
244/// Convert a `ChatResponse` into equivalent `StreamEvent`s.
245///
246/// Used by `MockProvider::stream()` when the stream queue is empty but
247/// the response queue has entries — enables existing tests that use
248/// `queue_response` to work transparently with `stream_boxed()`.
249fn response_to_stream_events(response: &ChatResponse) -> Vec<StreamEvent> {
250    use crate::chat::ContentBlock;
251
252    let mut events = Vec::new();
253    let mut tool_index = 0u32;
254
255    for block in &response.content {
256        match block {
257            ContentBlock::Text(text) => {
258                events.push(StreamEvent::TextDelta(text.clone()));
259            }
260            ContentBlock::ToolCall(call) => {
261                events.push(StreamEvent::ToolCallStart {
262                    index: tool_index,
263                    id: call.id.clone(),
264                    name: call.name.clone(),
265                });
266                events.push(StreamEvent::ToolCallComplete {
267                    index: tool_index,
268                    call: call.clone(),
269                });
270                tool_index += 1;
271            }
272            _ => {}
273        }
274    }
275
276    events.push(StreamEvent::Usage(response.usage.clone()));
277
278    events.push(StreamEvent::Done {
279        stop_reason: response.stop_reason,
280    });
281
282    events
283}
284
285impl Provider for MockProvider {
286    async fn generate(&self, params: &ChatParams) -> Result<ChatResponse, LlmError> {
287        self.record_call(params);
288        let result = self
289            .responses
290            .lock()
291            .unwrap()
292            .pop_front()
293            .expect("MockProvider: no queued responses remaining");
294        result.map_err(MockError::into_llm_error)
295    }
296
297    async fn stream(&self, params: &ChatParams) -> Result<ChatStream, LlmError> {
298        self.record_call(params);
299
300        // Try the stream queue first; fall back to the response queue
301        // (auto-converting a ChatResponse into StreamEvents).
302        if let Some(result) = self.stream_responses.lock().unwrap().pop_front() {
303            let events = result.map_err(MockError::into_llm_error)?;
304            let stream = futures::stream::iter(events.into_iter().map(Ok));
305            return Ok(Box::pin(stream));
306        }
307
308        let result = self
309            .responses
310            .lock()
311            .unwrap()
312            .pop_front()
313            .expect("MockProvider: no queued responses or stream responses remaining");
314        let response = result.map_err(MockError::into_llm_error)?;
315        let events = response_to_stream_events(&response);
316        let stream = futures::stream::iter(events.into_iter().map(Ok));
317        Ok(Box::pin(stream))
318    }
319
320    fn metadata(&self) -> ProviderMetadata {
321        self.meta.clone()
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328    use crate::chat::{ContentBlock, StopReason};
329    use crate::provider::{Capability, DynProvider};
330    use crate::test_helpers::sample_response;
331    use futures::StreamExt;
332    use std::collections::HashSet;
333
334    fn test_metadata() -> ProviderMetadata {
335        ProviderMetadata {
336            name: "mock".into(),
337            model: "test-model".into(),
338            context_window: 128_000,
339            capabilities: HashSet::from([Capability::Tools, Capability::StructuredOutput]),
340        }
341    }
342
343    #[tokio::test]
344    async fn test_mock_generate_returns_queued() {
345        let mock = MockProvider::new(test_metadata());
346        let resp = sample_response("test");
347        mock.queue_response(resp.clone());
348
349        let result = mock.generate(&ChatParams::default()).await.unwrap();
350        assert_eq!(result, resp);
351    }
352
353    #[tokio::test]
354    async fn test_mock_generate_multiple_queued() {
355        let mock = MockProvider::new(test_metadata());
356        mock.queue_response(sample_response("first"));
357        mock.queue_response(sample_response("second"));
358        mock.queue_response(sample_response("third"));
359
360        let r1 = mock.generate(&ChatParams::default()).await.unwrap();
361        let r2 = mock.generate(&ChatParams::default()).await.unwrap();
362        let r3 = mock.generate(&ChatParams::default()).await.unwrap();
363
364        assert_eq!(r1.content, vec![ContentBlock::Text("first".into())]);
365        assert_eq!(r2.content, vec![ContentBlock::Text("second".into())]);
366        assert_eq!(r3.content, vec![ContentBlock::Text("third".into())]);
367    }
368
369    #[tokio::test]
370    async fn test_mock_generate_error() {
371        let mock = MockProvider::new(test_metadata());
372        mock.queue_error(MockError::Auth("bad key".into()));
373
374        let result = mock.generate(&ChatParams::default()).await;
375        assert!(result.is_err());
376        assert!(matches!(result.unwrap_err(), LlmError::Auth(_)));
377    }
378
379    #[tokio::test]
380    async fn test_mock_generate_mixed_queue() {
381        let mock = MockProvider::new(test_metadata());
382        mock.queue_response(sample_response("ok"));
383        mock.queue_error(MockError::Timeout { elapsed_ms: 5000 });
384        mock.queue_response(sample_response("ok again"));
385
386        let r1 = mock.generate(&ChatParams::default()).await;
387        let r2 = mock.generate(&ChatParams::default()).await;
388        let r3 = mock.generate(&ChatParams::default()).await;
389
390        assert!(r1.is_ok());
391        assert!(r2.is_err());
392        assert!(r3.is_ok());
393    }
394
395    #[tokio::test]
396    #[should_panic(expected = "no queued responses")]
397    async fn test_mock_generate_empty_queue_panics() {
398        let mock = MockProvider::new(test_metadata());
399        let _ = mock.generate(&ChatParams::default()).await;
400    }
401
402    #[tokio::test]
403    async fn test_mock_stream_returns_events() {
404        let mock = MockProvider::new(test_metadata());
405        mock.queue_stream(vec![
406            StreamEvent::TextDelta("hello".into()),
407            StreamEvent::TextDelta(" world".into()),
408            StreamEvent::Done {
409                stop_reason: StopReason::EndTurn,
410            },
411        ]);
412
413        let stream = mock.stream(&ChatParams::default()).await.unwrap();
414        let events: Vec<_> = stream.collect().await;
415        assert_eq!(events.len(), 3);
416        assert!(events.iter().all(Result::is_ok));
417    }
418
419    #[tokio::test]
420    async fn test_mock_stream_error() {
421        let mock = MockProvider::new(test_metadata());
422        mock.queue_stream_error(MockError::Auth("bad token".into()));
423
424        let result = mock.stream(&ChatParams::default()).await;
425        assert!(result.is_err());
426        let err = result.err().unwrap();
427        assert!(matches!(err, LlmError::Auth(_)));
428    }
429
430    #[tokio::test]
431    async fn test_mock_stream_empty_events() {
432        let mock = MockProvider::new(test_metadata());
433        mock.queue_stream(vec![]);
434
435        let stream = mock.stream(&ChatParams::default()).await.unwrap();
436        let events: Vec<_> = stream.collect().await;
437        assert!(events.is_empty());
438    }
439
440    #[tokio::test]
441    async fn test_mock_records_calls() {
442        let mock = MockProvider::new(test_metadata());
443        mock.queue_response(sample_response("a"));
444        mock.queue_response(sample_response("b"));
445        mock.queue_response(sample_response("c"));
446
447        let _ = mock.generate(&ChatParams::default()).await;
448        let _ = mock.generate(&ChatParams::default()).await;
449        let _ = mock.generate(&ChatParams::default()).await;
450
451        assert_eq!(mock.recorded_calls().len(), 3);
452    }
453
454    #[tokio::test]
455    async fn test_mock_records_params_accurately() {
456        let mock = MockProvider::new(test_metadata());
457        mock.queue_response(sample_response("ok"));
458
459        let params = ChatParams {
460            temperature: Some(0.5),
461            system: Some("be nice".into()),
462            ..Default::default()
463        };
464        let _ = mock.generate(&params).await;
465
466        let recorded = mock.recorded_calls();
467        assert_eq!(recorded[0].temperature, Some(0.5));
468        assert_eq!(recorded[0].system, Some("be nice".into()));
469    }
470
471    #[test]
472    fn test_mock_metadata_returns_configured() {
473        let meta = test_metadata();
474        let mock = MockProvider::new(meta.clone());
475        assert_eq!(Provider::metadata(&mock), meta);
476    }
477
478    #[tokio::test]
479    async fn test_mock_concurrent_access() {
480        let mock = Arc::new(MockProvider::new(test_metadata()));
481        for _ in 0..10 {
482            mock.queue_response(sample_response("ok"));
483        }
484
485        let mut handles = Vec::new();
486        for _ in 0..10 {
487            let m = mock.clone();
488            handles.push(tokio::spawn(async move {
489                m.generate(&ChatParams::default()).await.unwrap()
490            }));
491        }
492
493        for h in handles {
494            h.await.unwrap();
495        }
496
497        assert_eq!(mock.recorded_calls().len(), 10);
498    }
499
500    // --- DynProvider tests (through mock) ---
501
502    #[tokio::test]
503    async fn test_dyn_provider_blanket_impl() {
504        let mock = MockProvider::new(test_metadata());
505        mock.queue_response(sample_response("hello"));
506
507        let dyn_provider: &dyn DynProvider = &mock;
508        let params = ChatParams::default();
509        let result = dyn_provider.generate_boxed(&params).await;
510        assert!(result.is_ok());
511    }
512
513    #[tokio::test]
514    async fn test_dyn_provider_error_propagation() {
515        let mock = MockProvider::new(test_metadata());
516        mock.queue_error(MockError::Http {
517            status: Some(http::StatusCode::TOO_MANY_REQUESTS),
518            message: "rate limited".into(),
519            retryable: true,
520        });
521
522        let dyn_provider: &dyn DynProvider = &mock;
523        let result = dyn_provider.generate_boxed(&ChatParams::default()).await;
524        assert!(result.is_err());
525        assert!(matches!(result.unwrap_err(), LlmError::Http { .. }));
526    }
527
528    #[tokio::test]
529    async fn test_dyn_provider_stream_blanket() {
530        let mock = MockProvider::new(test_metadata());
531        mock.queue_stream(vec![
532            StreamEvent::TextDelta("hi".into()),
533            StreamEvent::Done {
534                stop_reason: StopReason::EndTurn,
535            },
536        ]);
537
538        let dyn_provider: &dyn DynProvider = &mock;
539        let params = ChatParams::default();
540        let stream = dyn_provider.stream_boxed(&params).await.unwrap();
541        let events: Vec<_> = stream.collect().await;
542        assert_eq!(events.len(), 2);
543    }
544
545    #[tokio::test]
546    async fn test_dyn_provider_metadata_matches() {
547        let mock = MockProvider::new(test_metadata());
548        let dyn_provider: &dyn DynProvider = &mock;
549        assert_eq!(Provider::metadata(&mock), dyn_provider.metadata());
550    }
551
552    #[tokio::test]
553    async fn test_dyn_provider_boxed_storage() {
554        let mock = MockProvider::new(test_metadata());
555        mock.queue_response(sample_response("from box"));
556
557        let boxed: Box<dyn DynProvider> = Box::new(mock);
558        let result = boxed.generate_boxed(&ChatParams::default()).await.unwrap();
559        assert_eq!(result.content, vec![ContentBlock::Text("from box".into())]);
560    }
561
562    #[test]
563    fn test_mock_provider_debug() {
564        let mock = MockProvider::new(test_metadata());
565        mock.queue_response(sample_response("a"));
566        mock.queue_stream(vec![StreamEvent::TextDelta("hi".into())]);
567
568        let debug = format!("{mock:?}");
569        assert!(debug.contains("MockProvider"));
570        assert!(debug.contains("queued_responses: 1"));
571        assert!(debug.contains("queued_streams: 1"));
572        assert!(debug.contains("recorded_calls: 0"));
573    }
574
575    #[test]
576    fn test_provider_is_object_safe() {
577        let f1: fn(&dyn DynProvider) = |_| {};
578        let f2: fn(Box<dyn DynProvider>) = |_| {};
579        // Suppress unused variable warnings
580        let _ = (f1, f2);
581    }
582
583    #[tokio::test]
584    async fn test_mock_error_into_llm_error_all_variants() {
585        let variants: Vec<(MockError, &str)> = vec![
586            (MockError::InvalidRequest("bad".into()), "InvalidRequest"),
587            (
588                MockError::Provider {
589                    code: "e1".into(),
590                    message: "fail".into(),
591                    retryable: false,
592                },
593                "Provider",
594            ),
595            (
596                MockError::ResponseFormat {
597                    message: "bad json".into(),
598                    raw: "{}".into(),
599                },
600                "ResponseFormat",
601            ),
602            (
603                MockError::SchemaValidation {
604                    message: "missing field".into(),
605                    schema: serde_json::json!({"type": "object"}),
606                    actual: serde_json::json!(42),
607                },
608                "SchemaValidation",
609            ),
610            (
611                MockError::RetryExhausted {
612                    attempts: 3,
613                    last_error_message: "timed out".into(),
614                },
615                "RetryExhausted",
616            ),
617        ];
618
619        for (mock_err, label) in variants {
620            let mock = MockProvider::new(test_metadata());
621            mock.queue_error(mock_err);
622            let result = mock.generate(&ChatParams::default()).await;
623            assert!(result.is_err(), "{label} should produce error");
624            let err = result.unwrap_err();
625            let debug = format!("{err:?}");
626            assert!(
627                debug.contains(label),
628                "expected {label} in error debug: {debug}"
629            );
630        }
631    }
632}