1use std::collections::VecDeque;
45use std::fmt;
46use std::sync::{Arc, Mutex};
47
48use crate::chat::ChatResponse;
49use crate::error::LlmError;
50use crate::provider::{ChatParams, Provider, ProviderMetadata};
51use crate::stream::{ChatStream, StreamEvent};
52
53pub struct MockProvider {
67 responses: Mutex<VecDeque<Result<ChatResponse, MockError>>>,
68 stream_responses: Mutex<VecDeque<Result<Vec<StreamEvent>, MockError>>>,
69 meta: ProviderMetadata,
70 calls: Arc<Mutex<Vec<ChatParams>>>,
71}
72
73#[derive(Debug, Clone)]
80pub enum MockError {
81 Http {
83 status: Option<http::StatusCode>,
85 message: String,
87 retryable: bool,
89 },
90 Auth(String),
92 InvalidRequest(String),
94 Provider {
96 code: String,
98 message: String,
100 retryable: bool,
102 },
103 Timeout {
105 elapsed_ms: u64,
107 },
108 ResponseFormat {
110 message: String,
112 raw: String,
114 },
115 SchemaValidation {
117 message: String,
119 schema: serde_json::Value,
121 actual: serde_json::Value,
123 },
124 RetryExhausted {
126 attempts: u32,
128 last_error_message: String,
130 },
131}
132
133impl MockError {
134 fn into_llm_error(self) -> LlmError {
135 match self {
136 Self::Http {
137 status,
138 message,
139 retryable,
140 } => LlmError::Http {
141 status,
142 message,
143 retryable,
144 },
145 Self::Auth(msg) => LlmError::Auth(msg),
146 Self::InvalidRequest(msg) => LlmError::InvalidRequest(msg),
147 Self::Provider {
148 code,
149 message,
150 retryable,
151 } => LlmError::Provider {
152 code,
153 message,
154 retryable,
155 },
156 Self::Timeout { elapsed_ms } => LlmError::Timeout { elapsed_ms },
157 Self::ResponseFormat { message, raw } => LlmError::ResponseFormat { message, raw },
158 Self::SchemaValidation {
159 message,
160 schema,
161 actual,
162 } => LlmError::SchemaValidation {
163 message,
164 schema,
165 actual,
166 },
167 Self::RetryExhausted {
168 attempts,
169 last_error_message,
170 } => LlmError::RetryExhausted {
171 attempts,
172 last_error: Box::new(LlmError::InvalidRequest(last_error_message)),
173 },
174 }
175 }
176}
177
178impl fmt::Debug for MockProvider {
179 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
180 let response_len = self.responses.lock().unwrap().len();
181 let stream_len = self.stream_responses.lock().unwrap().len();
182 let call_count = self.calls.lock().unwrap().len();
183 f.debug_struct("MockProvider")
184 .field("meta", &self.meta)
185 .field("queued_responses", &response_len)
186 .field("queued_streams", &stream_len)
187 .field("recorded_calls", &call_count)
188 .finish()
189 }
190}
191
192impl MockProvider {
193 pub fn new(meta: ProviderMetadata) -> Self {
195 Self {
196 responses: Mutex::new(VecDeque::new()),
197 stream_responses: Mutex::new(VecDeque::new()),
198 meta,
199 calls: Arc::new(Mutex::new(Vec::new())),
200 }
201 }
202
203 pub fn queue_response(&self, response: ChatResponse) -> &Self {
205 self.responses.lock().unwrap().push_back(Ok(response));
206 self
207 }
208
209 pub fn queue_error(&self, error: MockError) -> &Self {
211 self.responses.lock().unwrap().push_back(Err(error));
212 self
213 }
214
215 pub fn queue_stream(&self, events: Vec<StreamEvent>) -> &Self {
217 self.stream_responses.lock().unwrap().push_back(Ok(events));
218 self
219 }
220
221 pub fn queue_stream_error(&self, error: MockError) -> &Self {
227 self.stream_responses.lock().unwrap().push_back(Err(error));
228 self
229 }
230
231 pub fn recorded_calls(&self) -> Vec<ChatParams> {
234 self.calls.lock().unwrap().clone()
235 }
236
237 fn record_call(&self, params: &ChatParams) {
238 self.calls.lock().unwrap().push(params.clone());
239 }
240}
241
242fn response_to_stream_events(response: &ChatResponse) -> Vec<StreamEvent> {
248 use crate::chat::ContentBlock;
249
250 let mut events = Vec::new();
251 let mut tool_index = 0u32;
252
253 for block in &response.content {
254 match block {
255 ContentBlock::Text(text) => {
256 events.push(StreamEvent::TextDelta(text.clone()));
257 }
258 ContentBlock::ToolCall(call) => {
259 events.push(StreamEvent::ToolCallStart {
260 index: tool_index,
261 id: call.id.clone(),
262 name: call.name.clone(),
263 });
264 events.push(StreamEvent::ToolCallComplete {
265 index: tool_index,
266 call: call.clone(),
267 });
268 tool_index += 1;
269 }
270 _ => {}
271 }
272 }
273
274 events.push(StreamEvent::Usage(response.usage.clone()));
275
276 events.push(StreamEvent::Done {
277 stop_reason: response.stop_reason,
278 });
279
280 events
281}
282
283impl Provider for MockProvider {
284 async fn generate(&self, params: &ChatParams) -> Result<ChatResponse, LlmError> {
285 self.record_call(params);
286 let result = self
287 .responses
288 .lock()
289 .unwrap()
290 .pop_front()
291 .expect("MockProvider: no queued responses remaining");
292 result.map_err(MockError::into_llm_error)
293 }
294
295 async fn stream(&self, params: &ChatParams) -> Result<ChatStream, LlmError> {
296 self.record_call(params);
297
298 if let Some(result) = self.stream_responses.lock().unwrap().pop_front() {
301 let events = result.map_err(MockError::into_llm_error)?;
302 let stream = futures::stream::iter(events.into_iter().map(Ok));
303 return Ok(Box::pin(stream));
304 }
305
306 let result = self
307 .responses
308 .lock()
309 .unwrap()
310 .pop_front()
311 .expect("MockProvider: no queued responses or stream responses remaining");
312 let response = result.map_err(MockError::into_llm_error)?;
313 let events = response_to_stream_events(&response);
314 let stream = futures::stream::iter(events.into_iter().map(Ok));
315 Ok(Box::pin(stream))
316 }
317
318 fn metadata(&self) -> ProviderMetadata {
319 self.meta.clone()
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326 use crate::chat::{ContentBlock, StopReason};
327 use crate::provider::{Capability, DynProvider};
328 use crate::test_helpers::sample_response;
329 use futures::StreamExt;
330 use std::collections::HashSet;
331
332 fn test_metadata() -> ProviderMetadata {
333 ProviderMetadata {
334 name: "mock".into(),
335 model: "test-model".into(),
336 context_window: 128_000,
337 capabilities: HashSet::from([Capability::Tools, Capability::StructuredOutput]),
338 }
339 }
340
341 #[tokio::test]
342 async fn test_mock_generate_returns_queued() {
343 let mock = MockProvider::new(test_metadata());
344 let resp = sample_response("test");
345 mock.queue_response(resp.clone());
346
347 let result = mock.generate(&ChatParams::default()).await.unwrap();
348 assert_eq!(result, resp);
349 }
350
351 #[tokio::test]
352 async fn test_mock_generate_multiple_queued() {
353 let mock = MockProvider::new(test_metadata());
354 mock.queue_response(sample_response("first"));
355 mock.queue_response(sample_response("second"));
356 mock.queue_response(sample_response("third"));
357
358 let r1 = mock.generate(&ChatParams::default()).await.unwrap();
359 let r2 = mock.generate(&ChatParams::default()).await.unwrap();
360 let r3 = mock.generate(&ChatParams::default()).await.unwrap();
361
362 assert_eq!(r1.content, vec![ContentBlock::Text("first".into())]);
363 assert_eq!(r2.content, vec![ContentBlock::Text("second".into())]);
364 assert_eq!(r3.content, vec![ContentBlock::Text("third".into())]);
365 }
366
367 #[tokio::test]
368 async fn test_mock_generate_error() {
369 let mock = MockProvider::new(test_metadata());
370 mock.queue_error(MockError::Auth("bad key".into()));
371
372 let result = mock.generate(&ChatParams::default()).await;
373 assert!(result.is_err());
374 assert!(matches!(result.unwrap_err(), LlmError::Auth(_)));
375 }
376
377 #[tokio::test]
378 async fn test_mock_generate_mixed_queue() {
379 let mock = MockProvider::new(test_metadata());
380 mock.queue_response(sample_response("ok"));
381 mock.queue_error(MockError::Timeout { elapsed_ms: 5000 });
382 mock.queue_response(sample_response("ok again"));
383
384 let r1 = mock.generate(&ChatParams::default()).await;
385 let r2 = mock.generate(&ChatParams::default()).await;
386 let r3 = mock.generate(&ChatParams::default()).await;
387
388 assert!(r1.is_ok());
389 assert!(r2.is_err());
390 assert!(r3.is_ok());
391 }
392
393 #[tokio::test]
394 #[should_panic(expected = "no queued responses")]
395 async fn test_mock_generate_empty_queue_panics() {
396 let mock = MockProvider::new(test_metadata());
397 let _ = mock.generate(&ChatParams::default()).await;
398 }
399
400 #[tokio::test]
401 async fn test_mock_stream_returns_events() {
402 let mock = MockProvider::new(test_metadata());
403 mock.queue_stream(vec![
404 StreamEvent::TextDelta("hello".into()),
405 StreamEvent::TextDelta(" world".into()),
406 StreamEvent::Done {
407 stop_reason: StopReason::EndTurn,
408 },
409 ]);
410
411 let stream = mock.stream(&ChatParams::default()).await.unwrap();
412 let events: Vec<_> = stream.collect().await;
413 assert_eq!(events.len(), 3);
414 assert!(events.iter().all(Result::is_ok));
415 }
416
417 #[tokio::test]
418 async fn test_mock_stream_error() {
419 let mock = MockProvider::new(test_metadata());
420 mock.queue_stream_error(MockError::Auth("bad token".into()));
421
422 let result = mock.stream(&ChatParams::default()).await;
423 assert!(result.is_err());
424 let err = result.err().unwrap();
425 assert!(matches!(err, LlmError::Auth(_)));
426 }
427
428 #[tokio::test]
429 async fn test_mock_stream_empty_events() {
430 let mock = MockProvider::new(test_metadata());
431 mock.queue_stream(vec![]);
432
433 let stream = mock.stream(&ChatParams::default()).await.unwrap();
434 let events: Vec<_> = stream.collect().await;
435 assert!(events.is_empty());
436 }
437
438 #[tokio::test]
439 async fn test_mock_records_calls() {
440 let mock = MockProvider::new(test_metadata());
441 mock.queue_response(sample_response("a"));
442 mock.queue_response(sample_response("b"));
443 mock.queue_response(sample_response("c"));
444
445 let _ = mock.generate(&ChatParams::default()).await;
446 let _ = mock.generate(&ChatParams::default()).await;
447 let _ = mock.generate(&ChatParams::default()).await;
448
449 assert_eq!(mock.recorded_calls().len(), 3);
450 }
451
452 #[tokio::test]
453 async fn test_mock_records_params_accurately() {
454 let mock = MockProvider::new(test_metadata());
455 mock.queue_response(sample_response("ok"));
456
457 let params = ChatParams {
458 temperature: Some(0.5),
459 system: Some("be nice".into()),
460 ..Default::default()
461 };
462 let _ = mock.generate(¶ms).await;
463
464 let recorded = mock.recorded_calls();
465 assert_eq!(recorded[0].temperature, Some(0.5));
466 assert_eq!(recorded[0].system, Some("be nice".into()));
467 }
468
469 #[test]
470 fn test_mock_metadata_returns_configured() {
471 let meta = test_metadata();
472 let mock = MockProvider::new(meta.clone());
473 assert_eq!(Provider::metadata(&mock), meta);
474 }
475
476 #[tokio::test]
477 async fn test_mock_concurrent_access() {
478 let mock = Arc::new(MockProvider::new(test_metadata()));
479 for _ in 0..10 {
480 mock.queue_response(sample_response("ok"));
481 }
482
483 let mut handles = Vec::new();
484 for _ in 0..10 {
485 let m = mock.clone();
486 handles.push(tokio::spawn(async move {
487 m.generate(&ChatParams::default()).await.unwrap()
488 }));
489 }
490
491 for h in handles {
492 h.await.unwrap();
493 }
494
495 assert_eq!(mock.recorded_calls().len(), 10);
496 }
497
498 #[tokio::test]
501 async fn test_dyn_provider_blanket_impl() {
502 let mock = MockProvider::new(test_metadata());
503 mock.queue_response(sample_response("hello"));
504
505 let dyn_provider: &dyn DynProvider = &mock;
506 let params = ChatParams::default();
507 let result = dyn_provider.generate_boxed(¶ms).await;
508 assert!(result.is_ok());
509 }
510
511 #[tokio::test]
512 async fn test_dyn_provider_error_propagation() {
513 let mock = MockProvider::new(test_metadata());
514 mock.queue_error(MockError::Http {
515 status: Some(http::StatusCode::TOO_MANY_REQUESTS),
516 message: "rate limited".into(),
517 retryable: true,
518 });
519
520 let dyn_provider: &dyn DynProvider = &mock;
521 let result = dyn_provider.generate_boxed(&ChatParams::default()).await;
522 assert!(result.is_err());
523 assert!(matches!(result.unwrap_err(), LlmError::Http { .. }));
524 }
525
526 #[tokio::test]
527 async fn test_dyn_provider_stream_blanket() {
528 let mock = MockProvider::new(test_metadata());
529 mock.queue_stream(vec![
530 StreamEvent::TextDelta("hi".into()),
531 StreamEvent::Done {
532 stop_reason: StopReason::EndTurn,
533 },
534 ]);
535
536 let dyn_provider: &dyn DynProvider = &mock;
537 let params = ChatParams::default();
538 let stream = dyn_provider.stream_boxed(¶ms).await.unwrap();
539 let events: Vec<_> = stream.collect().await;
540 assert_eq!(events.len(), 2);
541 }
542
543 #[tokio::test]
544 async fn test_dyn_provider_metadata_matches() {
545 let mock = MockProvider::new(test_metadata());
546 let dyn_provider: &dyn DynProvider = &mock;
547 assert_eq!(Provider::metadata(&mock), dyn_provider.metadata());
548 }
549
550 #[tokio::test]
551 async fn test_dyn_provider_boxed_storage() {
552 let mock = MockProvider::new(test_metadata());
553 mock.queue_response(sample_response("from box"));
554
555 let boxed: Box<dyn DynProvider> = Box::new(mock);
556 let result = boxed.generate_boxed(&ChatParams::default()).await.unwrap();
557 assert_eq!(result.content, vec![ContentBlock::Text("from box".into())]);
558 }
559
560 #[test]
561 fn test_mock_provider_debug() {
562 let mock = MockProvider::new(test_metadata());
563 mock.queue_response(sample_response("a"));
564 mock.queue_stream(vec![StreamEvent::TextDelta("hi".into())]);
565
566 let debug = format!("{mock:?}");
567 assert!(debug.contains("MockProvider"));
568 assert!(debug.contains("queued_responses: 1"));
569 assert!(debug.contains("queued_streams: 1"));
570 assert!(debug.contains("recorded_calls: 0"));
571 }
572
573 #[test]
574 fn test_provider_is_object_safe() {
575 let f1: fn(&dyn DynProvider) = |_| {};
576 let f2: fn(Box<dyn DynProvider>) = |_| {};
577 let _ = (f1, f2);
579 }
580
581 #[tokio::test]
582 async fn test_mock_error_into_llm_error_all_variants() {
583 let variants: Vec<(MockError, &str)> = vec![
584 (MockError::InvalidRequest("bad".into()), "InvalidRequest"),
585 (
586 MockError::Provider {
587 code: "e1".into(),
588 message: "fail".into(),
589 retryable: false,
590 },
591 "Provider",
592 ),
593 (
594 MockError::ResponseFormat {
595 message: "bad json".into(),
596 raw: "{}".into(),
597 },
598 "ResponseFormat",
599 ),
600 (
601 MockError::SchemaValidation {
602 message: "missing field".into(),
603 schema: serde_json::json!({"type": "object"}),
604 actual: serde_json::json!(42),
605 },
606 "SchemaValidation",
607 ),
608 (
609 MockError::RetryExhausted {
610 attempts: 3,
611 last_error_message: "timed out".into(),
612 },
613 "RetryExhausted",
614 ),
615 ];
616
617 for (mock_err, label) in variants {
618 let mock = MockProvider::new(test_metadata());
619 mock.queue_error(mock_err);
620 let result = mock.generate(&ChatParams::default()).await;
621 assert!(result.is_err(), "{label} should produce error");
622 let err = result.unwrap_err();
623 let debug = format!("{err:?}");
624 assert!(
625 debug.contains(label),
626 "expected {label} in error debug: {debug}"
627 );
628 }
629 }
630}