Skip to main content

llmsdk_provider/middleware/
embedding_model.rs

1//! `EmbeddingModelMiddleware` trait and the `wrap_embedding_model` combinator.
2//!
3//! Mirrors `embedding-model-v4-middleware.ts` (trait surface) and
4//! `wrap-embedding-model.ts` (combinator). Structurally identical to
5//! [`super::language_model`]'s combinator; the only differences are the
6//! callable surface (`do_embed` instead of `do_generate` / `do_stream`) and
7//! two embedding-specific identity overrides (`max_embeddings_per_call`,
8//! `supports_parallel_calls`).
9// Rust guideline compliant 2026-02-21
10
11use std::sync::Arc;
12
13use async_trait::async_trait;
14
15use crate::embedding_model::{EmbedOptions, EmbedResult, EmbeddingModel};
16use crate::error::Result;
17
18/// Contract for middleware that decorates an [`EmbeddingModel`].
19///
20/// Every method has a sensible default; an implementor only overrides the
21/// hooks it cares about. The combinator [`wrap_embedding_model`] composes any
22/// number of middlewares into a fresh `EmbeddingModel` instance.
23#[async_trait]
24pub trait EmbeddingModelMiddleware: Send + Sync + std::fmt::Debug {
25    /// Override the provider id exposed by the wrapped model.
26    fn override_provider(&self, _inner: &dyn EmbeddingModel) -> Option<String> {
27        None
28    }
29
30    /// Override the model id exposed by the wrapped model.
31    fn override_model_id(&self, _inner: &dyn EmbeddingModel) -> Option<String> {
32        None
33    }
34
35    /// Override [`EmbeddingModel::max_embeddings_per_call`].
36    ///
37    /// Returns `None` to defer to the inner model.
38    async fn override_max_embeddings_per_call(
39        &self,
40        _inner: &dyn EmbeddingModel,
41    ) -> Option<Option<u32>> {
42        None
43    }
44
45    /// Override [`EmbeddingModel::supports_parallel_calls`].
46    async fn override_supports_parallel_calls(&self, _inner: &dyn EmbeddingModel) -> Option<bool> {
47        None
48    }
49
50    /// Transform the embed options before they reach the inner model.
51    ///
52    /// # Errors
53    ///
54    /// Return a [`crate::ProviderError`] to fail the call without invoking
55    /// the model.
56    async fn transform_params(
57        &self,
58        params: EmbedOptions,
59        _inner: &dyn EmbeddingModel,
60    ) -> Result<EmbedOptions> {
61        Ok(params)
62    }
63
64    /// Wrap an embedding call.
65    ///
66    /// Default: forwards to `next.do_embed(params)`. Override to add retry,
67    /// caching, telemetry, etc.
68    ///
69    /// # Errors
70    ///
71    /// Returns whatever error `next` returns, or a middleware-introduced
72    /// failure.
73    async fn wrap_embed(
74        &self,
75        next: &dyn EmbeddingModel,
76        params: EmbedOptions,
77    ) -> Result<EmbedResult> {
78        next.do_embed(params).await
79    }
80}
81
82/// Compose an embedding model with one or more middlewares.
83///
84/// The returned `Arc<dyn EmbeddingModel>` runs middleware in outer-to-inner
85/// order on the way in (`m[0].transform_params` first) and in inner-to-outer
86/// order on the way out (`m[0].wrap_embed` is the outermost wrap), matching
87/// the convention from `wrap_language_model`.
88///
89/// Passing an empty middleware iterator returns the model unchanged.
90pub fn wrap_embedding_model<I>(
91    model: Arc<dyn EmbeddingModel>,
92    middleware: I,
93) -> Arc<dyn EmbeddingModel>
94where
95    I: IntoIterator<Item = Arc<dyn EmbeddingModelMiddleware>>,
96{
97    let mut layers: Vec<Arc<dyn EmbeddingModelMiddleware>> = middleware.into_iter().collect();
98    layers.reverse();
99    layers
100        .into_iter()
101        .fold(model, |inner, mw| Arc::new(Wrapped::new(inner, mw)))
102}
103
104/// Internal one-layer wrapper.
105struct Wrapped {
106    inner: Arc<dyn EmbeddingModel>,
107    middleware: Arc<dyn EmbeddingModelMiddleware>,
108    provider: String,
109    model_id: String,
110}
111
112impl Wrapped {
113    fn new(inner: Arc<dyn EmbeddingModel>, middleware: Arc<dyn EmbeddingModelMiddleware>) -> Self {
114        let provider = middleware
115            .override_provider(inner.as_ref())
116            .unwrap_or_else(|| inner.provider().to_owned());
117        let model_id = middleware
118            .override_model_id(inner.as_ref())
119            .unwrap_or_else(|| inner.model_id().to_owned());
120        Self {
121            inner,
122            middleware,
123            provider,
124            model_id,
125        }
126    }
127}
128
129impl std::fmt::Debug for Wrapped {
130    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131        f.debug_struct("Wrapped")
132            .field("provider", &self.provider)
133            .field("model_id", &self.model_id)
134            .field("middleware", &self.middleware)
135            .field("inner", &self.inner)
136            .finish()
137    }
138}
139
140#[async_trait]
141impl EmbeddingModel for Wrapped {
142    fn provider(&self) -> &str {
143        &self.provider
144    }
145
146    fn model_id(&self) -> &str {
147        &self.model_id
148    }
149
150    async fn max_embeddings_per_call(&self) -> Option<u32> {
151        if let Some(custom) = self
152            .middleware
153            .override_max_embeddings_per_call(self.inner.as_ref())
154            .await
155        {
156            return custom;
157        }
158        self.inner.max_embeddings_per_call().await
159    }
160
161    async fn supports_parallel_calls(&self) -> bool {
162        if let Some(custom) = self
163            .middleware
164            .override_supports_parallel_calls(self.inner.as_ref())
165            .await
166        {
167            return custom;
168        }
169        self.inner.supports_parallel_calls().await
170    }
171
172    async fn do_embed(&self, options: EmbedOptions) -> Result<EmbedResult> {
173        let transformed = self
174            .middleware
175            .transform_params(options, self.inner.as_ref())
176            .await?;
177        self.middleware
178            .wrap_embed(self.inner.as_ref(), transformed)
179            .await
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use std::sync::Mutex;
186    use std::sync::atomic::{AtomicUsize, Ordering};
187
188    use super::*;
189
190    #[derive(Debug, Default)]
191    struct MockEmbed {
192        provider: String,
193        model_id: String,
194        calls: AtomicUsize,
195        last_input_len: Mutex<usize>,
196    }
197
198    impl MockEmbed {
199        fn new(provider: &str, model_id: &str) -> Self {
200            Self {
201                provider: provider.to_owned(),
202                model_id: model_id.to_owned(),
203                calls: AtomicUsize::new(0),
204                last_input_len: Mutex::new(0),
205            }
206        }
207    }
208
209    #[async_trait]
210    impl EmbeddingModel for MockEmbed {
211        fn provider(&self) -> &str {
212            &self.provider
213        }
214        fn model_id(&self) -> &str {
215            &self.model_id
216        }
217        async fn do_embed(&self, options: EmbedOptions) -> Result<EmbedResult> {
218            self.calls.fetch_add(1, Ordering::SeqCst);
219            *self.last_input_len.lock().expect("mutex") = options.values.len();
220            Ok(EmbedResult {
221                embeddings: options.values.iter().map(|_| vec![0.0; 3]).collect(),
222                usage: None,
223                provider_metadata: None,
224                request: None,
225                response: None,
226            })
227        }
228    }
229
230    #[derive(Debug)]
231    struct OverrideAndDoubleInputs;
232
233    #[async_trait]
234    impl EmbeddingModelMiddleware for OverrideAndDoubleInputs {
235        fn override_provider(&self, _inner: &dyn EmbeddingModel) -> Option<String> {
236            Some("wrapped".to_owned())
237        }
238
239        async fn override_max_embeddings_per_call(
240            &self,
241            _inner: &dyn EmbeddingModel,
242        ) -> Option<Option<u32>> {
243            Some(Some(42))
244        }
245
246        async fn transform_params(
247            &self,
248            mut params: EmbedOptions,
249            _inner: &dyn EmbeddingModel,
250        ) -> Result<EmbedOptions> {
251            let original = params.values.clone();
252            params.values.extend(original);
253            Ok(params)
254        }
255    }
256
257    #[tokio::test]
258    async fn empty_middleware_returns_unchanged() {
259        let model = Arc::new(MockEmbed::new("p", "m"));
260        let wrapped: Arc<dyn EmbeddingModel> =
261            wrap_embedding_model(Arc::clone(&model) as _, Vec::new());
262        assert_eq!(wrapped.provider(), "p");
263        assert_eq!(wrapped.model_id(), "m");
264    }
265
266    #[tokio::test]
267    async fn overrides_and_transform_run() {
268        let model = Arc::new(MockEmbed::new("p", "m"));
269        let wrapped = wrap_embedding_model(
270            Arc::clone(&model) as _,
271            [Arc::new(OverrideAndDoubleInputs) as Arc<dyn EmbeddingModelMiddleware>],
272        );
273
274        assert_eq!(wrapped.provider(), "wrapped");
275        assert_eq!(wrapped.max_embeddings_per_call().await, Some(42));
276
277        wrapped
278            .do_embed(EmbedOptions {
279                values: vec!["a".into(), "b".into()],
280                ..Default::default()
281            })
282            .await
283            .expect("embed");
284
285        assert_eq!(model.calls.load(Ordering::SeqCst), 1);
286        assert_eq!(*model.last_input_len.lock().expect("mutex"), 4);
287    }
288}