1use std::path::Path;
2use std::sync::Arc;
3
4use anyhow::Result;
5use rayon::prelude::*;
6
7use infigraph_core::embed::{
8 build_hnsw_index, doc_embedder, invalidate_embeddings_cache, invalidate_hnsw_cache,
9 load_embeddings, save_embeddings,
10};
11
12use crate::chunk::Chunk;
13use crate::store::DocStore;
14
15pub fn update_doc_embeddings(
16 store: &DocStore,
17 root: &Path,
18 new_chunks: &[&Chunk],
19 changed_files: &[&str],
20) -> Result<usize> {
21 let tg_dir = root.join(".infigraph");
22 let emb_path = tg_dir.join("docs_embeddings.bin");
23
24 let mut existing: std::collections::HashMap<String, Vec<f32>> = load_embeddings(&emb_path)
25 .unwrap_or_default()
26 .into_iter()
27 .collect();
28
29 let changed_set: std::collections::HashSet<&str> = changed_files.iter().copied().collect();
30
31 if !changed_set.is_empty() {
33 existing.retain(|id, _| {
34 let file = id.split("::chunk_").next().unwrap_or("");
35 !changed_set.contains(file)
36 });
37 }
38
39 let to_embed: Vec<(&str, String)> = new_chunks
41 .iter()
42 .map(|c| {
43 let file_context = doc_path_context(&c.doc_file);
44 let text = match (&file_context, &c.heading) {
45 (Some(ctx), Some(h)) => format!("{} > {}: {}", ctx, h, c.text),
46 (Some(ctx), None) => format!("{}: {}", ctx, c.text),
47 (None, Some(h)) => format!("{}: {}", h, c.text),
48 (None, None) => c.text.clone(),
49 };
50 (c.id.as_str(), text)
51 })
52 .collect();
53
54 if !to_embed.is_empty() {
55 let embedder = doc_embedder();
56 const BATCH: usize = 256;
57 let results: Vec<Vec<(String, Vec<f32>)>> = to_embed
58 .par_chunks(BATCH)
59 .map(|chunk| {
60 let emb = Arc::clone(&embedder);
61 let texts: Vec<&str> = chunk.iter().map(|(_, t)| t.as_str()).collect();
62 let vecs = emb.embed_batch(&texts).unwrap_or_default();
63 chunk
64 .iter()
65 .enumerate()
66 .filter_map(|(i, (id, _))| vecs.get(i).map(|v| (id.to_string(), v.clone())))
67 .collect()
68 })
69 .collect();
70 for batch in results {
71 for (id, v) in batch {
72 existing.insert(id, v);
73 }
74 }
75 }
76
77 let all_store_chunks = store.get_all_chunks().unwrap_or_default();
79 let valid_ids: std::collections::HashSet<String> =
80 all_store_chunks.into_iter().map(|(id, _)| id).collect();
81 existing.retain(|id, _| valid_ids.contains(id));
82
83 let embeddings: Vec<(String, Vec<f32>)> = existing.into_iter().collect();
84 let count = embeddings.len();
85 save_embeddings(&emb_path, &embeddings)?;
86
87 const HNSW_THRESHOLD: usize = 200_000;
89 let hnsw_path = tg_dir.join("docs_hnsw_index.usearch");
90 if count >= HNSW_THRESHOLD || hnsw_path.exists() {
91 invalidate_hnsw_cache();
92 if let Err(e) = build_hnsw_index(&embeddings, &hnsw_path, &emb_path) {
93 eprintln!(
94 "warning: doc HNSW index build failed ({e}), vector search will use brute-force"
95 );
96 }
97 }
98
99 invalidate_embeddings_cache();
100 Ok(count)
101}
102
103fn doc_path_context(file: &str) -> Option<String> {
104 let parts: Vec<&str> = file.split('/').collect();
105 if parts.len() <= 1 {
106 return None;
107 }
108 let stem = parts
109 .last()?
110 .rsplit_once('.')
111 .map(|(s, _)| s)
112 .unwrap_or(parts.last()?);
113 let name = stem.replace(['_', '-'], " ");
114 let dirs: Vec<&str> = parts[..parts.len() - 1]
115 .iter()
116 .filter(|p| {
117 let lower = p.to_lowercase();
118 !matches!(
119 lower.as_str(),
120 "src" | "doc" | "docs" | "documentation" | "resources"
121 )
122 })
123 .copied()
124 .collect();
125 if dirs.is_empty() {
126 Some(name)
127 } else {
128 let dir_path = dirs.join("/");
129 Some(format!("{}/{}", dir_path, name))
130 }
131}