microscope_memory/
embedding_index.rs1use std::fs;
7use std::path::Path;
8
9use rayon::prelude::*;
10
11use crate::embeddings::{cosine_similarity_simd, EmbeddingProvider};
12
13#[allow(dead_code)]
15pub struct EmbeddingIndex {
16 data: memmap2::Mmap,
17 block_count: u32,
18 dim: u32,
19 max_depth: u32,
20}
21
22const HEADER_SIZE: usize = 12; impl EmbeddingIndex {
25 pub fn open(path: &Path) -> Option<Self> {
27 if !path.exists() {
28 return None;
29 }
30 let file = fs::File::open(path).ok()?;
31 let data = unsafe { memmap2::Mmap::map(&file).ok()? };
32 if data.len() < HEADER_SIZE {
33 return None;
34 }
35
36 let block_count = u32::from_le_bytes(data[0..4].try_into().unwrap());
37 let dim = u32::from_le_bytes(data[4..8].try_into().unwrap());
38 let max_depth = u32::from_le_bytes(data[8..12].try_into().unwrap());
39
40 let expected = HEADER_SIZE + block_count as usize * dim as usize * 4;
41 if data.len() < expected {
42 return None;
43 }
44
45 Some(EmbeddingIndex {
46 data,
47 block_count,
48 dim,
49 max_depth,
50 })
51 }
52
53 pub fn embedding(&self, block_idx: usize) -> Option<&[f32]> {
55 if block_idx >= self.block_count as usize {
56 return None;
57 }
58 let offset = HEADER_SIZE + block_idx * self.dim as usize * 4;
59 let end = offset + self.dim as usize * 4;
60 if end > self.data.len() {
61 return None;
62 }
63 let ptr = self.data[offset..end].as_ptr() as *const f32;
65 Some(unsafe { std::slice::from_raw_parts(ptr, self.dim as usize) })
66 }
67
68 pub fn block_count(&self) -> usize {
70 self.block_count as usize
71 }
72
73 pub fn dim(&self) -> usize {
75 self.dim as usize
76 }
77
78 #[allow(dead_code)]
80 pub fn max_depth(&self) -> u8 {
81 self.max_depth as u8
82 }
83
84 pub fn search(&self, query_emb: &[f32], k: usize) -> Vec<(f32, usize)> {
87 if query_emb.len() != self.dim as usize {
88 return vec![];
89 }
90
91 let mut results: Vec<(f32, usize)> = (0..self.block_count as usize)
92 .into_par_iter()
93 .filter_map(|i| {
94 let emb = self.embedding(i)?;
95 let is_zero = emb.iter().all(|&v| v == 0.0);
97 if is_zero {
98 return None;
99 }
100 let sim = cosine_similarity_simd(query_emb, emb);
101 if sim > 0.3 {
102 Some((sim, i))
103 } else {
104 None
105 }
106 })
107 .collect();
108
109 results.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
110 results.truncate(k);
111 results
112 }
113}
114
115pub fn build_embedding_index(
118 provider: &dyn EmbeddingProvider,
119 reader: &crate::MicroscopeReader,
120 max_depth: u8,
121 output_path: &Path,
122) -> Result<(), String> {
123 let dim = provider.dimension();
124
125 let mut embed_count = 0usize;
127 for d in 0..=max_depth as usize {
128 if d < reader.depth_ranges.len() {
129 embed_count += reader.depth_ranges[d].1 as usize;
130 }
131 }
132
133 println!(
134 " Embedding {} blocks (D0-D{}, dim={})...",
135 embed_count, max_depth, dim
136 );
137
138 let total_blocks = reader.block_count;
141 let mut buf = Vec::with_capacity(HEADER_SIZE + total_blocks * dim * 4);
142
143 buf.extend_from_slice(&(total_blocks as u32).to_le_bytes());
145 buf.extend_from_slice(&(dim as u32).to_le_bytes());
146 buf.extend_from_slice(&(max_depth as u32).to_le_bytes());
147
148 let zero_vec = vec![0.0f32; dim];
150 let mut embedded = 0usize;
151
152 for i in 0..total_blocks {
153 let h = reader.header(i);
154 if h.depth <= max_depth {
155 let text = reader.text(i);
156 match provider.embed(text) {
157 Ok(emb) => {
158 for &v in &emb {
159 buf.extend_from_slice(&v.to_le_bytes());
160 }
161 embedded += 1;
162 if embedded.is_multiple_of(1000) {
163 eprint!("\r Embedded {}/{}", embedded, embed_count);
164 }
165 }
166 Err(_) => {
167 for &v in &zero_vec {
168 buf.extend_from_slice(&v.to_le_bytes());
169 }
170 }
171 }
172 } else {
173 for &v in &zero_vec {
174 buf.extend_from_slice(&v.to_le_bytes());
175 }
176 }
177 }
178 eprintln!("\r Embedded {}/{}", embedded, embed_count);
179
180 fs::write(output_path, &buf).map_err(|e| format!("write embeddings.bin: {}", e))?;
181 let size_kb = buf.len() as f64 / 1024.0;
182 println!(
183 " embeddings.bin: {:.1} KB ({} blocks, {} dim)",
184 size_kb, total_blocks, dim
185 );
186
187 Ok(())
188}
189
190#[cfg(test)]
191mod tests {
192 use super::*;
193 use std::io::Write;
194
195 #[test]
196 fn test_embedding_index_roundtrip() {
197 let dir = std::env::temp_dir().join("mscope_emb_test");
198 let _ = fs::create_dir_all(&dir);
199 let path = dir.join("embeddings.bin");
200
201 let mut buf = Vec::new();
203 buf.extend_from_slice(&3u32.to_le_bytes()); buf.extend_from_slice(&4u32.to_le_bytes()); buf.extend_from_slice(&2u32.to_le_bytes()); for &v in &[1.0f32, 0.0, 0.0, 0.0] {
209 buf.extend_from_slice(&v.to_le_bytes());
210 }
211 for &v in &[0.0f32, 1.0, 0.0, 0.0] {
213 buf.extend_from_slice(&v.to_le_bytes());
214 }
215 for &v in &[0.0f32, 0.0, 0.0, 0.0] {
217 buf.extend_from_slice(&v.to_le_bytes());
218 }
219
220 let mut f = fs::File::create(&path).unwrap();
221 f.write_all(&buf).unwrap();
222
223 let idx = EmbeddingIndex::open(&path).unwrap();
224 assert_eq!(idx.block_count(), 3);
225 assert_eq!(idx.dim(), 4);
226 assert_eq!(idx.max_depth(), 2);
227
228 let emb0 = idx.embedding(0).unwrap();
229 assert_eq!(emb0, &[1.0, 0.0, 0.0, 0.0]);
230
231 let results = idx.search(&[1.0, 0.0, 0.0, 0.0], 2);
233 assert!(!results.is_empty());
234 assert_eq!(results[0].1, 0); let _ = fs::remove_dir_all(&dir);
237 }
238}