use crate::constants::{HOOK_DEFAULT_LLM_TIMEOUT_SECS, HOOK_DEFAULT_TIMEOUT_SECS};
use crate::storage::ChatMessage;
use serde::{Deserialize, Serialize};
pub(crate) const MAX_CHAIN_DURATION_SECS: u64 = 30;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum HookEvent {
PreSendMessage,
PostSendMessage,
PreLlmRequest,
PostLlmResponse,
PreToolExecution,
PostToolExecution,
PostToolExecutionFailure,
Stop,
PreMicroCompact,
PostMicroCompact,
PreAutoCompact,
PostAutoCompact,
SessionStart,
SessionEnd,
}
impl std::str::FromStr for HookEvent {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"pre_send_message" => Ok(HookEvent::PreSendMessage),
"post_send_message" => Ok(HookEvent::PostSendMessage),
"pre_llm_request" => Ok(HookEvent::PreLlmRequest),
"post_llm_response" => Ok(HookEvent::PostLlmResponse),
"pre_tool_execution" => Ok(HookEvent::PreToolExecution),
"post_tool_execution" => Ok(HookEvent::PostToolExecution),
"post_tool_execution_failure" => Ok(HookEvent::PostToolExecutionFailure),
"stop" => Ok(HookEvent::Stop),
"pre_micro_compact" => Ok(HookEvent::PreMicroCompact),
"post_micro_compact" => Ok(HookEvent::PostMicroCompact),
"pre_auto_compact" => Ok(HookEvent::PreAutoCompact),
"post_auto_compact" => Ok(HookEvent::PostAutoCompact),
"session_start" => Ok(HookEvent::SessionStart),
"session_end" => Ok(HookEvent::SessionEnd),
_ => Err(()),
}
}
}
impl HookEvent {
pub fn as_str(&self) -> &'static str {
match self {
HookEvent::PreSendMessage => "pre_send_message",
HookEvent::PostSendMessage => "post_send_message",
HookEvent::PreLlmRequest => "pre_llm_request",
HookEvent::PostLlmResponse => "post_llm_response",
HookEvent::PreToolExecution => "pre_tool_execution",
HookEvent::PostToolExecution => "post_tool_execution",
HookEvent::PostToolExecutionFailure => "post_tool_execution_failure",
HookEvent::Stop => "stop",
HookEvent::PreMicroCompact => "pre_micro_compact",
HookEvent::PostMicroCompact => "post_micro_compact",
HookEvent::PreAutoCompact => "pre_auto_compact",
HookEvent::PostAutoCompact => "post_auto_compact",
HookEvent::SessionStart => "session_start",
HookEvent::SessionEnd => "session_end",
}
}
pub fn all() -> &'static [HookEvent] {
&[
HookEvent::PreSendMessage,
HookEvent::PostSendMessage,
HookEvent::PreLlmRequest,
HookEvent::PostLlmResponse,
HookEvent::PreToolExecution,
HookEvent::PostToolExecution,
HookEvent::PostToolExecutionFailure,
HookEvent::Stop,
HookEvent::PreMicroCompact,
HookEvent::PostMicroCompact,
HookEvent::PreAutoCompact,
HookEvent::PostAutoCompact,
HookEvent::SessionStart,
HookEvent::SessionEnd,
]
}
pub fn parse(s: &str) -> Option<HookEvent> {
s.parse().ok()
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, PartialEq)]
#[serde(rename_all = "snake_case")]
pub enum OnError {
#[default]
Skip,
Stop,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct HookFilter {
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_name: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub tool_matcher: Option<String>,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub model_prefix: Option<String>,
}
impl HookFilter {
pub fn is_empty(&self) -> bool {
self.tool_name.is_none() && self.tool_matcher.is_none() && self.model_prefix.is_none()
}
pub fn matches(&self, context: &HookContext) -> bool {
if let Some(ref expected_tool) = self.tool_name {
match &context.tool_name {
Some(actual) if actual == expected_tool => {}
Some(_) => return false,
None => return false,
}
} else if let Some(ref pattern) = self.tool_matcher {
let actual = match &context.tool_name {
Some(a) => a,
None => return false,
};
let matched = pattern.split('|').any(|p| p.trim() == actual);
if !matched {
return false;
}
}
if let Some(ref prefix) = self.model_prefix {
match &context.model {
Some(actual) if actual.starts_with(prefix.as_str()) => {}
Some(_) => return false,
None => return false,
}
}
true
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum HookType {
#[default]
Bash,
Llm,
}
impl std::fmt::Display for HookType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
HookType::Bash => write!(f, "bash"),
HookType::Llm => write!(f, "llm"),
}
}
}
#[derive(Debug, Serialize)]
pub struct HookContext {
pub event: HookEvent,
#[serde(skip_serializing_if = "Option::is_none")]
pub messages: Option<Vec<ChatMessage>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub system_prompt: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user_input: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub assistant_output: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_name: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_arguments: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_result: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub tool_error: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub session_id: Option<String>,
pub cwd: String,
}
impl Default for HookContext {
fn default() -> Self {
Self {
event: HookEvent::SessionStart,
messages: None,
system_prompt: None,
model: None,
user_input: None,
assistant_output: None,
tool_name: None,
tool_arguments: None,
tool_result: None,
tool_error: None,
session_id: None,
cwd: std::env::current_dir()
.map(|p| p.display().to_string())
.unwrap_or_else(|_| ".".to_string()),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum HookAction {
Stop,
Skip,
}
#[derive(Debug, Deserialize, Default)]
pub struct HookResult {
#[serde(default)]
pub messages: Option<Vec<ChatMessage>>,
#[serde(default)]
pub system_prompt: Option<String>,
#[serde(default)]
pub user_input: Option<String>,
#[serde(default)]
pub assistant_output: Option<String>,
#[serde(default)]
pub tool_arguments: Option<String>,
#[serde(default)]
pub tool_result: Option<String>,
#[serde(default)]
pub tool_error: Option<String>,
#[serde(default)]
pub inject_messages: Option<Vec<ChatMessage>>,
#[serde(default)]
pub retry_feedback: Option<String>,
#[serde(default)]
pub additional_context: Option<String>,
#[serde(default)]
pub system_message: Option<String>,
#[serde(default)]
pub action: Option<HookAction>,
}
impl HookResult {
pub fn is_stop(&self) -> bool {
self.action == Some(HookAction::Stop)
}
pub fn is_skip(&self) -> bool {
self.action == Some(HookAction::Skip)
}
pub fn is_halt(&self) -> bool {
self.is_stop() || self.is_skip()
}
}
#[derive(Debug)]
#[allow(dead_code, clippy::large_enum_variant)]
pub(crate) enum HookOutcome {
Success(HookResult),
Retry {
error: String,
#[allow(dead_code)]
attempts_left: u32,
},
Err(String),
}
pub(crate) fn default_timeout() -> u64 {
HOOK_DEFAULT_TIMEOUT_SECS
}
pub(crate) fn default_llm_timeout() -> u64 {
HOOK_DEFAULT_LLM_TIMEOUT_SECS
}