brainwires_core/
embedding.rs1use anyhow::Result;
8
9pub trait EmbeddingProvider: Send + Sync {
23 fn embed(&self, text: &str) -> Result<Vec<f32>>;
25
26 fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>> {
31 texts.iter().map(|t| self.embed(t)).collect()
32 }
33
34 fn dimension(&self) -> usize;
36
37 fn model_name(&self) -> &str;
39}
40
41#[cfg(test)]
42mod tests {
43 use super::*;
44
45 struct MockEmbedding;
46
47 impl EmbeddingProvider for MockEmbedding {
48 fn embed(&self, _text: &str) -> Result<Vec<f32>> {
49 Ok(vec![0.1, 0.2, 0.3])
50 }
51
52 fn dimension(&self) -> usize {
53 3
54 }
55
56 fn model_name(&self) -> &str {
57 "mock-model"
58 }
59 }
60
61 #[test]
62 fn test_embed_single() {
63 let provider = MockEmbedding;
64 let embedding = provider.embed("test").unwrap();
65 assert_eq!(embedding.len(), 3);
66 }
67
68 #[test]
69 fn test_embed_batch_default() {
70 let provider = MockEmbedding;
71 let texts = vec!["a".to_string(), "b".to_string()];
72 let embeddings = provider.embed_batch(&texts).unwrap();
73 assert_eq!(embeddings.len(), 2);
74 }
75
76 #[test]
77 fn test_dimension() {
78 let provider = MockEmbedding;
79 assert_eq!(provider.dimension(), 3);
80 }
81
82 #[test]
83 fn test_model_name() {
84 let provider = MockEmbedding;
85 assert_eq!(provider.model_name(), "mock-model");
86 }
87}