Skip to main content

microscope_memory/
embedding_index.rs

1//! Embedding index: mmap-backed pre-computed embedding vectors.
2//!
3//! Format: [u32 block_count][u32 dim][u32 max_depth][f32 × dim × embedded_count]
4//! Only blocks at depth 0..max_depth are embedded.
5
6use std::fs;
7use std::path::Path;
8
9use rayon::prelude::*;
10
11use crate::embeddings::{cosine_similarity_simd, EmbeddingProvider};
12
13/// Mmap-backed embedding index for fast semantic lookup.
14#[allow(dead_code)]
15pub struct EmbeddingIndex {
16    data: memmap2::Mmap,
17    block_count: u32,
18    dim: u32,
19    max_depth: u32,
20}
21
22const HEADER_SIZE: usize = 12; // 3 × u32
23
24impl EmbeddingIndex {
25    /// Open an existing embeddings.bin file.
26    pub fn open(path: &Path) -> Option<Self> {
27        if !path.exists() {
28            return None;
29        }
30        let file = fs::File::open(path).ok()?;
31        let data = unsafe { memmap2::Mmap::map(&file).ok()? };
32        if data.len() < HEADER_SIZE {
33            return None;
34        }
35
36        let block_count = u32::from_le_bytes(data[0..4].try_into().unwrap());
37        let dim = u32::from_le_bytes(data[4..8].try_into().unwrap());
38        let max_depth = u32::from_le_bytes(data[8..12].try_into().unwrap());
39
40        let expected = HEADER_SIZE + block_count as usize * dim as usize * 4;
41        if data.len() < expected {
42            return None;
43        }
44
45        Some(EmbeddingIndex {
46            data,
47            block_count,
48            dim,
49            max_depth,
50        })
51    }
52
53    /// Get embedding for block at index (zero-copy mmap access).
54    pub fn embedding(&self, block_idx: usize) -> Option<&[f32]> {
55        if block_idx >= self.block_count as usize {
56            return None;
57        }
58        let offset = HEADER_SIZE + block_idx * self.dim as usize * 4;
59        let end = offset + self.dim as usize * 4;
60        if end > self.data.len() {
61            return None;
62        }
63        // Safety: data is aligned to f32 by construction during build
64        let ptr = self.data[offset..end].as_ptr() as *const f32;
65        Some(unsafe { std::slice::from_raw_parts(ptr, self.dim as usize) })
66    }
67
68    /// Number of embedded blocks.
69    pub fn block_count(&self) -> usize {
70        self.block_count as usize
71    }
72
73    /// Embedding dimension.
74    pub fn dim(&self) -> usize {
75        self.dim as usize
76    }
77
78    /// Max depth that was embedded.
79    #[allow(dead_code)]
80    pub fn max_depth(&self) -> u8 {
81        self.max_depth as u8
82    }
83
84    /// Search for top-k most similar blocks to query embedding.
85    /// Returns Vec<(similarity, block_index)> sorted descending.
86    pub fn search(&self, query_emb: &[f32], k: usize) -> Vec<(f32, usize)> {
87        if query_emb.len() != self.dim as usize {
88            return vec![];
89        }
90
91        let mut results: Vec<(f32, usize)> = (0..self.block_count as usize)
92            .into_par_iter()
93            .filter_map(|i| {
94                let emb = self.embedding(i)?;
95                // Check for zero embedding (unembedded block placeholder)
96                let is_zero = emb.iter().all(|&v| v == 0.0);
97                if is_zero {
98                    return None;
99                }
100                let sim = cosine_similarity_simd(query_emb, emb);
101                if sim > 0.3 {
102                    Some((sim, i))
103                } else {
104                    None
105                }
106            })
107            .collect();
108
109        results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
110        results.truncate(k);
111        results
112    }
113}
114
115/// Build embedding index file from a provider and reader.
116/// Only embeds blocks at depth 0..=max_depth.
117pub fn build_embedding_index(
118    provider: &dyn EmbeddingProvider,
119    reader: &crate::MicroscopeReader,
120    max_depth: u8,
121    output_path: &Path,
122) -> Result<(), String> {
123    let dim = provider.dimension();
124
125    // Count blocks to embed (depth 0..=max_depth)
126    let mut embed_count = 0usize;
127    for d in 0..=max_depth as usize {
128        if d < reader.depth_ranges.len() {
129            embed_count += reader.depth_ranges[d].1 as usize;
130        }
131    }
132
133    println!(
134        "  Embedding {} blocks (D0-D{}, dim={})...",
135        embed_count, max_depth, dim
136    );
137
138    // Build embeddings buffer: header + flat f32 vectors
139    // Blocks outside max_depth get zero vectors
140    let total_blocks = reader.block_count;
141    let mut buf = Vec::with_capacity(HEADER_SIZE + total_blocks * dim * 4);
142
143    // Header
144    buf.extend_from_slice(&(total_blocks as u32).to_le_bytes());
145    buf.extend_from_slice(&(dim as u32).to_le_bytes());
146    buf.extend_from_slice(&(max_depth as u32).to_le_bytes());
147
148    // Embed blocks
149    let zero_vec = vec![0.0f32; dim];
150    let mut embedded = 0usize;
151
152    for i in 0..total_blocks {
153        let h = reader.header(i);
154        if h.depth <= max_depth {
155            let text = reader.text(i);
156            match provider.embed(text) {
157                Ok(emb) => {
158                    for &v in &emb {
159                        buf.extend_from_slice(&v.to_le_bytes());
160                    }
161                    embedded += 1;
162                    if embedded.is_multiple_of(1000) {
163                        eprint!("\r  Embedded {}/{}", embedded, embed_count);
164                    }
165                }
166                Err(_) => {
167                    for &v in &zero_vec {
168                        buf.extend_from_slice(&v.to_le_bytes());
169                    }
170                }
171            }
172        } else {
173            for &v in &zero_vec {
174                buf.extend_from_slice(&v.to_le_bytes());
175            }
176        }
177    }
178    eprintln!("\r  Embedded {}/{}", embedded, embed_count);
179
180    fs::write(output_path, &buf).map_err(|e| format!("write embeddings.bin: {}", e))?;
181    let size_kb = buf.len() as f64 / 1024.0;
182    println!(
183        "  embeddings.bin: {:.1} KB ({} blocks, {} dim)",
184        size_kb, total_blocks, dim
185    );
186
187    Ok(())
188}
189
190#[cfg(test)]
191mod tests {
192    use super::*;
193    use std::io::Write;
194
195    #[test]
196    fn test_embedding_index_roundtrip() {
197        let dir = std::env::temp_dir().join("mscope_emb_test");
198        let _ = fs::create_dir_all(&dir);
199        let path = dir.join("embeddings.bin");
200
201        // Build a small test file: 3 blocks, dim=4
202        let mut buf = Vec::new();
203        buf.extend_from_slice(&3u32.to_le_bytes()); // block_count
204        buf.extend_from_slice(&4u32.to_le_bytes()); // dim
205        buf.extend_from_slice(&2u32.to_le_bytes()); // max_depth
206
207        // Block 0: [1, 0, 0, 0]
208        for &v in &[1.0f32, 0.0, 0.0, 0.0] {
209            buf.extend_from_slice(&v.to_le_bytes());
210        }
211        // Block 1: [0, 1, 0, 0]
212        for &v in &[0.0f32, 1.0, 0.0, 0.0] {
213            buf.extend_from_slice(&v.to_le_bytes());
214        }
215        // Block 2: zero (not embedded)
216        for &v in &[0.0f32, 0.0, 0.0, 0.0] {
217            buf.extend_from_slice(&v.to_le_bytes());
218        }
219
220        let mut f = fs::File::create(&path).unwrap();
221        f.write_all(&buf).unwrap();
222
223        let idx = EmbeddingIndex::open(&path).unwrap();
224        assert_eq!(idx.block_count(), 3);
225        assert_eq!(idx.dim(), 4);
226        assert_eq!(idx.max_depth(), 2);
227
228        let emb0 = idx.embedding(0).unwrap();
229        assert_eq!(emb0, &[1.0, 0.0, 0.0, 0.0]);
230
231        // Search with query [1, 0, 0, 0] should find block 0
232        let results = idx.search(&[1.0, 0.0, 0.0, 0.0], 2);
233        assert!(!results.is_empty());
234        assert_eq!(results[0].1, 0); // block 0 should be most similar
235
236        let _ = fs::remove_dir_all(&dir);
237    }
238}