Skip to main content

rig_core/test_utils/
completion.rs

1//! Completion helpers for deterministic agent-loop tests.
2
3use std::{
4    collections::VecDeque,
5    sync::{Arc, Mutex, MutexGuard},
6};
7
8use crate::{
9    OneOrMany,
10    completion::{
11        AssistantContent, CompletionError, CompletionModel, CompletionRequest, CompletionResponse,
12        Usage,
13    },
14    message::{ToolCall, ToolFunction},
15    streaming::{StreamingCompletionResponse, StreamingResult},
16};
17
18use super::streaming::{MockResponse, MockStreamEvent};
19
20/// Scripted error returned by [`MockCompletionModel`].
21#[derive(Clone, Debug)]
22pub enum MockError {
23    /// Provider error.
24    Provider(String),
25    /// Request construction error.
26    Request(String),
27}
28
29impl MockError {
30    /// Create a provider error.
31    pub fn provider(message: impl Into<String>) -> Self {
32        Self::Provider(message.into())
33    }
34
35    /// Create a request error.
36    pub fn request(message: impl Into<String>) -> Self {
37        Self::Request(message.into())
38    }
39
40    pub(crate) fn into_completion_error(self) -> CompletionError {
41        match self {
42            Self::Provider(message) => CompletionError::ProviderError(message),
43            Self::Request(message) => CompletionError::RequestError(message.into()),
44        }
45    }
46}
47
48/// A scripted non-streaming mock completion turn.
49#[derive(Clone, Debug)]
50pub struct MockTurn {
51    response: Result<MockTurnResponse, MockError>,
52}
53
54#[derive(Clone, Debug)]
55struct MockTurnResponse {
56    choice: OneOrMany<AssistantContent>,
57    usage: Usage,
58    message_id: Option<String>,
59}
60
61impl MockTurn {
62    /// Create a text response turn.
63    pub fn text(text: impl Into<String>) -> Self {
64        Self::from_content(AssistantContent::text(text.into()))
65    }
66
67    /// Create a tool-call response turn.
68    pub fn tool_call(
69        id: impl Into<String>,
70        name: impl Into<String>,
71        arguments: serde_json::Value,
72    ) -> Self {
73        Self::from_content(AssistantContent::ToolCall(ToolCall::new(
74            id.into(),
75            ToolFunction::new(name.into(), arguments),
76        )))
77    }
78
79    /// Create a provider-error response turn.
80    pub fn error(message: impl Into<String>) -> Self {
81        Self {
82            response: Err(MockError::provider(message)),
83        }
84    }
85
86    /// Create a request-error response turn.
87    pub fn request_error(message: impl Into<String>) -> Self {
88        Self {
89            response: Err(MockError::request(message)),
90        }
91    }
92
93    /// Create a response turn from one assistant content item.
94    pub fn from_content(content: AssistantContent) -> Self {
95        Self {
96            response: Ok(MockTurnResponse {
97                choice: OneOrMany::one(content),
98                usage: Usage::new(),
99                message_id: None,
100            }),
101        }
102    }
103
104    /// Create a response turn from multiple assistant content items.
105    pub fn from_contents(
106        content: impl IntoIterator<Item = AssistantContent>,
107    ) -> Result<Self, crate::one_or_many::EmptyListError> {
108        Ok(Self {
109            response: Ok(MockTurnResponse {
110                choice: OneOrMany::many(content)?,
111                usage: Usage::new(),
112                message_id: None,
113            }),
114        })
115    }
116
117    /// Attach a provider-specific call ID to a tool-call response turn.
118    pub fn with_call_id(mut self, call_id: impl Into<String>) -> Self {
119        let call_id = call_id.into();
120        if let Ok(response) = &mut self.response {
121            for content in response.choice.iter_mut() {
122                if let AssistantContent::ToolCall(tool_call) = content {
123                    tool_call.call_id = Some(call_id);
124                    break;
125                }
126            }
127        }
128        self
129    }
130
131    /// Override usage for this turn.
132    pub fn with_usage(mut self, usage: Usage) -> Self {
133        if let Ok(response) = &mut self.response {
134            response.usage = usage;
135        }
136        self
137    }
138
139    /// Set a provider-assigned assistant message ID for this turn.
140    pub fn with_message_id(mut self, message_id: impl Into<String>) -> Self {
141        if let Ok(response) = &mut self.response {
142            response.message_id = Some(message_id.into());
143        }
144        self
145    }
146
147    fn into_completion_response(self) -> Result<CompletionResponse<MockResponse>, CompletionError> {
148        let response = self.response.map_err(MockError::into_completion_error)?;
149        Ok(CompletionResponse {
150            choice: response.choice,
151            usage: response.usage,
152            raw_response: MockResponse::with_usage(response.usage),
153            message_id: response.message_id,
154        })
155    }
156}
157
158#[derive(Default)]
159struct MockCompletionModelState {
160    turns: Mutex<VecDeque<MockTurn>>,
161    stream_turns: Mutex<VecDeque<Vec<MockStreamEvent>>>,
162    requests: Mutex<Vec<CompletionRequest>>,
163}
164
165/// A cloneable scripted [`CompletionModel`] for tests.
166///
167/// Each completion or stream call consumes exactly one scripted turn. If no turn
168/// is available, the model returns [`CompletionError::ProviderError`] with a
169/// clear message instead of repeating previous responses.
170#[derive(Clone, Default)]
171pub struct MockCompletionModel {
172    state: Arc<MockCompletionModelState>,
173}
174
175impl MockCompletionModel {
176    /// Create a mock model from scripted non-streaming turns.
177    pub fn new(turns: impl IntoIterator<Item = MockTurn>) -> Self {
178        Self::from_turns(turns)
179    }
180
181    /// Create a mock model that returns one text completion.
182    pub fn text(text: impl Into<String>) -> Self {
183        Self::from_turns([MockTurn::text(text)])
184    }
185
186    /// Create a mock model from scripted non-streaming turns.
187    pub fn from_turns(turns: impl IntoIterator<Item = MockTurn>) -> Self {
188        Self {
189            state: Arc::new(MockCompletionModelState {
190                turns: Mutex::new(turns.into_iter().collect()),
191                stream_turns: Mutex::new(VecDeque::new()),
192                requests: Mutex::new(Vec::new()),
193            }),
194        }
195    }
196
197    /// Create a mock model from scripted streaming turns.
198    pub fn from_stream_turns(
199        stream_turns: impl IntoIterator<Item = impl IntoIterator<Item = MockStreamEvent>>,
200    ) -> Self {
201        Self {
202            state: Arc::new(MockCompletionModelState {
203                turns: Mutex::new(VecDeque::new()),
204                stream_turns: Mutex::new(
205                    stream_turns
206                        .into_iter()
207                        .map(|turn| turn.into_iter().collect())
208                        .collect(),
209                ),
210                requests: Mutex::new(Vec::new()),
211            }),
212        }
213    }
214
215    /// Return cloned requests received by this model.
216    pub fn requests(&self) -> Vec<CompletionRequest> {
217        self.requests_guard().clone()
218    }
219
220    /// Return the number of requests received by this model.
221    pub fn request_count(&self) -> usize {
222        self.requests_guard().len()
223    }
224
225    fn record_request(&self, request: CompletionRequest) {
226        self.requests_guard().push(request);
227    }
228
229    fn next_turn(&self) -> Option<MockTurn> {
230        self.turns_guard().pop_front()
231    }
232
233    fn next_stream_turn(&self) -> Option<Vec<MockStreamEvent>> {
234        self.stream_turns_guard().pop_front()
235    }
236
237    fn turns_guard(&self) -> MutexGuard<'_, VecDeque<MockTurn>> {
238        match self.state.turns.lock() {
239            Ok(guard) => guard,
240            Err(poisoned) => poisoned.into_inner(),
241        }
242    }
243
244    fn stream_turns_guard(&self) -> MutexGuard<'_, VecDeque<Vec<MockStreamEvent>>> {
245        match self.state.stream_turns.lock() {
246            Ok(guard) => guard,
247            Err(poisoned) => poisoned.into_inner(),
248        }
249    }
250
251    fn requests_guard(&self) -> MutexGuard<'_, Vec<CompletionRequest>> {
252        match self.state.requests.lock() {
253            Ok(guard) => guard,
254            Err(poisoned) => poisoned.into_inner(),
255        }
256    }
257}
258
259impl CompletionModel for MockCompletionModel {
260    type Response = MockResponse;
261    type StreamingResponse = MockResponse;
262    type Client = ();
263
264    fn make(_: &Self::Client, _: impl Into<String>) -> Self {
265        Self::default()
266    }
267
268    async fn completion(
269        &self,
270        request: CompletionRequest,
271    ) -> Result<CompletionResponse<Self::Response>, CompletionError> {
272        self.record_request(request);
273        let Some(turn) = self.next_turn() else {
274            return Err(CompletionError::ProviderError(
275                "mock completion model has no scripted completion turn".to_string(),
276            ));
277        };
278
279        turn.into_completion_response()
280    }
281
282    async fn stream(
283        &self,
284        request: CompletionRequest,
285    ) -> Result<StreamingCompletionResponse<Self::StreamingResponse>, CompletionError> {
286        self.record_request(request);
287        let Some(events) = self.next_stream_turn() else {
288            return Err(CompletionError::ProviderError(
289                "mock completion model has no scripted streaming turn".to_string(),
290            ));
291        };
292
293        let stream = async_stream::stream! {
294            for event in events {
295                yield event.into_raw_choice();
296            }
297        };
298        let stream: StreamingResult<Self::StreamingResponse> = Box::pin(stream);
299        Ok(StreamingCompletionResponse::stream(stream))
300    }
301}
302
303#[cfg(test)]
304mod tests {
305    use super::*;
306    use crate::{
307        completion::GetTokenUsage,
308        message::Message,
309        streaming::{StreamedAssistantContent, ToolCallDeltaContent},
310    };
311    use futures::StreamExt;
312
313    fn request(prompt: &str) -> CompletionRequest {
314        CompletionRequest {
315            model: None,
316            preamble: None,
317            chat_history: OneOrMany::one(Message::user(prompt)),
318            documents: Vec::new(),
319            tools: Vec::new(),
320            temperature: None,
321            max_tokens: None,
322            tool_choice: None,
323            additional_params: None,
324            output_schema: None,
325        }
326    }
327
328    #[tokio::test]
329    async fn completion_consumes_scripted_turns_and_records_requests() {
330        let model = MockCompletionModel::new([
331            MockTurn::text("first").with_message_id("msg_1"),
332            MockTurn::tool_call("tool_1", "calculator", serde_json::json!({"x": 1}))
333                .with_call_id("call_1"),
334        ]);
335
336        let first = model
337            .completion(request("hello"))
338            .await
339            .expect("first scripted turn should succeed");
340        assert_eq!(first.message_id.as_deref(), Some("msg_1"));
341        assert!(matches!(
342            first.choice.first(),
343            AssistantContent::Text(text) if text.text == "first"
344        ));
345
346        let second = model
347            .completion(request("use a tool"))
348            .await
349            .expect("second scripted turn should succeed");
350        assert!(matches!(
351            second.choice.first(),
352            AssistantContent::ToolCall(tool_call)
353                if tool_call.id == "tool_1"
354                    && tool_call.call_id.as_deref() == Some("call_1")
355        ));
356
357        assert_eq!(model.request_count(), 2);
358        assert_eq!(model.requests().len(), 2);
359    }
360
361    #[tokio::test]
362    async fn missing_completion_turn_returns_provider_error() {
363        let model = MockCompletionModel::default();
364
365        let err = model
366            .completion(request("hello"))
367            .await
368            .expect_err("missing turn should error");
369
370        assert!(matches!(
371            err,
372            CompletionError::ProviderError(message)
373                if message.contains("no scripted completion turn")
374        ));
375    }
376
377    #[tokio::test]
378    async fn stream_yields_scripted_events_and_records_requests() {
379        let model = MockCompletionModel::from_stream_turns([[
380            MockStreamEvent::message_id("msg_stream"),
381            MockStreamEvent::text("hel"),
382            MockStreamEvent::text("lo"),
383            MockStreamEvent::tool_call_name_delta("tool_1", "internal_1", "calculator"),
384            MockStreamEvent::tool_call_arguments_delta("tool_1", "internal_1", "{\"x\":1}"),
385            MockStreamEvent::tool_call("tool_1", "calculator", serde_json::json!({"x": 1}))
386                .with_call_id("call_1"),
387            MockStreamEvent::final_response_with_total_tokens(7),
388        ]]);
389
390        let mut stream = model
391            .stream(request("stream"))
392            .await
393            .expect("stream should be created");
394
395        let mut text = String::new();
396        let mut saw_name_delta = false;
397        let mut saw_arguments_delta = false;
398        let mut saw_tool_call = false;
399        let mut saw_final = false;
400
401        while let Some(item) = stream.next().await {
402            match item.expect("stream event should succeed") {
403                StreamedAssistantContent::Text(chunk) => text.push_str(&chunk.text),
404                StreamedAssistantContent::ToolCallDelta { content, .. } => match content {
405                    ToolCallDeltaContent::Name(name) => {
406                        saw_name_delta = name == "calculator";
407                    }
408                    ToolCallDeltaContent::Delta(arguments) => {
409                        saw_arguments_delta = arguments == "{\"x\":1}";
410                    }
411                },
412                StreamedAssistantContent::ToolCall { tool_call, .. } => {
413                    saw_tool_call = tool_call.call_id.as_deref() == Some("call_1");
414                }
415                StreamedAssistantContent::Final(response) => {
416                    saw_final = matches!(
417                        response.token_usage(),
418                        Some(Usage {
419                            total_tokens: 7,
420                            ..
421                        })
422                    );
423                }
424                _ => {}
425            }
426        }
427
428        assert_eq!(text, "hello");
429        assert!(saw_name_delta);
430        assert!(saw_arguments_delta);
431        assert!(saw_tool_call);
432        assert!(saw_final);
433        assert_eq!(stream.message_id.as_deref(), Some("msg_stream"));
434        assert_eq!(model.request_count(), 1);
435    }
436
437    #[tokio::test]
438    async fn stream_error_event_is_returned() {
439        let model = MockCompletionModel::from_stream_turns([[MockStreamEvent::error("boom")]]);
440        let mut stream = model
441            .stream(request("stream"))
442            .await
443            .expect("stream should be created");
444
445        let err = stream
446            .next()
447            .await
448            .expect("stream should yield one event")
449            .expect_err("scripted event should error");
450
451        assert!(matches!(
452            err,
453            CompletionError::ProviderError(message) if message == "boom"
454        ));
455    }
456}