1use crate::{PermissionLevel, Tool, ToolCategory, ToolContext, ToolResult};
9use async_trait::async_trait;
10use cersei_embeddings::{EmbeddingProvider, Metric, VectorIndex};
11use once_cell::sync::Lazy;
12use serde::Deserialize;
13use std::collections::HashMap;
14use std::path::{Path, PathBuf};
15use std::sync::{Arc, Mutex};
16use tantivy::collector::TopDocs;
17use tantivy::query::QueryParser;
18use tantivy::schema::*;
19use tantivy::{doc, Index, IndexReader, ReloadPolicy, TantivyDocument};
20
21const CHUNK_LINES: usize = 50;
24const CHUNK_OVERLAP: usize = 10;
25const BM25_CANDIDATES: usize = 20;
26const VECTOR_CANDIDATES: usize = 20;
27const DEFAULT_RESULTS: usize = 10;
28const CHUNK_EMBED_CHARS: usize = 500;
29
30const INDEXED_EXTENSIONS: &[&str] = &[
31 "bash", "c", "cc", "cpp", "cs", "css", "go", "h", "hh", "hpp", "htm", "html", "java", "js",
32 "json", "jsx", "kt", "lua", "md", "mjs", "proto", "py", "rb", "rs", "sass", "scss", "sh",
33 "sql", "swift", "toml", "ts", "tsx", "txt", "xml", "yaml", "yml", "zsh", "cjs", "graphql",
34 "gql", "jsonc", "ml", "mli", "f90", "f95", "cobol", "cbl", "ocaml",
35];
36
37#[derive(Debug, Clone)]
40struct ChunkMeta {
41 path: String,
42 start_line: usize,
43 end_line: usize,
44 content: String,
45}
46
47struct CachedIndex {
50 working_dir: PathBuf,
51 bm25_index: Index,
53 reader: IndexReader,
54 path_field: Field,
55 content_field: Field,
56 lines_field: Field,
57 vector_index: Option<VectorIndex>,
59 chunks: Vec<ChunkMeta>, }
61
62static INDEX_CACHE: Lazy<Mutex<Option<CachedIndex>>> = Lazy::new(|| Mutex::new(None));
63
64fn should_index(path: &Path) -> bool {
67 path.extension()
68 .and_then(|e| e.to_str())
69 .map(|ext| INDEXED_EXTENSIONS.contains(&ext.to_lowercase().as_str()))
70 .unwrap_or(false)
71}
72
73fn chunk_file(path: &Path, content: &str) -> Vec<ChunkMeta> {
74 let lines: Vec<&str> = content.lines().collect();
75 if lines.is_empty() {
76 return vec![];
77 }
78 let path_str = path.display().to_string();
79 let mut chunks = Vec::new();
80 let mut start = 0;
81 while start < lines.len() {
82 let end = (start + CHUNK_LINES).min(lines.len());
83 let chunk_content = lines[start..end].join("\n");
84 if !chunk_content.trim().is_empty() {
85 chunks.push(ChunkMeta {
86 path: path_str.clone(),
87 content: chunk_content,
88 start_line: start + 1,
89 end_line: end,
90 });
91 }
92 if end >= lines.len() {
93 break;
94 }
95 start += CHUNK_LINES - CHUNK_OVERLAP;
96 }
97 chunks
98}
99
100fn build_bm25_index(
101 chunks: &[ChunkMeta],
102) -> Result<(Index, IndexReader, Field, Field, Field), String> {
103 let mut schema_builder = Schema::builder();
104 let path_field = schema_builder.add_text_field("path", STRING | STORED);
105 let content_field = schema_builder.add_text_field("content", TEXT | STORED);
106 let lines_field = schema_builder.add_text_field("lines", STRING | STORED);
107 let schema = schema_builder.build();
108
109 let index = Index::create_in_ram(schema);
110 let mut writer = index
111 .writer(50_000_000)
112 .map_err(|e| format!("Writer error: {e}"))?;
113
114 for chunk in chunks {
115 writer
116 .add_document(doc!(
117 path_field => chunk.path.clone(),
118 content_field => chunk.content.clone(),
119 lines_field => format!("{}:{}", chunk.start_line, chunk.end_line),
120 ))
121 .map_err(|e| format!("Add doc error: {e}"))?;
122 }
123
124 writer.commit().map_err(|e| format!("Commit error: {e}"))?;
125
126 let reader = index
127 .reader_builder()
128 .reload_policy(ReloadPolicy::Manual)
129 .try_into()
130 .map_err(|e| format!("Reader error: {e}"))?;
131
132 Ok((index, reader, path_field, content_field, lines_field))
133}
134
135fn collect_chunks(working_dir: &Path) -> Vec<ChunkMeta> {
136 let mut all_chunks = Vec::new();
137 for entry in walkdir::WalkDir::new(working_dir)
138 .follow_links(false)
139 .into_iter()
140 .filter_entry(|e| {
141 let name = e.file_name().to_str().unwrap_or("");
142 !name.starts_with('.')
143 && name != "node_modules"
144 && name != "target"
145 && name != "__pycache__"
146 && name != ".venv"
147 && name != "venv"
148 })
149 {
150 let entry = match entry {
151 Ok(e) => e,
152 Err(_) => continue,
153 };
154 if !entry.file_type().is_file() || !should_index(entry.path()) {
155 continue;
156 }
157 if let Ok(meta) = entry.path().metadata() {
158 if meta.len() > 500_000 {
159 continue;
160 }
161 }
162 if let Ok(content) = std::fs::read_to_string(entry.path()) {
163 all_chunks.extend(chunk_file(entry.path(), &content));
164 }
165 }
166 all_chunks
167}
168
169fn build_index(
170 working_dir: &Path,
171 embeddings: Option<Vec<Vec<f32>>>,
172) -> Result<CachedIndex, String> {
173 let chunks = collect_chunks(working_dir);
174 let file_count = chunks
175 .iter()
176 .map(|c| &c.path)
177 .collect::<std::collections::HashSet<_>>()
178 .len();
179 tracing::info!(
180 "CodeSearch: indexed {file_count} files, {} chunks",
181 chunks.len()
182 );
183
184 let (bm25_index, reader, path_field, content_field, lines_field) = build_bm25_index(&chunks)?;
185
186 let vector_index = if let Some(embs) = embeddings {
187 if !embs.is_empty() && !embs[0].is_empty() {
188 match VectorIndex::from_vectors(&embs, Metric::Cosine) {
189 Ok(idx) => Some(idx),
190 Err(e) => {
191 tracing::warn!("Vector index failed, BM25 only: {e}");
192 None
193 }
194 }
195 } else {
196 None
197 }
198 } else {
199 None
200 };
201
202 Ok(CachedIndex {
203 working_dir: working_dir.to_path_buf(),
204 bm25_index,
205 reader,
206 path_field,
207 content_field,
208 lines_field,
209 vector_index,
210 chunks,
211 })
212}
213
214#[derive(Debug, Clone)]
217struct SearchResult {
218 path: String,
219 content: String,
220 start_line: usize,
221 end_line: usize,
222 bm25_score: f32,
223 vector_score: f32,
224 final_score: f32,
225}
226
227fn bm25_search(
228 cached: &CachedIndex,
229 query: &str,
230 limit: usize,
231) -> Result<Vec<SearchResult>, String> {
232 let searcher = cached.reader.searcher();
233 let qp = QueryParser::for_index(&cached.bm25_index, vec![cached.content_field]);
234 let parsed = qp
235 .parse_query(query)
236 .map_err(|e| format!("Query parse: {e}"))?;
237 let top = searcher
238 .search(&parsed, &TopDocs::with_limit(limit))
239 .map_err(|e| format!("Search: {e}"))?;
240
241 let mut results = Vec::new();
242 for (score, addr) in top {
243 let doc: TantivyDocument = searcher.doc(addr).map_err(|e| format!("Doc: {e}"))?;
244 let path = doc
245 .get_first(cached.path_field)
246 .and_then(|v| v.as_str())
247 .unwrap_or("")
248 .to_string();
249 let content = doc
250 .get_first(cached.content_field)
251 .and_then(|v| v.as_str())
252 .unwrap_or("")
253 .to_string();
254 let lines = doc
255 .get_first(cached.lines_field)
256 .and_then(|v| v.as_str())
257 .unwrap_or("0:0")
258 .to_string();
259 let (start, end) = lines
260 .split_once(':')
261 .map(|(s, e)| (s.parse().unwrap_or(0), e.parse().unwrap_or(0)))
262 .unwrap_or((0, 0));
263 results.push(SearchResult {
264 path,
265 content,
266 start_line: start,
267 end_line: end,
268 bm25_score: score,
269 vector_score: 0.0,
270 final_score: score,
271 });
272 }
273 Ok(results)
274}
275
276fn vector_search(
277 cached: &CachedIndex,
278 query_embedding: &[f32],
279 limit: usize,
280) -> Result<Vec<SearchResult>, String> {
281 let vi = cached.vector_index.as_ref().ok_or("No vector index")?;
282 let hits = vi
283 .search(query_embedding, limit)
284 .map_err(|e| format!("Vector search: {e}"))?;
285
286 let mut results = Vec::new();
287 for hit in hits {
288 let key = hit.key as usize;
289 if key < cached.chunks.len() {
290 let chunk = &cached.chunks[key];
291 results.push(SearchResult {
292 path: chunk.path.clone(),
293 content: chunk.content.clone(),
294 start_line: chunk.start_line,
295 end_line: chunk.end_line,
296 bm25_score: 0.0,
297 vector_score: hit.similarity,
298 final_score: hit.similarity * 100.0,
299 });
300 }
301 }
302 Ok(results)
303}
304
305fn merge_results(
306 bm25: Vec<SearchResult>,
307 vector: Vec<SearchResult>,
308 limit: usize,
309) -> Vec<SearchResult> {
310 let mut merged: HashMap<String, SearchResult> = HashMap::new();
311
312 let max_bm25 = bm25
314 .iter()
315 .map(|r| r.bm25_score)
316 .fold(0.0f32, f32::max)
317 .max(1.0);
318
319 for mut r in bm25 {
320 let key = format!("{}:{}:{}", r.path, r.start_line, r.end_line);
321 r.bm25_score /= max_bm25; merged.insert(key, r);
323 }
324
325 for r in vector {
326 let key = format!("{}:{}:{}", r.path, r.start_line, r.end_line);
327 if let Some(existing) = merged.get_mut(&key) {
328 existing.vector_score = r.vector_score;
329 } else {
330 merged.insert(key, r);
331 }
332 }
333
334 let mut results: Vec<SearchResult> = merged
336 .into_values()
337 .map(|mut r| {
338 r.final_score = r.bm25_score * 0.6 + r.vector_score * 0.4;
339 r
340 })
341 .collect();
342
343 results.sort_by(|a, b| {
344 b.final_score
345 .partial_cmp(&a.final_score)
346 .unwrap_or(std::cmp::Ordering::Equal)
347 });
348 results.truncate(limit);
349 results
350}
351
352pub struct CodeSearchTool {
355 embedding_provider: Option<Arc<dyn EmbeddingProvider>>,
356}
357
358impl CodeSearchTool {
359 pub fn new() -> Self {
361 Self {
362 embedding_provider: None,
363 }
364 }
365
366 pub fn with_embeddings(provider: Arc<dyn EmbeddingProvider>) -> Self {
372 Self {
373 embedding_provider: Some(provider),
374 }
375 }
376}
377
378impl Default for CodeSearchTool {
379 fn default() -> Self {
380 Self::new()
381 }
382}
383
384#[async_trait]
385impl Tool for CodeSearchTool {
386 fn name(&self) -> &str {
387 "CodeSearch"
388 }
389
390 fn description(&self) -> &str {
391 "Semantic code search across the codebase. Use natural language queries about behavior, \
392 patterns, or concepts. Returns relevant code snippets with file paths and line numbers. \
393 This is your DEFAULT tool for discovering code — use it before Grep when you need to \
394 understand how something works rather than find an exact string."
395 }
396
397 fn permission_level(&self) -> PermissionLevel {
398 PermissionLevel::ReadOnly
399 }
400 fn category(&self) -> ToolCategory {
401 ToolCategory::FileSystem
402 }
403
404 fn input_schema(&self) -> serde_json::Value {
405 serde_json::json!({
406 "type": "object",
407 "properties": {
408 "query": {
409 "type": "string",
410 "description": "Natural language search query about code behavior, patterns, or concepts."
411 },
412 "path": { "type": "string", "description": "Directory to search in." },
413 "limit": { "type": "integer", "description": "Max results (default: 10)." }
414 },
415 "required": ["query"]
416 })
417 }
418
419 async fn execute(&self, input: serde_json::Value, ctx: &ToolContext) -> ToolResult {
420 #[derive(Deserialize)]
421 struct Input {
422 query: String,
423 path: Option<String>,
424 limit: Option<usize>,
425 }
426
427 let input: Input = match serde_json::from_value(input) {
428 Ok(i) => i,
429 Err(e) => return ToolResult::error(format!("Invalid input: {e}")),
430 };
431
432 let search_dir = input
433 .path
434 .map(PathBuf::from)
435 .unwrap_or_else(|| ctx.working_dir.clone());
436 let limit = input.limit.unwrap_or(DEFAULT_RESULTS);
437
438 let needs_build = {
440 let cache = INDEX_CACHE.lock().unwrap();
441 cache
442 .as_ref()
443 .map(|c| c.working_dir != search_dir)
444 .unwrap_or(true)
445 };
446
447 if needs_build {
448 let chunks = collect_chunks(&search_dir);
450 let chunk_texts: Vec<String> = chunks
451 .iter()
452 .map(|c| c.content.chars().take(CHUNK_EMBED_CHARS).collect())
453 .collect();
454
455 let embeddings = if let Some(provider) = &self.embedding_provider {
457 if chunk_texts.is_empty() {
458 None
459 } else {
460 match provider.embed_batch(&chunk_texts).await {
461 Ok(embs) => Some(embs),
462 Err(e) => {
463 tracing::warn!("Embedding failed, BM25 only: {e}");
464 None
465 }
466 }
467 }
468 } else {
469 None
470 };
471
472 match build_index(&search_dir, embeddings) {
473 Ok(idx) => {
474 *INDEX_CACHE.lock().unwrap() = Some(idx);
475 }
476 Err(e) => return ToolResult::error(format!("Index error: {e}")),
477 }
478 }
479
480 let (bm25_results, has_vector) = {
482 let cache = INDEX_CACHE.lock().unwrap();
483 let cached = match cache.as_ref() {
484 Some(c) => c,
485 None => return ToolResult::error("No index available"),
486 };
487 let bm25 = match bm25_search(cached, &input.query, BM25_CANDIDATES) {
488 Ok(r) => r,
489 Err(e) => return ToolResult::error(format!("BM25 error: {e}")),
490 };
491 (bm25, cached.vector_index.is_some())
492 }; let results = if has_vector {
496 if let Some(provider) = &self.embedding_provider {
497 match provider.embed(&input.query).await {
498 Ok(query_emb) => {
499 let cache = INDEX_CACHE.lock().unwrap();
500 let cached = cache.as_ref().unwrap();
501 let vec_results = vector_search(cached, &query_emb, VECTOR_CANDIDATES)
502 .unwrap_or_default();
503 drop(cache);
504 merge_results(bm25_results, vec_results, limit)
505 }
506 Err(e) => {
507 tracing::warn!("Query embedding failed, BM25 only: {e}");
508 let mut r = bm25_results;
509 r.truncate(limit);
510 r
511 }
512 }
513 } else {
514 let mut r = bm25_results;
515 r.truncate(limit);
516 r
517 }
518 } else {
519 let mut r = bm25_results;
520 r.truncate(limit);
521 r
522 };
523
524 if results.is_empty() {
525 return ToolResult::success(
526 "No results found. Try different search terms or use Grep for exact patterns.",
527 );
528 }
529
530 let mut output = String::new();
531 for (i, r) in results.iter().enumerate() {
532 output.push_str(&format!(
533 "── Result {} ── {}:{}-{} (score: {:.2})\n{}\n\n",
534 i + 1,
535 r.path,
536 r.start_line,
537 r.end_line,
538 r.final_score,
539 r.content
540 ));
541 }
542 ToolResult::success(output)
543 }
544}