use async_trait::async_trait;
use std::collections::HashMap;
use std::sync::Arc;
use uira_memory::MemorySystem;
use super::types::{HookEvent, HookInput, HookOutput};
#[derive(Clone)]
pub struct HookContext {
pub session_id: Option<String>,
pub directory: String,
pub memory_system: Option<Arc<MemorySystem>>,
pub data: HashMap<String, serde_json::Value>,
}
impl std::fmt::Debug for HookContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HookContext")
.field("session_id", &self.session_id)
.field("directory", &self.directory)
.field("has_memory_system", &self.memory_system.is_some())
.field("data", &self.data)
.finish()
}
}
impl HookContext {
pub fn new(
session_id: Option<String>,
directory: String,
memory_system: Option<Arc<MemorySystem>>,
) -> Self {
Self {
session_id,
directory,
memory_system,
data: HashMap::new(),
}
}
pub fn with_data(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
self.data.insert(key.into(), value);
self
}
}
pub type HookResult = anyhow::Result<HookOutput>;
#[async_trait]
pub trait Hook: Send + Sync {
fn name(&self) -> &str;
fn events(&self) -> &[HookEvent];
async fn execute(
&self,
event: HookEvent,
input: &HookInput,
context: &HookContext,
) -> HookResult;
fn priority(&self) -> i32 {
0
}
fn is_enabled(&self) -> bool {
true
}
}
#[cfg(test)]
mod tests {
use super::*;
struct TestHook;
#[async_trait]
impl Hook for TestHook {
fn name(&self) -> &str {
"test-hook"
}
fn events(&self) -> &[HookEvent] {
&[HookEvent::UserPromptSubmit]
}
async fn execute(
&self,
_event: HookEvent,
_input: &HookInput,
_context: &HookContext,
) -> HookResult {
Ok(HookOutput::pass())
}
}
#[tokio::test]
async fn test_hook_execution() {
let hook = TestHook;
let input = HookInput {
session_id: Some("test-session".to_string()),
prompt: Some("test prompt".to_string()),
message: None,
parts: None,
tool_name: None,
tool_input: None,
tool_output: None,
directory: None,
stop_reason: None,
user_requested: None,
transcript_path: None,
extra: HashMap::new(),
};
let context = HookContext::new(Some("test-session".to_string()), "/tmp".to_string(), None);
let result = hook
.execute(HookEvent::UserPromptSubmit, &input, &context)
.await;
assert!(result.is_ok());
let output = result.unwrap();
assert!(output.should_continue);
}
#[test]
fn test_hook_context_with_data() {
let context = HookContext::new(Some("session".to_string()), "/tmp".to_string(), None)
.with_data("key", serde_json::json!("value"));
assert_eq!(context.data.get("key"), Some(&serde_json::json!("value")));
}
}