agentic_memory/v3/
embeddings.rs1use std::collections::HashMap;
5use std::sync::Arc;
6
7pub type Embedding = Vec<f32>;
9
10pub trait EmbeddingProvider: Send + Sync {
12 fn embed(&self, text: &str) -> Option<Embedding>;
14
15 fn embed_batch(&self, texts: &[&str]) -> Vec<Option<Embedding>> {
17 texts.iter().map(|t| self.embed(t)).collect()
18 }
19
20 fn dimension(&self) -> usize;
22
23 fn name(&self) -> &str;
25}
26
27pub struct NoOpEmbedding;
29
30impl EmbeddingProvider for NoOpEmbedding {
31 fn embed(&self, _text: &str) -> Option<Embedding> {
32 None
33 }
34
35 fn dimension(&self) -> usize {
36 0
37 }
38
39 fn name(&self) -> &str {
40 "none"
41 }
42}
43
44pub struct TfIdfEmbedding {
46 vocabulary: HashMap<String, usize>,
47 dimension: usize,
48}
49
50impl TfIdfEmbedding {
51 pub fn new(dimension: usize) -> Self {
52 Self {
53 vocabulary: HashMap::new(),
54 dimension,
55 }
56 }
57
58 pub fn fit(&mut self, texts: &[&str]) {
60 let mut word_counts: HashMap<String, usize> = HashMap::new();
61
62 for text in texts {
63 for word in text.split_whitespace() {
64 let word = word.to_lowercase();
65 *word_counts.entry(word).or_insert(0) += 1;
66 }
67 }
68
69 let mut words: Vec<_> = word_counts.into_iter().collect();
71 words.sort_by(|a, b| b.1.cmp(&a.1));
72
73 self.vocabulary = words
74 .into_iter()
75 .take(self.dimension)
76 .enumerate()
77 .map(|(i, (word, _))| (word, i))
78 .collect();
79 }
80}
81
82impl EmbeddingProvider for TfIdfEmbedding {
83 fn embed(&self, text: &str) -> Option<Embedding> {
84 let mut embedding = vec![0.0f32; self.dimension];
85 let words: Vec<_> = text.split_whitespace().collect();
86 let total = words.len() as f32;
87
88 if total == 0.0 {
89 return Some(embedding);
90 }
91
92 for word in words {
93 let word = word.to_lowercase();
94 if let Some(&idx) = self.vocabulary.get(&word) {
95 embedding[idx] += 1.0 / total;
96 }
97 }
98
99 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
101 if norm > 0.0 {
102 for x in &mut embedding {
103 *x /= norm;
104 }
105 }
106
107 Some(embedding)
108 }
109
110 fn dimension(&self) -> usize {
111 self.dimension
112 }
113
114 fn name(&self) -> &str {
115 "tfidf"
116 }
117}
118
119pub struct EmbeddingManager {
121 provider: Arc<dyn EmbeddingProvider>,
122}
123
124impl EmbeddingManager {
125 pub fn new(provider: Arc<dyn EmbeddingProvider>) -> Self {
126 Self { provider }
127 }
128
129 pub fn with_tfidf(dimension: usize) -> Self {
130 Self {
131 provider: Arc::new(TfIdfEmbedding::new(dimension)),
132 }
133 }
134
135 pub fn none() -> Self {
136 Self {
137 provider: Arc::new(NoOpEmbedding),
138 }
139 }
140
141 pub fn embed(&self, text: &str) -> Option<Embedding> {
142 self.provider.embed(text)
143 }
144
145 pub fn dimension(&self) -> usize {
146 self.provider.dimension()
147 }
148
149 pub fn name(&self) -> &str {
150 self.provider.name()
151 }
152}