use std::collections::HashMap;
use std::path::PathBuf;
use serde::Deserialize;
use thiserror::Error;
use crate::learnings::redaction::contains_secrets;
use crate::learnings::{
LearningCaptureConfig, LearningError, capture_failed_command, redact_secrets,
};
#[derive(Debug, Clone, Copy, PartialEq, clap::ValueEnum)]
pub enum LearnHookType {
PreToolUse,
PostToolUse,
UserPromptSubmit,
}
#[derive(Debug, Clone, Copy, PartialEq, clap::ValueEnum)]
#[allow(dead_code)]
pub enum AgentFormat {
Claude,
Codex,
Opencode,
}
pub fn capture_from_hook(input: &HookInput) -> Result<PathBuf, LearningError> {
let command = input
.command()
.ok_or_else(|| LearningError::Ignored("No command in input".to_string()))?;
let error_output = input.error_output();
let exit_code = input.tool_result.exit_code;
let config = LearningCaptureConfig::default();
capture_failed_command(command, &error_output, exit_code, &config)
}
pub async fn process_hook_input_with_type(
_format: AgentFormat,
hook_type: LearnHookType,
) -> Result<(), HookError> {
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let mut buffer = String::new();
tokio::io::stdin()
.read_to_string(&mut buffer)
.await
.map_err(HookError::StdinError)?;
match hook_type {
LearnHookType::PreToolUse => {
process_pre_tool_use(&buffer);
}
LearnHookType::PostToolUse => {
match HookInput::from_json(&buffer) {
Ok(input) => {
if input.should_capture() {
if let Err(e) = capture_from_hook(&input) {
log::debug!("Hook capture failed: {}", e);
}
}
}
Err(e) => {
log::debug!("Hook parse failed (fail-open): {}", e);
}
}
}
LearnHookType::UserPromptSubmit => {
process_user_prompt_submit(&buffer);
}
}
let output = if contains_secrets(&buffer) {
log::debug!("Hook passthrough: secrets detected, redacting before stdout");
redact_secrets(&buffer)
} else {
buffer
};
tokio::io::stdout()
.write_all(output.as_bytes())
.await
.map_err(HookError::StdinError)?;
Ok(())
}
fn process_pre_tool_use(json: &str) {
let input = match HookInput::from_json(json) {
Ok(i) => i,
Err(_) => return, };
let command = match input.command() {
Some(c) => c,
None => return, };
let config = LearningCaptureConfig::default();
let storage_dir = config.storage_location();
let base_cmd = command.split_whitespace().next().unwrap_or(command);
let learnings = match crate::learnings::capture::query_learnings(&storage_dir, base_cmd, false)
{
Ok(l) => l,
Err(_) => return,
};
if learnings.is_empty() {
return;
}
let best = learnings
.iter()
.find(|l| l.correction.is_some())
.or(learnings.first());
if let Some(learning) = best {
let mut warning = format!(
"Warning: similar command failed before (exit {}): {}",
learning.exit_code, learning.command
);
if let Some(ref correction) = learning.correction {
warning.push_str(&format!("\n Suggested: {}", correction));
}
eprintln!("{}", warning);
}
}
fn process_user_prompt_submit(json: &str) {
let value: serde_json::Value = match serde_json::from_str(json) {
Ok(v) => v,
Err(_) => return, };
let prompt = match value.get("user_prompt").and_then(|v| v.as_str()) {
Some(p) => p,
None => return,
};
if let Some((original, corrected)) = parse_correction_pattern(prompt) {
let config = LearningCaptureConfig::default();
if let Err(e) = crate::learnings::capture_correction(
crate::learnings::CorrectionType::Other("user-prompt".to_string()),
&original,
&corrected,
&format!("Auto-captured from user prompt: {}", prompt),
&config,
) {
log::debug!("User prompt correction capture failed: {}", e);
}
}
}
fn parse_correction_pattern(text: &str) -> Option<(String, String)> {
let lower = text.to_lowercase();
if let Some(use_idx) = lower.find("use ") {
if let Some(instead_idx) = lower.find(" instead of ") {
let corrected = text[use_idx + 4..instead_idx].trim().to_string();
let original = text[instead_idx + 12..]
.trim()
.trim_end_matches('.')
.to_string();
if !corrected.is_empty() && !original.is_empty() {
return Some((original, corrected));
}
}
}
if let Some(prefer_idx) = lower.find("prefer ") {
if let Some(over_idx) = lower.find(" over ") {
let corrected = text[prefer_idx + 7..over_idx].trim().to_string();
let original = text[over_idx + 6..]
.trim()
.trim_end_matches('.')
.to_string();
if !corrected.is_empty() && !original.is_empty() {
return Some((original, corrected));
}
}
}
None
}
#[derive(Debug, Error)]
#[allow(dead_code)]
pub enum HookError {
#[error("failed to read stdin: {0}")]
StdinError(#[from] std::io::Error),
#[error("failed to parse hook input: {0}")]
ParseError(#[from] serde_json::Error),
#[error("capture failed: {0}")]
CaptureError(#[from] LearningError),
}
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)]
pub struct HookInput {
pub tool_name: String,
pub tool_input: ToolInput,
pub tool_result: ToolResult,
}
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)]
pub struct ToolInput {
pub command: Option<String>,
#[serde(flatten)]
pub extra: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)]
pub struct ToolResult {
pub exit_code: i32,
#[serde(default)]
pub stdout: String,
#[serde(default)]
pub stderr: String,
}
#[allow(dead_code)]
impl HookInput {
pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
serde_json::from_str(json)
}
pub fn should_capture(&self) -> bool {
self.tool_name == "Bash" && self.tool_result.exit_code != 0
}
pub fn error_output(&self) -> String {
let mut output = String::new();
if !self.tool_result.stdout.is_empty() {
output.push_str(&self.tool_result.stdout);
}
if !self.tool_result.stderr.is_empty() {
if !output.is_empty() {
output.push('\n');
}
output.push_str(&self.tool_result.stderr);
}
output
}
pub fn command(&self) -> Option<&str> {
self.tool_input.command.as_deref()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hook_input_parse() {
let json = r#"{
"tool_name": "Bash",
"tool_input": {"command": "git push -f"},
"tool_result": {"exit_code": 1, "stdout": "", "stderr": "rejected"}
}"#;
let input = HookInput::from_json(json).unwrap();
assert_eq!(input.tool_name, "Bash");
assert_eq!(input.command(), Some("git push -f"));
assert_eq!(input.tool_result.exit_code, 1);
assert_eq!(input.tool_result.stdout, "");
assert_eq!(input.tool_result.stderr, "rejected");
}
#[test]
fn test_should_capture_failed_bash() {
let input = HookInput {
tool_name: "Bash".to_string(),
tool_input: ToolInput {
command: Some("cmd".to_string()),
extra: HashMap::new(),
},
tool_result: ToolResult {
exit_code: 1,
stdout: String::new(),
stderr: String::new(),
},
};
assert!(input.should_capture());
}
#[test]
fn test_should_not_capture_success() {
let input = HookInput {
tool_name: "Bash".to_string(),
tool_input: ToolInput {
command: Some("cmd".to_string()),
extra: HashMap::new(),
},
tool_result: ToolResult {
exit_code: 0,
stdout: String::new(),
stderr: String::new(),
},
};
assert!(!input.should_capture());
}
#[test]
fn test_should_not_capture_edit() {
let input = HookInput {
tool_name: "Edit".to_string(),
tool_input: ToolInput {
command: None,
extra: HashMap::new(),
},
tool_result: ToolResult {
exit_code: 0,
stdout: String::new(),
stderr: String::new(),
},
};
assert!(!input.should_capture());
}
#[test]
fn test_error_output_combining() {
let input = HookInput {
tool_name: "Bash".to_string(),
tool_input: ToolInput {
command: Some("cmd".to_string()),
extra: HashMap::new(),
},
tool_result: ToolResult {
exit_code: 1,
stdout: "output line 1".to_string(),
stderr: "error line 1".to_string(),
},
};
assert_eq!(input.error_output(), "output line 1\nerror line 1");
}
#[test]
fn test_command_extraction() {
let input = HookInput {
tool_name: "Bash".to_string(),
tool_input: ToolInput {
command: Some("git push origin main".to_string()),
extra: HashMap::new(),
},
tool_result: ToolResult {
exit_code: 0,
stdout: String::new(),
stderr: String::new(),
},
};
assert_eq!(input.command(), Some("git push origin main"));
}
#[test]
fn test_command_extraction_none() {
let input = HookInput {
tool_name: "Edit".to_string(),
tool_input: ToolInput {
command: None,
extra: HashMap::new(),
},
tool_result: ToolResult {
exit_code: 0,
stdout: String::new(),
stderr: String::new(),
},
};
assert_eq!(input.command(), None);
}
#[test]
fn test_parse_with_extra_fields() {
let json = r#"{
"tool_name": "Write",
"tool_input": {
"path": "/tmp/test.txt",
"content": "hello world"
},
"tool_result": {"exit_code": 0, "stdout": "", "stderr": ""}
}"#;
let input = HookInput::from_json(json).unwrap();
assert_eq!(input.tool_name, "Write");
assert!(input.tool_input.command.is_none());
assert!(input.tool_input.extra.contains_key("path"));
assert!(input.tool_input.extra.contains_key("content"));
}
#[test]
fn test_error_output_stdout_only() {
let input = HookInput {
tool_name: "Bash".to_string(),
tool_input: ToolInput {
command: Some("cmd".to_string()),
extra: HashMap::new(),
},
tool_result: ToolResult {
exit_code: 1,
stdout: "some output".to_string(),
stderr: String::new(),
},
};
assert_eq!(input.error_output(), "some output");
}
#[test]
fn test_error_output_stderr_only() {
let input = HookInput {
tool_name: "Bash".to_string(),
tool_input: ToolInput {
command: Some("cmd".to_string()),
extra: HashMap::new(),
},
tool_result: ToolResult {
exit_code: 1,
stdout: String::new(),
stderr: "some error".to_string(),
},
};
assert_eq!(input.error_output(), "some error");
}
#[test]
fn test_error_output_empty() {
let input = HookInput {
tool_name: "Bash".to_string(),
tool_input: ToolInput {
command: Some("cmd".to_string()),
extra: HashMap::new(),
},
tool_result: ToolResult {
exit_code: 1,
stdout: String::new(),
stderr: String::new(),
},
};
assert_eq!(input.error_output(), "");
}
#[test]
fn test_parse_invalid_json() {
let json = "not valid json";
let result = HookInput::from_json(json);
assert!(result.is_err());
}
#[test]
fn test_should_not_capture_bash_with_exit_zero() {
let input = HookInput {
tool_name: "Bash".to_string(),
tool_input: ToolInput {
command: Some("echo hello".to_string()),
extra: HashMap::new(),
},
tool_result: ToolResult {
exit_code: 0,
stdout: "hello".to_string(),
stderr: String::new(),
},
};
assert!(!input.should_capture());
}
#[test]
fn test_should_capture_bash_with_negative_exit_code() {
let input = HookInput {
tool_name: "Bash".to_string(),
tool_input: ToolInput {
command: Some("kill -9 $$".to_string()),
extra: HashMap::new(),
},
tool_result: ToolResult {
exit_code: -1,
stdout: String::new(),
stderr: "Killed".to_string(),
},
};
assert!(input.should_capture());
}
#[test]
fn test_should_not_capture_non_bash_even_if_failed() {
let input = HookInput {
tool_name: "Write".to_string(),
tool_input: ToolInput {
command: None,
extra: HashMap::new(),
},
tool_result: ToolResult {
exit_code: 1,
stdout: String::new(),
stderr: "Permission denied".to_string(),
},
};
assert!(!input.should_capture());
}
#[test]
fn test_capture_from_hook_success() {
let input = HookInput {
tool_name: "Bash".to_string(),
tool_input: ToolInput {
command: Some("git push".to_string()),
extra: HashMap::new(),
},
tool_result: ToolResult {
exit_code: 1,
stdout: String::new(),
stderr: "rejected".to_string(),
},
};
let result = capture_from_hook(&input);
if let Err(LearningError::Ignored(msg)) = &result {
assert_ne!(msg, "No command in input");
}
}
#[test]
fn test_capture_from_hook_no_command() {
let input = HookInput {
tool_name: "Edit".to_string(),
tool_input: ToolInput {
command: None,
extra: HashMap::new(),
},
tool_result: ToolResult {
exit_code: 0,
stdout: String::new(),
stderr: String::new(),
},
};
let result = capture_from_hook(&input);
assert!(result.is_err());
match result.unwrap_err() {
LearningError::Ignored(msg) => assert_eq!(msg, "No command in input"),
_ => panic!("Expected Ignored error"),
}
}
#[test]
fn test_agent_format_variants() {
assert_ne!(AgentFormat::Claude, AgentFormat::Codex);
assert_ne!(AgentFormat::Claude, AgentFormat::Opencode);
assert_ne!(AgentFormat::Codex, AgentFormat::Opencode);
}
#[test]
fn test_hook_passthrough_redacts_aws_key_in_error() {
use crate::learnings::redact_secrets;
use crate::learnings::redaction::contains_secrets;
let aws_key = format!("AKIA{}", "IOSFODNN7EXAMPLE");
let json = format!(
r#"{{
"tool_name": "Bash",
"tool_input": {{"command": "aws s3 ls"}},
"tool_result": {{
"exit_code": 1,
"stdout": "",
"stderr": "Unable to locate credentials {}"
}}
}}"#,
aws_key
);
assert!(contains_secrets(&json));
let redacted = redact_secrets(&json);
assert!(!redacted.contains(&aws_key));
assert!(redacted.contains("[AWS_KEY_REDACTED]"));
let parsed = HookInput::from_json(&redacted).unwrap();
assert_eq!(parsed.tool_name, "Bash");
assert_eq!(parsed.tool_result.exit_code, 1);
assert!(parsed.tool_result.stderr.contains("[AWS_KEY_REDACTED]"));
}
#[test]
fn test_learn_hook_type_variants() {
assert_ne!(LearnHookType::PreToolUse, LearnHookType::PostToolUse);
assert_ne!(LearnHookType::PostToolUse, LearnHookType::UserPromptSubmit);
assert_ne!(LearnHookType::PreToolUse, LearnHookType::UserPromptSubmit);
}
#[test]
fn test_parse_correction_pattern_use_instead_of() {
let result = parse_correction_pattern("use bun instead of npm");
assert_eq!(result, Some(("npm".to_string(), "bun".to_string())));
}
#[test]
fn test_parse_correction_pattern_prefer_over() {
let result = parse_correction_pattern("prefer cargo over make");
assert_eq!(result, Some(("make".to_string(), "cargo".to_string())));
}
#[test]
fn test_parse_correction_pattern_with_trailing_period() {
let result = parse_correction_pattern("use Result<T> instead of unwrap().");
assert_eq!(
result,
Some(("unwrap()".to_string(), "Result<T>".to_string()))
);
}
#[test]
fn test_parse_correction_pattern_no_match() {
assert!(parse_correction_pattern("hello world").is_none());
assert!(parse_correction_pattern("this is fine").is_none());
}
#[test]
fn test_pre_tool_use_no_crash_on_non_bash() {
let json = r#"{
"tool_name": "Edit",
"tool_input": {"path": "/tmp/test.txt"},
"tool_result": {"exit_code": 0, "stdout": "", "stderr": ""}
}"#;
process_pre_tool_use(json);
}
#[test]
fn test_pre_tool_use_no_crash_on_invalid_json() {
process_pre_tool_use("not valid json");
}
#[test]
fn test_user_prompt_submit_no_crash_on_empty() {
process_user_prompt_submit("{}");
}
#[test]
fn test_user_prompt_submit_no_crash_on_invalid_json() {
process_user_prompt_submit("invalid");
}
}