1use async_trait::async_trait;
21use crate::{Chunk, Choice, Delta, Message, Provider, Request, Response, Result};
22use futures::stream::{self, BoxStream, StreamExt};
23use std::sync::{Arc, Mutex};
24
25#[derive(Debug, Clone)]
30pub struct MockProvider {
31 responses: Arc<Mutex<Vec<Response>>>,
32 chunks: Arc<Mutex<Vec<Vec<Chunk>>>>,
33}
34
35impl Default for MockProvider {
36 fn default() -> Self {
37 Self::new()
38 }
39}
40
41impl MockProvider {
42 pub fn new() -> Self {
44 Self {
45 responses: Arc::new(Mutex::new(Vec::new())),
46 chunks: Arc::new(Mutex::new(Vec::new())),
47 }
48 }
49
50 pub fn push_response(&self, response: Response) {
52 self.responses.lock().unwrap().push(response);
53 }
54
55 pub fn push_stream(&self, stream_chunks: Vec<Chunk>) {
57 self.chunks.lock().unwrap().push(stream_chunks);
58 }
59}
60
61#[async_trait]
62impl Provider for MockProvider {
63 async fn complete(&self, _req: Request) -> Result<Response> {
64 let mut responses = self.responses.lock().unwrap();
65 if responses.is_empty() {
66 Ok(Response {
67 id: "mock-id".to_string(),
68 model: "mock".to_string(),
69 choices: vec![Choice {
70 index: 0,
71 message: Message::assistant("Mock response"),
72 finish_reason: Some("stop".to_string()),
73 }],
74 usage: None,
75 created: None,
76 })
77 } else {
78 Ok(responses.remove(0))
79 }
80 }
81
82 async fn stream(&self, _req: Request) -> Result<BoxStream<'static, Result<Chunk>>> {
83 let mut chunks = self.chunks.lock().unwrap();
84 let stream_chunks = if chunks.is_empty() {
85 vec![Chunk {
86 id: "mock-id".to_string(),
87 model: "mock".to_string(),
88 delta: Delta {
89 role: None,
90 content: "Mock chunk".to_string(),
91 },
92 finish_reason: Some("stop".to_string()),
93 }]
94 } else {
95 chunks.remove(0)
96 };
97
98 let s = stream::iter(stream_chunks.into_iter().map(Ok)).boxed();
99 Ok(s)
100 }
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106 use futures::StreamExt;
107
108 #[tokio::test]
109 async fn test_mock_provider_default_complete() {
110 let provider = MockProvider::new();
111 let req = Request::new().with_model("test");
112 let response = provider.complete(req).await.unwrap();
113 assert_eq!(response.content(), "Mock response");
114 }
115
116 #[tokio::test]
117 async fn test_mock_provider_queued_complete() {
118 let provider = MockProvider::new();
119 provider.push_response(Response {
120 id: "r1".to_string(),
121 model: "test".to_string(),
122 choices: vec![Choice {
123 index: 0,
124 message: Message::assistant("queued"),
125 finish_reason: Some("stop".to_string()),
126 }],
127 usage: None,
128 created: None,
129 });
130 let req = Request::new().with_model("test");
131 let response = provider.complete(req).await.unwrap();
132 assert_eq!(response.content(), "queued");
133 }
134
135 #[tokio::test]
136 async fn test_mock_provider_stream() {
137 let provider = MockProvider::new();
138 let req = Request::new().with_model("test");
139 let mut stream = provider.stream(req).await.unwrap();
140 let chunk = stream.next().await.unwrap().unwrap();
141 assert_eq!(chunk.content(), "Mock chunk");
142 }
143}