use super::*;
use std::collections::HashSet;
pub struct LoggingHook {
pub log_pre_execution: bool,
pub log_post_execution: bool,
pub log_tool_events: bool,
pub log_level: LogLevel,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum LogLevel {
Debug,
Info,
Warn,
}
impl Default for LoggingHook {
fn default() -> Self {
Self::new()
}
}
impl LoggingHook {
pub fn new() -> Self {
Self {
log_pre_execution: true,
log_post_execution: true,
log_tool_events: true,
log_level: LogLevel::Info,
}
}
pub fn with_level(mut self, level: LogLevel) -> Self {
self.log_level = level;
self
}
pub fn without_pre_execution(mut self) -> Self {
self.log_pre_execution = false;
self
}
pub fn without_post_execution(mut self) -> Self {
self.log_post_execution = false;
self
}
pub fn without_tool_events(mut self) -> Self {
self.log_tool_events = false;
self
}
fn log(&self, message: &str) {
match self.log_level {
LogLevel::Debug => tracing::debug!("{}", message),
LogLevel::Info => tracing::info!("{}", message),
LogLevel::Warn => tracing::warn!("{}", message),
}
}
}
#[async_trait]
impl ExecutionHooks for LoggingHook {
async fn pre_execution(&self, input: PreExecutionInput, ctx: HookContext) -> HookResult {
if self.log_pre_execution {
self.log(&format!(
"[{}] Pre-execution: task='{}' type='{}' priority='{}'",
ctx.agent_id, input.task_description, input.task_type, input.priority
));
}
HookResult::Continue
}
async fn post_execution(&self, input: PostExecutionInput, ctx: HookContext) -> HookResult {
if self.log_post_execution {
self.log(&format!(
"[{}] Post-execution: task='{}' success={} duration={}ms",
ctx.agent_id, input.task_description, input.success, input.duration_ms
));
}
HookResult::Continue
}
async fn on_error(&self, input: OnErrorInput, ctx: HookContext) -> HookResult {
tracing::error!(
agent_id = %ctx.agent_id,
error_type = %input.error_type,
recoverable = input.is_recoverable,
"Error occurred: {}",
input.error_message
);
HookResult::Continue
}
fn name(&self) -> &str {
"logging"
}
fn priority(&self) -> i32 {
-100 }
}
#[async_trait]
impl ToolHooks for LoggingHook {
async fn pre_tool_use(&self, input: PreToolUseInput, ctx: HookContext) -> HookResult {
if self.log_tool_events {
self.log(&format!(
"[{}] Pre-tool-use: tool='{}' args={}",
ctx.agent_id,
input.tool_name,
serde_json::to_string(&input.arguments).unwrap_or_default()
));
}
HookResult::Continue
}
async fn post_tool_use(&self, input: PostToolUseInput, ctx: HookContext) -> HookResult {
if self.log_tool_events {
self.log(&format!(
"[{}] Post-tool-use: tool='{}' success={} duration={}ms",
ctx.agent_id, input.tool_name, input.success, input.duration_ms
));
}
HookResult::Continue
}
fn name(&self) -> &str {
"logging"
}
fn priority(&self) -> i32 {
-100
}
}
pub struct SecurityHook {
protected_patterns: HashSet<String>,
blocked_commands: HashSet<String>,
allow_destructive_git: bool,
}
impl Default for SecurityHook {
fn default() -> Self {
Self::new()
}
}
impl SecurityHook {
pub fn new() -> Self {
let mut protected = HashSet::new();
protected.insert(".env".to_string());
protected.insert("*.key".to_string());
protected.insert("*.pem".to_string());
protected.insert("credentials.json".to_string());
protected.insert("secrets.yaml".to_string());
protected.insert(".git/config".to_string());
let mut blocked = HashSet::new();
blocked.insert("rm -rf /".to_string());
blocked.insert("rm -rf /*".to_string());
blocked.insert(":(){:|:&};:".to_string());
Self {
protected_patterns: protected,
blocked_commands: blocked,
allow_destructive_git: false,
}
}
pub fn protect_pattern(mut self, pattern: impl Into<String>) -> Self {
self.protected_patterns.insert(pattern.into());
self
}
pub fn block_command(mut self, command: impl Into<String>) -> Self {
self.blocked_commands.insert(command.into());
self
}
pub fn allow_destructive_git(mut self) -> Self {
self.allow_destructive_git = true;
self
}
fn is_protected(&self, path: &str) -> bool {
for pattern in &self.protected_patterns {
if let Some(suffix) = pattern.strip_prefix('*') {
if path.ends_with(suffix) {
return true;
}
} else if path.contains(pattern) || path == pattern {
return true;
}
}
false
}
fn is_blocked_command(&self, command: &str) -> Option<&str> {
for blocked in &self.blocked_commands {
if command.contains(blocked) {
return Some(blocked);
}
}
if !self.allow_destructive_git {
if command.contains("git push --force")
|| command.contains("git reset --hard")
|| command.contains("git clean -f")
{
return Some("destructive git operation");
}
}
None
}
}
#[async_trait]
impl ExecutionHooks for SecurityHook {
async fn pre_execution(&self, input: PreExecutionInput, _ctx: HookContext) -> HookResult {
let desc_lower = input.task_description.to_lowercase();
if desc_lower.contains("delete all")
|| desc_lower.contains("remove all")
|| desc_lower.contains("drop database")
{
return HookResult::Deny {
reason: "Task description contains potentially destructive operation".to_string(),
};
}
HookResult::Continue
}
async fn post_execution(&self, _input: PostExecutionInput, _ctx: HookContext) -> HookResult {
HookResult::Continue
}
async fn on_error(&self, _input: OnErrorInput, _ctx: HookContext) -> HookResult {
HookResult::Continue
}
fn name(&self) -> &str {
"security"
}
fn priority(&self) -> i32 {
100 }
}
#[async_trait]
impl ToolHooks for SecurityHook {
async fn pre_tool_use(&self, input: PreToolUseInput, _ctx: HookContext) -> HookResult {
if input.tool_name == "Write" || input.tool_name == "Edit" {
if let Some(path) = input.arguments.get("file_path").and_then(|v| v.as_str()) {
if self.is_protected(path) {
return HookResult::Deny {
reason: format!("Cannot modify protected file: {}", path),
};
}
}
}
if input.tool_name == "Bash" {
if let Some(command) = input.arguments.get("command").and_then(|v| v.as_str()) {
if let Some(blocked) = self.is_blocked_command(command) {
return HookResult::Deny {
reason: format!("Blocked command detected: {}", blocked),
};
}
}
}
HookResult::Continue
}
async fn post_tool_use(&self, _input: PostToolUseInput, _ctx: HookContext) -> HookResult {
HookResult::Continue
}
fn name(&self) -> &str {
"security"
}
fn priority(&self) -> i32 {
100
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_security_hook_protected_files() {
let hook = SecurityHook::new();
let ctx = HookContext::new("test-agent");
let input = PreToolUseInput {
tool_name: "Write".to_string(),
arguments: serde_json::json!({
"file_path": "/project/.env"
}),
description: None,
};
let result = hook.pre_tool_use(input, ctx.clone()).await;
assert!(result.is_denied());
let input = PreToolUseInput {
tool_name: "Edit".to_string(),
arguments: serde_json::json!({
"file_path": "/home/user/private.key"
}),
description: None,
};
let result = hook.pre_tool_use(input, ctx.clone()).await;
assert!(result.is_denied());
let input = PreToolUseInput {
tool_name: "Write".to_string(),
arguments: serde_json::json!({
"file_path": "/project/src/main.rs"
}),
description: None,
};
let result = hook.pre_tool_use(input, ctx).await;
assert!(result.should_continue());
}
#[tokio::test]
async fn test_security_hook_destructive_git() {
let hook = SecurityHook::new();
let ctx = HookContext::new("test-agent");
let input = PreToolUseInput {
tool_name: "Bash".to_string(),
arguments: serde_json::json!({
"command": "git push --force origin main"
}),
description: None,
};
let result = hook.pre_tool_use(input, ctx).await;
assert!(result.is_denied());
}
#[tokio::test]
async fn test_security_hook_allowed_with_flag() {
let hook = SecurityHook::new().allow_destructive_git();
let ctx = HookContext::new("test-agent");
let input = PreToolUseInput {
tool_name: "Bash".to_string(),
arguments: serde_json::json!({
"command": "git push --force origin main"
}),
description: None,
};
let result = hook.pre_tool_use(input, ctx).await;
assert!(result.should_continue());
}
}