1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3
4use md5::{Digest, Md5};
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct CodeChunk {
9 pub file_path: String,
10 pub symbol_name: String,
11 pub kind: ChunkKind,
12 pub start_line: usize,
13 pub end_line: usize,
14 pub content: String,
15 pub tokens: Vec<String>,
16 pub token_count: usize,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
20pub enum ChunkKind {
21 Function,
22 Struct,
23 Impl,
24 Module,
25 Class,
26 Method,
27 Other,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct BM25Index {
32 pub chunks: Vec<CodeChunk>,
33 pub inverted: HashMap<String, Vec<(usize, f64)>>,
34 pub avg_doc_len: f64,
35 pub doc_count: usize,
36 pub doc_freqs: HashMap<String, usize>,
37}
38
39#[derive(Debug, Clone)]
40pub struct SearchResult {
41 pub chunk_idx: usize,
42 pub score: f64,
43 pub file_path: String,
44 pub symbol_name: String,
45 pub kind: ChunkKind,
46 pub start_line: usize,
47 pub end_line: usize,
48 pub snippet: String,
49}
50
51const BM25_K1: f64 = 1.2;
52const BM25_B: f64 = 0.75;
53
54impl Default for BM25Index {
55 fn default() -> Self {
56 Self::new()
57 }
58}
59
60impl BM25Index {
61 pub fn new() -> Self {
62 Self {
63 chunks: Vec::new(),
64 inverted: HashMap::new(),
65 avg_doc_len: 0.0,
66 doc_count: 0,
67 doc_freqs: HashMap::new(),
68 }
69 }
70
71 pub fn build_from_directory(root: &Path) -> Self {
72 let mut index = Self::new();
73 let walker = ignore::WalkBuilder::new(root)
74 .hidden(true)
75 .git_ignore(true)
76 .max_depth(Some(10))
77 .build();
78
79 let mut file_count = 0usize;
80 for entry in walker.flatten() {
81 if file_count >= 2000 {
82 break;
83 }
84 let path = entry.path();
85 if !path.is_file() {
86 continue;
87 }
88 if !is_code_file(path) {
89 continue;
90 }
91 if let Ok(content) = std::fs::read_to_string(path) {
92 let rel = path
93 .strip_prefix(root)
94 .unwrap_or(path)
95 .to_string_lossy()
96 .to_string();
97 let ext = path
98 .extension()
99 .and_then(|e| e.to_str())
100 .unwrap_or("");
101 let chunks = super::chunks_ts::extract_chunks_ts(&rel, &content, ext)
102 .unwrap_or_else(|| extract_chunks(&rel, &content));
103 for chunk in chunks {
104 index.add_chunk(chunk);
105 }
106 file_count += 1;
107 }
108 }
109
110 index.finalize();
111 index
112 }
113
114 fn add_chunk(&mut self, chunk: CodeChunk) {
115 let idx = self.chunks.len();
116
117 for token in &chunk.tokens {
118 let lower = token.to_lowercase();
119 self.inverted.entry(lower).or_default().push((idx, 1.0));
120 }
121
122 self.chunks.push(chunk);
123 }
124
125 fn finalize(&mut self) {
126 self.doc_count = self.chunks.len();
127 if self.doc_count == 0 {
128 return;
129 }
130
131 let total_len: usize = self.chunks.iter().map(|c| c.token_count).sum();
132 self.avg_doc_len = total_len as f64 / self.doc_count as f64;
133
134 self.doc_freqs.clear();
135 for (term, postings) in &self.inverted {
136 let unique_docs: std::collections::HashSet<usize> =
137 postings.iter().map(|(idx, _)| *idx).collect();
138 self.doc_freqs.insert(term.clone(), unique_docs.len());
139 }
140 }
141
142 pub fn search(&self, query: &str, top_k: usize) -> Vec<SearchResult> {
143 let query_tokens = tokenize(query);
144 if query_tokens.is_empty() || self.doc_count == 0 {
145 return Vec::new();
146 }
147
148 let mut scores: HashMap<usize, f64> = HashMap::new();
149
150 for token in &query_tokens {
151 let lower = token.to_lowercase();
152 let df = *self.doc_freqs.get(&lower).unwrap_or(&0) as f64;
153 if df == 0.0 {
154 continue;
155 }
156
157 let idf = ((self.doc_count as f64 - df + 0.5) / (df + 0.5) + 1.0).ln();
158
159 if let Some(postings) = self.inverted.get(&lower) {
160 let mut doc_tfs: HashMap<usize, f64> = HashMap::new();
161 for (idx, weight) in postings {
162 *doc_tfs.entry(*idx).or_insert(0.0) += weight;
163 }
164
165 for (doc_idx, tf) in &doc_tfs {
166 let doc_len = self.chunks[*doc_idx].token_count as f64;
167 let norm_len = doc_len / self.avg_doc_len.max(1.0);
168 let bm25 = idf * (tf * (BM25_K1 + 1.0))
169 / (tf + BM25_K1 * (1.0 - BM25_B + BM25_B * norm_len));
170
171 *scores.entry(*doc_idx).or_insert(0.0) += bm25;
172 }
173 }
174 }
175
176 let mut results: Vec<SearchResult> = scores
177 .into_iter()
178 .map(|(idx, score)| {
179 let chunk = &self.chunks[idx];
180 let snippet = chunk.content.lines().take(5).collect::<Vec<_>>().join("\n");
181 SearchResult {
182 chunk_idx: idx,
183 score,
184 file_path: chunk.file_path.clone(),
185 symbol_name: chunk.symbol_name.clone(),
186 kind: chunk.kind.clone(),
187 start_line: chunk.start_line,
188 end_line: chunk.end_line,
189 snippet,
190 }
191 })
192 .collect();
193
194 results.sort_by(|a, b| {
195 b.score
196 .partial_cmp(&a.score)
197 .unwrap_or(std::cmp::Ordering::Equal)
198 });
199 results.truncate(top_k);
200 results
201 }
202
203 pub fn save(&self, root: &Path) -> std::io::Result<()> {
204 let dir = index_dir(root);
205 std::fs::create_dir_all(&dir)?;
206 let data = serde_json::to_string(self).map_err(std::io::Error::other)?;
207 std::fs::write(dir.join("bm25_index.json"), data)?;
208 Ok(())
209 }
210
211 pub fn load(root: &Path) -> Option<Self> {
212 let path = index_dir(root).join("bm25_index.json");
213 let data = std::fs::read_to_string(path).ok()?;
214 serde_json::from_str(&data).ok()
215 }
216}
217
218fn index_dir(root: &Path) -> PathBuf {
219 let mut hasher = Md5::new();
220 hasher.update(root.to_string_lossy().as_bytes());
221 let hash = format!("{:x}", hasher.finalize());
222 dirs::home_dir()
223 .unwrap_or_else(|| PathBuf::from("."))
224 .join(".lean-ctx")
225 .join("vectors")
226 .join(hash)
227}
228
229fn is_code_file(path: &Path) -> bool {
230 let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
231 matches!(
232 ext,
233 "rs" | "ts"
234 | "tsx"
235 | "js"
236 | "jsx"
237 | "py"
238 | "go"
239 | "java"
240 | "c"
241 | "cpp"
242 | "h"
243 | "hpp"
244 | "rb"
245 | "cs"
246 | "kt"
247 | "swift"
248 | "php"
249 | "scala"
250 | "ex"
251 | "exs"
252 | "zig"
253 | "lua"
254 | "dart"
255 | "vue"
256 | "svelte"
257 )
258}
259
260pub fn tokenize_for_index(text: &str) -> Vec<String> {
261 tokenize(text)
262}
263
264fn tokenize(text: &str) -> Vec<String> {
265 let mut tokens = Vec::new();
266 let mut current = String::new();
267
268 for ch in text.chars() {
269 if ch.is_alphanumeric() || ch == '_' {
270 current.push(ch);
271 } else {
272 if current.len() >= 2 {
273 tokens.push(current.clone());
274 }
275 current.clear();
276 }
277 }
278 if current.len() >= 2 {
279 tokens.push(current);
280 }
281
282 let mut expanded = Vec::new();
283 for token in &tokens {
284 expanded.push(token.clone());
285 if token.contains('_') {
286 for part in token.split('_') {
287 if part.len() >= 2 {
288 expanded.push(part.to_string());
289 }
290 }
291 }
292 }
293
294 split_camel_case_tokens(&expanded)
295}
296
297fn split_camel_case_tokens(tokens: &[String]) -> Vec<String> {
298 let mut result = Vec::new();
299 for token in tokens {
300 result.push(token.clone());
301 let mut start = 0;
302 let chars: Vec<char> = token.chars().collect();
303 for i in 1..chars.len() {
304 if chars[i].is_uppercase() && (i + 1 >= chars.len() || !chars[i + 1].is_uppercase()) {
305 let part: String = chars[start..i].iter().collect();
306 if part.len() >= 2 {
307 result.push(part);
308 }
309 start = i;
310 }
311 }
312 if start > 0 {
313 let part: String = chars[start..].iter().collect();
314 if part.len() >= 2 {
315 result.push(part);
316 }
317 }
318 }
319 result
320}
321
322fn extract_chunks(file_path: &str, content: &str) -> Vec<CodeChunk> {
323 let lines: Vec<&str> = content.lines().collect();
324 if lines.is_empty() {
325 return Vec::new();
326 }
327
328 let mut chunks = Vec::new();
329 let mut i = 0;
330
331 while i < lines.len() {
332 let trimmed = lines[i].trim();
333
334 if let Some((name, kind)) = detect_symbol(trimmed) {
335 let start = i;
336 let end = find_block_end(&lines, i);
337 let block: String = lines[start..=end.min(lines.len() - 1)].to_vec().join("\n");
338 let tokens = tokenize(&block);
339 let token_count = tokens.len();
340
341 chunks.push(CodeChunk {
342 file_path: file_path.to_string(),
343 symbol_name: name,
344 kind,
345 start_line: start + 1,
346 end_line: end + 1,
347 content: block,
348 tokens,
349 token_count,
350 });
351
352 i = end + 1;
353 } else {
354 i += 1;
355 }
356 }
357
358 if chunks.is_empty() && !content.is_empty() {
359 let tokens = tokenize(content);
360 let token_count = tokens.len();
361 let snippet = lines
362 .iter()
363 .take(50)
364 .copied()
365 .collect::<Vec<_>>()
366 .join("\n");
367 chunks.push(CodeChunk {
368 file_path: file_path.to_string(),
369 symbol_name: file_path.to_string(),
370 kind: ChunkKind::Module,
371 start_line: 1,
372 end_line: lines.len(),
373 content: snippet,
374 tokens,
375 token_count,
376 });
377 }
378
379 chunks
380}
381
382fn detect_symbol(line: &str) -> Option<(String, ChunkKind)> {
383 let trimmed = line.trim();
384
385 let patterns: &[(&str, ChunkKind)] = &[
386 ("pub async fn ", ChunkKind::Function),
387 ("async fn ", ChunkKind::Function),
388 ("pub fn ", ChunkKind::Function),
389 ("fn ", ChunkKind::Function),
390 ("pub struct ", ChunkKind::Struct),
391 ("struct ", ChunkKind::Struct),
392 ("pub enum ", ChunkKind::Struct),
393 ("enum ", ChunkKind::Struct),
394 ("impl ", ChunkKind::Impl),
395 ("pub trait ", ChunkKind::Struct),
396 ("trait ", ChunkKind::Struct),
397 ("export function ", ChunkKind::Function),
398 ("export async function ", ChunkKind::Function),
399 ("export default function ", ChunkKind::Function),
400 ("function ", ChunkKind::Function),
401 ("async function ", ChunkKind::Function),
402 ("export class ", ChunkKind::Class),
403 ("class ", ChunkKind::Class),
404 ("export interface ", ChunkKind::Struct),
405 ("interface ", ChunkKind::Struct),
406 ("def ", ChunkKind::Function),
407 ("async def ", ChunkKind::Function),
408 ("class ", ChunkKind::Class),
409 ("func ", ChunkKind::Function),
410 ];
411
412 for (prefix, kind) in patterns {
413 if let Some(rest) = trimmed.strip_prefix(prefix) {
414 let name: String = rest
415 .chars()
416 .take_while(|c| c.is_alphanumeric() || *c == '_' || *c == '<')
417 .take_while(|c| *c != '<')
418 .collect();
419 if !name.is_empty() {
420 return Some((name, kind.clone()));
421 }
422 }
423 }
424
425 None
426}
427
428fn find_block_end(lines: &[&str], start: usize) -> usize {
429 let mut depth = 0i32;
430 let mut found_open = false;
431
432 for (i, line) in lines.iter().enumerate().skip(start) {
433 for ch in line.chars() {
434 match ch {
435 '{' | '(' if !found_open || depth > 0 => {
436 depth += 1;
437 found_open = true;
438 }
439 '}' | ')' if depth > 0 => {
440 depth -= 1;
441 if depth == 0 && found_open {
442 return i;
443 }
444 }
445 _ => {}
446 }
447 }
448
449 if found_open && depth <= 0 && i > start {
450 return i;
451 }
452
453 if !found_open && i > start + 2 {
454 let trimmed = lines[i].trim();
455 if trimmed.is_empty()
456 || (!trimmed.starts_with(' ') && !trimmed.starts_with('\t') && i > start)
457 {
458 return i.saturating_sub(1);
459 }
460 }
461 }
462
463 (start + 50).min(lines.len().saturating_sub(1))
464}
465
466pub fn format_search_results(results: &[SearchResult], compact: bool) -> String {
467 if results.is_empty() {
468 return "No results found.".to_string();
469 }
470
471 let mut out = String::new();
472 for (i, r) in results.iter().enumerate() {
473 if compact {
474 out.push_str(&format!(
475 "{}. {:.2} {}:{}-{} {:?} {}\n",
476 i + 1,
477 r.score,
478 r.file_path,
479 r.start_line,
480 r.end_line,
481 r.kind,
482 r.symbol_name,
483 ));
484 } else {
485 out.push_str(&format!(
486 "\n--- Result {} (score: {:.2}) ---\n{} :: {} [{:?}] (L{}-{})\n{}\n",
487 i + 1,
488 r.score,
489 r.file_path,
490 r.symbol_name,
491 r.kind,
492 r.start_line,
493 r.end_line,
494 r.snippet,
495 ));
496 }
497 }
498 out
499}
500
501#[cfg(test)]
502mod tests {
503 use super::*;
504
505 #[test]
506 fn tokenize_splits_code() {
507 let tokens = tokenize("fn calculate_total(items: Vec<Item>) -> f64");
508 assert!(tokens.contains(&"calculate_total".to_string()));
509 assert!(tokens.contains(&"items".to_string()));
510 assert!(tokens.contains(&"Vec".to_string()));
511 }
512
513 #[test]
514 fn camel_case_splitting() {
515 let tokens = split_camel_case_tokens(&["calculateTotal".to_string()]);
516 assert!(tokens.contains(&"calculateTotal".to_string()));
517 assert!(tokens.contains(&"calculate".to_string()));
518 assert!(tokens.contains(&"Total".to_string()));
519 }
520
521 #[test]
522 fn detect_rust_function() {
523 let (name, kind) =
524 detect_symbol("pub fn process_request(req: Request) -> Response {").unwrap();
525 assert_eq!(name, "process_request");
526 assert_eq!(kind, ChunkKind::Function);
527 }
528
529 #[test]
530 fn bm25_search_finds_relevant() {
531 let mut index = BM25Index::new();
532 index.add_chunk(CodeChunk {
533 file_path: "auth.rs".into(),
534 symbol_name: "validate_token".into(),
535 kind: ChunkKind::Function,
536 start_line: 1,
537 end_line: 10,
538 content: "fn validate_token(token: &str) -> bool { check_jwt_expiry(token) }".into(),
539 tokens: tokenize("fn validate_token token str bool check_jwt_expiry token"),
540 token_count: 8,
541 });
542 index.add_chunk(CodeChunk {
543 file_path: "db.rs".into(),
544 symbol_name: "connect_database".into(),
545 kind: ChunkKind::Function,
546 start_line: 1,
547 end_line: 5,
548 content: "fn connect_database(url: &str) -> Pool { create_pool(url) }".into(),
549 tokens: tokenize("fn connect_database url str Pool create_pool url"),
550 token_count: 7,
551 });
552 index.finalize();
553
554 let results = index.search("jwt token validation", 5);
555 assert!(!results.is_empty());
556 assert_eq!(results[0].symbol_name, "validate_token");
557 }
558}