Skip to main content

modo/embed/
provider.rs

1use std::sync::Arc;
2
3use crate::error::{Error, Result};
4
5use super::backend::EmbeddingBackend;
6
7/// Concrete embedding provider — wraps any [`EmbeddingBackend`].
8///
9/// Cheap to clone (wraps `Arc` internally). Use as an axum service via
10/// `Service(embedder): Service<EmbeddingProvider>`.
11///
12/// # Example
13///
14/// ```rust,ignore
15/// let client = reqwest::Client::new();
16/// let embedder = EmbeddingProvider::new(
17///     OpenAIEmbedding::new(client, &config)?,
18/// );
19/// let blob = embedder.embed("hello world").await?;
20/// ```
21pub struct EmbeddingProvider(Arc<dyn EmbeddingBackend>);
22
23impl Clone for EmbeddingProvider {
24    fn clone(&self) -> Self {
25        Self(Arc::clone(&self.0))
26    }
27}
28
29impl EmbeddingProvider {
30    /// Wrap any backend. `Arc` is handled internally.
31    pub fn new(backend: impl EmbeddingBackend + 'static) -> Self {
32        Self(Arc::new(backend))
33    }
34
35    /// Embed text. Returns a little-endian f32 blob for libsql.
36    ///
37    /// # Errors
38    ///
39    /// Returns `Error::bad_request` if `input` is empty.
40    /// Propagates provider API errors.
41    pub async fn embed(&self, input: &str) -> Result<Vec<u8>> {
42        if input.is_empty() {
43            return Err(Error::bad_request("embedding input must not be empty"));
44        }
45        self.0.embed(input).await
46    }
47
48    /// Number of dimensions this provider/model produces.
49    pub fn dimensions(&self) -> usize {
50        self.0.dimensions()
51    }
52
53    /// Model identifier string.
54    pub fn model_name(&self) -> &str {
55        self.0.model_name()
56    }
57}
58
59#[cfg(test)]
60mod tests {
61    use super::*;
62    use crate::embed::convert::from_f32_blob;
63    use crate::embed::test::InMemoryBackend;
64
65    #[tokio::test]
66    async fn embed_returns_blob_of_correct_length() {
67        let dims = 128;
68        let provider = EmbeddingProvider::new(InMemoryBackend::new(dims));
69        let blob = provider.embed("hello").await.unwrap();
70        assert_eq!(blob.len(), dims * 4);
71    }
72
73    #[tokio::test]
74    async fn embed_blob_roundtrips_to_floats() {
75        let dims = 4;
76        let provider = EmbeddingProvider::new(InMemoryBackend::new(dims));
77        let blob = provider.embed("test").await.unwrap();
78        let floats = from_f32_blob(&blob).unwrap();
79        assert_eq!(floats.len(), dims);
80        assert_eq!(floats, vec![0.1_f32; dims]);
81    }
82
83    #[tokio::test]
84    async fn embed_rejects_empty_input() {
85        let provider = EmbeddingProvider::new(InMemoryBackend::new(4));
86        let err = provider.embed("").await.unwrap_err();
87        assert_eq!(err.status(), http::StatusCode::BAD_REQUEST);
88    }
89
90    #[test]
91    fn dimensions_delegated() {
92        let provider = EmbeddingProvider::new(InMemoryBackend::new(768));
93        assert_eq!(provider.dimensions(), 768);
94    }
95
96    #[test]
97    fn model_name_delegated() {
98        let provider = EmbeddingProvider::new(InMemoryBackend::new(4));
99        assert_eq!(provider.model_name(), "test-embedding");
100    }
101}