1use std::collections::VecDeque;
47use std::fmt;
48use std::sync::{Arc, Mutex};
49
50use crate::chat::ChatResponse;
51use crate::error::LlmError;
52use crate::provider::{ChatParams, Provider, ProviderMetadata};
53use crate::stream::{ChatStream, StreamEvent};
54
55pub struct MockProvider {
69 responses: Mutex<VecDeque<Result<ChatResponse, MockError>>>,
70 stream_responses: Mutex<VecDeque<Result<Vec<StreamEvent>, MockError>>>,
71 meta: ProviderMetadata,
72 calls: Arc<Mutex<Vec<ChatParams>>>,
73}
74
75#[derive(Debug, Clone)]
82pub enum MockError {
83 Http {
85 status: Option<http::StatusCode>,
87 message: String,
89 retryable: bool,
91 },
92 Auth(String),
94 InvalidRequest(String),
96 Provider {
98 code: String,
100 message: String,
102 retryable: bool,
104 },
105 Timeout {
107 elapsed_ms: u64,
109 },
110 ResponseFormat {
112 message: String,
114 raw: String,
116 },
117 SchemaValidation {
119 message: String,
121 schema: serde_json::Value,
123 actual: serde_json::Value,
125 },
126 RetryExhausted {
128 attempts: u32,
130 last_error_message: String,
132 },
133}
134
135impl MockError {
136 fn into_llm_error(self) -> LlmError {
137 match self {
138 Self::Http {
139 status,
140 message,
141 retryable,
142 } => LlmError::Http {
143 status,
144 message,
145 retryable,
146 },
147 Self::Auth(msg) => LlmError::Auth(msg),
148 Self::InvalidRequest(msg) => LlmError::InvalidRequest(msg),
149 Self::Provider {
150 code,
151 message,
152 retryable,
153 } => LlmError::Provider {
154 code,
155 message,
156 retryable,
157 },
158 Self::Timeout { elapsed_ms } => LlmError::Timeout { elapsed_ms },
159 Self::ResponseFormat { message, raw } => LlmError::ResponseFormat { message, raw },
160 Self::SchemaValidation {
161 message,
162 schema,
163 actual,
164 } => LlmError::SchemaValidation {
165 message,
166 schema,
167 actual,
168 },
169 Self::RetryExhausted {
170 attempts,
171 last_error_message,
172 } => LlmError::RetryExhausted {
173 attempts,
174 last_error: Box::new(LlmError::InvalidRequest(last_error_message)),
175 },
176 }
177 }
178}
179
180impl fmt::Debug for MockProvider {
181 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
182 let response_len = self.responses.lock().unwrap().len();
183 let stream_len = self.stream_responses.lock().unwrap().len();
184 let call_count = self.calls.lock().unwrap().len();
185 f.debug_struct("MockProvider")
186 .field("meta", &self.meta)
187 .field("queued_responses", &response_len)
188 .field("queued_streams", &stream_len)
189 .field("recorded_calls", &call_count)
190 .finish()
191 }
192}
193
194impl MockProvider {
195 pub fn new(meta: ProviderMetadata) -> Self {
197 Self {
198 responses: Mutex::new(VecDeque::new()),
199 stream_responses: Mutex::new(VecDeque::new()),
200 meta,
201 calls: Arc::new(Mutex::new(Vec::new())),
202 }
203 }
204
205 pub fn queue_response(&self, response: ChatResponse) -> &Self {
207 self.responses.lock().unwrap().push_back(Ok(response));
208 self
209 }
210
211 pub fn queue_error(&self, error: MockError) -> &Self {
213 self.responses.lock().unwrap().push_back(Err(error));
214 self
215 }
216
217 pub fn queue_stream(&self, events: Vec<StreamEvent>) -> &Self {
219 self.stream_responses.lock().unwrap().push_back(Ok(events));
220 self
221 }
222
223 pub fn queue_stream_error(&self, error: MockError) -> &Self {
229 self.stream_responses.lock().unwrap().push_back(Err(error));
230 self
231 }
232
233 pub fn recorded_calls(&self) -> Vec<ChatParams> {
236 self.calls.lock().unwrap().clone()
237 }
238
239 fn record_call(&self, params: &ChatParams) {
240 self.calls.lock().unwrap().push(params.clone());
241 }
242}
243
244fn response_to_stream_events(response: &ChatResponse) -> Vec<StreamEvent> {
250 use crate::chat::ContentBlock;
251
252 let mut events = Vec::new();
253 let mut tool_index = 0u32;
254
255 for block in &response.content {
256 match block {
257 ContentBlock::Text(text) => {
258 events.push(StreamEvent::TextDelta(text.clone()));
259 }
260 ContentBlock::ToolCall(call) => {
261 events.push(StreamEvent::ToolCallStart {
262 index: tool_index,
263 id: call.id.clone(),
264 name: call.name.clone(),
265 });
266 events.push(StreamEvent::ToolCallComplete {
267 index: tool_index,
268 call: call.clone(),
269 });
270 tool_index += 1;
271 }
272 _ => {}
273 }
274 }
275
276 events.push(StreamEvent::Usage(response.usage.clone()));
277
278 events.push(StreamEvent::Done {
279 stop_reason: response.stop_reason,
280 });
281
282 events
283}
284
285impl Provider for MockProvider {
286 async fn generate(&self, params: &ChatParams) -> Result<ChatResponse, LlmError> {
287 self.record_call(params);
288 let result = self
289 .responses
290 .lock()
291 .unwrap()
292 .pop_front()
293 .expect("MockProvider: no queued responses remaining");
294 result.map_err(MockError::into_llm_error)
295 }
296
297 async fn stream(&self, params: &ChatParams) -> Result<ChatStream, LlmError> {
298 self.record_call(params);
299
300 if let Some(result) = self.stream_responses.lock().unwrap().pop_front() {
303 let events = result.map_err(MockError::into_llm_error)?;
304 let stream = futures::stream::iter(events.into_iter().map(Ok));
305 return Ok(Box::pin(stream));
306 }
307
308 let result = self
309 .responses
310 .lock()
311 .unwrap()
312 .pop_front()
313 .expect("MockProvider: no queued responses or stream responses remaining");
314 let response = result.map_err(MockError::into_llm_error)?;
315 let events = response_to_stream_events(&response);
316 let stream = futures::stream::iter(events.into_iter().map(Ok));
317 Ok(Box::pin(stream))
318 }
319
320 fn metadata(&self) -> ProviderMetadata {
321 self.meta.clone()
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328 use crate::chat::{ContentBlock, StopReason};
329 use crate::provider::{Capability, DynProvider};
330 use crate::test_helpers::sample_response;
331 use futures::StreamExt;
332 use std::collections::HashSet;
333
334 fn test_metadata() -> ProviderMetadata {
335 ProviderMetadata {
336 name: "mock".into(),
337 model: "test-model".into(),
338 context_window: 128_000,
339 capabilities: HashSet::from([Capability::Tools, Capability::StructuredOutput]),
340 }
341 }
342
343 #[tokio::test]
344 async fn test_mock_generate_returns_queued() {
345 let mock = MockProvider::new(test_metadata());
346 let resp = sample_response("test");
347 mock.queue_response(resp.clone());
348
349 let result = mock.generate(&ChatParams::default()).await.unwrap();
350 assert_eq!(result, resp);
351 }
352
353 #[tokio::test]
354 async fn test_mock_generate_multiple_queued() {
355 let mock = MockProvider::new(test_metadata());
356 mock.queue_response(sample_response("first"));
357 mock.queue_response(sample_response("second"));
358 mock.queue_response(sample_response("third"));
359
360 let r1 = mock.generate(&ChatParams::default()).await.unwrap();
361 let r2 = mock.generate(&ChatParams::default()).await.unwrap();
362 let r3 = mock.generate(&ChatParams::default()).await.unwrap();
363
364 assert_eq!(r1.content, vec![ContentBlock::Text("first".into())]);
365 assert_eq!(r2.content, vec![ContentBlock::Text("second".into())]);
366 assert_eq!(r3.content, vec![ContentBlock::Text("third".into())]);
367 }
368
369 #[tokio::test]
370 async fn test_mock_generate_error() {
371 let mock = MockProvider::new(test_metadata());
372 mock.queue_error(MockError::Auth("bad key".into()));
373
374 let result = mock.generate(&ChatParams::default()).await;
375 assert!(result.is_err());
376 assert!(matches!(result.unwrap_err(), LlmError::Auth(_)));
377 }
378
379 #[tokio::test]
380 async fn test_mock_generate_mixed_queue() {
381 let mock = MockProvider::new(test_metadata());
382 mock.queue_response(sample_response("ok"));
383 mock.queue_error(MockError::Timeout { elapsed_ms: 5000 });
384 mock.queue_response(sample_response("ok again"));
385
386 let r1 = mock.generate(&ChatParams::default()).await;
387 let r2 = mock.generate(&ChatParams::default()).await;
388 let r3 = mock.generate(&ChatParams::default()).await;
389
390 assert!(r1.is_ok());
391 assert!(r2.is_err());
392 assert!(r3.is_ok());
393 }
394
395 #[tokio::test]
396 #[should_panic(expected = "no queued responses")]
397 async fn test_mock_generate_empty_queue_panics() {
398 let mock = MockProvider::new(test_metadata());
399 let _ = mock.generate(&ChatParams::default()).await;
400 }
401
402 #[tokio::test]
403 async fn test_mock_stream_returns_events() {
404 let mock = MockProvider::new(test_metadata());
405 mock.queue_stream(vec![
406 StreamEvent::TextDelta("hello".into()),
407 StreamEvent::TextDelta(" world".into()),
408 StreamEvent::Done {
409 stop_reason: StopReason::EndTurn,
410 },
411 ]);
412
413 let stream = mock.stream(&ChatParams::default()).await.unwrap();
414 let events: Vec<_> = stream.collect().await;
415 assert_eq!(events.len(), 3);
416 assert!(events.iter().all(Result::is_ok));
417 }
418
419 #[tokio::test]
420 async fn test_mock_stream_error() {
421 let mock = MockProvider::new(test_metadata());
422 mock.queue_stream_error(MockError::Auth("bad token".into()));
423
424 let result = mock.stream(&ChatParams::default()).await;
425 assert!(result.is_err());
426 let err = result.err().unwrap();
427 assert!(matches!(err, LlmError::Auth(_)));
428 }
429
430 #[tokio::test]
431 async fn test_mock_stream_empty_events() {
432 let mock = MockProvider::new(test_metadata());
433 mock.queue_stream(vec![]);
434
435 let stream = mock.stream(&ChatParams::default()).await.unwrap();
436 let events: Vec<_> = stream.collect().await;
437 assert!(events.is_empty());
438 }
439
440 #[tokio::test]
441 async fn test_mock_records_calls() {
442 let mock = MockProvider::new(test_metadata());
443 mock.queue_response(sample_response("a"));
444 mock.queue_response(sample_response("b"));
445 mock.queue_response(sample_response("c"));
446
447 let _ = mock.generate(&ChatParams::default()).await;
448 let _ = mock.generate(&ChatParams::default()).await;
449 let _ = mock.generate(&ChatParams::default()).await;
450
451 assert_eq!(mock.recorded_calls().len(), 3);
452 }
453
454 #[tokio::test]
455 async fn test_mock_records_params_accurately() {
456 let mock = MockProvider::new(test_metadata());
457 mock.queue_response(sample_response("ok"));
458
459 let params = ChatParams {
460 temperature: Some(0.5),
461 system: Some("be nice".into()),
462 ..Default::default()
463 };
464 let _ = mock.generate(¶ms).await;
465
466 let recorded = mock.recorded_calls();
467 assert_eq!(recorded[0].temperature, Some(0.5));
468 assert_eq!(recorded[0].system, Some("be nice".into()));
469 }
470
471 #[test]
472 fn test_mock_metadata_returns_configured() {
473 let meta = test_metadata();
474 let mock = MockProvider::new(meta.clone());
475 assert_eq!(Provider::metadata(&mock), meta);
476 }
477
478 #[tokio::test]
479 async fn test_mock_concurrent_access() {
480 let mock = Arc::new(MockProvider::new(test_metadata()));
481 for _ in 0..10 {
482 mock.queue_response(sample_response("ok"));
483 }
484
485 let mut handles = Vec::new();
486 for _ in 0..10 {
487 let m = mock.clone();
488 handles.push(tokio::spawn(async move {
489 m.generate(&ChatParams::default()).await.unwrap()
490 }));
491 }
492
493 for h in handles {
494 h.await.unwrap();
495 }
496
497 assert_eq!(mock.recorded_calls().len(), 10);
498 }
499
500 #[tokio::test]
503 async fn test_dyn_provider_blanket_impl() {
504 let mock = MockProvider::new(test_metadata());
505 mock.queue_response(sample_response("hello"));
506
507 let dyn_provider: &dyn DynProvider = &mock;
508 let params = ChatParams::default();
509 let result = dyn_provider.generate_boxed(¶ms).await;
510 assert!(result.is_ok());
511 }
512
513 #[tokio::test]
514 async fn test_dyn_provider_error_propagation() {
515 let mock = MockProvider::new(test_metadata());
516 mock.queue_error(MockError::Http {
517 status: Some(http::StatusCode::TOO_MANY_REQUESTS),
518 message: "rate limited".into(),
519 retryable: true,
520 });
521
522 let dyn_provider: &dyn DynProvider = &mock;
523 let result = dyn_provider.generate_boxed(&ChatParams::default()).await;
524 assert!(result.is_err());
525 assert!(matches!(result.unwrap_err(), LlmError::Http { .. }));
526 }
527
528 #[tokio::test]
529 async fn test_dyn_provider_stream_blanket() {
530 let mock = MockProvider::new(test_metadata());
531 mock.queue_stream(vec![
532 StreamEvent::TextDelta("hi".into()),
533 StreamEvent::Done {
534 stop_reason: StopReason::EndTurn,
535 },
536 ]);
537
538 let dyn_provider: &dyn DynProvider = &mock;
539 let params = ChatParams::default();
540 let stream = dyn_provider.stream_boxed(¶ms).await.unwrap();
541 let events: Vec<_> = stream.collect().await;
542 assert_eq!(events.len(), 2);
543 }
544
545 #[tokio::test]
546 async fn test_dyn_provider_metadata_matches() {
547 let mock = MockProvider::new(test_metadata());
548 let dyn_provider: &dyn DynProvider = &mock;
549 assert_eq!(Provider::metadata(&mock), dyn_provider.metadata());
550 }
551
552 #[tokio::test]
553 async fn test_dyn_provider_boxed_storage() {
554 let mock = MockProvider::new(test_metadata());
555 mock.queue_response(sample_response("from box"));
556
557 let boxed: Box<dyn DynProvider> = Box::new(mock);
558 let result = boxed.generate_boxed(&ChatParams::default()).await.unwrap();
559 assert_eq!(result.content, vec![ContentBlock::Text("from box".into())]);
560 }
561
562 #[test]
563 fn test_mock_provider_debug() {
564 let mock = MockProvider::new(test_metadata());
565 mock.queue_response(sample_response("a"));
566 mock.queue_stream(vec![StreamEvent::TextDelta("hi".into())]);
567
568 let debug = format!("{mock:?}");
569 assert!(debug.contains("MockProvider"));
570 assert!(debug.contains("queued_responses: 1"));
571 assert!(debug.contains("queued_streams: 1"));
572 assert!(debug.contains("recorded_calls: 0"));
573 }
574
575 #[test]
576 fn test_provider_is_object_safe() {
577 let f1: fn(&dyn DynProvider) = |_| {};
578 let f2: fn(Box<dyn DynProvider>) = |_| {};
579 let _ = (f1, f2);
581 }
582
583 #[tokio::test]
584 async fn test_mock_error_into_llm_error_all_variants() {
585 let variants: Vec<(MockError, &str)> = vec![
586 (MockError::InvalidRequest("bad".into()), "InvalidRequest"),
587 (
588 MockError::Provider {
589 code: "e1".into(),
590 message: "fail".into(),
591 retryable: false,
592 },
593 "Provider",
594 ),
595 (
596 MockError::ResponseFormat {
597 message: "bad json".into(),
598 raw: "{}".into(),
599 },
600 "ResponseFormat",
601 ),
602 (
603 MockError::SchemaValidation {
604 message: "missing field".into(),
605 schema: serde_json::json!({"type": "object"}),
606 actual: serde_json::json!(42),
607 },
608 "SchemaValidation",
609 ),
610 (
611 MockError::RetryExhausted {
612 attempts: 3,
613 last_error_message: "timed out".into(),
614 },
615 "RetryExhausted",
616 ),
617 ];
618
619 for (mock_err, label) in variants {
620 let mock = MockProvider::new(test_metadata());
621 mock.queue_error(mock_err);
622 let result = mock.generate(&ChatParams::default()).await;
623 assert!(result.is_err(), "{label} should produce error");
624 let err = result.unwrap_err();
625 let debug = format!("{err:?}");
626 assert!(
627 debug.contains(label),
628 "expected {label} in error debug: {debug}"
629 );
630 }
631 }
632}