llmsdk_provider/middleware/
embedding_model.rs1use std::sync::Arc;
12
13use async_trait::async_trait;
14
15use crate::embedding_model::{EmbedOptions, EmbedResult, EmbeddingModel};
16use crate::error::Result;
17
18#[async_trait]
24pub trait EmbeddingModelMiddleware: Send + Sync + std::fmt::Debug {
25 fn override_provider(&self, _inner: &dyn EmbeddingModel) -> Option<String> {
27 None
28 }
29
30 fn override_model_id(&self, _inner: &dyn EmbeddingModel) -> Option<String> {
32 None
33 }
34
35 async fn override_max_embeddings_per_call(
39 &self,
40 _inner: &dyn EmbeddingModel,
41 ) -> Option<Option<u32>> {
42 None
43 }
44
45 async fn override_supports_parallel_calls(&self, _inner: &dyn EmbeddingModel) -> Option<bool> {
47 None
48 }
49
50 async fn transform_params(
57 &self,
58 params: EmbedOptions,
59 _inner: &dyn EmbeddingModel,
60 ) -> Result<EmbedOptions> {
61 Ok(params)
62 }
63
64 async fn wrap_embed(
74 &self,
75 next: &dyn EmbeddingModel,
76 params: EmbedOptions,
77 ) -> Result<EmbedResult> {
78 next.do_embed(params).await
79 }
80}
81
82pub fn wrap_embedding_model<I>(
91 model: Arc<dyn EmbeddingModel>,
92 middleware: I,
93) -> Arc<dyn EmbeddingModel>
94where
95 I: IntoIterator<Item = Arc<dyn EmbeddingModelMiddleware>>,
96{
97 let mut layers: Vec<Arc<dyn EmbeddingModelMiddleware>> = middleware.into_iter().collect();
98 layers.reverse();
99 layers
100 .into_iter()
101 .fold(model, |inner, mw| Arc::new(Wrapped::new(inner, mw)))
102}
103
104struct Wrapped {
106 inner: Arc<dyn EmbeddingModel>,
107 middleware: Arc<dyn EmbeddingModelMiddleware>,
108 provider: String,
109 model_id: String,
110}
111
112impl Wrapped {
113 fn new(inner: Arc<dyn EmbeddingModel>, middleware: Arc<dyn EmbeddingModelMiddleware>) -> Self {
114 let provider = middleware
115 .override_provider(inner.as_ref())
116 .unwrap_or_else(|| inner.provider().to_owned());
117 let model_id = middleware
118 .override_model_id(inner.as_ref())
119 .unwrap_or_else(|| inner.model_id().to_owned());
120 Self {
121 inner,
122 middleware,
123 provider,
124 model_id,
125 }
126 }
127}
128
129impl std::fmt::Debug for Wrapped {
130 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
131 f.debug_struct("Wrapped")
132 .field("provider", &self.provider)
133 .field("model_id", &self.model_id)
134 .field("middleware", &self.middleware)
135 .field("inner", &self.inner)
136 .finish()
137 }
138}
139
140#[async_trait]
141impl EmbeddingModel for Wrapped {
142 fn provider(&self) -> &str {
143 &self.provider
144 }
145
146 fn model_id(&self) -> &str {
147 &self.model_id
148 }
149
150 async fn max_embeddings_per_call(&self) -> Option<u32> {
151 if let Some(custom) = self
152 .middleware
153 .override_max_embeddings_per_call(self.inner.as_ref())
154 .await
155 {
156 return custom;
157 }
158 self.inner.max_embeddings_per_call().await
159 }
160
161 async fn supports_parallel_calls(&self) -> bool {
162 if let Some(custom) = self
163 .middleware
164 .override_supports_parallel_calls(self.inner.as_ref())
165 .await
166 {
167 return custom;
168 }
169 self.inner.supports_parallel_calls().await
170 }
171
172 async fn do_embed(&self, options: EmbedOptions) -> Result<EmbedResult> {
173 let transformed = self
174 .middleware
175 .transform_params(options, self.inner.as_ref())
176 .await?;
177 self.middleware
178 .wrap_embed(self.inner.as_ref(), transformed)
179 .await
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use std::sync::Mutex;
186 use std::sync::atomic::{AtomicUsize, Ordering};
187
188 use super::*;
189
190 #[derive(Debug, Default)]
191 struct MockEmbed {
192 provider: String,
193 model_id: String,
194 calls: AtomicUsize,
195 last_input_len: Mutex<usize>,
196 }
197
198 impl MockEmbed {
199 fn new(provider: &str, model_id: &str) -> Self {
200 Self {
201 provider: provider.to_owned(),
202 model_id: model_id.to_owned(),
203 calls: AtomicUsize::new(0),
204 last_input_len: Mutex::new(0),
205 }
206 }
207 }
208
209 #[async_trait]
210 impl EmbeddingModel for MockEmbed {
211 fn provider(&self) -> &str {
212 &self.provider
213 }
214 fn model_id(&self) -> &str {
215 &self.model_id
216 }
217 async fn do_embed(&self, options: EmbedOptions) -> Result<EmbedResult> {
218 self.calls.fetch_add(1, Ordering::SeqCst);
219 *self.last_input_len.lock().expect("mutex") = options.values.len();
220 Ok(EmbedResult {
221 embeddings: options.values.iter().map(|_| vec![0.0; 3]).collect(),
222 usage: None,
223 provider_metadata: None,
224 request: None,
225 response: None,
226 })
227 }
228 }
229
230 #[derive(Debug)]
231 struct OverrideAndDoubleInputs;
232
233 #[async_trait]
234 impl EmbeddingModelMiddleware for OverrideAndDoubleInputs {
235 fn override_provider(&self, _inner: &dyn EmbeddingModel) -> Option<String> {
236 Some("wrapped".to_owned())
237 }
238
239 async fn override_max_embeddings_per_call(
240 &self,
241 _inner: &dyn EmbeddingModel,
242 ) -> Option<Option<u32>> {
243 Some(Some(42))
244 }
245
246 async fn transform_params(
247 &self,
248 mut params: EmbedOptions,
249 _inner: &dyn EmbeddingModel,
250 ) -> Result<EmbedOptions> {
251 let original = params.values.clone();
252 params.values.extend(original);
253 Ok(params)
254 }
255 }
256
257 #[tokio::test]
258 async fn empty_middleware_returns_unchanged() {
259 let model = Arc::new(MockEmbed::new("p", "m"));
260 let wrapped: Arc<dyn EmbeddingModel> =
261 wrap_embedding_model(Arc::clone(&model) as _, Vec::new());
262 assert_eq!(wrapped.provider(), "p");
263 assert_eq!(wrapped.model_id(), "m");
264 }
265
266 #[tokio::test]
267 async fn overrides_and_transform_run() {
268 let model = Arc::new(MockEmbed::new("p", "m"));
269 let wrapped = wrap_embedding_model(
270 Arc::clone(&model) as _,
271 [Arc::new(OverrideAndDoubleInputs) as Arc<dyn EmbeddingModelMiddleware>],
272 );
273
274 assert_eq!(wrapped.provider(), "wrapped");
275 assert_eq!(wrapped.max_embeddings_per_call().await, Some(42));
276
277 wrapped
278 .do_embed(EmbedOptions {
279 values: vec!["a".into(), "b".into()],
280 ..Default::default()
281 })
282 .await
283 .expect("embed");
284
285 assert_eq!(model.calls.load(Ordering::SeqCst), 1);
286 assert_eq!(*model.last_input_len.lock().expect("mutex"), 4);
287 }
288}