1use std::collections::HashMap;
2use std::path::{Path, PathBuf};
3
4use md5::{Digest, Md5};
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8pub struct CodeChunk {
9 pub file_path: String,
10 pub symbol_name: String,
11 pub kind: ChunkKind,
12 pub start_line: usize,
13 pub end_line: usize,
14 pub content: String,
15 pub tokens: Vec<String>,
16 pub token_count: usize,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
20pub enum ChunkKind {
21 Function,
22 Struct,
23 Impl,
24 Module,
25 Class,
26 Method,
27 Other,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct BM25Index {
32 pub chunks: Vec<CodeChunk>,
33 pub inverted: HashMap<String, Vec<(usize, f64)>>,
34 pub avg_doc_len: f64,
35 pub doc_count: usize,
36 pub doc_freqs: HashMap<String, usize>,
37}
38
39#[derive(Debug, Clone)]
40pub struct SearchResult {
41 pub chunk_idx: usize,
42 pub score: f64,
43 pub file_path: String,
44 pub symbol_name: String,
45 pub kind: ChunkKind,
46 pub start_line: usize,
47 pub end_line: usize,
48 pub snippet: String,
49}
50
51const BM25_K1: f64 = 1.2;
52const BM25_B: f64 = 0.75;
53
54impl Default for BM25Index {
55 fn default() -> Self {
56 Self::new()
57 }
58}
59
60impl BM25Index {
61 pub fn new() -> Self {
62 Self {
63 chunks: Vec::new(),
64 inverted: HashMap::new(),
65 avg_doc_len: 0.0,
66 doc_count: 0,
67 doc_freqs: HashMap::new(),
68 }
69 }
70
71 pub fn build_from_directory(root: &Path) -> Self {
72 let mut index = Self::new();
73 let walker = ignore::WalkBuilder::new(root)
74 .hidden(true)
75 .git_ignore(true)
76 .max_depth(Some(10))
77 .build();
78
79 let mut file_count = 0usize;
80 for entry in walker.flatten() {
81 if file_count >= 2000 {
82 break;
83 }
84 let path = entry.path();
85 if !path.is_file() {
86 continue;
87 }
88 if !is_code_file(path) {
89 continue;
90 }
91 if let Ok(content) = std::fs::read_to_string(path) {
92 let rel = path
93 .strip_prefix(root)
94 .unwrap_or(path)
95 .to_string_lossy()
96 .to_string();
97 let chunks = extract_chunks(&rel, &content);
98 for chunk in chunks {
99 index.add_chunk(chunk);
100 }
101 file_count += 1;
102 }
103 }
104
105 index.finalize();
106 index
107 }
108
109 fn add_chunk(&mut self, chunk: CodeChunk) {
110 let idx = self.chunks.len();
111
112 for token in &chunk.tokens {
113 let lower = token.to_lowercase();
114 self.inverted.entry(lower).or_default().push((idx, 1.0));
115 }
116
117 self.chunks.push(chunk);
118 }
119
120 fn finalize(&mut self) {
121 self.doc_count = self.chunks.len();
122 if self.doc_count == 0 {
123 return;
124 }
125
126 let total_len: usize = self.chunks.iter().map(|c| c.token_count).sum();
127 self.avg_doc_len = total_len as f64 / self.doc_count as f64;
128
129 self.doc_freqs.clear();
130 for (term, postings) in &self.inverted {
131 let unique_docs: std::collections::HashSet<usize> =
132 postings.iter().map(|(idx, _)| *idx).collect();
133 self.doc_freqs.insert(term.clone(), unique_docs.len());
134 }
135 }
136
137 pub fn search(&self, query: &str, top_k: usize) -> Vec<SearchResult> {
138 let query_tokens = tokenize(query);
139 if query_tokens.is_empty() || self.doc_count == 0 {
140 return Vec::new();
141 }
142
143 let mut scores: HashMap<usize, f64> = HashMap::new();
144
145 for token in &query_tokens {
146 let lower = token.to_lowercase();
147 let df = *self.doc_freqs.get(&lower).unwrap_or(&0) as f64;
148 if df == 0.0 {
149 continue;
150 }
151
152 let idf = ((self.doc_count as f64 - df + 0.5) / (df + 0.5) + 1.0).ln();
153
154 if let Some(postings) = self.inverted.get(&lower) {
155 let mut doc_tfs: HashMap<usize, f64> = HashMap::new();
156 for (idx, weight) in postings {
157 *doc_tfs.entry(*idx).or_insert(0.0) += weight;
158 }
159
160 for (doc_idx, tf) in &doc_tfs {
161 let doc_len = self.chunks[*doc_idx].token_count as f64;
162 let norm_len = doc_len / self.avg_doc_len.max(1.0);
163 let bm25 = idf * (tf * (BM25_K1 + 1.0))
164 / (tf + BM25_K1 * (1.0 - BM25_B + BM25_B * norm_len));
165
166 *scores.entry(*doc_idx).or_insert(0.0) += bm25;
167 }
168 }
169 }
170
171 let mut results: Vec<SearchResult> = scores
172 .into_iter()
173 .map(|(idx, score)| {
174 let chunk = &self.chunks[idx];
175 let snippet = chunk.content.lines().take(5).collect::<Vec<_>>().join("\n");
176 SearchResult {
177 chunk_idx: idx,
178 score,
179 file_path: chunk.file_path.clone(),
180 symbol_name: chunk.symbol_name.clone(),
181 kind: chunk.kind.clone(),
182 start_line: chunk.start_line,
183 end_line: chunk.end_line,
184 snippet,
185 }
186 })
187 .collect();
188
189 results.sort_by(|a, b| {
190 b.score
191 .partial_cmp(&a.score)
192 .unwrap_or(std::cmp::Ordering::Equal)
193 });
194 results.truncate(top_k);
195 results
196 }
197
198 pub fn save(&self, root: &Path) -> std::io::Result<()> {
199 let dir = index_dir(root);
200 std::fs::create_dir_all(&dir)?;
201 let data = serde_json::to_string(self).map_err(std::io::Error::other)?;
202 std::fs::write(dir.join("bm25_index.json"), data)?;
203 Ok(())
204 }
205
206 pub fn load(root: &Path) -> Option<Self> {
207 let path = index_dir(root).join("bm25_index.json");
208 let data = std::fs::read_to_string(path).ok()?;
209 serde_json::from_str(&data).ok()
210 }
211
212 pub fn load_or_build(root: &Path) -> Self {
213 if let Some(idx) = Self::load(root) {
214 if !vector_index_looks_stale(&idx, root) {
215 return idx;
216 }
217 tracing::warn!(
218 "[vector_index: stale index detected for {}; rebuilding]",
219 root.display()
220 );
221 }
222
223 let built = Self::build_from_directory(root);
224 let _ = built.save(root);
225 built
226 }
227
228 pub fn index_file_path(root: &Path) -> PathBuf {
229 index_dir(root).join("bm25_index.json")
230 }
231}
232
233fn vector_index_looks_stale(index: &BM25Index, root: &Path) -> bool {
234 if index.chunks.is_empty() {
235 return false;
236 }
237
238 let mut seen = std::collections::HashSet::<&str>::new();
239 for chunk in &index.chunks {
240 let rel = chunk.file_path.trim_start_matches(['/', '\\']);
241 if rel.is_empty() {
242 continue;
243 }
244 if !seen.insert(rel) {
245 continue;
246 }
247 if !root.join(rel).exists() {
248 return true;
249 }
250 }
251
252 false
253}
254
255fn index_dir(root: &Path) -> PathBuf {
256 let mut hasher = Md5::new();
257 hasher.update(root.to_string_lossy().as_bytes());
258 let hash = format!("{:x}", hasher.finalize());
259 crate::core::data_dir::lean_ctx_data_dir()
260 .unwrap_or_else(|_| PathBuf::from("."))
261 .join("vectors")
262 .join(hash)
263}
264
265pub(crate) fn is_code_file(path: &Path) -> bool {
266 let ext = path.extension().and_then(|e| e.to_str()).unwrap_or("");
267 matches!(
268 ext,
269 "rs" | "ts"
270 | "tsx"
271 | "js"
272 | "jsx"
273 | "py"
274 | "go"
275 | "java"
276 | "c"
277 | "cpp"
278 | "h"
279 | "hpp"
280 | "rb"
281 | "cs"
282 | "kt"
283 | "swift"
284 | "php"
285 | "scala"
286 | "ex"
287 | "exs"
288 | "zig"
289 | "lua"
290 | "dart"
291 | "vue"
292 | "svelte"
293 )
294}
295
296fn tokenize(text: &str) -> Vec<String> {
297 let mut tokens = Vec::new();
298 let mut current = String::new();
299
300 for ch in text.chars() {
301 if ch.is_alphanumeric() || ch == '_' {
302 current.push(ch);
303 } else {
304 if current.len() >= 2 {
305 tokens.push(current.clone());
306 }
307 current.clear();
308 }
309 }
310 if current.len() >= 2 {
311 tokens.push(current);
312 }
313
314 split_camel_case_tokens(&tokens)
315}
316
317pub(crate) fn tokenize_for_index(text: &str) -> Vec<String> {
318 tokenize(text)
319}
320
321fn split_camel_case_tokens(tokens: &[String]) -> Vec<String> {
322 let mut result = Vec::new();
323 for token in tokens {
324 result.push(token.clone());
325 let mut start = 0;
326 let chars: Vec<char> = token.chars().collect();
327 for i in 1..chars.len() {
328 if chars[i].is_uppercase() && (i + 1 >= chars.len() || !chars[i + 1].is_uppercase()) {
329 let part: String = chars[start..i].iter().collect();
330 if part.len() >= 2 {
331 result.push(part);
332 }
333 start = i;
334 }
335 }
336 if start > 0 {
337 let part: String = chars[start..].iter().collect();
338 if part.len() >= 2 {
339 result.push(part);
340 }
341 }
342 }
343 result
344}
345
346fn extract_chunks(file_path: &str, content: &str) -> Vec<CodeChunk> {
347 #[cfg(feature = "tree-sitter")]
348 {
349 let ext = std::path::Path::new(file_path)
350 .extension()
351 .and_then(|e| e.to_str())
352 .unwrap_or("");
353 if let Some(chunks) = crate::core::chunks_ts::extract_chunks_ts(file_path, content, ext) {
354 return chunks;
355 }
356 }
357
358 let lines: Vec<&str> = content.lines().collect();
359 if lines.is_empty() {
360 return Vec::new();
361 }
362
363 let mut chunks = Vec::new();
364 let mut i = 0;
365
366 while i < lines.len() {
367 let trimmed = lines[i].trim();
368
369 if let Some((name, kind)) = detect_symbol(trimmed) {
370 let start = i;
371 let end = find_block_end(&lines, i);
372 let block: String = lines[start..=end.min(lines.len() - 1)].to_vec().join("\n");
373 let tokens = tokenize(&block);
374 let token_count = tokens.len();
375
376 chunks.push(CodeChunk {
377 file_path: file_path.to_string(),
378 symbol_name: name,
379 kind,
380 start_line: start + 1,
381 end_line: end + 1,
382 content: block,
383 tokens,
384 token_count,
385 });
386
387 i = end + 1;
388 } else {
389 i += 1;
390 }
391 }
392
393 if chunks.is_empty() && !content.is_empty() {
394 let bytes = content.as_bytes();
399 let rk_chunks = crate::core::rabin_karp::chunk(content);
400 if !rk_chunks.is_empty() && rk_chunks.len() <= 200 {
401 for (idx, c) in rk_chunks.into_iter().take(50).enumerate() {
402 let end = (c.offset + c.length).min(bytes.len());
403 let slice = &bytes[c.offset..end];
404 let chunk_text = String::from_utf8_lossy(slice).into_owned();
405 let tokens = tokenize(&chunk_text);
406 let token_count = tokens.len();
407 let start_line = 1 + bytecount::count(&bytes[..c.offset], b'\n');
408 let end_line = start_line + bytecount::count(slice, b'\n');
409 chunks.push(CodeChunk {
410 file_path: file_path.to_string(),
411 symbol_name: format!("{file_path}#chunk-{idx}"),
412 kind: ChunkKind::Module,
413 start_line,
414 end_line: end_line.max(start_line),
415 content: chunk_text,
416 tokens,
417 token_count,
418 });
419 }
420 } else {
421 let tokens = tokenize(content);
422 let token_count = tokens.len();
423 let snippet = lines
424 .iter()
425 .take(50)
426 .copied()
427 .collect::<Vec<_>>()
428 .join("\n");
429 chunks.push(CodeChunk {
430 file_path: file_path.to_string(),
431 symbol_name: file_path.to_string(),
432 kind: ChunkKind::Module,
433 start_line: 1,
434 end_line: lines.len(),
435 content: snippet,
436 tokens,
437 token_count,
438 });
439 }
440 }
441
442 chunks
443}
444
445fn detect_symbol(line: &str) -> Option<(String, ChunkKind)> {
446 let trimmed = line.trim();
447
448 let patterns: &[(&str, ChunkKind)] = &[
449 ("pub async fn ", ChunkKind::Function),
450 ("async fn ", ChunkKind::Function),
451 ("pub fn ", ChunkKind::Function),
452 ("fn ", ChunkKind::Function),
453 ("pub struct ", ChunkKind::Struct),
454 ("struct ", ChunkKind::Struct),
455 ("pub enum ", ChunkKind::Struct),
456 ("enum ", ChunkKind::Struct),
457 ("impl ", ChunkKind::Impl),
458 ("pub trait ", ChunkKind::Struct),
459 ("trait ", ChunkKind::Struct),
460 ("export function ", ChunkKind::Function),
461 ("export async function ", ChunkKind::Function),
462 ("export default function ", ChunkKind::Function),
463 ("function ", ChunkKind::Function),
464 ("async function ", ChunkKind::Function),
465 ("export class ", ChunkKind::Class),
466 ("class ", ChunkKind::Class),
467 ("export interface ", ChunkKind::Struct),
468 ("interface ", ChunkKind::Struct),
469 ("def ", ChunkKind::Function),
470 ("async def ", ChunkKind::Function),
471 ("class ", ChunkKind::Class),
472 ("func ", ChunkKind::Function),
473 ];
474
475 for (prefix, kind) in patterns {
476 if let Some(rest) = trimmed.strip_prefix(prefix) {
477 let name: String = rest
478 .chars()
479 .take_while(|c| c.is_alphanumeric() || *c == '_' || *c == '<')
480 .take_while(|c| *c != '<')
481 .collect();
482 if !name.is_empty() {
483 return Some((name, kind.clone()));
484 }
485 }
486 }
487
488 None
489}
490
491fn find_block_end(lines: &[&str], start: usize) -> usize {
492 let mut depth = 0i32;
493 let mut found_open = false;
494
495 for (i, line) in lines.iter().enumerate().skip(start) {
496 for ch in line.chars() {
497 match ch {
498 '{' | '(' if !found_open || depth > 0 => {
499 depth += 1;
500 found_open = true;
501 }
502 '}' | ')' if depth > 0 => {
503 depth -= 1;
504 if depth == 0 && found_open {
505 return i;
506 }
507 }
508 _ => {}
509 }
510 }
511
512 if found_open && depth <= 0 && i > start {
513 return i;
514 }
515
516 if !found_open && i > start + 2 {
517 let trimmed = lines[i].trim();
518 if trimmed.is_empty()
519 || (!trimmed.starts_with(' ') && !trimmed.starts_with('\t') && i > start)
520 {
521 return i.saturating_sub(1);
522 }
523 }
524 }
525
526 (start + 50).min(lines.len().saturating_sub(1))
527}
528
529pub fn format_search_results(results: &[SearchResult], compact: bool) -> String {
530 if results.is_empty() {
531 return "No results found.".to_string();
532 }
533
534 let mut out = String::new();
535 for (i, r) in results.iter().enumerate() {
536 if compact {
537 out.push_str(&format!(
538 "{}. {:.2} {}:{}-{} {:?} {}\n",
539 i + 1,
540 r.score,
541 r.file_path,
542 r.start_line,
543 r.end_line,
544 r.kind,
545 r.symbol_name,
546 ));
547 } else {
548 out.push_str(&format!(
549 "\n--- Result {} (score: {:.2}) ---\n{} :: {} [{:?}] (L{}-{})\n{}\n",
550 i + 1,
551 r.score,
552 r.file_path,
553 r.symbol_name,
554 r.kind,
555 r.start_line,
556 r.end_line,
557 r.snippet,
558 ));
559 }
560 }
561 out
562}
563
564#[cfg(test)]
565mod tests {
566 use super::*;
567 use tempfile::tempdir;
568
569 #[test]
570 fn tokenize_splits_code() {
571 let tokens = tokenize("fn calculate_total(items: Vec<Item>) -> f64");
572 assert!(tokens.contains(&"calculate_total".to_string()));
573 assert!(tokens.contains(&"items".to_string()));
574 assert!(tokens.contains(&"Vec".to_string()));
575 }
576
577 #[test]
578 fn camel_case_splitting() {
579 let tokens = split_camel_case_tokens(&["calculateTotal".to_string()]);
580 assert!(tokens.contains(&"calculateTotal".to_string()));
581 assert!(tokens.contains(&"calculate".to_string()));
582 assert!(tokens.contains(&"Total".to_string()));
583 }
584
585 #[test]
586 fn detect_rust_function() {
587 let (name, kind) =
588 detect_symbol("pub fn process_request(req: Request) -> Response {").unwrap();
589 assert_eq!(name, "process_request");
590 assert_eq!(kind, ChunkKind::Function);
591 }
592
593 #[test]
594 fn bm25_search_finds_relevant() {
595 let mut index = BM25Index::new();
596 index.add_chunk(CodeChunk {
597 file_path: "auth.rs".into(),
598 symbol_name: "validate_token".into(),
599 kind: ChunkKind::Function,
600 start_line: 1,
601 end_line: 10,
602 content: "fn validate_token(token: &str) -> bool { check_jwt_expiry(token) }".into(),
603 tokens: tokenize("fn validate_token token str bool check_jwt_expiry token"),
604 token_count: 8,
605 });
606 index.add_chunk(CodeChunk {
607 file_path: "db.rs".into(),
608 symbol_name: "connect_database".into(),
609 kind: ChunkKind::Function,
610 start_line: 1,
611 end_line: 5,
612 content: "fn connect_database(url: &str) -> Pool { create_pool(url) }".into(),
613 tokens: tokenize("fn connect_database url str Pool create_pool url"),
614 token_count: 7,
615 });
616 index.finalize();
617
618 let results = index.search("jwt token validation", 5);
619 assert!(!results.is_empty());
620 assert_eq!(results[0].symbol_name, "validate_token");
621 }
622
623 #[test]
624 fn vector_index_is_stale_when_any_indexed_file_is_missing() {
625 let td = tempdir().expect("tempdir");
626 let root = td.path();
627 std::fs::write(root.join("a.rs"), "pub fn a() {}\n").expect("write a.rs");
628
629 let idx = BM25Index::build_from_directory(root);
630 assert!(!vector_index_looks_stale(&idx, root));
631
632 std::fs::remove_file(root.join("a.rs")).expect("remove a.rs");
633 assert!(vector_index_looks_stale(&idx, root));
634 }
635}