Skip to main content

lean_ctx/core/
vector_index.rs

1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3
4use md5::{Digest, Md5};
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct CodeChunk {
9    pub file_path: String,
10    pub symbol_name: String,
11    pub kind: ChunkKind,
12    pub start_line: usize,
13    pub end_line: usize,
14    pub content: String,
15    pub tokens: Vec<String>,
16    pub token_count: usize,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
20pub enum ChunkKind {
21    Function,
22    Struct,
23    Impl,
24    Module,
25    Class,
26    Method,
27    Other,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct BM25Index {
32    pub chunks: Vec<CodeChunk>,
33    pub inverted: HashMap<String, Vec<(usize, f64)>>,
34    pub avg_doc_len: f64,
35    pub doc_count: usize,
36    pub doc_freqs: HashMap<String, usize>,
37}
38
39#[derive(Debug, Clone)]
40pub struct SearchResult {
41    pub chunk_idx: usize,
42    pub score: f64,
43    pub file_path: String,
44    pub symbol_name: String,
45    pub kind: ChunkKind,
46    pub start_line: usize,
47    pub end_line: usize,
48    pub snippet: String,
49}
50
51const BM25_K1: f64 = 1.2;
52const BM25_B: f64 = 0.75;
53
54impl Default for BM25Index {
55    fn default() -> Self {
56        Self::new()
57    }
58}
59
60impl BM25Index {
61    pub fn new() -> Self {
62        Self {
63            chunks: Vec::new(),
64            inverted: HashMap::new(),
65            avg_doc_len: 0.0,
66            doc_count: 0,
67            doc_freqs: HashMap::new(),
68        }
69    }
70
71    pub fn build_from_directory(root: &Path) -> Self {
72        let mut index = Self::new();
73        let walker = ignore::WalkBuilder::new(root)
74            .hidden(true)
75            .git_ignore(true)
76            .max_depth(Some(10))
77            .build();
78
79        let mut file_count = 0usize;
80        for entry in walker.flatten() {
81            if file_count >= 2000 {
82                break;
83            }
84            let path = entry.path();
85            if !path.is_file() {
86                continue;
87            }
88            if !is_code_file(path) {
89                continue;
90            }
91            if let Ok(content) = std::fs::read_to_string(path) {
92                let rel = path
93                    .strip_prefix(root)
94                    .unwrap_or(path)
95                    .to_string_lossy()
96                    .to_string();
97                let ext = path
98                    .extension()
99                    .and_then(|e| e.to_str())
100                    .unwrap_or("");
101                let chunks = super::chunks_ts::extract_chunks_ts(&rel, &content, ext)
102                    .unwrap_or_else(|| extract_chunks(&rel, &content));
103                for chunk in chunks {
104                    index.add_chunk(chunk);
105                }
106                file_count += 1;
107            }
108        }
109
110        index.finalize();
111        index
112    }
113
114    fn add_chunk(&mut self, chunk: CodeChunk) {
115        let idx = self.chunks.len();
116
117        for token in &chunk.tokens {
118            let lower = token.to_lowercase();
119            self.inverted.entry(lower).or_default().push((idx, 1.0));
120        }
121
122        self.chunks.push(chunk);
123    }
124
125    fn finalize(&mut self) {
126        self.doc_count = self.chunks.len();
127        if self.doc_count == 0 {
128            return;
129        }
130
131        let total_len: usize = self.chunks.iter().map(|c| c.token_count).sum();
132        self.avg_doc_len = total_len as f64 / self.doc_count as f64;
133
134        self.doc_freqs.clear();
135        for (term, postings) in &self.inverted {
136            let unique_docs: std::collections::HashSet<usize> =
137                postings.iter().map(|(idx, _)| *idx).collect();
138            self.doc_freqs.insert(term.clone(), unique_docs.len());
139        }
140    }
141
142    pub fn search(&self, query: &str, top_k: usize) -> Vec<SearchResult> {
143        let query_tokens = tokenize(query);
144        if query_tokens.is_empty() || self.doc_count == 0 {
145            return Vec::new();
146        }
147
148        let mut scores: HashMap<usize, f64> = HashMap::new();
149
150        for token in &query_tokens {
151            let lower = token.to_lowercase();
152            let df = *self.doc_freqs.get(&lower).unwrap_or(&0) as f64;
153            if df == 0.0 {
154                continue;
155            }
156
157            let idf = ((self.doc_count as f64 - df + 0.5) / (df + 0.5) + 1.0).ln();
158
159            if let Some(postings) = self.inverted.get(&lower) {
160                let mut doc_tfs: HashMap<usize, f64> = HashMap::new();
161                for (idx, weight) in postings {
162                    *doc_tfs.entry(*idx).or_insert(0.0) += weight;
163                }
164
165                for (doc_idx, tf) in &doc_tfs {
166                    let doc_len = self.chunks[*doc_idx].token_count as f64;
167                    let norm_len = doc_len / self.avg_doc_len.max(1.0);
168                    let bm25 = idf * (tf * (BM25_K1 + 1.0))
169                        / (tf + BM25_K1 * (1.0 - BM25_B + BM25_B * norm_len));
170
171                    *scores.entry(*doc_idx).or_insert(0.0) += bm25;
172                }
173            }
174        }
175
176        let mut results: Vec<SearchResult> = scores
177            .into_iter()
178            .map(|(idx, score)| {
179                let chunk = &self.chunks[idx];
180                let snippet = chunk.content.lines().take(5).collect::<Vec<_>>().join("\n");
181                SearchResult {
182                    chunk_idx: idx,
183                    score,
184                    file_path: chunk.file_path.clone(),
185                    symbol_name: chunk.symbol_name.clone(),
186                    kind: chunk.kind.clone(),
187                    start_line: chunk.start_line,
188                    end_line: chunk.end_line,
189                    snippet,
190                }
191            })
192            .collect();
193
194        results.sort_by(|a, b| {
195            b.score
196                .partial_cmp(&a.score)
197                .unwrap_or(std::cmp::Ordering::Equal)
198        });
199        results.truncate(top_k);
200        results
201    }
202
203    pub fn save(&self, root: &Path) -> std::io::Result<()> {
204        let dir = index_dir(root);
205        std::fs::create_dir_all(&dir)?;
206        let data = serde_json::to_string(self).map_err(std::io::Error::other)?;
207        std::fs::write(dir.join("bm25_index.json"), data)?;
208        Ok(())
209    }
210
211    pub fn load(root: &Path) -> Option<Self> {
212        let path = index_dir(root).join("bm25_index.json");
213        let data = std::fs::read_to_string(path).ok()?;
214        serde_json::from_str(&data).ok()
215    }
216}
217
218fn index_dir(root: &Path) -> PathBuf {
219    let mut hasher = Md5::new();
220    hasher.update(root.to_string_lossy().as_bytes());
221    let hash = format!("{:x}", hasher.finalize());
222    dirs::home_dir()
223        .unwrap_or_else(|| PathBuf::from("."))
224        .join(".lean-ctx")
225        .join("vectors")
226        .join(hash)
227}
228
229fn is_code_file(path: &Path) -> bool {
230    let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
231    matches!(
232        ext,
233        "rs" | "ts"
234            | "tsx"
235            | "js"
236            | "jsx"
237            | "py"
238            | "go"
239            | "java"
240            | "c"
241            | "cpp"
242            | "h"
243            | "hpp"
244            | "rb"
245            | "cs"
246            | "kt"
247            | "swift"
248            | "php"
249            | "scala"
250            | "ex"
251            | "exs"
252            | "zig"
253            | "lua"
254            | "dart"
255            | "vue"
256            | "svelte"
257    )
258}
259
260pub fn tokenize_for_index(text: &str) -> Vec<String> {
261    tokenize(text)
262}
263
264fn tokenize(text: &str) -> Vec<String> {
265    let mut tokens = Vec::new();
266    let mut current = String::new();
267
268    for ch in text.chars() {
269        if ch.is_alphanumeric() || ch == '_' {
270            current.push(ch);
271        } else {
272            if current.len() >= 2 {
273                tokens.push(current.clone());
274            }
275            current.clear();
276        }
277    }
278    if current.len() >= 2 {
279        tokens.push(current);
280    }
281
282    let mut expanded = Vec::new();
283    for token in &tokens {
284        expanded.push(token.clone());
285        if token.contains('_') {
286            for part in token.split('_') {
287                if part.len() >= 2 {
288                    expanded.push(part.to_string());
289                }
290            }
291        }
292    }
293
294    split_camel_case_tokens(&expanded)
295}
296
297fn split_camel_case_tokens(tokens: &[String]) -> Vec<String> {
298    let mut result = Vec::new();
299    for token in tokens {
300        result.push(token.clone());
301        let mut start = 0;
302        let chars: Vec<char> = token.chars().collect();
303        for i in 1..chars.len() {
304            if chars[i].is_uppercase() && (i + 1 >= chars.len() || !chars[i + 1].is_uppercase()) {
305                let part: String = chars[start..i].iter().collect();
306                if part.len() >= 2 {
307                    result.push(part);
308                }
309                start = i;
310            }
311        }
312        if start > 0 {
313            let part: String = chars[start..].iter().collect();
314            if part.len() >= 2 {
315                result.push(part);
316            }
317        }
318    }
319    result
320}
321
322fn extract_chunks(file_path: &str, content: &str) -> Vec<CodeChunk> {
323    let lines: Vec<&str> = content.lines().collect();
324    if lines.is_empty() {
325        return Vec::new();
326    }
327
328    let mut chunks = Vec::new();
329    let mut i = 0;
330
331    while i < lines.len() {
332        let trimmed = lines[i].trim();
333
334        if let Some((name, kind)) = detect_symbol(trimmed) {
335            let start = i;
336            let end = find_block_end(&lines, i);
337            let block: String = lines[start..=end.min(lines.len() - 1)].to_vec().join("\n");
338            let tokens = tokenize(&block);
339            let token_count = tokens.len();
340
341            chunks.push(CodeChunk {
342                file_path: file_path.to_string(),
343                symbol_name: name,
344                kind,
345                start_line: start + 1,
346                end_line: end + 1,
347                content: block,
348                tokens,
349                token_count,
350            });
351
352            i = end + 1;
353        } else {
354            i += 1;
355        }
356    }
357
358    if chunks.is_empty() && !content.is_empty() {
359        let tokens = tokenize(content);
360        let token_count = tokens.len();
361        let snippet = lines
362            .iter()
363            .take(50)
364            .copied()
365            .collect::<Vec<_>>()
366            .join("\n");
367        chunks.push(CodeChunk {
368            file_path: file_path.to_string(),
369            symbol_name: file_path.to_string(),
370            kind: ChunkKind::Module,
371            start_line: 1,
372            end_line: lines.len(),
373            content: snippet,
374            tokens,
375            token_count,
376        });
377    }
378
379    chunks
380}
381
382fn detect_symbol(line: &str) -> Option<(String, ChunkKind)> {
383    let trimmed = line.trim();
384
385    let patterns: &[(&str, ChunkKind)] = &[
386        ("pub async fn ", ChunkKind::Function),
387        ("async fn ", ChunkKind::Function),
388        ("pub fn ", ChunkKind::Function),
389        ("fn ", ChunkKind::Function),
390        ("pub struct ", ChunkKind::Struct),
391        ("struct ", ChunkKind::Struct),
392        ("pub enum ", ChunkKind::Struct),
393        ("enum ", ChunkKind::Struct),
394        ("impl ", ChunkKind::Impl),
395        ("pub trait ", ChunkKind::Struct),
396        ("trait ", ChunkKind::Struct),
397        ("export function ", ChunkKind::Function),
398        ("export async function ", ChunkKind::Function),
399        ("export default function ", ChunkKind::Function),
400        ("function ", ChunkKind::Function),
401        ("async function ", ChunkKind::Function),
402        ("export class ", ChunkKind::Class),
403        ("class ", ChunkKind::Class),
404        ("export interface ", ChunkKind::Struct),
405        ("interface ", ChunkKind::Struct),
406        ("def ", ChunkKind::Function),
407        ("async def ", ChunkKind::Function),
408        ("class ", ChunkKind::Class),
409        ("func ", ChunkKind::Function),
410    ];
411
412    for (prefix, kind) in patterns {
413        if let Some(rest) = trimmed.strip_prefix(prefix) {
414            let name: String = rest
415                .chars()
416                .take_while(|c| c.is_alphanumeric() || *c == '_' || *c == '<')
417                .take_while(|c| *c != '<')
418                .collect();
419            if !name.is_empty() {
420                return Some((name, kind.clone()));
421            }
422        }
423    }
424
425    None
426}
427
428fn find_block_end(lines: &[&str], start: usize) -> usize {
429    let mut depth = 0i32;
430    let mut found_open = false;
431
432    for (i, line) in lines.iter().enumerate().skip(start) {
433        for ch in line.chars() {
434            match ch {
435                '{' | '(' if !found_open || depth > 0 => {
436                    depth += 1;
437                    found_open = true;
438                }
439                '}' | ')' if depth > 0 => {
440                    depth -= 1;
441                    if depth == 0 && found_open {
442                        return i;
443                    }
444                }
445                _ => {}
446            }
447        }
448
449        if found_open && depth <= 0 && i > start {
450            return i;
451        }
452
453        if !found_open && i > start + 2 {
454            let trimmed = lines[i].trim();
455            if trimmed.is_empty()
456                || (!trimmed.starts_with(' ') && !trimmed.starts_with('\t') && i > start)
457            {
458                return i.saturating_sub(1);
459            }
460        }
461    }
462
463    (start + 50).min(lines.len().saturating_sub(1))
464}
465
466pub fn format_search_results(results: &[SearchResult], compact: bool) -> String {
467    if results.is_empty() {
468        return "No results found.".to_string();
469    }
470
471    let mut out = String::new();
472    for (i, r) in results.iter().enumerate() {
473        if compact {
474            out.push_str(&format!(
475                "{}. {:.2} {}:{}-{} {:?} {}\n",
476                i + 1,
477                r.score,
478                r.file_path,
479                r.start_line,
480                r.end_line,
481                r.kind,
482                r.symbol_name,
483            ));
484        } else {
485            out.push_str(&format!(
486                "\n--- Result {} (score: {:.2}) ---\n{} :: {} [{:?}] (L{}-{})\n{}\n",
487                i + 1,
488                r.score,
489                r.file_path,
490                r.symbol_name,
491                r.kind,
492                r.start_line,
493                r.end_line,
494                r.snippet,
495            ));
496        }
497    }
498    out
499}
500
501#[cfg(test)]
502mod tests {
503    use super::*;
504
505    #[test]
506    fn tokenize_splits_code() {
507        let tokens = tokenize("fn calculate_total(items: Vec<Item>) -> f64");
508        assert!(tokens.contains(&"calculate_total".to_string()));
509        assert!(tokens.contains(&"items".to_string()));
510        assert!(tokens.contains(&"Vec".to_string()));
511    }
512
513    #[test]
514    fn camel_case_splitting() {
515        let tokens = split_camel_case_tokens(&["calculateTotal".to_string()]);
516        assert!(tokens.contains(&"calculateTotal".to_string()));
517        assert!(tokens.contains(&"calculate".to_string()));
518        assert!(tokens.contains(&"Total".to_string()));
519    }
520
521    #[test]
522    fn detect_rust_function() {
523        let (name, kind) =
524            detect_symbol("pub fn process_request(req: Request) -> Response {").unwrap();
525        assert_eq!(name, "process_request");
526        assert_eq!(kind, ChunkKind::Function);
527    }
528
529    #[test]
530    fn bm25_search_finds_relevant() {
531        let mut index = BM25Index::new();
532        index.add_chunk(CodeChunk {
533            file_path: "auth.rs".into(),
534            symbol_name: "validate_token".into(),
535            kind: ChunkKind::Function,
536            start_line: 1,
537            end_line: 10,
538            content: "fn validate_token(token: &str) -> bool { check_jwt_expiry(token) }".into(),
539            tokens: tokenize("fn validate_token token str bool check_jwt_expiry token"),
540            token_count: 8,
541        });
542        index.add_chunk(CodeChunk {
543            file_path: "db.rs".into(),
544            symbol_name: "connect_database".into(),
545            kind: ChunkKind::Function,
546            start_line: 1,
547            end_line: 5,
548            content: "fn connect_database(url: &str) -> Pool { create_pool(url) }".into(),
549            tokens: tokenize("fn connect_database url str Pool create_pool url"),
550            token_count: 7,
551        });
552        index.finalize();
553
554        let results = index.search("jwt token validation", 5);
555        assert!(!results.is_empty());
556        assert_eq!(results[0].symbol_name, "validate_token");
557    }
558}