Skip to main content

lean_ctx/core/
vector_index.rs

1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3
4use md5::{Digest, Md5};
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct CodeChunk {
9    pub file_path: String,
10    pub symbol_name: String,
11    pub kind: ChunkKind,
12    pub start_line: usize,
13    pub end_line: usize,
14    pub content: String,
15    pub tokens: Vec<String>,
16    pub token_count: usize,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
20pub enum ChunkKind {
21    Function,
22    Struct,
23    Impl,
24    Module,
25    Class,
26    Method,
27    Other,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct BM25Index {
32    pub chunks: Vec<CodeChunk>,
33    pub inverted: HashMap<String, Vec<(usize, f64)>>,
34    pub avg_doc_len: f64,
35    pub doc_count: usize,
36    pub doc_freqs: HashMap<String, usize>,
37}
38
39#[derive(Debug, Clone)]
40pub struct SearchResult {
41    pub chunk_idx: usize,
42    pub score: f64,
43    pub file_path: String,
44    pub symbol_name: String,
45    pub kind: ChunkKind,
46    pub start_line: usize,
47    pub end_line: usize,
48    pub snippet: String,
49}
50
51const BM25_K1: f64 = 1.2;
52const BM25_B: f64 = 0.75;
53
54impl Default for BM25Index {
55    fn default() -> Self {
56        Self::new()
57    }
58}
59
60impl BM25Index {
61    pub fn new() -> Self {
62        Self {
63            chunks: Vec::new(),
64            inverted: HashMap::new(),
65            avg_doc_len: 0.0,
66            doc_count: 0,
67            doc_freqs: HashMap::new(),
68        }
69    }
70
71    pub fn build_from_directory(root: &Path) -> Self {
72        let mut index = Self::new();
73        let walker = ignore::WalkBuilder::new(root)
74            .hidden(true)
75            .git_ignore(true)
76            .max_depth(Some(10))
77            .build();
78
79        let mut file_count = 0usize;
80        for entry in walker.flatten() {
81            if file_count >= 2000 {
82                break;
83            }
84            let path = entry.path();
85            if !path.is_file() {
86                continue;
87            }
88            if !is_code_file(path) {
89                continue;
90            }
91            if let Ok(content) = std::fs::read_to_string(path) {
92                let rel = path
93                    .strip_prefix(root)
94                    .unwrap_or(path)
95                    .to_string_lossy()
96                    .to_string();
97                let chunks = extract_chunks(&rel, &content);
98                for chunk in chunks {
99                    index.add_chunk(chunk);
100                }
101                file_count += 1;
102            }
103        }
104
105        index.finalize();
106        index
107    }
108
109    fn add_chunk(&mut self, chunk: CodeChunk) {
110        let idx = self.chunks.len();
111
112        for token in &chunk.tokens {
113            let lower = token.to_lowercase();
114            self.inverted.entry(lower).or_default().push((idx, 1.0));
115        }
116
117        self.chunks.push(chunk);
118    }
119
120    fn finalize(&mut self) {
121        self.doc_count = self.chunks.len();
122        if self.doc_count == 0 {
123            return;
124        }
125
126        let total_len: usize = self.chunks.iter().map(|c| c.token_count).sum();
127        self.avg_doc_len = total_len as f64 / self.doc_count as f64;
128
129        self.doc_freqs.clear();
130        for (term, postings) in &self.inverted {
131            let unique_docs: std::collections::HashSet<usize> =
132                postings.iter().map(|(idx, _)| *idx).collect();
133            self.doc_freqs.insert(term.clone(), unique_docs.len());
134        }
135    }
136
137    pub fn search(&self, query: &str, top_k: usize) -> Vec<SearchResult> {
138        let query_tokens = tokenize(query);
139        if query_tokens.is_empty() || self.doc_count == 0 {
140            return Vec::new();
141        }
142
143        let mut scores: HashMap<usize, f64> = HashMap::new();
144
145        for token in &query_tokens {
146            let lower = token.to_lowercase();
147            let df = *self.doc_freqs.get(&lower).unwrap_or(&0) as f64;
148            if df == 0.0 {
149                continue;
150            }
151
152            let idf = ((self.doc_count as f64 - df + 0.5) / (df + 0.5) + 1.0).ln();
153
154            if let Some(postings) = self.inverted.get(&lower) {
155                let mut doc_tfs: HashMap<usize, f64> = HashMap::new();
156                for (idx, weight) in postings {
157                    *doc_tfs.entry(*idx).or_insert(0.0) += weight;
158                }
159
160                for (doc_idx, tf) in &doc_tfs {
161                    let doc_len = self.chunks[*doc_idx].token_count as f64;
162                    let norm_len = doc_len / self.avg_doc_len.max(1.0);
163                    let bm25 = idf * (tf * (BM25_K1 + 1.0))
164                        / (tf + BM25_K1 * (1.0 - BM25_B + BM25_B * norm_len));
165
166                    *scores.entry(*doc_idx).or_insert(0.0) += bm25;
167                }
168            }
169        }
170
171        let mut results: Vec<SearchResult> = scores
172            .into_iter()
173            .map(|(idx, score)| {
174                let chunk = &self.chunks[idx];
175                let snippet = chunk.content.lines().take(5).collect::<Vec<_>>().join("\n");
176                SearchResult {
177                    chunk_idx: idx,
178                    score,
179                    file_path: chunk.file_path.clone(),
180                    symbol_name: chunk.symbol_name.clone(),
181                    kind: chunk.kind.clone(),
182                    start_line: chunk.start_line,
183                    end_line: chunk.end_line,
184                    snippet,
185                }
186            })
187            .collect();
188
189        results.sort_by(|a, b| {
190            b.score
191                .partial_cmp(&a.score)
192                .unwrap_or(std::cmp::Ordering::Equal)
193        });
194        results.truncate(top_k);
195        results
196    }
197
198    pub fn save(&self, root: &Path) -> std::io::Result<()> {
199        let dir = index_dir(root);
200        std::fs::create_dir_all(&dir)?;
201        let data = serde_json::to_string(self).map_err(std::io::Error::other)?;
202        std::fs::write(dir.join("bm25_index.json"), data)?;
203        Ok(())
204    }
205
206    pub fn load(root: &Path) -> Option<Self> {
207        let path = index_dir(root).join("bm25_index.json");
208        let data = std::fs::read_to_string(path).ok()?;
209        serde_json::from_str(&data).ok()
210    }
211
212    pub fn load_or_build(root: &Path) -> Self {
213        if let Some(idx) = Self::load(root) {
214            if !vector_index_looks_stale(&idx, root) {
215                return idx;
216            }
217            tracing::warn!(
218                "[vector_index: stale index detected for {}; rebuilding]",
219                root.display()
220            );
221        }
222
223        let built = Self::build_from_directory(root);
224        let _ = built.save(root);
225        built
226    }
227
228    pub fn index_file_path(root: &Path) -> PathBuf {
229        index_dir(root).join("bm25_index.json")
230    }
231}
232
233fn vector_index_looks_stale(index: &BM25Index, root: &Path) -> bool {
234    if index.chunks.is_empty() {
235        return false;
236    }
237
238    let mut seen = std::collections::HashSet::<&str>::new();
239    for chunk in &index.chunks {
240        let rel = chunk.file_path.trim_start_matches(['/', '\\']);
241        if rel.is_empty() {
242            continue;
243        }
244        if !seen.insert(rel) {
245            continue;
246        }
247        if !root.join(rel).exists() {
248            return true;
249        }
250    }
251
252    false
253}
254
255fn index_dir(root: &Path) -> PathBuf {
256    let mut hasher = Md5::new();
257    hasher.update(root.to_string_lossy().as_bytes());
258    let hash = format!("{:x}", hasher.finalize());
259    crate::core::data_dir::lean_ctx_data_dir()
260        .unwrap_or_else(|_| PathBuf::from("."))
261        .join("vectors")
262        .join(hash)
263}
264
265pub(crate) fn is_code_file(path: &Path) -> bool {
266    let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
267    matches!(
268        ext,
269        "rs" | "ts"
270            | "tsx"
271            | "js"
272            | "jsx"
273            | "py"
274            | "go"
275            | "java"
276            | "c"
277            | "cpp"
278            | "h"
279            | "hpp"
280            | "rb"
281            | "cs"
282            | "kt"
283            | "swift"
284            | "php"
285            | "scala"
286            | "ex"
287            | "exs"
288            | "zig"
289            | "lua"
290            | "dart"
291            | "vue"
292            | "svelte"
293    )
294}
295
296fn tokenize(text: &str) -> Vec<String> {
297    let mut tokens = Vec::new();
298    let mut current = String::new();
299
300    for ch in text.chars() {
301        if ch.is_alphanumeric() || ch == '_' {
302            current.push(ch);
303        } else {
304            if current.len() >= 2 {
305                tokens.push(current.clone());
306            }
307            current.clear();
308        }
309    }
310    if current.len() >= 2 {
311        tokens.push(current);
312    }
313
314    split_camel_case_tokens(&tokens)
315}
316
317pub(crate) fn tokenize_for_index(text: &str) -> Vec<String> {
318    tokenize(text)
319}
320
321fn split_camel_case_tokens(tokens: &[String]) -> Vec<String> {
322    let mut result = Vec::new();
323    for token in tokens {
324        result.push(token.clone());
325        let mut start = 0;
326        let chars: Vec<char> = token.chars().collect();
327        for i in 1..chars.len() {
328            if chars[i].is_uppercase() && (i + 1 >= chars.len() || !chars[i + 1].is_uppercase()) {
329                let part: String = chars[start..i].iter().collect();
330                if part.len() >= 2 {
331                    result.push(part);
332                }
333                start = i;
334            }
335        }
336        if start > 0 {
337            let part: String = chars[start..].iter().collect();
338            if part.len() >= 2 {
339                result.push(part);
340            }
341        }
342    }
343    result
344}
345
346fn extract_chunks(file_path: &str, content: &str) -> Vec<CodeChunk> {
347    #[cfg(feature = "tree-sitter")]
348    {
349        let ext = std::path::Path::new(file_path)
350            .extension()
351            .and_then(|e| e.to_str())
352            .unwrap_or("");
353        if let Some(chunks) = crate::core::chunks_ts::extract_chunks_ts(file_path, content, ext) {
354            return chunks;
355        }
356    }
357
358    let lines: Vec<&str> = content.lines().collect();
359    if lines.is_empty() {
360        return Vec::new();
361    }
362
363    let mut chunks = Vec::new();
364    let mut i = 0;
365
366    while i < lines.len() {
367        let trimmed = lines[i].trim();
368
369        if let Some((name, kind)) = detect_symbol(trimmed) {
370            let start = i;
371            let end = find_block_end(&lines, i);
372            let block: String = lines[start..=end.min(lines.len() - 1)].to_vec().join("\n");
373            let tokens = tokenize(&block);
374            let token_count = tokens.len();
375
376            chunks.push(CodeChunk {
377                file_path: file_path.to_string(),
378                symbol_name: name,
379                kind,
380                start_line: start + 1,
381                end_line: end + 1,
382                content: block,
383                tokens,
384                token_count,
385            });
386
387            i = end + 1;
388        } else {
389            i += 1;
390        }
391    }
392
393    if chunks.is_empty() && !content.is_empty() {
394        // Fallback: when no symbols are detected, chunk the file into stable, content-defined
395        // segments (rolling-hash) to enable meaningful semantic search over non-code assets.
396        //
397        // Safety note: rabin_karp uses byte offsets; we must slice bytes and decode safely.
398        let bytes = content.as_bytes();
399        let rk_chunks = crate::core::rabin_karp::chunk(content);
400        if !rk_chunks.is_empty() && rk_chunks.len() <= 200 {
401            for (idx, c) in rk_chunks.into_iter().take(50).enumerate() {
402                let end = (c.offset + c.length).min(bytes.len());
403                let slice = &bytes[c.offset..end];
404                let chunk_text = String::from_utf8_lossy(slice).into_owned();
405                let tokens = tokenize(&chunk_text);
406                let token_count = tokens.len();
407                let start_line = 1 + bytecount::count(&bytes[..c.offset], b'\n');
408                let end_line = start_line + bytecount::count(slice, b'\n');
409                chunks.push(CodeChunk {
410                    file_path: file_path.to_string(),
411                    symbol_name: format!("{file_path}#chunk-{idx}"),
412                    kind: ChunkKind::Module,
413                    start_line,
414                    end_line: end_line.max(start_line),
415                    content: chunk_text,
416                    tokens,
417                    token_count,
418                });
419            }
420        } else {
421            let tokens = tokenize(content);
422            let token_count = tokens.len();
423            let snippet = lines
424                .iter()
425                .take(50)
426                .copied()
427                .collect::<Vec<_>>()
428                .join("\n");
429            chunks.push(CodeChunk {
430                file_path: file_path.to_string(),
431                symbol_name: file_path.to_string(),
432                kind: ChunkKind::Module,
433                start_line: 1,
434                end_line: lines.len(),
435                content: snippet,
436                tokens,
437                token_count,
438            });
439        }
440    }
441
442    chunks
443}
444
445fn detect_symbol(line: &str) -> Option<(String, ChunkKind)> {
446    let trimmed = line.trim();
447
448    let patterns: &[(&str, ChunkKind)] = &[
449        ("pub async fn ", ChunkKind::Function),
450        ("async fn ", ChunkKind::Function),
451        ("pub fn ", ChunkKind::Function),
452        ("fn ", ChunkKind::Function),
453        ("pub struct ", ChunkKind::Struct),
454        ("struct ", ChunkKind::Struct),
455        ("pub enum ", ChunkKind::Struct),
456        ("enum ", ChunkKind::Struct),
457        ("impl ", ChunkKind::Impl),
458        ("pub trait ", ChunkKind::Struct),
459        ("trait ", ChunkKind::Struct),
460        ("export function ", ChunkKind::Function),
461        ("export async function ", ChunkKind::Function),
462        ("export default function ", ChunkKind::Function),
463        ("function ", ChunkKind::Function),
464        ("async function ", ChunkKind::Function),
465        ("export class ", ChunkKind::Class),
466        ("class ", ChunkKind::Class),
467        ("export interface ", ChunkKind::Struct),
468        ("interface ", ChunkKind::Struct),
469        ("def ", ChunkKind::Function),
470        ("async def ", ChunkKind::Function),
471        ("class ", ChunkKind::Class),
472        ("func ", ChunkKind::Function),
473    ];
474
475    for (prefix, kind) in patterns {
476        if let Some(rest) = trimmed.strip_prefix(prefix) {
477            let name: String = rest
478                .chars()
479                .take_while(|c| c.is_alphanumeric() || *c == '_' || *c == '<')
480                .take_while(|c| *c != '<')
481                .collect();
482            if !name.is_empty() {
483                return Some((name, kind.clone()));
484            }
485        }
486    }
487
488    None
489}
490
491fn find_block_end(lines: &[&str], start: usize) -> usize {
492    let mut depth = 0i32;
493    let mut found_open = false;
494
495    for (i, line) in lines.iter().enumerate().skip(start) {
496        for ch in line.chars() {
497            match ch {
498                '{' | '(' if !found_open || depth > 0 => {
499                    depth += 1;
500                    found_open = true;
501                }
502                '}' | ')' if depth > 0 => {
503                    depth -= 1;
504                    if depth == 0 && found_open {
505                        return i;
506                    }
507                }
508                _ => {}
509            }
510        }
511
512        if found_open && depth <= 0 && i > start {
513            return i;
514        }
515
516        if !found_open && i > start + 2 {
517            let trimmed = lines[i].trim();
518            if trimmed.is_empty()
519                || (!trimmed.starts_with(' ') && !trimmed.starts_with('\t') && i > start)
520            {
521                return i.saturating_sub(1);
522            }
523        }
524    }
525
526    (start + 50).min(lines.len().saturating_sub(1))
527}
528
529pub fn format_search_results(results: &[SearchResult], compact: bool) -> String {
530    if results.is_empty() {
531        return "No results found.".to_string();
532    }
533
534    let mut out = String::new();
535    for (i, r) in results.iter().enumerate() {
536        if compact {
537            out.push_str(&format!(
538                "{}. {:.2} {}:{}-{} {:?} {}\n",
539                i + 1,
540                r.score,
541                r.file_path,
542                r.start_line,
543                r.end_line,
544                r.kind,
545                r.symbol_name,
546            ));
547        } else {
548            out.push_str(&format!(
549                "\n--- Result {} (score: {:.2}) ---\n{} :: {} [{:?}] (L{}-{})\n{}\n",
550                i + 1,
551                r.score,
552                r.file_path,
553                r.symbol_name,
554                r.kind,
555                r.start_line,
556                r.end_line,
557                r.snippet,
558            ));
559        }
560    }
561    out
562}
563
564#[cfg(test)]
565mod tests {
566    use super::*;
567    use tempfile::tempdir;
568
569    #[test]
570    fn tokenize_splits_code() {
571        let tokens = tokenize("fn calculate_total(items: Vec<Item>) -> f64");
572        assert!(tokens.contains(&"calculate_total".to_string()));
573        assert!(tokens.contains(&"items".to_string()));
574        assert!(tokens.contains(&"Vec".to_string()));
575    }
576
577    #[test]
578    fn camel_case_splitting() {
579        let tokens = split_camel_case_tokens(&["calculateTotal".to_string()]);
580        assert!(tokens.contains(&"calculateTotal".to_string()));
581        assert!(tokens.contains(&"calculate".to_string()));
582        assert!(tokens.contains(&"Total".to_string()));
583    }
584
585    #[test]
586    fn detect_rust_function() {
587        let (name, kind) =
588            detect_symbol("pub fn process_request(req: Request) -> Response {").unwrap();
589        assert_eq!(name, "process_request");
590        assert_eq!(kind, ChunkKind::Function);
591    }
592
593    #[test]
594    fn bm25_search_finds_relevant() {
595        let mut index = BM25Index::new();
596        index.add_chunk(CodeChunk {
597            file_path: "auth.rs".into(),
598            symbol_name: "validate_token".into(),
599            kind: ChunkKind::Function,
600            start_line: 1,
601            end_line: 10,
602            content: "fn validate_token(token: &str) -> bool { check_jwt_expiry(token) }".into(),
603            tokens: tokenize("fn validate_token token str bool check_jwt_expiry token"),
604            token_count: 8,
605        });
606        index.add_chunk(CodeChunk {
607            file_path: "db.rs".into(),
608            symbol_name: "connect_database".into(),
609            kind: ChunkKind::Function,
610            start_line: 1,
611            end_line: 5,
612            content: "fn connect_database(url: &str) -> Pool { create_pool(url) }".into(),
613            tokens: tokenize("fn connect_database url str Pool create_pool url"),
614            token_count: 7,
615        });
616        index.finalize();
617
618        let results = index.search("jwt token validation", 5);
619        assert!(!results.is_empty());
620        assert_eq!(results[0].symbol_name, "validate_token");
621    }
622
623    #[test]
624    fn vector_index_is_stale_when_any_indexed_file_is_missing() {
625        let td = tempdir().expect("tempdir");
626        let root = td.path();
627        std::fs::write(root.join("a.rs"), "pub fn a() {}\n").expect("write a.rs");
628
629        let idx = BM25Index::build_from_directory(root);
630        assert!(!vector_index_looks_stale(&idx, root));
631
632        std::fs::remove_file(root.join("a.rs")).expect("remove a.rs");
633        assert!(vector_index_looks_stale(&idx, root));
634    }
635}