1use crate::constants::{HOOK_DEFAULT_LLM_TIMEOUT_SECS, HOOK_DEFAULT_TIMEOUT_SECS};
2use crate::storage::ChatMessage;
3use serde::{Deserialize, Serialize};
4use std::env;
5
6pub(crate) const MAX_CHAIN_DURATION_SECS: u64 = 30;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
38#[serde(rename_all = "snake_case")]
39pub enum HookEvent {
40 PreSendMessage,
41 PostSendMessage,
42 PreLlmRequest,
43 PostLlmResponse,
44 PreToolExecution,
45 PostToolExecution,
46 PostToolExecutionFailure,
47 Stop,
48 PreMicroCompact,
49 PostMicroCompact,
50 PreAutoCompact,
51 PostAutoCompact,
52 SessionStart,
53 SessionEnd,
54}
55
56impl std::str::FromStr for HookEvent {
57 type Err = ();
58
59 fn from_str(s: &str) -> Result<Self, Self::Err> {
60 match s {
61 "pre_send_message" => Ok(HookEvent::PreSendMessage),
62 "post_send_message" => Ok(HookEvent::PostSendMessage),
63 "pre_llm_request" => Ok(HookEvent::PreLlmRequest),
64 "post_llm_response" => Ok(HookEvent::PostLlmResponse),
65 "pre_tool_execution" => Ok(HookEvent::PreToolExecution),
66 "post_tool_execution" => Ok(HookEvent::PostToolExecution),
67 "post_tool_execution_failure" => Ok(HookEvent::PostToolExecutionFailure),
68 "stop" => Ok(HookEvent::Stop),
69 "pre_micro_compact" => Ok(HookEvent::PreMicroCompact),
70 "post_micro_compact" => Ok(HookEvent::PostMicroCompact),
71 "pre_auto_compact" => Ok(HookEvent::PreAutoCompact),
72 "post_auto_compact" => Ok(HookEvent::PostAutoCompact),
73 "session_start" => Ok(HookEvent::SessionStart),
74 "session_end" => Ok(HookEvent::SessionEnd),
75 _ => Err(()),
76 }
77 }
78}
79
80impl HookEvent {
81 pub fn as_str(&self) -> &'static str {
83 match self {
84 HookEvent::PreSendMessage => "pre_send_message",
85 HookEvent::PostSendMessage => "post_send_message",
86 HookEvent::PreLlmRequest => "pre_llm_request",
87 HookEvent::PostLlmResponse => "post_llm_response",
88 HookEvent::PreToolExecution => "pre_tool_execution",
89 HookEvent::PostToolExecution => "post_tool_execution",
90 HookEvent::PostToolExecutionFailure => "post_tool_execution_failure",
91 HookEvent::Stop => "stop",
92 HookEvent::PreMicroCompact => "pre_micro_compact",
93 HookEvent::PostMicroCompact => "post_micro_compact",
94 HookEvent::PreAutoCompact => "pre_auto_compact",
95 HookEvent::PostAutoCompact => "post_auto_compact",
96 HookEvent::SessionStart => "session_start",
97 HookEvent::SessionEnd => "session_end",
98 }
99 }
100
101 pub fn all() -> &'static [HookEvent] {
103 &[
104 HookEvent::PreSendMessage,
105 HookEvent::PostSendMessage,
106 HookEvent::PreLlmRequest,
107 HookEvent::PostLlmResponse,
108 HookEvent::PreToolExecution,
109 HookEvent::PostToolExecution,
110 HookEvent::PostToolExecutionFailure,
111 HookEvent::Stop,
112 HookEvent::PreMicroCompact,
113 HookEvent::PostMicroCompact,
114 HookEvent::PreAutoCompact,
115 HookEvent::PostAutoCompact,
116 HookEvent::SessionStart,
117 HookEvent::SessionEnd,
118 ]
119 }
120
121 pub fn parse(s: &str) -> Option<HookEvent> {
123 s.parse().ok()
124 }
125}
126
127#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default, PartialEq)]
131#[serde(rename_all = "snake_case")]
132pub enum OnError {
133 #[default]
135 Skip,
136 Stop,
138}
139
140#[derive(Debug, Clone, Serialize, Deserialize, Default)]
147pub struct HookFilter {
148 #[serde(default, skip_serializing_if = "Option::is_none")]
150 pub tool_name: Option<String>,
151 #[serde(default, skip_serializing_if = "Option::is_none")]
154 pub tool_matcher: Option<String>,
155 #[serde(default, skip_serializing_if = "Option::is_none")]
157 pub model_prefix: Option<String>,
158}
159
160impl HookFilter {
161 pub fn is_empty(&self) -> bool {
163 self.tool_name.is_none() && self.tool_matcher.is_none() && self.model_prefix.is_none()
164 }
165
166 pub fn matches(&self, context: &HookContext) -> bool {
168 if let Some(ref expected_tool) = self.tool_name {
170 match &context.tool_name {
171 Some(actual) if actual == expected_tool => {}
172 Some(_) => return false,
173 None => return false,
174 }
175 } else if let Some(ref pattern) = self.tool_matcher {
176 let actual = match &context.tool_name {
178 Some(a) => a,
179 None => return false,
180 };
181 let matched = pattern.split('|').any(|p| p.trim() == actual);
182 if !matched {
183 return false;
184 }
185 }
186 if let Some(ref prefix) = self.model_prefix {
187 match &context.model {
188 Some(actual) if actual.starts_with(prefix.as_str()) => {}
189 Some(_) => return false,
190 None => return false,
191 }
192 }
193 true
194 }
195}
196
197#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize, Default)]
201#[serde(rename_all = "snake_case")]
202pub enum HookType {
203 #[default]
205 Bash,
206 Llm,
208}
209
210impl std::fmt::Display for HookType {
211 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
212 match self {
213 HookType::Bash => write!(f, "bash"),
214 HookType::Llm => write!(f, "llm"),
215 }
216 }
217}
218
219#[derive(Debug, Serialize)]
226pub struct HookContext {
227 pub event: HookEvent,
229
230 #[serde(skip_serializing_if = "Option::is_none")]
233 pub messages: Option<Vec<ChatMessage>>,
234
235 #[serde(skip_serializing_if = "Option::is_none")]
238 pub system_prompt: Option<String>,
239
240 #[serde(skip_serializing_if = "Option::is_none")]
243 pub model: Option<String>,
244
245 #[serde(skip_serializing_if = "Option::is_none")]
248 pub user_input: Option<String>,
249
250 #[serde(skip_serializing_if = "Option::is_none")]
253 pub assistant_output: Option<String>,
254
255 #[serde(skip_serializing_if = "Option::is_none")]
258 pub tool_name: Option<String>,
259
260 #[serde(skip_serializing_if = "Option::is_none")]
263 pub tool_arguments: Option<String>,
264
265 #[serde(skip_serializing_if = "Option::is_none")]
268 pub tool_result: Option<String>,
269
270 #[serde(skip_serializing_if = "Option::is_none")]
273 pub tool_error: Option<String>,
274
275 #[serde(skip_serializing_if = "Option::is_none")]
278 pub session_id: Option<String>,
279
280 pub cwd: String,
282}
283
284impl Default for HookContext {
285 fn default() -> Self {
286 Self {
287 event: HookEvent::SessionStart,
288 messages: None,
289 system_prompt: None,
290 model: None,
291 user_input: None,
292 assistant_output: None,
293 tool_name: None,
294 tool_arguments: None,
295 tool_result: None,
296 tool_error: None,
297 session_id: None,
298 cwd: env::current_dir()
299 .map(|p| p.display().to_string())
300 .unwrap_or_else(|_| ".".to_string()),
301 }
302 }
303}
304
305#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
313#[serde(rename_all = "snake_case")]
314pub enum HookAction {
315 Stop,
317 Skip,
319}
320
321#[derive(Debug, Deserialize, Default)]
323pub struct HookResult {
324 #[serde(default)]
326 pub messages: Option<Vec<ChatMessage>>,
327 #[serde(default)]
329 pub system_prompt: Option<String>,
330 #[serde(default)]
332 pub user_input: Option<String>,
333 #[serde(default)]
335 pub assistant_output: Option<String>,
336 #[serde(default)]
338 pub tool_arguments: Option<String>,
339 #[serde(default)]
341 pub tool_result: Option<String>,
342 #[serde(default)]
344 pub tool_error: Option<String>,
345 #[serde(default)]
347 pub inject_messages: Option<Vec<ChatMessage>>,
348 #[serde(default)]
350 pub retry_feedback: Option<String>,
351 #[serde(default)]
353 pub additional_context: Option<String>,
354 #[serde(default)]
356 pub system_message: Option<String>,
357 #[serde(default)]
359 pub action: Option<HookAction>,
360}
361
362impl HookResult {
363 pub fn is_stop(&self) -> bool {
365 self.action == Some(HookAction::Stop)
366 }
367
368 pub fn is_skip(&self) -> bool {
370 self.action == Some(HookAction::Skip)
371 }
372
373 pub fn is_halt(&self) -> bool {
375 self.is_stop() || self.is_skip()
376 }
377}
378
379#[derive(Debug)]
385#[allow(dead_code, clippy::large_enum_variant)]
386pub(crate) enum HookOutcome {
387 Success(HookResult),
388 Retry {
389 error: String,
390 #[allow(dead_code)]
391 attempts_left: u32,
392 },
393 Err(String),
394}
395
396pub(crate) fn default_timeout() -> u64 {
399 HOOK_DEFAULT_TIMEOUT_SECS
400}
401
402pub(crate) fn default_llm_timeout() -> u64 {
403 HOOK_DEFAULT_LLM_TIMEOUT_SECS
404}