Skip to main content

cersei_tools/
code_search.rs

1//! CodeSearch tool: hybrid BM25 + vector semantic search.
2//!
3//! Two modes:
4//! 1. BM25 only (default): tantivy full-text search. No API calls.
5//! 2. BM25 + Vector: BM25 candidates merged with HNSW vector k-NN search
6//!    backed by the `cersei-embeddings` crate (Gemini / OpenAI / custom).
7
8use crate::{PermissionLevel, Tool, ToolCategory, ToolContext, ToolResult};
9use async_trait::async_trait;
10use cersei_embeddings::{EmbeddingProvider, Metric, VectorIndex};
11use once_cell::sync::Lazy;
12use serde::Deserialize;
13use std::collections::HashMap;
14use std::path::{Path, PathBuf};
15use std::sync::{Arc, Mutex};
16use tantivy::collector::TopDocs;
17use tantivy::query::QueryParser;
18use tantivy::schema::*;
19use tantivy::{doc, Index, IndexReader, ReloadPolicy, TantivyDocument};
20
21// ─── Config ────────────────────────────────────────────────────────────────
22
23const CHUNK_LINES: usize = 50;
24const CHUNK_OVERLAP: usize = 10;
25const BM25_CANDIDATES: usize = 20;
26const VECTOR_CANDIDATES: usize = 20;
27const DEFAULT_RESULTS: usize = 10;
28const CHUNK_EMBED_CHARS: usize = 500;
29
30const INDEXED_EXTENSIONS: &[&str] = &[
31    "bash", "c", "cc", "cpp", "cs", "css", "go", "h", "hh", "hpp", "htm", "html", "java", "js",
32    "json", "jsx", "kt", "lua", "md", "mjs", "proto", "py", "rb", "rs", "sass", "scss", "sh",
33    "sql", "swift", "toml", "ts", "tsx", "txt", "xml", "yaml", "yml", "zsh", "cjs", "graphql",
34    "gql", "jsonc", "ml", "mli", "f90", "f95", "cobol", "cbl", "ocaml",
35];
36
37// ─── Chunk metadata ────────────────────────────────────────────────────────
38
39#[derive(Debug, Clone)]
40struct ChunkMeta {
41    path: String,
42    start_line: usize,
43    end_line: usize,
44    content: String,
45}
46
47// ─── Cached index ──────────────────────────────────────────────────────────
48
49struct CachedIndex {
50    working_dir: PathBuf,
51    // BM25
52    bm25_index: Index,
53    reader: IndexReader,
54    path_field: Field,
55    content_field: Field,
56    lines_field: Field,
57    // Vector (optional)
58    vector_index: Option<VectorIndex>,
59    chunks: Vec<ChunkMeta>, // chunk_id (vector key) → metadata
60}
61
62static INDEX_CACHE: Lazy<Mutex<Option<CachedIndex>>> = Lazy::new(|| Mutex::new(None));
63
64// ─── Indexing ──────────────────────────────────────────────────────────────
65
66fn should_index(path: &Path) -> bool {
67    path.extension()
68        .and_then(|e| e.to_str())
69        .map(|ext| INDEXED_EXTENSIONS.contains(&ext.to_lowercase().as_str()))
70        .unwrap_or(false)
71}
72
73fn chunk_file(path: &Path, content: &str) -> Vec<ChunkMeta> {
74    let lines: Vec<&str> = content.lines().collect();
75    if lines.is_empty() {
76        return vec![];
77    }
78    let path_str = path.display().to_string();
79    let mut chunks = Vec::new();
80    let mut start = 0;
81    while start < lines.len() {
82        let end = (start + CHUNK_LINES).min(lines.len());
83        let chunk_content = lines[start..end].join("\n");
84        if !chunk_content.trim().is_empty() {
85            chunks.push(ChunkMeta {
86                path: path_str.clone(),
87                content: chunk_content,
88                start_line: start + 1,
89                end_line: end,
90            });
91        }
92        if end >= lines.len() {
93            break;
94        }
95        start += CHUNK_LINES - CHUNK_OVERLAP;
96    }
97    chunks
98}
99
100fn build_bm25_index(
101    chunks: &[ChunkMeta],
102) -> Result<(Index, IndexReader, Field, Field, Field), String> {
103    let mut schema_builder = Schema::builder();
104    let path_field = schema_builder.add_text_field("path", STRING | STORED);
105    let content_field = schema_builder.add_text_field("content", TEXT | STORED);
106    let lines_field = schema_builder.add_text_field("lines", STRING | STORED);
107    let schema = schema_builder.build();
108
109    let index = Index::create_in_ram(schema);
110    let mut writer = index
111        .writer(50_000_000)
112        .map_err(|e| format!("Writer error: {e}"))?;
113
114    for chunk in chunks {
115        writer
116            .add_document(doc!(
117                path_field => chunk.path.clone(),
118                content_field => chunk.content.clone(),
119                lines_field => format!("{}:{}", chunk.start_line, chunk.end_line),
120            ))
121            .map_err(|e| format!("Add doc error: {e}"))?;
122    }
123
124    writer.commit().map_err(|e| format!("Commit error: {e}"))?;
125
126    let reader = index
127        .reader_builder()
128        .reload_policy(ReloadPolicy::Manual)
129        .try_into()
130        .map_err(|e| format!("Reader error: {e}"))?;
131
132    Ok((index, reader, path_field, content_field, lines_field))
133}
134
135fn collect_chunks(working_dir: &Path) -> Vec<ChunkMeta> {
136    let mut all_chunks = Vec::new();
137    for entry in walkdir::WalkDir::new(working_dir)
138        .follow_links(false)
139        .into_iter()
140        .filter_entry(|e| {
141            let name = e.file_name().to_str().unwrap_or("");
142            !name.starts_with('.')
143                && name != "node_modules"
144                && name != "target"
145                && name != "__pycache__"
146                && name != ".venv"
147                && name != "venv"
148        })
149    {
150        let entry = match entry {
151            Ok(e) => e,
152            Err(_) => continue,
153        };
154        if !entry.file_type().is_file() || !should_index(entry.path()) {
155            continue;
156        }
157        if let Ok(meta) = entry.path().metadata() {
158            if meta.len() > 500_000 {
159                continue;
160            }
161        }
162        if let Ok(content) = std::fs::read_to_string(entry.path()) {
163            all_chunks.extend(chunk_file(entry.path(), &content));
164        }
165    }
166    all_chunks
167}
168
169fn build_index(
170    working_dir: &Path,
171    embeddings: Option<Vec<Vec<f32>>>,
172) -> Result<CachedIndex, String> {
173    let chunks = collect_chunks(working_dir);
174    let file_count = chunks
175        .iter()
176        .map(|c| &c.path)
177        .collect::<std::collections::HashSet<_>>()
178        .len();
179    tracing::info!(
180        "CodeSearch: indexed {file_count} files, {} chunks",
181        chunks.len()
182    );
183
184    let (bm25_index, reader, path_field, content_field, lines_field) = build_bm25_index(&chunks)?;
185
186    let vector_index = if let Some(embs) = embeddings {
187        if !embs.is_empty() && !embs[0].is_empty() {
188            match VectorIndex::from_vectors(&embs, Metric::Cosine) {
189                Ok(idx) => Some(idx),
190                Err(e) => {
191                    tracing::warn!("Vector index failed, BM25 only: {e}");
192                    None
193                }
194            }
195        } else {
196            None
197        }
198    } else {
199        None
200    };
201
202    Ok(CachedIndex {
203        working_dir: working_dir.to_path_buf(),
204        bm25_index,
205        reader,
206        path_field,
207        content_field,
208        lines_field,
209        vector_index,
210        chunks,
211    })
212}
213
214// ─── Search ────────────────────────────────────────────────────────────────
215
216#[derive(Debug, Clone)]
217struct SearchResult {
218    path: String,
219    content: String,
220    start_line: usize,
221    end_line: usize,
222    bm25_score: f32,
223    vector_score: f32,
224    final_score: f32,
225}
226
227fn bm25_search(
228    cached: &CachedIndex,
229    query: &str,
230    limit: usize,
231) -> Result<Vec<SearchResult>, String> {
232    let searcher = cached.reader.searcher();
233    let qp = QueryParser::for_index(&cached.bm25_index, vec![cached.content_field]);
234    let parsed = qp
235        .parse_query(query)
236        .map_err(|e| format!("Query parse: {e}"))?;
237    let top = searcher
238        .search(&parsed, &TopDocs::with_limit(limit))
239        .map_err(|e| format!("Search: {e}"))?;
240
241    let mut results = Vec::new();
242    for (score, addr) in top {
243        let doc: TantivyDocument = searcher.doc(addr).map_err(|e| format!("Doc: {e}"))?;
244        let path = doc
245            .get_first(cached.path_field)
246            .and_then(|v| v.as_str())
247            .unwrap_or("")
248            .to_string();
249        let content = doc
250            .get_first(cached.content_field)
251            .and_then(|v| v.as_str())
252            .unwrap_or("")
253            .to_string();
254        let lines = doc
255            .get_first(cached.lines_field)
256            .and_then(|v| v.as_str())
257            .unwrap_or("0:0")
258            .to_string();
259        let (start, end) = lines
260            .split_once(':')
261            .map(|(s, e)| (s.parse().unwrap_or(0), e.parse().unwrap_or(0)))
262            .unwrap_or((0, 0));
263        results.push(SearchResult {
264            path,
265            content,
266            start_line: start,
267            end_line: end,
268            bm25_score: score,
269            vector_score: 0.0,
270            final_score: score,
271        });
272    }
273    Ok(results)
274}
275
276fn vector_search(
277    cached: &CachedIndex,
278    query_embedding: &[f32],
279    limit: usize,
280) -> Result<Vec<SearchResult>, String> {
281    let vi = cached.vector_index.as_ref().ok_or("No vector index")?;
282    let hits = vi
283        .search(query_embedding, limit)
284        .map_err(|e| format!("Vector search: {e}"))?;
285
286    let mut results = Vec::new();
287    for hit in hits {
288        let key = hit.key as usize;
289        if key < cached.chunks.len() {
290            let chunk = &cached.chunks[key];
291            results.push(SearchResult {
292                path: chunk.path.clone(),
293                content: chunk.content.clone(),
294                start_line: chunk.start_line,
295                end_line: chunk.end_line,
296                bm25_score: 0.0,
297                vector_score: hit.similarity,
298                final_score: hit.similarity * 100.0,
299            });
300        }
301    }
302    Ok(results)
303}
304
305fn merge_results(
306    bm25: Vec<SearchResult>,
307    vector: Vec<SearchResult>,
308    limit: usize,
309) -> Vec<SearchResult> {
310    let mut merged: HashMap<String, SearchResult> = HashMap::new();
311
312    // Normalize BM25 scores to 0-1 range
313    let max_bm25 = bm25
314        .iter()
315        .map(|r| r.bm25_score)
316        .fold(0.0f32, f32::max)
317        .max(1.0);
318
319    for mut r in bm25 {
320        let key = format!("{}:{}:{}", r.path, r.start_line, r.end_line);
321        r.bm25_score /= max_bm25; // normalize to 0-1
322        merged.insert(key, r);
323    }
324
325    for r in vector {
326        let key = format!("{}:{}:{}", r.path, r.start_line, r.end_line);
327        if let Some(existing) = merged.get_mut(&key) {
328            existing.vector_score = r.vector_score;
329        } else {
330            merged.insert(key, r);
331        }
332    }
333
334    // Blend: 60% BM25 + 40% vector
335    let mut results: Vec<SearchResult> = merged
336        .into_values()
337        .map(|mut r| {
338            r.final_score = r.bm25_score * 0.6 + r.vector_score * 0.4;
339            r
340        })
341        .collect();
342
343    results.sort_by(|a, b| {
344        b.final_score
345            .partial_cmp(&a.final_score)
346            .unwrap_or(std::cmp::Ordering::Equal)
347    });
348    results.truncate(limit);
349    results
350}
351
352// ─── Tool implementation ───────────────────────────────────────────────────
353
354pub struct CodeSearchTool {
355    embedding_provider: Option<Arc<dyn EmbeddingProvider>>,
356}
357
358impl CodeSearchTool {
359    /// BM25-only search. No network, no API key required.
360    pub fn new() -> Self {
361        Self {
362            embedding_provider: None,
363        }
364    }
365
366    /// Enable hybrid BM25 + vector search using the given embedding provider.
367    ///
368    /// Use [`cersei_embeddings::auto_from_model`] to construct a provider
369    /// from an LLM model string, or build one explicitly with
370    /// [`cersei_embeddings::GeminiEmbeddings`] / [`cersei_embeddings::OpenAiEmbeddings`].
371    pub fn with_embeddings(provider: Arc<dyn EmbeddingProvider>) -> Self {
372        Self {
373            embedding_provider: Some(provider),
374        }
375    }
376}
377
378impl Default for CodeSearchTool {
379    fn default() -> Self {
380        Self::new()
381    }
382}
383
384#[async_trait]
385impl Tool for CodeSearchTool {
386    fn name(&self) -> &str {
387        "CodeSearch"
388    }
389
390    fn description(&self) -> &str {
391        "Semantic code search across the codebase. Use natural language queries about behavior, \
392         patterns, or concepts. Returns relevant code snippets with file paths and line numbers. \
393         This is your DEFAULT tool for discovering code — use it before Grep when you need to \
394         understand how something works rather than find an exact string."
395    }
396
397    fn permission_level(&self) -> PermissionLevel {
398        PermissionLevel::ReadOnly
399    }
400    fn category(&self) -> ToolCategory {
401        ToolCategory::FileSystem
402    }
403
404    fn input_schema(&self) -> serde_json::Value {
405        serde_json::json!({
406            "type": "object",
407            "properties": {
408                "query": {
409                    "type": "string",
410                    "description": "Natural language search query about code behavior, patterns, or concepts."
411                },
412                "path": { "type": "string", "description": "Directory to search in." },
413                "limit": { "type": "integer", "description": "Max results (default: 10)." }
414            },
415            "required": ["query"]
416        })
417    }
418
419    async fn execute(&self, input: serde_json::Value, ctx: &ToolContext) -> ToolResult {
420        #[derive(Deserialize)]
421        struct Input {
422            query: String,
423            path: Option<String>,
424            limit: Option<usize>,
425        }
426
427        let input: Input = match serde_json::from_value(input) {
428            Ok(i) => i,
429            Err(e) => return ToolResult::error(format!("Invalid input: {e}")),
430        };
431
432        let search_dir = input
433            .path
434            .map(PathBuf::from)
435            .unwrap_or_else(|| ctx.working_dir.clone());
436        let limit = input.limit.unwrap_or(DEFAULT_RESULTS);
437
438        // Build or retrieve index
439        let needs_build = {
440            let cache = INDEX_CACHE.lock().unwrap();
441            cache
442                .as_ref()
443                .map(|c| c.working_dir != search_dir)
444                .unwrap_or(true)
445        };
446
447        if needs_build {
448            // Collect chunks first
449            let chunks = collect_chunks(&search_dir);
450            let chunk_texts: Vec<String> = chunks
451                .iter()
452                .map(|c| c.content.chars().take(CHUNK_EMBED_CHARS).collect())
453                .collect();
454
455            // Optionally embed all chunks via the configured provider
456            let embeddings = if let Some(provider) = &self.embedding_provider {
457                if chunk_texts.is_empty() {
458                    None
459                } else {
460                    match provider.embed_batch(&chunk_texts).await {
461                        Ok(embs) => Some(embs),
462                        Err(e) => {
463                            tracing::warn!("Embedding failed, BM25 only: {e}");
464                            None
465                        }
466                    }
467                }
468            } else {
469                None
470            };
471
472            match build_index(&search_dir, embeddings) {
473                Ok(idx) => {
474                    *INDEX_CACHE.lock().unwrap() = Some(idx);
475                }
476                Err(e) => return ToolResult::error(format!("Index error: {e}")),
477            }
478        }
479
480        // Search BM25 and vector (release lock before any await)
481        let (bm25_results, has_vector) = {
482            let cache = INDEX_CACHE.lock().unwrap();
483            let cached = match cache.as_ref() {
484                Some(c) => c,
485                None => return ToolResult::error("No index available"),
486            };
487            let bm25 = match bm25_search(cached, &input.query, BM25_CANDIDATES) {
488                Ok(r) => r,
489                Err(e) => return ToolResult::error(format!("BM25 error: {e}")),
490            };
491            (bm25, cached.vector_index.is_some())
492        }; // lock released here
493
494        // Vector search needs async embedding call, then re-acquires lock briefly
495        let results = if has_vector {
496            if let Some(provider) = &self.embedding_provider {
497                match provider.embed(&input.query).await {
498                    Ok(query_emb) => {
499                        let cache = INDEX_CACHE.lock().unwrap();
500                        let cached = cache.as_ref().unwrap();
501                        let vec_results = vector_search(cached, &query_emb, VECTOR_CANDIDATES)
502                            .unwrap_or_default();
503                        drop(cache);
504                        merge_results(bm25_results, vec_results, limit)
505                    }
506                    Err(e) => {
507                        tracing::warn!("Query embedding failed, BM25 only: {e}");
508                        let mut r = bm25_results;
509                        r.truncate(limit);
510                        r
511                    }
512                }
513            } else {
514                let mut r = bm25_results;
515                r.truncate(limit);
516                r
517            }
518        } else {
519            let mut r = bm25_results;
520            r.truncate(limit);
521            r
522        };
523
524        if results.is_empty() {
525            return ToolResult::success(
526                "No results found. Try different search terms or use Grep for exact patterns.",
527            );
528        }
529
530        let mut output = String::new();
531        for (i, r) in results.iter().enumerate() {
532            output.push_str(&format!(
533                "── Result {} ── {}:{}-{} (score: {:.2})\n{}\n\n",
534                i + 1,
535                r.path,
536                r.start_line,
537                r.end_line,
538                r.final_score,
539                r.content
540            ));
541        }
542        ToolResult::success(output)
543    }
544}