Skip to main content

cognate_core/
mock.rs

1//! Mock provider for use in tests.
2//!
3//! [`MockProvider`] implements [`Provider`] using pre-loaded responses so
4//! unit tests can run without real API calls.
5//!
6//! # Example
7//!
8//! ```rust
9//! use cognate_core::{MockProvider, Provider, Request, Response, Choice, Message, Usage};
10//!
11//! #[tokio::test]
12//! async fn my_test() {
13//!     let provider = MockProvider::new();
14//!     let req = Request::new().with_model("test");
15//!     let response = provider.complete(req).await.unwrap();
16//!     assert_eq!(response.content(), "Mock response");
17//! }
18//! ```
19
20use async_trait::async_trait;
21use crate::{Chunk, Choice, Delta, Message, Provider, Request, Response, Result};
22use futures::stream::{self, BoxStream, StreamExt};
23use std::sync::{Arc, Mutex};
24
25/// A mock [`Provider`] for testing.
26///
27/// Responses are consumed from an internal queue in FIFO order.
28/// If the queue is empty a default stub response/chunk is returned.
29#[derive(Debug, Clone)]
30pub struct MockProvider {
31    responses: Arc<Mutex<Vec<Response>>>,
32    chunks: Arc<Mutex<Vec<Vec<Chunk>>>>,
33}
34
35impl Default for MockProvider {
36    fn default() -> Self {
37        Self::new()
38    }
39}
40
41impl MockProvider {
42    /// Create a new, empty mock provider.
43    pub fn new() -> Self {
44        Self {
45            responses: Arc::new(Mutex::new(Vec::new())),
46            chunks: Arc::new(Mutex::new(Vec::new())),
47        }
48    }
49
50    /// Enqueue a response to be returned by the next call to [`complete`](Self::complete).
51    pub fn push_response(&self, response: Response) {
52        self.responses.lock().unwrap().push(response);
53    }
54
55    /// Enqueue a sequence of chunks to be emitted by the next call to [`stream`](Self::stream).
56    pub fn push_stream(&self, stream_chunks: Vec<Chunk>) {
57        self.chunks.lock().unwrap().push(stream_chunks);
58    }
59}
60
61#[async_trait]
62impl Provider for MockProvider {
63    async fn complete(&self, _req: Request) -> Result<Response> {
64        let mut responses = self.responses.lock().unwrap();
65        if responses.is_empty() {
66            Ok(Response {
67                id: "mock-id".to_string(),
68                model: "mock".to_string(),
69                choices: vec![Choice {
70                    index: 0,
71                    message: Message::assistant("Mock response"),
72                    finish_reason: Some("stop".to_string()),
73                }],
74                usage: None,
75                created: None,
76            })
77        } else {
78            Ok(responses.remove(0))
79        }
80    }
81
82    async fn stream(&self, _req: Request) -> Result<BoxStream<'static, Result<Chunk>>> {
83        let mut chunks = self.chunks.lock().unwrap();
84        let stream_chunks = if chunks.is_empty() {
85            vec![Chunk {
86                id: "mock-id".to_string(),
87                model: "mock".to_string(),
88                delta: Delta {
89                    role: None,
90                    content: "Mock chunk".to_string(),
91                },
92                finish_reason: Some("stop".to_string()),
93            }]
94        } else {
95            chunks.remove(0)
96        };
97
98        let s = stream::iter(stream_chunks.into_iter().map(Ok)).boxed();
99        Ok(s)
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106    use futures::StreamExt;
107
108    #[tokio::test]
109    async fn test_mock_provider_default_complete() {
110        let provider = MockProvider::new();
111        let req = Request::new().with_model("test");
112        let response = provider.complete(req).await.unwrap();
113        assert_eq!(response.content(), "Mock response");
114    }
115
116    #[tokio::test]
117    async fn test_mock_provider_queued_complete() {
118        let provider = MockProvider::new();
119        provider.push_response(Response {
120            id: "r1".to_string(),
121            model: "test".to_string(),
122            choices: vec![Choice {
123                index: 0,
124                message: Message::assistant("queued"),
125                finish_reason: Some("stop".to_string()),
126            }],
127            usage: None,
128            created: None,
129        });
130        let req = Request::new().with_model("test");
131        let response = provider.complete(req).await.unwrap();
132        assert_eq!(response.content(), "queued");
133    }
134
135    #[tokio::test]
136    async fn test_mock_provider_stream() {
137        let provider = MockProvider::new();
138        let req = Request::new().with_model("test");
139        let mut stream = provider.stream(req).await.unwrap();
140        let chunk = stream.next().await.unwrap().unwrap();
141        assert_eq!(chunk.content(), "Mock chunk");
142    }
143}