Skip to main content

infigraph_docs/
embed.rs

1use std::path::Path;
2use std::sync::Arc;
3
4use anyhow::Result;
5use rayon::prelude::*;
6
7use infigraph_core::embed::{
8    build_hnsw_index, doc_embedder, invalidate_embeddings_cache, invalidate_hnsw_cache,
9    load_embeddings, save_embeddings,
10};
11
12use crate::chunk::Chunk;
13use crate::store::DocStore;
14
15pub fn update_doc_embeddings(
16    store: &DocStore,
17    root: &Path,
18    new_chunks: &[&Chunk],
19    changed_files: &[&str],
20) -> Result<usize> {
21    let tg_dir = root.join(".infigraph");
22    let emb_path = tg_dir.join("docs_embeddings.bin");
23
24    let mut existing: std::collections::HashMap<String, Vec<f32>> = load_embeddings(&emb_path)
25        .unwrap_or_default()
26        .into_iter()
27        .collect();
28
29    let changed_set: std::collections::HashSet<&str> = changed_files.iter().copied().collect();
30
31    // Remove embeddings for chunks belonging to changed files
32    if !changed_set.is_empty() {
33        existing.retain(|id, _| {
34            let file = id.split("::chunk_").next().unwrap_or("");
35            !changed_set.contains(file)
36        });
37    }
38
39    // Embed new chunks with file path context for better semantic matching
40    let to_embed: Vec<(&str, String)> = new_chunks
41        .iter()
42        .map(|c| {
43            let file_context = doc_path_context(&c.doc_file);
44            let text = match (&file_context, &c.heading) {
45                (Some(ctx), Some(h)) => format!("{} > {}: {}", ctx, h, c.text),
46                (Some(ctx), None) => format!("{}: {}", ctx, c.text),
47                (None, Some(h)) => format!("{}: {}", h, c.text),
48                (None, None) => c.text.clone(),
49            };
50            (c.id.as_str(), text)
51        })
52        .collect();
53
54    if !to_embed.is_empty() {
55        let embedder = doc_embedder();
56        const BATCH: usize = 256;
57        let results: Vec<Vec<(String, Vec<f32>)>> = to_embed
58            .par_chunks(BATCH)
59            .map(|chunk| {
60                let emb = Arc::clone(&embedder);
61                let texts: Vec<&str> = chunk.iter().map(|(_, t)| t.as_str()).collect();
62                let vecs = emb.embed_batch(&texts).unwrap_or_default();
63                chunk
64                    .iter()
65                    .enumerate()
66                    .filter_map(|(i, (id, _))| vecs.get(i).map(|v| (id.to_string(), v.clone())))
67                    .collect()
68            })
69            .collect();
70        for batch in results {
71            for (id, v) in batch {
72                existing.insert(id, v);
73            }
74        }
75    }
76
77    // Also keep existing chunks that are still in the store
78    let all_store_chunks = store.get_all_chunks().unwrap_or_default();
79    let valid_ids: std::collections::HashSet<String> =
80        all_store_chunks.into_iter().map(|(id, _)| id).collect();
81    existing.retain(|id, _| valid_ids.contains(id));
82
83    let embeddings: Vec<(String, Vec<f32>)> = existing.into_iter().collect();
84    let count = embeddings.len();
85    save_embeddings(&emb_path, &embeddings)?;
86
87    // Build HNSW if above threshold or existing index
88    const HNSW_THRESHOLD: usize = 200_000;
89    let hnsw_path = tg_dir.join("docs_hnsw_index.usearch");
90    if count >= HNSW_THRESHOLD || hnsw_path.exists() {
91        invalidate_hnsw_cache();
92        if let Err(e) = build_hnsw_index(&embeddings, &hnsw_path, &emb_path) {
93            eprintln!(
94                "warning: doc HNSW index build failed ({e}), vector search will use brute-force"
95            );
96        }
97    }
98
99    invalidate_embeddings_cache();
100    Ok(count)
101}
102
103fn doc_path_context(file: &str) -> Option<String> {
104    let parts: Vec<&str> = file.split('/').collect();
105    if parts.len() <= 1 {
106        return None;
107    }
108    let stem = parts
109        .last()?
110        .rsplit_once('.')
111        .map(|(s, _)| s)
112        .unwrap_or(parts.last()?);
113    let name = stem.replace(['_', '-'], " ");
114    let dirs: Vec<&str> = parts[..parts.len() - 1]
115        .iter()
116        .filter(|p| {
117            let lower = p.to_lowercase();
118            !matches!(
119                lower.as_str(),
120                "src" | "doc" | "docs" | "documentation" | "resources"
121            )
122        })
123        .copied()
124        .collect();
125    if dirs.is_empty() {
126        Some(name)
127    } else {
128        let dir_path = dirs.join("/");
129        Some(format!("{}/{}", dir_path, name))
130    }
131}