Skip to main content

pulsehive_core/
embedding.rs

1//! Embedding provider abstraction for domain-specific embedding models.
2//!
3//! [`EmbeddingProvider`] enables products to use custom embedding models
4//! (medical, code, multilingual) instead of PulseDB's built-in all-MiniLM-L6-v2.
5//!
6//! When set on HiveMind, PulseHive computes embeddings via the provider and
7//! passes vectors to PulseDB in External mode. Products that don't set a provider
8//! get PulseDB's default embeddings automatically.
9//!
10//! # Example
11//! ```rust,ignore
12//! struct OpenAIEmbeddings { client: reqwest::Client, api_key: String }
13//!
14//! #[async_trait]
15//! impl EmbeddingProvider for OpenAIEmbeddings {
16//!     async fn embed(&self, text: &str) -> Result<Vec<f32>> {
17//!         // Call OpenAI embeddings API
18//!         todo!()
19//!     }
20//!     fn dimensions(&self) -> usize { 1536 } // text-embedding-3-small
21//! }
22//! ```
23
24use async_trait::async_trait;
25
26use crate::error::Result;
27
28/// Trait for domain-specific embedding model implementations.
29///
30/// Provides text-to-vector embeddings for semantic search and similarity
31/// computation. When registered with HiveMind, all experiences are embedded
32/// via this provider before storage in PulseDB (External mode).
33///
34/// Must be `Send + Sync` for concurrent use across Tokio tasks.
35///
36/// # Default batch implementation
37///
38/// `embed_batch` has a default implementation that calls `embed` sequentially.
39/// Override it for providers that support native batching (e.g., OpenAI, Cohere).
40#[async_trait]
41pub trait EmbeddingProvider: Send + Sync {
42    /// Embed a single text string into a vector.
43    async fn embed(&self, text: &str) -> Result<Vec<f32>>;
44
45    /// Embed a batch of text strings.
46    ///
47    /// Default implementation calls `embed` sequentially. Override for providers
48    /// that support native batch embedding (significantly faster for large batches).
49    async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
50        let mut results = Vec::with_capacity(texts.len());
51        for text in texts {
52            results.push(self.embed(text).await?);
53        }
54        Ok(results)
55    }
56
57    /// Return the dimensionality of embeddings produced by this provider.
58    ///
59    /// Must be constant for a given provider instance. Used to configure
60    /// PulseDB's HNSW index when opening in External mode.
61    fn dimensions(&self) -> usize;
62}
63
64#[cfg(test)]
65mod tests {
66    use super::*;
67    use std::sync::Arc;
68
69    #[test]
70    fn test_embedding_provider_is_object_safe() {
71        fn _assert_object_safe(_: &dyn EmbeddingProvider) {}
72        fn _assert_boxable(_: Box<dyn EmbeddingProvider>) {}
73        fn _assert_arcable(_: Arc<dyn EmbeddingProvider>) {}
74    }
75
76    /// Mock embedding provider for testing.
77    struct MockEmbeddingProvider {
78        dims: usize,
79    }
80
81    #[async_trait]
82    impl EmbeddingProvider for MockEmbeddingProvider {
83        async fn embed(&self, _text: &str) -> Result<Vec<f32>> {
84            Ok(vec![0.1; self.dims])
85        }
86
87        fn dimensions(&self) -> usize {
88            self.dims
89        }
90    }
91
92    #[test]
93    fn test_dimensions_returns_configured_value() {
94        let provider = MockEmbeddingProvider { dims: 384 };
95        assert_eq!(provider.dimensions(), 384);
96
97        let provider = MockEmbeddingProvider { dims: 1536 };
98        assert_eq!(provider.dimensions(), 1536);
99    }
100
101    #[tokio::test]
102    async fn test_embed_returns_correct_length() {
103        let provider = MockEmbeddingProvider { dims: 384 };
104        let result = provider.embed("test text").await.unwrap();
105        assert_eq!(result.len(), 384);
106    }
107
108    #[tokio::test]
109    async fn test_embed_batch_default_impl() {
110        let provider = MockEmbeddingProvider { dims: 384 };
111        let texts = &["hello", "world", "test"];
112        let results = provider.embed_batch(texts).await.unwrap();
113        assert_eq!(results.len(), 3);
114        for result in &results {
115            assert_eq!(result.len(), 384);
116        }
117    }
118
119    #[tokio::test]
120    async fn test_embed_batch_empty_input() {
121        let provider = MockEmbeddingProvider { dims: 384 };
122        let results = provider.embed_batch(&[]).await.unwrap();
123        assert!(results.is_empty());
124    }
125}