pulsehive_core/embedding.rs
1//! Embedding provider abstraction for domain-specific embedding models.
2//!
3//! [`EmbeddingProvider`] enables products to use custom embedding models
4//! (medical, code, multilingual) instead of PulseDB's built-in all-MiniLM-L6-v2.
5//!
6//! When set on HiveMind, PulseHive computes embeddings via the provider and
7//! passes vectors to PulseDB in External mode. Products that don't set a provider
8//! get PulseDB's default embeddings automatically.
9//!
10//! # Example
11//! ```rust,ignore
12//! struct OpenAIEmbeddings { client: reqwest::Client, api_key: String }
13//!
14//! #[async_trait]
15//! impl EmbeddingProvider for OpenAIEmbeddings {
16//! async fn embed(&self, text: &str) -> Result<Vec<f32>> {
17//! // Call OpenAI embeddings API
18//! todo!()
19//! }
20//! fn dimensions(&self) -> usize { 1536 } // text-embedding-3-small
21//! }
22//! ```
23
24use async_trait::async_trait;
25
26use crate::error::Result;
27
28/// Trait for domain-specific embedding model implementations.
29///
30/// Provides text-to-vector embeddings for semantic search and similarity
31/// computation. When registered with HiveMind, all experiences are embedded
32/// via this provider before storage in PulseDB (External mode).
33///
34/// Must be `Send + Sync` for concurrent use across Tokio tasks.
35///
36/// # Default batch implementation
37///
38/// `embed_batch` has a default implementation that calls `embed` sequentially.
39/// Override it for providers that support native batching (e.g., OpenAI, Cohere).
40#[async_trait]
41pub trait EmbeddingProvider: Send + Sync {
42 /// Embed a single text string into a vector.
43 async fn embed(&self, text: &str) -> Result<Vec<f32>>;
44
45 /// Embed a batch of text strings.
46 ///
47 /// Default implementation calls `embed` sequentially. Override for providers
48 /// that support native batch embedding (significantly faster for large batches).
49 async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
50 let mut results = Vec::with_capacity(texts.len());
51 for text in texts {
52 results.push(self.embed(text).await?);
53 }
54 Ok(results)
55 }
56
57 /// Return the dimensionality of embeddings produced by this provider.
58 ///
59 /// Must be constant for a given provider instance. Used to configure
60 /// PulseDB's HNSW index when opening in External mode.
61 fn dimensions(&self) -> usize;
62}
63
64#[cfg(test)]
65mod tests {
66 use super::*;
67 use std::sync::Arc;
68
69 #[test]
70 fn test_embedding_provider_is_object_safe() {
71 fn _assert_object_safe(_: &dyn EmbeddingProvider) {}
72 fn _assert_boxable(_: Box<dyn EmbeddingProvider>) {}
73 fn _assert_arcable(_: Arc<dyn EmbeddingProvider>) {}
74 }
75
76 /// Mock embedding provider for testing.
77 struct MockEmbeddingProvider {
78 dims: usize,
79 }
80
81 #[async_trait]
82 impl EmbeddingProvider for MockEmbeddingProvider {
83 async fn embed(&self, _text: &str) -> Result<Vec<f32>> {
84 Ok(vec![0.1; self.dims])
85 }
86
87 fn dimensions(&self) -> usize {
88 self.dims
89 }
90 }
91
92 #[test]
93 fn test_dimensions_returns_configured_value() {
94 let provider = MockEmbeddingProvider { dims: 384 };
95 assert_eq!(provider.dimensions(), 384);
96
97 let provider = MockEmbeddingProvider { dims: 1536 };
98 assert_eq!(provider.dimensions(), 1536);
99 }
100
101 #[tokio::test]
102 async fn test_embed_returns_correct_length() {
103 let provider = MockEmbeddingProvider { dims: 384 };
104 let result = provider.embed("test text").await.unwrap();
105 assert_eq!(result.len(), 384);
106 }
107
108 #[tokio::test]
109 async fn test_embed_batch_default_impl() {
110 let provider = MockEmbeddingProvider { dims: 384 };
111 let texts = &["hello", "world", "test"];
112 let results = provider.embed_batch(texts).await.unwrap();
113 assert_eq!(results.len(), 3);
114 for result in &results {
115 assert_eq!(result.len(), 384);
116 }
117 }
118
119 #[tokio::test]
120 async fn test_embed_batch_empty_input() {
121 let provider = MockEmbeddingProvider { dims: 384 };
122 let results = provider.embed_batch(&[]).await.unwrap();
123 assert!(results.is_empty());
124 }
125}