Skip to main content

anyllm/embedding/
provider.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::Arc;
4
5use serde::{Deserialize, Serialize};
6
7use crate::{CapabilitySupport, EmbeddingRequest, EmbeddingResponse, ProviderIdentity, Result};
8
9/// Portable embedding features that a provider/model may support.
10#[non_exhaustive]
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
12pub enum EmbeddingCapability {
13    /// Accepts more than one input in a single request.
14    BatchInput,
15    /// Honors the [`EmbeddingRequest::dimensions`] output-size request field.
16    OutputDimensions,
17}
18
19/// Core trait for providers that expose a text embedding API.
20///
21/// Implementations are batch-oriented. Callers that have a single input
22/// should use [`EmbeddingProviderExt::embed_text`].
23///
24/// Methods return `impl Future<…> + Send` so wrappers and dyn dispatch can
25/// rely on `Send` futures, matching [`crate::ChatProvider`].
26pub trait EmbeddingProvider: ProviderIdentity {
27    /// Send an embedding request and return ordered vectors.
28    ///
29    /// # Errors
30    ///
31    /// Returns [`crate::Error`] on provider communication or decoding failures.
32    fn embed(
33        &self,
34        request: &EmbeddingRequest,
35    ) -> impl Future<Output = Result<EmbeddingResponse>> + Send;
36
37    /// Returns support information for a provider/model embedding capability.
38    fn embedding_capability(
39        &self,
40        _model: &str,
41        _capability: EmbeddingCapability,
42    ) -> CapabilitySupport {
43        CapabilitySupport::Unknown
44    }
45}
46
47impl<T> EmbeddingProvider for &T
48where
49    T: EmbeddingProvider + ?Sized,
50{
51    async fn embed(&self, request: &EmbeddingRequest) -> Result<EmbeddingResponse> {
52        T::embed(*self, request).await
53    }
54
55    fn embedding_capability(
56        &self,
57        model: &str,
58        capability: EmbeddingCapability,
59    ) -> CapabilitySupport {
60        T::embedding_capability(*self, model, capability)
61    }
62}
63
64impl<T> EmbeddingProvider for Box<T>
65where
66    T: EmbeddingProvider + ?Sized,
67{
68    async fn embed(&self, request: &EmbeddingRequest) -> Result<EmbeddingResponse> {
69        T::embed(self.as_ref(), request).await
70    }
71
72    fn embedding_capability(
73        &self,
74        model: &str,
75        capability: EmbeddingCapability,
76    ) -> CapabilitySupport {
77        T::embedding_capability(self.as_ref(), model, capability)
78    }
79}
80
81impl<T> EmbeddingProvider for Arc<T>
82where
83    T: EmbeddingProvider + ?Sized,
84{
85    async fn embed(&self, request: &EmbeddingRequest) -> Result<EmbeddingResponse> {
86        T::embed(self.as_ref(), request).await
87    }
88
89    fn embedding_capability(
90        &self,
91        model: &str,
92        capability: EmbeddingCapability,
93    ) -> CapabilitySupport {
94        T::embedding_capability(self.as_ref(), model, capability)
95    }
96}
97
98/// A type-erased embedding provider for dynamic dispatch.
99///
100/// Wraps any `T: EmbeddingProvider + 'static` behind a vtable, boxing the
101/// async method future. Mirrors [`crate::DynChatProvider`].
102#[derive(Clone)]
103pub struct DynEmbeddingProvider(Arc<dyn EmbeddingProviderErased>);
104
105impl DynEmbeddingProvider {
106    /// Erase a concrete provider into a `DynEmbeddingProvider`.
107    #[must_use]
108    pub fn new<T>(provider: T) -> Self
109    where
110        T: EmbeddingProvider + 'static,
111    {
112        Self(Arc::new(provider))
113    }
114}
115
116impl<T> From<Arc<T>> for DynEmbeddingProvider
117where
118    T: EmbeddingProvider + 'static,
119{
120    fn from(provider: Arc<T>) -> Self {
121        Self(provider)
122    }
123}
124
125impl std::fmt::Debug for DynEmbeddingProvider {
126    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127        f.debug_struct("DynEmbeddingProvider")
128            .field("provider", &self.0.provider_name())
129            .finish()
130    }
131}
132
133impl ProviderIdentity for DynEmbeddingProvider {
134    fn provider_name(&self) -> &'static str {
135        self.0.provider_name()
136    }
137}
138
139impl EmbeddingProvider for DynEmbeddingProvider {
140    async fn embed(&self, request: &EmbeddingRequest) -> Result<EmbeddingResponse> {
141        self.0.embed_erased(request).await
142    }
143
144    fn embedding_capability(
145        &self,
146        model: &str,
147        capability: EmbeddingCapability,
148    ) -> CapabilitySupport {
149        self.0.embedding_capability_erased(model, capability)
150    }
151}
152
153/// Object-safe internal trait that manually boxes the async `embed` future.
154///
155/// Sealed by the blanket impl for `T: EmbeddingProvider`.
156trait EmbeddingProviderErased: ProviderIdentity {
157    fn embed_erased<'a>(
158        &'a self,
159        request: &'a EmbeddingRequest,
160    ) -> Pin<Box<dyn Future<Output = Result<EmbeddingResponse>> + Send + 'a>>;
161
162    fn embedding_capability_erased(
163        &self,
164        model: &str,
165        capability: EmbeddingCapability,
166    ) -> CapabilitySupport;
167}
168
169impl<T> EmbeddingProviderErased for T
170where
171    T: EmbeddingProvider,
172{
173    fn embed_erased<'a>(
174        &'a self,
175        request: &'a EmbeddingRequest,
176    ) -> Pin<Box<dyn Future<Output = Result<EmbeddingResponse>> + Send + 'a>> {
177        Box::pin(EmbeddingProvider::embed(self, request))
178    }
179
180    fn embedding_capability_erased(
181        &self,
182        model: &str,
183        capability: EmbeddingCapability,
184    ) -> CapabilitySupport {
185        EmbeddingProvider::embedding_capability(self, model, capability)
186    }
187}
188
189/// Convenience extension methods for [`EmbeddingProvider`] implementors.
190pub trait EmbeddingProviderExt: EmbeddingProvider {
191    /// Quick one-shot embedding for a single input.
192    ///
193    /// # Errors
194    ///
195    /// Propagates any [`crate::Error`] from the underlying
196    /// [`EmbeddingProvider::embed`] call, and returns
197    /// [`crate::Error::UnexpectedResponse`] if the provider response contains
198    /// no embeddings.
199    fn embed_text(
200        &self,
201        model: &str,
202        input: impl Into<String>,
203    ) -> impl Future<Output = Result<Vec<f32>>> + Send {
204        let input = input.into();
205        let model = model.to_string();
206
207        async move {
208            let response = self
209                .embed(&EmbeddingRequest::new(model).input(input))
210                .await?;
211
212            response.embeddings.into_iter().next().ok_or_else(|| {
213                crate::Error::UnexpectedResponse(format!(
214                    "provider '{}' returned no embeddings for embed_text()",
215                    self.provider_name()
216                ))
217            })
218        }
219    }
220}
221
222impl<T: EmbeddingProvider> EmbeddingProviderExt for T {}
223
224#[cfg(test)]
225mod provider_tests {
226    use super::*;
227    use crate::{ProviderIdentity, Result};
228    use std::sync::Arc;
229
230    struct StaticEmbeddingProvider {
231        response: EmbeddingResponse,
232    }
233
234    impl ProviderIdentity for StaticEmbeddingProvider {
235        fn provider_name(&self) -> &'static str {
236            "static-embed"
237        }
238    }
239
240    impl EmbeddingProvider for StaticEmbeddingProvider {
241        async fn embed(&self, _request: &EmbeddingRequest) -> Result<EmbeddingResponse> {
242            Ok(self.response.clone())
243        }
244
245        fn embedding_capability(
246            &self,
247            _model: &str,
248            capability: EmbeddingCapability,
249        ) -> CapabilitySupport {
250            match capability {
251                EmbeddingCapability::BatchInput => CapabilitySupport::Supported,
252                EmbeddingCapability::OutputDimensions => CapabilitySupport::Unsupported,
253            }
254        }
255    }
256
257    fn demo_provider() -> StaticEmbeddingProvider {
258        StaticEmbeddingProvider {
259            response: EmbeddingResponse::new(vec![vec![0.1, 0.2]]).model("demo"),
260        }
261    }
262
263    #[tokio::test]
264    async fn direct_impl_returns_response() {
265        let provider = demo_provider();
266        let request = EmbeddingRequest::new("demo").input("hello");
267        let response = provider.embed(&request).await.unwrap();
268        assert_eq!(response.embeddings, vec![vec![0.1, 0.2]]);
269    }
270
271    #[tokio::test]
272    async fn ref_forwards_embed() {
273        let provider = demo_provider();
274        let borrowed: &StaticEmbeddingProvider = &provider;
275        let request = EmbeddingRequest::new("demo").input("hello");
276        assert_eq!(
277            borrowed.embed(&request).await.unwrap().embeddings,
278            vec![vec![0.1, 0.2]]
279        );
280        assert_eq!(borrowed.provider_name(), "static-embed");
281    }
282
283    #[tokio::test]
284    async fn box_forwards_embed() {
285        let boxed: Box<StaticEmbeddingProvider> = Box::new(demo_provider());
286        let request = EmbeddingRequest::new("demo").input("hello");
287        assert_eq!(
288            boxed.embed(&request).await.unwrap().embeddings,
289            vec![vec![0.1, 0.2]]
290        );
291    }
292
293    #[tokio::test]
294    async fn arc_forwards_embed_and_capability() {
295        let arced: Arc<StaticEmbeddingProvider> = Arc::new(demo_provider());
296        let request = EmbeddingRequest::new("demo").input("hello");
297        assert_eq!(
298            arced.embed(&request).await.unwrap().embeddings,
299            vec![vec![0.1, 0.2]]
300        );
301        assert_eq!(
302            arced.embedding_capability("demo", EmbeddingCapability::BatchInput),
303            CapabilitySupport::Supported
304        );
305        assert_eq!(
306            arced.embedding_capability("demo", EmbeddingCapability::OutputDimensions),
307            CapabilitySupport::Unsupported
308        );
309    }
310
311    #[tokio::test]
312    async fn default_capability_method_returns_unknown() {
313        struct Minimal;
314        impl ProviderIdentity for Minimal {}
315        impl EmbeddingProvider for Minimal {
316            async fn embed(&self, _request: &EmbeddingRequest) -> Result<EmbeddingResponse> {
317                Ok(EmbeddingResponse::default())
318            }
319        }
320
321        assert_eq!(
322            Minimal.embedding_capability("any", EmbeddingCapability::BatchInput),
323            CapabilitySupport::Unknown
324        );
325    }
326}
327
328#[cfg(test)]
329mod dyn_tests {
330    use super::*;
331    use crate::ProviderIdentity;
332    use std::sync::Arc;
333
334    struct DynDemo {
335        tag: &'static str,
336    }
337
338    impl ProviderIdentity for DynDemo {
339        fn provider_name(&self) -> &'static str {
340            self.tag
341        }
342    }
343
344    impl EmbeddingProvider for DynDemo {
345        async fn embed(&self, request: &EmbeddingRequest) -> crate::Result<EmbeddingResponse> {
346            let inputs = request.inputs.len();
347            Ok(EmbeddingResponse::new(vec![vec![0.0; 4]; inputs]))
348        }
349
350        fn embedding_capability(
351            &self,
352            _model: &str,
353            capability: EmbeddingCapability,
354        ) -> CapabilitySupport {
355            match capability {
356                EmbeddingCapability::BatchInput => CapabilitySupport::Supported,
357                EmbeddingCapability::OutputDimensions => CapabilitySupport::Unsupported,
358            }
359        }
360    }
361
362    #[tokio::test]
363    async fn dyn_provider_from_concrete_forwards_calls() {
364        let provider = DynEmbeddingProvider::new(DynDemo { tag: "dyn-embed" });
365        let request = EmbeddingRequest::new("demo").inputs(["a", "b"]);
366        let response = provider.embed(&request).await.unwrap();
367        assert_eq!(response.embeddings.len(), 2);
368        assert_eq!(provider.provider_name(), "dyn-embed");
369        assert_eq!(
370            provider.embedding_capability("demo", EmbeddingCapability::BatchInput),
371            CapabilitySupport::Supported
372        );
373    }
374
375    #[tokio::test]
376    async fn dyn_provider_from_arc_is_cloneable() {
377        let provider: DynEmbeddingProvider = Arc::new(DynDemo { tag: "arc-embed" }).into();
378        let cloned = provider.clone();
379        let request = EmbeddingRequest::new("demo").input("x");
380        assert_eq!(cloned.embed(&request).await.unwrap().embeddings.len(), 1);
381        assert_eq!(cloned.provider_name(), "arc-embed");
382    }
383
384    #[test]
385    fn dyn_provider_debug_includes_provider_name() {
386        let provider = DynEmbeddingProvider::new(DynDemo { tag: "debug-embed" });
387        let debug = format!("{provider:?}");
388        assert!(debug.contains("DynEmbeddingProvider"));
389        assert!(debug.contains("debug-embed"));
390    }
391}
392
393#[cfg(test)]
394mod ext_tests {
395    use super::*;
396    use crate::{Error, ProviderIdentity};
397    use std::sync::Mutex;
398
399    struct RecordingProvider {
400        response: EmbeddingResponse,
401        last_inputs: Mutex<Option<Vec<String>>>,
402    }
403
404    impl ProviderIdentity for RecordingProvider {
405        fn provider_name(&self) -> &'static str {
406            "recording"
407        }
408    }
409
410    impl EmbeddingProvider for RecordingProvider {
411        async fn embed(&self, request: &EmbeddingRequest) -> crate::Result<EmbeddingResponse> {
412            *self.last_inputs.lock().unwrap() = Some(request.inputs.clone());
413            Ok(self.response.clone())
414        }
415    }
416
417    #[tokio::test]
418    async fn embed_text_sends_single_input_and_returns_vector() {
419        let provider = RecordingProvider {
420            response: EmbeddingResponse::new(vec![vec![0.5, 0.5]]),
421            last_inputs: Mutex::new(None),
422        };
423        let vector = provider.embed_text("model", "hello").await.unwrap();
424        assert_eq!(vector, vec![0.5, 0.5]);
425        assert_eq!(
426            provider.last_inputs.lock().unwrap().clone(),
427            Some(vec!["hello".to_string()])
428        );
429    }
430
431    #[tokio::test]
432    async fn embed_text_errors_when_response_has_no_vectors() {
433        let provider = RecordingProvider {
434            response: EmbeddingResponse::new(Vec::new()),
435            last_inputs: Mutex::new(None),
436        };
437        let err = provider.embed_text("model", "hello").await.unwrap_err();
438        match err {
439            Error::UnexpectedResponse(message) => assert!(
440                message.contains("recording"),
441                "expected provider name in error, got: {message}"
442            ),
443            other => panic!("expected UnexpectedResponse, got {other:?}"),
444        }
445    }
446}