llm_sdk/llm_sdk_test/
model.rs

1use crate::{
2    boxed_stream::BoxedStream,
3    errors::{LanguageModelError, LanguageModelResult},
4    language_model::{LanguageModel, LanguageModelMetadata, LanguageModelStream},
5    LanguageModelInput, ModelResponse, PartialModelResponse,
6};
7use futures::{future::BoxFuture, stream};
8use std::{collections::VecDeque, sync::Mutex};
9
10/// Result for a mocked `generate` call.
11/// It can either be a full response or an error to return.
12pub enum MockGenerateResult {
13    Response(ModelResponse),
14    Error(LanguageModelError),
15}
16
17impl MockGenerateResult {
18    /// Construct a result that yields the provided response.
19    #[must_use]
20    pub fn response(response: ModelResponse) -> Self {
21        Self::Response(response)
22    }
23
24    /// Construct a result that yields the provided error.
25    #[must_use]
26    pub fn error(error: LanguageModelError) -> Self {
27        Self::Error(error)
28    }
29}
30
31impl From<ModelResponse> for MockGenerateResult {
32    fn from(response: ModelResponse) -> Self {
33        Self::response(response)
34    }
35}
36
37impl From<LanguageModelResult<ModelResponse>> for MockGenerateResult {
38    fn from(result: LanguageModelResult<ModelResponse>) -> Self {
39        match result {
40            Ok(response) => Self::Response(response),
41            Err(error) => Self::Error(error),
42        }
43    }
44}
45
46/// Result for a mocked `stream` call.
47/// It can either be a set of partial responses or an error to return.
48pub enum MockStreamResult {
49    Partials(Vec<PartialModelResponse>),
50    Error(LanguageModelError),
51}
52
53impl MockStreamResult {
54    /// Construct a result that yields the provided partial responses.
55    #[must_use]
56    pub fn partials(partials: Vec<PartialModelResponse>) -> Self {
57        Self::Partials(partials)
58    }
59
60    /// Construct a result that yields the provided error.
61    #[must_use]
62    pub fn error(error: LanguageModelError) -> Self {
63        Self::Error(error)
64    }
65}
66
67impl From<Vec<PartialModelResponse>> for MockStreamResult {
68    fn from(partials: Vec<PartialModelResponse>) -> Self {
69        Self::partials(partials)
70    }
71}
72
73impl From<PartialModelResponse> for MockStreamResult {
74    fn from(partial: PartialModelResponse) -> Self {
75        Self::partials(vec![partial])
76    }
77}
78
79impl From<LanguageModelResult<Vec<PartialModelResponse>>> for MockStreamResult {
80    fn from(result: LanguageModelResult<Vec<PartialModelResponse>>) -> Self {
81        match result {
82            Ok(partials) => Self::Partials(partials),
83            Err(error) => Self::Error(error),
84        }
85    }
86}
87
88#[derive(Default)]
89struct MockLanguageModelState {
90    mocked_generate_results: VecDeque<MockGenerateResult>,
91    mocked_stream_results: VecDeque<MockStreamResult>,
92    tracked_generate_inputs: Vec<LanguageModelInput>,
93    tracked_stream_inputs: Vec<LanguageModelInput>,
94}
95
96impl MockLanguageModelState {
97    fn enqueue_generate_result(&mut self, result: MockGenerateResult) {
98        self.mocked_generate_results.push_back(result);
99    }
100
101    fn enqueue_stream_result(&mut self, result: MockStreamResult) {
102        self.mocked_stream_results.push_back(result);
103    }
104
105    fn reset(&mut self) {
106        self.tracked_generate_inputs.clear();
107        self.tracked_stream_inputs.clear();
108    }
109
110    fn restore(&mut self) {
111        self.mocked_generate_results.clear();
112        self.mocked_stream_results.clear();
113        self.reset();
114    }
115}
116
117/// A mock language model for testing that tracks inputs and yields predefined
118/// outputs.
119pub struct MockLanguageModel {
120    provider: &'static str,
121    model_id: String,
122    metadata: Option<LanguageModelMetadata>,
123    state: Mutex<MockLanguageModelState>,
124}
125
126impl Default for MockLanguageModel {
127    fn default() -> Self {
128        Self {
129            provider: "mock",
130            model_id: "mock-model".to_string(),
131            metadata: None,
132            state: Mutex::new(MockLanguageModelState::default()),
133        }
134    }
135}
136
137impl MockLanguageModel {
138    /// Construct a new mock language model instance.
139    #[must_use]
140    pub fn new() -> Self {
141        Self::default()
142    }
143
144    /// Override the provider identifier returned by the mock.
145    pub fn set_provider(&mut self, provider: &'static str) {
146        self.provider = provider;
147    }
148
149    /// Override the model identifier returned by the mock.
150    pub fn set_model_id<S: Into<String>>(&mut self, model_id: S) {
151        self.model_id = model_id.into();
152    }
153
154    /// Override the metadata returned by the mock.
155    pub fn set_metadata(&mut self, metadata: Option<LanguageModelMetadata>) {
156        self.metadata = metadata;
157    }
158
159    /// Enqueue one or more mocked generate results.
160    /// # Panics
161    /// Panics if the internal state mutex is poisoned.
162    pub fn enqueue_generate_results<I>(&self, results: I) -> &Self
163    where
164        I: IntoIterator<Item = MockGenerateResult>,
165    {
166        let mut state = self.state.lock().expect("mock state poisoned");
167        for result in results {
168            state.enqueue_generate_result(result);
169        }
170        drop(state);
171        self
172    }
173
174    /// Convenience to enqueue a single mocked generate result.
175    pub fn enqueue_generate<R>(&self, result: R) -> &Self
176    where
177        R: Into<MockGenerateResult>,
178    {
179        self.enqueue_generate_results(std::iter::once(result.into()))
180    }
181
182    /// Enqueue one or more mocked stream results.
183    /// # Panics
184    /// Panics if the internal state mutex is poisoned.
185    pub fn enqueue_stream_results<I>(&self, results: I) -> &Self
186    where
187        I: IntoIterator<Item = MockStreamResult>,
188    {
189        let mut state = self.state.lock().expect("mock state poisoned");
190        for result in results {
191            state.enqueue_stream_result(result);
192        }
193        drop(state);
194        self
195    }
196
197    /// Convenience to enqueue a single mocked stream result.
198    pub fn enqueue_stream<R>(&self, result: R) -> &Self
199    where
200        R: Into<MockStreamResult>,
201    {
202        self.enqueue_stream_results(std::iter::once(result.into()))
203    }
204
205    /// Retrieve the tracked generate inputs accumulated so far.
206    /// # Panics
207    /// Panics if the internal state mutex is poisoned.
208    pub fn tracked_generate_inputs(&self) -> Vec<LanguageModelInput> {
209        let state = self.state.lock().expect("mock state poisoned");
210        state.tracked_generate_inputs.clone()
211    }
212
213    /// Retrieve the tracked stream inputs accumulated so far.
214    /// # Panics
215    /// Panics if the internal state mutex is poisoned.
216    pub fn tracked_stream_inputs(&self) -> Vec<LanguageModelInput> {
217        let state = self.state.lock().expect("mock state poisoned");
218        state.tracked_stream_inputs.clone()
219    }
220
221    /// Reset tracked inputs without touching enqueued results.
222    /// # Panics
223    /// Panics if the internal state mutex is poisoned.
224    pub fn reset(&self) {
225        let mut state = self.state.lock().expect("mock state poisoned");
226        state.reset();
227    }
228
229    /// Clear both tracked inputs and enqueued results.
230    /// # Panics
231    /// Panics if the internal state mutex is poisoned.
232    pub fn restore(&self) {
233        let mut state = self.state.lock().expect("mock state poisoned");
234        state.restore();
235    }
236}
237
238impl LanguageModel for MockLanguageModel {
239    fn provider(&self) -> &'static str {
240        self.provider
241    }
242
243    fn model_id(&self) -> String {
244        self.model_id.clone()
245    }
246
247    fn metadata(&self) -> Option<&LanguageModelMetadata> {
248        self.metadata.as_ref()
249    }
250
251    fn generate(
252        &self,
253        input: LanguageModelInput,
254    ) -> BoxFuture<'_, LanguageModelResult<ModelResponse>> {
255        Box::pin(async move {
256            let mut state = self.state.lock().expect("mock state poisoned");
257            state.tracked_generate_inputs.push(input.clone());
258
259            let result = state.mocked_generate_results.pop_front().ok_or_else(|| {
260                LanguageModelError::Invariant(
261                    self.provider,
262                    "no mocked generate results available".into(),
263                )
264            })?;
265
266            match result {
267                MockGenerateResult::Response(response) => Ok(response),
268                MockGenerateResult::Error(error) => Err(error),
269            }
270        })
271    }
272
273    fn stream(
274        &self,
275        input: LanguageModelInput,
276    ) -> BoxFuture<'_, LanguageModelResult<LanguageModelStream>> {
277        Box::pin(async move {
278            let mut state = self.state.lock().expect("mock state poisoned");
279
280            let result = state.mocked_stream_results.pop_front().ok_or_else(|| {
281                LanguageModelError::Invariant(
282                    self.provider,
283                    "no mocked stream results available".into(),
284                )
285            })?;
286
287            state.tracked_stream_inputs.push(input.clone());
288
289            match result {
290                MockStreamResult::Error(error) => Err(error),
291                MockStreamResult::Partials(partials) => {
292                    let stream = stream_from_partials(partials);
293                    Ok(stream)
294                }
295            }
296        })
297    }
298}
299
300fn stream_from_partials(partials: Vec<PartialModelResponse>) -> LanguageModelStream {
301    let iter = stream::iter(partials.into_iter().map(Ok));
302    BoxedStream::from_stream(iter)
303}