Skip to main content

fabryk_vector/
embedding.rs

1//! Embedding provider trait and mock implementation.
2//!
3//! This module defines the `EmbeddingProvider` trait that abstracts over
4//! different embedding generation backends (fastembed, OpenAI, etc.).
5//!
6//! # Providers
7//!
8//! - `MockEmbeddingProvider`: Deterministic fixed-dimension vectors for testing
9//! - `FastEmbedProvider`: Local embedding via fastembed (requires `vector-fastembed` feature)
10
11use async_trait::async_trait;
12use fabryk_core::Result;
13
14/// Trait for generating text embeddings.
15///
16/// Implementations wrap specific embedding libraries (fastembed, OpenAI, etc.)
17/// and provide a uniform async interface. The trait requires `Send + Sync` to
18/// allow safe sharing across async tasks.
19///
20/// # Thread Safety
21///
22/// Implementations should handle internal synchronization (e.g., `Arc<Mutex<>>`)
23/// for thread-unsafe underlying libraries.
24#[async_trait]
25pub trait EmbeddingProvider: Send + Sync {
26    /// Generate an embedding for a single text.
27    async fn embed(&self, text: &str) -> Result<Vec<f32>>;
28
29    /// Generate embeddings for a batch of texts.
30    ///
31    /// Default implementation calls `embed` for each text sequentially.
32    /// Backends that support native batching should override this.
33    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
34        let mut results = Vec::with_capacity(texts.len());
35        for text in texts {
36            results.push(self.embed(text).await?);
37        }
38        Ok(results)
39    }
40
41    /// The embedding dimension.
42    fn dimension(&self) -> usize;
43
44    /// The provider name for diagnostics.
45    fn name(&self) -> &str;
46}
47
48/// A mock embedding provider for testing.
49///
50/// Generates deterministic vectors based on the input text hash.
51/// Each component is derived from the text bytes, producing consistent
52/// embeddings for the same input.
53pub struct MockEmbeddingProvider {
54    dimension: usize,
55}
56
57impl MockEmbeddingProvider {
58    /// Create a new mock provider with the given dimension.
59    pub fn new(dimension: usize) -> Self {
60        Self { dimension }
61    }
62
63    /// Generate a deterministic embedding from text.
64    fn deterministic_embedding(&self, text: &str) -> Vec<f32> {
65        let mut embedding = vec![0.0f32; self.dimension];
66        let bytes = text.as_bytes();
67
68        for (i, val) in embedding.iter_mut().enumerate() {
69            // Use byte values to create deterministic but varied components
70            let byte_idx = i % bytes.len().max(1);
71            let byte_val = if bytes.is_empty() {
72                0u8
73            } else {
74                bytes[byte_idx]
75            };
76            *val = ((byte_val as f32 + i as f32) % 256.0) / 256.0;
77        }
78
79        // Normalize to unit vector
80        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
81        if norm > 0.0 {
82            for val in &mut embedding {
83                *val /= norm;
84            }
85        }
86
87        embedding
88    }
89}
90
91#[async_trait]
92impl EmbeddingProvider for MockEmbeddingProvider {
93    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
94        Ok(self.deterministic_embedding(text))
95    }
96
97    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
98        Ok(texts
99            .iter()
100            .map(|t| self.deterministic_embedding(t))
101            .collect())
102    }
103
104    fn dimension(&self) -> usize {
105        self.dimension
106    }
107
108    fn name(&self) -> &str {
109        "mock"
110    }
111}
112
113// ============================================================================
114// Tests
115// ============================================================================
116
117#[cfg(test)]
118mod tests {
119    use super::*;
120
121    #[test]
122    fn test_mock_provider_creation() {
123        let provider = MockEmbeddingProvider::new(384);
124        assert_eq!(provider.dimension(), 384);
125        assert_eq!(provider.name(), "mock");
126    }
127
128    #[tokio::test]
129    async fn test_mock_embed_single() {
130        let provider = MockEmbeddingProvider::new(8);
131        let embedding = provider.embed("hello world").await.unwrap();
132
133        assert_eq!(embedding.len(), 8);
134
135        // Verify unit-normalized
136        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
137        assert!((norm - 1.0).abs() < 1e-5);
138    }
139
140    #[tokio::test]
141    async fn test_mock_embed_deterministic() {
142        let provider = MockEmbeddingProvider::new(16);
143        let e1 = provider.embed("same text").await.unwrap();
144        let e2 = provider.embed("same text").await.unwrap();
145
146        assert_eq!(e1, e2);
147    }
148
149    #[tokio::test]
150    async fn test_mock_embed_different_texts() {
151        let provider = MockEmbeddingProvider::new(16);
152        let e1 = provider.embed("text one").await.unwrap();
153        let e2 = provider.embed("text two").await.unwrap();
154
155        assert_ne!(e1, e2);
156    }
157
158    #[tokio::test]
159    async fn test_mock_embed_batch() {
160        let provider = MockEmbeddingProvider::new(8);
161        let texts = vec!["hello", "world", "test"];
162        let embeddings = provider.embed_batch(&texts).await.unwrap();
163
164        assert_eq!(embeddings.len(), 3);
165        for emb in &embeddings {
166            assert_eq!(emb.len(), 8);
167        }
168    }
169
170    #[tokio::test]
171    async fn test_mock_embed_empty_text() {
172        let provider = MockEmbeddingProvider::new(4);
173        let embedding = provider.embed("").await.unwrap();
174
175        assert_eq!(embedding.len(), 4);
176        // Empty text produces zero vector (all 0s mapped from byte 0)
177        // After normalization, it should remain zeros since norm is 0
178    }
179
180    #[tokio::test]
181    async fn test_mock_embed_batch_empty() {
182        let provider = MockEmbeddingProvider::new(4);
183        let texts: Vec<&str> = vec![];
184        let embeddings = provider.embed_batch(&texts).await.unwrap();
185
186        assert!(embeddings.is_empty());
187    }
188
189    #[test]
190    fn test_trait_object_safety() {
191        // Verify EmbeddingProvider can be used as a trait object
192        fn _assert_object_safe(_: &dyn EmbeddingProvider) {}
193    }
194}