Skip to main content

zeph_tools/
search_code.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::collections::HashMap;
5use std::path::{Path, PathBuf};
6
7use async_trait::async_trait;
8use schemars::JsonSchema;
9use serde::Deserialize;
10use tree_sitter::{Parser, QueryCursor, StreamingIterator};
11use zeph_index::languages::detect_language;
12
13use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params};
14use crate::registry::{InvocationHint, ToolDef};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum SearchCodeSource {
18    Semantic,
19    Structural,
20    LspSymbol,
21    LspReferences,
22    GrepFallback,
23}
24
25impl SearchCodeSource {
26    fn label(self) -> &'static str {
27        match self {
28            Self::Semantic => "vector search",
29            Self::Structural => "tree-sitter",
30            Self::LspSymbol => "LSP symbol search",
31            Self::LspReferences => "LSP references",
32            Self::GrepFallback => "grep fallback",
33        }
34    }
35
36    #[must_use]
37    pub fn default_score(self) -> f32 {
38        match self {
39            Self::Structural => 0.98,
40            Self::LspSymbol => 0.95,
41            Self::LspReferences => 0.90,
42            Self::Semantic => 0.75,
43            Self::GrepFallback => 0.45,
44        }
45    }
46}
47
48#[derive(Debug, Clone)]
49pub struct SearchCodeHit {
50    pub file_path: String,
51    pub line_start: usize,
52    pub line_end: usize,
53    pub snippet: String,
54    pub source: SearchCodeSource,
55    pub score: f32,
56    pub symbol_name: Option<String>,
57}
58
59#[async_trait]
60pub trait SemanticSearchBackend: Send + Sync {
61    async fn search(
62        &self,
63        query: &str,
64        file_pattern: Option<&str>,
65        max_results: usize,
66    ) -> Result<Vec<SearchCodeHit>, ToolError>;
67}
68
69#[async_trait]
70pub trait LspSearchBackend: Send + Sync {
71    async fn workspace_symbol(
72        &self,
73        symbol: &str,
74        file_pattern: Option<&str>,
75        max_results: usize,
76    ) -> Result<Vec<SearchCodeHit>, ToolError>;
77
78    async fn references(
79        &self,
80        symbol: &str,
81        file_pattern: Option<&str>,
82        max_results: usize,
83    ) -> Result<Vec<SearchCodeHit>, ToolError>;
84}
85
86#[derive(Deserialize, JsonSchema)]
87struct SearchCodeParams {
88    /// Natural-language query for semantic search.
89    #[serde(default)]
90    query: Option<String>,
91    /// Exact or partial symbol name.
92    #[serde(default)]
93    symbol: Option<String>,
94    /// Optional glob restricting files, for example `crates/zeph-tools/**`.
95    #[serde(default)]
96    file_pattern: Option<String>,
97    /// Also return reference locations when `symbol` is provided.
98    #[serde(default)]
99    include_references: bool,
100    /// Cap on returned locations.
101    #[serde(default = "default_max_results")]
102    max_results: usize,
103}
104
105const fn default_max_results() -> usize {
106    10
107}
108
109pub struct SearchCodeExecutor {
110    allowed_paths: Vec<PathBuf>,
111    semantic_backend: Option<std::sync::Arc<dyn SemanticSearchBackend>>,
112    lsp_backend: Option<std::sync::Arc<dyn LspSearchBackend>>,
113}
114
115impl std::fmt::Debug for SearchCodeExecutor {
116    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
117        f.debug_struct("SearchCodeExecutor")
118            .field("allowed_paths", &self.allowed_paths)
119            .field("has_semantic_backend", &self.semantic_backend.is_some())
120            .field("has_lsp_backend", &self.lsp_backend.is_some())
121            .finish()
122    }
123}
124
125impl SearchCodeExecutor {
126    #[must_use]
127    pub fn new(allowed_paths: Vec<PathBuf>) -> Self {
128        let paths = if allowed_paths.is_empty() {
129            vec![std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."))]
130        } else {
131            allowed_paths
132        };
133        Self {
134            allowed_paths: paths
135                .into_iter()
136                .map(|p| p.canonicalize().unwrap_or(p))
137                .collect(),
138            semantic_backend: None,
139            lsp_backend: None,
140        }
141    }
142
143    #[must_use]
144    pub fn with_semantic_backend(
145        mut self,
146        backend: std::sync::Arc<dyn SemanticSearchBackend>,
147    ) -> Self {
148        self.semantic_backend = Some(backend);
149        self
150    }
151
152    #[must_use]
153    pub fn with_lsp_backend(mut self, backend: std::sync::Arc<dyn LspSearchBackend>) -> Self {
154        self.lsp_backend = Some(backend);
155        self
156    }
157
158    async fn handle_search_code(
159        &self,
160        params: &SearchCodeParams,
161    ) -> Result<Option<ToolOutput>, ToolError> {
162        let query = params
163            .query
164            .as_deref()
165            .map(str::trim)
166            .filter(|s| !s.is_empty());
167        let symbol = params
168            .symbol
169            .as_deref()
170            .map(str::trim)
171            .filter(|s| !s.is_empty());
172
173        if query.is_none() && symbol.is_none() {
174            return Err(ToolError::InvalidParams {
175                message: "at least one of `query` or `symbol` must be provided".into(),
176            });
177        }
178
179        let max_results = params.max_results.clamp(1, 50);
180        let mut hits = Vec::new();
181
182        if let Some(query) = query
183            && let Some(backend) = &self.semantic_backend
184        {
185            hits.extend(
186                backend
187                    .search(query, params.file_pattern.as_deref(), max_results)
188                    .await?,
189            );
190        }
191
192        if let Some(symbol) = symbol {
193            hits.extend(self.structural_search(
194                symbol,
195                params.file_pattern.as_deref(),
196                max_results,
197            )?);
198
199            if let Some(backend) = &self.lsp_backend {
200                if let Ok(lsp_hits) = backend
201                    .workspace_symbol(symbol, params.file_pattern.as_deref(), max_results)
202                    .await
203                {
204                    hits.extend(lsp_hits);
205                }
206                if params.include_references
207                    && let Ok(lsp_refs) = backend
208                        .references(symbol, params.file_pattern.as_deref(), max_results)
209                        .await
210                {
211                    hits.extend(lsp_refs);
212                }
213            }
214        }
215
216        if hits.is_empty() {
217            let fallback_term = symbol.or(query).unwrap_or_default();
218            hits.extend(self.grep_fallback(
219                fallback_term,
220                params.file_pattern.as_deref(),
221                max_results,
222            )?);
223        }
224
225        let merged = dedupe_hits(hits, max_results);
226        let root = self
227            .allowed_paths
228            .first()
229            .map_or(Path::new("."), PathBuf::as_path);
230        let summary = format_hits(&merged, root);
231        let locations = merged
232            .iter()
233            .map(|hit| hit.file_path.clone())
234            .collect::<Vec<_>>();
235        let raw_response = serde_json::json!({
236            "results": merged.iter().map(|hit| {
237                serde_json::json!({
238                    "file_path": hit.file_path,
239                    "line_start": hit.line_start,
240                    "line_end": hit.line_end,
241                    "snippet": hit.snippet,
242                    "source": hit.source.label(),
243                    "score": hit.score,
244                    "symbol_name": hit.symbol_name,
245                })
246            }).collect::<Vec<_>>()
247        });
248
249        Ok(Some(ToolOutput {
250            tool_name: "search_code".to_owned(),
251            summary,
252            blocks_executed: 1,
253            filter_stats: None,
254            diff: None,
255            streamed: false,
256            terminal_id: None,
257            locations: Some(locations),
258            raw_response: Some(raw_response),
259        }))
260    }
261
262    fn structural_search(
263        &self,
264        symbol: &str,
265        file_pattern: Option<&str>,
266        max_results: usize,
267    ) -> Result<Vec<SearchCodeHit>, ToolError> {
268        let matcher = file_pattern
269            .map(glob::Pattern::new)
270            .transpose()
271            .map_err(|e| ToolError::InvalidParams {
272                message: format!("invalid file_pattern: {e}"),
273            })?;
274        let mut hits = Vec::new();
275        let symbol_lower = symbol.to_lowercase();
276
277        for root in &self.allowed_paths {
278            collect_structural_hits(root, root, matcher.as_ref(), &symbol_lower, &mut hits)?;
279            if hits.len() >= max_results {
280                break;
281            }
282        }
283
284        Ok(hits)
285    }
286
287    fn grep_fallback(
288        &self,
289        pattern: &str,
290        file_pattern: Option<&str>,
291        max_results: usize,
292    ) -> Result<Vec<SearchCodeHit>, ToolError> {
293        let matcher = file_pattern
294            .map(glob::Pattern::new)
295            .transpose()
296            .map_err(|e| ToolError::InvalidParams {
297                message: format!("invalid file_pattern: {e}"),
298            })?;
299        let escaped = regex::escape(pattern);
300        let regex = regex::RegexBuilder::new(&escaped)
301            .case_insensitive(true)
302            .build()
303            .map_err(|e| ToolError::InvalidParams {
304                message: e.to_string(),
305            })?;
306        let mut hits = Vec::new();
307        for root in &self.allowed_paths {
308            collect_grep_hits(root, root, matcher.as_ref(), &regex, &mut hits, max_results)?;
309            if hits.len() >= max_results {
310                break;
311            }
312        }
313        Ok(hits)
314    }
315}
316
317impl ToolExecutor for SearchCodeExecutor {
318    async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
319        Ok(None)
320    }
321
322    async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
323        if call.tool_id != "search_code" {
324            return Ok(None);
325        }
326        let params: SearchCodeParams = deserialize_params(&call.params)?;
327        self.handle_search_code(&params).await
328    }
329
330    fn tool_definitions(&self) -> Vec<ToolDef> {
331        vec![ToolDef {
332            id: "search_code".into(),
333            description: "Search the codebase using semantic, structural, and LSP sources.\n\nParameters: query (string, optional) - natural language description to find semantically similar code; symbol (string, optional) - exact or partial symbol name for definition search; file_pattern (string, optional) - glob restricting files; include_references (boolean, optional) - also return symbol references when LSP is available; max_results (integer, optional) - cap results 1-50, default 10\nReturns: ranked code locations with file path, line range, snippet, source label, and score\nErrors: InvalidParams when both query and symbol are empty\nExample: {\"query\": \"where is retry backoff calculated\", \"symbol\": \"retry_backoff_ms\", \"include_references\": true}".into(),
334            schema: schemars::schema_for!(SearchCodeParams),
335            invocation: InvocationHint::ToolCall,
336        }]
337    }
338}
339
340fn dedupe_hits(mut hits: Vec<SearchCodeHit>, max_results: usize) -> Vec<SearchCodeHit> {
341    let mut merged: HashMap<(String, usize, usize), SearchCodeHit> = HashMap::new();
342    for hit in hits.drain(..) {
343        let key = (hit.file_path.clone(), hit.line_start, hit.line_end);
344        merged
345            .entry(key)
346            .and_modify(|existing| {
347                if hit.score > existing.score {
348                    existing.score = hit.score;
349                    existing.snippet.clone_from(&hit.snippet);
350                    existing.symbol_name = hit.symbol_name.clone().or(existing.symbol_name.clone());
351                }
352                if existing.source != hit.source {
353                    existing.source = if existing.score >= hit.score {
354                        existing.source
355                    } else {
356                        hit.source
357                    };
358                }
359            })
360            .or_insert(hit);
361    }
362
363    let mut merged = merged.into_values().collect::<Vec<_>>();
364    merged.sort_by(|a, b| {
365        b.score
366            .partial_cmp(&a.score)
367            .unwrap_or(std::cmp::Ordering::Equal)
368            .then_with(|| a.file_path.cmp(&b.file_path))
369            .then_with(|| a.line_start.cmp(&b.line_start))
370    });
371    merged.truncate(max_results);
372    merged
373}
374
375fn format_hits(hits: &[SearchCodeHit], root: &Path) -> String {
376    if hits.is_empty() {
377        return "No code matches found.".into();
378    }
379
380    hits.iter()
381        .enumerate()
382        .map(|(idx, hit)| {
383            let display_path = Path::new(&hit.file_path)
384                .strip_prefix(root)
385                .map_or_else(|_| hit.file_path.clone(), |p| p.display().to_string());
386            format!(
387                "[{}] {}:{}-{}\n    {}\n    source: {}\n    score: {:.2}",
388                idx + 1,
389                display_path,
390                hit.line_start,
391                hit.line_end,
392                hit.snippet.replace('\n', " "),
393                hit.source.label(),
394                hit.score,
395            )
396        })
397        .collect::<Vec<_>>()
398        .join("\n\n")
399}
400
401fn collect_structural_hits(
402    root: &Path,
403    current: &Path,
404    matcher: Option<&glob::Pattern>,
405    symbol_lower: &str,
406    hits: &mut Vec<SearchCodeHit>,
407) -> Result<(), ToolError> {
408    if should_skip_path(current) {
409        return Ok(());
410    }
411
412    let entries = std::fs::read_dir(current).map_err(ToolError::Execution)?;
413    for entry in entries {
414        let entry = entry.map_err(ToolError::Execution)?;
415        let path = entry.path();
416        if path.is_dir() {
417            collect_structural_hits(root, &path, matcher, symbol_lower, hits)?;
418            continue;
419        }
420        if !matches_pattern(root, &path, matcher) {
421            continue;
422        }
423        let Some(lang) = detect_language(&path) else {
424            continue;
425        };
426        let Some(grammar) = lang.grammar() else {
427            continue;
428        };
429        let Some(query) = lang.symbol_query() else {
430            continue;
431        };
432        let Ok(source) = std::fs::read_to_string(&path) else {
433            continue;
434        };
435        let mut parser = Parser::new();
436        if parser.set_language(&grammar).is_err() {
437            continue;
438        }
439        let Some(tree) = parser.parse(&source, None) else {
440            continue;
441        };
442        let mut cursor = QueryCursor::new();
443        let capture_names = query.capture_names();
444        let def_idx = capture_names.iter().position(|name| *name == "def");
445        let name_idx = capture_names.iter().position(|name| *name == "name");
446        let (Some(def_idx), Some(name_idx)) = (def_idx, name_idx) else {
447            continue;
448        };
449
450        let mut query_matches = cursor.matches(query, tree.root_node(), source.as_bytes());
451        while let Some(match_) = query_matches.next() {
452            let mut def_node = None;
453            let mut name = None;
454            for capture in match_.captures {
455                if capture.index as usize == def_idx {
456                    def_node = Some(capture.node);
457                }
458                if capture.index as usize == name_idx {
459                    name = Some(source[capture.node.byte_range()].to_string());
460                }
461            }
462            let Some(name) = name else {
463                continue;
464            };
465            if !name.to_lowercase().contains(symbol_lower) {
466                continue;
467            }
468            let Some(def_node) = def_node else {
469                continue;
470            };
471            hits.push(SearchCodeHit {
472                file_path: canonical_string(&path),
473                line_start: def_node.start_position().row + 1,
474                line_end: def_node.end_position().row + 1,
475                snippet: extract_snippet(&source, def_node.start_position().row + 1),
476                source: SearchCodeSource::Structural,
477                score: SearchCodeSource::Structural.default_score(),
478                symbol_name: Some(name),
479            });
480        }
481    }
482    Ok(())
483}
484
485fn collect_grep_hits(
486    root: &Path,
487    current: &Path,
488    matcher: Option<&glob::Pattern>,
489    regex: &regex::Regex,
490    hits: &mut Vec<SearchCodeHit>,
491    max_results: usize,
492) -> Result<(), ToolError> {
493    if hits.len() >= max_results || should_skip_path(current) {
494        return Ok(());
495    }
496
497    let entries = std::fs::read_dir(current).map_err(ToolError::Execution)?;
498    for entry in entries {
499        let entry = entry.map_err(ToolError::Execution)?;
500        let path = entry.path();
501        if path.is_dir() {
502            collect_grep_hits(root, &path, matcher, regex, hits, max_results)?;
503            continue;
504        }
505        if !matches_pattern(root, &path, matcher) {
506            continue;
507        }
508        let Ok(source) = std::fs::read_to_string(&path) else {
509            continue;
510        };
511        for (idx, line) in source.lines().enumerate() {
512            if regex.is_match(line) {
513                hits.push(SearchCodeHit {
514                    file_path: canonical_string(&path),
515                    line_start: idx + 1,
516                    line_end: idx + 1,
517                    snippet: line.trim().to_string(),
518                    source: SearchCodeSource::GrepFallback,
519                    score: SearchCodeSource::GrepFallback.default_score(),
520                    symbol_name: None,
521                });
522                if hits.len() >= max_results {
523                    return Ok(());
524                }
525            }
526        }
527    }
528    Ok(())
529}
530
531fn matches_pattern(root: &Path, path: &Path, matcher: Option<&glob::Pattern>) -> bool {
532    let Some(matcher) = matcher else {
533        return true;
534    };
535    let relative = path.strip_prefix(root).unwrap_or(path);
536    matcher.matches_path(relative)
537}
538
539fn should_skip_path(path: &Path) -> bool {
540    path.file_name()
541        .and_then(|name| name.to_str())
542        .is_some_and(|name| matches!(name, ".git" | "target" | "node_modules" | ".zeph"))
543}
544
545fn canonical_string(path: &Path) -> String {
546    path.canonicalize()
547        .unwrap_or_else(|_| path.to_path_buf())
548        .display()
549        .to_string()
550}
551
552fn extract_snippet(source: &str, line_number: usize) -> String {
553    source
554        .lines()
555        .nth(line_number.saturating_sub(1))
556        .map(str::trim)
557        .unwrap_or_default()
558        .to_string()
559}
560
561#[cfg(test)]
562mod tests {
563    use super::*;
564
565    struct EmptySemantic;
566
567    #[async_trait]
568    impl SemanticSearchBackend for EmptySemantic {
569        async fn search(
570            &self,
571            _query: &str,
572            _file_pattern: Option<&str>,
573            _max_results: usize,
574        ) -> Result<Vec<SearchCodeHit>, ToolError> {
575            Ok(vec![])
576        }
577    }
578
579    #[tokio::test]
580    async fn search_code_requires_query_or_symbol() {
581        let dir = tempfile::tempdir().unwrap();
582        let exec = SearchCodeExecutor::new(vec![dir.path().to_path_buf()]);
583        let call = ToolCall {
584            tool_id: "search_code".into(),
585            params: serde_json::Map::new(),
586        };
587        let err = exec.execute_tool_call(&call).await.unwrap_err();
588        assert!(matches!(err, ToolError::InvalidParams { .. }));
589    }
590
591    #[tokio::test]
592    async fn search_code_finds_structural_symbol() {
593        let dir = tempfile::tempdir().unwrap();
594        let file = dir.path().join("lib.rs");
595        std::fs::write(&file, "pub fn retry_backoff_ms() -> u64 { 0 }\n").unwrap();
596        let exec = SearchCodeExecutor::new(vec![dir.path().to_path_buf()]);
597        let call = ToolCall {
598            tool_id: "search_code".into(),
599            params: serde_json::json!({ "symbol": "retry_backoff_ms" })
600                .as_object()
601                .unwrap()
602                .clone(),
603        };
604        let out = exec.execute_tool_call(&call).await.unwrap().unwrap();
605        assert!(out.summary.contains("retry_backoff_ms"));
606        assert!(out.summary.contains("tree-sitter"));
607        assert_eq!(out.tool_name, "search_code");
608    }
609
610    #[tokio::test]
611    async fn search_code_uses_grep_fallback() {
612        let dir = tempfile::tempdir().unwrap();
613        let file = dir.path().join("mod.rs");
614        std::fs::write(&file, "let retry_backoff_ms = 5;\n").unwrap();
615        let exec = SearchCodeExecutor::new(vec![dir.path().to_path_buf()]);
616        let call = ToolCall {
617            tool_id: "search_code".into(),
618            params: serde_json::json!({ "query": "retry_backoff_ms" })
619                .as_object()
620                .unwrap()
621                .clone(),
622        };
623        let out = exec.execute_tool_call(&call).await.unwrap().unwrap();
624        assert!(out.summary.contains("grep fallback"));
625    }
626
627    #[test]
628    fn tool_definitions_include_search_code() {
629        let exec = SearchCodeExecutor::new(vec![])
630            .with_semantic_backend(std::sync::Arc::new(EmptySemantic));
631        let defs = exec.tool_definitions();
632        assert_eq!(defs.len(), 1);
633        assert_eq!(defs[0].id.as_ref(), "search_code");
634    }
635
636    #[test]
637    fn format_hits_strips_root_prefix() {
638        let root = Path::new("/tmp/myproject");
639        let hits = vec![SearchCodeHit {
640            file_path: "/tmp/myproject/crates/foo/src/lib.rs".to_owned(),
641            line_start: 10,
642            line_end: 15,
643            snippet: "pub fn example() {}".to_owned(),
644            source: SearchCodeSource::GrepFallback,
645            score: 0.45,
646            symbol_name: None,
647        }];
648        let output = format_hits(&hits, root);
649        assert!(
650            output.contains("crates/foo/src/lib.rs"),
651            "expected relative path in output, got: {output}"
652        );
653        assert!(
654            !output.contains("/tmp/myproject"),
655            "absolute path must not appear in output, got: {output}"
656        );
657    }
658}