pulsedb/embedding/mod.rs
1//! Embedding service abstractions for PulseDB.
2//!
3//! This module provides the trait and implementations for embedding generation.
4//! Embeddings are dense vector representations of text used for semantic search.
5//!
6//! # Providers
7//!
8//! - [`ExternalEmbedding`] - For pre-computed embeddings (e.g., OpenAI, Cohere)
9//! - `OnnxEmbedding` - Built-in ONNX model (requires `builtin-embeddings` feature)
10//!
11//! # Example
12//!
13//! ```rust
14//! use pulsedb::embedding::{EmbeddingService, ExternalEmbedding};
15//!
16//! // External mode - user provides embeddings
17//! let service = ExternalEmbedding::new(384);
18//! assert_eq!(service.dimension(), 384);
19//!
20//! // Validation only - cannot generate embeddings
21//! let result = service.embed("hello");
22//! assert!(result.is_err());
23//! ```
24
25#[cfg(feature = "builtin-embeddings")]
26#[cfg_attr(docsrs, doc(cfg(feature = "builtin-embeddings")))]
27pub mod onnx;
28
29use crate::error::{PulseDBError, Result};
30use crate::types::Embedding;
31
32/// Embedding service trait for generating vector representations of text.
33///
34/// This trait defines the contract for any embedding provider. Implementations
35/// must be thread-safe (`Send + Sync`) to allow concurrent embedding operations.
36///
37/// # Implementing a Custom Provider
38///
39/// ```rust,no_run
40/// use pulsedb::embedding::EmbeddingService;
41/// use pulsedb::{Embedding, Result};
42///
43/// # struct MyApiClient;
44/// # impl MyApiClient {
45/// # fn get_embedding(&self, _: &str) -> Result<Embedding> { Ok(vec![0.0; 384]) }
46/// # fn get_embeddings(&self, _: &[&str]) -> Result<Vec<Embedding>> { Ok(vec![]) }
47/// # }
48/// struct MyEmbeddingService {
49/// client: MyApiClient,
50/// dimension: usize,
51/// }
52///
53/// impl EmbeddingService for MyEmbeddingService {
54/// fn embed(&self, text: &str) -> Result<Embedding> {
55/// Ok(self.client.get_embedding(text)?)
56/// }
57///
58/// fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Embedding>> {
59/// Ok(self.client.get_embeddings(texts)?)
60/// }
61///
62/// fn dimension(&self) -> usize {
63/// self.dimension
64/// }
65/// }
66/// ```
67pub trait EmbeddingService: Send + Sync {
68 /// Generates an embedding for a single text.
69 ///
70 /// # Arguments
71 ///
72 /// * `text` - The text to embed
73 ///
74 /// # Returns
75 ///
76 /// A vector of f32 values with length equal to `dimension()`.
77 ///
78 /// # Errors
79 ///
80 /// Returns `PulseDBError::Embedding` if embedding generation fails.
81 fn embed(&self, text: &str) -> Result<Embedding>;
82
83 /// Generates embeddings for multiple texts in a batch.
84 ///
85 /// Batch processing is typically more efficient than individual calls
86 /// due to reduced API overhead and better GPU utilization.
87 ///
88 /// # Arguments
89 ///
90 /// * `texts` - Slice of texts to embed
91 ///
92 /// # Returns
93 ///
94 /// A vector of embeddings in the same order as the input texts.
95 ///
96 /// # Errors
97 ///
98 /// Returns `PulseDBError::Embedding` if any embedding generation fails.
99 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Embedding>>;
100
101 /// Returns the dimension of embeddings produced by this service.
102 ///
103 /// All embeddings from this service will have exactly this many dimensions.
104 fn dimension(&self) -> usize;
105
106 /// Validates that an embedding has the correct dimension.
107 ///
108 /// # Errors
109 ///
110 /// Returns `ValidationError::DimensionMismatch` if dimensions don't match.
111 fn validate_embedding(&self, embedding: &Embedding) -> Result<()> {
112 let expected = self.dimension();
113 let actual = embedding.len();
114
115 if actual != expected {
116 return Err(PulseDBError::Validation(
117 crate::error::ValidationError::dimension_mismatch(expected, actual),
118 ));
119 }
120
121 Ok(())
122 }
123}
124
125/// External embedding provider.
126///
127/// This provider is used when embeddings are generated externally (e.g., by
128/// OpenAI, Cohere, or a custom service). It validates embedding dimensions
129/// but cannot generate embeddings itself.
130///
131/// # Usage
132///
133/// When using `ExternalEmbedding`, you must provide pre-computed embedding
134/// vectors when recording experiences. Attempting to call `embed()` or
135/// `embed_batch()` will return an error.
136///
137/// # Example
138///
139/// ```rust
140/// use pulsedb::embedding::{EmbeddingService, ExternalEmbedding};
141///
142/// // Create for OpenAI ada-002 (1536 dimensions)
143/// let service = ExternalEmbedding::new(1536);
144/// assert_eq!(service.dimension(), 1536);
145/// ```
146#[derive(Clone, Debug)]
147pub struct ExternalEmbedding {
148 dimension: usize,
149}
150
151impl ExternalEmbedding {
152 /// Creates a new external embedding provider with the given dimension.
153 ///
154 /// # Arguments
155 ///
156 /// * `dimension` - The expected embedding dimension
157 ///
158 /// # Example
159 ///
160 /// ```rust
161 /// use pulsedb::embedding::ExternalEmbedding;
162 ///
163 /// // all-MiniLM-L6-v2
164 /// let service = ExternalEmbedding::new(384);
165 ///
166 /// // OpenAI text-embedding-3-small
167 /// let service = ExternalEmbedding::new(1536);
168 /// ```
169 pub fn new(dimension: usize) -> Self {
170 Self { dimension }
171 }
172}
173
174impl EmbeddingService for ExternalEmbedding {
175 fn embed(&self, _text: &str) -> Result<Embedding> {
176 Err(PulseDBError::embedding(
177 "External embedding mode: embeddings must be provided by the caller",
178 ))
179 }
180
181 fn embed_batch(&self, _texts: &[&str]) -> Result<Vec<Embedding>> {
182 Err(PulseDBError::embedding(
183 "External embedding mode: embeddings must be provided by the caller",
184 ))
185 }
186
187 fn dimension(&self) -> usize {
188 self.dimension
189 }
190}
191
192/// Creates an embedding service based on the configuration.
193///
194/// # Arguments
195///
196/// * `config` - Database configuration specifying the embedding provider
197///
198/// # Returns
199///
200/// A boxed embedding service ready for use.
201///
202/// # Errors
203///
204/// Returns an error if:
205/// - Builtin embeddings requested but feature not enabled
206/// - ONNX model loading fails (for builtin provider)
207pub fn create_embedding_service(
208 config: &crate::config::Config,
209) -> Result<Box<dyn EmbeddingService>> {
210 use crate::config::EmbeddingProvider;
211
212 match &config.embedding_provider {
213 EmbeddingProvider::External => {
214 let dimension = config.embedding_dimension.size();
215 Ok(Box::new(ExternalEmbedding::new(dimension)))
216 }
217
218 #[cfg(feature = "builtin-embeddings")]
219 EmbeddingProvider::Builtin { model_path } => {
220 let dim = config.embedding_dimension.size();
221 match onnx::OnnxEmbedding::with_dimension(model_path.clone(), dim) {
222 Ok(service) => Ok(Box::new(service)),
223 Err(ref e) if e.to_string().contains("Model not found") => {
224 tracing::info!(
225 "Builtin embedding model not found, downloading (dimension: {dim})..."
226 );
227 let _path = onnx::OnnxEmbedding::download_default_model(dim)?;
228 let service = onnx::OnnxEmbedding::with_dimension(model_path.clone(), dim)?;
229 Ok(Box::new(service))
230 }
231 Err(e) => Err(e),
232 }
233 }
234
235 #[cfg(not(feature = "builtin-embeddings"))]
236 EmbeddingProvider::Builtin { .. } => Err(PulseDBError::embedding(
237 "Builtin embeddings require the 'builtin-embeddings' feature",
238 )),
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use super::*;
245
246 #[test]
247 fn test_external_embedding_dimension() {
248 let service = ExternalEmbedding::new(384);
249 assert_eq!(service.dimension(), 384);
250 }
251
252 #[test]
253 fn test_external_embedding_embed_returns_error() {
254 let service = ExternalEmbedding::new(384);
255 let result = service.embed("hello world");
256 assert!(result.is_err());
257 }
258
259 #[test]
260 fn test_external_embedding_embed_batch_returns_error() {
261 let service = ExternalEmbedding::new(384);
262 let result = service.embed_batch(&["hello", "world"]);
263 assert!(result.is_err());
264 }
265
266 #[test]
267 fn test_validate_embedding_correct_dimension() {
268 let service = ExternalEmbedding::new(3);
269 let embedding = vec![1.0, 2.0, 3.0];
270 assert!(service.validate_embedding(&embedding).is_ok());
271 }
272
273 #[test]
274 fn test_validate_embedding_wrong_dimension() {
275 let service = ExternalEmbedding::new(3);
276 let embedding = vec![1.0, 2.0]; // Only 2 dimensions
277 let result = service.validate_embedding(&embedding);
278 assert!(result.is_err());
279 }
280
281 #[test]
282 fn test_external_embedding_is_send_sync() {
283 fn assert_send_sync<T: Send + Sync>() {}
284 assert_send_sync::<ExternalEmbedding>();
285 }
286
287 #[test]
288 fn test_create_embedding_service_external() {
289 let config = crate::config::Config::default();
290 let service = create_embedding_service(&config).unwrap();
291 assert_eq!(service.dimension(), 384);
292 }
293
294 #[test]
295 #[ignore] // Requires network access for auto-download
296 fn test_create_embedding_service_builtin_auto_downloads() {
297 let config = crate::config::Config::with_builtin_embeddings();
298 let result = create_embedding_service(&config);
299 // With auto-download, this should succeed if network is available
300 assert!(result.is_ok());
301 }
302}