Skip to main content

modo/embed/
provider.rs

1use std::sync::Arc;
2
3use crate::error::{Error, Result};
4
5use super::backend::EmbeddingBackend;
6
7/// Concrete embedding provider — wraps any [`EmbeddingBackend`].
8///
9/// Cheap to clone (wraps `Arc` internally). Use as an axum service via
10/// `Service(embedder): Service<EmbeddingProvider>` where `embedder` is
11/// `Arc<EmbeddingProvider>`; `Arc<T>` derefs to `T` so calling `.embed()`
12/// directly on `embedder` works without extra unwrapping.
13///
14/// # Example
15///
16/// ```rust,ignore
17/// let client = reqwest::Client::new();
18/// let embedder = EmbeddingProvider::new(
19///     OpenAIEmbedding::new(client, &config)?,
20/// );
21/// let blob = embedder.embed("hello world").await?;
22/// ```
23pub struct EmbeddingProvider(Arc<dyn EmbeddingBackend>);
24
25impl Clone for EmbeddingProvider {
26    fn clone(&self) -> Self {
27        Self(Arc::clone(&self.0))
28    }
29}
30
31impl EmbeddingProvider {
32    /// Wrap any backend. `Arc` is handled internally.
33    pub fn new(backend: impl EmbeddingBackend + 'static) -> Self {
34        Self(Arc::new(backend))
35    }
36
37    /// Embed text. Returns a little-endian f32 blob for libsql.
38    ///
39    /// # Errors
40    ///
41    /// Returns `Error::bad_request` if `input` is empty.
42    /// Propagates provider API errors.
43    pub async fn embed(&self, input: &str) -> Result<Vec<u8>> {
44        if input.is_empty() {
45            return Err(Error::bad_request("embedding input must not be empty"));
46        }
47        self.0.embed(input).await
48    }
49
50    /// Number of dimensions this provider/model produces.
51    pub fn dimensions(&self) -> usize {
52        self.0.dimensions()
53    }
54
55    /// Model identifier string.
56    pub fn model_name(&self) -> &str {
57        self.0.model_name()
58    }
59}
60
61#[cfg(test)]
62mod tests {
63    use super::*;
64    use crate::embed::convert::from_f32_blob;
65    use crate::embed::test::InMemoryBackend;
66
67    #[tokio::test]
68    async fn embed_returns_blob_of_correct_length() {
69        let dims = 128;
70        let provider = EmbeddingProvider::new(InMemoryBackend::new(dims));
71        let blob = provider.embed("hello").await.unwrap();
72        assert_eq!(blob.len(), dims * 4);
73    }
74
75    #[tokio::test]
76    async fn embed_blob_roundtrips_to_floats() {
77        let dims = 4;
78        let provider = EmbeddingProvider::new(InMemoryBackend::new(dims));
79        let blob = provider.embed("test").await.unwrap();
80        let floats = from_f32_blob(&blob).unwrap();
81        assert_eq!(floats.len(), dims);
82        assert_eq!(floats, vec![0.1_f32; dims]);
83    }
84
85    #[tokio::test]
86    async fn embed_rejects_empty_input() {
87        let provider = EmbeddingProvider::new(InMemoryBackend::new(4));
88        let err = provider.embed("").await.unwrap_err();
89        assert_eq!(err.status(), http::StatusCode::BAD_REQUEST);
90    }
91
92    #[test]
93    fn dimensions_delegated() {
94        let provider = EmbeddingProvider::new(InMemoryBackend::new(768));
95        assert_eq!(provider.dimensions(), 768);
96    }
97
98    #[test]
99    fn model_name_delegated() {
100        let provider = EmbeddingProvider::new(InMemoryBackend::new(4));
101        assert_eq!(provider.model_name(), "test-embedding");
102    }
103}