use crate::config::load_config;
use crate::directory::detect_directory_references;
use crate::history;
use crate::types::{Config, HookInput, HookOutput};
use anyhow::{Context, Result};
use once_cell::sync::Lazy;
use regex::Regex;
use std::collections::HashMap;
use std::io::{self, Read};
use std::path::PathBuf;
use std::sync::Mutex;
static REGEX_CACHE: Lazy<Mutex<HashMap<String, Regex>>> = Lazy::new(|| Mutex::new(HashMap::new()));
pub fn run_as_hook(config_path: &str, replace_mode: bool) -> Result<()> {
let config = load_config(config_path)?;
let mut buffer = String::new();
io::stdin().read_to_string(&mut buffer)?;
let hook_input: HookInput =
serde_json::from_str(&buffer).context("Failed to parse hook input JSON")?;
match hook_input.hook_event_name.as_str() {
"PreToolUse" => handle_pre_tool_use(&config, &hook_input, replace_mode)?,
"UserPromptSubmit" => handle_user_prompt_submit(&config, &hook_input)?,
"PostToolUse" => handle_post_tool_use(&config, &hook_input)?,
_ => {
eprintln!("Warning: Unknown hook event type: {}", hook_input.hook_event_name);
}
}
Ok(())
}
fn handle_pre_tool_use(config: &Config, hook_input: &HookInput, replace_mode: bool) -> Result<()> {
if hook_input.tool_name.as_deref() != Some("Bash") {
return Ok(());
}
let Some(tool_input) = &hook_input.tool_input else {
return Ok(());
};
let Some(command) = &tool_input.command else {
return Ok(());
};
if let Some((suggestion, replacement_cmd)) = check_command_mappings(config, command)? {
let output = if replace_mode {
HookOutput {
decision: "replace".to_string(),
reason: format!("Command mapped: using '{replacement_cmd}' instead"),
replacement_command: Some(replacement_cmd),
}
} else {
HookOutput {
decision: "block".to_string(),
reason: suggestion,
replacement_command: None,
}
};
println!("{}", serde_json::to_string(&output)?);
std::process::exit(0);
}
Ok(())
}
fn handle_user_prompt_submit(config: &Config, hook_input: &HookInput) -> Result<()> {
let Some(prompt) = &hook_input.prompt else {
return Ok(());
};
let directory_refs = detect_directory_references(config, prompt);
if !directory_refs.is_empty() {
for resolution in directory_refs {
println!("Directory reference '{}' resolved to: {}",
resolution.alias_used,
resolution.canonical_path
);
if !resolution.variables_substituted.is_empty() {
println!(" Variables substituted: {:?}", resolution.variables_substituted);
}
}
}
Ok(())
}
fn handle_post_tool_use(config: &Config, hook_input: &HookInput) -> Result<()> {
let Some(tool_name) = &hook_input.tool_name else {
return Ok(());
};
let Some(tool_response) = &hook_input.tool_response else {
return Ok(());
};
if tool_name != "Bash" {
return Ok(());
}
let history_config = match &config.command_history {
Some(cfg) if cfg.enabled => cfg,
_ => return Ok(()), };
let Some(tool_input) = &hook_input.tool_input else {
return Ok(());
};
let Some(command) = &tool_input.command else {
return Ok(());
};
let log_path = expand_tilde(&history_config.log_file)?;
let conn = history::init_database(&log_path)
.context("Failed to initialize command history database")?;
let (was_replaced, original_command) = if let Some((_, _)) = check_command_mappings(config, command)? {
(false, None)
} else {
(false, None)
};
let record = history::create_record(
&hook_input.session_id,
command,
tool_response.exit_code,
hook_input.cwd.as_deref(),
was_replaced,
original_command,
);
history::log_command(&conn, &record)
.context("Failed to log command to history")?;
Ok(())
}
fn expand_tilde(path: &str) -> Result<PathBuf> {
if path.starts_with("~/") {
let home = std::env::var("HOME")
.context("HOME environment variable not set")?;
Ok(PathBuf::from(path.replacen("~", &home, 1)))
} else {
Ok(PathBuf::from(path))
}
}
fn get_cached_regex(pattern: &str) -> Result<Regex> {
let mut cache = REGEX_CACHE.lock()
.expect("regex cache mutex should not be poisoned");
if let Some(regex) = cache.get(pattern) {
return Ok(regex.clone());
}
let regex = Regex::new(pattern)?;
cache.insert(pattern.to_string(), regex.clone());
Ok(regex)
}
pub fn check_command_mappings(config: &Config, command: &str) -> Result<Option<(String, String)>> {
for (pattern, replacement) in &config.commands {
let regex_pattern = format!(r"^({})(\s|$)", regex::escape(pattern));
let regex = get_cached_regex(®ex_pattern)?;
if regex.is_match(command) {
let suggested_command = regex.replace_all(command, |caps: ®ex::Captures| {
format!("{}{}", replacement, &caps[2])
});
let suggestion = format!(
"Command '{pattern}' is mapped to use '{replacement}' instead. Try: {suggested_command}"
);
return Ok(Some((suggestion, suggested_command.to_string())));
}
}
Ok(None)
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
#[test]
fn test_command_mapping() {
let mut commands = HashMap::new();
commands.insert("npm".to_string(), "bun".to_string());
commands.insert("yarn".to_string(), "bun".to_string());
commands.insert("npx".to_string(), "bunx".to_string());
let config = Config {
commands,
semantic_directories: HashMap::new(),
command_history: None,
};
let result = check_command_mappings(&config, "npm install").unwrap();
assert!(result.is_some());
let (suggestion, replacement) = result.unwrap();
assert!(suggestion.contains("bun install"));
assert_eq!(replacement, "bun install");
let result = check_command_mappings(&config, "yarn start").unwrap();
assert!(result.is_some());
let (suggestion, replacement) = result.unwrap();
assert!(suggestion.contains("bun start"));
assert_eq!(replacement, "bun start");
}
#[test]
fn test_command_mapping_edge_cases() {
let mut commands = HashMap::new();
commands.insert("npm".to_string(), "bun".to_string());
let config = Config {
commands,
semantic_directories: HashMap::new(),
command_history: None,
};
let result = check_command_mappings(&config, "my-npm-tool install").unwrap();
assert!(result.is_none(), "npm in 'my-npm-tool' should NOT match");
let result = check_command_mappings(&config, "").unwrap();
assert!(result.is_none());
let result = check_command_mappings(&config, "npm install --verbose").unwrap();
assert!(result.is_some());
let (_, replacement) = result.unwrap();
assert_eq!(replacement, "bun install --verbose");
let result = check_command_mappings(&config, "run npm").unwrap();
assert!(result.is_none(), "'npm' in 'run npm' should NOT match (not primary command)");
let result = check_command_mappings(&config, "npmc install").unwrap();
assert!(result.is_none(), "'npmc' should NOT match 'npm'");
let result = check_command_mappings(&config, "npm").unwrap();
assert!(result.is_some());
let (_, replacement) = result.unwrap();
assert_eq!(replacement, "bun");
}
#[test]
fn test_command_mapping_prevents_false_positives() {
let mut commands = HashMap::new();
commands.insert("RM".to_string(), "rm -i".to_string());
let config = Config {
commands,
semantic_directories: HashMap::new(),
command_history: None,
};
let result = check_command_mappings(&config, "RM file.txt").unwrap();
assert!(result.is_some());
let (_, replacement) = result.unwrap();
assert_eq!(replacement, "rm -i file.txt");
let result = check_command_mappings(&config, "RMm file.txt").unwrap();
assert!(result.is_none(), "'RMm' should NOT match 'RM'");
let result = check_command_mappings(&config, "gitRM file.txt").unwrap();
assert!(result.is_none(), "'gitRM' should NOT match 'RM'");
let result = check_command_mappings(&config, "git RM file.txt").unwrap();
assert!(result.is_none(), "'RM' in 'git RM' should NOT match (not primary command)");
let result = check_command_mappings(&config, "git-RM file.txt").unwrap();
assert!(result.is_none(), "'git-RM' should NOT match 'RM'");
let result = check_command_mappings(&config, "RM-tool file.txt").unwrap();
assert!(result.is_none(), "'RM-tool' should NOT match 'RM'");
}
#[test]
fn test_hook_output_serialization() {
let output = HookOutput {
decision: "block".to_string(),
reason: "Test reason".to_string(),
replacement_command: Some("test command".to_string()),
};
let json = serde_json::to_string(&output).unwrap();
assert!(json.contains("\"decision\":\"block\""));
assert!(json.contains("\"reason\":\"Test reason\""));
assert!(json.contains("\"replacement_command\":\"test command\""));
let output = HookOutput {
decision: "allow".to_string(),
reason: "No mapping found".to_string(),
replacement_command: None,
};
let json = serde_json::to_string(&output).unwrap();
assert!(json.contains("\"decision\":\"allow\""));
assert!(json.contains("\"reason\":\"No mapping found\""));
assert!(!json.contains("replacement_command"));
}
}