Skip to main content

zeph_tools/
search_code.rs

1// SPDX-FileCopyrightText: 2026 Andrei G <bug-ops>
2// SPDX-License-Identifier: MIT OR Apache-2.0
3
4use std::collections::HashMap;
5use std::path::{Path, PathBuf};
6use std::pin::Pin;
7
8use schemars::JsonSchema;
9use serde::Deserialize;
10use tree_sitter::{Parser, QueryCursor, StreamingIterator};
11use zeph_index::languages::detect_language;
12
13use crate::executor::{ToolCall, ToolError, ToolExecutor, ToolOutput, deserialize_params};
14use crate::registry::{InvocationHint, ToolDef};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
17pub enum SearchCodeSource {
18    Semantic,
19    Structural,
20    LspSymbol,
21    LspReferences,
22    GrepFallback,
23}
24
25impl SearchCodeSource {
26    fn label(self) -> &'static str {
27        match self {
28            Self::Semantic => "vector search",
29            Self::Structural => "tree-sitter",
30            Self::LspSymbol => "LSP symbol search",
31            Self::LspReferences => "LSP references",
32            Self::GrepFallback => "grep fallback",
33        }
34    }
35
36    #[must_use]
37    pub fn default_score(self) -> f32 {
38        match self {
39            Self::Structural => 0.98,
40            Self::LspSymbol => 0.95,
41            Self::LspReferences => 0.90,
42            Self::Semantic => 0.75,
43            Self::GrepFallback => 0.45,
44        }
45    }
46}
47
48#[derive(Debug, Clone)]
49pub struct SearchCodeHit {
50    pub file_path: String,
51    pub line_start: usize,
52    pub line_end: usize,
53    pub snippet: String,
54    pub source: SearchCodeSource,
55    pub score: f32,
56    pub symbol_name: Option<String>,
57}
58
59pub trait SemanticSearchBackend: Send + Sync {
60    fn search<'a>(
61        &'a self,
62        query: &'a str,
63        file_pattern: Option<&'a str>,
64        max_results: usize,
65    ) -> Pin<Box<dyn std::future::Future<Output = Result<Vec<SearchCodeHit>, ToolError>> + Send + 'a>>;
66}
67
68pub trait LspSearchBackend: Send + Sync {
69    fn workspace_symbol<'a>(
70        &'a self,
71        symbol: &'a str,
72        file_pattern: Option<&'a str>,
73        max_results: usize,
74    ) -> Pin<Box<dyn std::future::Future<Output = Result<Vec<SearchCodeHit>, ToolError>> + Send + 'a>>;
75
76    fn references<'a>(
77        &'a self,
78        symbol: &'a str,
79        file_pattern: Option<&'a str>,
80        max_results: usize,
81    ) -> Pin<Box<dyn std::future::Future<Output = Result<Vec<SearchCodeHit>, ToolError>> + Send + 'a>>;
82}
83
84#[derive(Deserialize, JsonSchema)]
85struct SearchCodeParams {
86    /// Natural-language query for semantic search.
87    #[serde(default)]
88    query: Option<String>,
89    /// Exact or partial symbol name.
90    #[serde(default)]
91    symbol: Option<String>,
92    /// Optional glob restricting files, for example `crates/zeph-tools/**`.
93    #[serde(default)]
94    file_pattern: Option<String>,
95    /// Also return reference locations when `symbol` is provided.
96    #[serde(default)]
97    include_references: bool,
98    /// Cap on returned locations.
99    #[serde(default = "default_max_results")]
100    max_results: usize,
101}
102
103const fn default_max_results() -> usize {
104    10
105}
106
107pub struct SearchCodeExecutor {
108    allowed_paths: Vec<PathBuf>,
109    semantic_backend: Option<std::sync::Arc<dyn SemanticSearchBackend>>,
110    lsp_backend: Option<std::sync::Arc<dyn LspSearchBackend>>,
111}
112
113impl std::fmt::Debug for SearchCodeExecutor {
114    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
115        f.debug_struct("SearchCodeExecutor")
116            .field("allowed_paths", &self.allowed_paths)
117            .field("has_semantic_backend", &self.semantic_backend.is_some())
118            .field("has_lsp_backend", &self.lsp_backend.is_some())
119            .finish()
120    }
121}
122
123impl SearchCodeExecutor {
124    #[must_use]
125    pub fn new(allowed_paths: Vec<PathBuf>) -> Self {
126        let paths = if allowed_paths.is_empty() {
127            vec![std::env::current_dir().unwrap_or_else(|_| PathBuf::from("."))]
128        } else {
129            allowed_paths
130        };
131        Self {
132            allowed_paths: paths
133                .into_iter()
134                .map(|p| p.canonicalize().unwrap_or(p))
135                .collect(),
136            semantic_backend: None,
137            lsp_backend: None,
138        }
139    }
140
141    #[must_use]
142    pub fn with_semantic_backend(
143        mut self,
144        backend: std::sync::Arc<dyn SemanticSearchBackend>,
145    ) -> Self {
146        self.semantic_backend = Some(backend);
147        self
148    }
149
150    #[must_use]
151    pub fn with_lsp_backend(mut self, backend: std::sync::Arc<dyn LspSearchBackend>) -> Self {
152        self.lsp_backend = Some(backend);
153        self
154    }
155
156    async fn handle_search_code(
157        &self,
158        params: &SearchCodeParams,
159    ) -> Result<Option<ToolOutput>, ToolError> {
160        let query = params
161            .query
162            .as_deref()
163            .map(str::trim)
164            .filter(|s| !s.is_empty());
165        let symbol = params
166            .symbol
167            .as_deref()
168            .map(str::trim)
169            .filter(|s| !s.is_empty());
170
171        if query.is_none() && symbol.is_none() {
172            return Err(ToolError::InvalidParams {
173                message: "at least one of `query` or `symbol` must be provided".into(),
174            });
175        }
176
177        let max_results = params.max_results.clamp(1, 50);
178        let mut hits = Vec::new();
179
180        if let Some(query) = query
181            && let Some(backend) = &self.semantic_backend
182        {
183            hits.extend(
184                backend
185                    .search(query, params.file_pattern.as_deref(), max_results)
186                    .await?,
187            );
188        }
189
190        if let Some(symbol) = symbol {
191            hits.extend(self.structural_search(
192                symbol,
193                params.file_pattern.as_deref(),
194                max_results,
195            )?);
196
197            if let Some(backend) = &self.lsp_backend {
198                if let Ok(lsp_hits) = backend
199                    .workspace_symbol(symbol, params.file_pattern.as_deref(), max_results)
200                    .await
201                {
202                    hits.extend(lsp_hits);
203                }
204                if params.include_references
205                    && let Ok(lsp_refs) = backend
206                        .references(symbol, params.file_pattern.as_deref(), max_results)
207                        .await
208                {
209                    hits.extend(lsp_refs);
210                }
211            }
212        }
213
214        if hits.is_empty() {
215            let fallback_term = symbol.or(query).unwrap_or_default();
216            hits.extend(self.grep_fallback(
217                fallback_term,
218                params.file_pattern.as_deref(),
219                max_results,
220            )?);
221        }
222
223        let merged = dedupe_hits(hits, max_results);
224        let root = self
225            .allowed_paths
226            .first()
227            .map_or(Path::new("."), PathBuf::as_path);
228        let summary = format_hits(&merged, root);
229        let locations = merged
230            .iter()
231            .map(|hit| hit.file_path.clone())
232            .collect::<Vec<_>>();
233        let raw_response = serde_json::json!({
234            "results": merged.iter().map(|hit| {
235                serde_json::json!({
236                    "file_path": hit.file_path,
237                    "line_start": hit.line_start,
238                    "line_end": hit.line_end,
239                    "snippet": hit.snippet,
240                    "source": hit.source.label(),
241                    "score": hit.score,
242                    "symbol_name": hit.symbol_name,
243                })
244            }).collect::<Vec<_>>()
245        });
246
247        Ok(Some(ToolOutput {
248            tool_name: "search_code".to_owned(),
249            summary,
250            blocks_executed: 1,
251            filter_stats: None,
252            diff: None,
253            streamed: false,
254            terminal_id: None,
255            locations: Some(locations),
256            raw_response: Some(raw_response),
257        }))
258    }
259
260    fn structural_search(
261        &self,
262        symbol: &str,
263        file_pattern: Option<&str>,
264        max_results: usize,
265    ) -> Result<Vec<SearchCodeHit>, ToolError> {
266        let matcher = file_pattern
267            .map(glob::Pattern::new)
268            .transpose()
269            .map_err(|e| ToolError::InvalidParams {
270                message: format!("invalid file_pattern: {e}"),
271            })?;
272        let mut hits = Vec::new();
273        let symbol_lower = symbol.to_lowercase();
274
275        for root in &self.allowed_paths {
276            collect_structural_hits(root, root, matcher.as_ref(), &symbol_lower, &mut hits)?;
277            if hits.len() >= max_results {
278                break;
279            }
280        }
281
282        Ok(hits)
283    }
284
285    fn grep_fallback(
286        &self,
287        pattern: &str,
288        file_pattern: Option<&str>,
289        max_results: usize,
290    ) -> Result<Vec<SearchCodeHit>, ToolError> {
291        let matcher = file_pattern
292            .map(glob::Pattern::new)
293            .transpose()
294            .map_err(|e| ToolError::InvalidParams {
295                message: format!("invalid file_pattern: {e}"),
296            })?;
297        let escaped = regex::escape(pattern);
298        let regex = regex::RegexBuilder::new(&escaped)
299            .case_insensitive(true)
300            .build()
301            .map_err(|e| ToolError::InvalidParams {
302                message: e.to_string(),
303            })?;
304        let mut hits = Vec::new();
305        for root in &self.allowed_paths {
306            collect_grep_hits(root, root, matcher.as_ref(), &regex, &mut hits, max_results)?;
307            if hits.len() >= max_results {
308                break;
309            }
310        }
311        Ok(hits)
312    }
313}
314
315impl ToolExecutor for SearchCodeExecutor {
316    async fn execute(&self, _response: &str) -> Result<Option<ToolOutput>, ToolError> {
317        Ok(None)
318    }
319
320    async fn execute_tool_call(&self, call: &ToolCall) -> Result<Option<ToolOutput>, ToolError> {
321        if call.tool_id != "search_code" {
322            return Ok(None);
323        }
324        let params: SearchCodeParams = deserialize_params(&call.params)?;
325        self.handle_search_code(&params).await
326    }
327
328    fn tool_definitions(&self) -> Vec<ToolDef> {
329        vec![ToolDef {
330            id: "search_code".into(),
331            description: "Search the codebase using semantic, structural, and LSP sources.\n\nParameters: query (string, optional) - natural language description to find semantically similar code; symbol (string, optional) - exact or partial symbol name for definition search; file_pattern (string, optional) - glob restricting files; include_references (boolean, optional) - also return symbol references when LSP is available; max_results (integer, optional) - cap results 1-50, default 10\nReturns: ranked code locations with file path, line range, snippet, source label, and score\nErrors: InvalidParams when both query and symbol are empty\nExample: {\"query\": \"where is retry backoff calculated\", \"symbol\": \"retry_backoff_ms\", \"include_references\": true}".into(),
332            schema: schemars::schema_for!(SearchCodeParams),
333            invocation: InvocationHint::ToolCall,
334        }]
335    }
336}
337
338fn dedupe_hits(mut hits: Vec<SearchCodeHit>, max_results: usize) -> Vec<SearchCodeHit> {
339    let mut merged: HashMap<(String, usize, usize), SearchCodeHit> = HashMap::new();
340    for hit in hits.drain(..) {
341        let key = (hit.file_path.clone(), hit.line_start, hit.line_end);
342        merged
343            .entry(key)
344            .and_modify(|existing| {
345                if hit.score > existing.score {
346                    existing.score = hit.score;
347                    existing.snippet.clone_from(&hit.snippet);
348                    existing.symbol_name = hit.symbol_name.clone().or(existing.symbol_name.clone());
349                }
350                if existing.source != hit.source {
351                    existing.source = if existing.score >= hit.score {
352                        existing.source
353                    } else {
354                        hit.source
355                    };
356                }
357            })
358            .or_insert(hit);
359    }
360
361    let mut merged = merged.into_values().collect::<Vec<_>>();
362    merged.sort_by(|a, b| {
363        b.score
364            .partial_cmp(&a.score)
365            .unwrap_or(std::cmp::Ordering::Equal)
366            .then_with(|| a.file_path.cmp(&b.file_path))
367            .then_with(|| a.line_start.cmp(&b.line_start))
368    });
369    merged.truncate(max_results);
370    merged
371}
372
373fn format_hits(hits: &[SearchCodeHit], root: &Path) -> String {
374    if hits.is_empty() {
375        return "No code matches found.".into();
376    }
377
378    hits.iter()
379        .enumerate()
380        .map(|(idx, hit)| {
381            let display_path = Path::new(&hit.file_path)
382                .strip_prefix(root)
383                .map_or_else(|_| hit.file_path.clone(), |p| p.display().to_string());
384            format!(
385                "[{}] {}:{}-{}\n    {}\n    source: {}\n    score: {:.2}",
386                idx + 1,
387                display_path,
388                hit.line_start,
389                hit.line_end,
390                hit.snippet.replace('\n', " "),
391                hit.source.label(),
392                hit.score,
393            )
394        })
395        .collect::<Vec<_>>()
396        .join("\n\n")
397}
398
399fn collect_structural_hits(
400    root: &Path,
401    current: &Path,
402    matcher: Option<&glob::Pattern>,
403    symbol_lower: &str,
404    hits: &mut Vec<SearchCodeHit>,
405) -> Result<(), ToolError> {
406    if should_skip_path(current) {
407        return Ok(());
408    }
409
410    let entries = std::fs::read_dir(current).map_err(ToolError::Execution)?;
411    for entry in entries {
412        let entry = entry.map_err(ToolError::Execution)?;
413        let path = entry.path();
414        if path.is_dir() {
415            collect_structural_hits(root, &path, matcher, symbol_lower, hits)?;
416            continue;
417        }
418        if !matches_pattern(root, &path, matcher) {
419            continue;
420        }
421        let Some(lang) = detect_language(&path) else {
422            continue;
423        };
424        let Some(grammar) = lang.grammar() else {
425            continue;
426        };
427        let Some(query) = lang.symbol_query() else {
428            continue;
429        };
430        let Ok(source) = std::fs::read_to_string(&path) else {
431            continue;
432        };
433        let mut parser = Parser::new();
434        if parser.set_language(&grammar).is_err() {
435            continue;
436        }
437        let Some(tree) = parser.parse(&source, None) else {
438            continue;
439        };
440        let mut cursor = QueryCursor::new();
441        let capture_names = query.capture_names();
442        let def_idx = capture_names.iter().position(|name| *name == "def");
443        let name_idx = capture_names.iter().position(|name| *name == "name");
444        let (Some(def_idx), Some(name_idx)) = (def_idx, name_idx) else {
445            continue;
446        };
447
448        let mut query_matches = cursor.matches(query, tree.root_node(), source.as_bytes());
449        while let Some(match_) = query_matches.next() {
450            let mut def_node = None;
451            let mut name = None;
452            for capture in match_.captures {
453                if capture.index as usize == def_idx {
454                    def_node = Some(capture.node);
455                }
456                if capture.index as usize == name_idx {
457                    name = Some(source[capture.node.byte_range()].to_string());
458                }
459            }
460            let Some(name) = name else {
461                continue;
462            };
463            if !name.to_lowercase().contains(symbol_lower) {
464                continue;
465            }
466            let Some(def_node) = def_node else {
467                continue;
468            };
469            hits.push(SearchCodeHit {
470                file_path: canonical_string(&path),
471                line_start: def_node.start_position().row + 1,
472                line_end: def_node.end_position().row + 1,
473                snippet: extract_snippet(&source, def_node.start_position().row + 1),
474                source: SearchCodeSource::Structural,
475                score: SearchCodeSource::Structural.default_score(),
476                symbol_name: Some(name),
477            });
478        }
479    }
480    Ok(())
481}
482
483fn collect_grep_hits(
484    root: &Path,
485    current: &Path,
486    matcher: Option<&glob::Pattern>,
487    regex: &regex::Regex,
488    hits: &mut Vec<SearchCodeHit>,
489    max_results: usize,
490) -> Result<(), ToolError> {
491    if hits.len() >= max_results || should_skip_path(current) {
492        return Ok(());
493    }
494
495    let entries = std::fs::read_dir(current).map_err(ToolError::Execution)?;
496    for entry in entries {
497        let entry = entry.map_err(ToolError::Execution)?;
498        let path = entry.path();
499        if path.is_dir() {
500            collect_grep_hits(root, &path, matcher, regex, hits, max_results)?;
501            continue;
502        }
503        if !matches_pattern(root, &path, matcher) {
504            continue;
505        }
506        let Ok(source) = std::fs::read_to_string(&path) else {
507            continue;
508        };
509        for (idx, line) in source.lines().enumerate() {
510            if regex.is_match(line) {
511                hits.push(SearchCodeHit {
512                    file_path: canonical_string(&path),
513                    line_start: idx + 1,
514                    line_end: idx + 1,
515                    snippet: line.trim().to_string(),
516                    source: SearchCodeSource::GrepFallback,
517                    score: SearchCodeSource::GrepFallback.default_score(),
518                    symbol_name: None,
519                });
520                if hits.len() >= max_results {
521                    return Ok(());
522                }
523            }
524        }
525    }
526    Ok(())
527}
528
529fn matches_pattern(root: &Path, path: &Path, matcher: Option<&glob::Pattern>) -> bool {
530    let Some(matcher) = matcher else {
531        return true;
532    };
533    let relative = path.strip_prefix(root).unwrap_or(path);
534    matcher.matches_path(relative)
535}
536
537fn should_skip_path(path: &Path) -> bool {
538    path.file_name()
539        .and_then(|name| name.to_str())
540        .is_some_and(|name| matches!(name, ".git" | "target" | "node_modules" | ".zeph"))
541}
542
543fn canonical_string(path: &Path) -> String {
544    path.canonicalize()
545        .unwrap_or_else(|_| path.to_path_buf())
546        .display()
547        .to_string()
548}
549
550fn extract_snippet(source: &str, line_number: usize) -> String {
551    source
552        .lines()
553        .nth(line_number.saturating_sub(1))
554        .map(str::trim)
555        .unwrap_or_default()
556        .to_string()
557}
558
559#[cfg(test)]
560mod tests {
561    use super::*;
562
563    struct EmptySemantic;
564
565    impl SemanticSearchBackend for EmptySemantic {
566        fn search<'a>(
567            &'a self,
568            _query: &'a str,
569            _file_pattern: Option<&'a str>,
570            _max_results: usize,
571        ) -> Pin<
572            Box<
573                dyn std::future::Future<Output = Result<Vec<SearchCodeHit>, ToolError>> + Send + 'a,
574            >,
575        > {
576            Box::pin(async move { Ok(vec![]) })
577        }
578    }
579
580    #[tokio::test]
581    async fn search_code_requires_query_or_symbol() {
582        let dir = tempfile::tempdir().unwrap();
583        let exec = SearchCodeExecutor::new(vec![dir.path().to_path_buf()]);
584        let call = ToolCall {
585            tool_id: "search_code".into(),
586            params: serde_json::Map::new(),
587        };
588        let err = exec.execute_tool_call(&call).await.unwrap_err();
589        assert!(matches!(err, ToolError::InvalidParams { .. }));
590    }
591
592    #[tokio::test]
593    async fn search_code_finds_structural_symbol() {
594        let dir = tempfile::tempdir().unwrap();
595        let file = dir.path().join("lib.rs");
596        std::fs::write(&file, "pub fn retry_backoff_ms() -> u64 { 0 }\n").unwrap();
597        let exec = SearchCodeExecutor::new(vec![dir.path().to_path_buf()]);
598        let call = ToolCall {
599            tool_id: "search_code".into(),
600            params: serde_json::json!({ "symbol": "retry_backoff_ms" })
601                .as_object()
602                .unwrap()
603                .clone(),
604        };
605        let out = exec.execute_tool_call(&call).await.unwrap().unwrap();
606        assert!(out.summary.contains("retry_backoff_ms"));
607        assert!(out.summary.contains("tree-sitter"));
608        assert_eq!(out.tool_name, "search_code");
609    }
610
611    #[tokio::test]
612    async fn search_code_uses_grep_fallback() {
613        let dir = tempfile::tempdir().unwrap();
614        let file = dir.path().join("mod.rs");
615        std::fs::write(&file, "let retry_backoff_ms = 5;\n").unwrap();
616        let exec = SearchCodeExecutor::new(vec![dir.path().to_path_buf()]);
617        let call = ToolCall {
618            tool_id: "search_code".into(),
619            params: serde_json::json!({ "query": "retry_backoff_ms" })
620                .as_object()
621                .unwrap()
622                .clone(),
623        };
624        let out = exec.execute_tool_call(&call).await.unwrap().unwrap();
625        assert!(out.summary.contains("grep fallback"));
626    }
627
628    #[test]
629    fn tool_definitions_include_search_code() {
630        let exec = SearchCodeExecutor::new(vec![])
631            .with_semantic_backend(std::sync::Arc::new(EmptySemantic));
632        let defs = exec.tool_definitions();
633        assert_eq!(defs.len(), 1);
634        assert_eq!(defs[0].id.as_ref(), "search_code");
635    }
636
637    #[test]
638    fn format_hits_strips_root_prefix() {
639        let root = Path::new("/tmp/myproject");
640        let hits = vec![SearchCodeHit {
641            file_path: "/tmp/myproject/crates/foo/src/lib.rs".to_owned(),
642            line_start: 10,
643            line_end: 15,
644            snippet: "pub fn example() {}".to_owned(),
645            source: SearchCodeSource::GrepFallback,
646            score: 0.45,
647            symbol_name: None,
648        }];
649        let output = format_hits(&hits, root);
650        assert!(
651            output.contains("crates/foo/src/lib.rs"),
652            "expected relative path in output, got: {output}"
653        );
654        assert!(
655            !output.contains("/tmp/myproject"),
656            "absolute path must not appear in output, got: {output}"
657        );
658    }
659}