Skip to main content

agentic_memory/v3/
embeddings.rs

1//! Embedding generation for semantic search.
2//! Supports multiple backends: local, API, or none (fallback to text search).
3
4use std::collections::HashMap;
5use std::sync::Arc;
6
7/// Embedding vector (typically 384 or 1536 dimensions)
8pub type Embedding = Vec<f32>;
9
10/// Trait for embedding providers
11pub trait EmbeddingProvider: Send + Sync {
12    /// Generate embedding for text
13    fn embed(&self, text: &str) -> Option<Embedding>;
14
15    /// Generate embeddings for multiple texts (batched)
16    fn embed_batch(&self, texts: &[&str]) -> Vec<Option<Embedding>> {
17        texts.iter().map(|t| self.embed(t)).collect()
18    }
19
20    /// Get embedding dimension
21    fn dimension(&self) -> usize;
22
23    /// Provider name
24    fn name(&self) -> &str;
25}
26
27/// No-op provider (fallback to text search)
28pub struct NoOpEmbedding;
29
30impl EmbeddingProvider for NoOpEmbedding {
31    fn embed(&self, _text: &str) -> Option<Embedding> {
32        None
33    }
34
35    fn dimension(&self) -> usize {
36        0
37    }
38
39    fn name(&self) -> &str {
40        "none"
41    }
42}
43
44/// Simple TF-IDF based embedding (no ML, fast, deterministic)
45pub struct TfIdfEmbedding {
46    vocabulary: HashMap<String, usize>,
47    dimension: usize,
48}
49
50impl TfIdfEmbedding {
51    pub fn new(dimension: usize) -> Self {
52        Self {
53            vocabulary: HashMap::new(),
54            dimension,
55        }
56    }
57
58    /// Build vocabulary from corpus
59    pub fn fit(&mut self, texts: &[&str]) {
60        let mut word_counts: HashMap<String, usize> = HashMap::new();
61
62        for text in texts {
63            for word in text.split_whitespace() {
64                let word = word.to_lowercase();
65                *word_counts.entry(word).or_insert(0) += 1;
66            }
67        }
68
69        // Take top N words by frequency
70        let mut words: Vec<_> = word_counts.into_iter().collect();
71        words.sort_by(|a, b| b.1.cmp(&a.1));
72
73        self.vocabulary = words
74            .into_iter()
75            .take(self.dimension)
76            .enumerate()
77            .map(|(i, (word, _))| (word, i))
78            .collect();
79    }
80}
81
82impl EmbeddingProvider for TfIdfEmbedding {
83    fn embed(&self, text: &str) -> Option<Embedding> {
84        let mut embedding = vec![0.0f32; self.dimension];
85        let words: Vec<_> = text.split_whitespace().collect();
86        let total = words.len() as f32;
87
88        if total == 0.0 {
89            return Some(embedding);
90        }
91
92        for word in words {
93            let word = word.to_lowercase();
94            if let Some(&idx) = self.vocabulary.get(&word) {
95                embedding[idx] += 1.0 / total;
96            }
97        }
98
99        // Normalize
100        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
101        if norm > 0.0 {
102            for x in &mut embedding {
103                *x /= norm;
104            }
105        }
106
107        Some(embedding)
108    }
109
110    fn dimension(&self) -> usize {
111        self.dimension
112    }
113
114    fn name(&self) -> &str {
115        "tfidf"
116    }
117}
118
119/// Embedding manager that handles provider selection
120pub struct EmbeddingManager {
121    provider: Arc<dyn EmbeddingProvider>,
122}
123
124impl EmbeddingManager {
125    pub fn new(provider: Arc<dyn EmbeddingProvider>) -> Self {
126        Self { provider }
127    }
128
129    pub fn with_tfidf(dimension: usize) -> Self {
130        Self {
131            provider: Arc::new(TfIdfEmbedding::new(dimension)),
132        }
133    }
134
135    pub fn none() -> Self {
136        Self {
137            provider: Arc::new(NoOpEmbedding),
138        }
139    }
140
141    pub fn embed(&self, text: &str) -> Option<Embedding> {
142        self.provider.embed(text)
143    }
144
145    pub fn dimension(&self) -> usize {
146        self.provider.dimension()
147    }
148
149    pub fn name(&self) -> &str {
150        self.provider.name()
151    }
152}