synwire_agent/middleware/
narrowing.rs1use std::path::PathBuf;
11use std::sync::Arc;
12
13use synwire_core::vfs::protocol::Vfs;
14use synwire_core::vfs::types::{TreeEntry, TreeOptions};
15
16#[derive(Debug, Clone)]
18#[non_exhaustive]
19pub struct NarrowingQuery {
20 pub description: String,
22 pub top_k_files: usize,
24 pub top_k_symbols: usize,
26}
27
28impl NarrowingQuery {
29 #[must_use]
31 pub fn new(description: impl Into<String>) -> Self {
32 Self {
33 description: description.into(),
34 top_k_files: 5,
35 top_k_symbols: 3,
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
42#[non_exhaustive]
43pub struct NarrowingResult {
44 pub file: PathBuf,
46 pub symbol: Option<String>,
48 pub score: f32,
50 pub context: String,
52}
53
54#[derive(Debug, thiserror::Error)]
56#[non_exhaustive]
57pub enum NarrowingError {
58 #[error("VFS error: {0}")]
60 Vfs(String),
61 #[error("no results found for the given query")]
63 NoResults,
64}
65
66#[derive(Debug, Default)]
71pub struct HierarchicalNarrowing;
72
73impl HierarchicalNarrowing {
74 #[must_use]
76 pub const fn new() -> Self {
77 Self
78 }
79
80 pub async fn narrow(
87 &self,
88 vfs: &Arc<dyn Vfs>,
89 query: &NarrowingQuery,
90 ) -> Result<Vec<NarrowingResult>, NarrowingError> {
91 let tree = vfs
93 .tree(".", TreeOptions::default())
94 .await
95 .map_err(|e| NarrowingError::Vfs(e.to_string()))?;
96
97 let mut all_files: Vec<String> = Vec::new();
98 collect_files(&tree, &mut all_files);
99
100 let query_words = tokenise(&query.description);
102 let mut scored_files: Vec<(String, f32)> = all_files
103 .iter()
104 .map(|path| {
105 let score = file_score(path, &query_words);
106 (path.clone(), score)
107 })
108 .collect();
109
110 scored_files.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
112 scored_files.truncate(query.top_k_files);
113
114 let any_positive = scored_files.iter().any(|(_, s)| *s > 0.0);
116 if any_positive {
117 scored_files.retain(|(_, s)| *s > 0.0);
118 }
119
120 let mut results: Vec<NarrowingResult> = Vec::new();
122
123 for (file_path, file_score) in &scored_files {
124 let Ok(skeleton) = vfs.skeleton(file_path).await else {
125 continue;
126 };
127
128 let mut sym_candidates: Vec<(String, f32, String)> = skeleton
130 .lines()
131 .filter(|line| !line.trim().is_empty())
132 .map(|line| {
133 let sym_name = extract_symbol_name(line);
134 let sym_score = symbol_score(line, &query_words);
135 let combined = sym_score.mul_add(0.6, file_score * 0.4);
136 (sym_name, combined, line.to_owned())
137 })
138 .collect();
139
140 sym_candidates
141 .sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
142
143 if let Some((sym_name, score, context)) = sym_candidates.into_iter().next() {
145 results.push(NarrowingResult {
146 file: PathBuf::from(file_path),
147 symbol: if sym_name.is_empty() {
148 None
149 } else {
150 Some(sym_name)
151 },
152 score,
153 context,
154 });
155 } else {
156 results.push(NarrowingResult {
158 file: PathBuf::from(file_path),
159 symbol: None,
160 score: *file_score,
161 context: String::new(),
162 });
163 }
164 }
165
166 results.sort_by(|a, b| {
167 b.score
168 .partial_cmp(&a.score)
169 .unwrap_or(std::cmp::Ordering::Equal)
170 });
171 results.truncate(query.top_k_symbols);
172
173 if results.is_empty() {
174 return Err(NarrowingError::NoResults);
175 }
176
177 Ok(results)
178 }
179}
180
181fn collect_files(entry: &TreeEntry, out: &mut Vec<String>) {
185 if !entry.is_dir {
186 out.push(entry.path.clone());
187 }
188 for child in &entry.children {
189 collect_files(child, out);
190 }
191}
192
193fn tokenise(text: &str) -> Vec<String> {
195 text.split(|c: char| !c.is_alphanumeric())
196 .filter(|w| !w.is_empty())
197 .map(str::to_lowercase)
198 .collect()
199}
200
201#[allow(clippy::cast_precision_loss)]
209pub fn file_score(path: &str, query_words: &[String]) -> f32 {
210 if query_words.is_empty() {
211 return 0.0;
212 }
213 let path_lower = path.to_lowercase();
214 let path_tokens = tokenise(&path_lower);
215 let matches = query_words
216 .iter()
217 .filter(|qw| {
218 if path_lower.contains(qw.as_str()) {
220 return true;
221 }
222 path_tokens.iter().any(|pt| qw.starts_with(pt.as_str()))
224 })
225 .count();
226 matches as f32 / query_words.len() as f32
227}
228
229#[allow(clippy::cast_precision_loss)]
233pub fn symbol_score(line: &str, query_words: &[String]) -> f32 {
234 if query_words.is_empty() {
235 return 0.0;
236 }
237 let line_lower = line.to_lowercase();
238 let matches = query_words
239 .iter()
240 .filter(|w| line_lower.contains(w.as_str()))
241 .count();
242 matches as f32 / query_words.len() as f32
243}
244
245fn extract_symbol_name(line: &str) -> String {
247 for word in line.split_whitespace() {
249 let ident: String = word
251 .chars()
252 .take_while(|c| c.is_alphanumeric() || *c == '_')
253 .collect();
254 if ident.is_empty() || !ident.chars().next().is_some_and(char::is_alphabetic) {
255 continue;
256 }
257 match ident.as_str() {
259 "pub" | "fn" | "async" | "struct" | "enum" | "impl" | "trait" | "mod" | "use"
260 | "type" | "const" | "static" | "let" | "for" | "if" | "while" | "return"
261 | "unsafe" | "extern" | "crate" | "super" | "self" => {}
262 _ => return ident,
263 }
264 }
265 String::new()
266}
267
268#[cfg(test)]
269#[allow(
270 clippy::unwrap_used,
271 clippy::expect_used,
272 clippy::panic,
273 clippy::float_cmp,
274 clippy::needless_collect,
275 clippy::useless_vec
276)]
277mod tests {
278 use super::*;
279
280 #[test]
283 fn file_score_exact_match() {
284 let words = tokenise("authentication logic");
285 assert!(file_score("src/auth.rs", &words) > 0.0);
286 }
287
288 #[test]
289 fn file_score_no_match() {
290 let words = tokenise("authentication logic");
291 assert_eq!(file_score("src/routes.rs", &words), 0.0);
292 }
293
294 #[test]
295 fn file_score_database_match() {
296 let words = tokenise("database connection");
297 assert!(file_score("src/database.rs", &words) > 0.0);
298 }
299
300 #[test]
301 fn symbol_score_counts_overlapping_words() {
302 let words = tokenise("authenticate user");
303 let score = symbol_score("pub fn authenticate(user: &User) -> Result<Token>", &words);
304 assert!(score > 0.0);
305 }
306
307 #[test]
308 fn symbol_score_zero_when_no_overlap() {
309 let words = tokenise("unrelated concept");
310 let score = symbol_score("pub fn authenticate(user: &User) -> Result<Token>", &words);
311 assert_eq!(score, 0.0);
312 }
313
314 #[test]
316 fn narrowing_ranks_auth_file_for_authentication_query() {
317 let files = vec!["src/auth.rs", "src/database.rs", "src/routes.rs"];
318 let words = tokenise("authentication logic");
319 let mut scored: Vec<(&str, f32)> =
320 files.iter().map(|f| (*f, file_score(f, &words))).collect();
321 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
322 let top3: Vec<&str> = scored.iter().take(3).map(|(f, _)| *f).collect();
323 assert!(top3.contains(&"src/auth.rs"), "auth.rs should be in top-3");
324 }
325
326 #[test]
328 fn narrowing_ranks_database_file_for_database_query() {
329 let files = vec!["src/auth.rs", "src/database.rs", "src/routes.rs"];
330 let words = tokenise("database connection");
331 let mut scored: Vec<(&str, f32)> =
332 files.iter().map(|f| (*f, file_score(f, &words))).collect();
333 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
334 let top3: Vec<&str> = scored.iter().take(3).map(|(f, _)| *f).collect();
335 assert!(
336 top3.contains(&"src/database.rs"),
337 "database.rs should be in top-3"
338 );
339 }
340
341 #[test]
342 fn extract_symbol_name_skips_keywords() {
343 assert_eq!(
344 extract_symbol_name("pub fn authenticate(user: &User)"),
345 "authenticate"
346 );
347 assert_eq!(extract_symbol_name("pub struct AuthToken {"), "AuthToken");
348 assert_eq!(extract_symbol_name(" "), "");
349 }
350
351 #[test]
352 fn tokenise_splits_on_non_alphanumeric() {
353 let words = tokenise("hello-world_foo bar");
354 assert!(words.contains(&"hello".to_owned()));
355 assert!(words.contains(&"world".to_owned()));
356 assert!(words.contains(&"foo".to_owned()));
357 assert!(words.contains(&"bar".to_owned()));
358 }
359}