1use std::collections::VecDeque;
7use std::pin::Pin;
8use std::sync::{Mutex, MutexGuard};
9
10use futures::Stream;
11
12use crate::error::PeError;
13use crate::llm::{LlmProvider, LlmResponse, StreamChunk, ToolSchema};
14use crate::message::{AiMessage, Message, MessageContent, ToolCall};
15
16#[derive(Debug, Clone)]
18enum MockResponse {
19 Text(String),
21 ToolCall {
23 tool_name: String,
24 args: serde_json::Value,
25 },
26 Error(PeError),
28}
29
30pub struct MockProvider {
54 responses: Mutex<VecDeque<MockResponse>>,
55 embed_response: Vec<f32>,
56}
57
58impl MockProvider {
59 fn responses_guard(&self) -> MutexGuard<'_, VecDeque<MockResponse>> {
60 match self.responses.lock() {
61 Ok(guard) => guard,
62 Err(poisoned) => poisoned.into_inner(),
63 }
64 }
65
66 pub fn new() -> Self {
68 Self {
69 responses: Mutex::new(VecDeque::new()),
70 embed_response: vec![0.0; 128], }
72 }
73
74 #[must_use = "builder methods return the modified builder"]
76 pub fn respond_with(self, text: impl Into<String>) -> Self {
77 self.responses_guard()
78 .push_back(MockResponse::Text(text.into()));
79 self
80 }
81
82 #[must_use = "builder methods return the modified builder"]
84 pub fn respond_with_tool_call(
85 self,
86 tool_name: impl Into<String>,
87 args: serde_json::Value,
88 ) -> Self {
89 self.responses_guard().push_back(MockResponse::ToolCall {
90 tool_name: tool_name.into(),
91 args,
92 });
93 self
94 }
95
96 #[must_use = "builder methods return the modified builder"]
98 pub fn respond_with_error(self, err: PeError) -> Self {
99 self.responses_guard().push_back(MockResponse::Error(err));
100 self
101 }
102
103 #[must_use = "builder methods return the modified builder"]
105 pub fn with_embedding(mut self, embedding: Vec<f32>) -> Self {
106 self.embed_response = embedding;
107 self
108 }
109
110 pub fn remaining(&self) -> usize {
112 self.responses_guard().len()
113 }
114
115 fn next_response(&self) -> Result<MockResponse, PeError> {
116 self.responses_guard()
117 .pop_front()
118 .ok_or(PeError::MockProviderExhausted)
119 }
120
121 fn mock_response_to_llm(resp: MockResponse) -> Result<LlmResponse, PeError> {
122 match resp {
123 MockResponse::Text(text) => Ok(LlmResponse {
124 message: AiMessage {
125 content: MessageContent::Text(text),
126 tool_calls: vec![],
127 invalid_tool_calls: vec![],
128 usage_metadata: None,
129 response_metadata: Default::default(),
130 id: None,
131 },
132 provider_metadata: Default::default(),
133 }),
134 MockResponse::ToolCall { tool_name, args } => Ok(LlmResponse {
135 message: AiMessage {
136 content: MessageContent::Text(String::new()),
137 tool_calls: vec![ToolCall {
138 id: format!("call_{}", tool_name),
139 name: tool_name,
140 args,
141 }],
142 invalid_tool_calls: vec![],
143 usage_metadata: None,
144 response_metadata: Default::default(),
145 id: None,
146 },
147 provider_metadata: Default::default(),
148 }),
149 MockResponse::Error(e) => Err(e),
150 }
151 }
152}
153
154impl Default for MockProvider {
155 fn default() -> Self {
156 Self::new()
157 }
158}
159
160impl LlmProvider for MockProvider {
161 fn complete(
162 &self,
163 _messages: &[Message],
164 _tools: &[ToolSchema],
165 ) -> Pin<Box<dyn std::future::Future<Output = Result<LlmResponse, PeError>> + Send + '_>> {
166 Box::pin(async move {
167 let resp = self.next_response()?;
168 Self::mock_response_to_llm(resp)
169 })
170 }
171
172 fn stream(&self, _messages: &[Message], _tools: &[ToolSchema]) -> crate::llm::StreamFuture<'_> {
173 Box::pin(async move {
174 let resp = self.next_response()?;
175 let llm_resp = Self::mock_response_to_llm(resp)?;
176
177 let text = llm_resp.message.content.as_text().unwrap_or("").to_string();
179 let chunks = vec![StreamChunk::Token(text), StreamChunk::Done(llm_resp)];
180
181 Ok(Box::pin(futures::stream::iter(chunks))
182 as Pin<Box<dyn Stream<Item = StreamChunk> + Send>>)
183 })
184 }
185
186 fn embed(
187 &self,
188 _text: &str,
189 ) -> Pin<Box<dyn std::future::Future<Output = Result<Vec<f32>, PeError>> + Send + '_>> {
190 let embedding = self.embed_response.clone();
191 Box::pin(async move { Ok(embedding) })
192 }
193
194 fn provider_name(&self) -> &'static str {
195 "mock"
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202
203 #[tokio::test]
204 async fn test_text_response() {
205 let provider = MockProvider::new().respond_with("Hello, world!");
206
207 let resp = provider.complete(&[], &[]).await.unwrap();
208 assert_eq!(resp.message.content.as_text(), Some("Hello, world!"));
209 }
210
211 #[tokio::test]
212 async fn test_tool_call_response() {
213 let provider = MockProvider::new()
214 .respond_with_tool_call("web_search", serde_json::json!({ "query": "rust async" }));
215
216 let resp = provider.complete(&[], &[]).await.unwrap();
217 assert_eq!(resp.message.tool_calls.len(), 1);
218 assert_eq!(resp.message.tool_calls[0].name, "web_search");
219 }
220
221 #[tokio::test]
222 async fn test_multiple_responses_fifo() {
223 let provider = MockProvider::new()
224 .respond_with("first")
225 .respond_with("second")
226 .respond_with("third");
227
228 let r1 = provider.complete(&[], &[]).await.unwrap();
229 let r2 = provider.complete(&[], &[]).await.unwrap();
230 let r3 = provider.complete(&[], &[]).await.unwrap();
231
232 assert_eq!(r1.message.content.as_text(), Some("first"));
233 assert_eq!(r2.message.content.as_text(), Some("second"));
234 assert_eq!(r3.message.content.as_text(), Some("third"));
235 }
236
237 #[tokio::test]
238 async fn test_exhausted_queue_returns_error() {
239 let provider = MockProvider::new().respond_with("only one");
240
241 let _ = provider.complete(&[], &[]).await.unwrap();
242 let err = provider.complete(&[], &[]).await.unwrap_err();
243
244 assert!(matches!(err, PeError::MockProviderExhausted));
245 }
246
247 #[tokio::test]
248 async fn test_error_response() {
249 let provider = MockProvider::new().respond_with_error(PeError::LlmProvider {
250 details: "rate limited".into(),
251 });
252
253 let err = provider.complete(&[], &[]).await.unwrap_err();
254 assert!(matches!(err, PeError::LlmProvider { .. }));
255 }
256
257 #[tokio::test]
258 async fn test_embed_returns_configured_vector() {
259 let provider = MockProvider::new().with_embedding(vec![1.0, 2.0, 3.0]);
260
261 let embedding = provider.embed("test text").await.unwrap();
262 assert_eq!(embedding, vec![1.0, 2.0, 3.0]);
263 }
264
265 #[tokio::test]
266 async fn test_remaining_count() {
267 let provider = MockProvider::new().respond_with("a").respond_with("b");
268
269 assert_eq!(provider.remaining(), 2);
270 let _ = provider.complete(&[], &[]).await;
271 assert_eq!(provider.remaining(), 1);
272 }
273
274 #[test]
275 fn poisoned_queue_lock_is_recovered() {
276 let provider = MockProvider::new().respond_with("hello");
277
278 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
279 let _guard = provider.responses.lock().unwrap();
280 panic!("poison mock provider");
281 }));
282 assert!(result.is_err());
283
284 assert_eq!(provider.remaining(), 1);
285 }
286
287 #[tokio::test]
288 async fn test_object_safety() {
289 let provider: Box<dyn LlmProvider> = Box::new(MockProvider::new().respond_with("boxed"));
291 let resp = provider.complete(&[], &[]).await.unwrap();
292 assert_eq!(resp.message.content.as_text(), Some("boxed"));
293 }
294}