1use std::io::{BufReader, BufWriter, Read, Write};
2use std::path::{Path, PathBuf};
3use std::sync::{Arc, Mutex, OnceLock};
4
5use anyhow::{Context, Result};
6
7use crate::model::Symbol;
8
9struct CachedEmbeddings {
10 path: PathBuf,
11 modified: std::time::SystemTime,
12 data: Vec<(String, Vec<f32>)>,
13}
14
15static EMBEDDINGS_CACHE: OnceLock<Mutex<Option<CachedEmbeddings>>> = OnceLock::new();
16
17fn cache_lock() -> &'static Mutex<Option<CachedEmbeddings>> {
18 EMBEDDINGS_CACHE.get_or_init(|| Mutex::new(None))
19}
20
21pub fn load_embeddings_cached(path: &Path) -> Result<Vec<(String, Vec<f32>)>> {
25 let meta = std::fs::metadata(path).context("stat embeddings file")?;
26 let mtime = meta.modified().unwrap_or(std::time::UNIX_EPOCH);
27 let canon = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
28
29 let guard = cache_lock().lock().unwrap();
30 if let Some(cached) = guard.as_ref() {
31 if cached.path == canon && cached.modified == mtime {
32 return Ok(cached.data.clone());
33 }
34 }
35 drop(guard);
36
37 let data = load_embeddings(path)?;
38 let mut guard = cache_lock().lock().unwrap();
39 *guard = Some(CachedEmbeddings {
40 path: canon,
41 modified: mtime,
42 data: data.clone(),
43 });
44 Ok(data)
45}
46
47pub fn invalidate_embeddings_cache() {
49 if let Ok(mut guard) = cache_lock().lock() {
50 *guard = None;
51 }
52}
53
54pub trait EmbedProvider: Send + Sync {
56 fn dimension(&self) -> usize;
57 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>>;
58
59 fn embed(&self, text: &str) -> Result<Vec<f32>> {
60 let mut results = self.embed_batch(&[text])?;
61 results
62 .pop()
63 .ok_or_else(|| anyhow::anyhow!("embedding returned no results"))
64 }
65}
66
67pub fn symbol_text(sym: &Symbol) -> String {
69 let mut text = format!("{} {} {}", sym.kind.as_str(), sym.name, sym.language);
70 if let Some(doc) = &sym.docstring {
71 if !doc.is_empty() {
72 text.push_str(": ");
73 text.push_str(doc);
74 }
75 }
76 text
77}
78
79pub fn rich_symbol_text(kind: &str, name: &str, file: &str, language: &str, doc: &str) -> String {
81 rich_symbol_text_full(kind, name, file, language, doc, "", "")
82}
83
84pub fn rich_symbol_text_full(
86 kind: &str,
87 name: &str,
88 file: &str,
89 language: &str,
90 doc: &str,
91 params: &str,
92 ret: &str,
93) -> String {
94 let path_context = path_to_context(file);
95 let mut text = format!("{kind} {name}");
96 if !params.is_empty() {
97 text.push_str(params);
98 }
99 if !ret.is_empty() {
100 text.push_str(" -> ");
101 text.push_str(ret);
102 }
103 text.push_str(" in ");
104 text.push_str(&path_context);
105 if !language.is_empty() {
106 text.push(' ');
107 text.push_str(language);
108 }
109 if !doc.is_empty() {
110 text.push_str(": ");
111 text.push_str(doc);
112 }
113 text
114}
115
116pub fn path_to_context(file: &str) -> String {
118 let parts: Vec<&str> = file.split('/').collect();
119 if parts.len() <= 3 {
120 return file.to_string();
121 }
122 let filename = parts.last().unwrap_or(&"");
123 let meaningful: Vec<&str> = parts
124 .iter()
125 .filter(|p| {
126 let lower = p.to_lowercase();
127 !matches!(
128 lower.as_str(),
129 "src" | "source" | "lib" | "include" | "_h" | "test" | "tests" | "benchmark"
130 )
131 })
132 .copied()
133 .collect();
134 if meaningful.len() <= 4 {
135 meaningful.join("/")
136 } else {
137 let last4 = &meaningful[meaningful.len() - 4..];
138 if last4.contains(filename) {
139 last4.join("/")
140 } else {
141 format!("{}/{}", last4[1..].join("/"), filename)
142 }
143 }
144}
145
146pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
148 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
149}
150
151pub struct TrigramEmbedder {
158 dim: usize,
159}
160
161impl TrigramEmbedder {
162 pub fn new(dim: usize) -> Self {
163 Self { dim }
164 }
165}
166
167impl Default for TrigramEmbedder {
168 fn default() -> Self {
169 Self::new(256)
170 }
171}
172
173impl EmbedProvider for TrigramEmbedder {
174 fn dimension(&self) -> usize {
175 self.dim
176 }
177
178 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
179 Ok(texts.iter().map(|t| trigram_embed(t, self.dim)).collect())
180 }
181}
182
183fn trigram_embed(text: &str, dim: usize) -> Vec<f32> {
185 let mut vec = vec![0.0f32; dim];
186 let lower = text.to_lowercase();
187 let chars: Vec<char> = lower.chars().collect();
188
189 if chars.len() < 3 {
190 for c in &chars {
192 let h = fnv1a(&[*c as u8]) as usize % dim;
193 vec[h] += 1.0;
194 }
195 if chars.len() == 2 {
196 let bigram = format!("{}{}", chars[0], chars[1]);
197 let h = fnv1a(bigram.as_bytes()) as usize % dim;
198 vec[h] += 1.0;
199 }
200 } else {
201 for window in chars.windows(3) {
202 let trigram: String = window.iter().collect();
203 let h = fnv1a(trigram.as_bytes()) as usize % dim;
204 vec[h] += 1.0;
205 }
206 }
207
208 for token in lower.split(|c: char| !c.is_alphanumeric() && c != '_') {
210 if token.len() > 1 {
211 let h = fnv1a(token.as_bytes()) as usize % dim;
212 vec[h] += 0.5; }
214 }
215
216 let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
218 if norm > 0.0 {
219 for v in &mut vec {
220 *v /= norm;
221 }
222 }
223
224 vec
225}
226
227pub struct Model2VecEmbedder {
233 model: model2vec_rs::model::StaticModel,
234}
235
236impl Model2VecEmbedder {
237 pub fn new() -> Result<Self> {
239 let model_dir = Self::find_model_dir()?;
241 let model = model2vec_rs::model::StaticModel::from_pretrained(model_dir, None, None, None)?;
242 Ok(Self { model })
243 }
244
245 fn find_model_dir() -> Result<std::path::PathBuf> {
246 if let Ok(p) = std::env::var("INFIGRAPH_MODEL_DIR") {
248 let pb = std::path::PathBuf::from(p);
249 if pb.exists() {
250 return Ok(pb);
251 }
252 }
253 if let Some(home) = dirs_next::home_dir() {
255 let installed = home
256 .join(".infigraph")
257 .join("models")
258 .join("potion-base-8M");
259 if installed.join("model.safetensors").exists() {
260 return Ok(installed);
261 }
262 }
263 let start =
265 std::env::current_exe().unwrap_or_else(|_| std::env::current_dir().unwrap_or_default());
266 let mut dir = start.as_path();
267 loop {
268 let candidate = dir.join("models/potion-base-8M");
269 if candidate.join("model.safetensors").exists() {
270 return Ok(candidate);
271 }
272 match dir.parent() {
273 Some(p) => dir = p,
274 None => break,
275 }
276 }
277 let cwd = std::env::current_dir()?;
279 let mut dir = cwd.as_path();
280 loop {
281 let candidate = dir.join("models/potion-base-8M");
282 if candidate.join("model.safetensors").exists() {
283 return Ok(candidate);
284 }
285 match dir.parent() {
286 Some(p) => dir = p,
287 None => break,
288 }
289 }
290 anyhow::bail!(
291 "models/potion-base-8M not found; set INFIGRAPH_MODEL_DIR or run from repo root"
292 )
293 }
294}
295
296impl EmbedProvider for Model2VecEmbedder {
297 fn dimension(&self) -> usize {
298 256 }
300
301 fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
302 let owned: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
303 Ok(self.model.encode(&owned))
304 }
305}
306
307static CODE_EMBEDDER: OnceLock<Arc<dyn EmbedProvider>> = OnceLock::new();
308static DOC_EMBEDDER: OnceLock<Arc<dyn EmbedProvider>> = OnceLock::new();
309
310pub fn init_embedder() -> Arc<dyn EmbedProvider> {
312 match Model2VecEmbedder::new() {
313 Ok(m) => Arc::new(m),
314 Err(e) => {
315 eprintln!("warning: Model2Vec unavailable ({e}), using trigram fallback");
316 Arc::new(TrigramEmbedder::default())
317 }
318 }
319}
320
321pub fn code_embedder() -> Arc<dyn EmbedProvider> {
323 Arc::clone(CODE_EMBEDDER.get_or_init(init_embedder))
324}
325
326pub fn doc_embedder() -> Arc<dyn EmbedProvider> {
328 Arc::clone(DOC_EMBEDDER.get_or_init(init_embedder))
329}
330
331pub fn best_embedder() -> Box<dyn EmbedProvider> {
333 match Model2VecEmbedder::new() {
334 Ok(m) => Box::new(m),
335 Err(e) => {
336 eprintln!("warning: Model2Vec unavailable ({e}), using trigram fallback");
337 Box::new(TrigramEmbedder::default())
338 }
339 }
340}
341
342pub fn embedding_count(root: &Path) -> usize {
344 let path = root.join(".infigraph").join("embeddings.bin");
345 let Ok(file) = std::fs::File::open(&path) else {
346 return 0;
347 };
348 let mut r = BufReader::new(file);
349 let mut buf4 = [0u8; 4];
350 if r.read_exact(&mut buf4).is_err() {
351 return 0;
352 }
353 u32::from_le_bytes(buf4) as usize
354}
355
356pub fn save_embeddings(path: &Path, embeddings: &[(String, Vec<f32>)]) -> Result<()> {
359 let file = std::fs::File::create(path).context("create embeddings file")?;
360 let mut w = BufWriter::new(file);
361 w.write_all(&(embeddings.len() as u32).to_le_bytes())?;
362 for (id, vec) in embeddings {
363 let id_bytes = id.as_bytes();
364 w.write_all(&(id_bytes.len() as u32).to_le_bytes())?;
365 w.write_all(id_bytes)?;
366 w.write_all(&(vec.len() as u32).to_le_bytes())?;
367 for &v in vec {
368 w.write_all(&v.to_le_bytes())?;
369 }
370 }
371 drop(w);
372 invalidate_embeddings_cache();
373 Ok(())
374}
375
376pub fn load_embeddings(path: &Path) -> Result<Vec<(String, Vec<f32>)>> {
378 let file = std::fs::File::open(path).context("open embeddings file")?;
379 let mut r = BufReader::new(file);
380 let mut buf4 = [0u8; 4];
381 r.read_exact(&mut buf4)?;
382 let count = u32::from_le_bytes(buf4) as usize;
383 let mut result = Vec::with_capacity(count);
384 for _ in 0..count {
385 r.read_exact(&mut buf4)?;
386 let id_len = u32::from_le_bytes(buf4) as usize;
387 let mut id_buf = vec![0u8; id_len];
388 r.read_exact(&mut id_buf)?;
389 let id = String::from_utf8(id_buf).context("invalid utf8 in embedding id")?;
390 r.read_exact(&mut buf4)?;
391 let dim = u32::from_le_bytes(buf4) as usize;
392 let mut vec = Vec::with_capacity(dim);
393 for _ in 0..dim {
394 r.read_exact(&mut buf4)?;
395 vec.push(f32::from_le_bytes(buf4));
396 }
397 result.push((id, vec));
398 }
399 Ok(result)
400}
401
402pub fn update_embeddings(
407 store: &crate::graph::GraphStore,
408 root: &Path,
409 changed_files: &[&str],
410) -> Result<usize> {
411 use rayon::prelude::*;
412 use std::sync::Arc;
413
414 let conn = store.connection()?;
415 let gq = crate::graph::GraphQuery::new(&conn);
416 let rows = gq.raw_query("MATCH (s:Symbol) RETURN s.id, s.name, s.kind, s.file, s.docstring, s.language, s.parameters, s.return_type")?;
417 if rows.is_empty() {
418 return Ok(0);
419 }
420
421 let emb_path = root.join(".infigraph").join("embeddings.bin");
422 let mut existing: std::collections::HashMap<String, Vec<f32>> = load_embeddings(&emb_path)
423 .unwrap_or_default()
424 .into_iter()
425 .collect();
426
427 let changed_set: std::collections::HashSet<&str> = changed_files.iter().copied().collect();
428
429 let to_embed: Vec<(String, String)> = rows
430 .iter()
431 .filter_map(|row| {
432 let id = &row[0];
433 let file = row.get(3).map(|s| s.as_str()).unwrap_or("");
434 if !changed_set.is_empty() && !changed_set.contains(file) && existing.contains_key(id) {
435 return None;
436 }
437 let name = &row[1];
438 let kind = &row[2];
439 let doc = row.get(4).map(|s| s.as_str()).unwrap_or("");
440 let lang = row.get(5).map(|s| s.as_str()).unwrap_or("");
441 let params = row.get(6).map(|s| s.as_str()).unwrap_or("");
442 let ret = row.get(7).map(|s| s.as_str()).unwrap_or("");
443 let text = rich_symbol_text_full(kind, name, file, lang, doc, params, ret);
444 Some((id.clone(), text))
445 })
446 .collect();
447
448 if !to_embed.is_empty() {
449 let embedder: Arc<Box<dyn EmbedProvider>> = Arc::new(best_embedder());
450 const BATCH: usize = 256;
451 let results: Vec<Vec<(String, Vec<f32>)>> = to_embed
452 .par_chunks(BATCH)
453 .map(|chunk| {
454 let emb = Arc::clone(&embedder);
455 let texts: Vec<&str> = chunk.iter().map(|(_, t)| t.as_str()).collect();
456 let vecs = emb.embed_batch(&texts).unwrap_or_default();
457 chunk
458 .iter()
459 .enumerate()
460 .filter_map(|(i, (id, _))| vecs.get(i).map(|v| (id.clone(), v.clone())))
461 .collect()
462 })
463 .collect();
464 for batch in results {
465 for (id, v) in batch {
466 existing.insert(id, v);
467 }
468 }
469 }
470
471 let all_ids: std::collections::HashSet<String> = rows.iter().map(|r| r[0].clone()).collect();
472 existing.retain(|id, _| all_ids.contains(id));
473
474 let symbol_embeddings: Vec<(String, Vec<f32>)> = existing.into_iter().collect();
475 let count = symbol_embeddings.len();
476 save_embeddings(&emb_path, &symbol_embeddings)?;
477
478 const HNSW_THRESHOLD: usize = 200_000;
482 let hnsw_path = root.join(".infigraph").join("hnsw_index.usearch");
483 let should_build = count >= HNSW_THRESHOLD || hnsw_path.exists();
484 if should_build {
485 invalidate_hnsw_cache();
486 if let Err(e) = build_hnsw_index(&symbol_embeddings, &hnsw_path, &emb_path) {
487 eprintln!("warning: HNSW index build failed ({e}), vector search will use brute-force");
488 }
489 }
490
491 Ok(count)
492}
493
494use usearch::{Index as UsearchIndex, IndexOptions, MetricKind, ScalarKind};
499
500const HNSW_CONNECTIVITY: usize = 32;
501const HNSW_EXPANSION_ADD: usize = 200;
502const HNSW_EXPANSION_SEARCH: usize = 256;
503const HNSW_OVERSAMPLE: usize = 20;
504
505static HNSW_CACHE: OnceLock<Mutex<Option<CachedHnsw>>> = OnceLock::new();
506
507struct CachedHnsw {
508 path: PathBuf,
509 modified: std::time::SystemTime,
510 index: UsearchIndex,
511 id_map: Vec<String>,
512}
513
514fn hnsw_cache_lock() -> &'static Mutex<Option<CachedHnsw>> {
515 HNSW_CACHE.get_or_init(|| Mutex::new(None))
516}
517
518fn hnsw_opts(dim: usize) -> IndexOptions {
519 IndexOptions {
520 dimensions: dim,
521 metric: MetricKind::IP,
522 quantization: ScalarKind::F32,
523 connectivity: HNSW_CONNECTIVITY,
524 expansion_add: HNSW_EXPANSION_ADD,
525 ..IndexOptions::default()
526 }
527}
528
529pub fn build_hnsw_index(
532 embeddings: &[(String, Vec<f32>)],
533 index_path: &Path,
534 embeddings_path: &Path,
535) -> Result<usize> {
536 if embeddings.is_empty() {
537 return Ok(0);
538 }
539
540 let dim = embeddings[0].1.len();
541 let n = embeddings.len();
542 let threads = std::thread::available_parallelism()
543 .map(|t| t.get())
544 .unwrap_or(4);
545
546 let index =
547 UsearchIndex::new(&hnsw_opts(dim)).map_err(|e| anyhow::anyhow!("usearch create: {e}"))?;
548 index
549 .reserve(n)
550 .map_err(|e| anyhow::anyhow!("usearch reserve: {e}"))?;
551
552 let index = std::sync::Arc::new(index);
553 let chunk_size = n.div_ceil(threads);
554 std::thread::scope(|s| {
555 for (chunk_idx, chunk) in embeddings.chunks(chunk_size).enumerate() {
556 let idx = std::sync::Arc::clone(&index);
557 let offset = chunk_idx * chunk_size;
558 s.spawn(move || {
559 for (i, (_, v)) in chunk.iter().enumerate() {
560 let _ = idx.add((offset + i) as u64, v);
561 }
562 });
563 }
564 });
565
566 let path_str = index_path
567 .to_str()
568 .ok_or_else(|| anyhow::anyhow!("non-utf8 index path"))?;
569 index
570 .save(path_str)
571 .map_err(|e| anyhow::anyhow!("usearch save: {e}"))?;
572
573 let emb_mtime = std::fs::metadata(embeddings_path)
574 .and_then(|m| m.modified())
575 .unwrap_or(std::time::UNIX_EPOCH);
576 let sidecar_path = index_path.with_extension("meta");
577 let ids: Vec<&str> = embeddings.iter().map(|(id, _)| id.as_str()).collect();
578 let sidecar = serde_json::json!({
579 "emb_mtime_secs": emb_mtime.duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_secs(),
580 "count": n,
581 "dim": dim,
582 "ids": ids,
583 });
584 std::fs::write(&sidecar_path, serde_json::to_vec(&sidecar)?).context("write hnsw sidecar")?;
585
586 invalidate_hnsw_cache();
587 Ok(n)
588}
589
590pub fn invalidate_hnsw_cache() {
592 if let Ok(mut guard) = hnsw_cache_lock().lock() {
593 *guard = None;
594 }
595}
596
597pub struct HnswResult {
599 pub id: String,
600 pub score: f32,
601}
602
603fn query_index(
604 index: &UsearchIndex,
605 id_map: &[String],
606 query: &[f32],
607 top_k: usize,
608) -> Result<Vec<HnswResult>> {
609 let fetch_k = top_k * HNSW_OVERSAMPLE;
610 let results = index
611 .search(query, fetch_k)
612 .map_err(|e| anyhow::anyhow!("usearch search: {e}"))?;
613 let out: Vec<HnswResult> = results
614 .keys
615 .iter()
616 .zip(results.distances.iter())
617 .filter_map(|(&key, &dist)| {
618 let idx = key as usize;
619 id_map.get(idx).map(|id| HnswResult {
620 id: id.clone(),
621 score: 1.0 - dist,
622 })
623 })
624 .collect();
625 Ok(out)
626}
627
628pub fn search_hnsw(
631 index_path: &Path,
632 embeddings_path: &Path,
633 query: &[f32],
634 top_k: usize,
635) -> Result<Option<Vec<HnswResult>>> {
636 let sidecar_path = index_path.with_extension("meta");
637 if !index_path.exists() || !sidecar_path.exists() {
638 return Ok(None);
639 }
640
641 let emb_mtime_secs = std::fs::metadata(embeddings_path)
642 .and_then(|m| m.modified())
643 .unwrap_or(std::time::UNIX_EPOCH)
644 .duration_since(std::time::UNIX_EPOCH)
645 .unwrap_or_default()
646 .as_secs();
647
648 let canon = index_path
649 .canonicalize()
650 .unwrap_or_else(|_| index_path.to_path_buf());
651 let idx_mtime = std::fs::metadata(index_path)
652 .and_then(|m| m.modified())
653 .unwrap_or(std::time::UNIX_EPOCH);
654
655 let guard = hnsw_cache_lock().lock().unwrap();
657 if let Some(cached) = guard.as_ref() {
658 if cached.path == canon && cached.modified == idx_mtime {
659 return Ok(Some(query_index(
660 &cached.index,
661 &cached.id_map,
662 query,
663 top_k,
664 )?));
665 }
666 }
667 drop(guard);
668
669 let sidecar_bytes = std::fs::read(&sidecar_path).context("read hnsw sidecar")?;
671 let sidecar: serde_json::Value =
672 serde_json::from_slice(&sidecar_bytes).context("parse hnsw sidecar")?;
673 let stored_mtime = sidecar["emb_mtime_secs"].as_u64().unwrap_or(0);
674 if stored_mtime != emb_mtime_secs {
675 return Ok(None);
676 }
677
678 let id_map: Vec<String> = sidecar["ids"]
679 .as_array()
680 .map(|arr| {
681 arr.iter()
682 .filter_map(|v| v.as_str().map(String::from))
683 .collect()
684 })
685 .unwrap_or_default();
686
687 let dim = sidecar["dim"].as_u64().unwrap_or(256) as usize;
688 let path_str = index_path
689 .to_str()
690 .ok_or_else(|| anyhow::anyhow!("non-utf8 index path"))?;
691 let index = UsearchIndex::new(&hnsw_opts(dim))
692 .map_err(|e| anyhow::anyhow!("usearch create for load: {e}"))?;
693 index
694 .view(path_str)
695 .map_err(|e| anyhow::anyhow!("usearch view: {e}"))?;
696 index.change_expansion_search(HNSW_EXPANSION_SEARCH);
697
698 let out = query_index(&index, &id_map, query, top_k)?;
699
700 let mut guard = hnsw_cache_lock().lock().unwrap();
701 *guard = Some(CachedHnsw {
702 path: canon,
703 modified: idx_mtime,
704 index,
705 id_map,
706 });
707
708 Ok(Some(out))
709}
710
711fn fnv1a(data: &[u8]) -> u64 {
713 let mut hash: u64 = 0xcbf29ce484222325;
714 for &byte in data {
715 hash ^= byte as u64;
716 hash = hash.wrapping_mul(0x100000001b3);
717 }
718 hash
719}