1use std::collections::VecDeque;
45use std::fmt;
46use std::sync::{Arc, Mutex};
47
48use crate::chat::ChatResponse;
49use crate::error::LlmError;
50use crate::provider::{ChatParams, Provider, ProviderMetadata};
51use crate::stream::{ChatStream, StreamEvent};
52
53pub struct MockProvider {
67 responses: Mutex<VecDeque<Result<ChatResponse, MockError>>>,
68 stream_responses: Mutex<VecDeque<Result<Vec<StreamEvent>, MockError>>>,
69 meta: ProviderMetadata,
70 calls: Arc<Mutex<Vec<ChatParams>>>,
71}
72
73#[derive(Debug, Clone)]
80pub enum MockError {
81 Http {
83 status: Option<http::StatusCode>,
85 message: String,
87 retryable: bool,
89 },
90 Auth(String),
92 InvalidRequest(String),
94 Provider {
96 code: String,
98 message: String,
100 retryable: bool,
102 },
103 Timeout {
105 elapsed_ms: u64,
107 },
108 ResponseFormat {
110 message: String,
112 raw: String,
114 },
115 SchemaValidation {
117 message: String,
119 schema: serde_json::Value,
121 actual: serde_json::Value,
123 },
124 RetryExhausted {
126 attempts: u32,
128 last_error_message: String,
130 },
131}
132
133impl MockError {
134 fn into_llm_error(self) -> LlmError {
135 match self {
136 Self::Http {
137 status,
138 message,
139 retryable,
140 } => LlmError::Http {
141 status,
142 message,
143 retryable,
144 },
145 Self::Auth(msg) => LlmError::Auth(msg),
146 Self::InvalidRequest(msg) => LlmError::InvalidRequest(msg),
147 Self::Provider {
148 code,
149 message,
150 retryable,
151 } => LlmError::Provider {
152 code,
153 message,
154 retryable,
155 },
156 Self::Timeout { elapsed_ms } => LlmError::Timeout { elapsed_ms },
157 Self::ResponseFormat { message, raw } => LlmError::ResponseFormat { message, raw },
158 Self::SchemaValidation {
159 message,
160 schema,
161 actual,
162 } => LlmError::SchemaValidation {
163 message,
164 schema,
165 actual,
166 },
167 Self::RetryExhausted {
168 attempts,
169 last_error_message,
170 } => LlmError::RetryExhausted {
171 attempts,
172 last_error: Box::new(LlmError::InvalidRequest(last_error_message)),
173 },
174 }
175 }
176}
177
178impl fmt::Debug for MockProvider {
179 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
180 let response_len = self.responses.lock().unwrap().len();
181 let stream_len = self.stream_responses.lock().unwrap().len();
182 let call_count = self.calls.lock().unwrap().len();
183 f.debug_struct("MockProvider")
184 .field("meta", &self.meta)
185 .field("queued_responses", &response_len)
186 .field("queued_streams", &stream_len)
187 .field("recorded_calls", &call_count)
188 .finish()
189 }
190}
191
192impl MockProvider {
193 pub fn new(meta: ProviderMetadata) -> Self {
195 Self {
196 responses: Mutex::new(VecDeque::new()),
197 stream_responses: Mutex::new(VecDeque::new()),
198 meta,
199 calls: Arc::new(Mutex::new(Vec::new())),
200 }
201 }
202
203 pub fn queue_response(&self, response: ChatResponse) -> &Self {
205 self.responses.lock().unwrap().push_back(Ok(response));
206 self
207 }
208
209 pub fn queue_error(&self, error: MockError) -> &Self {
211 self.responses.lock().unwrap().push_back(Err(error));
212 self
213 }
214
215 pub fn queue_stream(&self, events: Vec<StreamEvent>) -> &Self {
217 self.stream_responses.lock().unwrap().push_back(Ok(events));
218 self
219 }
220
221 pub fn queue_stream_error(&self, error: MockError) -> &Self {
227 self.stream_responses.lock().unwrap().push_back(Err(error));
228 self
229 }
230
231 pub fn recorded_calls(&self) -> Vec<ChatParams> {
234 self.calls.lock().unwrap().clone()
235 }
236
237 fn record_call(&self, params: &ChatParams) {
238 self.calls.lock().unwrap().push(params.clone());
239 }
240}
241
242impl Provider for MockProvider {
243 async fn generate(&self, params: &ChatParams) -> Result<ChatResponse, LlmError> {
244 self.record_call(params);
245 let result = self
246 .responses
247 .lock()
248 .unwrap()
249 .pop_front()
250 .expect("MockProvider: no queued responses remaining");
251 result.map_err(MockError::into_llm_error)
252 }
253
254 async fn stream(&self, params: &ChatParams) -> Result<ChatStream, LlmError> {
255 self.record_call(params);
256 let result = self
257 .stream_responses
258 .lock()
259 .unwrap()
260 .pop_front()
261 .expect("MockProvider: no queued stream responses remaining");
262 let events = result.map_err(MockError::into_llm_error)?;
263 let stream = futures::stream::iter(events.into_iter().map(Ok));
264 Ok(Box::pin(stream))
265 }
266
267 fn metadata(&self) -> ProviderMetadata {
268 self.meta.clone()
269 }
270}
271
272#[cfg(test)]
273mod tests {
274 use super::*;
275 use crate::chat::{ContentBlock, StopReason};
276 use crate::provider::{Capability, DynProvider};
277 use crate::test_helpers::sample_response;
278 use futures::StreamExt;
279 use std::collections::HashSet;
280
281 fn test_metadata() -> ProviderMetadata {
282 ProviderMetadata {
283 name: "mock".into(),
284 model: "test-model".into(),
285 context_window: 128_000,
286 capabilities: HashSet::from([Capability::Tools, Capability::StructuredOutput]),
287 }
288 }
289
290 #[tokio::test]
291 async fn test_mock_generate_returns_queued() {
292 let mock = MockProvider::new(test_metadata());
293 let resp = sample_response("test");
294 mock.queue_response(resp.clone());
295
296 let result = mock.generate(&ChatParams::default()).await.unwrap();
297 assert_eq!(result, resp);
298 }
299
300 #[tokio::test]
301 async fn test_mock_generate_multiple_queued() {
302 let mock = MockProvider::new(test_metadata());
303 mock.queue_response(sample_response("first"));
304 mock.queue_response(sample_response("second"));
305 mock.queue_response(sample_response("third"));
306
307 let r1 = mock.generate(&ChatParams::default()).await.unwrap();
308 let r2 = mock.generate(&ChatParams::default()).await.unwrap();
309 let r3 = mock.generate(&ChatParams::default()).await.unwrap();
310
311 assert_eq!(r1.content, vec![ContentBlock::Text("first".into())]);
312 assert_eq!(r2.content, vec![ContentBlock::Text("second".into())]);
313 assert_eq!(r3.content, vec![ContentBlock::Text("third".into())]);
314 }
315
316 #[tokio::test]
317 async fn test_mock_generate_error() {
318 let mock = MockProvider::new(test_metadata());
319 mock.queue_error(MockError::Auth("bad key".into()));
320
321 let result = mock.generate(&ChatParams::default()).await;
322 assert!(result.is_err());
323 assert!(matches!(result.unwrap_err(), LlmError::Auth(_)));
324 }
325
326 #[tokio::test]
327 async fn test_mock_generate_mixed_queue() {
328 let mock = MockProvider::new(test_metadata());
329 mock.queue_response(sample_response("ok"));
330 mock.queue_error(MockError::Timeout { elapsed_ms: 5000 });
331 mock.queue_response(sample_response("ok again"));
332
333 let r1 = mock.generate(&ChatParams::default()).await;
334 let r2 = mock.generate(&ChatParams::default()).await;
335 let r3 = mock.generate(&ChatParams::default()).await;
336
337 assert!(r1.is_ok());
338 assert!(r2.is_err());
339 assert!(r3.is_ok());
340 }
341
342 #[tokio::test]
343 #[should_panic(expected = "no queued responses")]
344 async fn test_mock_generate_empty_queue_panics() {
345 let mock = MockProvider::new(test_metadata());
346 let _ = mock.generate(&ChatParams::default()).await;
347 }
348
349 #[tokio::test]
350 async fn test_mock_stream_returns_events() {
351 let mock = MockProvider::new(test_metadata());
352 mock.queue_stream(vec![
353 StreamEvent::TextDelta("hello".into()),
354 StreamEvent::TextDelta(" world".into()),
355 StreamEvent::Done {
356 stop_reason: StopReason::EndTurn,
357 },
358 ]);
359
360 let stream = mock.stream(&ChatParams::default()).await.unwrap();
361 let events: Vec<_> = stream.collect().await;
362 assert_eq!(events.len(), 3);
363 assert!(events.iter().all(Result::is_ok));
364 }
365
366 #[tokio::test]
367 async fn test_mock_stream_error() {
368 let mock = MockProvider::new(test_metadata());
369 mock.queue_stream_error(MockError::Auth("bad token".into()));
370
371 let result = mock.stream(&ChatParams::default()).await;
372 assert!(result.is_err());
373 let err = result.err().unwrap();
374 assert!(matches!(err, LlmError::Auth(_)));
375 }
376
377 #[tokio::test]
378 async fn test_mock_stream_empty_events() {
379 let mock = MockProvider::new(test_metadata());
380 mock.queue_stream(vec![]);
381
382 let stream = mock.stream(&ChatParams::default()).await.unwrap();
383 let events: Vec<_> = stream.collect().await;
384 assert!(events.is_empty());
385 }
386
387 #[tokio::test]
388 async fn test_mock_records_calls() {
389 let mock = MockProvider::new(test_metadata());
390 mock.queue_response(sample_response("a"));
391 mock.queue_response(sample_response("b"));
392 mock.queue_response(sample_response("c"));
393
394 let _ = mock.generate(&ChatParams::default()).await;
395 let _ = mock.generate(&ChatParams::default()).await;
396 let _ = mock.generate(&ChatParams::default()).await;
397
398 assert_eq!(mock.recorded_calls().len(), 3);
399 }
400
401 #[tokio::test]
402 async fn test_mock_records_params_accurately() {
403 let mock = MockProvider::new(test_metadata());
404 mock.queue_response(sample_response("ok"));
405
406 let params = ChatParams {
407 temperature: Some(0.5),
408 system: Some("be nice".into()),
409 ..Default::default()
410 };
411 let _ = mock.generate(¶ms).await;
412
413 let recorded = mock.recorded_calls();
414 assert_eq!(recorded[0].temperature, Some(0.5));
415 assert_eq!(recorded[0].system, Some("be nice".into()));
416 }
417
418 #[test]
419 fn test_mock_metadata_returns_configured() {
420 let meta = test_metadata();
421 let mock = MockProvider::new(meta.clone());
422 assert_eq!(Provider::metadata(&mock), meta);
423 }
424
425 #[tokio::test]
426 async fn test_mock_concurrent_access() {
427 let mock = Arc::new(MockProvider::new(test_metadata()));
428 for _ in 0..10 {
429 mock.queue_response(sample_response("ok"));
430 }
431
432 let mut handles = Vec::new();
433 for _ in 0..10 {
434 let m = mock.clone();
435 handles.push(tokio::spawn(async move {
436 m.generate(&ChatParams::default()).await.unwrap()
437 }));
438 }
439
440 for h in handles {
441 h.await.unwrap();
442 }
443
444 assert_eq!(mock.recorded_calls().len(), 10);
445 }
446
447 #[tokio::test]
450 async fn test_dyn_provider_blanket_impl() {
451 let mock = MockProvider::new(test_metadata());
452 mock.queue_response(sample_response("hello"));
453
454 let dyn_provider: &dyn DynProvider = &mock;
455 let params = ChatParams::default();
456 let result = dyn_provider.generate_boxed(¶ms).await;
457 assert!(result.is_ok());
458 }
459
460 #[tokio::test]
461 async fn test_dyn_provider_error_propagation() {
462 let mock = MockProvider::new(test_metadata());
463 mock.queue_error(MockError::Http {
464 status: Some(http::StatusCode::TOO_MANY_REQUESTS),
465 message: "rate limited".into(),
466 retryable: true,
467 });
468
469 let dyn_provider: &dyn DynProvider = &mock;
470 let result = dyn_provider.generate_boxed(&ChatParams::default()).await;
471 assert!(result.is_err());
472 assert!(matches!(result.unwrap_err(), LlmError::Http { .. }));
473 }
474
475 #[tokio::test]
476 async fn test_dyn_provider_stream_blanket() {
477 let mock = MockProvider::new(test_metadata());
478 mock.queue_stream(vec![
479 StreamEvent::TextDelta("hi".into()),
480 StreamEvent::Done {
481 stop_reason: StopReason::EndTurn,
482 },
483 ]);
484
485 let dyn_provider: &dyn DynProvider = &mock;
486 let params = ChatParams::default();
487 let stream = dyn_provider.stream_boxed(¶ms).await.unwrap();
488 let events: Vec<_> = stream.collect().await;
489 assert_eq!(events.len(), 2);
490 }
491
492 #[tokio::test]
493 async fn test_dyn_provider_metadata_matches() {
494 let mock = MockProvider::new(test_metadata());
495 let dyn_provider: &dyn DynProvider = &mock;
496 assert_eq!(Provider::metadata(&mock), dyn_provider.metadata());
497 }
498
499 #[tokio::test]
500 async fn test_dyn_provider_boxed_storage() {
501 let mock = MockProvider::new(test_metadata());
502 mock.queue_response(sample_response("from box"));
503
504 let boxed: Box<dyn DynProvider> = Box::new(mock);
505 let result = boxed.generate_boxed(&ChatParams::default()).await.unwrap();
506 assert_eq!(result.content, vec![ContentBlock::Text("from box".into())]);
507 }
508
509 #[test]
510 fn test_mock_provider_debug() {
511 let mock = MockProvider::new(test_metadata());
512 mock.queue_response(sample_response("a"));
513 mock.queue_stream(vec![StreamEvent::TextDelta("hi".into())]);
514
515 let debug = format!("{mock:?}");
516 assert!(debug.contains("MockProvider"));
517 assert!(debug.contains("queued_responses: 1"));
518 assert!(debug.contains("queued_streams: 1"));
519 assert!(debug.contains("recorded_calls: 0"));
520 }
521
522 #[test]
523 fn test_provider_is_object_safe() {
524 let f1: fn(&dyn DynProvider) = |_| {};
525 let f2: fn(Box<dyn DynProvider>) = |_| {};
526 let _ = (f1, f2);
528 }
529
530 #[tokio::test]
531 async fn test_mock_error_into_llm_error_all_variants() {
532 let variants: Vec<(MockError, &str)> = vec![
533 (MockError::InvalidRequest("bad".into()), "InvalidRequest"),
534 (
535 MockError::Provider {
536 code: "e1".into(),
537 message: "fail".into(),
538 retryable: false,
539 },
540 "Provider",
541 ),
542 (
543 MockError::ResponseFormat {
544 message: "bad json".into(),
545 raw: "{}".into(),
546 },
547 "ResponseFormat",
548 ),
549 (
550 MockError::SchemaValidation {
551 message: "missing field".into(),
552 schema: serde_json::json!({"type": "object"}),
553 actual: serde_json::json!(42),
554 },
555 "SchemaValidation",
556 ),
557 (
558 MockError::RetryExhausted {
559 attempts: 3,
560 last_error_message: "timed out".into(),
561 },
562 "RetryExhausted",
563 ),
564 ];
565
566 for (mock_err, label) in variants {
567 let mock = MockProvider::new(test_metadata());
568 mock.queue_error(mock_err);
569 let result = mock.generate(&ChatParams::default()).await;
570 assert!(result.is_err(), "{label} should produce error");
571 let err = result.unwrap_err();
572 let debug = format!("{err:?}");
573 assert!(
574 debug.contains(label),
575 "expected {label} in error debug: {debug}"
576 );
577 }
578 }
579}