use std::collections::HashMap;
use std::path::PathBuf;
use serde::Deserialize;
use thiserror::Error;
use crate::learnings::{
LearningCaptureConfig, LearningError, capture_failed_command, redact_secrets,
};
#[derive(Debug, Clone, Copy, PartialEq, clap::ValueEnum)]
pub enum LearnHookType {
PreToolUse,
PostToolUse,
UserPromptSubmit,
}
pub fn capture_from_hook(input: &HookInput) -> Result<PathBuf, LearningError> {
let command = input
.command()
.ok_or_else(|| LearningError::Ignored("No command in input".to_string()))?;
let error_output = input.error_output();
let exit_code = input.tool_result.exit_code;
let config = LearningCaptureConfig::default();
capture_failed_command(command, &error_output, exit_code, &config)
}
pub async fn process_hook_input_with_type(hook_type: LearnHookType) -> Result<(), HookError> {
process_hook_with_streams(hook_type, tokio::io::stdin(), tokio::io::stdout()).await
}
pub(crate) async fn process_hook_with_streams<R, W>(
hook_type: LearnHookType,
mut reader: R,
mut writer: W,
) -> Result<(), HookError>
where
R: tokio::io::AsyncRead + Unpin,
W: tokio::io::AsyncWrite + Unpin,
{
use tokio::io::{AsyncReadExt, AsyncWriteExt};
let mut buffer = String::new();
reader
.read_to_string(&mut buffer)
.await
.map_err(HookError::Stdin)?;
match hook_type {
LearnHookType::PreToolUse => {
process_pre_tool_use(&buffer);
}
LearnHookType::PostToolUse => {
match HookInput::from_json(&buffer) {
Ok(input) => {
if input.should_capture()
&& let Err(e) = capture_from_hook(&input)
{
log::debug!("Hook capture failed: {}", e);
}
}
Err(e) => {
log::debug!("Hook parse failed (fail-open): {}", e);
}
}
}
LearnHookType::UserPromptSubmit => {
process_user_prompt_submit(&buffer);
}
}
let output = redact_secrets(&buffer);
writer
.write_all(output.as_bytes())
.await
.map_err(HookError::Stdin)?;
Ok(())
}
fn process_pre_tool_use(json: &str) {
let input = match HookInput::from_json(json) {
Ok(i) => i,
Err(_) => return, };
let command = match input.command() {
Some(c) => c,
None => return, };
let config = LearningCaptureConfig::default();
let storage_dir = config.storage_location();
let base_cmd = command.split_whitespace().next().unwrap_or(command);
let learnings = match crate::learnings::capture::query_learnings(&storage_dir, base_cmd, false)
{
Ok(l) => l,
Err(_) => return,
};
if learnings.is_empty() {
return;
}
let best = learnings
.iter()
.find(|l| l.correction.is_some())
.or(learnings.first());
if let Some(learning) = best {
let mut warning = format!(
"Warning: similar command failed before (exit {}): {}",
learning.exit_code, learning.command
);
if let Some(ref correction) = learning.correction {
warning.push_str(&format!("\n Suggested: {}", correction));
}
eprintln!("{}", warning);
}
}
fn process_user_prompt_submit(json: &str) {
let value: serde_json::Value = match serde_json::from_str(json) {
Ok(v) => v,
Err(_) => return, };
let prompt = match value.get("user_prompt").and_then(|v| v.as_str()) {
Some(p) => p,
None => return,
};
if let Some((original, corrected)) = parse_correction_pattern(prompt) {
let config = LearningCaptureConfig::default();
if let Err(e) = crate::learnings::capture_correction(
crate::learnings::CorrectionType::ToolPreference,
&original,
&corrected,
&format!("Auto-captured from user prompt: {}", prompt),
&config,
) {
log::debug!("User prompt correction capture failed: {}", e);
}
}
}
fn parse_correction_pattern(text: &str) -> Option<(String, String)> {
let lower = text.to_lowercase();
let trimmed = lower.trim_start();
if let Some(use_idx) = trimmed.find("use ")
&& use_idx == 0
{
let text_after_use = &text[text.to_lowercase().trim_start().find("use ").unwrap() + 4..];
let lower_after_use = text_after_use.to_lowercase();
if let Some(instead_idx) = lower_after_use.find(" instead of ") {
let corrected = text_after_use[..instead_idx].trim().to_string();
let original = text_after_use[instead_idx + 12..]
.trim()
.trim_end_matches('.')
.to_string();
if !corrected.is_empty() && !original.is_empty() {
return Some((original, corrected));
}
}
if let Some(not_idx) = lower_after_use.find(" not ") {
let corrected = text_after_use[..not_idx].trim().to_string();
let original = text_after_use[not_idx + 5..]
.trim()
.trim_end_matches('.')
.to_string();
if !corrected.is_empty() && !original.is_empty() {
return Some((original, corrected));
}
}
}
if let Some(prefer_idx) = trimmed.find("prefer ")
&& prefer_idx == 0
{
let text_after_prefer =
&text[text.to_lowercase().trim_start().find("prefer ").unwrap() + 7..];
let lower_after_prefer = text_after_prefer.to_lowercase();
if let Some(over_idx) = lower_after_prefer.find(" over ") {
let corrected = text_after_prefer[..over_idx].trim().to_string();
let original = text_after_prefer[over_idx + 6..]
.trim()
.trim_end_matches('.')
.to_string();
if !corrected.is_empty() && !original.is_empty() {
return Some((original, corrected));
}
}
}
None
}
#[derive(Debug, Error)]
pub enum HookError {
#[error("failed to read stdin: {0}")]
Stdin(#[from] std::io::Error),
}
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)]
pub struct HookInput {
pub tool_name: String,
pub tool_input: ToolInput,
pub tool_result: ToolResult,
}
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)]
pub struct ToolInput {
pub command: Option<String>,
#[serde(flatten)]
pub extra: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Clone, Deserialize)]
#[allow(dead_code)]
pub struct ToolResult {
pub exit_code: i32,
#[serde(default)]
pub stdout: String,
#[serde(default)]
pub stderr: String,
}
#[allow(dead_code)]
impl HookInput {
pub fn from_json(json: &str) -> Result<Self, serde_json::Error> {
serde_json::from_str(json)
}
pub fn should_capture(&self) -> bool {
self.tool_name == "Bash" && self.tool_result.exit_code != 0
}
pub fn error_output(&self) -> String {
let mut output = String::new();
if !self.tool_result.stdout.is_empty() {
output.push_str(&self.tool_result.stdout);
}
if !self.tool_result.stderr.is_empty() {
if !output.is_empty() {
output.push('\n');
}
output.push_str(&self.tool_result.stderr);
}
output
}
pub fn command(&self) -> Option<&str> {
self.tool_input.command.as_deref()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hook_input_parse() {
let json = r#"{
"tool_name": "Bash",
"tool_input": {"command": "git push -f"},
"tool_result": {"exit_code": 1, "stdout": "", "stderr": "rejected"}
}"#;
let input = HookInput::from_json(json).unwrap();
assert_eq!(input.tool_name, "Bash");
assert_eq!(input.command(), Some("git push -f"));
assert_eq!(input.tool_result.exit_code, 1);
assert_eq!(input.tool_result.stdout, "");
assert_eq!(input.tool_result.stderr, "rejected");
}
#[test]
fn test_should_capture_failed_bash() {
let input = HookInput {
tool_name: "Bash".to_string(),
tool_input: ToolInput {
command: Some("cmd".to_string()),
extra: HashMap::new(),
},
tool_result: ToolResult {
exit_code: 1,
stdout: String::new(),
stderr: String::new(),
},
};
assert!(input.should_capture());
}
#[test]
fn test_should_not_capture_success() {
let input = HookInput {
tool_name: "Bash".to_string(),
tool_input: ToolInput {
command: Some("cmd".to_string()),
extra: HashMap::new(),
},
tool_result: ToolResult {
exit_code: 0,
stdout: String::new(),
stderr: String::new(),
},
};
assert!(!input.should_capture());
}
#[test]
fn test_should_not_capture_edit() {
let input = HookInput {
tool_name: "Edit".to_string(),
tool_input: ToolInput {
command: None,
extra: HashMap::new(),
},
tool_result: ToolResult {
exit_code: 0,
stdout: String::new(),
stderr: String::new(),
},
};
assert!(!input.should_capture());
}
#[test]
fn test_error_output_combining() {
let input = HookInput {
tool_name: "Bash".to_string(),
tool_input: ToolInput {
command: Some("cmd".to_string()),
extra: HashMap::new(),
},
tool_result: ToolResult {
exit_code: 1,
stdout: "output line 1".to_string(),
stderr: "error line 1".to_string(),
},
};
assert_eq!(input.error_output(), "output line 1\nerror line 1");
}
#[test]
fn test_command_extraction() {
let input = HookInput {
tool_name: "Bash".to_string(),
tool_input: ToolInput {
command: Some("git push origin main".to_string()),
extra: HashMap::new(),
},
tool_result: ToolResult {
exit_code: 0,
stdout: String::new(),
stderr: String::new(),
},
};
assert_eq!(input.command(), Some("git push origin main"));
}
#[test]
fn test_command_extraction_none() {
let input = HookInput {
tool_name: "Edit".to_string(),
tool_input: ToolInput {
command: None,
extra: HashMap::new(),
},
tool_result: ToolResult {
exit_code: 0,
stdout: String::new(),
stderr: String::new(),
},
};
assert_eq!(input.command(), None);
}
#[test]
fn test_parse_with_extra_fields() {
let json = r#"{
"tool_name": "Write",
"tool_input": {
"path": "/tmp/test.txt",
"content": "hello world"
},
"tool_result": {"exit_code": 0, "stdout": "", "stderr": ""}
}"#;
let input = HookInput::from_json(json).unwrap();
assert_eq!(input.tool_name, "Write");
assert!(input.tool_input.command.is_none());
assert!(input.tool_input.extra.contains_key("path"));
assert!(input.tool_input.extra.contains_key("content"));
}
#[test]
fn test_error_output_stdout_only() {
let input = HookInput {
tool_name: "Bash".to_string(),
tool_input: ToolInput {
command: Some("cmd".to_string()),
extra: HashMap::new(),
},
tool_result: ToolResult {
exit_code: 1,
stdout: "some output".to_string(),
stderr: String::new(),
},
};
assert_eq!(input.error_output(), "some output");
}
#[test]
fn test_error_output_stderr_only() {
let input = HookInput {
tool_name: "Bash".to_string(),
tool_input: ToolInput {
command: Some("cmd".to_string()),
extra: HashMap::new(),
},
tool_result: ToolResult {
exit_code: 1,
stdout: String::new(),
stderr: "some error".to_string(),
},
};
assert_eq!(input.error_output(), "some error");
}
#[test]
fn test_error_output_empty() {
let input = HookInput {
tool_name: "Bash".to_string(),
tool_input: ToolInput {
command: Some("cmd".to_string()),
extra: HashMap::new(),
},
tool_result: ToolResult {
exit_code: 1,
stdout: String::new(),
stderr: String::new(),
},
};
assert_eq!(input.error_output(), "");
}
#[test]
fn test_parse_invalid_json() {
let json = "not valid json";
let result = HookInput::from_json(json);
assert!(result.is_err());
}
#[test]
fn test_should_not_capture_bash_with_exit_zero() {
let input = HookInput {
tool_name: "Bash".to_string(),
tool_input: ToolInput {
command: Some("echo hello".to_string()),
extra: HashMap::new(),
},
tool_result: ToolResult {
exit_code: 0,
stdout: "hello".to_string(),
stderr: String::new(),
},
};
assert!(!input.should_capture());
}
#[test]
fn test_should_capture_bash_with_negative_exit_code() {
let input = HookInput {
tool_name: "Bash".to_string(),
tool_input: ToolInput {
command: Some("kill -9 $$".to_string()),
extra: HashMap::new(),
},
tool_result: ToolResult {
exit_code: -1,
stdout: String::new(),
stderr: "Killed".to_string(),
},
};
assert!(input.should_capture());
}
#[test]
fn test_should_not_capture_non_bash_even_if_failed() {
let input = HookInput {
tool_name: "Write".to_string(),
tool_input: ToolInput {
command: None,
extra: HashMap::new(),
},
tool_result: ToolResult {
exit_code: 1,
stdout: String::new(),
stderr: "Permission denied".to_string(),
},
};
assert!(!input.should_capture());
}
#[test]
fn test_capture_from_hook_success() {
let input = HookInput {
tool_name: "Bash".to_string(),
tool_input: ToolInput {
command: Some("git push".to_string()),
extra: HashMap::new(),
},
tool_result: ToolResult {
exit_code: 1,
stdout: String::new(),
stderr: "rejected".to_string(),
},
};
let result = capture_from_hook(&input);
if let Err(LearningError::Ignored(msg)) = &result {
assert_ne!(msg, "No command in input");
}
}
#[test]
fn test_capture_from_hook_no_command() {
let input = HookInput {
tool_name: "Edit".to_string(),
tool_input: ToolInput {
command: None,
extra: HashMap::new(),
},
tool_result: ToolResult {
exit_code: 0,
stdout: String::new(),
stderr: String::new(),
},
};
let result = capture_from_hook(&input);
assert!(result.is_err());
match result.unwrap_err() {
LearningError::Ignored(msg) => assert_eq!(msg, "No command in input"),
_ => panic!("Expected Ignored error"),
}
}
#[tokio::test]
async fn test_process_hook_with_streams_strips_secrets_from_output() {
use super::process_hook_with_streams;
let aws_key = format!("AKIA{}", "IOSFODNN7EXAMPLE");
let json = format!(
r#"{{"tool_name":"Bash","tool_input":{{"command":"aws s3 ls"}},"tool_result":{{"exit_code":1,"stdout":"","stderr":"Unable to locate credentials {aws_key}"}}}}"#,
);
let mut output_buf: Vec<u8> = Vec::new();
process_hook_with_streams(LearnHookType::PostToolUse, json.as_bytes(), &mut output_buf)
.await
.expect("process_hook_with_streams must not fail");
let output = String::from_utf8(output_buf).expect("output must be valid UTF-8");
assert!(
!output.contains(&aws_key),
"AWS key must not appear in stdout output; got: {output}"
);
assert!(
output.contains("[AWS_KEY_REDACTED]"),
"Redacted placeholder must appear in output; got: {output}"
);
}
#[tokio::test]
async fn test_process_hook_with_streams_clean_input_unchanged() {
use super::process_hook_with_streams;
let json = r#"{"tool_name":"Bash","tool_input":{"command":"cargo build"},"tool_result":{"exit_code":0,"stdout":"Compiling","stderr":""}}"#;
let mut output_buf: Vec<u8> = Vec::new();
process_hook_with_streams(LearnHookType::PostToolUse, json.as_bytes(), &mut output_buf)
.await
.expect("process_hook_with_streams must not fail");
let output = String::from_utf8(output_buf).expect("output must be valid UTF-8");
assert_eq!(output, json, "Clean input must pass through unchanged");
}
#[tokio::test]
async fn test_process_hook_with_streams_pre_tool_use_also_redacts() {
use super::process_hook_with_streams;
let aws_key = format!("AKIA{}", "IOSFODNN7EXAMPLE");
let json = format!(
r#"{{"tool_name":"Bash","tool_input":{{"command":"export AWS_ACCESS_KEY_ID={aws_key}"}},"tool_result":{{"exit_code":0,"stdout":"","stderr":""}}}}"#,
);
let mut output_buf: Vec<u8> = Vec::new();
process_hook_with_streams(LearnHookType::PreToolUse, json.as_bytes(), &mut output_buf)
.await
.expect("process_hook_with_streams must not fail");
let output = String::from_utf8(output_buf).expect("output must be valid UTF-8");
assert!(
!output.contains(&aws_key),
"AWS key must not appear in pre-tool-use stdout output; got: {output}"
);
}
#[test]
fn test_hook_passthrough_redacts_aws_key_in_error() {
use crate::learnings::redact_secrets;
let aws_key = format!("AKIA{}", "IOSFODNN7EXAMPLE");
let json = format!(
r#"{{
"tool_name": "Bash",
"tool_input": {{"command": "aws s3 ls"}},
"tool_result": {{
"exit_code": 1,
"stdout": "",
"stderr": "Unable to locate credentials {}"
}}
}}"#,
aws_key
);
let redacted = redact_secrets(&json);
assert!(!redacted.contains(&aws_key));
assert!(redacted.contains("[AWS_KEY_REDACTED]"));
let parsed = HookInput::from_json(&redacted).unwrap();
assert_eq!(parsed.tool_name, "Bash");
assert_eq!(parsed.tool_result.exit_code, 1);
assert!(parsed.tool_result.stderr.contains("[AWS_KEY_REDACTED]"));
}
#[test]
fn test_learn_hook_type_variants() {
assert_ne!(LearnHookType::PreToolUse, LearnHookType::PostToolUse);
assert_ne!(LearnHookType::PostToolUse, LearnHookType::UserPromptSubmit);
assert_ne!(LearnHookType::PreToolUse, LearnHookType::UserPromptSubmit);
}
#[test]
fn test_parse_correction_pattern_use_instead_of() {
let result = parse_correction_pattern("use bun instead of npm");
assert_eq!(result, Some(("npm".to_string(), "bun".to_string())));
}
#[test]
fn test_parse_correction_pattern_prefer_over() {
let result = parse_correction_pattern("prefer cargo over make");
assert_eq!(result, Some(("make".to_string(), "cargo".to_string())));
}
#[test]
fn test_parse_correction_pattern_with_trailing_period() {
let result = parse_correction_pattern("use Result<T> instead of unwrap().");
assert_eq!(
result,
Some(("unwrap()".to_string(), "Result<T>".to_string()))
);
}
#[test]
fn test_parse_correction_pattern_use_not() {
let result = parse_correction_pattern("use uv not pip");
assert_eq!(result, Some(("pip".to_string(), "uv".to_string())));
}
#[test]
fn test_parse_correction_pattern_no_match() {
assert!(parse_correction_pattern("hello world").is_none());
assert!(parse_correction_pattern("this is fine").is_none());
assert!(parse_correction_pattern("I prefer tea over coffee").is_none());
}
#[test]
fn test_pre_tool_use_no_crash_on_non_bash() {
let json = r#"{
"tool_name": "Edit",
"tool_input": {"path": "/tmp/test.txt"},
"tool_result": {"exit_code": 0, "stdout": "", "stderr": ""}
}"#;
process_pre_tool_use(json);
}
#[test]
fn test_pre_tool_use_no_crash_on_invalid_json() {
process_pre_tool_use("not valid json");
}
#[test]
fn test_user_prompt_submit_no_crash_on_empty() {
process_user_prompt_submit("{}");
}
#[test]
fn test_user_prompt_submit_no_crash_on_invalid_json() {
process_user_prompt_submit("invalid");
}
#[tokio::test]
async fn test_process_hook_github_pat_is_redacted() {
use super::process_hook_with_streams;
let pat = format!("ghp_{}", "A".repeat(36));
let json = format!(
r#"{{"tool_name":"Bash","tool_input":{{"command":"git push"}},"tool_result":{{"exit_code":1,"stdout":"","stderr":"remote: invalid credentials {pat}"}}}}"#,
);
let mut output_buf: Vec<u8> = Vec::new();
process_hook_with_streams(LearnHookType::PostToolUse, json.as_bytes(), &mut output_buf)
.await
.expect("process_hook_with_streams must not fail");
let output = String::from_utf8(output_buf).expect("output must be valid UTF-8");
assert!(
!output.contains(&pat),
"GitHub PAT must not appear in stdout output; got: {output}"
);
assert!(
output.contains("[GITHUB_TOKEN_REDACTED]"),
"Redacted placeholder must appear in output; got: {output}"
);
}
#[tokio::test]
async fn test_process_hook_slack_token_is_redacted() {
use super::process_hook_with_streams;
let slack_token = format!(
"xoxb-{}-{}-{}",
"FAKE_TEST_ID_A", "FAKE_TEST_ID_B", "FAKE_TEST_SECRET"
);
let json = format!(
r#"{{"tool_name":"Bash","tool_input":{{"command":"curl -H 'Authorization: Bearer {slack_token}' https://slack.com/api/chat.postMessage"}},"tool_result":{{"exit_code":0,"stdout":"","stderr":""}}}}"#,
);
let mut output_buf: Vec<u8> = Vec::new();
process_hook_with_streams(LearnHookType::PostToolUse, json.as_bytes(), &mut output_buf)
.await
.expect("process_hook_with_streams must not fail");
let output = String::from_utf8(output_buf).expect("output must be valid UTF-8");
assert!(
!output.contains(&slack_token),
"Slack token must not appear in stdout output; got: {output}"
);
assert!(
output.contains("[SLACK_TOKEN_REDACTED]"),
"Redacted placeholder must appear in output; got: {output}"
);
}
#[tokio::test]
async fn test_process_hook_connection_string_is_redacted() {
use super::process_hook_with_streams;
let conn = "postgresql://dbuser:s3cr3tpassword@prod-db.internal:5432/appdb";
let json = format!(
r#"{{"tool_name":"Bash","tool_input":{{"command":"psql {conn}"}},"tool_result":{{"exit_code":1,"stdout":"","stderr":"connection refused"}}}}"#,
);
let mut output_buf: Vec<u8> = Vec::new();
process_hook_with_streams(LearnHookType::PostToolUse, json.as_bytes(), &mut output_buf)
.await
.expect("process_hook_with_streams must not fail");
let output = String::from_utf8(output_buf).expect("output must be valid UTF-8");
assert!(
!output.contains("s3cr3tpassword"),
"Connection string password must not appear in stdout output; got: {output}"
);
assert!(
output.contains("postgresql://[REDACTED]@"),
"Redacted connection string must appear in output; got: {output}"
);
}
}