Skip to main content

brainwires_core/
embedding.rs

1//! Unified embedding abstraction
2//!
3//! Provides the `EmbeddingProvider` trait for pluggable text embedding backends.
4//! Implementations live in downstream crates (storage, rag) — this trait enables
5//! consumers to accept any embedding backend without coupling to a specific one.
6
7use anyhow::Result;
8
9/// Trait for text embedding generation.
10///
11/// Implementations should be thread-safe and reusable across concurrent contexts.
12///
13/// # Example
14///
15/// ```ignore
16/// use brainwires_core::EmbeddingProvider;
17///
18/// fn search(provider: &dyn EmbeddingProvider, query: &str) -> anyhow::Result<Vec<f32>> {
19///     provider.embed(query)
20/// }
21/// ```
22pub trait EmbeddingProvider: Send + Sync {
23    /// Generate an embedding for a single text.
24    fn embed(&self, text: &str) -> Result<Vec<f32>>;
25
26    /// Generate embeddings for a batch of texts.
27    ///
28    /// Default implementation calls `embed` in a loop. Backends that support
29    /// native batching should override this for better performance.
30    fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
31        texts.iter().map(|t| self.embed(t)).collect()
32    }
33
34    /// Get the dimensionality of the embedding vectors.
35    fn dimension(&self) -> usize;
36
37    /// Get the model name (e.g. "all-MiniLM-L6-v2").
38    fn model_name(&self) -> &str;
39}
40
41#[cfg(test)]
42mod tests {
43    use super::*;
44
45    struct MockEmbedding;
46
47    impl EmbeddingProvider for MockEmbedding {
48        fn embed(&self, _text: &str) -> Result<Vec<f32>> {
49            Ok(vec![0.1, 0.2, 0.3])
50        }
51
52        fn dimension(&self) -> usize {
53            3
54        }
55
56        fn model_name(&self) -> &str {
57            "mock-model"
58        }
59    }
60
61    #[test]
62    fn test_embed_single() {
63        let provider = MockEmbedding;
64        let embedding = provider.embed("test").unwrap();
65        assert_eq!(embedding.len(), 3);
66    }
67
68    #[test]
69    fn test_embed_batch_default() {
70        let provider = MockEmbedding;
71        let texts = vec!["a".to_string(), "b".to_string()];
72        let embeddings = provider.embed_batch(&texts).unwrap();
73        assert_eq!(embeddings.len(), 2);
74    }
75
76    #[test]
77    fn test_dimension() {
78        let provider = MockEmbedding;
79        assert_eq!(provider.dimension(), 3);
80    }
81
82    #[test]
83    fn test_model_name() {
84        let provider = MockEmbedding;
85        assert_eq!(provider.model_name(), "mock-model");
86    }
87}