1use std::collections::HashMap;
5use std::path::{Path, PathBuf};
6use std::pin::Pin;
7
8use schemars::JsonSchema;
9use serde::Deserialize;
10use tree_sitter::{Parser, QueryCursor, StreamingIterator};
11use zeph_index::languages::detect_language;
12
13use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params};
14use crate::registry::{InvocationHint, ToolDef};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum SearchCodeSource {
18 Semantic,
19 Structural,
20 LspSymbol,
21 LspReferences,
22 GrepFallback,
23}
24
25impl SearchCodeSource {
26 fn label(self) -> &'static str {
27 match self {
28 Self::Semantic => "vector search",
29 Self::Structural => "tree-sitter",
30 Self::LspSymbol => "LSP symbol search",
31 Self::LspReferences => "LSP references",
32 Self::GrepFallback => "grep fallback",
33 }
34 }
35
36 #[must_use]
37 pub fn default_score(self) -> f32 {
38 match self {
39 Self::Structural => 0.98,
40 Self::LspSymbol => 0.95,
41 Self::LspReferences => 0.90,
42 Self::Semantic => 0.75,
43 Self::GrepFallback => 0.45,
44 }
45 }
46}
47
48#[derive(Debug, Clone)]
49pub struct SearchCodeHit {
50 pub file_path: String,
51 pub line_start: usize,
52 pub line_end: usize,
53 pub snippet: String,
54 pub source: SearchCodeSource,
55 pub score: f32,
56 pub symbol_name: Option<String>,
57}
58
59pub trait SemanticSearchBackend: Send + Sync {
60 fn search<'a>(
61 &'a self,
62 query: &'a str,
63 file_pattern: Option<&'a str>,
64 max_results: usize,
65 ) -> Pin<Box<dyn std::future::Future<Output = Result<Vec<SearchCodeHit>, ToolError>> + Send + 'a>>;
66}
67
68pub trait LspSearchBackend: Send + Sync {
69 fn workspace_symbol<'a>(
70 &'a self,
71 symbol: &'a str,
72 file_pattern: Option<&'a str>,
73 max_results: usize,
74 ) -> Pin<Box<dyn std::future::Future<Output = Result<Vec<SearchCodeHit>, ToolError>> + Send + 'a>>;
75
76 fn references<'a>(
77 &'a self,
78 symbol: &'a str,
79 file_pattern: Option<&'a str>,
80 max_results: usize,
81 ) -> Pin<Box<dyn std::future::Future<Output = Result<Vec<SearchCodeHit>, ToolError>> + Send + 'a>>;
82}
83
84#[derive(Deserialize, JsonSchema)]
85struct SearchCodeParams {
86 #[serde(default)]
88 query: Option<String>,
89 #[serde(default)]
91 symbol: Option<String>,
92 #[serde(default)]
94 file_pattern: Option<String>,
95 #[serde(default)]
97 include_references: bool,
98 #[serde(default = "default_max_results")]
100 max_results: usize,
101}
102
103const fn default_max_results() -> usize {
104 10
105}
106
107pub struct SearchCodeExecutor {
108 allowed_paths: Vec<PathBuf>,
109 semantic_backend: Option<std::sync::Arc<dyn SemanticSearchBackend>>,
110 lsp_backend: Option<std::sync::Arc<dyn LspSearchBackend>>,
111}
112
113impl std::fmt::Debug for SearchCodeExecutor {
114 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115 f.debug_struct("SearchCodeExecutor")
116 .field("allowed_paths", &self.allowed_paths)
117 .field("has_semantic_backend", &self.semantic_backend.is_some())
118 .field("has_lsp_backend", &self.lsp_backend.is_some())
119 .finish()
120 }
121}
122
123impl SearchCodeExecutor {
124 #[must_use]
125 pub fn new(allowed_paths: Vec<PathBuf>) -> Self {
126 let paths = if allowed_paths.is_empty() {
127 vec![std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."))]
128 } else {
129 allowed_paths
130 };
131 Self {
132 allowed_paths: paths
133 .into_iter()
134 .map(|p| p.canonicalize().unwrap_or(p))
135 .collect(),
136 semantic_backend: None,
137 lsp_backend: None,
138 }
139 }
140
141 #[must_use]
142 pub fn with_semantic_backend(
143 mut self,
144 backend: std::sync::Arc<dyn SemanticSearchBackend>,
145 ) -> Self {
146 self.semantic_backend = Some(backend);
147 self
148 }
149
150 #[must_use]
151 pub fn with_lsp_backend(mut self, backend: std::sync::Arc<dyn LspSearchBackend>) -> Self {
152 self.lsp_backend = Some(backend);
153 self
154 }
155
156 async fn handle_search_code(
157 &self,
158 params: &SearchCodeParams,
159 ) -> Result<Option<ToolOutput>, ToolError> {
160 let query = params
161 .query
162 .as_deref()
163 .map(str::trim)
164 .filter(|s| !s.is_empty());
165 let symbol = params
166 .symbol
167 .as_deref()
168 .map(str::trim)
169 .filter(|s| !s.is_empty());
170
171 if query.is_none() && symbol.is_none() {
172 return Err(ToolError::InvalidParams {
173 message: "at least one of `query` or `symbol` must be provided".into(),
174 });
175 }
176
177 let max_results = params.max_results.clamp(1, 50);
178 let mut hits = Vec::new();
179
180 if let Some(query) = query
181 && let Some(backend) = &self.semantic_backend
182 {
183 hits.extend(
184 backend
185 .search(query, params.file_pattern.as_deref(), max_results)
186 .await?,
187 );
188 }
189
190 if let Some(symbol) = symbol {
191 hits.extend(self.structural_search(
192 symbol,
193 params.file_pattern.as_deref(),
194 max_results,
195 )?);
196
197 if let Some(backend) = &self.lsp_backend {
198 if let Ok(lsp_hits) = backend
199 .workspace_symbol(symbol, params.file_pattern.as_deref(), max_results)
200 .await
201 {
202 hits.extend(lsp_hits);
203 }
204 if params.include_references
205 && let Ok(lsp_refs) = backend
206 .references(symbol, params.file_pattern.as_deref(), max_results)
207 .await
208 {
209 hits.extend(lsp_refs);
210 }
211 }
212 }
213
214 if hits.is_empty() {
215 let fallback_term = symbol.or(query).unwrap_or_default();
216 hits.extend(self.grep_fallback(
217 fallback_term,
218 params.file_pattern.as_deref(),
219 max_results,
220 )?);
221 }
222
223 let merged = dedupe_hits(hits, max_results);
224 let root = self
225 .allowed_paths
226 .first()
227 .map_or(Path::new("."), PathBuf::as_path);
228 let summary = format_hits(&merged, root);
229 let locations = merged
230 .iter()
231 .map(|hit| hit.file_path.clone())
232 .collect::<Vec<_>>();
233 let raw_response = serde_json::json!({
234 "results": merged.iter().map(|hit| {
235 serde_json::json!({
236 "file_path": hit.file_path,
237 "line_start": hit.line_start,
238 "line_end": hit.line_end,
239 "snippet": hit.snippet,
240 "source": hit.source.label(),
241 "score": hit.score,
242 "symbol_name": hit.symbol_name,
243 })
244 }).collect::<Vec<_>>()
245 });
246
247 Ok(Some(ToolOutput {
248 tool_name: "search_code".to_owned(),
249 summary,
250 blocks_executed: 1,
251 filter_stats: None,
252 diff: None,
253 streamed: false,
254 terminal_id: None,
255 locations: Some(locations),
256 raw_response: Some(raw_response),
257 }))
258 }
259
260 fn structural_search(
261 &self,
262 symbol: &str,
263 file_pattern: Option<&str>,
264 max_results: usize,
265 ) -> Result<Vec<SearchCodeHit>, ToolError> {
266 let matcher = file_pattern
267 .map(glob::Pattern::new)
268 .transpose()
269 .map_err(|e| ToolError::InvalidParams {
270 message: format!("invalid file_pattern: {e}"),
271 })?;
272 let mut hits = Vec::new();
273 let symbol_lower = symbol.to_lowercase();
274
275 for root in &self.allowed_paths {
276 collect_structural_hits(root, root, matcher.as_ref(), &symbol_lower, &mut hits)?;
277 if hits.len() >= max_results {
278 break;
279 }
280 }
281
282 Ok(hits)
283 }
284
285 fn grep_fallback(
286 &self,
287 pattern: &str,
288 file_pattern: Option<&str>,
289 max_results: usize,
290 ) -> Result<Vec<SearchCodeHit>, ToolError> {
291 let matcher = file_pattern
292 .map(glob::Pattern::new)
293 .transpose()
294 .map_err(|e| ToolError::InvalidParams {
295 message: format!("invalid file_pattern: {e}"),
296 })?;
297 let escaped = regex::escape(pattern);
298 let regex = regex::RegexBuilder::new(&escaped)
299 .case_insensitive(true)
300 .build()
301 .map_err(|e| ToolError::InvalidParams {
302 message: e.to_string(),
303 })?;
304 let mut hits = Vec::new();
305 for root in &self.allowed_paths {
306 collect_grep_hits(root, root, matcher.as_ref(), ®ex, &mut hits, max_results)?;
307 if hits.len() >= max_results {
308 break;
309 }
310 }
311 Ok(hits)
312 }
313}
314
315impl ToolExecutor for SearchCodeExecutor {
316 async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
317 Ok(None)
318 }
319
320 async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
321 if call.tool_id != "search_code" {
322 return Ok(None);
323 }
324 let params: SearchCodeParams = deserialize_params(&call.params)?;
325 self.handle_search_code(¶ms).await
326 }
327
328 fn tool_definitions(&self) -> Vec<ToolDef> {
329 vec![ToolDef {
330 id: "search_code".into(),
331 description: "Search the codebase using semantic, structural, and LSP sources.\n\nParameters: query (string, optional) - natural language description to find semantically similar code; symbol (string, optional) - exact or partial symbol name for definition search; file_pattern (string, optional) - glob restricting files; include_references (boolean, optional) - also return symbol references when LSP is available; max_results (integer, optional) - cap results 1-50, default 10\nReturns: ranked code locations with file path, line range, snippet, source label, and score\nErrors: InvalidParams when both query and symbol are empty\nExample: {\"query\": \"where is retry backoff calculated\", \"symbol\": \"retry_backoff_ms\", \"include_references\": true}".into(),
332 schema: schemars::schema_for!(SearchCodeParams),
333 invocation: InvocationHint::ToolCall,
334 }]
335 }
336}
337
338fn dedupe_hits(mut hits: Vec<SearchCodeHit>, max_results: usize) -> Vec<SearchCodeHit> {
339 let mut merged: HashMap<(String, usize, usize), SearchCodeHit> = HashMap::new();
340 for hit in hits.drain(..) {
341 let key = (hit.file_path.clone(), hit.line_start, hit.line_end);
342 merged
343 .entry(key)
344 .and_modify(|existing| {
345 if hit.score > existing.score {
346 existing.score = hit.score;
347 existing.snippet.clone_from(&hit.snippet);
348 existing.symbol_name = hit.symbol_name.clone().or(existing.symbol_name.clone());
349 }
350 if existing.source != hit.source {
351 existing.source = if existing.score >= hit.score {
352 existing.source
353 } else {
354 hit.source
355 };
356 }
357 })
358 .or_insert(hit);
359 }
360
361 let mut merged = merged.into_values().collect::<Vec<_>>();
362 merged.sort_by(|a, b| {
363 b.score
364 .partial_cmp(&a.score)
365 .unwrap_or(std::cmp::Ordering::Equal)
366 .then_with(|| a.file_path.cmp(&b.file_path))
367 .then_with(|| a.line_start.cmp(&b.line_start))
368 });
369 merged.truncate(max_results);
370 merged
371}
372
373fn format_hits(hits: &[SearchCodeHit], root: &Path) -> String {
374 if hits.is_empty() {
375 return "No code matches found.".into();
376 }
377
378 hits.iter()
379 .enumerate()
380 .map(|(idx, hit)| {
381 let display_path = Path::new(&hit.file_path)
382 .strip_prefix(root)
383 .map_or_else(|_| hit.file_path.clone(), |p| p.display().to_string());
384 format!(
385 "[{}] {}:{}-{}\n {}\n source: {}\n score: {:.2}",
386 idx + 1,
387 display_path,
388 hit.line_start,
389 hit.line_end,
390 hit.snippet.replace('\n', " "),
391 hit.source.label(),
392 hit.score,
393 )
394 })
395 .collect::<Vec<_>>()
396 .join("\n\n")
397}
398
399fn collect_structural_hits(
400 root: &Path,
401 current: &Path,
402 matcher: Option<&glob::Pattern>,
403 symbol_lower: &str,
404 hits: &mut Vec<SearchCodeHit>,
405) -> Result<(), ToolError> {
406 if should_skip_path(current) {
407 return Ok(());
408 }
409
410 let entries = std::fs::read_dir(current).map_err(ToolError::Execution)?;
411 for entry in entries {
412 let entry = entry.map_err(ToolError::Execution)?;
413 let path = entry.path();
414 if path.is_dir() {
415 collect_structural_hits(root, &path, matcher, symbol_lower, hits)?;
416 continue;
417 }
418 if !matches_pattern(root, &path, matcher) {
419 continue;
420 }
421 let Some(lang) = detect_language(&path) else {
422 continue;
423 };
424 let Some(grammar) = lang.grammar() else {
425 continue;
426 };
427 let Some(query) = lang.symbol_query() else {
428 continue;
429 };
430 let Ok(source) = std::fs::read_to_string(&path) else {
431 continue;
432 };
433 let mut parser = Parser::new();
434 if parser.set_language(&grammar).is_err() {
435 continue;
436 }
437 let Some(tree) = parser.parse(&source, None) else {
438 continue;
439 };
440 let mut cursor = QueryCursor::new();
441 let capture_names = query.capture_names();
442 let def_idx = capture_names.iter().position(|name| *name == "def");
443 let name_idx = capture_names.iter().position(|name| *name == "name");
444 let (Some(def_idx), Some(name_idx)) = (def_idx, name_idx) else {
445 continue;
446 };
447
448 let mut query_matches = cursor.matches(query, tree.root_node(), source.as_bytes());
449 while let Some(match_) = query_matches.next() {
450 let mut def_node = None;
451 let mut name = None;
452 for capture in match_.captures {
453 if capture.index as usize == def_idx {
454 def_node = Some(capture.node);
455 }
456 if capture.index as usize == name_idx {
457 name = Some(source[capture.node.byte_range()].to_string());
458 }
459 }
460 let Some(name) = name else {
461 continue;
462 };
463 if !name.to_lowercase().contains(symbol_lower) {
464 continue;
465 }
466 let Some(def_node) = def_node else {
467 continue;
468 };
469 hits.push(SearchCodeHit {
470 file_path: canonical_string(&path),
471 line_start: def_node.start_position().row + 1,
472 line_end: def_node.end_position().row + 1,
473 snippet: extract_snippet(&source, def_node.start_position().row + 1),
474 source: SearchCodeSource::Structural,
475 score: SearchCodeSource::Structural.default_score(),
476 symbol_name: Some(name),
477 });
478 }
479 }
480 Ok(())
481}
482
483fn collect_grep_hits(
484 root: &Path,
485 current: &Path,
486 matcher: Option<&glob::Pattern>,
487 regex: ®ex::Regex,
488 hits: &mut Vec<SearchCodeHit>,
489 max_results: usize,
490) -> Result<(), ToolError> {
491 if hits.len() >= max_results || should_skip_path(current) {
492 return Ok(());
493 }
494
495 let entries = std::fs::read_dir(current).map_err(ToolError::Execution)?;
496 for entry in entries {
497 let entry = entry.map_err(ToolError::Execution)?;
498 let path = entry.path();
499 if path.is_dir() {
500 collect_grep_hits(root, &path, matcher, regex, hits, max_results)?;
501 continue;
502 }
503 if !matches_pattern(root, &path, matcher) {
504 continue;
505 }
506 let Ok(source) = std::fs::read_to_string(&path) else {
507 continue;
508 };
509 for (idx, line) in source.lines().enumerate() {
510 if regex.is_match(line) {
511 hits.push(SearchCodeHit {
512 file_path: canonical_string(&path),
513 line_start: idx + 1,
514 line_end: idx + 1,
515 snippet: line.trim().to_string(),
516 source: SearchCodeSource::GrepFallback,
517 score: SearchCodeSource::GrepFallback.default_score(),
518 symbol_name: None,
519 });
520 if hits.len() >= max_results {
521 return Ok(());
522 }
523 }
524 }
525 }
526 Ok(())
527}
528
529fn matches_pattern(root: &Path, path: &Path, matcher: Option<&glob::Pattern>) -> bool {
530 let Some(matcher) = matcher else {
531 return true;
532 };
533 let relative = path.strip_prefix(root).unwrap_or(path);
534 matcher.matches_path(relative)
535}
536
537fn should_skip_path(path: &Path) -> bool {
538 path.file_name()
539 .and_then(|name| name.to_str())
540 .is_some_and(|name| matches!(name, ".git" | "target" | "node_modules" | ".zeph"))
541}
542
543fn canonical_string(path: &Path) -> String {
544 path.canonicalize()
545 .unwrap_or_else(|_| path.to_path_buf())
546 .display()
547 .to_string()
548}
549
550fn extract_snippet(source: &str, line_number: usize) -> String {
551 source
552 .lines()
553 .nth(line_number.saturating_sub(1))
554 .map(str::trim)
555 .unwrap_or_default()
556 .to_string()
557}
558
559#[cfg(test)]
560mod tests {
561 use super::*;
562
563 struct EmptySemantic;
564
565 impl SemanticSearchBackend for EmptySemantic {
566 fn search<'a>(
567 &'a self,
568 _query: &'a str,
569 _file_pattern: Option<&'a str>,
570 _max_results: usize,
571 ) -> Pin<
572 Box<
573 dyn std::future::Future<Output = Result<Vec<SearchCodeHit>, ToolError>> + Send + 'a,
574 >,
575 > {
576 Box::pin(async move { Ok(vec![]) })
577 }
578 }
579
580 #[tokio::test]
581 async fn search_code_requires_query_or_symbol() {
582 let dir = tempfile::tempdir().unwrap();
583 let exec = SearchCodeExecutor::new(vec![dir.path().to_path_buf()]);
584 let call = ToolCall {
585 tool_id: "search_code".into(),
586 params: serde_json::Map::new(),
587 };
588 let err = exec.execute_tool_call(&call).await.unwrap_err();
589 assert!(matches!(err, ToolError::InvalidParams { .. }));
590 }
591
592 #[tokio::test]
593 async fn search_code_finds_structural_symbol() {
594 let dir = tempfile::tempdir().unwrap();
595 let file = dir.path().join("lib.rs");
596 std::fs::write(&file, "pub fn retry_backoff_ms() -> u64 { 0 }\n").unwrap();
597 let exec = SearchCodeExecutor::new(vec![dir.path().to_path_buf()]);
598 let call = ToolCall {
599 tool_id: "search_code".into(),
600 params: serde_json::json!({ "symbol": "retry_backoff_ms" })
601 .as_object()
602 .unwrap()
603 .clone(),
604 };
605 let out = exec.execute_tool_call(&call).await.unwrap().unwrap();
606 assert!(out.summary.contains("retry_backoff_ms"));
607 assert!(out.summary.contains("tree-sitter"));
608 assert_eq!(out.tool_name, "search_code");
609 }
610
611 #[tokio::test]
612 async fn search_code_uses_grep_fallback() {
613 let dir = tempfile::tempdir().unwrap();
614 let file = dir.path().join("mod.rs");
615 std::fs::write(&file, "let retry_backoff_ms = 5;\n").unwrap();
616 let exec = SearchCodeExecutor::new(vec![dir.path().to_path_buf()]);
617 let call = ToolCall {
618 tool_id: "search_code".into(),
619 params: serde_json::json!({ "query": "retry_backoff_ms" })
620 .as_object()
621 .unwrap()
622 .clone(),
623 };
624 let out = exec.execute_tool_call(&call).await.unwrap().unwrap();
625 assert!(out.summary.contains("grep fallback"));
626 }
627
628 #[test]
629 fn tool_definitions_include_search_code() {
630 let exec = SearchCodeExecutor::new(vec![])
631 .with_semantic_backend(std::sync::Arc::new(EmptySemantic));
632 let defs = exec.tool_definitions();
633 assert_eq!(defs.len(), 1);
634 assert_eq!(defs[0].id.as_ref(), "search_code");
635 }
636
637 #[test]
638 fn format_hits_strips_root_prefix() {
639 let root = Path::new("/tmp/myproject");
640 let hits = vec![SearchCodeHit {
641 file_path: "/tmp/myproject/crates/foo/src/lib.rs".to_owned(),
642 line_start: 10,
643 line_end: 15,
644 snippet: "pub fn example() {}".to_owned(),
645 source: SearchCodeSource::GrepFallback,
646 score: 0.45,
647 symbol_name: None,
648 }];
649 let output = format_hits(&hits, root);
650 assert!(
651 output.contains("crates/foo/src/lib.rs"),
652 "expected relative path in output, got: {output}"
653 );
654 assert!(
655 !output.contains("/tmp/myproject"),
656 "absolute path must not appear in output, got: {output}"
657 );
658 }
659}