Skip to main content

llmsdk_provider/middleware/
provider.rs

1//! Top-level [`Provider`] decoration.
2//!
3//! Mirrors `wrap-provider.ts`. Decorates every model returned by an inner
4//! [`Provider`] with the matching surface middleware chain. Unlike the
5//! ai-sdk TS variant, we surface the three middleware lists as a typed
6//! [`ProviderMiddlewareSet`] struct rather than a free-form options bag.
7//!
8//! Limitation (matches ai-sdk): the middleware chain is applied uniformly to
9//! every model id; per-model-id routing must be implemented in a custom
10//! [`Provider`].
11// Rust guideline compliant 2026-02-21
12
13use std::sync::Arc;
14
15use crate::error::Result;
16use crate::image_model::ImageModel;
17use crate::language_model::LanguageModel;
18use crate::provider::{DynEmbeddingModel, DynImageModel, DynLanguageModel, Provider};
19
20use super::image_model::{ImageModelMiddleware, wrap_image_model};
21use super::language_model::{LanguageModelMiddleware, wrap_language_model};
22
23/// Two middleware chains: one for language models, one for image models.
24///
25/// Mirrors the upstream `wrapProvider({ languageModelMiddleware,
26/// imageModelMiddleware })` surface in
27/// `packages/ai/src/middleware/wrap-provider.ts:20-50`. Upstream
28/// **deliberately** does not expose embedding / reranking / video / speech /
29/// transcription middleware here — those model surfaces are forwarded
30/// verbatim and any wrapping happens via the per-model `wrap_*` helpers.
31///
32/// Passing an empty `Vec` for a surface leaves that surface untouched.
33#[derive(Default, Clone)]
34pub struct ProviderMiddlewareSet {
35    /// Middleware applied to every [`LanguageModel`] returned by
36    /// [`Provider::language_model`].
37    pub language: Vec<Arc<dyn LanguageModelMiddleware>>,
38    /// Middleware applied to every [`ImageModel`] returned by
39    /// [`Provider::image_model`].
40    pub image: Vec<Arc<dyn ImageModelMiddleware>>,
41}
42
43impl std::fmt::Debug for ProviderMiddlewareSet {
44    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
45        f.debug_struct("ProviderMiddlewareSet")
46            .field("language", &self.language.len())
47            .field("image", &self.image.len())
48            .finish()
49    }
50}
51
52/// Wrap a provider so every returned model is decorated with the matching
53/// middleware chain.
54///
55/// Each lookup (`language_model` / `embedding_model` / `image_model`)
56/// delegates to the inner provider and, on success, wraps the result with
57/// the configured middleware chain. Lookups for unsupported surfaces
58/// propagate the inner error unchanged.
59///
60/// Cloning the middleware set is shallow (each `Arc` is bumped); the cost
61/// per lookup is one `Vec::clone` plus the existing `Wrapped` allocations.
62pub fn wrap_provider(inner: Arc<dyn Provider>, set: ProviderMiddlewareSet) -> Arc<dyn Provider> {
63    Arc::new(WrappedProvider { inner, set })
64}
65
66struct WrappedProvider {
67    inner: Arc<dyn Provider>,
68    set: ProviderMiddlewareSet,
69}
70
71impl std::fmt::Debug for WrappedProvider {
72    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
73        f.debug_struct("WrappedProvider")
74            .field("inner", &self.inner)
75            .field("middleware", &self.set)
76            .finish()
77    }
78}
79
80impl Provider for WrappedProvider {
81    fn language_model(&self, model_id: &str) -> Result<DynLanguageModel> {
82        let dyn_model = self.inner.language_model(model_id)?;
83        if self.set.language.is_empty() {
84            return Ok(dyn_model);
85        }
86        let arc: Arc<dyn LanguageModel> = dyn_model.into_inner();
87        let wrapped = wrap_language_model(arc, self.set.language.iter().cloned());
88        Ok(DynLanguageModel::from_arc(wrapped))
89    }
90
91    fn embedding_model(&self, model_id: &str) -> Result<DynEmbeddingModel> {
92        // Mirror upstream `wrap-provider.ts:37` —
93        // `embeddingModel: providerV4.embeddingModel` (verbatim forward,
94        // no middleware). Callers needing per-call embedding middleware use
95        // `wrap_embedding_model` directly on a specific model handle.
96        self.inner.embedding_model(model_id)
97    }
98
99    fn image_model(&self, model_id: &str) -> Result<DynImageModel> {
100        let dyn_model = self.inner.image_model(model_id)?;
101        if self.set.image.is_empty() {
102            return Ok(dyn_model);
103        }
104        let arc: Arc<dyn ImageModel> = dyn_model.into_inner();
105        let wrapped = wrap_image_model(arc, self.set.image.iter().cloned());
106        Ok(DynImageModel::from_arc(wrapped))
107    }
108}
109
110#[cfg(test)]
111mod tests {
112    use std::sync::Mutex;
113    use std::sync::atomic::{AtomicUsize, Ordering};
114
115    use async_trait::async_trait;
116
117    use super::*;
118    use crate::embedding_model::{EmbedOptions, EmbedResult, EmbeddingModel};
119    use crate::language_model::{
120        CallOptions, FinishReason, FinishReasonKind, GenerateResult, StreamResult, Usage,
121    };
122
123    #[derive(Debug, Default)]
124    struct StubLang;
125
126    #[async_trait]
127    impl LanguageModel for StubLang {
128        fn provider(&self) -> &'static str {
129            "stub"
130        }
131        fn model_id(&self) -> &'static str {
132            "stub-lm"
133        }
134        async fn do_generate(&self, _options: CallOptions) -> Result<GenerateResult> {
135            Ok(GenerateResult {
136                content: vec![],
137                finish_reason: FinishReason::new(FinishReasonKind::Stop),
138                usage: Usage::default(),
139                provider_metadata: None,
140                request: None,
141                response: None,
142                warnings: vec![],
143            })
144        }
145        async fn do_stream(&self, _options: CallOptions) -> Result<StreamResult> {
146            Ok(StreamResult {
147                stream: Box::pin(futures::stream::iter(vec![])),
148                request: None,
149                response: None,
150            })
151        }
152    }
153
154    #[derive(Debug, Default)]
155    struct StubEmbed;
156
157    #[async_trait]
158    impl EmbeddingModel for StubEmbed {
159        fn provider(&self) -> &'static str {
160            "stub"
161        }
162        fn model_id(&self) -> &'static str {
163            "stub-em"
164        }
165        async fn do_embed(&self, _opts: EmbedOptions) -> Result<EmbedResult> {
166            Ok(EmbedResult {
167                embeddings: vec![],
168                usage: None,
169                provider_metadata: None,
170                request: None,
171                response: None,
172            })
173        }
174    }
175
176    #[derive(Debug, Default)]
177    struct StubProvider;
178
179    impl Provider for StubProvider {
180        fn language_model(&self, _model_id: &str) -> Result<DynLanguageModel> {
181            Ok(DynLanguageModel::new(StubLang))
182        }
183        fn embedding_model(&self, _model_id: &str) -> Result<DynEmbeddingModel> {
184            Ok(DynEmbeddingModel::new(StubEmbed))
185        }
186    }
187
188    /// Counts how many times each surface's middleware ran.
189    #[derive(Debug, Default)]
190    struct Counter {
191        lang_calls: AtomicUsize,
192        embed_calls: AtomicUsize,
193        last_temp: Mutex<Option<f32>>,
194    }
195
196    #[derive(Debug)]
197    struct CountingLang(Arc<Counter>);
198
199    #[async_trait]
200    impl LanguageModelMiddleware for CountingLang {
201        async fn transform_params(
202            &self,
203            _kind: super::super::language_model::CallKind,
204            mut params: CallOptions,
205            _inner: &dyn LanguageModel,
206        ) -> Result<CallOptions> {
207            self.0.lang_calls.fetch_add(1, Ordering::SeqCst);
208            params.temperature = Some(0.5);
209            *self.0.last_temp.lock().expect("mutex") = params.temperature;
210            Ok(params)
211        }
212    }
213
214    #[tokio::test]
215    async fn wraps_language_surface_only_embedding_passes_through() {
216        // Mirrors upstream wrap-provider.ts:32-37 — language goes through
217        // wrap_language_model, embedding is forwarded verbatim (no
218        // wrapping). The middleware never observes the embedding call.
219        let counter = Arc::new(Counter::default());
220        let set = ProviderMiddlewareSet {
221            language: vec![Arc::new(CountingLang(Arc::clone(&counter)))],
222            image: vec![],
223        };
224        let wrapped = wrap_provider(Arc::new(StubProvider), set);
225
226        let lm = wrapped.language_model("anything").expect("language");
227        lm.do_generate(CallOptions::default())
228            .await
229            .expect("generate");
230        assert_eq!(counter.lang_calls.load(Ordering::SeqCst), 1);
231        assert_eq!(*counter.last_temp.lock().expect("mutex"), Some(0.5));
232
233        // Embedding model is reachable but never routed through the
234        // middleware (which is *not* configured at the provider level).
235        let em = wrapped.embedding_model("anything").expect("embed");
236        em.do_embed(EmbedOptions::default()).await.expect("embed");
237        // CountingEmbed was never installed → counter stays at 0.
238        assert_eq!(counter.embed_calls.load(Ordering::SeqCst), 0);
239    }
240
241    #[tokio::test]
242    async fn unsupported_surface_propagates_inner_error() {
243        let set = ProviderMiddlewareSet::default();
244        let wrapped = wrap_provider(Arc::new(StubProvider), set);
245        let err = wrapped.image_model("x").expect_err("inner unsupported");
246        assert!(err.is_unsupported());
247    }
248}