1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use serde::{Deserialize, Serialize};
6
7use crate::{CapabilitySupport, EmbeddingRequest, EmbeddingResponse, ProviderIdentity, Result};
8
9#[non_exhaustive]
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
12pub enum EmbeddingCapability {
13 BatchInput,
15 OutputDimensions,
17}
18
19pub trait EmbeddingProvider: ProviderIdentity {
27 fn embed(
33 &self,
34 request: &EmbeddingRequest,
35 ) -> impl Future<Output = Result<EmbeddingResponse>> + Send;
36
37 fn embedding_capability(
39 &self,
40 _model: &str,
41 _capability: EmbeddingCapability,
42 ) -> CapabilitySupport {
43 CapabilitySupport::Unknown
44 }
45}
46
47impl<T> EmbeddingProvider for &T
48where
49 T: EmbeddingProvider + ?Sized,
50{
51 async fn embed(&self, request: &EmbeddingRequest) -> Result<EmbeddingResponse> {
52 T::embed(*self, request).await
53 }
54
55 fn embedding_capability(
56 &self,
57 model: &str,
58 capability: EmbeddingCapability,
59 ) -> CapabilitySupport {
60 T::embedding_capability(*self, model, capability)
61 }
62}
63
64impl<T> EmbeddingProvider for Box<T>
65where
66 T: EmbeddingProvider + ?Sized,
67{
68 async fn embed(&self, request: &EmbeddingRequest) -> Result<EmbeddingResponse> {
69 T::embed(self.as_ref(), request).await
70 }
71
72 fn embedding_capability(
73 &self,
74 model: &str,
75 capability: EmbeddingCapability,
76 ) -> CapabilitySupport {
77 T::embedding_capability(self.as_ref(), model, capability)
78 }
79}
80
81impl<T> EmbeddingProvider for Arc<T>
82where
83 T: EmbeddingProvider + ?Sized,
84{
85 async fn embed(&self, request: &EmbeddingRequest) -> Result<EmbeddingResponse> {
86 T::embed(self.as_ref(), request).await
87 }
88
89 fn embedding_capability(
90 &self,
91 model: &str,
92 capability: EmbeddingCapability,
93 ) -> CapabilitySupport {
94 T::embedding_capability(self.as_ref(), model, capability)
95 }
96}
97
98#[derive(Clone)]
103pub struct DynEmbeddingProvider(Arc<dyn EmbeddingProviderErased>);
104
105impl DynEmbeddingProvider {
106 #[must_use]
108 pub fn new<T>(provider: T) -> Self
109 where
110 T: EmbeddingProvider + 'static,
111 {
112 Self(Arc::new(provider))
113 }
114}
115
116impl<T> From<Arc<T>> for DynEmbeddingProvider
117where
118 T: EmbeddingProvider + 'static,
119{
120 fn from(provider: Arc<T>) -> Self {
121 Self(provider)
122 }
123}
124
125impl std::fmt::Debug for DynEmbeddingProvider {
126 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127 f.debug_struct("DynEmbeddingProvider")
128 .field("provider", &self.0.provider_name())
129 .finish()
130 }
131}
132
133impl ProviderIdentity for DynEmbeddingProvider {
134 fn provider_name(&self) -> &'static str {
135 self.0.provider_name()
136 }
137}
138
139impl EmbeddingProvider for DynEmbeddingProvider {
140 async fn embed(&self, request: &EmbeddingRequest) -> Result<EmbeddingResponse> {
141 self.0.embed_erased(request).await
142 }
143
144 fn embedding_capability(
145 &self,
146 model: &str,
147 capability: EmbeddingCapability,
148 ) -> CapabilitySupport {
149 self.0.embedding_capability_erased(model, capability)
150 }
151}
152
153trait EmbeddingProviderErased: ProviderIdentity {
157 fn embed_erased<'a>(
158 &'a self,
159 request: &'a EmbeddingRequest,
160 ) -> Pin<Box<dyn Future<Output = Result<EmbeddingResponse>> + Send + 'a>>;
161
162 fn embedding_capability_erased(
163 &self,
164 model: &str,
165 capability: EmbeddingCapability,
166 ) -> CapabilitySupport;
167}
168
169impl<T> EmbeddingProviderErased for T
170where
171 T: EmbeddingProvider,
172{
173 fn embed_erased<'a>(
174 &'a self,
175 request: &'a EmbeddingRequest,
176 ) -> Pin<Box<dyn Future<Output = Result<EmbeddingResponse>> + Send + 'a>> {
177 Box::pin(EmbeddingProvider::embed(self, request))
178 }
179
180 fn embedding_capability_erased(
181 &self,
182 model: &str,
183 capability: EmbeddingCapability,
184 ) -> CapabilitySupport {
185 EmbeddingProvider::embedding_capability(self, model, capability)
186 }
187}
188
189pub trait EmbeddingProviderExt: EmbeddingProvider {
191 fn embed_text(
200 &self,
201 model: &str,
202 input: impl Into<String>,
203 ) -> impl Future<Output = Result<Vec<f32>>> + Send {
204 let input = input.into();
205 let model = model.to_string();
206
207 async move {
208 let response = self
209 .embed(&EmbeddingRequest::new(model).input(input))
210 .await?;
211
212 response.embeddings.into_iter().next().ok_or_else(|| {
213 crate::Error::UnexpectedResponse(format!(
214 "provider '{}' returned no embeddings for embed_text()",
215 self.provider_name()
216 ))
217 })
218 }
219 }
220}
221
222impl<T: EmbeddingProvider> EmbeddingProviderExt for T {}
223
224#[cfg(test)]
225mod provider_tests {
226 use super::*;
227 use crate::{ProviderIdentity, Result};
228 use std::sync::Arc;
229
230 struct StaticEmbeddingProvider {
231 response: EmbeddingResponse,
232 }
233
234 impl ProviderIdentity for StaticEmbeddingProvider {
235 fn provider_name(&self) -> &'static str {
236 "static-embed"
237 }
238 }
239
240 impl EmbeddingProvider for StaticEmbeddingProvider {
241 async fn embed(&self, _request: &EmbeddingRequest) -> Result<EmbeddingResponse> {
242 Ok(self.response.clone())
243 }
244
245 fn embedding_capability(
246 &self,
247 _model: &str,
248 capability: EmbeddingCapability,
249 ) -> CapabilitySupport {
250 match capability {
251 EmbeddingCapability::BatchInput => CapabilitySupport::Supported,
252 EmbeddingCapability::OutputDimensions => CapabilitySupport::Unsupported,
253 }
254 }
255 }
256
257 fn demo_provider() -> StaticEmbeddingProvider {
258 StaticEmbeddingProvider {
259 response: EmbeddingResponse::new(vec![vec![0.1, 0.2]]).model("demo"),
260 }
261 }
262
263 #[tokio::test]
264 async fn direct_impl_returns_response() {
265 let provider = demo_provider();
266 let request = EmbeddingRequest::new("demo").input("hello");
267 let response = provider.embed(&request).await.unwrap();
268 assert_eq!(response.embeddings, vec![vec![0.1, 0.2]]);
269 }
270
271 #[tokio::test]
272 async fn ref_forwards_embed() {
273 let provider = demo_provider();
274 let borrowed: &StaticEmbeddingProvider = &provider;
275 let request = EmbeddingRequest::new("demo").input("hello");
276 assert_eq!(
277 borrowed.embed(&request).await.unwrap().embeddings,
278 vec![vec![0.1, 0.2]]
279 );
280 assert_eq!(borrowed.provider_name(), "static-embed");
281 }
282
283 #[tokio::test]
284 async fn box_forwards_embed() {
285 let boxed: Box<StaticEmbeddingProvider> = Box::new(demo_provider());
286 let request = EmbeddingRequest::new("demo").input("hello");
287 assert_eq!(
288 boxed.embed(&request).await.unwrap().embeddings,
289 vec![vec![0.1, 0.2]]
290 );
291 }
292
293 #[tokio::test]
294 async fn arc_forwards_embed_and_capability() {
295 let arced: Arc<StaticEmbeddingProvider> = Arc::new(demo_provider());
296 let request = EmbeddingRequest::new("demo").input("hello");
297 assert_eq!(
298 arced.embed(&request).await.unwrap().embeddings,
299 vec![vec![0.1, 0.2]]
300 );
301 assert_eq!(
302 arced.embedding_capability("demo", EmbeddingCapability::BatchInput),
303 CapabilitySupport::Supported
304 );
305 assert_eq!(
306 arced.embedding_capability("demo", EmbeddingCapability::OutputDimensions),
307 CapabilitySupport::Unsupported
308 );
309 }
310
311 #[tokio::test]
312 async fn default_capability_method_returns_unknown() {
313 struct Minimal;
314 impl ProviderIdentity for Minimal {}
315 impl EmbeddingProvider for Minimal {
316 async fn embed(&self, _request: &EmbeddingRequest) -> Result<EmbeddingResponse> {
317 Ok(EmbeddingResponse::default())
318 }
319 }
320
321 assert_eq!(
322 Minimal.embedding_capability("any", EmbeddingCapability::BatchInput),
323 CapabilitySupport::Unknown
324 );
325 }
326}
327
328#[cfg(test)]
329mod dyn_tests {
330 use super::*;
331 use crate::ProviderIdentity;
332 use std::sync::Arc;
333
334 struct DynDemo {
335 tag: &'static str,
336 }
337
338 impl ProviderIdentity for DynDemo {
339 fn provider_name(&self) -> &'static str {
340 self.tag
341 }
342 }
343
344 impl EmbeddingProvider for DynDemo {
345 async fn embed(&self, request: &EmbeddingRequest) -> crate::Result<EmbeddingResponse> {
346 let inputs = request.inputs.len();
347 Ok(EmbeddingResponse::new(vec![vec![0.0; 4]; inputs]))
348 }
349
350 fn embedding_capability(
351 &self,
352 _model: &str,
353 capability: EmbeddingCapability,
354 ) -> CapabilitySupport {
355 match capability {
356 EmbeddingCapability::BatchInput => CapabilitySupport::Supported,
357 EmbeddingCapability::OutputDimensions => CapabilitySupport::Unsupported,
358 }
359 }
360 }
361
362 #[tokio::test]
363 async fn dyn_provider_from_concrete_forwards_calls() {
364 let provider = DynEmbeddingProvider::new(DynDemo { tag: "dyn-embed" });
365 let request = EmbeddingRequest::new("demo").inputs(["a", "b"]);
366 let response = provider.embed(&request).await.unwrap();
367 assert_eq!(response.embeddings.len(), 2);
368 assert_eq!(provider.provider_name(), "dyn-embed");
369 assert_eq!(
370 provider.embedding_capability("demo", EmbeddingCapability::BatchInput),
371 CapabilitySupport::Supported
372 );
373 }
374
375 #[tokio::test]
376 async fn dyn_provider_from_arc_is_cloneable() {
377 let provider: DynEmbeddingProvider = Arc::new(DynDemo { tag: "arc-embed" }).into();
378 let cloned = provider.clone();
379 let request = EmbeddingRequest::new("demo").input("x");
380 assert_eq!(cloned.embed(&request).await.unwrap().embeddings.len(), 1);
381 assert_eq!(cloned.provider_name(), "arc-embed");
382 }
383
384 #[test]
385 fn dyn_provider_debug_includes_provider_name() {
386 let provider = DynEmbeddingProvider::new(DynDemo { tag: "debug-embed" });
387 let debug = format!("{provider:?}");
388 assert!(debug.contains("DynEmbeddingProvider"));
389 assert!(debug.contains("debug-embed"));
390 }
391}
392
393#[cfg(test)]
394mod ext_tests {
395 use super::*;
396 use crate::{Error, ProviderIdentity};
397 use std::sync::Mutex;
398
399 struct RecordingProvider {
400 response: EmbeddingResponse,
401 last_inputs: Mutex<Option<Vec<String>>>,
402 }
403
404 impl ProviderIdentity for RecordingProvider {
405 fn provider_name(&self) -> &'static str {
406 "recording"
407 }
408 }
409
410 impl EmbeddingProvider for RecordingProvider {
411 async fn embed(&self, request: &EmbeddingRequest) -> crate::Result<EmbeddingResponse> {
412 *self.last_inputs.lock().unwrap() = Some(request.inputs.clone());
413 Ok(self.response.clone())
414 }
415 }
416
417 #[tokio::test]
418 async fn embed_text_sends_single_input_and_returns_vector() {
419 let provider = RecordingProvider {
420 response: EmbeddingResponse::new(vec![vec![0.5, 0.5]]),
421 last_inputs: Mutex::new(None),
422 };
423 let vector = provider.embed_text("model", "hello").await.unwrap();
424 assert_eq!(vector, vec![0.5, 0.5]);
425 assert_eq!(
426 provider.last_inputs.lock().unwrap().clone(),
427 Some(vec!["hello".to_string()])
428 );
429 }
430
431 #[tokio::test]
432 async fn embed_text_errors_when_response_has_no_vectors() {
433 let provider = RecordingProvider {
434 response: EmbeddingResponse::new(Vec::new()),
435 last_inputs: Mutex::new(None),
436 };
437 let err = provider.embed_text("model", "hello").await.unwrap_err();
438 match err {
439 Error::UnexpectedResponse(message) => assert!(
440 message.contains("recording"),
441 "expected provider name in error, got: {message}"
442 ),
443 other => panic!("expected UnexpectedResponse, got {other:?}"),
444 }
445 }
446}