Skip to main content

synwire_agent/middleware/
narrowing.rs

1//! Hierarchical narrowing middleware.
2//!
3//! Implements a three-phase progressive-disclosure strategy:
4//! 1. `tree` — build a directory map of the project
5//! 2. `skeleton` — extract signatures from candidate files
6//! 3. targeted read — return only the relevant function/range
7//!
8//! This reduces token usage versus reading entire files by ~75%.
9
10use std::path::PathBuf;
11use std::sync::Arc;
12
13use synwire_core::vfs::protocol::Vfs;
14use synwire_core::vfs::types::{TreeEntry, TreeOptions};
15
16/// Query parameters for hierarchical narrowing.
17#[derive(Debug, Clone)]
18#[non_exhaustive]
19pub struct NarrowingQuery {
20    /// Natural-language description of the code to locate.
21    pub description: String,
22    /// Maximum number of candidate files to inspect skeletons for.
23    pub top_k_files: usize,
24    /// Maximum number of symbols to return in results.
25    pub top_k_symbols: usize,
26}
27
28impl NarrowingQuery {
29    /// Construct a new query with sensible defaults (`top_k_files = 5`, `top_k_symbols = 3`).
30    #[must_use]
31    pub fn new(description: impl Into<String>) -> Self {
32        Self {
33            description: description.into(),
34            top_k_files: 5,
35            top_k_symbols: 3,
36        }
37    }
38}
39
40/// A single result from hierarchical narrowing.
41#[derive(Debug, Clone)]
42#[non_exhaustive]
43pub struct NarrowingResult {
44    /// Path of the file containing the symbol.
45    pub file: PathBuf,
46    /// Symbol name, if a specific symbol was identified.
47    pub symbol: Option<String>,
48    /// Relevance score in the range `[0.0, 1.0]`.
49    pub score: f32,
50    /// Skeleton line or signature used as context.
51    pub context: String,
52}
53
54/// Errors produced by [`HierarchicalNarrowing`].
55#[derive(Debug, thiserror::Error)]
56#[non_exhaustive]
57pub enum NarrowingError {
58    /// A VFS operation failed.
59    #[error("VFS error: {0}")]
60    Vfs(String),
61    /// No results matched the query.
62    #[error("no results found for the given query")]
63    NoResults,
64}
65
66/// Hierarchical narrowing engine.
67///
68/// Uses a three-phase heuristic search (tree → skeleton → match) to locate
69/// code relevant to a natural-language description without calling an LLM.
70#[derive(Debug, Default)]
71pub struct HierarchicalNarrowing;
72
73impl HierarchicalNarrowing {
74    /// Create a new narrowing engine.
75    #[must_use]
76    pub const fn new() -> Self {
77        Self
78    }
79
80    /// Locate code relevant to `query` within the VFS rooted at `.`.
81    ///
82    /// # Errors
83    ///
84    /// Returns [`NarrowingError::Vfs`] if a VFS operation fails, or
85    /// [`NarrowingError::NoResults`] if nothing matches.
86    pub async fn narrow(
87        &self,
88        vfs: &Arc<dyn Vfs>,
89        query: &NarrowingQuery,
90    ) -> Result<Vec<NarrowingResult>, NarrowingError> {
91        // Phase 1: collect all file paths via tree walk.
92        let tree = vfs
93            .tree(".", TreeOptions::default())
94            .await
95            .map_err(|e| NarrowingError::Vfs(e.to_string()))?;
96
97        let mut all_files: Vec<String> = Vec::new();
98        collect_files(&tree, &mut all_files);
99
100        // Phase 2: score each file by keyword overlap with the description.
101        let query_words = tokenise(&query.description);
102        let mut scored_files: Vec<(String, f32)> = all_files
103            .iter()
104            .map(|path| {
105                let score = file_score(path, &query_words);
106                (path.clone(), score)
107            })
108            .collect();
109
110        // Sort descending by score, keep top-k.
111        scored_files.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
112        scored_files.truncate(query.top_k_files);
113
114        // Discard zero-score files only when higher-scoring ones exist.
115        let any_positive = scored_files.iter().any(|(_, s)| *s > 0.0);
116        if any_positive {
117            scored_files.retain(|(_, s)| *s > 0.0);
118        }
119
120        // Phase 3: for each candidate file, get the skeleton and score symbols.
121        let mut results: Vec<NarrowingResult> = Vec::new();
122
123        for (file_path, file_score) in &scored_files {
124            let Ok(skeleton) = vfs.skeleton(file_path).await else {
125                continue;
126            };
127
128            // Score each non-empty skeleton line as a candidate symbol.
129            let mut sym_candidates: Vec<(String, f32, String)> = skeleton
130                .lines()
131                .filter(|line| !line.trim().is_empty())
132                .map(|line| {
133                    let sym_name = extract_symbol_name(line);
134                    let sym_score = symbol_score(line, &query_words);
135                    let combined = sym_score.mul_add(0.6, file_score * 0.4);
136                    (sym_name, combined, line.to_owned())
137                })
138                .collect();
139
140            sym_candidates
141                .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
142
143            // Take the best symbol from this file.
144            if let Some((sym_name, score, context)) = sym_candidates.into_iter().next() {
145                results.push(NarrowingResult {
146                    file: PathBuf::from(file_path),
147                    symbol: if sym_name.is_empty() {
148                        None
149                    } else {
150                        Some(sym_name)
151                    },
152                    score,
153                    context,
154                });
155            } else {
156                // File matched but had no skeleton lines — still include it.
157                results.push(NarrowingResult {
158                    file: PathBuf::from(file_path),
159                    symbol: None,
160                    score: *file_score,
161                    context: String::new(),
162                });
163            }
164        }
165
166        results.sort_by(|a, b| {
167            b.score
168                .partial_cmp(&a.score)
169                .unwrap_or(std::cmp::Ordering::Equal)
170        });
171        results.truncate(query.top_k_symbols);
172
173        if results.is_empty() {
174            return Err(NarrowingError::NoResults);
175        }
176
177        Ok(results)
178    }
179}
180
181// ── Internal helpers ──────────────────────────────────────────────────────────
182
183/// Recursively collect all non-directory file paths from a `TreeEntry`.
184fn collect_files(entry: &TreeEntry, out: &mut Vec<String>) {
185    if !entry.is_dir {
186        out.push(entry.path.clone());
187    }
188    for child in &entry.children {
189        collect_files(child, out);
190    }
191}
192
193/// Split a string into lowercase words (split on non-alphanumeric characters).
194fn tokenise(text: &str) -> Vec<String> {
195    text.split(|c: char| !c.is_alphanumeric())
196        .filter(|w| !w.is_empty())
197        .map(str::to_lowercase)
198        .collect()
199}
200
201/// Score a file path by the proportion of query words that appear in it.
202///
203/// A query word matches when the path contains it as a substring **or** when
204/// the path contains a word that starts with the query word (prefix match),
205/// allowing "auth" in a path to match the query word "authentication".
206///
207/// Returns a value in `[0.0, 1.0]`.
208#[allow(clippy::cast_precision_loss)]
209pub fn file_score(path: &str, query_words: &[String]) -> f32 {
210    if query_words.is_empty() {
211        return 0.0;
212    }
213    let path_lower = path.to_lowercase();
214    let path_tokens = tokenise(&path_lower);
215    let matches = query_words
216        .iter()
217        .filter(|qw| {
218            // Direct substring match in the full path string.
219            if path_lower.contains(qw.as_str()) {
220                return true;
221            }
222            // Prefix match: any path token that starts with the query word.
223            path_tokens.iter().any(|pt| qw.starts_with(pt.as_str()))
224        })
225        .count();
226    matches as f32 / query_words.len() as f32
227}
228
229/// Score a skeleton line by query word overlap.
230///
231/// Returns a value in `[0.0, 1.0]`.
232#[allow(clippy::cast_precision_loss)]
233pub fn symbol_score(line: &str, query_words: &[String]) -> f32 {
234    if query_words.is_empty() {
235        return 0.0;
236    }
237    let line_lower = line.to_lowercase();
238    let matches = query_words
239        .iter()
240        .filter(|w| line_lower.contains(w.as_str()))
241        .count();
242    matches as f32 / query_words.len() as f32
243}
244
245/// Extract the first identifier token from a skeleton line as the symbol name.
246fn extract_symbol_name(line: &str) -> String {
247    // Find the first word that looks like an identifier (letters/underscores/digits).
248    for word in line.split_whitespace() {
249        // Strip everything from the first non-identifier character (e.g. `(`).
250        let ident: String = word
251            .chars()
252            .take_while(|c| c.is_alphanumeric() || *c == '_')
253            .collect();
254        if ident.is_empty() || !ident.chars().next().is_some_and(char::is_alphabetic) {
255            continue;
256        }
257        // Skip common keywords that are not the symbol name.
258        match ident.as_str() {
259            "pub" | "fn" | "async" | "struct" | "enum" | "impl" | "trait" | "mod" | "use"
260            | "type" | "const" | "static" | "let" | "for" | "if" | "while" | "return"
261            | "unsafe" | "extern" | "crate" | "super" | "self" => {}
262            _ => return ident,
263        }
264    }
265    String::new()
266}
267
268#[cfg(test)]
269#[allow(
270    clippy::unwrap_used,
271    clippy::expect_used,
272    clippy::panic,
273    clippy::float_cmp,
274    clippy::needless_collect,
275    clippy::useless_vec
276)]
277mod tests {
278    use super::*;
279
280    // Tests for the pure keyword-scoring helpers — no VFS required.
281
282    #[test]
283    fn file_score_exact_match() {
284        let words = tokenise("authentication logic");
285        assert!(file_score("src/auth.rs", &words) > 0.0);
286    }
287
288    #[test]
289    fn file_score_no_match() {
290        let words = tokenise("authentication logic");
291        assert_eq!(file_score("src/routes.rs", &words), 0.0);
292    }
293
294    #[test]
295    fn file_score_database_match() {
296        let words = tokenise("database connection");
297        assert!(file_score("src/database.rs", &words) > 0.0);
298    }
299
300    #[test]
301    fn symbol_score_counts_overlapping_words() {
302        let words = tokenise("authenticate user");
303        let score = symbol_score("pub fn authenticate(user: &User) -> Result<Token>", &words);
304        assert!(score > 0.0);
305    }
306
307    #[test]
308    fn symbol_score_zero_when_no_overlap() {
309        let words = tokenise("unrelated concept");
310        let score = symbol_score("pub fn authenticate(user: &User) -> Result<Token>", &words);
311        assert_eq!(score, 0.0);
312    }
313
314    /// Simulate the file-ranking phase without a real VFS.
315    #[test]
316    fn narrowing_ranks_auth_file_for_authentication_query() {
317        let files = vec!["src/auth.rs", "src/database.rs", "src/routes.rs"];
318        let words = tokenise("authentication logic");
319        let mut scored: Vec<(&str, f32)> =
320            files.iter().map(|f| (*f, file_score(f, &words))).collect();
321        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
322        let top3: Vec<&str> = scored.iter().take(3).map(|(f, _)| *f).collect();
323        assert!(top3.contains(&"src/auth.rs"), "auth.rs should be in top-3");
324    }
325
326    /// Simulate the file-ranking phase for the database query.
327    #[test]
328    fn narrowing_ranks_database_file_for_database_query() {
329        let files = vec!["src/auth.rs", "src/database.rs", "src/routes.rs"];
330        let words = tokenise("database connection");
331        let mut scored: Vec<(&str, f32)> =
332            files.iter().map(|f| (*f, file_score(f, &words))).collect();
333        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
334        let top3: Vec<&str> = scored.iter().take(3).map(|(f, _)| *f).collect();
335        assert!(
336            top3.contains(&"src/database.rs"),
337            "database.rs should be in top-3"
338        );
339    }
340
341    #[test]
342    fn extract_symbol_name_skips_keywords() {
343        assert_eq!(
344            extract_symbol_name("pub fn authenticate(user: &User)"),
345            "authenticate"
346        );
347        assert_eq!(extract_symbol_name("pub struct AuthToken {"), "AuthToken");
348        assert_eq!(extract_symbol_name("  "), "");
349    }
350
351    #[test]
352    fn tokenise_splits_on_non_alphanumeric() {
353        let words = tokenise("hello-world_foo bar");
354        assert!(words.contains(&"hello".to_owned()));
355        assert!(words.contains(&"world".to_owned()));
356        assert!(words.contains(&"foo".to_owned()));
357        assert!(words.contains(&"bar".to_owned()));
358    }
359}