Skip to main content

gitcortex_mcp/
embeddings.rs

1//! Semantic search via local embeddings (AllMiniLM-L6-v2, 384 dims).
2//!
3//! Model is downloaded from HuggingFace on first use (~23 MB, cached in
4//! `$XDG_CACHE_HOME/huggingface/hub`). All subsequent starts load from cache.
5//!
6//! Vector index is persisted per-branch at:
7//!   `~/.local/share/gitcortex/{repo_id}/embeddings_{branch}.bin`
8//!
9//! Background indexer (`index_missing`) embeds nodes that don't yet have a
10//! vector. Call it once after `gcx serve` opens the store. Search stays
11//! text-only while the indexer runs; it automatically uses semantic hits once
12//! at least one vector is loaded.
13
14use std::collections::HashMap;
15use std::io::{BufWriter, Write};
16use std::path::{Path, PathBuf};
17
18use fastembed::{EmbeddingModel, InitOptions, TextEmbedding};
19use gitcortex_core::graph::Node;
20
21/// Minimum cosine similarity to surface as a semantic hit.
22const SIMILARITY_THRESHOLD: f32 = 0.50;
23const DIM: usize = 384;
24
25// Binary format: magic + version + dim + count + entries
26const MAGIC: &[u8; 4] = b"GCXV";
27const FORMAT_VERSION: u32 = 1;
28
29// ── Vector index ──────────────────────────────────────────────────────────────
30
31pub struct SemanticIndex {
32    /// node_id → unit-normalised embedding
33    vectors: HashMap<String, Vec<f32>>,
34    path: PathBuf,
35}
36
37impl SemanticIndex {
38    pub fn load_or_create(path: &Path) -> Self {
39        let vectors = load_bin(path).unwrap_or_default();
40        if !vectors.is_empty() {
41            tracing::info!(
42                "semantic index loaded: {} vectors from {}",
43                vectors.len(),
44                path.display()
45            );
46        }
47        Self {
48            vectors,
49            path: path.to_owned(),
50        }
51    }
52
53    pub fn has(&self, node_id: &str) -> bool {
54        self.vectors.contains_key(node_id)
55    }
56
57    pub fn insert(&mut self, node_id: String, vec: Vec<f32>) {
58        self.vectors.insert(node_id, unit_normalise(vec));
59    }
60
61    pub fn len(&self) -> usize {
62        self.vectors.len()
63    }
64
65    pub fn is_empty(&self) -> bool {
66        self.vectors.is_empty()
67    }
68
69    /// Drop vectors whose node ID is not in `live_ids`. Node UUIDs regenerate
70    /// on every re-index, so without pruning the index file grows with
71    /// orphaned vectors that can still surface as (unresolvable) hits.
72    /// Returns the number of vectors removed.
73    pub fn retain_ids(&mut self, live_ids: &std::collections::HashSet<String>) -> usize {
74        let before = self.vectors.len();
75        self.vectors.retain(|id, _| live_ids.contains(id));
76        before - self.vectors.len()
77    }
78
79    pub fn save(&self) {
80        if let Err(e) = save_bin(&self.path, &self.vectors) {
81            tracing::warn!("failed to save semantic index: {e}");
82        }
83    }
84
85    /// Return up to `k` node IDs with cosine similarity ≥ SIMILARITY_THRESHOLD.
86    /// Query vector need not be pre-normalised — normalised internally.
87    pub fn top_k(&self, query_vec: &[f32], k: usize) -> Vec<String> {
88        let q = unit_normalise(query_vec.to_vec());
89        let mut scores: Vec<(&String, f32)> = self
90            .vectors
91            .iter()
92            .map(|(id, v)| (id, dot(&q, v)))
93            .filter(|(_, s)| *s >= SIMILARITY_THRESHOLD)
94            .collect();
95        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
96        scores
97            .into_iter()
98            .take(k)
99            .map(|(id, _)| id.clone())
100            .collect()
101    }
102}
103
104// ── Embedder ──────────────────────────────────────────────────────────────────
105
106pub struct Embedder {
107    model: TextEmbedding,
108}
109
110impl Embedder {
111    /// Download (first run) or load (cached) AllMiniLM-L6-v2.
112    pub fn new() -> anyhow::Result<Self> {
113        tracing::info!("initialising semantic embedder (AllMiniLM-L6-v2) …");
114        let model = TextEmbedding::try_new(
115            InitOptions::new(EmbeddingModel::AllMiniLML6V2).with_show_download_progress(false),
116        )?;
117        tracing::info!("semantic embedder ready");
118        Ok(Self { model })
119    }
120
121    pub fn embed_one(&self, text: &str) -> anyhow::Result<Vec<f32>> {
122        let mut out = self.model.embed(vec![text.to_owned()], None)?;
123        out.pop()
124            .ok_or_else(|| anyhow::anyhow!("embedder returned no vectors"))
125    }
126
127    /// Embed a batch of texts. Returns one vector per input in order.
128    pub fn embed_batch(&self, texts: Vec<String>) -> anyhow::Result<Vec<Vec<f32>>> {
129        self.model.embed(texts, None)
130    }
131}
132
133// ── Text representation for a node ────────────────────────────────────────────
134
135/// Build the text string that gets embedded for a node.
136/// Format: `"{kind} {qualified_name} {signature} {doc_comment}"`
137pub fn node_text(n: &Node) -> String {
138    let kind = n.kind.to_string();
139    let sig = &n.metadata.definition.signature;
140    let doc = n.metadata.definition.doc_comment.as_deref().unwrap_or("");
141    if sig.is_empty() && doc.is_empty() {
142        format!("{kind} {}", n.qualified_name)
143    } else if doc.is_empty() {
144        format!("{kind} {} {sig}", n.qualified_name)
145    } else {
146        format!("{kind} {} {sig} {doc}", n.qualified_name)
147    }
148}
149
150// ── Math helpers ──────────────────────────────────────────────────────────────
151
152fn dot(a: &[f32], b: &[f32]) -> f32 {
153    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
154}
155
156fn unit_normalise(mut v: Vec<f32>) -> Vec<f32> {
157    let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
158    if norm > f32::EPSILON {
159        for x in &mut v {
160            *x /= norm;
161        }
162    }
163    v
164}
165
166// ── Binary storage ────────────────────────────────────────────────────────────
167//
168// Layout (all integers little-endian):
169//   [4]  magic "GCXV"
170//   [4]  format version (u32)
171//   [4]  embedding dimension (u32)
172//   [4]  record count (u32)
173//   per record:
174//     [4]       id_len (u32)
175//     [id_len]  node_id (UTF-8)
176//     [dim × 4] f32 values
177
178fn load_bin(path: &Path) -> Option<HashMap<String, Vec<f32>>> {
179    let data = std::fs::read(path).ok()?;
180    let mut p = 0usize;
181
182    macro_rules! read_u32 {
183        () => {{
184            let b: [u8; 4] = data.get(p..p + 4)?.try_into().ok()?;
185            p += 4;
186            u32::from_le_bytes(b)
187        }};
188    }
189
190    if data.get(p..p + 4)? != MAGIC {
191        return None;
192    }
193    p += 4;
194
195    let _ver = read_u32!();
196    let dim = read_u32!() as usize;
197    let count = read_u32!() as usize;
198
199    let mut map = HashMap::with_capacity(count);
200    for _ in 0..count {
201        let id_len = read_u32!() as usize;
202        let id = String::from_utf8(data.get(p..p + id_len)?.to_vec()).ok()?;
203        p += id_len;
204        let end = p + dim * 4;
205        let vec: Vec<f32> = data
206            .get(p..end)?
207            .chunks_exact(4)
208            .map(|b| f32::from_le_bytes(b.try_into().unwrap()))
209            .collect();
210        p = end;
211        map.insert(id, vec);
212    }
213    Some(map)
214}
215
216fn save_bin(path: &Path, vectors: &HashMap<String, Vec<f32>>) -> std::io::Result<()> {
217    if let Some(parent) = path.parent() {
218        std::fs::create_dir_all(parent)?;
219    }
220    let tmp = path.with_extension("tmp");
221    {
222        let f = std::fs::File::create(&tmp)?;
223        let mut w = BufWriter::new(f);
224        w.write_all(MAGIC)?;
225        w.write_all(&FORMAT_VERSION.to_le_bytes())?;
226        w.write_all(&(DIM as u32).to_le_bytes())?;
227        w.write_all(&(vectors.len() as u32).to_le_bytes())?;
228        for (id, vec) in vectors {
229            let id_b = id.as_bytes();
230            w.write_all(&(id_b.len() as u32).to_le_bytes())?;
231            w.write_all(id_b)?;
232            for &v in vec {
233                w.write_all(&v.to_le_bytes())?;
234            }
235        }
236        w.flush()?;
237    }
238    std::fs::rename(&tmp, path)?;
239    Ok(())
240}