1use std::{
4 collections::VecDeque,
5 sync::{Arc, Mutex, MutexGuard},
6};
7
8use crate::{
9 OneOrMany,
10 completion::{
11 AssistantContent, CompletionError, CompletionModel, CompletionRequest, CompletionResponse,
12 Usage,
13 },
14 message::{ToolCall, ToolFunction},
15 streaming::{StreamingCompletionResponse, StreamingResult},
16};
17
18use super::streaming::{MockResponse, MockStreamEvent};
19
20#[derive(Clone, Debug)]
22pub enum MockError {
23 Provider(String),
25 Request(String),
27}
28
29impl MockError {
30 pub fn provider(message: impl Into<String>) -> Self {
32 Self::Provider(message.into())
33 }
34
35 pub fn request(message: impl Into<String>) -> Self {
37 Self::Request(message.into())
38 }
39
40 pub(crate) fn into_completion_error(self) -> CompletionError {
41 match self {
42 Self::Provider(message) => CompletionError::ProviderError(message),
43 Self::Request(message) => CompletionError::RequestError(message.into()),
44 }
45 }
46}
47
48#[derive(Clone, Debug)]
50pub struct MockTurn {
51 response: Result<MockTurnResponse, MockError>,
52}
53
54#[derive(Clone, Debug)]
55struct MockTurnResponse {
56 choice: OneOrMany<AssistantContent>,
57 usage: Usage,
58 message_id: Option<String>,
59}
60
61impl MockTurn {
62 pub fn text(text: impl Into<String>) -> Self {
64 Self::from_content(AssistantContent::text(text.into()))
65 }
66
67 pub fn tool_call(
69 id: impl Into<String>,
70 name: impl Into<String>,
71 arguments: serde_json::Value,
72 ) -> Self {
73 Self::from_content(AssistantContent::ToolCall(ToolCall::new(
74 id.into(),
75 ToolFunction::new(name.into(), arguments),
76 )))
77 }
78
79 pub fn error(message: impl Into<String>) -> Self {
81 Self {
82 response: Err(MockError::provider(message)),
83 }
84 }
85
86 pub fn request_error(message: impl Into<String>) -> Self {
88 Self {
89 response: Err(MockError::request(message)),
90 }
91 }
92
93 pub fn from_content(content: AssistantContent) -> Self {
95 Self {
96 response: Ok(MockTurnResponse {
97 choice: OneOrMany::one(content),
98 usage: Usage::new(),
99 message_id: None,
100 }),
101 }
102 }
103
104 pub fn from_contents(
106 content: impl IntoIterator<Item = AssistantContent>,
107 ) -> Result<Self, crate::one_or_many::EmptyListError> {
108 Ok(Self {
109 response: Ok(MockTurnResponse {
110 choice: OneOrMany::many(content)?,
111 usage: Usage::new(),
112 message_id: None,
113 }),
114 })
115 }
116
117 pub fn with_call_id(mut self, call_id: impl Into<String>) -> Self {
119 let call_id = call_id.into();
120 if let Ok(response) = &mut self.response {
121 for content in response.choice.iter_mut() {
122 if let AssistantContent::ToolCall(tool_call) = content {
123 tool_call.call_id = Some(call_id);
124 break;
125 }
126 }
127 }
128 self
129 }
130
131 pub fn with_usage(mut self, usage: Usage) -> Self {
133 if let Ok(response) = &mut self.response {
134 response.usage = usage;
135 }
136 self
137 }
138
139 pub fn with_message_id(mut self, message_id: impl Into<String>) -> Self {
141 if let Ok(response) = &mut self.response {
142 response.message_id = Some(message_id.into());
143 }
144 self
145 }
146
147 fn into_completion_response(self) -> Result<CompletionResponse<MockResponse>, CompletionError> {
148 let response = self.response.map_err(MockError::into_completion_error)?;
149 Ok(CompletionResponse {
150 choice: response.choice,
151 usage: response.usage,
152 raw_response: MockResponse::with_usage(response.usage),
153 message_id: response.message_id,
154 })
155 }
156}
157
158#[derive(Default)]
159struct MockCompletionModelState {
160 turns: Mutex<VecDeque<MockTurn>>,
161 stream_turns: Mutex<VecDeque<Vec<MockStreamEvent>>>,
162 requests: Mutex<Vec<CompletionRequest>>,
163}
164
165#[derive(Clone, Default)]
171pub struct MockCompletionModel {
172 state: Arc<MockCompletionModelState>,
173}
174
175impl MockCompletionModel {
176 pub fn new(turns: impl IntoIterator<Item = MockTurn>) -> Self {
178 Self::from_turns(turns)
179 }
180
181 pub fn text(text: impl Into<String>) -> Self {
183 Self::from_turns([MockTurn::text(text)])
184 }
185
186 pub fn from_turns(turns: impl IntoIterator<Item = MockTurn>) -> Self {
188 Self {
189 state: Arc::new(MockCompletionModelState {
190 turns: Mutex::new(turns.into_iter().collect()),
191 stream_turns: Mutex::new(VecDeque::new()),
192 requests: Mutex::new(Vec::new()),
193 }),
194 }
195 }
196
197 pub fn from_stream_turns(
199 stream_turns: impl IntoIterator<Item = impl IntoIterator<Item = MockStreamEvent>>,
200 ) -> Self {
201 Self {
202 state: Arc::new(MockCompletionModelState {
203 turns: Mutex::new(VecDeque::new()),
204 stream_turns: Mutex::new(
205 stream_turns
206 .into_iter()
207 .map(|turn| turn.into_iter().collect())
208 .collect(),
209 ),
210 requests: Mutex::new(Vec::new()),
211 }),
212 }
213 }
214
215 pub fn requests(&self) -> Vec<CompletionRequest> {
217 self.requests_guard().clone()
218 }
219
220 pub fn request_count(&self) -> usize {
222 self.requests_guard().len()
223 }
224
225 fn record_request(&self, request: CompletionRequest) {
226 self.requests_guard().push(request);
227 }
228
229 fn next_turn(&self) -> Option<MockTurn> {
230 self.turns_guard().pop_front()
231 }
232
233 fn next_stream_turn(&self) -> Option<Vec<MockStreamEvent>> {
234 self.stream_turns_guard().pop_front()
235 }
236
237 fn turns_guard(&self) -> MutexGuard<'_, VecDeque<MockTurn>> {
238 match self.state.turns.lock() {
239 Ok(guard) => guard,
240 Err(poisoned) => poisoned.into_inner(),
241 }
242 }
243
244 fn stream_turns_guard(&self) -> MutexGuard<'_, VecDeque<Vec<MockStreamEvent>>> {
245 match self.state.stream_turns.lock() {
246 Ok(guard) => guard,
247 Err(poisoned) => poisoned.into_inner(),
248 }
249 }
250
251 fn requests_guard(&self) -> MutexGuard<'_, Vec<CompletionRequest>> {
252 match self.state.requests.lock() {
253 Ok(guard) => guard,
254 Err(poisoned) => poisoned.into_inner(),
255 }
256 }
257}
258
259impl CompletionModel for MockCompletionModel {
260 type Response = MockResponse;
261 type StreamingResponse = MockResponse;
262 type Client = ();
263
264 fn make(_: &Self::Client, _: impl Into<String>) -> Self {
265 Self::default()
266 }
267
268 async fn completion(
269 &self,
270 request: CompletionRequest,
271 ) -> Result<CompletionResponse<Self::Response>, CompletionError> {
272 self.record_request(request);
273 let Some(turn) = self.next_turn() else {
274 return Err(CompletionError::ProviderError(
275 "mock completion model has no scripted completion turn".to_string(),
276 ));
277 };
278
279 turn.into_completion_response()
280 }
281
282 async fn stream(
283 &self,
284 request: CompletionRequest,
285 ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
286 self.record_request(request);
287 let Some(events) = self.next_stream_turn() else {
288 return Err(CompletionError::ProviderError(
289 "mock completion model has no scripted streaming turn".to_string(),
290 ));
291 };
292
293 let stream = async_stream::stream! {
294 for event in events {
295 yield event.into_raw_choice();
296 }
297 };
298 let stream: StreamingResult<Self::StreamingResponse> = Box::pin(stream);
299 Ok(StreamingCompletionResponse::stream(stream))
300 }
301}
302
303#[cfg(test)]
304mod tests {
305 use super::*;
306 use crate::{
307 completion::GetTokenUsage,
308 message::Message,
309 streaming::{StreamedAssistantContent, ToolCallDeltaContent},
310 };
311 use futures::StreamExt;
312
313 fn request(prompt: &str) -> CompletionRequest {
314 CompletionRequest {
315 model: None,
316 preamble: None,
317 chat_history: OneOrMany::one(Message::user(prompt)),
318 documents: Vec::new(),
319 tools: Vec::new(),
320 temperature: None,
321 max_tokens: None,
322 tool_choice: None,
323 additional_params: None,
324 output_schema: None,
325 }
326 }
327
328 #[tokio::test]
329 async fn completion_consumes_scripted_turns_and_records_requests() {
330 let model = MockCompletionModel::new([
331 MockTurn::text("first").with_message_id("msg_1"),
332 MockTurn::tool_call("tool_1", "calculator", serde_json::json!({"x": 1}))
333 .with_call_id("call_1"),
334 ]);
335
336 let first = model
337 .completion(request("hello"))
338 .await
339 .expect("first scripted turn should succeed");
340 assert_eq!(first.message_id.as_deref(), Some("msg_1"));
341 assert!(matches!(
342 first.choice.first(),
343 AssistantContent::Text(text) if text.text == "first"
344 ));
345
346 let second = model
347 .completion(request("use a tool"))
348 .await
349 .expect("second scripted turn should succeed");
350 assert!(matches!(
351 second.choice.first(),
352 AssistantContent::ToolCall(tool_call)
353 if tool_call.id == "tool_1"
354 && tool_call.call_id.as_deref() == Some("call_1")
355 ));
356
357 assert_eq!(model.request_count(), 2);
358 assert_eq!(model.requests().len(), 2);
359 }
360
361 #[tokio::test]
362 async fn missing_completion_turn_returns_provider_error() {
363 let model = MockCompletionModel::default();
364
365 let err = model
366 .completion(request("hello"))
367 .await
368 .expect_err("missing turn should error");
369
370 assert!(matches!(
371 err,
372 CompletionError::ProviderError(message)
373 if message.contains("no scripted completion turn")
374 ));
375 }
376
377 #[tokio::test]
378 async fn stream_yields_scripted_events_and_records_requests() {
379 let model = MockCompletionModel::from_stream_turns([[
380 MockStreamEvent::message_id("msg_stream"),
381 MockStreamEvent::text("hel"),
382 MockStreamEvent::text("lo"),
383 MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "calculator"),
384 MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
385 MockStreamEvent::tool_call("tool_1", "calculator", serde_json::json!({"x": 1}))
386 .with_call_id("call_1"),
387 MockStreamEvent::final_response_with_total_tokens(7),
388 ]]);
389
390 let mut stream = model
391 .stream(request("stream"))
392 .await
393 .expect("stream should be created");
394
395 let mut text = String::new();
396 let mut saw_name_delta = false;
397 let mut saw_arguments_delta = false;
398 let mut saw_tool_call = false;
399 let mut saw_final = false;
400
401 while let Some(item) = stream.next().await {
402 match item.expect("stream event should succeed") {
403 StreamedAssistantContent::Text(chunk) => text.push_str(&chunk.text),
404 StreamedAssistantContent::ToolCallDelta { content, .. } => match content {
405 ToolCallDeltaContent::Name(name) => {
406 saw_name_delta = name == "calculator";
407 }
408 ToolCallDeltaContent::Delta(arguments) => {
409 saw_arguments_delta = arguments == "{\"x\":1}";
410 }
411 },
412 StreamedAssistantContent::ToolCall { tool_call, .. } => {
413 saw_tool_call = tool_call.call_id.as_deref() == Some("call_1");
414 }
415 StreamedAssistantContent::Final(response) => {
416 saw_final = matches!(
417 response.token_usage(),
418 Some(Usage {
419 total_tokens: 7,
420 ..
421 })
422 );
423 }
424 _ => {}
425 }
426 }
427
428 assert_eq!(text, "hello");
429 assert!(saw_name_delta);
430 assert!(saw_arguments_delta);
431 assert!(saw_tool_call);
432 assert!(saw_final);
433 assert_eq!(stream.message_id.as_deref(), Some("msg_stream"));
434 assert_eq!(model.request_count(), 1);
435 }
436
437 #[tokio::test]
438 async fn stream_error_event_is_returned() {
439 let model = MockCompletionModel::from_stream_turns([[MockStreamEvent::error("boom")]]);
440 let mut stream = model
441 .stream(request("stream"))
442 .await
443 .expect("stream should be created");
444
445 let err = stream
446 .next()
447 .await
448 .expect("stream should yield one event")
449 .expect_err("scripted event should error");
450
451 assert!(matches!(
452 err,
453 CompletionError::ProviderError(message) if message == "boom"
454 ));
455 }
456}