Skip to main content

llmsdk_provider/middleware/
language_model.rs

1//! `LanguageModelMiddleware` trait and the `wrap_language_model` combinator.
2//!
3//! Mirrors `language-model-v4-middleware.ts` (trait surface) and
4//! `wrap-language-model.ts` (combinator). The combinator merges ai-sdk's
5//! `doGenerate` + `doStream` closure pair into a single `next: &dyn LanguageModel`
6//! argument; middleware that wants to swap call kinds (e.g. a future
7//! simulate-streaming middleware) just calls the other method on `next`.
8// Rust guideline compliant 2026-02-21
9
10use std::sync::Arc;
11
12use async_trait::async_trait;
13
14use crate::error::Result;
15use crate::language_model::{
16    CallOptions, GenerateResult, LanguageModel, StreamResult, SupportedUrls,
17};
18
19/// Discriminates the active call kind passed to
20/// [`LanguageModelMiddleware::transform_params`].
21///
22/// Mirrors ai-sdk's `type: 'generate' | 'stream'` discriminator.
23#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
24pub enum CallKind {
25    /// Non-streaming generation; the wrapper will invoke
26    /// [`LanguageModel::do_generate`] after the middleware chain runs.
27    Generate,
28    /// Streaming generation; the wrapper will invoke
29    /// [`LanguageModel::do_stream`] after the middleware chain runs.
30    Stream,
31}
32
33/// Contract for middleware that decorates a [`LanguageModel`].
34///
35/// Every method has a sensible default, so an implementor only overrides the
36/// hooks it cares about. The combinator [`wrap_language_model`] composes any
37/// number of middlewares into a fresh `LanguageModel` instance.
38///
39/// # Semantics
40///
41/// - `override_*`: replace the corresponding identity / metadata accessor.
42/// - [`transform_params`](Self::transform_params): mutate the call options
43///   before they reach the underlying model. Runs once per call, before
44///   `wrap_*`.
45/// - [`wrap_generate`](Self::wrap_generate) / [`wrap_stream`](Self::wrap_stream):
46///   intercept the actual call. The default implementation simply forwards
47///   to `next`; overrides may add retry, caching, instrumentation, swap
48///   between generate/stream, etc.
49///
50/// `next` is the *next layer* (which may itself be a wrapped model or the
51/// original provider model), not necessarily the underlying provider model.
52#[async_trait]
53pub trait LanguageModelMiddleware: Send + Sync + std::fmt::Debug {
54    /// Override the provider id exposed by the wrapped model.
55    ///
56    /// Return `None` to keep `inner.provider()`. The override is read once
57    /// when [`wrap_language_model`] runs, so it must not depend on call-time
58    /// state.
59    fn override_provider(&self, _inner: &dyn LanguageModel) -> Option<String> {
60        None
61    }
62
63    /// Override the model id exposed by the wrapped model.
64    ///
65    /// Same caching semantics as [`Self::override_provider`].
66    fn override_model_id(&self, _inner: &dyn LanguageModel) -> Option<String> {
67        None
68    }
69
70    /// Override the supported-URL map exposed by the wrapped model.
71    ///
72    /// Unlike the identity overrides, this is re-evaluated on every
73    /// [`LanguageModel::supported_urls`] call so middleware can reflect
74    /// dynamic provider state.
75    async fn override_supported_urls(&self, _inner: &dyn LanguageModel) -> Option<SupportedUrls> {
76        None
77    }
78
79    /// Transform the call options before they reach the inner model.
80    ///
81    /// Runs once per call, before [`Self::wrap_generate`] / [`Self::wrap_stream`].
82    /// The returned options are passed to both the next middleware layer's
83    /// `transform_params` and the eventual underlying call.
84    ///
85    /// # Errors
86    ///
87    /// Return a [`crate::ProviderError`] to fail the call without invoking
88    /// the model.
89    async fn transform_params(
90        &self,
91        _kind: CallKind,
92        params: CallOptions,
93        _inner: &dyn LanguageModel,
94    ) -> Result<CallOptions> {
95        Ok(params)
96    }
97
98    /// Wrap a non-streaming generation.
99    ///
100    /// Default: forwards to `next.do_generate(params)`. Override to add
101    /// retry, caching, telemetry, or to dispatch to `next.do_stream` instead.
102    ///
103    /// # Errors
104    ///
105    /// Returns whatever error `next` returns, or a middleware-introduced
106    /// failure.
107    async fn wrap_generate(
108        &self,
109        next: &dyn LanguageModel,
110        params: CallOptions,
111    ) -> Result<GenerateResult> {
112        next.do_generate(params).await
113    }
114
115    /// Wrap a streaming generation.
116    ///
117    /// Default: forwards to `next.do_stream(params)`. Override to add
118    /// retry, caching, telemetry, or to simulate streaming on top of
119    /// `next.do_generate`.
120    ///
121    /// # Errors
122    ///
123    /// Returns whatever error `next` returns, or a middleware-introduced
124    /// failure.
125    async fn wrap_stream(
126        &self,
127        next: &dyn LanguageModel,
128        params: CallOptions,
129    ) -> Result<StreamResult> {
130        next.do_stream(params).await
131    }
132}
133
134/// Compose a model with one or more middlewares.
135///
136/// The returned `Arc<dyn LanguageModel>` runs middleware in *outer-to-inner*
137/// order on the way in (`m[0].transform_params` first) and in *inner-to-outer*
138/// order on the way out (`m[0].wrap_generate` is the outermost wrap). This
139/// matches the convention used by `@ai-sdk/ai`'s `wrapLanguageModel`.
140///
141/// Passing an empty middleware iterator returns the model unchanged.
142///
143/// # Examples
144///
145/// Stacking two middlewares (the first is the outermost):
146///
147/// ```ignore
148/// use std::sync::Arc;
149/// use llmsdk_provider::{wrap_language_model, LanguageModel, LanguageModelMiddleware};
150///
151/// fn stack(
152///     model: Arc<dyn LanguageModel>,
153///     retry: Arc<dyn LanguageModelMiddleware>,
154///     log: Arc<dyn LanguageModelMiddleware>,
155/// ) -> Arc<dyn LanguageModel> {
156///     // `log` wraps `retry` wraps `model`. Logs see every retry attempt.
157///     wrap_language_model(model, [log, retry])
158/// }
159/// ```
160pub fn wrap_language_model<I>(
161    model: Arc<dyn LanguageModel>,
162    middleware: I,
163) -> Arc<dyn LanguageModel>
164where
165    I: IntoIterator<Item = Arc<dyn LanguageModelMiddleware>>,
166{
167    let mut layers: Vec<Arc<dyn LanguageModelMiddleware>> = middleware.into_iter().collect();
168    // Apply right-most middleware first so list head ends up outermost.
169    layers.reverse();
170    layers
171        .into_iter()
172        .fold(model, |inner, mw| Arc::new(Wrapped::new(inner, mw)))
173}
174
175/// Internal one-layer wrapper that pairs a model with a single middleware.
176///
177/// Each call to [`wrap_language_model`] produces a stack of these. We cache
178/// the identity overrides at construction time because the trait accessors
179/// return `&str` while the middleware returns `Option<String>`.
180struct Wrapped {
181    inner: Arc<dyn LanguageModel>,
182    middleware: Arc<dyn LanguageModelMiddleware>,
183    provider: String,
184    model_id: String,
185}
186
187impl Wrapped {
188    fn new(inner: Arc<dyn LanguageModel>, middleware: Arc<dyn LanguageModelMiddleware>) -> Self {
189        let provider = middleware
190            .override_provider(inner.as_ref())
191            .unwrap_or_else(|| inner.provider().to_owned());
192        let model_id = middleware
193            .override_model_id(inner.as_ref())
194            .unwrap_or_else(|| inner.model_id().to_owned());
195        Self {
196            inner,
197            middleware,
198            provider,
199            model_id,
200        }
201    }
202}
203
204impl std::fmt::Debug for Wrapped {
205    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
206        f.debug_struct("Wrapped")
207            .field("provider", &self.provider)
208            .field("model_id", &self.model_id)
209            .field("middleware", &self.middleware)
210            .field("inner", &self.inner)
211            .finish()
212    }
213}
214
215#[async_trait]
216impl LanguageModel for Wrapped {
217    fn provider(&self) -> &str {
218        &self.provider
219    }
220
221    fn model_id(&self) -> &str {
222        &self.model_id
223    }
224
225    async fn supported_urls(&self) -> SupportedUrls {
226        if let Some(custom) = self
227            .middleware
228            .override_supported_urls(self.inner.as_ref())
229            .await
230        {
231            return custom;
232        }
233        self.inner.supported_urls().await
234    }
235
236    async fn do_generate(&self, options: CallOptions) -> Result<GenerateResult> {
237        let transformed = self
238            .middleware
239            .transform_params(CallKind::Generate, options, self.inner.as_ref())
240            .await?;
241        self.middleware
242            .wrap_generate(self.inner.as_ref(), transformed)
243            .await
244    }
245
246    async fn do_stream(&self, options: CallOptions) -> Result<StreamResult> {
247        let transformed = self
248            .middleware
249            .transform_params(CallKind::Stream, options, self.inner.as_ref())
250            .await?;
251        self.middleware
252            .wrap_stream(self.inner.as_ref(), transformed)
253            .await
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use std::sync::Mutex;
260    use std::sync::atomic::{AtomicUsize, Ordering};
261
262    use futures::StreamExt;
263    use futures::stream;
264
265    use crate::language_model::{FinishReason, FinishReasonKind, StreamPart, Usage};
266
267    use super::*;
268
269    /// Mock model that records every `do_generate` / `do_stream` invocation
270    /// and lets the test decide what to return.
271    #[derive(Debug, Default)]
272    struct MockModel {
273        provider: String,
274        model_id: String,
275        generate_calls: AtomicUsize,
276        stream_calls: AtomicUsize,
277        last_params: Mutex<Option<CallOptions>>,
278    }
279
280    impl MockModel {
281        fn new(provider: &str, model_id: &str) -> Self {
282            Self {
283                provider: provider.to_owned(),
284                model_id: model_id.to_owned(),
285                generate_calls: AtomicUsize::new(0),
286                stream_calls: AtomicUsize::new(0),
287                last_params: Mutex::new(None),
288            }
289        }
290
291        fn generate_count(&self) -> usize {
292            self.generate_calls.load(Ordering::SeqCst)
293        }
294
295        fn stream_count(&self) -> usize {
296            self.stream_calls.load(Ordering::SeqCst)
297        }
298
299        fn last_temperature(&self) -> Option<f32> {
300            self.last_params
301                .lock()
302                .expect("mock mutex poisoned")
303                .as_ref()
304                .and_then(|p| p.temperature)
305        }
306    }
307
308    #[async_trait]
309    impl LanguageModel for MockModel {
310        fn provider(&self) -> &str {
311            &self.provider
312        }
313
314        fn model_id(&self) -> &str {
315            &self.model_id
316        }
317
318        async fn do_generate(&self, options: CallOptions) -> Result<GenerateResult> {
319            self.generate_calls.fetch_add(1, Ordering::SeqCst);
320            *self.last_params.lock().expect("mock mutex poisoned") = Some(options);
321            Ok(GenerateResult {
322                content: vec![],
323                finish_reason: FinishReason::new(FinishReasonKind::Stop),
324                usage: Usage::default(),
325                provider_metadata: None,
326                request: None,
327                response: None,
328                warnings: vec![],
329            })
330        }
331
332        async fn do_stream(&self, options: CallOptions) -> Result<StreamResult> {
333            self.stream_calls.fetch_add(1, Ordering::SeqCst);
334            *self.last_params.lock().expect("mock mutex poisoned") = Some(options);
335            let parts = stream::iter(vec![
336                Ok(StreamPart::StreamStart { warnings: vec![] }),
337                Ok(StreamPart::Finish {
338                    usage: Usage::default(),
339                    finish_reason: FinishReason::new(FinishReasonKind::Stop),
340                    provider_metadata: None,
341                }),
342            ]);
343            Ok(StreamResult {
344                stream: Box::pin(parts),
345                request: None,
346                response: None,
347            })
348        }
349    }
350
351    /// Middleware that overrides identity + bumps temperature.
352    #[derive(Debug)]
353    struct OverrideAndTransform;
354
355    #[async_trait]
356    impl LanguageModelMiddleware for OverrideAndTransform {
357        fn override_provider(&self, _inner: &dyn LanguageModel) -> Option<String> {
358            Some("wrapped-provider".to_owned())
359        }
360
361        fn override_model_id(&self, _inner: &dyn LanguageModel) -> Option<String> {
362            Some("wrapped-model".to_owned())
363        }
364
365        async fn transform_params(
366            &self,
367            _kind: CallKind,
368            mut params: CallOptions,
369            _inner: &dyn LanguageModel,
370        ) -> Result<CallOptions> {
371            params.temperature = Some(params.temperature.unwrap_or(0.0) + 1.0);
372            Ok(params)
373        }
374    }
375
376    /// Records the order in which `wrap_generate` runs across the stack.
377    #[derive(Debug)]
378    struct OrderRecorder {
379        label: &'static str,
380        log: Arc<Mutex<Vec<String>>>,
381    }
382
383    #[async_trait]
384    impl LanguageModelMiddleware for OrderRecorder {
385        async fn wrap_generate(
386            &self,
387            next: &dyn LanguageModel,
388            params: CallOptions,
389        ) -> Result<GenerateResult> {
390            self.log
391                .lock()
392                .expect("log mutex poisoned")
393                .push(format!("{}:enter", self.label));
394            let res = next.do_generate(params).await;
395            self.log
396                .lock()
397                .expect("log mutex poisoned")
398                .push(format!("{}:exit", self.label));
399            res
400        }
401    }
402
403    /// Middleware that ignores `do_stream` and serves it from `do_generate`.
404    #[derive(Debug)]
405    struct StreamFromGenerate;
406
407    #[async_trait]
408    impl LanguageModelMiddleware for StreamFromGenerate {
409        async fn wrap_stream(
410            &self,
411            next: &dyn LanguageModel,
412            params: CallOptions,
413        ) -> Result<StreamResult> {
414            // Prove that middleware can swap call kinds via `next`.
415            let _ = next.do_generate(params).await?;
416            Ok(StreamResult {
417                stream: Box::pin(stream::iter(vec![])),
418                request: None,
419                response: None,
420            })
421        }
422    }
423
424    #[tokio::test]
425    async fn empty_middleware_returns_model_unchanged() {
426        let model = Arc::new(MockModel::new("openai", "gpt-foo"));
427        let wrapped: Arc<dyn LanguageModel> =
428            wrap_language_model(Arc::clone(&model) as _, Vec::new());
429        assert_eq!(wrapped.provider(), "openai");
430        assert_eq!(wrapped.model_id(), "gpt-foo");
431
432        wrapped
433            .do_generate(CallOptions::default())
434            .await
435            .expect("generate succeeded");
436        assert_eq!(model.generate_count(), 1);
437    }
438
439    #[tokio::test]
440    async fn overrides_replace_identity_and_transform_mutates_params() {
441        let model = Arc::new(MockModel::new("openai", "gpt-foo"));
442        let wrapped = wrap_language_model(
443            Arc::clone(&model) as _,
444            [Arc::new(OverrideAndTransform) as Arc<dyn LanguageModelMiddleware>],
445        );
446
447        assert_eq!(wrapped.provider(), "wrapped-provider");
448        assert_eq!(wrapped.model_id(), "wrapped-model");
449
450        wrapped
451            .do_generate(CallOptions::default())
452            .await
453            .expect("generate succeeded");
454        assert_eq!(model.last_temperature(), Some(1.0));
455    }
456
457    #[tokio::test]
458    async fn wrap_order_runs_first_middleware_outermost() {
459        let model = Arc::new(MockModel::new("openai", "gpt-foo"));
460        let log: Arc<Mutex<Vec<String>>> = Arc::new(Mutex::new(Vec::new()));
461        let m1 = Arc::new(OrderRecorder {
462            label: "m1",
463            log: Arc::clone(&log),
464        }) as Arc<dyn LanguageModelMiddleware>;
465        let m2 = Arc::new(OrderRecorder {
466            label: "m2",
467            log: Arc::clone(&log),
468        }) as Arc<dyn LanguageModelMiddleware>;
469
470        let wrapped = wrap_language_model(model, [m1, m2]);
471        wrapped
472            .do_generate(CallOptions::default())
473            .await
474            .expect("generate succeeded");
475
476        let entries = log.lock().expect("log mutex poisoned").clone();
477        assert_eq!(
478            entries,
479            vec!["m1:enter", "m2:enter", "m2:exit", "m1:exit"],
480            "first middleware must be outermost",
481        );
482    }
483
484    #[tokio::test]
485    async fn middleware_can_swap_call_kind_via_next() {
486        let model = Arc::new(MockModel::new("openai", "gpt-foo"));
487        let wrapped = wrap_language_model(
488            Arc::clone(&model) as _,
489            [Arc::new(StreamFromGenerate) as Arc<dyn LanguageModelMiddleware>],
490        );
491
492        let mut stream = wrapped
493            .do_stream(CallOptions::default())
494            .await
495            .expect("stream succeeded")
496            .stream;
497        // Drain the (empty) stream to satisfy `must_use`.
498        assert!(stream.next().await.is_none());
499
500        assert_eq!(model.generate_count(), 1, "do_generate was used internally");
501        assert_eq!(model.stream_count(), 0, "do_stream on inner was bypassed");
502    }
503}