use std::collections::{HashMap, HashSet, VecDeque};
use std::path::Path;
use std::sync::Arc;
use crate::types::Tag;
pub struct FocusResolver {
root: std::path::PathBuf,
}
impl FocusResolver {
pub fn new(root: impl AsRef<Path>) -> Self {
Self {
root: root.as_ref().to_path_buf(),
}
}
pub fn resolve(
&self,
focus_targets: &[String],
tags_by_file: &HashMap<String, Vec<Tag>>,
) -> (HashSet<String>, HashSet<String>) {
let mut matched_files = HashSet::new();
let mut matched_idents = HashSet::new();
let targets = focus_targets
.iter()
.flat_map(|t| t.split(',').map(|s| s.trim()))
.filter(|t| !t.is_empty())
.collect::<Vec<_>>();
for target in targets {
if let Some(matched_file) = self.try_match_file(target, tags_by_file) {
matched_files.insert(matched_file);
continue;
}
let query_idents = self.fuzzy_match_symbols(target, tags_by_file);
if !query_idents.is_empty() {
matched_idents.extend(query_idents);
continue;
}
eprintln!("focus: no matches for '{}'", target);
}
(matched_files, matched_idents)
}
fn try_match_file(
&self,
target: &str,
tags_by_file: &HashMap<String, Vec<Tag>>,
) -> Option<String> {
if !target.contains('/') && !target.contains('.') {
return None;
}
let abs_path = std::path::Path::new(target);
if abs_path.is_absolute() && abs_path.exists() {
return Some(abs_path.to_string_lossy().to_string());
}
let rel_path = self.root.join(target);
if rel_path.exists() {
return Some(rel_path.to_string_lossy().to_string());
}
for file_path in tags_by_file.keys() {
if file_path.ends_with(target) || file_path.contains(&format!("/{}", target)) {
return Some(file_path.clone());
}
}
None
}
fn fuzzy_match_symbols(
&self,
query: &str,
tags_by_file: &HashMap<String, Vec<Tag>>,
) -> HashSet<String> {
let mut matched = HashSet::new();
for tags in tags_by_file.values() {
for tag in tags {
if matches_query(&tag.name, query) {
matched.insert(tag.name.to_string());
}
}
}
matched
}
pub fn expand_via_graph(
&self,
matched_idents: &HashSet<String>,
symbol_graph: &[(Arc<str>, Arc<str>, Arc<str>, Arc<str>)],
max_hops: usize,
decay: f64,
) -> HashMap<(Arc<str>, Arc<str>), f64> {
let mut expanded = HashMap::new();
let seeds: HashSet<_> = symbol_graph
.iter()
.filter_map(|(from_file, from_sym, to_file, to_sym)| {
let mut matches = Vec::new();
if matched_idents.contains(&**from_sym) {
matches.push((from_file.clone(), from_sym.clone()));
}
if matched_idents.contains(&**to_sym) {
matches.push((to_file.clone(), to_sym.clone()));
}
if matches.is_empty() {
None
} else {
Some(matches)
}
})
.flatten()
.collect();
if seeds.is_empty() {
return expanded;
}
for seed in &seeds {
expanded.insert(seed.clone(), 1.0);
}
let mut frontier: VecDeque<_> = seeds.into_iter().collect();
let mut visited = HashSet::new();
for hop in 1..=max_hops {
let weight = decay.powi(hop as i32);
let frontier_size = frontier.len();
for _ in 0..frontier_size {
let node = match frontier.pop_front() {
Some(n) => n,
None => break,
};
if !visited.insert(node.clone()) {
continue;
}
for (from_file, from_sym, to_file, to_sym) in symbol_graph {
let neighbor = if from_file == &node.0 && from_sym == &node.1 {
Some((to_file.clone(), to_sym.clone()))
} else if to_file == &node.0 && to_sym == &node.1 {
Some((from_file.clone(), from_sym.clone()))
} else {
None
};
if let Some(neighbor_node) = neighbor {
if !expanded.contains_key(&neighbor_node) {
expanded.insert(neighbor_node.clone(), weight);
frontier.push_back(neighbor_node);
}
}
}
}
if frontier.is_empty() {
break;
}
}
expanded
}
}
fn matches_query(name: &str, query: &str) -> bool {
let name_lower = name.to_lowercase();
let query_lower = query.to_lowercase();
if name_lower == query_lower {
return true;
}
if name_lower.contains(&query_lower) {
return true;
}
let query_parts: HashSet<_> = split_identifier(&query_lower);
let name_parts: HashSet<_> = split_identifier(&name_lower);
if !query_parts.is_empty() && query_parts.is_subset(&name_parts) {
return true;
}
if query_parts
.iter()
.any(|p| p.len() >= 3 && name_lower.contains(p))
{
return true;
}
let query_stems: HashSet<_> = query_parts.iter().filter_map(|p| get_stem(p)).collect();
let name_stems: HashSet<_> = name_parts.iter().filter_map(|p| get_stem(p)).collect();
if !query_stems.is_empty() && !query_stems.is_disjoint(&name_stems) {
return true;
}
if query_lower.len() >= 4 {
for part in &name_parts {
if part.len() >= 4 && levenshtein(&query_lower, part) <= 1 {
return true;
}
}
}
false
}
fn split_identifier(s: &str) -> HashSet<String> {
let mut parts = HashSet::new();
let mut current = String::new();
let chars: Vec<char> = s.chars().collect();
let mut i = 0;
while i < chars.len() {
let ch = chars[i];
if ch == '_' || ch == '-' || ch == '.' {
if !current.is_empty() {
parts.insert(current.to_lowercase());
current.clear();
}
i += 1;
continue;
}
if ch.is_uppercase() && !current.is_empty() {
let next_is_lower = i + 1 < chars.len() && chars[i + 1].is_lowercase();
let prev_is_upper = current.chars().last().map_or(false, |c| c.is_uppercase());
if next_is_lower || !prev_is_upper {
parts.insert(current.to_lowercase());
current.clear();
}
}
current.push(ch);
i += 1;
}
if !current.is_empty() {
parts.insert(current.to_lowercase());
}
parts
}
fn get_stem(word: &str) -> Option<&'static str> {
STEM_GROUPS
.iter()
.find(|group| group.contains(&word))
.map(|group| group[0])
}
const STEM_GROUPS: &[&[&str]] = &[
&[
"auth",
"authenticate",
"authentication",
"authenticated",
"authorize",
"authorization",
"authorized",
],
&["parse", "parser", "parsing", "parsed"],
&[
"valid",
"validate",
"validation",
"validator",
"validated",
"invalid",
"invalidate",
"invalidated",
],
&[
"config",
"configure",
"configuration",
"configured",
"configurator",
],
&[
"init",
"initialize",
"initialization",
"initialized",
"initializer",
],
&["render", "renderer", "rendering", "rendered"],
&["cache", "caching", "cached"],
&["handle", "handler", "handling", "handled"],
&["exec", "execute", "execution", "executed", "executor"],
&["process", "processor", "processing", "processed"],
&[
"serial",
"serialize",
"serialization",
"serialized",
"serializer",
"deserialize",
"deserialized",
],
&[
"connect",
"connection",
"connected",
"connector",
"disconnect",
"disconnected",
],
&["transform", "transformer", "transformation", "transformed"],
&["compile", "compiler", "compilation", "compiled"],
&["eval", "evaluate", "evaluation", "evaluated", "evaluator"],
&["gen", "generate", "generation", "generated", "generator"],
&["register", "registration", "registered", "registry"],
&["query", "request", "req"],
&["response", "resp", "reply"],
];
fn levenshtein(a: &str, b: &str) -> usize {
let a_len = a.len();
let b_len = b.len();
if a_len == 0 {
return b_len;
}
if b_len == 0 {
return a_len;
}
let mut prev_row: Vec<usize> = (0..=b_len).collect();
let mut curr_row = vec![0; b_len + 1];
for (i, a_char) in a.chars().enumerate() {
curr_row[0] = i + 1;
for (j, b_char) in b.chars().enumerate() {
let cost = if a_char == b_char { 0 } else { 1 };
curr_row[j + 1] = (curr_row[j] + 1) .min(prev_row[j + 1] + 1) .min(prev_row[j] + cost); }
std::mem::swap(&mut prev_row, &mut curr_row);
}
prev_row[b_len]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_split_identifier() {
assert_eq!(
split_identifier("get_user_name"),
["get", "user", "name"]
.iter()
.map(|s| s.to_string())
.collect()
);
assert_eq!(
split_identifier("getUserName"),
["get", "user", "name"]
.iter()
.map(|s| s.to_string())
.collect()
);
assert_eq!(
split_identifier("GetUserName"),
["get", "user", "name"]
.iter()
.map(|s| s.to_string())
.collect()
);
let parts = split_identifier("HTTPServer");
assert!(parts.contains("http") || parts.contains("h")); assert!(parts.contains("server"));
assert_eq!(
split_identifier("parse_JSONObject"),
["parse", "json", "object"]
.iter()
.map(|s| s.to_string())
.collect()
);
}
#[test]
fn test_get_stem() {
assert_eq!(get_stem("authenticate"), Some("auth"));
assert_eq!(get_stem("authentication"), Some("auth"));
assert_eq!(get_stem("authorized"), Some("auth"));
assert_eq!(get_stem("auth"), Some("auth"));
assert_eq!(get_stem("parser"), Some("parse"));
assert_eq!(get_stem("parsing"), Some("parse"));
assert_eq!(get_stem("parsed"), Some("parse"));
assert_eq!(get_stem("validator"), Some("valid"));
assert_eq!(get_stem("invalid"), Some("valid"));
assert_eq!(get_stem("unknown_word"), None);
}
#[test]
fn test_levenshtein() {
assert_eq!(levenshtein("", ""), 0);
assert_eq!(levenshtein("a", ""), 1);
assert_eq!(levenshtein("", "b"), 1);
assert_eq!(levenshtein("abc", "abc"), 0);
assert_eq!(levenshtein("abc", "abd"), 1);
assert_eq!(levenshtein("abc", "axc"), 1);
assert_eq!(levenshtein("abc", "abcd"), 1);
assert_eq!(levenshtein("parse", "parser"), 1);
assert_eq!(levenshtein("parsr", "parser"), 1); }
#[test]
fn test_matches_query_exact() {
assert!(matches_query("authenticate", "authenticate"));
assert!(matches_query("Authenticate", "authenticate")); assert!(matches_query("AUTHENTICATE", "authenticate"));
}
#[test]
fn test_matches_query_substring() {
assert!(matches_query("authenticate", "auth"));
assert!(matches_query("user_authentication", "auth"));
assert!(matches_query("HTTPServer", "http"));
}
#[test]
fn test_matches_query_word_parts() {
assert!(matches_query("getUserName", "getuser"));
assert!(matches_query("get_user_name", "user"));
assert!(matches_query("parseHTTPRequest", "parse"));
}
#[test]
fn test_matches_query_stem() {
assert!(matches_query("authenticate", "auth"));
assert!(matches_query("authentication", "auth"));
assert!(matches_query("authorized", "auth"));
assert!(matches_query("parser", "parse"));
assert!(matches_query("parsing", "parse"));
assert!(matches_query("validator", "valid"));
assert!(matches_query("invalid", "valid"));
}
#[test]
fn test_matches_query_typo() {
assert!(matches_query("parser", "parsr")); assert!(matches_query("authenticate", "authentcate")); assert!(!matches_query("parser", "xyzabc"));
}
#[test]
fn test_matches_query_negative() {
assert!(!matches_query("authenticate", "xyz"));
assert!(!matches_query("getUserName", "setpassword"));
assert!(!matches_query("parser", "compiler"));
}
#[test]
fn test_resolve_empty() {
let resolver = FocusResolver::new("/tmp");
let tags_by_file = HashMap::new();
let (files, idents) = resolver.resolve(&[], &tags_by_file);
assert!(files.is_empty());
assert!(idents.is_empty());
}
#[test]
fn test_resolve_symbols() {
let resolver = FocusResolver::new("/tmp");
let mut tags_by_file = HashMap::new();
tags_by_file.insert(
"/src/auth.rs".to_string(),
vec![
Tag {
rel_fname: "src/auth.rs".into(),
fname: "/src/auth.rs".into(),
line: 10,
name: "authenticate".into(),
kind: crate::types::TagKind::Def,
node_type: "function".into(),
parent_name: None,
parent_line: None,
signature: None,
fields: None,
metadata: None,
},
Tag {
rel_fname: "src/auth.rs".into(),
fname: "/src/auth.rs".into(),
line: 20,
name: "authorize".into(),
kind: crate::types::TagKind::Def,
node_type: "function".into(),
parent_name: None,
parent_line: None,
signature: None,
fields: None,
metadata: None,
},
],
);
let (files, idents) = resolver.resolve(&["auth".to_string()], &tags_by_file);
assert!(files.is_empty()); assert_eq!(idents.len(), 2); assert!(idents.contains("authenticate"));
assert!(idents.contains("authorize"));
}
#[test]
fn test_resolve_comma_separated() {
let resolver = FocusResolver::new("/tmp");
let mut tags_by_file = HashMap::new();
tags_by_file.insert(
"/src/parser.rs".to_string(),
vec![Tag {
rel_fname: "src/parser.rs".into(),
fname: "/src/parser.rs".into(),
line: 10,
name: "parse".into(),
kind: crate::types::TagKind::Def,
node_type: "function".into(),
parent_name: None,
parent_line: None,
signature: None,
fields: None,
metadata: None,
}],
);
tags_by_file.insert(
"/src/auth.rs".to_string(),
vec![Tag {
rel_fname: "src/auth.rs".into(),
fname: "/src/auth.rs".into(),
line: 20,
name: "authenticate".into(),
kind: crate::types::TagKind::Def,
node_type: "function".into(),
parent_name: None,
parent_line: None,
signature: None,
fields: None,
metadata: None,
}],
);
let (files, idents) = resolver.resolve(&["parse,auth".to_string()], &tags_by_file);
assert!(files.is_empty());
assert_eq!(idents.len(), 2);
assert!(idents.contains("parse"));
assert!(idents.contains("authenticate"));
}
#[test]
fn test_expand_via_graph_empty() {
let resolver = FocusResolver::new("/tmp");
let matched_idents = HashSet::new();
let symbol_graph = vec![];
let expanded = resolver.expand_via_graph(&matched_idents, &symbol_graph, 2, 0.5);
assert!(expanded.is_empty());
}
#[test]
fn test_expand_via_graph_seeds_only() {
let resolver = FocusResolver::new("/tmp");
let mut matched_idents = HashSet::new();
matched_idents.insert("foo".to_string());
let symbol_graph = vec![];
let expanded = resolver.expand_via_graph(&matched_idents, &symbol_graph, 2, 0.5);
assert!(expanded.is_empty());
}
#[test]
fn test_expand_via_graph_one_hop() {
let resolver = FocusResolver::new("/tmp");
let mut matched_idents = HashSet::new();
matched_idents.insert("foo".to_string());
let symbol_graph = vec![(
Arc::from("a.rs"),
Arc::from("foo"),
Arc::from("b.rs"),
Arc::from("bar"),
)];
let expanded = resolver.expand_via_graph(&matched_idents, &symbol_graph, 1, 0.5);
assert_eq!(expanded.len(), 2);
assert_eq!(
expanded.get(&(Arc::from("a.rs"), Arc::from("foo"))),
Some(&1.0)
);
assert_eq!(
expanded.get(&(Arc::from("b.rs"), Arc::from("bar"))),
Some(&0.5)
);
}
#[test]
fn test_expand_via_graph_two_hops() {
let resolver = FocusResolver::new("/tmp");
let mut matched_idents = HashSet::new();
matched_idents.insert("foo".to_string());
let symbol_graph = vec![
(
Arc::from("a.rs"),
Arc::from("foo"),
Arc::from("b.rs"),
Arc::from("bar"),
),
(
Arc::from("b.rs"),
Arc::from("bar"),
Arc::from("c.rs"),
Arc::from("baz"),
),
];
let expanded = resolver.expand_via_graph(&matched_idents, &symbol_graph, 2, 0.5);
assert_eq!(expanded.len(), 3);
assert_eq!(
expanded.get(&(Arc::from("a.rs"), Arc::from("foo"))),
Some(&1.0)
);
assert_eq!(
expanded.get(&(Arc::from("b.rs"), Arc::from("bar"))),
Some(&0.5)
);
assert_eq!(
expanded.get(&(Arc::from("c.rs"), Arc::from("baz"))),
Some(&0.25)
);
}
#[test]
fn test_expand_via_graph_max_hops() {
let resolver = FocusResolver::new("/tmp");
let mut matched_idents = HashSet::new();
matched_idents.insert("foo".to_string());
let symbol_graph = vec![
(
Arc::from("a.rs"),
Arc::from("foo"),
Arc::from("b.rs"),
Arc::from("bar"),
),
(
Arc::from("b.rs"),
Arc::from("bar"),
Arc::from("c.rs"),
Arc::from("baz"),
),
(
Arc::from("c.rs"),
Arc::from("baz"),
Arc::from("d.rs"),
Arc::from("qux"),
),
];
let expanded = resolver.expand_via_graph(&matched_idents, &symbol_graph, 1, 0.5);
assert_eq!(expanded.len(), 2); assert!(!expanded.contains_key(&(Arc::from("c.rs"), Arc::from("baz"))));
let expanded = resolver.expand_via_graph(&matched_idents, &symbol_graph, 2, 0.5);
assert_eq!(expanded.len(), 3); assert!(!expanded.contains_key(&(Arc::from("d.rs"), Arc::from("qux"))));
}
#[test]
fn test_expand_via_graph_bidirectional() {
let resolver = FocusResolver::new("/tmp");
let mut matched_idents = HashSet::new();
matched_idents.insert("bar".to_string());
let symbol_graph = vec![
(
Arc::from("a.rs"),
Arc::from("foo"),
Arc::from("b.rs"),
Arc::from("bar"),
),
(
Arc::from("c.rs"),
Arc::from("baz"),
Arc::from("b.rs"),
Arc::from("bar"),
),
];
let expanded = resolver.expand_via_graph(&matched_idents, &symbol_graph, 1, 0.5);
assert_eq!(expanded.len(), 3); assert_eq!(
expanded.get(&(Arc::from("b.rs"), Arc::from("bar"))),
Some(&1.0)
);
assert_eq!(
expanded.get(&(Arc::from("a.rs"), Arc::from("foo"))),
Some(&0.5)
);
assert_eq!(
expanded.get(&(Arc::from("c.rs"), Arc::from("baz"))),
Some(&0.5)
);
}
#[test]
fn test_expand_via_graph_custom_decay() {
let resolver = FocusResolver::new("/tmp");
let mut matched_idents = HashSet::new();
matched_idents.insert("foo".to_string());
let symbol_graph = vec![(
Arc::from("a.rs"),
Arc::from("foo"),
Arc::from("b.rs"),
Arc::from("bar"),
)];
let expanded = resolver.expand_via_graph(&matched_idents, &symbol_graph, 1, 0.75);
assert_eq!(
expanded.get(&(Arc::from("b.rs"), Arc::from("bar"))),
Some(&0.75)
);
let expanded = resolver.expand_via_graph(&matched_idents, &symbol_graph, 1, 0.25);
assert_eq!(
expanded.get(&(Arc::from("b.rs"), Arc::from("bar"))),
Some(&0.25)
);
}
#[test]
fn test_resolve_file_with_extension() {
let resolver = FocusResolver::new("/tmp");
let mut tags_by_file = HashMap::new();
tags_by_file.insert(
"/src/auth.rs".to_string(),
vec![Tag {
rel_fname: "src/auth.rs".into(),
fname: "/src/auth.rs".into(),
line: 10,
name: "authenticate".into(),
kind: crate::types::TagKind::Def,
node_type: "function".into(),
parent_name: None,
parent_line: None,
signature: None,
fields: None,
metadata: None,
}],
);
let (files, idents) = resolver.resolve(&["auth.rs".to_string()], &tags_by_file);
assert_eq!(files.len(), 1);
assert!(files.contains("/src/auth.rs"));
assert!(idents.is_empty());
}
#[test]
fn test_expand_via_graph_cycle() {
let resolver = FocusResolver::new("/tmp");
let mut matched_idents = HashSet::new();
matched_idents.insert("foo".to_string());
let symbol_graph = vec![
(
Arc::from("a.rs"),
Arc::from("foo"),
Arc::from("b.rs"),
Arc::from("bar"),
),
(
Arc::from("b.rs"),
Arc::from("bar"),
Arc::from("a.rs"),
Arc::from("foo"),
),
];
let expanded = resolver.expand_via_graph(&matched_idents, &symbol_graph, 2, 0.5);
assert_eq!(expanded.len(), 2); assert_eq!(
expanded.get(&(Arc::from("a.rs"), Arc::from("foo"))),
Some(&1.0)
);
assert_eq!(
expanded.get(&(Arc::from("b.rs"), Arc::from("bar"))),
Some(&0.5)
);
}
#[test]
fn test_split_identifier_edge_cases() {
assert_eq!(
split_identifier("foo"),
["foo"].iter().map(|s| s.to_string()).collect()
);
assert_eq!(split_identifier(""), HashSet::new());
let parts = split_identifier("foo123bar");
assert!(parts.contains("foo123bar") || parts.len() >= 1);
assert_eq!(
split_identifier("foo__bar--baz"),
["foo", "bar", "baz"]
.iter()
.map(|s| s.to_string())
.collect()
);
}
#[test]
fn test_matches_query_case_insensitive() {
assert!(matches_query("AuthHandler", "authhandler"));
assert!(matches_query("AUTH_HANDLER", "auth"));
assert!(matches_query("parseJSON", "parsejson"));
}
#[test]
fn test_expand_via_graph_multiple_seeds() {
let resolver = FocusResolver::new("/tmp");
let mut matched_idents = HashSet::new();
matched_idents.insert("foo".to_string());
matched_idents.insert("baz".to_string());
let symbol_graph = vec![
(
Arc::from("a.rs"),
Arc::from("foo"),
Arc::from("b.rs"),
Arc::from("bar"),
),
(
Arc::from("c.rs"),
Arc::from("baz"),
Arc::from("d.rs"),
Arc::from("qux"),
),
];
let expanded = resolver.expand_via_graph(&matched_idents, &symbol_graph, 1, 0.5);
assert_eq!(expanded.len(), 4); assert_eq!(
expanded.get(&(Arc::from("a.rs"), Arc::from("foo"))),
Some(&1.0)
);
assert_eq!(
expanded.get(&(Arc::from("b.rs"), Arc::from("bar"))),
Some(&0.5)
);
assert_eq!(
expanded.get(&(Arc::from("c.rs"), Arc::from("baz"))),
Some(&1.0)
);
assert_eq!(
expanded.get(&(Arc::from("d.rs"), Arc::from("qux"))),
Some(&0.5)
);
}
}