Skip to main content

microscope_memory/
embeddings.rs

1#![allow(dead_code)]
2// Embedding module for semantic vector search
3// Supports OpenAI, HuggingFace, and custom embeddings
4
5use std::collections::HashMap;
6use std::f32;
7
8pub const EMBEDDING_DIM: usize = 1536; // OpenAI ada-002 dimension
9
10/// Embedding provider trait
11pub trait EmbeddingProvider: Send + Sync {
12    fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError>;
13    fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError>;
14    fn dimension(&self) -> usize;
15}
16
17#[derive(Debug)]
18pub enum EmbeddingError {
19    ApiError(String),
20    InvalidDimension,
21    NetworkError,
22}
23
24/// Cached embedding storage
25pub struct EmbeddingCache {
26    embeddings: HashMap<String, Vec<f32>>,
27    dimension: usize,
28}
29
30impl EmbeddingCache {
31    pub fn new(dimension: usize) -> Self {
32        Self {
33            embeddings: HashMap::new(),
34            dimension,
35        }
36    }
37
38    pub fn insert(&mut self, text: String, embedding: Vec<f32>) {
39        if embedding.len() == self.dimension {
40            self.embeddings.insert(text, embedding);
41        }
42    }
43
44    pub fn get(&self, text: &str) -> Option<&Vec<f32>> {
45        self.embeddings.get(text)
46    }
47
48    pub fn contains(&self, text: &str) -> bool {
49        self.embeddings.contains_key(text)
50    }
51}
52
53/// Fast SIMD-accelerated cosine similarity
54#[cfg(target_arch = "x86_64")]
55pub fn cosine_similarity_simd(a: &[f32], b: &[f32]) -> f32 {
56    use std::arch::x86_64::*;
57
58    if a.len() != b.len() {
59        return 0.0;
60    }
61
62    unsafe {
63        let mut dot_sum = _mm256_setzero_ps();
64        let mut norm_a = _mm256_setzero_ps();
65        let mut norm_b = _mm256_setzero_ps();
66
67        let chunks = a.len() / 8;
68
69        for i in 0..chunks {
70            let va = _mm256_loadu_ps(a.as_ptr().add(i * 8));
71            let vb = _mm256_loadu_ps(b.as_ptr().add(i * 8));
72
73            dot_sum = _mm256_fmadd_ps(va, vb, dot_sum);
74            norm_a = _mm256_fmadd_ps(va, va, norm_a);
75            norm_b = _mm256_fmadd_ps(vb, vb, norm_b);
76        }
77
78        // Sum the vector components
79        let dot = horizontal_sum_ps256(dot_sum);
80        let na = horizontal_sum_ps256(norm_a).sqrt();
81        let nb = horizontal_sum_ps256(norm_b).sqrt();
82
83        // Handle remaining elements
84        let mut dot_rem = 0.0;
85        let mut na_rem = 0.0;
86        let mut nb_rem = 0.0;
87
88        for i in (chunks * 8)..a.len() {
89            dot_rem += a[i] * b[i];
90            na_rem += a[i] * a[i];
91            nb_rem += b[i] * b[i];
92        }
93
94        (dot + dot_rem) / ((na + na_rem.sqrt()) * (nb + nb_rem.sqrt()))
95    }
96}
97
98#[cfg(not(target_arch = "x86_64"))]
99pub fn cosine_similarity_simd(a: &[f32], b: &[f32]) -> f32 {
100    cosine_similarity_scalar(a, b)
101}
102
103/// Fallback scalar cosine similarity
104pub fn cosine_similarity_scalar(a: &[f32], b: &[f32]) -> f32 {
105    if a.len() != b.len() {
106        return 0.0;
107    }
108
109    let mut dot = 0.0;
110    let mut norm_a = 0.0;
111    let mut norm_b = 0.0;
112
113    for i in 0..a.len() {
114        dot += a[i] * b[i];
115        norm_a += a[i] * a[i];
116        norm_b += b[i] * b[i];
117    }
118
119    dot / (norm_a.sqrt() * norm_b.sqrt())
120}
121
122#[cfg(target_arch = "x86_64")]
123unsafe fn horizontal_sum_ps256(v: std::arch::x86_64::__m256) -> f32 {
124    use std::arch::x86_64::*;
125
126    let high = _mm256_extractf128_ps(v, 1);
127    let low = _mm256_castps256_ps128(v);
128    let sum = _mm_add_ps(high, low);
129    let shuf = _mm_shuffle_ps(sum, sum, 0x0E);
130    let sums = _mm_add_ps(sum, shuf);
131    let shuf2 = _mm_movehl_ps(sums, sums);
132    let result = _mm_add_ss(sums, shuf2);
133    _mm_cvtss_f32(result)
134}
135
136/// Mock embedding provider for testing
137pub struct MockEmbeddingProvider {
138    dimension: usize,
139}
140
141impl MockEmbeddingProvider {
142    pub fn new(dimension: usize) -> Self {
143        Self { dimension }
144    }
145}
146
147impl EmbeddingProvider for MockEmbeddingProvider {
148    fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
149        // Simple hash-based embedding for testing
150        let mut embedding = vec![0.0; self.dimension];
151        let hash = text
152            .bytes()
153            .fold(0u64, |acc, b| acc.wrapping_mul(31).wrapping_add(b as u64));
154
155        for (i, slot) in embedding.iter_mut().enumerate() {
156            let val = ((hash.wrapping_mul(i as u64 + 1)) % 1000) as f32 / 1000.0;
157            *slot = val * 2.0 - 1.0; // Normalize to [-1, 1]
158        }
159
160        // Normalize to unit vector
161        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
162        if norm > 0.0 {
163            for val in &mut embedding {
164                *val /= norm;
165            }
166        }
167
168        Ok(embedding)
169    }
170
171    fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
172        texts.iter().map(|t| self.embed(t)).collect()
173    }
174
175    fn dimension(&self) -> usize {
176        self.dimension
177    }
178}
179
180/// Embedding-enhanced block header
181#[repr(C, packed)]
182pub struct EmbeddedBlockHeader {
183    // Original fields
184    pub x: f32,
185    pub y: f32,
186    pub z: f32,
187    pub zoom: f32,
188    pub depth: u8,
189    pub layer_id: u8,
190    pub data_offset: u32,
191    pub data_len: u16,
192    pub parent_idx: u32,
193    pub child_count: u16,
194    pub crc16: [u8; 2], // CRC16-CCITT (0x0000 = no checksum)
195
196    // New embedding fields
197    pub embedding_offset: u32, // Offset into embedding file
198    pub has_embedding: bool,   // Whether this block has an embedding
199}
200
201// ─── Candle-based real embedding provider ────────────
202#[cfg(feature = "embeddings")]
203pub struct CandleEmbeddingProvider {
204    model: candle_transformers::models::bert::BertModel,
205    tokenizer: tokenizers::Tokenizer,
206    dim: usize,
207    device: candle_core::Device,
208}
209
210#[cfg(feature = "embeddings")]
211impl CandleEmbeddingProvider {
212    pub fn new(model_id: &str) -> Result<Self, EmbeddingError> {
213        use candle_core::Device;
214        use hf_hub::api::sync::Api;
215
216        let device = Device::Cpu;
217        let api = Api::new().map_err(|e| EmbeddingError::ApiError(e.to_string()))?;
218        let repo = api.model(model_id.to_string());
219
220        // Load tokenizer
221        let tokenizer_path = repo
222            .get("tokenizer.json")
223            .map_err(|e| EmbeddingError::ApiError(format!("tokenizer download: {}", e)))?;
224        let tokenizer = tokenizers::Tokenizer::from_file(tokenizer_path)
225            .map_err(|e| EmbeddingError::ApiError(format!("tokenizer load: {}", e)))?;
226
227        // Load model weights
228        let weights_path = repo
229            .get("model.safetensors")
230            .map_err(|e| EmbeddingError::ApiError(format!("weights download: {}", e)))?;
231        // Safety: safetensors file is valid and will remain mapped for the lifetime of the model
232        let vb = unsafe {
233            candle_nn::VarBuilder::from_mmaped_safetensors(
234                &[weights_path],
235                candle_core::DType::F32,
236                &device,
237            )
238        }
239        .map_err(|e| EmbeddingError::ApiError(format!("varbuilder: {}", e)))?;
240
241        // Load config (Zero-JSON: hardcoded BERT-base-uncased defaults)
242        let config = candle_transformers::models::bert::Config {
243            vocab_size: 30522,
244            hidden_size: 768,
245            num_hidden_layers: 12,
246            num_attention_heads: 12,
247            intermediate_size: 3072,
248            hidden_act: candle_transformers::models::bert::Activation::Gelu,
249            hidden_dropout_prob: 0.1,
250            attention_probs_dropout_prob: 0.1,
251            max_position_embeddings: 512,
252            type_vocab_size: 2,
253            initializer_range: 0.02,
254            layer_norm_eps: 1e-12,
255            pad_token_id: 0,
256            model_type: Some("bert".to_string()),
257        };
258        let dim = config.hidden_size;
259
260        let model = candle_transformers::models::bert::BertModel::load(vb, &config)
261            .map_err(|e| EmbeddingError::ApiError(format!("model load: {}", e)))?;
262
263        Ok(Self {
264            model,
265            tokenizer,
266            dim,
267            device,
268        })
269    }
270
271    fn embed_inner(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
272        use candle_core::Tensor;
273
274        let encoding = self
275            .tokenizer
276            .encode(text, true)
277            .map_err(|e| EmbeddingError::ApiError(format!("tokenize: {}", e)))?;
278
279        let ids = encoding.get_ids();
280        let type_ids = encoding.get_type_ids();
281        let len = ids.len();
282
283        let input_ids = Tensor::new(ids, &self.device)
284            .map_err(|e| EmbeddingError::ApiError(e.to_string()))?
285            .reshape((1, len))
286            .map_err(|e| EmbeddingError::ApiError(e.to_string()))?;
287        let token_type_ids = Tensor::new(type_ids, &self.device)
288            .map_err(|e| EmbeddingError::ApiError(e.to_string()))?
289            .reshape((1, len))
290            .map_err(|e| EmbeddingError::ApiError(e.to_string()))?;
291
292        let output = self
293            .model
294            .forward(&input_ids, &token_type_ids, None)
295            .map_err(|e| EmbeddingError::ApiError(format!("forward: {}", e)))?;
296
297        // Mean pooling over sequence dimension
298        let pooled = output
299            .mean(1)
300            .map_err(|e| EmbeddingError::ApiError(format!("mean pool: {}", e)))?
301            .squeeze(0)
302            .map_err(|e| EmbeddingError::ApiError(format!("squeeze: {}", e)))?;
303
304        let mut embedding: Vec<f32> = pooled
305            .to_vec1()
306            .map_err(|e| EmbeddingError::ApiError(format!("to_vec: {}", e)))?;
307
308        // L2 normalize
309        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
310        if norm > 0.0 {
311            for val in &mut embedding {
312                *val /= norm;
313            }
314        }
315
316        Ok(embedding)
317    }
318}
319
320#[cfg(feature = "embeddings")]
321impl EmbeddingProvider for CandleEmbeddingProvider {
322    fn embed(&self, text: &str) -> Result<Vec<f32>, EmbeddingError> {
323        self.embed_inner(text)
324    }
325
326    fn embed_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, EmbeddingError> {
327        texts.iter().map(|t| self.embed_inner(t)).collect()
328    }
329
330    fn dimension(&self) -> usize {
331        self.dim
332    }
333}
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338
339    #[test]
340    fn test_cosine_similarity() {
341        let a = vec![1.0, 0.0, 0.0];
342        let b = vec![1.0, 0.0, 0.0];
343        assert!((cosine_similarity_scalar(&a, &b) - 1.0).abs() < 0.001);
344
345        let c = vec![0.0, 1.0, 0.0];
346        assert!((cosine_similarity_scalar(&a, &c) - 0.0).abs() < 0.001);
347
348        let d = vec![-1.0, 0.0, 0.0];
349        assert!((cosine_similarity_scalar(&a, &d) - -1.0).abs() < 0.001);
350    }
351
352    #[test]
353    fn test_mock_embeddings() {
354        let provider = MockEmbeddingProvider::new(128);
355        let embedding = provider.embed("test text").unwrap();
356
357        assert_eq!(embedding.len(), 128);
358
359        // Check normalization
360        let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
361        assert!((norm - 1.0).abs() < 0.001);
362    }
363
364    #[test]
365    fn test_embedding_cache() {
366        let mut cache = EmbeddingCache::new(3);
367        let embedding = vec![1.0, 0.0, 0.0];
368
369        cache.insert("test".to_string(), embedding.clone());
370        assert!(cache.contains("test"));
371        assert_eq!(cache.get("test"), Some(&embedding));
372    }
373}