1use std::fmt;
24
25use serde::{Deserialize, Serialize};
26
27#[derive(Debug, thiserror::Error)]
31pub enum EmbeddingError {
32 #[error("embedding provider error: {0}")]
34 ProviderError(String),
35
36 #[error("failed to parse embedding response: {0}")]
38 ParseError(String),
39
40 #[error("expected {expected} embedding vectors, got {got}")]
42 CountMismatch { expected: usize, got: usize },
43
44 #[error("expected {expected}-dimensional embedding, got {got} dimensions")]
46 DimensionMismatch { expected: usize, got: usize },
47
48 #[error("embedding configuration error: {0}")]
50 ConfigError(String),
51}
52
53pub trait EmbeddingProvider: Send + Sync + fmt::Debug {
60 fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError>;
64
65 fn dimension(&self) -> usize;
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize)]
78#[serde(default, rename_all = "snake_case")]
79pub struct EmbeddingConfig {
80 pub model: String,
82 pub dimension: usize,
84 pub batch_size: usize,
86}
87
88impl Default for EmbeddingConfig {
89 fn default() -> Self {
90 Self {
91 model: String::new(), dimension: 0, batch_size: 32,
94 }
95 }
96}
97
98impl fmt::Display for EmbeddingConfig {
99 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100 write!(
101 f,
102 "model={}, dimension={}, batch_size={}",
103 if self.model.is_empty() {
104 "(default)"
105 } else {
106 &self.model
107 },
108 if self.dimension == 0 {
109 "(default)".to_owned()
110 } else {
111 self.dimension.to_string()
112 },
113 self.batch_size,
114 )
115 }
116}
117
118pub fn create_provider(
129 config: &EmbeddingConfig,
130) -> Result<Box<dyn EmbeddingProvider>, EmbeddingError> {
131 if config.batch_size == 0 {
133 return Err(EmbeddingError::ConfigError(
134 "batch_size must be at least 1".to_owned(),
135 ));
136 }
137
138 #[cfg(feature = "builtin-embeddings")]
139 {
140 builtin::create_builtin_provider(config)
141 }
142
143 #[cfg(not(feature = "builtin-embeddings"))]
144 {
145 Err(EmbeddingError::ConfigError(
146 "embedding support is not compiled in — rebuild with the \
147 'builtin-embeddings' feature (enabled by default)"
148 .to_owned(),
149 ))
150 }
151}
152
153#[cfg(feature = "builtin-embeddings")]
156mod builtin {
157 use std::sync::Mutex;
158
159 use super::*;
160
161 pub const DEFAULT_MODEL: &str = "all-MiniLM-L6-v2";
163 pub const DEFAULT_DIMENSION: usize = 384;
165
166 pub fn create_builtin_provider(
167 config: &EmbeddingConfig,
168 ) -> Result<Box<dyn EmbeddingProvider>, EmbeddingError> {
169 let model = if config.model.is_empty() {
170 DEFAULT_MODEL.to_owned()
171 } else {
172 config.model.clone()
173 };
174 let dimension = if config.dimension == 0 {
175 DEFAULT_DIMENSION
176 } else {
177 config.dimension
178 };
179 Ok(Box::new(BuiltinProvider::new(model, dimension)?))
180 }
181
182 pub struct BuiltinProvider {
188 model_name: String,
189 dimension: usize,
190 inner: Mutex<fastembed::TextEmbedding>,
194 }
195
196 impl fmt::Debug for BuiltinProvider {
197 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
198 f.debug_struct("BuiltinProvider")
199 .field("model_name", &self.model_name)
200 .field("dimension", &self.dimension)
201 .finish_non_exhaustive()
202 }
203 }
204
205 impl BuiltinProvider {
206 pub fn new(model_name: String, dimension: usize) -> Result<Self, EmbeddingError> {
207 use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
208
209 let model = match model_name.as_str() {
211 "all-MiniLM-L6-v2" => EmbeddingModel::AllMiniLML6V2,
212 other => {
213 return Err(EmbeddingError::ConfigError(format!(
214 "unknown built-in model '{other}'. \
215 Supported: all-MiniLM-L6-v2"
216 )));
217 }
218 };
219
220 let init_opts = InitOptions::new(model).with_show_download_progress(false);
222
223 tracing::info!(model = %model_name, "Loading built-in embedding model (may download on first run)");
224
225 let inner = TextEmbedding::try_new(init_opts)
226 .map_err(|e| EmbeddingError::ProviderError(format!("failed to load model: {e}")))?;
227
228 Ok(Self {
229 model_name,
230 dimension,
231 inner: Mutex::new(inner),
232 })
233 }
234 }
235
236 impl EmbeddingProvider for BuiltinProvider {
237 fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
238 if texts.is_empty() {
239 return Ok(Vec::new());
240 }
241
242 let embeddings = {
243 let mut model = self.inner.lock().map_err(|e| {
244 EmbeddingError::ProviderError(format!("model lock poisoned: {e}"))
245 })?;
246 model
247 .embed(texts, None)
248 .map_err(|e| EmbeddingError::ProviderError(e.to_string()))?
249 };
250
251 if embeddings.len() != texts.len() {
252 return Err(EmbeddingError::CountMismatch {
253 expected: texts.len(),
254 got: embeddings.len(),
255 });
256 }
257
258 for (i, vec) in embeddings.iter().enumerate() {
260 if vec.is_empty() {
261 return Err(EmbeddingError::ParseError(format!(
262 "embedding at index {i} is empty"
263 )));
264 }
265 if self.dimension > 0 && vec.len() != self.dimension {
269 return Err(EmbeddingError::DimensionMismatch {
270 expected: self.dimension,
271 got: vec.len(),
272 });
273 }
274 for &val in vec {
275 if !val.is_finite() {
276 return Err(EmbeddingError::ParseError(format!(
277 "embedding at index {i} contains non-finite value: {val}"
278 )));
279 }
280 }
281 }
282
283 Ok(embeddings)
284 }
285
286 fn dimension(&self) -> usize {
287 self.dimension
288 }
289 }
290}
291
292#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[derive(Debug)]
302 struct MockProvider {
303 dim: usize,
304 error: Option<String>,
305 }
306
307 impl MockProvider {
308 fn new(dim: usize) -> Self {
309 Self { dim, error: None }
310 }
311
312 fn with_error(dim: usize, msg: &str) -> Self {
313 Self {
314 dim,
315 error: Some(msg.to_owned()),
316 }
317 }
318 }
319
320 impl EmbeddingProvider for MockProvider {
321 fn embed(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
322 if let Some(ref msg) = self.error {
323 return Err(EmbeddingError::ProviderError(msg.clone()));
324 }
325 Ok(texts
326 .iter()
327 .enumerate()
328 .map(|(i, _)| vec![i as f32 / 10.0; self.dim])
329 .collect())
330 }
331
332 fn dimension(&self) -> usize {
333 self.dim
334 }
335 }
336
337 #[test]
338 fn mock_provider_returns_expected_embeddings() {
339 let provider = MockProvider::new(384);
340 let texts = vec!["hello".to_owned(), "world".to_owned()];
341 let result = provider.embed(&texts).unwrap();
342 assert_eq!(result.len(), 2);
343 assert_eq!(result[0].len(), 384);
344 assert!((result[0][0] - 0.0).abs() < f32::EPSILON);
345 assert!((result[1][0] - 0.1).abs() < f32::EPSILON);
346 }
347
348 #[test]
349 fn mock_provider_empty_input() {
350 let provider = MockProvider::new(384);
351 let result = provider.embed(&[]).unwrap();
352 assert!(result.is_empty());
353 }
354
355 #[test]
356 fn mock_provider_dimension() {
357 let provider = MockProvider::new(1536);
358 assert_eq!(provider.dimension(), 1536);
359 }
360
361 #[test]
362 fn mock_provider_error() {
363 let provider = MockProvider::with_error(384, "model load failed");
364 let result = provider.embed(&["test".to_owned()]);
365 assert!(result.is_err());
366 assert!(matches!(
367 result.unwrap_err(),
368 EmbeddingError::ProviderError(_)
369 ));
370 }
371
372 #[test]
375 fn config_default() {
376 let cfg = EmbeddingConfig::default();
377 assert!(cfg.model.is_empty());
378 assert_eq!(cfg.dimension, 0);
379 assert_eq!(cfg.batch_size, 32);
380 }
381
382 #[test]
383 fn config_parse_minimal() {
384 let toml_str = r#"batch_size = 16"#;
385 let cfg: EmbeddingConfig = toml::from_str(toml_str).unwrap();
386 assert_eq!(cfg.batch_size, 16);
387 assert!(cfg.model.is_empty());
388 assert_eq!(cfg.dimension, 0);
389 }
390
391 #[test]
392 fn config_parse_full() {
393 let toml_str = r#"
394model = "all-MiniLM-L6-v2"
395dimension = 384
396batch_size = 64
397"#;
398 let cfg: EmbeddingConfig = toml::from_str(toml_str).unwrap();
399 assert_eq!(cfg.model, "all-MiniLM-L6-v2");
400 assert_eq!(cfg.dimension, 384);
401 assert_eq!(cfg.batch_size, 64);
402 }
403
404 #[test]
405 fn config_parse_empty_uses_defaults() {
406 let cfg: EmbeddingConfig = toml::from_str("").unwrap();
407 assert!(cfg.model.is_empty());
408 assert_eq!(cfg.dimension, 0);
409 assert_eq!(cfg.batch_size, 32);
410 }
411
412 #[test]
415 fn config_display_with_values() {
416 let cfg = EmbeddingConfig {
417 model: "all-MiniLM-L6-v2".to_owned(),
418 dimension: 384,
419 batch_size: 32,
420 };
421 let s = format!("{cfg}");
422 assert!(s.contains("model=all-MiniLM-L6-v2"));
423 assert!(s.contains("dimension=384"));
424 assert!(s.contains("batch_size=32"));
425 }
426
427 #[test]
428 fn config_display_defaults() {
429 let cfg = EmbeddingConfig::default();
430 let s = format!("{cfg}");
431 assert!(s.contains("model=(default)"));
432 assert!(s.contains("dimension=(default)"));
433 }
434
435 #[test]
438 fn create_provider_batch_size_zero_returns_error() {
439 let cfg = EmbeddingConfig {
440 batch_size: 0,
441 ..Default::default()
442 };
443 let result = create_provider(&cfg);
444 assert!(result.is_err());
445 assert!(result.unwrap_err().to_string().contains("batch_size"));
446 }
447
448 #[test]
451 fn error_display_messages() {
452 let err = EmbeddingError::ProviderError("model load failed".to_owned());
453 assert!(err.to_string().contains("model load failed"));
454
455 let err = EmbeddingError::ParseError("bad data".to_owned());
456 assert!(err.to_string().contains("bad data"));
457
458 let err = EmbeddingError::CountMismatch {
459 expected: 3,
460 got: 1,
461 };
462 assert!(err.to_string().contains("3"));
463 assert!(err.to_string().contains("embedding vectors"));
464
465 let err = EmbeddingError::DimensionMismatch {
466 expected: 384,
467 got: 1536,
468 };
469 assert!(err.to_string().contains("384"));
470 assert!(err.to_string().contains("1536"));
471
472 let err = EmbeddingError::ConfigError("bad config".to_owned());
473 assert!(err.to_string().contains("bad config"));
474 }
475
476 #[test]
479 fn provider_as_trait_object() {
480 let provider: Box<dyn EmbeddingProvider> = Box::new(MockProvider::new(384));
481 assert_eq!(provider.dimension(), 384);
482 let result = provider.embed(&["test".to_owned()]).unwrap();
483 assert_eq!(result.len(), 1);
484 assert_eq!(result[0].len(), 384);
485 }
486
487 #[test]
488 fn provider_send_sync() {
489 fn assert_send_sync<T: Send + Sync>() {}
490 assert_send_sync::<MockProvider>();
491 }
492
493 #[test]
494 fn config_display_custom_model() {
495 let cfg = EmbeddingConfig {
496 model: "custom-model".to_owned(),
497 dimension: 768,
498 batch_size: 64,
499 };
500 let s = format!("{cfg}");
501 assert!(s.contains("custom-model"));
502 assert!(s.contains("dimension=768"));
503 assert!(s.contains("batch_size=64"));
504 }
505
506 #[test]
507 fn config_display_zero_dimension() {
508 let cfg = EmbeddingConfig {
509 dimension: 0,
510 ..Default::default()
511 };
512 let s = format!("{cfg}");
513 assert!(s.contains("dimension=(default)"));
514 }
515
516 #[test]
517 fn mock_provider_debug() {
518 let provider = MockProvider::new(128);
519 let dbg = format!("{provider:?}");
520 assert!(dbg.contains("MockProvider"));
521 }
522
523 #[test]
524 fn error_display_count_mismatch() {
525 let err = EmbeddingError::CountMismatch {
526 expected: 10,
527 got: 5,
528 };
529 let s = err.to_string();
530 assert!(s.contains("10"));
531 assert!(s.contains("5"));
532 }
533
534 #[test]
535 fn error_display_dimension_mismatch() {
536 let err = EmbeddingError::DimensionMismatch {
537 expected: 512,
538 got: 384,
539 };
540 let s = err.to_string();
541 assert!(s.contains("512"));
542 assert!(s.contains("384"));
543 assert!(s.contains("dimension"));
544 }
545
546 #[test]
547 fn mock_provider_zero_dimension() {
548 let provider = MockProvider::new(0);
549 assert_eq!(provider.dimension(), 0);
550 let result = provider.embed(&["test".to_owned()]).unwrap();
551 assert_eq!(result.len(), 1);
552 assert_eq!(result[0].len(), 0);
553 }
554
555 #[test]
556 fn mock_provider_embedding_values() {
557 let provider = MockProvider::new(3);
558 let texts = vec!["a".to_owned(), "b".to_owned(), "c".to_owned()];
559 let result = provider.embed(&texts).unwrap();
560 assert_eq!(result.len(), 3);
561 assert!((result[0][0] - 0.0).abs() < f32::EPSILON);
562 assert!((result[1][0] - 0.1).abs() < f32::EPSILON);
563 assert!((result[2][0] - 0.2).abs() < f32::EPSILON);
564 }
565
566 #[test]
567 fn create_provider_valid_config() {
568 let cfg = EmbeddingConfig::default();
569 let _ = create_provider(&cfg);
570 }
571
572 #[test]
573 fn embedding_error_display_parse_error() {
574 let err = EmbeddingError::ParseError("json malformed".to_owned());
575 assert!(err.to_string().contains("json malformed"));
576 }
577
578 #[test]
579 fn embedding_error_display_config_error() {
580 let err = EmbeddingError::ConfigError("unsupported provider".to_owned());
581 assert!(err.to_string().contains("unsupported provider"));
582 }
583}