llmsdk_provider/middleware/
language_model.rs1use std::sync::Arc;
11
12use async_trait::async_trait;
13
14use crate::error::Result;
15use crate::language_model::{
16 CallOptions, GenerateResult, LanguageModel, StreamResult, SupportedUrls,
17};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub enum CallKind {
25 Generate,
28 Stream,
31}
32
33#[async_trait]
53pub trait LanguageModelMiddleware: Send + Sync + std::fmt::Debug {
54 fn override_provider(&self, _inner: &dyn LanguageModel) -> Option<String> {
60 None
61 }
62
63 fn override_model_id(&self, _inner: &dyn LanguageModel) -> Option<String> {
67 None
68 }
69
70 async fn override_supported_urls(&self, _inner: &dyn LanguageModel) -> Option<SupportedUrls> {
76 None
77 }
78
79 async fn transform_params(
90 &self,
91 _kind: CallKind,
92 params: CallOptions,
93 _inner: &dyn LanguageModel,
94 ) -> Result<CallOptions> {
95 Ok(params)
96 }
97
98 async fn wrap_generate(
108 &self,
109 next: &dyn LanguageModel,
110 params: CallOptions,
111 ) -> Result<GenerateResult> {
112 next.do_generate(params).await
113 }
114
115 async fn wrap_stream(
126 &self,
127 next: &dyn LanguageModel,
128 params: CallOptions,
129 ) -> Result<StreamResult> {
130 next.do_stream(params).await
131 }
132}
133
134pub fn wrap_language_model<I>(
161 model: Arc<dyn LanguageModel>,
162 middleware: I,
163) -> Arc<dyn LanguageModel>
164where
165 I: IntoIterator<Item = Arc<dyn LanguageModelMiddleware>>,
166{
167 let mut layers: Vec<Arc<dyn LanguageModelMiddleware>> = middleware.into_iter().collect();
168 layers.reverse();
170 layers
171 .into_iter()
172 .fold(model, |inner, mw| Arc::new(Wrapped::new(inner, mw)))
173}
174
175struct Wrapped {
181 inner: Arc<dyn LanguageModel>,
182 middleware: Arc<dyn LanguageModelMiddleware>,
183 provider: String,
184 model_id: String,
185}
186
187impl Wrapped {
188 fn new(inner: Arc<dyn LanguageModel>, middleware: Arc<dyn LanguageModelMiddleware>) -> Self {
189 let provider = middleware
190 .override_provider(inner.as_ref())
191 .unwrap_or_else(|| inner.provider().to_owned());
192 let model_id = middleware
193 .override_model_id(inner.as_ref())
194 .unwrap_or_else(|| inner.model_id().to_owned());
195 Self {
196 inner,
197 middleware,
198 provider,
199 model_id,
200 }
201 }
202}
203
204impl std::fmt::Debug for Wrapped {
205 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206 f.debug_struct("Wrapped")
207 .field("provider", &self.provider)
208 .field("model_id", &self.model_id)
209 .field("middleware", &self.middleware)
210 .field("inner", &self.inner)
211 .finish()
212 }
213}
214
215#[async_trait]
216impl LanguageModel for Wrapped {
217 fn provider(&self) -> &str {
218 &self.provider
219 }
220
221 fn model_id(&self) -> &str {
222 &self.model_id
223 }
224
225 async fn supported_urls(&self) -> SupportedUrls {
226 if let Some(custom) = self
227 .middleware
228 .override_supported_urls(self.inner.as_ref())
229 .await
230 {
231 return custom;
232 }
233 self.inner.supported_urls().await
234 }
235
236 async fn do_generate(&self, options: CallOptions) -> Result<GenerateResult> {
237 let transformed = self
238 .middleware
239 .transform_params(CallKind::Generate, options, self.inner.as_ref())
240 .await?;
241 self.middleware
242 .wrap_generate(self.inner.as_ref(), transformed)
243 .await
244 }
245
246 async fn do_stream(&self, options: CallOptions) -> Result<StreamResult> {
247 let transformed = self
248 .middleware
249 .transform_params(CallKind::Stream, options, self.inner.as_ref())
250 .await?;
251 self.middleware
252 .wrap_stream(self.inner.as_ref(), transformed)
253 .await
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use std::sync::Mutex;
260 use std::sync::atomic::{AtomicUsize, Ordering};
261
262 use futures::StreamExt;
263 use futures::stream;
264
265 use crate::language_model::{FinishReason, FinishReasonKind, StreamPart, Usage};
266
267 use super::*;
268
269 #[derive(Debug, Default)]
272 struct MockModel {
273 provider: String,
274 model_id: String,
275 generate_calls: AtomicUsize,
276 stream_calls: AtomicUsize,
277 last_params: Mutex<Option<CallOptions>>,
278 }
279
280 impl MockModel {
281 fn new(provider: &str, model_id: &str) -> Self {
282 Self {
283 provider: provider.to_owned(),
284 model_id: model_id.to_owned(),
285 generate_calls: AtomicUsize::new(0),
286 stream_calls: AtomicUsize::new(0),
287 last_params: Mutex::new(None),
288 }
289 }
290
291 fn generate_count(&self) -> usize {
292 self.generate_calls.load(Ordering::SeqCst)
293 }
294
295 fn stream_count(&self) -> usize {
296 self.stream_calls.load(Ordering::SeqCst)
297 }
298
299 fn last_temperature(&self) -> Option<f32> {
300 self.last_params
301 .lock()
302 .expect("mock mutex poisoned")
303 .as_ref()
304 .and_then(|p| p.temperature)
305 }
306 }
307
308 #[async_trait]
309 impl LanguageModel for MockModel {
310 fn provider(&self) -> &str {
311 &self.provider
312 }
313
314 fn model_id(&self) -> &str {
315 &self.model_id
316 }
317
318 async fn do_generate(&self, options: CallOptions) -> Result<GenerateResult> {
319 self.generate_calls.fetch_add(1, Ordering::SeqCst);
320 *self.last_params.lock().expect("mock mutex poisoned") = Some(options);
321 Ok(GenerateResult {
322 content: vec![],
323 finish_reason: FinishReason::new(FinishReasonKind::Stop),
324 usage: Usage::default(),
325 provider_metadata: None,
326 request: None,
327 response: None,
328 warnings: vec![],
329 })
330 }
331
332 async fn do_stream(&self, options: CallOptions) -> Result<StreamResult> {
333 self.stream_calls.fetch_add(1, Ordering::SeqCst);
334 *self.last_params.lock().expect("mock mutex poisoned") = Some(options);
335 let parts = stream::iter(vec![
336 Ok(StreamPart::StreamStart { warnings: vec![] }),
337 Ok(StreamPart::Finish {
338 usage: Usage::default(),
339 finish_reason: FinishReason::new(FinishReasonKind::Stop),
340 provider_metadata: None,
341 }),
342 ]);
343 Ok(StreamResult {
344 stream: Box::pin(parts),
345 request: None,
346 response: None,
347 })
348 }
349 }
350
351 #[derive(Debug)]
353 struct OverrideAndTransform;
354
355 #[async_trait]
356 impl LanguageModelMiddleware for OverrideAndTransform {
357 fn override_provider(&self, _inner: &dyn LanguageModel) -> Option<String> {
358 Some("wrapped-provider".to_owned())
359 }
360
361 fn override_model_id(&self, _inner: &dyn LanguageModel) -> Option<String> {
362 Some("wrapped-model".to_owned())
363 }
364
365 async fn transform_params(
366 &self,
367 _kind: CallKind,
368 mut params: CallOptions,
369 _inner: &dyn LanguageModel,
370 ) -> Result<CallOptions> {
371 params.temperature = Some(params.temperature.unwrap_or(0.0) + 1.0);
372 Ok(params)
373 }
374 }
375
376 #[derive(Debug)]
378 struct OrderRecorder {
379 label: &'static str,
380 log: Arc<Mutex<Vec<String>>>,
381 }
382
383 #[async_trait]
384 impl LanguageModelMiddleware for OrderRecorder {
385 async fn wrap_generate(
386 &self,
387 next: &dyn LanguageModel,
388 params: CallOptions,
389 ) -> Result<GenerateResult> {
390 self.log
391 .lock()
392 .expect("log mutex poisoned")
393 .push(format!("{}:enter", self.label));
394 let res = next.do_generate(params).await;
395 self.log
396 .lock()
397 .expect("log mutex poisoned")
398 .push(format!("{}:exit", self.label));
399 res
400 }
401 }
402
403 #[derive(Debug)]
405 struct StreamFromGenerate;
406
407 #[async_trait]
408 impl LanguageModelMiddleware for StreamFromGenerate {
409 async fn wrap_stream(
410 &self,
411 next: &dyn LanguageModel,
412 params: CallOptions,
413 ) -> Result<StreamResult> {
414 let _ = next.do_generate(params).await?;
416 Ok(StreamResult {
417 stream: Box::pin(stream::iter(vec![])),
418 request: None,
419 response: None,
420 })
421 }
422 }
423
424 #[tokio::test]
425 async fn empty_middleware_returns_model_unchanged() {
426 let model = Arc::new(MockModel::new("openai", "gpt-foo"));
427 let wrapped: Arc<dyn LanguageModel> =
428 wrap_language_model(Arc::clone(&model) as _, Vec::new());
429 assert_eq!(wrapped.provider(), "openai");
430 assert_eq!(wrapped.model_id(), "gpt-foo");
431
432 wrapped
433 .do_generate(CallOptions::default())
434 .await
435 .expect("generate succeeded");
436 assert_eq!(model.generate_count(), 1);
437 }
438
439 #[tokio::test]
440 async fn overrides_replace_identity_and_transform_mutates_params() {
441 let model = Arc::new(MockModel::new("openai", "gpt-foo"));
442 let wrapped = wrap_language_model(
443 Arc::clone(&model) as _,
444 [Arc::new(OverrideAndTransform) as Arc<dyn LanguageModelMiddleware>],
445 );
446
447 assert_eq!(wrapped.provider(), "wrapped-provider");
448 assert_eq!(wrapped.model_id(), "wrapped-model");
449
450 wrapped
451 .do_generate(CallOptions::default())
452 .await
453 .expect("generate succeeded");
454 assert_eq!(model.last_temperature(), Some(1.0));
455 }
456
457 #[tokio::test]
458 async fn wrap_order_runs_first_middleware_outermost() {
459 let model = Arc::new(MockModel::new("openai", "gpt-foo"));
460 let log: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
461 let m1 = Arc::new(OrderRecorder {
462 label: "m1",
463 log: Arc::clone(&log),
464 }) as Arc<dyn LanguageModelMiddleware>;
465 let m2 = Arc::new(OrderRecorder {
466 label: "m2",
467 log: Arc::clone(&log),
468 }) as Arc<dyn LanguageModelMiddleware>;
469
470 let wrapped = wrap_language_model(model, [m1, m2]);
471 wrapped
472 .do_generate(CallOptions::default())
473 .await
474 .expect("generate succeeded");
475
476 let entries = log.lock().expect("log mutex poisoned").clone();
477 assert_eq!(
478 entries,
479 vec!["m1:enter", "m2:enter", "m2:exit", "m1:exit"],
480 "first middleware must be outermost",
481 );
482 }
483
484 #[tokio::test]
485 async fn middleware_can_swap_call_kind_via_next() {
486 let model = Arc::new(MockModel::new("openai", "gpt-foo"));
487 let wrapped = wrap_language_model(
488 Arc::clone(&model) as _,
489 [Arc::new(StreamFromGenerate) as Arc<dyn LanguageModelMiddleware>],
490 );
491
492 let mut stream = wrapped
493 .do_stream(CallOptions::default())
494 .await
495 .expect("stream succeeded")
496 .stream;
497 assert!(stream.next().await.is_none());
499
500 assert_eq!(model.generate_count(), 1, "do_generate was used internally");
501 assert_eq!(model.stream_count(), 0, "do_stream on inner was bypassed");
502 }
503}