halldyll_memory_model/embedding/
mod.rs

1//! Embedding generation with ONNX Runtime
2
3use crate::core::{MemoryError, MemoryResult};
4use ort::session::{builder::GraphOptimizationLevel, Session};
5use std::collections::HashMap;
6use std::path::Path;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9
10/// Maximum sequence length for tokenization
11const MAX_SEQUENCE_LENGTH: usize = 512;
12
13/// Simple tokenizer for embedding generation
14/// In production, you would use a proper tokenizer like HuggingFace tokenizers
15#[derive(Debug, Clone)]
16pub struct SimpleTokenizer {
17    vocab: HashMap<String, i64>,
18    unk_token_id: i64,
19    pad_token_id: i64,
20    cls_token_id: i64,
21    sep_token_id: i64,
22}
23
24impl SimpleTokenizer {
25    /// Create a new simple tokenizer with basic vocab
26    pub fn new() -> Self {
27        let mut vocab = HashMap::new();
28        // Special tokens
29        vocab.insert("[PAD]".to_string(), 0);
30        vocab.insert("[UNK]".to_string(), 1);
31        vocab.insert("[CLS]".to_string(), 2);
32        vocab.insert("[SEP]".to_string(), 3);
33
34        // Basic ASCII characters as tokens (simplified)
35        for (i, c) in ('a'..='z').enumerate() {
36            vocab.insert(c.to_string(), 4 + i as i64);
37        }
38        for (i, c) in ('A'..='Z').enumerate() {
39            vocab.insert(c.to_string(), 30 + i as i64);
40        }
41        for (i, c) in ('0'..='9').enumerate() {
42            vocab.insert(c.to_string(), 56 + i as i64);
43        }
44        vocab.insert(" ".to_string(), 66);
45        vocab.insert(".".to_string(), 67);
46        vocab.insert(",".to_string(), 68);
47        vocab.insert("!".to_string(), 69);
48        vocab.insert("?".to_string(), 70);
49
50        Self {
51            vocab,
52            unk_token_id: 1,
53            pad_token_id: 0,
54            cls_token_id: 2,
55            sep_token_id: 3,
56        }
57    }
58
59    /// Tokenize text into token IDs
60    pub fn encode(&self, text: &str, max_length: usize) -> (Vec<i64>, Vec<i64>) {
61        let mut input_ids = vec![self.cls_token_id];
62        let chars: Vec<char> = text.chars().collect();
63
64        for c in chars.iter().take(max_length - 2) {
65            let token_id = self
66                .vocab
67                .get(&c.to_string())
68                .copied()
69                .unwrap_or(self.unk_token_id);
70            input_ids.push(token_id);
71        }
72        input_ids.push(self.sep_token_id);
73
74        // Create attention mask
75        let attention_mask: Vec<i64> = vec![1; input_ids.len()];
76
77        // Pad to max_length
78        while input_ids.len() < max_length {
79            input_ids.push(self.pad_token_id);
80        }
81
82        let mut padded_attention_mask = attention_mask;
83        while padded_attention_mask.len() < max_length {
84            padded_attention_mask.push(0);
85        }
86
87        (input_ids, padded_attention_mask)
88    }
89}
90
91impl Default for SimpleTokenizer {
92    fn default() -> Self {
93        Self::new()
94    }
95}
96
97/// Embedding generator using ONNX Runtime
98pub struct EmbeddingGenerator {
99    session: Option<Arc<std::sync::Mutex<Session>>>,
100    tokenizer: SimpleTokenizer,
101    embedding_dim: usize,
102    cache: Arc<RwLock<HashMap<String, Vec<f32>>>>,
103    cache_size: usize,
104}
105
106impl EmbeddingGenerator {
107    /// Create a new embedding generator without a model (uses fallback hashing)
108    pub fn new() -> Self {
109        Self {
110            session: None,
111            tokenizer: SimpleTokenizer::new(),
112            embedding_dim: 384,
113            cache: Arc::new(RwLock::new(HashMap::new())),
114            cache_size: 1000,
115        }
116    }
117
118    /// Create embedding generator with ONNX model
119    pub fn with_model<P: AsRef<Path>>(model_path: P, embedding_dim: usize) -> MemoryResult<Self> {
120        let session = Session::builder()
121            .map_err(|e| MemoryError::OnnxModel(format!("Failed to create session builder: {}", e)))?
122            .with_optimization_level(GraphOptimizationLevel::Level3)
123            .map_err(|e| MemoryError::OnnxModel(format!("Failed to set optimization level: {}", e)))?
124            .commit_from_file(model_path)
125            .map_err(|e| MemoryError::OnnxModel(format!("Failed to load model: {}", e)))?;
126
127        Ok(Self {
128            session: Some(Arc::new(std::sync::Mutex::new(session))),
129            tokenizer: SimpleTokenizer::new(),
130            embedding_dim,
131            cache: Arc::new(RwLock::new(HashMap::new())),
132            cache_size: 1000,
133        })
134    }
135
136    /// Create embedding generator with custom cache size
137    pub fn with_cache_size(mut self, size: usize) -> Self {
138        self.cache_size = size;
139        self
140    }
141
142    /// Generate embedding for text
143    pub async fn generate(&self, text: &str) -> MemoryResult<Vec<f32>> {
144        // Check cache first
145        {
146            let cache = self.cache.read().await;
147            if let Some(embedding) = cache.get(text) {
148                return Ok(embedding.clone());
149            }
150        }
151
152        let embedding = if self.session.is_some() {
153            self.generate_with_model(text)?
154        } else {
155            self.generate_fallback(text)
156        };
157
158        // Update cache
159        {
160            let mut cache = self.cache.write().await;
161            if cache.len() >= self.cache_size {
162                // Simple eviction: clear half the cache
163                let keys_to_remove: Vec<String> = cache.keys().take(cache.len() / 2).cloned().collect();
164                for key in keys_to_remove {
165                    cache.remove(&key);
166                }
167            }
168            cache.insert(text.to_string(), embedding.clone());
169        }
170
171        Ok(embedding)
172    }
173
174    /// Generate embedding using ONNX model
175    fn generate_with_model(&self, text: &str) -> MemoryResult<Vec<f32>> {
176        let session_lock = self.session.as_ref()
177            .ok_or_else(|| MemoryError::OnnxModel("No model loaded".to_string()))?;
178
179        let mut session = session_lock.lock()
180            .map_err(|e| MemoryError::OnnxModel(format!("Failed to lock session: {}", e)))?;
181
182        let (input_ids, attention_mask) = self.tokenizer.encode(text, MAX_SEQUENCE_LENGTH);
183
184        // Create input tensors with proper shapes for ort 2.0
185        let shape = vec![1, MAX_SEQUENCE_LENGTH];
186        
187        let input_ids_tensor = ort::value::Tensor::from_array((shape.clone(), input_ids))
188            .map_err(|e| MemoryError::OnnxModel(format!("Failed to create input_ids tensor: {}", e)))?;
189        let attention_mask_tensor = ort::value::Tensor::from_array((shape, attention_mask))
190            .map_err(|e| MemoryError::OnnxModel(format!("Failed to create attention_mask tensor: {}", e)))?;
191
192        // Run inference
193        let outputs = session.run(ort::inputs![
194            input_ids_tensor,
195            attention_mask_tensor
196        ])
197        .map_err(|e| MemoryError::OnnxModel(format!("Inference failed: {}", e)))?;
198
199        // Extract embedding from first output
200        let output = outputs.iter().next()
201            .ok_or_else(|| MemoryError::OnnxModel("No output found".to_string()))?;
202
203        let tensor_data = output.1.try_extract_tensor::<f32>()
204            .map_err(|e| MemoryError::OnnxModel(format!("Failed to extract tensor: {}", e)))?;
205
206        // Mean pooling over sequence dimension
207        let embedding = self.mean_pooling_from_raw(&tensor_data);
208
209        Ok(embedding)
210    }
211
212    /// Mean pooling from raw tensor data
213    fn mean_pooling_from_raw(&self, data: &(&ort::tensor::Shape, &[f32])) -> Vec<f32> {
214        let shape = data.0;
215        let values = data.1;
216        let dims: Vec<usize> = shape.iter().map(|&d| d as usize).collect();
217        
218        if dims.len() == 3 {
219            // Shape is (1, seq_len, embedding_dim)
220            let seq_len = dims[1];
221            let embedding_dim = dims[2];
222            let mut result = vec![0.0f32; embedding_dim];
223
224            for i in 0..seq_len {
225                for j in 0..embedding_dim {
226                    result[j] += values[i * embedding_dim + j];
227                }
228            }
229
230            for val in &mut result {
231                *val /= seq_len as f32;
232            }
233
234            self.normalize(&mut result);
235            result
236        } else if dims.len() == 2 {
237            // Already pooled, shape is (1, embedding_dim)
238            let mut result: Vec<f32> = values.to_vec();
239            self.normalize(&mut result);
240            result
241        } else {
242            // Fallback: just normalize what we have
243            let mut result: Vec<f32> = values.iter().take(self.embedding_dim).copied().collect();
244            self.normalize(&mut result);
245            result
246        }
247    }
248
249    /// Generate fallback embedding using hash-based approach
250    fn generate_fallback(&self, text: &str) -> Vec<f32> {
251        let mut embedding = vec![0.0f32; self.embedding_dim];
252
253        // Use character-level features
254        let chars: Vec<char> = text.chars().collect();
255        let text_len = chars.len().max(1) as f32;
256
257        for (i, c) in chars.iter().enumerate() {
258            let char_val = (*c as u32) as f32;
259            let position = i as f32 / text_len;
260
261            // Distribute character influence across embedding dimensions
262            for j in 0..self.embedding_dim {
263                let idx = (char_val as usize + j) % self.embedding_dim;
264                embedding[idx] += (char_val * position * (j as f32 + 1.0)).sin() * 0.1;
265            }
266        }
267
268        // Add n-gram features
269        for window_size in 2..=4 {
270            if chars.len() >= window_size {
271                for window in chars.windows(window_size) {
272                    let hash: u32 = window.iter().fold(0u32, |acc, &c| {
273                        acc.wrapping_mul(31).wrapping_add(c as u32)
274                    });
275                    let idx = (hash as usize) % self.embedding_dim;
276                    embedding[idx] += 0.05;
277                }
278            }
279        }
280
281        // Normalize the embedding
282        self.normalize(&mut embedding);
283
284        embedding
285    }
286
287    /// Normalize embedding vector to unit length
288    fn normalize(&self, embedding: &mut [f32]) {
289        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
290        if norm > 1e-10 {
291            for val in embedding.iter_mut() {
292                *val /= norm;
293            }
294        }
295    }
296
297    /// Compute cosine similarity between two embeddings
298    pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
299        if a.len() != b.len() {
300            return 0.0;
301        }
302
303        let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
304        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
305        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
306
307        if norm_a > 1e-10 && norm_b > 1e-10 {
308            dot / (norm_a * norm_b)
309        } else {
310            0.0
311        }
312    }
313
314    /// Get embedding dimension
315    pub fn dimension(&self) -> usize {
316        self.embedding_dim
317    }
318
319    /// Clear the embedding cache
320    pub async fn clear_cache(&self) {
321        let mut cache = self.cache.write().await;
322        cache.clear();
323    }
324
325    /// Get current cache size
326    pub async fn cache_len(&self) -> usize {
327        let cache = self.cache.read().await;
328        cache.len()
329    }
330}
331
332impl Default for EmbeddingGenerator {
333    fn default() -> Self {
334        Self::new()
335    }
336}
337
338// Re-export SimpleTokenizer for external use
339pub use self::SimpleTokenizer as Tokenizer;
340
341#[cfg(test)]
342mod tests {
343    use super::*;
344
345    #[test]
346    fn test_simple_tokenizer() {
347        let tokenizer = SimpleTokenizer::new();
348        let (ids, mask) = tokenizer.encode("hello", 10);
349        assert_eq!(ids.len(), 10);
350        assert_eq!(mask.len(), 10);
351        assert_eq!(ids[0], 2); // CLS token
352    }
353
354    #[tokio::test]
355    async fn test_embedding_generator_fallback() {
356        let generator = EmbeddingGenerator::new();
357        let embedding = generator.generate("Hello world").await.unwrap();
358        assert_eq!(embedding.len(), 384);
359
360        // Check normalization
361        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
362        assert!((norm - 1.0).abs() < 1e-5);
363    }
364
365    #[tokio::test]
366    async fn test_embedding_similarity() {
367        let generator = EmbeddingGenerator::new();
368        let emb1 = generator.generate("Hello world").await.unwrap();
369        let emb2 = generator.generate("Hello world").await.unwrap();
370        let emb3 = generator.generate("Completely different text").await.unwrap();
371
372        let sim_same = EmbeddingGenerator::cosine_similarity(&emb1, &emb2);
373        let sim_diff = EmbeddingGenerator::cosine_similarity(&emb1, &emb3);
374
375        assert!((sim_same - 1.0).abs() < 1e-5); // Same text should have similarity 1.0
376        assert!(sim_diff < sim_same); // Different text should have lower similarity
377    }
378
379    #[tokio::test]
380    async fn test_embedding_cache() {
381        let generator = EmbeddingGenerator::new();
382        assert_eq!(generator.cache_len().await, 0);
383
384        let _ = generator.generate("test text").await.unwrap();
385        assert_eq!(generator.cache_len().await, 1);
386
387        generator.clear_cache().await;
388        assert_eq!(generator.cache_len().await, 0);
389    }
390}