pub mod cache;
pub mod types;
use std::path::{Path, PathBuf};
use anyhow::Context;
use include_dir::{Dir, DirEntry, include_dir};
use types::{CommandPattern, FilterConfig};
static STDLIB: Dir<'static> = include_dir!("$CARGO_MANIFEST_DIR/filters");
pub fn get_embedded_filter(relative_path: &Path) -> Option<&'static str> {
STDLIB.get_file(relative_path)?.contents_utf8()
}
pub fn default_search_dirs() -> Vec<PathBuf> {
let mut dirs = Vec::new();
if let Ok(cwd) = std::env::current_dir() {
dirs.push(cwd.join(".tokf/filters"));
}
if let Some(config) = dirs::config_dir() {
dirs.push(config.join("tokf/filters"));
}
dirs
}
pub fn try_load_filter(path: &Path) -> anyhow::Result<Option<FilterConfig>> {
let content = match std::fs::read_to_string(path) {
Ok(c) => c,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => return Ok(None),
Err(e) => {
return Err(anyhow::Error::new(e)
.context(format!("failed to read filter file: {}", path.display())));
}
};
let config: FilterConfig = toml::from_str(&content)
.with_context(|| format!("failed to parse filter file: {}", path.display()))?;
Ok(Some(config))
}
pub fn pattern_specificity(pattern: &str) -> usize {
pattern.split_whitespace().filter(|w| *w != "*").count()
}
pub fn pattern_matches_prefix(pattern: &str, words: &[&str]) -> Option<usize> {
let pattern_words: Vec<&str> = pattern.split_whitespace().collect();
if pattern_words.is_empty() || words.len() < pattern_words.len() {
return None;
}
for (i, pword) in pattern_words.iter().enumerate() {
if *pword == "*" {
if words[i].is_empty() {
return None;
}
} else if words[i] != *pword {
return None;
}
}
Some(pattern_words.len())
}
pub fn discover_filter_files(dir: &Path) -> Vec<PathBuf> {
let mut files = Vec::new();
collect_filter_files(dir, &mut files);
files.sort();
files
}
fn collect_filter_files(dir: &Path, files: &mut Vec<PathBuf>) {
let Ok(entries) = std::fs::read_dir(dir) else {
return;
};
let mut entries: Vec<_> = entries.filter_map(Result::ok).collect();
entries.sort_by_key(std::fs::DirEntry::file_name);
for entry in entries {
let path = entry.path();
let name = entry.file_name();
let name_str = name.to_string_lossy();
if name_str.starts_with('.') {
continue;
}
if path.is_dir() {
collect_filter_files(&path, files);
} else if path.extension().is_some_and(|e| e == "toml") {
files.push(path);
}
}
}
pub struct ResolvedFilter {
pub config: FilterConfig,
pub source_path: PathBuf,
pub relative_path: PathBuf,
pub priority: u8,
}
impl ResolvedFilter {
pub fn matches(&self, words: &[&str]) -> Option<usize> {
for pattern in self.config.command.patterns() {
if let Some(consumed) = pattern_matches_prefix(pattern, words) {
return Some(consumed);
}
}
None
}
pub fn specificity(&self) -> usize {
self.config
.command
.patterns()
.iter()
.map(|p| pattern_specificity(p))
.max()
.unwrap_or(0)
}
pub const fn priority_label(&self) -> &'static str {
match self.priority {
0 => "local",
1 => "user",
_ => "built-in",
}
}
}
pub fn discover_all_filters(search_dirs: &[PathBuf]) -> anyhow::Result<Vec<ResolvedFilter>> {
let mut all_filters: Vec<ResolvedFilter> = Vec::new();
for (priority, dir) in search_dirs.iter().enumerate() {
let files = discover_filter_files(dir);
for path in files {
let Ok(Some(config)) = try_load_filter(&path) else {
continue;
};
let relative_path = path.strip_prefix(dir).unwrap_or(&path).to_path_buf();
all_filters.push(ResolvedFilter {
config,
source_path: path,
relative_path,
priority: u8::try_from(priority).unwrap_or(u8::MAX),
});
}
}
let stdlib_priority = u8::MAX;
if let Ok(entries) = STDLIB.find("**/*.toml") {
for entry in entries {
if let DirEntry::File(file) = entry {
let content = file.contents_utf8().unwrap_or("");
let Ok(config) = toml::from_str::<FilterConfig>(content) else {
continue; };
let rel = file.path().to_path_buf();
all_filters.push(ResolvedFilter {
config,
source_path: PathBuf::from("<built-in>").join(&rel),
relative_path: rel,
priority: stdlib_priority,
});
}
}
}
all_filters.sort_by(|a, b| {
a.priority
.cmp(&b.priority)
.then_with(|| b.specificity().cmp(&a.specificity()))
});
let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
all_filters.retain(|f| seen.insert(f.config.command.first().to_string()));
Ok(all_filters)
}
pub fn command_pattern_to_regex(pattern: &str) -> String {
let escaped_words: Vec<String> = pattern
.split_whitespace()
.map(|w| {
if w == "*" {
r"\S+".to_string()
} else {
regex::escape(w)
}
})
.collect();
format!("^{}(\\s.*)?$", escaped_words.join(r"\ "))
}
pub fn command_pattern_regexes(command: &CommandPattern) -> Vec<(String, String)> {
command
.patterns()
.iter()
.map(|p| (p.clone(), command_pattern_to_regex(p)))
.collect()
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use std::fs;
use tempfile::TempDir;
use super::*;
#[test]
fn specificity_two_literals() {
assert_eq!(pattern_specificity("git push"), 2);
}
#[test]
fn specificity_wildcard_counts_less() {
assert_eq!(pattern_specificity("git *"), 1);
assert_eq!(pattern_specificity("* push"), 1);
}
#[test]
fn specificity_all_wildcards() {
assert_eq!(pattern_specificity("* *"), 0);
}
#[test]
fn specificity_ordering() {
assert!(pattern_specificity("git push") > pattern_specificity("git *"));
assert!(pattern_specificity("git *") == pattern_specificity("* push"));
}
#[test]
fn matches_exact() {
let words = ["git", "push"];
assert_eq!(pattern_matches_prefix("git push", &words), Some(2));
}
#[test]
fn matches_prefix_with_trailing_args() {
let words = ["git", "push", "origin", "main"];
assert_eq!(pattern_matches_prefix("git push", &words), Some(2));
}
#[test]
fn matches_wildcard() {
let words = ["npm", "run", "build"];
assert_eq!(pattern_matches_prefix("npm run *", &words), Some(3));
}
#[test]
fn no_match_different_command() {
let words = ["cargo", "test"];
assert_eq!(pattern_matches_prefix("git push", &words), None);
}
#[test]
fn no_match_too_short() {
let words = ["git"];
assert_eq!(pattern_matches_prefix("git push", &words), None);
}
#[test]
fn empty_pattern_returns_none() {
let words = ["git", "push"];
assert_eq!(pattern_matches_prefix("", &words), None);
}
#[test]
fn empty_words_returns_none() {
assert_eq!(pattern_matches_prefix("git push", &[]), None);
}
#[test]
fn single_word_pattern_prefix_match() {
assert_eq!(pattern_matches_prefix("echo", &["echo"]), Some(1));
assert_eq!(pattern_matches_prefix("echo", &["echo", "hello"]), Some(1));
assert_eq!(pattern_matches_prefix("echo", &["ls"]), None);
}
#[test]
fn wildcard_rejects_empty_token() {
assert_eq!(pattern_matches_prefix("git *", &["git", ""]), None);
}
#[test]
fn wildcard_at_start() {
let words = ["my-tool", "subcommand"];
assert_eq!(pattern_matches_prefix("* subcommand", &words), Some(2));
}
#[test]
fn hyphenated_tool_not_ambiguous() {
let words = ["golangci-lint", "run"];
assert_eq!(pattern_matches_prefix("golangci-lint run", &words), Some(2));
assert_eq!(pattern_matches_prefix("golangci-lint", &words), Some(1));
}
#[test]
fn discover_flat_dir() {
let dir = TempDir::new().unwrap();
fs::write(dir.path().join("a.toml"), "").unwrap();
fs::write(dir.path().join("b.toml"), "").unwrap();
fs::write(dir.path().join("not-toml.txt"), "").unwrap();
let files = discover_filter_files(dir.path());
assert_eq!(files.len(), 2);
assert!(files[0].ends_with("a.toml"));
assert!(files[1].ends_with("b.toml"));
}
#[test]
fn discover_nested_dirs() {
let dir = TempDir::new().unwrap();
let sub = dir.path().join("git");
fs::create_dir_all(&sub).unwrap();
fs::write(sub.join("push.toml"), "").unwrap();
fs::write(sub.join("status.toml"), "").unwrap();
fs::write(dir.path().join("root.toml"), "").unwrap();
let files = discover_filter_files(dir.path());
assert_eq!(files.len(), 3);
assert!(files[0].ends_with("git/push.toml"));
assert!(files[1].ends_with("git/status.toml"));
assert!(files[2].ends_with("root.toml"));
}
#[test]
fn discover_skips_hidden_entries() {
let dir = TempDir::new().unwrap();
fs::write(dir.path().join(".hidden.toml"), "").unwrap();
fs::write(dir.path().join("visible.toml"), "").unwrap();
let hidden_dir = dir.path().join(".hiddendir");
fs::create_dir_all(&hidden_dir).unwrap();
fs::write(hidden_dir.join("inside.toml"), "").unwrap();
let files = discover_filter_files(dir.path());
assert_eq!(files.len(), 1);
assert!(files[0].ends_with("visible.toml"));
}
#[test]
fn discover_nonexistent_dir_returns_empty() {
let files = discover_filter_files(Path::new("/no/such/directory/ever"));
assert!(files.is_empty());
}
#[test]
fn discover_all_priority_ordering() {
let dir1 = TempDir::new().unwrap();
let dir2 = TempDir::new().unwrap();
fs::write(
dir1.path().join("my-cmd.toml"),
"command = \"my cmd local\"",
)
.unwrap();
fs::write(dir2.path().join("my-cmd.toml"), "command = \"my cmd user\"").unwrap();
let dirs = vec![dir1.path().to_path_buf(), dir2.path().to_path_buf()];
let filters = discover_all_filters(&dirs).unwrap();
assert!(filters.len() >= 2);
assert_eq!(filters[0].config.command.first(), "my cmd local");
assert_eq!(filters[0].priority, 0);
}
#[test]
fn discover_all_dedup_same_command() {
let dir1 = TempDir::new().unwrap();
let dir2 = TempDir::new().unwrap();
fs::write(dir1.path().join("a.toml"), "command = \"git push\"").unwrap();
fs::write(dir2.path().join("b.toml"), "command = \"git push\"").unwrap();
let dirs = vec![dir1.path().to_path_buf(), dir2.path().to_path_buf()];
let filters = discover_all_filters(&dirs).unwrap();
let push_entries: Vec<_> = filters
.iter()
.filter(|f| f.config.command.first() == "git push")
.collect();
assert_eq!(push_entries.len(), 1);
assert_eq!(push_entries[0].priority, 0);
}
#[test]
fn discover_all_specificity_ordering() {
let dir = TempDir::new().unwrap();
fs::write(dir.path().join("a.toml"), "command = \"git *\"").unwrap();
fs::write(dir.path().join("b.toml"), "command = \"git push\"").unwrap();
let dirs = vec![dir.path().to_path_buf()];
let filters = discover_all_filters(&dirs).unwrap();
assert_eq!(filters[0].config.command.first(), "git push");
assert_eq!(filters[1].config.command.first(), "git *");
}
#[test]
fn discover_all_skips_invalid_toml() {
let dir = TempDir::new().unwrap();
fs::write(dir.path().join("bad.toml"), "not valid [[[").unwrap();
fs::write(dir.path().join("good.toml"), "command = \"my tool\"").unwrap();
let filters = discover_all_filters(&[dir.path().to_path_buf()]).unwrap();
let my_tool: Vec<_> = filters
.iter()
.filter(|f| f.config.command.first() == "my tool")
.collect();
assert_eq!(my_tool.len(), 1);
}
#[test]
fn discover_all_hyphenated_tool_not_ambiguous() {
let dir = TempDir::new().unwrap();
fs::write(
dir.path().join("golangci-lint.toml"),
"command = \"golangci-lint run\"",
)
.unwrap();
let filters = discover_all_filters(&[dir.path().to_path_buf()]).unwrap();
let golangci: Vec<_> = filters
.iter()
.filter(|f| f.config.command.first() == "golangci-lint run")
.collect();
assert_eq!(golangci.len(), 1);
let words = ["golangci-lint", "run"];
assert_eq!(golangci[0].matches(&words), Some(2));
let words_no_match = ["golangci", "lint", "run"];
assert_eq!(golangci[0].matches(&words_no_match), None);
}
#[test]
fn embedded_stdlib_non_empty() {
let entries: Vec<_> = STDLIB.find("**/*.toml").unwrap().collect();
assert!(
entries.len() >= 10,
"expected at least 10 embedded filters, got {}",
entries.len()
);
}
#[test]
fn all_embedded_toml_parse() {
for entry in STDLIB.find("**/*.toml").unwrap() {
if let DirEntry::File(file) = entry {
let content = file.contents_utf8().unwrap_or("");
assert!(
toml::from_str::<FilterConfig>(content).is_ok(),
"failed to parse embedded filter: {}",
file.path().display()
);
}
}
}
#[test]
fn embedded_filters_in_discover_with_no_dirs() {
let filters = discover_all_filters(&[]).unwrap();
assert!(
!filters.is_empty(),
"expected embedded stdlib filters with no search dirs"
);
let has_git_push = filters
.iter()
.any(|f| f.config.command.first() == "git push");
assert!(has_git_push, "expected git push in embedded stdlib");
}
#[test]
fn local_filter_shadows_embedded() {
let dir = TempDir::new().unwrap();
fs::write(
dir.path().join("push.toml"),
"command = \"git push\"\n# local override",
)
.unwrap();
let dirs = vec![dir.path().to_path_buf()];
let filters = discover_all_filters(&dirs).unwrap();
let push_entries: Vec<_> = filters
.iter()
.filter(|f| f.config.command.first() == "git push")
.collect();
assert_eq!(push_entries.len(), 1);
assert_eq!(push_entries[0].priority, 0); }
#[test]
fn test_load_valid_toml() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("test.toml");
fs::write(&path, "command = \"echo hello\"").unwrap();
let config = try_load_filter(&path).unwrap().unwrap();
assert_eq!(config.command.first(), "echo hello");
}
#[test]
fn test_load_invalid_toml() {
let dir = TempDir::new().unwrap();
let path = dir.path().join("bad.toml");
fs::write(&path, "not valid toml [[[").unwrap();
assert!(try_load_filter(&path).is_err());
}
#[test]
fn test_load_nonexistent_returns_none() {
let path = PathBuf::from("/tmp/nonexistent-tokf-test-file.toml");
assert!(try_load_filter(&path).unwrap().is_none());
}
#[test]
fn test_load_real_stdlib_filter() {
let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("filters/git/push.toml");
let config = try_load_filter(&path).unwrap().unwrap();
assert_eq!(config.command.first(), "git push");
}
#[test]
fn test_default_search_dirs_non_empty_and_starts_with_local() {
let dirs = default_search_dirs();
assert!(!dirs.is_empty());
assert!(
dirs[0].is_absolute(),
"first dir should be absolute, got: {:?}",
dirs[0]
);
assert!(
dirs[0].ends_with(".tokf/filters"),
"first dir should end with .tokf/filters, got: {:?}",
dirs[0]
);
}
#[test]
fn test_default_search_dirs_only_local_and_user() {
let dirs = default_search_dirs();
assert!(
dirs.len() <= 2,
"expected at most 2 search dirs (local + user), got {}: {:?}",
dirs.len(),
dirs
);
}
#[test]
fn regex_from_literal_pattern() {
let r = command_pattern_to_regex("git push");
let re = regex::Regex::new(&r).unwrap();
assert!(re.is_match("git push"));
assert!(re.is_match("git push origin main"));
assert!(!re.is_match("git status"));
}
#[test]
fn regex_from_wildcard_pattern() {
let r = command_pattern_to_regex("npm run *");
let re = regex::Regex::new(&r).unwrap();
assert!(re.is_match("npm run build"));
assert!(re.is_match("npm run test --watch"));
assert!(!re.is_match("npm run"));
assert!(!re.is_match("npm install"));
}
}