Skip to main content

pulsedb/embedding/
mod.rs

1//! Embedding service abstractions for PulseDB.
2//!
3//! This module provides the trait and implementations for embedding generation.
4//! Embeddings are dense vector representations of text used for semantic search.
5//!
6//! # Providers
7//!
8//! - [`ExternalEmbedding`] - For pre-computed embeddings (e.g., OpenAI, Cohere)
9//! - `OnnxEmbedding` - Built-in ONNX model (requires `builtin-embeddings` feature)
10//!
11//! # Example
12//!
13//! ```rust
14//! use pulsedb::embedding::{EmbeddingService, ExternalEmbedding};
15//!
16//! // External mode - user provides embeddings
17//! let service = ExternalEmbedding::new(384);
18//! assert_eq!(service.dimension(), 384);
19//!
20//! // Validation only - cannot generate embeddings
21//! let result = service.embed("hello");
22//! assert!(result.is_err());
23//! ```
24
25#[cfg(feature = "builtin-embeddings")]
26#[cfg_attr(docsrs, doc(cfg(feature = "builtin-embeddings")))]
27pub mod onnx;
28
29use crate::error::{PulseDBError, Result};
30use crate::types::Embedding;
31
32/// Embedding service trait for generating vector representations of text.
33///
34/// This trait defines the contract for any embedding provider. Implementations
35/// must be thread-safe (`Send + Sync`) to allow concurrent embedding operations.
36///
37/// # Implementing a Custom Provider
38///
39/// ```rust,no_run
40/// use pulsedb::embedding::EmbeddingService;
41/// use pulsedb::{Embedding, Result};
42///
43/// # struct MyApiClient;
44/// # impl MyApiClient {
45/// #     fn get_embedding(&self, _: &str) -> Result<Embedding> { Ok(vec![0.0; 384]) }
46/// #     fn get_embeddings(&self, _: &[&str]) -> Result<Vec<Embedding>> { Ok(vec![]) }
47/// # }
48/// struct MyEmbeddingService {
49///     client: MyApiClient,
50///     dimension: usize,
51/// }
52///
53/// impl EmbeddingService for MyEmbeddingService {
54///     fn embed(&self, text: &str) -> Result<Embedding> {
55///         Ok(self.client.get_embedding(text)?)
56///     }
57///
58///     fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Embedding>> {
59///         Ok(self.client.get_embeddings(texts)?)
60///     }
61///
62///     fn dimension(&self) -> usize {
63///         self.dimension
64///     }
65/// }
66/// ```
67pub trait EmbeddingService: Send + Sync {
68    /// Generates an embedding for a single text.
69    ///
70    /// # Arguments
71    ///
72    /// * `text` - The text to embed
73    ///
74    /// # Returns
75    ///
76    /// A vector of f32 values with length equal to `dimension()`.
77    ///
78    /// # Errors
79    ///
80    /// Returns `PulseDBError::Embedding` if embedding generation fails.
81    fn embed(&self, text: &str) -> Result<Embedding>;
82
83    /// Generates embeddings for multiple texts in a batch.
84    ///
85    /// Batch processing is typically more efficient than individual calls
86    /// due to reduced API overhead and better GPU utilization.
87    ///
88    /// # Arguments
89    ///
90    /// * `texts` - Slice of texts to embed
91    ///
92    /// # Returns
93    ///
94    /// A vector of embeddings in the same order as the input texts.
95    ///
96    /// # Errors
97    ///
98    /// Returns `PulseDBError::Embedding` if any embedding generation fails.
99    fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Embedding>>;
100
101    /// Returns the dimension of embeddings produced by this service.
102    ///
103    /// All embeddings from this service will have exactly this many dimensions.
104    fn dimension(&self) -> usize;
105
106    /// Validates that an embedding has the correct dimension.
107    ///
108    /// # Errors
109    ///
110    /// Returns `ValidationError::DimensionMismatch` if dimensions don't match.
111    fn validate_embedding(&self, embedding: &Embedding) -> Result<()> {
112        let expected = self.dimension();
113        let actual = embedding.len();
114
115        if actual != expected {
116            return Err(PulseDBError::Validation(
117                crate::error::ValidationError::dimension_mismatch(expected, actual),
118            ));
119        }
120
121        Ok(())
122    }
123}
124
125/// External embedding provider.
126///
127/// This provider is used when embeddings are generated externally (e.g., by
128/// OpenAI, Cohere, or a custom service). It validates embedding dimensions
129/// but cannot generate embeddings itself.
130///
131/// # Usage
132///
133/// When using `ExternalEmbedding`, you must provide pre-computed embedding
134/// vectors when recording experiences. Attempting to call `embed()` or
135/// `embed_batch()` will return an error.
136///
137/// # Example
138///
139/// ```rust
140/// use pulsedb::embedding::{EmbeddingService, ExternalEmbedding};
141///
142/// // Create for OpenAI ada-002 (1536 dimensions)
143/// let service = ExternalEmbedding::new(1536);
144/// assert_eq!(service.dimension(), 1536);
145/// ```
146#[derive(Clone, Debug)]
147pub struct ExternalEmbedding {
148    dimension: usize,
149}
150
151impl ExternalEmbedding {
152    /// Creates a new external embedding provider with the given dimension.
153    ///
154    /// # Arguments
155    ///
156    /// * `dimension` - The expected embedding dimension
157    ///
158    /// # Example
159    ///
160    /// ```rust
161    /// use pulsedb::embedding::ExternalEmbedding;
162    ///
163    /// // all-MiniLM-L6-v2
164    /// let service = ExternalEmbedding::new(384);
165    ///
166    /// // OpenAI text-embedding-3-small
167    /// let service = ExternalEmbedding::new(1536);
168    /// ```
169    pub fn new(dimension: usize) -> Self {
170        Self { dimension }
171    }
172}
173
174impl EmbeddingService for ExternalEmbedding {
175    fn embed(&self, _text: &str) -> Result<Embedding> {
176        Err(PulseDBError::embedding(
177            "External embedding mode: embeddings must be provided by the caller",
178        ))
179    }
180
181    fn embed_batch(&self, _texts: &[&str]) -> Result<Vec<Embedding>> {
182        Err(PulseDBError::embedding(
183            "External embedding mode: embeddings must be provided by the caller",
184        ))
185    }
186
187    fn dimension(&self) -> usize {
188        self.dimension
189    }
190}
191
192/// Creates an embedding service based on the configuration.
193///
194/// # Arguments
195///
196/// * `config` - Database configuration specifying the embedding provider
197///
198/// # Returns
199///
200/// A boxed embedding service ready for use.
201///
202/// # Errors
203///
204/// Returns an error if:
205/// - Builtin embeddings requested but feature not enabled
206/// - ONNX model loading fails (for builtin provider)
207pub fn create_embedding_service(
208    config: &crate::config::Config,
209) -> Result<Box<dyn EmbeddingService>> {
210    use crate::config::EmbeddingProvider;
211
212    match &config.embedding_provider {
213        EmbeddingProvider::External => {
214            let dimension = config.embedding_dimension.size();
215            Ok(Box::new(ExternalEmbedding::new(dimension)))
216        }
217
218        #[cfg(feature = "builtin-embeddings")]
219        EmbeddingProvider::Builtin { model_path } => {
220            let dim = config.embedding_dimension.size();
221            match onnx::OnnxEmbedding::with_dimension(model_path.clone(), dim) {
222                Ok(service) => Ok(Box::new(service)),
223                Err(ref e) if e.to_string().contains("Model not found") => {
224                    tracing::info!(
225                        "Builtin embedding model not found, downloading (dimension: {dim})..."
226                    );
227                    let _path = onnx::OnnxEmbedding::download_default_model(dim)?;
228                    let service = onnx::OnnxEmbedding::with_dimension(model_path.clone(), dim)?;
229                    Ok(Box::new(service))
230                }
231                Err(e) => Err(e),
232            }
233        }
234
235        #[cfg(not(feature = "builtin-embeddings"))]
236        EmbeddingProvider::Builtin { .. } => Err(PulseDBError::embedding(
237            "Builtin embeddings require the 'builtin-embeddings' feature",
238        )),
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245
246    #[test]
247    fn test_external_embedding_dimension() {
248        let service = ExternalEmbedding::new(384);
249        assert_eq!(service.dimension(), 384);
250    }
251
252    #[test]
253    fn test_external_embedding_embed_returns_error() {
254        let service = ExternalEmbedding::new(384);
255        let result = service.embed("hello world");
256        assert!(result.is_err());
257    }
258
259    #[test]
260    fn test_external_embedding_embed_batch_returns_error() {
261        let service = ExternalEmbedding::new(384);
262        let result = service.embed_batch(&["hello", "world"]);
263        assert!(result.is_err());
264    }
265
266    #[test]
267    fn test_validate_embedding_correct_dimension() {
268        let service = ExternalEmbedding::new(3);
269        let embedding = vec![1.0, 2.0, 3.0];
270        assert!(service.validate_embedding(&embedding).is_ok());
271    }
272
273    #[test]
274    fn test_validate_embedding_wrong_dimension() {
275        let service = ExternalEmbedding::new(3);
276        let embedding = vec![1.0, 2.0]; // Only 2 dimensions
277        let result = service.validate_embedding(&embedding);
278        assert!(result.is_err());
279    }
280
281    #[test]
282    fn test_external_embedding_is_send_sync() {
283        fn assert_send_sync<T: Send + Sync>() {}
284        assert_send_sync::<ExternalEmbedding>();
285    }
286
287    #[test]
288    fn test_create_embedding_service_external() {
289        let config = crate::config::Config::default();
290        let service = create_embedding_service(&config).unwrap();
291        assert_eq!(service.dimension(), 384);
292    }
293
294    #[test]
295    #[ignore] // Requires network access for auto-download
296    fn test_create_embedding_service_builtin_auto_downloads() {
297        let config = crate::config::Config::with_builtin_embeddings();
298        let result = create_embedding_service(&config);
299        // With auto-download, this should succeed if network is available
300        assert!(result.is_ok());
301    }
302}