llm_sdk/llm_sdk_test/
model.rs1use crate::{
2 boxed_stream::BoxedStream,
3 errors::{LanguageModelError, LanguageModelResult},
4 language_model::{LanguageModel, LanguageModelMetadata, LanguageModelStream},
5 LanguageModelInput, ModelResponse, PartialModelResponse,
6};
7use futures::{future::BoxFuture, stream};
8use std::{collections::VecDeque, sync::Mutex};
9
10pub enum MockGenerateResult {
13 Response(ModelResponse),
14 Error(LanguageModelError),
15}
16
17impl MockGenerateResult {
18 #[must_use]
20 pub fn response(response: ModelResponse) -> Self {
21 Self::Response(response)
22 }
23
24 #[must_use]
26 pub fn error(error: LanguageModelError) -> Self {
27 Self::Error(error)
28 }
29}
30
31impl From<ModelResponse> for MockGenerateResult {
32 fn from(response: ModelResponse) -> Self {
33 Self::response(response)
34 }
35}
36
37impl From<LanguageModelResult<ModelResponse>> for MockGenerateResult {
38 fn from(result: LanguageModelResult<ModelResponse>) -> Self {
39 match result {
40 Ok(response) => Self::Response(response),
41 Err(error) => Self::Error(error),
42 }
43 }
44}
45
46pub enum MockStreamResult {
49 Partials(Vec<PartialModelResponse>),
50 Error(LanguageModelError),
51}
52
53impl MockStreamResult {
54 #[must_use]
56 pub fn partials(partials: Vec<PartialModelResponse>) -> Self {
57 Self::Partials(partials)
58 }
59
60 #[must_use]
62 pub fn error(error: LanguageModelError) -> Self {
63 Self::Error(error)
64 }
65}
66
67impl From<Vec<PartialModelResponse>> for MockStreamResult {
68 fn from(partials: Vec<PartialModelResponse>) -> Self {
69 Self::partials(partials)
70 }
71}
72
73impl From<PartialModelResponse> for MockStreamResult {
74 fn from(partial: PartialModelResponse) -> Self {
75 Self::partials(vec![partial])
76 }
77}
78
79impl From<LanguageModelResult<Vec<PartialModelResponse>>> for MockStreamResult {
80 fn from(result: LanguageModelResult<Vec<PartialModelResponse>>) -> Self {
81 match result {
82 Ok(partials) => Self::Partials(partials),
83 Err(error) => Self::Error(error),
84 }
85 }
86}
87
88#[derive(Default)]
89struct MockLanguageModelState {
90 mocked_generate_results: VecDeque<MockGenerateResult>,
91 mocked_stream_results: VecDeque<MockStreamResult>,
92 tracked_generate_inputs: Vec<LanguageModelInput>,
93 tracked_stream_inputs: Vec<LanguageModelInput>,
94}
95
96impl MockLanguageModelState {
97 fn enqueue_generate_result(&mut self, result: MockGenerateResult) {
98 self.mocked_generate_results.push_back(result);
99 }
100
101 fn enqueue_stream_result(&mut self, result: MockStreamResult) {
102 self.mocked_stream_results.push_back(result);
103 }
104
105 fn reset(&mut self) {
106 self.tracked_generate_inputs.clear();
107 self.tracked_stream_inputs.clear();
108 }
109
110 fn restore(&mut self) {
111 self.mocked_generate_results.clear();
112 self.mocked_stream_results.clear();
113 self.reset();
114 }
115}
116
117pub struct MockLanguageModel {
120 provider: &'static str,
121 model_id: String,
122 metadata: Option<LanguageModelMetadata>,
123 state: Mutex<MockLanguageModelState>,
124}
125
126impl Default for MockLanguageModel {
127 fn default() -> Self {
128 Self {
129 provider: "mock",
130 model_id: "mock-model".to_string(),
131 metadata: None,
132 state: Mutex::new(MockLanguageModelState::default()),
133 }
134 }
135}
136
137impl MockLanguageModel {
138 #[must_use]
140 pub fn new() -> Self {
141 Self::default()
142 }
143
144 pub fn set_provider(&mut self, provider: &'static str) {
146 self.provider = provider;
147 }
148
149 pub fn set_model_id<S: Into<String>>(&mut self, model_id: S) {
151 self.model_id = model_id.into();
152 }
153
154 pub fn set_metadata(&mut self, metadata: Option<LanguageModelMetadata>) {
156 self.metadata = metadata;
157 }
158
159 pub fn enqueue_generate_results<I>(&self, results: I) -> &Self
163 where
164 I: IntoIterator<Item = MockGenerateResult>,
165 {
166 let mut state = self.state.lock().expect("mock state poisoned");
167 for result in results {
168 state.enqueue_generate_result(result);
169 }
170 drop(state);
171 self
172 }
173
174 pub fn enqueue_generate<R>(&self, result: R) -> &Self
176 where
177 R: Into<MockGenerateResult>,
178 {
179 self.enqueue_generate_results(std::iter::once(result.into()))
180 }
181
182 pub fn enqueue_stream_results<I>(&self, results: I) -> &Self
186 where
187 I: IntoIterator<Item = MockStreamResult>,
188 {
189 let mut state = self.state.lock().expect("mock state poisoned");
190 for result in results {
191 state.enqueue_stream_result(result);
192 }
193 drop(state);
194 self
195 }
196
197 pub fn enqueue_stream<R>(&self, result: R) -> &Self
199 where
200 R: Into<MockStreamResult>,
201 {
202 self.enqueue_stream_results(std::iter::once(result.into()))
203 }
204
205 pub fn tracked_generate_inputs(&self) -> Vec<LanguageModelInput> {
209 let state = self.state.lock().expect("mock state poisoned");
210 state.tracked_generate_inputs.clone()
211 }
212
213 pub fn tracked_stream_inputs(&self) -> Vec<LanguageModelInput> {
217 let state = self.state.lock().expect("mock state poisoned");
218 state.tracked_stream_inputs.clone()
219 }
220
221 pub fn reset(&self) {
225 let mut state = self.state.lock().expect("mock state poisoned");
226 state.reset();
227 }
228
229 pub fn restore(&self) {
233 let mut state = self.state.lock().expect("mock state poisoned");
234 state.restore();
235 }
236}
237
238impl LanguageModel for MockLanguageModel {
239 fn provider(&self) -> &'static str {
240 self.provider
241 }
242
243 fn model_id(&self) -> String {
244 self.model_id.clone()
245 }
246
247 fn metadata(&self) -> Option<&LanguageModelMetadata> {
248 self.metadata.as_ref()
249 }
250
251 fn generate(
252 &self,
253 input: LanguageModelInput,
254 ) -> BoxFuture<'_, LanguageModelResult<ModelResponse>> {
255 Box::pin(async move {
256 let mut state = self.state.lock().expect("mock state poisoned");
257 state.tracked_generate_inputs.push(input.clone());
258
259 let result = state.mocked_generate_results.pop_front().ok_or_else(|| {
260 LanguageModelError::Invariant(
261 self.provider,
262 "no mocked generate results available".into(),
263 )
264 })?;
265
266 match result {
267 MockGenerateResult::Response(response) => Ok(response),
268 MockGenerateResult::Error(error) => Err(error),
269 }
270 })
271 }
272
273 fn stream(
274 &self,
275 input: LanguageModelInput,
276 ) -> BoxFuture<'_, LanguageModelResult<LanguageModelStream>> {
277 Box::pin(async move {
278 let mut state = self.state.lock().expect("mock state poisoned");
279
280 let result = state.mocked_stream_results.pop_front().ok_or_else(|| {
281 LanguageModelError::Invariant(
282 self.provider,
283 "no mocked stream results available".into(),
284 )
285 })?;
286
287 state.tracked_stream_inputs.push(input.clone());
288
289 match result {
290 MockStreamResult::Error(error) => Err(error),
291 MockStreamResult::Partials(partials) => {
292 let stream = stream_from_partials(partials);
293 Ok(stream)
294 }
295 }
296 })
297 }
298}
299
300fn stream_from_partials(partials: Vec<PartialModelResponse>) -> LanguageModelStream {
301 let iter = stream::iter(partials.into_iter().map(Ok));
302 BoxedStream::from_stream(iter)
303}