halldyll_memory_model/embedding/
mod.rs1use crate::core::{MemoryError, MemoryResult};
4use ort::session::{builder::GraphOptimizationLevel, Session};
5use std::collections::HashMap;
6use std::path::Path;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10const MAX_SEQUENCE_LENGTH: usize = 512;
12
13#[derive(Debug, Clone)]
16pub struct SimpleTokenizer {
17 vocab: HashMap<String, i64>,
18 unk_token_id: i64,
19 pad_token_id: i64,
20 cls_token_id: i64,
21 sep_token_id: i64,
22}
23
24impl SimpleTokenizer {
25 pub fn new() -> Self {
27 let mut vocab = HashMap::new();
28 vocab.insert("[PAD]".to_string(), 0);
30 vocab.insert("[UNK]".to_string(), 1);
31 vocab.insert("[CLS]".to_string(), 2);
32 vocab.insert("[SEP]".to_string(), 3);
33
34 for (i, c) in ('a'..='z').enumerate() {
36 vocab.insert(c.to_string(), 4 + i as i64);
37 }
38 for (i, c) in ('A'..='Z').enumerate() {
39 vocab.insert(c.to_string(), 30 + i as i64);
40 }
41 for (i, c) in ('0'..='9').enumerate() {
42 vocab.insert(c.to_string(), 56 + i as i64);
43 }
44 vocab.insert(" ".to_string(), 66);
45 vocab.insert(".".to_string(), 67);
46 vocab.insert(",".to_string(), 68);
47 vocab.insert("!".to_string(), 69);
48 vocab.insert("?".to_string(), 70);
49
50 Self {
51 vocab,
52 unk_token_id: 1,
53 pad_token_id: 0,
54 cls_token_id: 2,
55 sep_token_id: 3,
56 }
57 }
58
59 pub fn encode(&self, text: &str, max_length: usize) -> (Vec<i64>, Vec<i64>) {
61 let mut input_ids = vec![self.cls_token_id];
62 let chars: Vec<char> = text.chars().collect();
63
64 for c in chars.iter().take(max_length - 2) {
65 let token_id = self
66 .vocab
67 .get(&c.to_string())
68 .copied()
69 .unwrap_or(self.unk_token_id);
70 input_ids.push(token_id);
71 }
72 input_ids.push(self.sep_token_id);
73
74 let attention_mask: Vec<i64> = vec![1; input_ids.len()];
76
77 while input_ids.len() < max_length {
79 input_ids.push(self.pad_token_id);
80 }
81
82 let mut padded_attention_mask = attention_mask;
83 while padded_attention_mask.len() < max_length {
84 padded_attention_mask.push(0);
85 }
86
87 (input_ids, padded_attention_mask)
88 }
89}
90
91impl Default for SimpleTokenizer {
92 fn default() -> Self {
93 Self::new()
94 }
95}
96
97pub struct EmbeddingGenerator {
99 session: Option<Arc<std::sync::Mutex<Session>>>,
100 tokenizer: SimpleTokenizer,
101 embedding_dim: usize,
102 cache: Arc<RwLock<HashMap<String, Vec<f32>>>>,
103 cache_size: usize,
104}
105
106impl EmbeddingGenerator {
107 pub fn new() -> Self {
109 Self {
110 session: None,
111 tokenizer: SimpleTokenizer::new(),
112 embedding_dim: 384,
113 cache: Arc::new(RwLock::new(HashMap::new())),
114 cache_size: 1000,
115 }
116 }
117
118 pub fn with_model<P: AsRef<Path>>(model_path: P, embedding_dim: usize) -> MemoryResult<Self> {
120 let session = Session::builder()
121 .map_err(|e| MemoryError::OnnxModel(format!("Failed to create session builder: {}", e)))?
122 .with_optimization_level(GraphOptimizationLevel::Level3)
123 .map_err(|e| MemoryError::OnnxModel(format!("Failed to set optimization level: {}", e)))?
124 .commit_from_file(model_path)
125 .map_err(|e| MemoryError::OnnxModel(format!("Failed to load model: {}", e)))?;
126
127 Ok(Self {
128 session: Some(Arc::new(std::sync::Mutex::new(session))),
129 tokenizer: SimpleTokenizer::new(),
130 embedding_dim,
131 cache: Arc::new(RwLock::new(HashMap::new())),
132 cache_size: 1000,
133 })
134 }
135
136 pub fn with_cache_size(mut self, size: usize) -> Self {
138 self.cache_size = size;
139 self
140 }
141
142 pub async fn generate(&self, text: &str) -> MemoryResult<Vec<f32>> {
144 {
146 let cache = self.cache.read().await;
147 if let Some(embedding) = cache.get(text) {
148 return Ok(embedding.clone());
149 }
150 }
151
152 let embedding = if self.session.is_some() {
153 self.generate_with_model(text)?
154 } else {
155 self.generate_fallback(text)
156 };
157
158 {
160 let mut cache = self.cache.write().await;
161 if cache.len() >= self.cache_size {
162 let keys_to_remove: Vec<String> = cache.keys().take(cache.len() / 2).cloned().collect();
164 for key in keys_to_remove {
165 cache.remove(&key);
166 }
167 }
168 cache.insert(text.to_string(), embedding.clone());
169 }
170
171 Ok(embedding)
172 }
173
174 fn generate_with_model(&self, text: &str) -> MemoryResult<Vec<f32>> {
176 let session_lock = self.session.as_ref()
177 .ok_or_else(|| MemoryError::OnnxModel("No model loaded".to_string()))?;
178
179 let mut session = session_lock.lock()
180 .map_err(|e| MemoryError::OnnxModel(format!("Failed to lock session: {}", e)))?;
181
182 let (input_ids, attention_mask) = self.tokenizer.encode(text, MAX_SEQUENCE_LENGTH);
183
184 let shape = vec![1, MAX_SEQUENCE_LENGTH];
186
187 let input_ids_tensor = ort::value::Tensor::from_array((shape.clone(), input_ids))
188 .map_err(|e| MemoryError::OnnxModel(format!("Failed to create input_ids tensor: {}", e)))?;
189 let attention_mask_tensor = ort::value::Tensor::from_array((shape, attention_mask))
190 .map_err(|e| MemoryError::OnnxModel(format!("Failed to create attention_mask tensor: {}", e)))?;
191
192 let outputs = session.run(ort::inputs![
194 input_ids_tensor,
195 attention_mask_tensor
196 ])
197 .map_err(|e| MemoryError::OnnxModel(format!("Inference failed: {}", e)))?;
198
199 let output = outputs.iter().next()
201 .ok_or_else(|| MemoryError::OnnxModel("No output found".to_string()))?;
202
203 let tensor_data = output.1.try_extract_tensor::<f32>()
204 .map_err(|e| MemoryError::OnnxModel(format!("Failed to extract tensor: {}", e)))?;
205
206 let embedding = self.mean_pooling_from_raw(&tensor_data);
208
209 Ok(embedding)
210 }
211
212 fn mean_pooling_from_raw(&self, data: &(&ort::tensor::Shape, &[f32])) -> Vec<f32> {
214 let shape = data.0;
215 let values = data.1;
216 let dims: Vec<usize> = shape.iter().map(|&d| d as usize).collect();
217
218 if dims.len() == 3 {
219 let seq_len = dims[1];
221 let embedding_dim = dims[2];
222 let mut result = vec![0.0f32; embedding_dim];
223
224 for i in 0..seq_len {
225 for j in 0..embedding_dim {
226 result[j] += values[i * embedding_dim + j];
227 }
228 }
229
230 for val in &mut result {
231 *val /= seq_len as f32;
232 }
233
234 self.normalize(&mut result);
235 result
236 } else if dims.len() == 2 {
237 let mut result: Vec<f32> = values.to_vec();
239 self.normalize(&mut result);
240 result
241 } else {
242 let mut result: Vec<f32> = values.iter().take(self.embedding_dim).copied().collect();
244 self.normalize(&mut result);
245 result
246 }
247 }
248
249 fn generate_fallback(&self, text: &str) -> Vec<f32> {
251 let mut embedding = vec![0.0f32; self.embedding_dim];
252
253 let chars: Vec<char> = text.chars().collect();
255 let text_len = chars.len().max(1) as f32;
256
257 for (i, c) in chars.iter().enumerate() {
258 let char_val = (*c as u32) as f32;
259 let position = i as f32 / text_len;
260
261 for j in 0..self.embedding_dim {
263 let idx = (char_val as usize + j) % self.embedding_dim;
264 embedding[idx] += (char_val * position * (j as f32 + 1.0)).sin() * 0.1;
265 }
266 }
267
268 for window_size in 2..=4 {
270 if chars.len() >= window_size {
271 for window in chars.windows(window_size) {
272 let hash: u32 = window.iter().fold(0u32, |acc, &c| {
273 acc.wrapping_mul(31).wrapping_add(c as u32)
274 });
275 let idx = (hash as usize) % self.embedding_dim;
276 embedding[idx] += 0.05;
277 }
278 }
279 }
280
281 self.normalize(&mut embedding);
283
284 embedding
285 }
286
287 fn normalize(&self, embedding: &mut [f32]) {
289 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
290 if norm > 1e-10 {
291 for val in embedding.iter_mut() {
292 *val /= norm;
293 }
294 }
295 }
296
297 pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
299 if a.len() != b.len() {
300 return 0.0;
301 }
302
303 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
304 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
305 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
306
307 if norm_a > 1e-10 && norm_b > 1e-10 {
308 dot / (norm_a * norm_b)
309 } else {
310 0.0
311 }
312 }
313
314 pub fn dimension(&self) -> usize {
316 self.embedding_dim
317 }
318
319 pub async fn clear_cache(&self) {
321 let mut cache = self.cache.write().await;
322 cache.clear();
323 }
324
325 pub async fn cache_len(&self) -> usize {
327 let cache = self.cache.read().await;
328 cache.len()
329 }
330}
331
332impl Default for EmbeddingGenerator {
333 fn default() -> Self {
334 Self::new()
335 }
336}
337
338pub use self::SimpleTokenizer as Tokenizer;
340
341#[cfg(test)]
342mod tests {
343 use super::*;
344
345 #[test]
346 fn test_simple_tokenizer() {
347 let tokenizer = SimpleTokenizer::new();
348 let (ids, mask) = tokenizer.encode("hello", 10);
349 assert_eq!(ids.len(), 10);
350 assert_eq!(mask.len(), 10);
351 assert_eq!(ids[0], 2); }
353
354 #[tokio::test]
355 async fn test_embedding_generator_fallback() {
356 let generator = EmbeddingGenerator::new();
357 let embedding = generator.generate("Hello world").await.unwrap();
358 assert_eq!(embedding.len(), 384);
359
360 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
362 assert!((norm - 1.0).abs() < 1e-5);
363 }
364
365 #[tokio::test]
366 async fn test_embedding_similarity() {
367 let generator = EmbeddingGenerator::new();
368 let emb1 = generator.generate("Hello world").await.unwrap();
369 let emb2 = generator.generate("Hello world").await.unwrap();
370 let emb3 = generator.generate("Completely different text").await.unwrap();
371
372 let sim_same = EmbeddingGenerator::cosine_similarity(&emb1, &emb2);
373 let sim_diff = EmbeddingGenerator::cosine_similarity(&emb1, &emb3);
374
375 assert!((sim_same - 1.0).abs() < 1e-5); assert!(sim_diff < sim_same); }
378
379 #[tokio::test]
380 async fn test_embedding_cache() {
381 let generator = EmbeddingGenerator::new();
382 assert_eq!(generator.cache_len().await, 0);
383
384 let _ = generator.generate("test text").await.unwrap();
385 assert_eq!(generator.cache_len().await, 1);
386
387 generator.clear_cache().await;
388 assert_eq!(generator.cache_len().await, 0);
389 }
390}