use std::io::{BufReader, BufWriter, Read, Write};
use std::path::{Path, PathBuf};
use std::sync::{Arc, Mutex, OnceLock};
use anyhow::{Context, Result};
use crate::model::Symbol;
struct CachedEmbeddings {
path: PathBuf,
modified: std::time::SystemTime,
data: Vec<(String, Vec<f32>)>,
}
static EMBEDDINGS_CACHE: OnceLock<Mutex<Option<CachedEmbeddings>>> = OnceLock::new();
fn cache_lock() -> &'static Mutex<Option<CachedEmbeddings>> {
EMBEDDINGS_CACHE.get_or_init(|| Mutex::new(None))
}
pub fn load_embeddings_cached(path: &Path) -> Result<Vec<(String, Vec<f32>)>> {
let meta = std::fs::metadata(path).context("stat embeddings file")?;
let mtime = meta.modified().unwrap_or(std::time::UNIX_EPOCH);
let canon = path.canonicalize().unwrap_or_else(|_| path.to_path_buf());
let guard = cache_lock().lock().unwrap();
if let Some(cached) = guard.as_ref() {
if cached.path == canon && cached.modified == mtime {
return Ok(cached.data.clone());
}
}
drop(guard);
let data = load_embeddings(path)?;
let mut guard = cache_lock().lock().unwrap();
*guard = Some(CachedEmbeddings {
path: canon,
modified: mtime,
data: data.clone(),
});
Ok(data)
}
pub fn invalidate_embeddings_cache() {
if let Ok(mut guard) = cache_lock().lock() {
*guard = None;
}
}
pub trait EmbedProvider: Send + Sync {
fn dimension(&self) -> usize;
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>>;
fn embed(&self, text: &str) -> Result<Vec<f32>> {
let mut results = self.embed_batch(&[text])?;
results
.pop()
.ok_or_else(|| anyhow::anyhow!("embedding returned no results"))
}
}
pub fn symbol_text(sym: &Symbol) -> String {
let mut text = format!("{} {} {}", sym.kind.as_str(), sym.name, sym.language);
if let Some(doc) = &sym.docstring {
if !doc.is_empty() {
text.push_str(": ");
text.push_str(doc);
}
}
text
}
pub fn rich_symbol_text(kind: &str, name: &str, file: &str, language: &str, doc: &str) -> String {
rich_symbol_text_full(kind, name, file, language, doc, "", "")
}
pub fn rich_symbol_text_full(
kind: &str,
name: &str,
file: &str,
language: &str,
doc: &str,
params: &str,
ret: &str,
) -> String {
let path_context = path_to_context(file);
let mut text = format!("{kind} {name}");
if !params.is_empty() {
text.push_str(params);
}
if !ret.is_empty() {
text.push_str(" -> ");
text.push_str(ret);
}
text.push_str(" in ");
text.push_str(&path_context);
if !language.is_empty() {
text.push(' ');
text.push_str(language);
}
if !doc.is_empty() {
text.push_str(": ");
text.push_str(doc);
}
text
}
pub fn path_to_context(file: &str) -> String {
let parts: Vec<&str> = file.split('/').collect();
if parts.len() <= 3 {
return file.to_string();
}
let filename = parts.last().unwrap_or(&"");
let meaningful: Vec<&str> = parts
.iter()
.filter(|p| {
let lower = p.to_lowercase();
!matches!(
lower.as_str(),
"src" | "source" | "lib" | "include" | "_h" | "test" | "tests" | "benchmark"
)
})
.copied()
.collect();
if meaningful.len() <= 4 {
meaningful.join("/")
} else {
let last4 = &meaningful[meaningful.len() - 4..];
if last4.contains(filename) {
last4.join("/")
} else {
format!("{}/{}", last4[1..].join("/"), filename)
}
}
}
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
pub struct TrigramEmbedder {
dim: usize,
}
impl TrigramEmbedder {
pub fn new(dim: usize) -> Self {
Self { dim }
}
}
impl Default for TrigramEmbedder {
fn default() -> Self {
Self::new(256)
}
}
impl EmbedProvider for TrigramEmbedder {
fn dimension(&self) -> usize {
self.dim
}
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
Ok(texts.iter().map(|t| trigram_embed(t, self.dim)).collect())
}
}
fn trigram_embed(text: &str, dim: usize) -> Vec<f32> {
let mut vec = vec![0.0f32; dim];
let lower = text.to_lowercase();
let chars: Vec<char> = lower.chars().collect();
if chars.len() < 3 {
for c in &chars {
let h = fnv1a(&[*c as u8]) as usize % dim;
vec[h] += 1.0;
}
if chars.len() == 2 {
let bigram = format!("{}{}", chars[0], chars[1]);
let h = fnv1a(bigram.as_bytes()) as usize % dim;
vec[h] += 1.0;
}
} else {
for window in chars.windows(3) {
let trigram: String = window.iter().collect();
let h = fnv1a(trigram.as_bytes()) as usize % dim;
vec[h] += 1.0;
}
}
for token in lower.split(|c: char| !c.is_alphanumeric() && c != '_') {
if token.len() > 1 {
let h = fnv1a(token.as_bytes()) as usize % dim;
vec[h] += 0.5; }
}
let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for v in &mut vec {
*v /= norm;
}
}
vec
}
pub struct Model2VecEmbedder {
model: model2vec_rs::model::StaticModel,
}
impl Model2VecEmbedder {
pub fn new() -> Result<Self> {
let model_dir = Self::find_model_dir()?;
let model = model2vec_rs::model::StaticModel::from_pretrained(model_dir, None, None, None)?;
Ok(Self { model })
}
fn find_model_dir() -> Result<std::path::PathBuf> {
if let Ok(p) = std::env::var("INFIGRAPH_MODEL_DIR") {
let pb = std::path::PathBuf::from(p);
if pb.exists() {
return Ok(pb);
}
}
if let Some(home) = dirs_next::home_dir() {
let installed = home
.join(".infigraph")
.join("models")
.join("potion-base-8M");
if installed.join("model.safetensors").exists() {
return Ok(installed);
}
}
let start =
std::env::current_exe().unwrap_or_else(|_| std::env::current_dir().unwrap_or_default());
let mut dir = start.as_path();
loop {
let candidate = dir.join("models/potion-base-8M");
if candidate.join("model.safetensors").exists() {
return Ok(candidate);
}
match dir.parent() {
Some(p) => dir = p,
None => break,
}
}
let cwd = std::env::current_dir()?;
let mut dir = cwd.as_path();
loop {
let candidate = dir.join("models/potion-base-8M");
if candidate.join("model.safetensors").exists() {
return Ok(candidate);
}
match dir.parent() {
Some(p) => dir = p,
None => break,
}
}
anyhow::bail!(
"models/potion-base-8M not found; set INFIGRAPH_MODEL_DIR or run from repo root"
)
}
}
impl EmbedProvider for Model2VecEmbedder {
fn dimension(&self) -> usize {
256 }
fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>> {
let owned: Vec<String> = texts.iter().map(|s| s.to_string()).collect();
Ok(self.model.encode(&owned))
}
}
static CODE_EMBEDDER: OnceLock<Arc<dyn EmbedProvider>> = OnceLock::new();
static DOC_EMBEDDER: OnceLock<Arc<dyn EmbedProvider>> = OnceLock::new();
pub fn init_embedder() -> Arc<dyn EmbedProvider> {
match Model2VecEmbedder::new() {
Ok(m) => Arc::new(m),
Err(e) => {
eprintln!("warning: Model2Vec unavailable ({e}), using trigram fallback");
Arc::new(TrigramEmbedder::default())
}
}
}
pub fn code_embedder() -> Arc<dyn EmbedProvider> {
Arc::clone(CODE_EMBEDDER.get_or_init(init_embedder))
}
pub fn doc_embedder() -> Arc<dyn EmbedProvider> {
Arc::clone(DOC_EMBEDDER.get_or_init(init_embedder))
}
pub fn best_embedder() -> Box<dyn EmbedProvider> {
match Model2VecEmbedder::new() {
Ok(m) => Box::new(m),
Err(e) => {
eprintln!("warning: Model2Vec unavailable ({e}), using trigram fallback");
Box::new(TrigramEmbedder::default())
}
}
}
pub fn embedding_count(root: &Path) -> usize {
let path = root.join(".infigraph").join("embeddings.bin");
let Ok(file) = std::fs::File::open(&path) else {
return 0;
};
let mut r = BufReader::new(file);
let mut buf4 = [0u8; 4];
if r.read_exact(&mut buf4).is_err() {
return 0;
}
u32::from_le_bytes(buf4) as usize
}
pub fn save_embeddings(path: &Path, embeddings: &[(String, Vec<f32>)]) -> Result<()> {
let file = std::fs::File::create(path).context("create embeddings file")?;
let mut w = BufWriter::new(file);
w.write_all(&(embeddings.len() as u32).to_le_bytes())?;
for (id, vec) in embeddings {
let id_bytes = id.as_bytes();
w.write_all(&(id_bytes.len() as u32).to_le_bytes())?;
w.write_all(id_bytes)?;
w.write_all(&(vec.len() as u32).to_le_bytes())?;
for &v in vec {
w.write_all(&v.to_le_bytes())?;
}
}
drop(w);
invalidate_embeddings_cache();
Ok(())
}
pub fn load_embeddings(path: &Path) -> Result<Vec<(String, Vec<f32>)>> {
let file = std::fs::File::open(path).context("open embeddings file")?;
let mut r = BufReader::new(file);
let mut buf4 = [0u8; 4];
r.read_exact(&mut buf4)?;
let count = u32::from_le_bytes(buf4) as usize;
let mut result = Vec::with_capacity(count);
for _ in 0..count {
r.read_exact(&mut buf4)?;
let id_len = u32::from_le_bytes(buf4) as usize;
let mut id_buf = vec![0u8; id_len];
r.read_exact(&mut id_buf)?;
let id = String::from_utf8(id_buf).context("invalid utf8 in embedding id")?;
r.read_exact(&mut buf4)?;
let dim = u32::from_le_bytes(buf4) as usize;
let mut vec = Vec::with_capacity(dim);
for _ in 0..dim {
r.read_exact(&mut buf4)?;
vec.push(f32::from_le_bytes(buf4));
}
result.push((id, vec));
}
Ok(result)
}
pub fn update_embeddings(
store: &crate::graph::GraphStore,
root: &Path,
changed_files: &[&str],
) -> Result<usize> {
use rayon::prelude::*;
use std::sync::Arc;
let conn = store.connection()?;
let gq = crate::graph::GraphQuery::new(&conn);
let rows = gq.raw_query("MATCH (s:Symbol) RETURN s.id, s.name, s.kind, s.file, s.docstring, s.language, s.parameters, s.return_type")?;
if rows.is_empty() {
return Ok(0);
}
let emb_path = root.join(".infigraph").join("embeddings.bin");
let mut existing: std::collections::HashMap<String, Vec<f32>> = load_embeddings(&emb_path)
.unwrap_or_default()
.into_iter()
.collect();
let changed_set: std::collections::HashSet<&str> = changed_files.iter().copied().collect();
let to_embed: Vec<(String, String)> = rows
.iter()
.filter_map(|row| {
let id = &row[0];
let file = row.get(3).map(|s| s.as_str()).unwrap_or("");
if !changed_set.is_empty() && !changed_set.contains(file) && existing.contains_key(id) {
return None;
}
let name = &row[1];
let kind = &row[2];
let doc = row.get(4).map(|s| s.as_str()).unwrap_or("");
let lang = row.get(5).map(|s| s.as_str()).unwrap_or("");
let params = row.get(6).map(|s| s.as_str()).unwrap_or("");
let ret = row.get(7).map(|s| s.as_str()).unwrap_or("");
let text = rich_symbol_text_full(kind, name, file, lang, doc, params, ret);
Some((id.clone(), text))
})
.collect();
if !to_embed.is_empty() {
let embedder: Arc<Box<dyn EmbedProvider>> = Arc::new(best_embedder());
const BATCH: usize = 256;
let results: Vec<Vec<(String, Vec<f32>)>> = to_embed
.par_chunks(BATCH)
.map(|chunk| {
let emb = Arc::clone(&embedder);
let texts: Vec<&str> = chunk.iter().map(|(_, t)| t.as_str()).collect();
let vecs = emb.embed_batch(&texts).unwrap_or_default();
chunk
.iter()
.enumerate()
.filter_map(|(i, (id, _))| vecs.get(i).map(|v| (id.clone(), v.clone())))
.collect()
})
.collect();
for batch in results {
for (id, v) in batch {
existing.insert(id, v);
}
}
}
let all_ids: std::collections::HashSet<String> = rows.iter().map(|r| r[0].clone()).collect();
existing.retain(|id, _| all_ids.contains(id));
let symbol_embeddings: Vec<(String, Vec<f32>)> = existing.into_iter().collect();
let count = symbol_embeddings.len();
save_embeddings(&emb_path, &symbol_embeddings)?;
const HNSW_THRESHOLD: usize = 200_000;
let hnsw_path = root.join(".infigraph").join("hnsw_index.usearch");
let should_build = count >= HNSW_THRESHOLD || hnsw_path.exists();
if should_build {
invalidate_hnsw_cache();
if let Err(e) = build_hnsw_index(&symbol_embeddings, &hnsw_path, &emb_path) {
eprintln!("warning: HNSW index build failed ({e}), vector search will use brute-force");
}
}
Ok(count)
}
use usearch::{Index as UsearchIndex, IndexOptions, MetricKind, ScalarKind};
const HNSW_CONNECTIVITY: usize = 32;
const HNSW_EXPANSION_ADD: usize = 200;
const HNSW_EXPANSION_SEARCH: usize = 256;
const HNSW_OVERSAMPLE: usize = 20;
static HNSW_CACHE: OnceLock<Mutex<Option<CachedHnsw>>> = OnceLock::new();
struct CachedHnsw {
path: PathBuf,
modified: std::time::SystemTime,
index: UsearchIndex,
id_map: Vec<String>,
}
fn hnsw_cache_lock() -> &'static Mutex<Option<CachedHnsw>> {
HNSW_CACHE.get_or_init(|| Mutex::new(None))
}
fn hnsw_opts(dim: usize) -> IndexOptions {
IndexOptions {
dimensions: dim,
metric: MetricKind::IP,
quantization: ScalarKind::F32,
connectivity: HNSW_CONNECTIVITY,
expansion_add: HNSW_EXPANSION_ADD,
..IndexOptions::default()
}
}
pub fn build_hnsw_index(
embeddings: &[(String, Vec<f32>)],
index_path: &Path,
embeddings_path: &Path,
) -> Result<usize> {
if embeddings.is_empty() {
return Ok(0);
}
let dim = embeddings[0].1.len();
let n = embeddings.len();
let threads = std::thread::available_parallelism()
.map(|t| t.get())
.unwrap_or(4);
let index =
UsearchIndex::new(&hnsw_opts(dim)).map_err(|e| anyhow::anyhow!("usearch create: {e}"))?;
index
.reserve(n)
.map_err(|e| anyhow::anyhow!("usearch reserve: {e}"))?;
let index = std::sync::Arc::new(index);
let chunk_size = n.div_ceil(threads);
std::thread::scope(|s| {
for (chunk_idx, chunk) in embeddings.chunks(chunk_size).enumerate() {
let idx = std::sync::Arc::clone(&index);
let offset = chunk_idx * chunk_size;
s.spawn(move || {
for (i, (_, v)) in chunk.iter().enumerate() {
let _ = idx.add((offset + i) as u64, v);
}
});
}
});
let path_str = index_path
.to_str()
.ok_or_else(|| anyhow::anyhow!("non-utf8 index path"))?;
index
.save(path_str)
.map_err(|e| anyhow::anyhow!("usearch save: {e}"))?;
let emb_mtime = std::fs::metadata(embeddings_path)
.and_then(|m| m.modified())
.unwrap_or(std::time::UNIX_EPOCH);
let sidecar_path = index_path.with_extension("meta");
let ids: Vec<&str> = embeddings.iter().map(|(id, _)| id.as_str()).collect();
let sidecar = serde_json::json!({
"emb_mtime_secs": emb_mtime.duration_since(std::time::UNIX_EPOCH).unwrap_or_default().as_secs(),
"count": n,
"dim": dim,
"ids": ids,
});
std::fs::write(&sidecar_path, serde_json::to_vec(&sidecar)?).context("write hnsw sidecar")?;
invalidate_hnsw_cache();
Ok(n)
}
pub fn invalidate_hnsw_cache() {
if let Ok(mut guard) = hnsw_cache_lock().lock() {
*guard = None;
}
}
pub struct HnswResult {
pub id: String,
pub score: f32,
}
fn query_index(
index: &UsearchIndex,
id_map: &[String],
query: &[f32],
top_k: usize,
) -> Result<Vec<HnswResult>> {
let fetch_k = top_k * HNSW_OVERSAMPLE;
let results = index
.search(query, fetch_k)
.map_err(|e| anyhow::anyhow!("usearch search: {e}"))?;
let out: Vec<HnswResult> = results
.keys
.iter()
.zip(results.distances.iter())
.filter_map(|(&key, &dist)| {
let idx = key as usize;
id_map.get(idx).map(|id| HnswResult {
id: id.clone(),
score: 1.0 - dist,
})
})
.collect();
Ok(out)
}
pub fn search_hnsw(
index_path: &Path,
embeddings_path: &Path,
query: &[f32],
top_k: usize,
) -> Result<Option<Vec<HnswResult>>> {
let sidecar_path = index_path.with_extension("meta");
if !index_path.exists() || !sidecar_path.exists() {
return Ok(None);
}
let emb_mtime_secs = std::fs::metadata(embeddings_path)
.and_then(|m| m.modified())
.unwrap_or(std::time::UNIX_EPOCH)
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let canon = index_path
.canonicalize()
.unwrap_or_else(|_| index_path.to_path_buf());
let idx_mtime = std::fs::metadata(index_path)
.and_then(|m| m.modified())
.unwrap_or(std::time::UNIX_EPOCH);
let guard = hnsw_cache_lock().lock().unwrap();
if let Some(cached) = guard.as_ref() {
if cached.path == canon && cached.modified == idx_mtime {
return Ok(Some(query_index(
&cached.index,
&cached.id_map,
query,
top_k,
)?));
}
}
drop(guard);
let sidecar_bytes = std::fs::read(&sidecar_path).context("read hnsw sidecar")?;
let sidecar: serde_json::Value =
serde_json::from_slice(&sidecar_bytes).context("parse hnsw sidecar")?;
let stored_mtime = sidecar["emb_mtime_secs"].as_u64().unwrap_or(0);
if stored_mtime != emb_mtime_secs {
return Ok(None);
}
let id_map: Vec<String> = sidecar["ids"]
.as_array()
.map(|arr| {
arr.iter()
.filter_map(|v| v.as_str().map(String::from))
.collect()
})
.unwrap_or_default();
let dim = sidecar["dim"].as_u64().unwrap_or(256) as usize;
let path_str = index_path
.to_str()
.ok_or_else(|| anyhow::anyhow!("non-utf8 index path"))?;
let index = UsearchIndex::new(&hnsw_opts(dim))
.map_err(|e| anyhow::anyhow!("usearch create for load: {e}"))?;
index
.view(path_str)
.map_err(|e| anyhow::anyhow!("usearch view: {e}"))?;
index.change_expansion_search(HNSW_EXPANSION_SEARCH);
let out = query_index(&index, &id_map, query, top_k)?;
let mut guard = hnsw_cache_lock().lock().unwrap();
*guard = Some(CachedHnsw {
path: canon,
modified: idx_mtime,
index,
id_map,
});
Ok(Some(out))
}
fn fnv1a(data: &[u8]) -> u64 {
let mut hash: u64 = 0xcbf29ce484222325;
for &byte in data {
hash ^= byte as u64;
hash = hash.wrapping_mul(0x100000001b3);
}
hash
}