1use crate::types::{SessionId, StopReason, Usage};
4use async_trait::async_trait;
5use chrono::{DateTime, Utc};
6use schemars::JsonSchema;
7use serde::{Deserialize, Serialize};
8use serde_json::Value;
9
10#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
12#[serde(transparent)]
13pub struct HookId(pub String);
14
15impl HookId {
16 pub fn new(id: impl Into<String>) -> Self {
17 Self(id.into())
18 }
19}
20
21impl std::fmt::Display for HookId {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 self.0.fmt(f)
24 }
25}
26
27impl From<&str> for HookId {
28 fn from(value: &str) -> Self {
29 Self::new(value)
30 }
31}
32
33impl From<String> for HookId {
34 fn from(value: String) -> Self {
35 Self::new(value)
36 }
37}
38
39#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
41#[serde(rename_all = "snake_case")]
42pub enum HookPoint {
43 RunStarted,
44 RunCompleted,
45 RunFailed,
46 PreLlmRequest,
47 PostLlmResponse,
48 PreToolExecution,
49 PostToolExecution,
50 TurnBoundary,
51}
52
53impl HookPoint {
54 pub fn is_pre(self) -> bool {
55 matches!(
56 self,
57 Self::RunStarted | Self::PreLlmRequest | Self::PreToolExecution | Self::TurnBoundary
58 )
59 }
60
61 pub fn is_post(self) -> bool {
62 matches!(
63 self,
64 Self::PostLlmResponse | Self::PostToolExecution | Self::RunCompleted | Self::RunFailed
65 )
66 }
67}
68
69#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
71#[serde(rename_all = "snake_case")]
72pub enum HookExecutionMode {
73 Foreground,
74 Background,
75}
76
77#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
79#[serde(rename_all = "snake_case")]
80pub enum HookCapability {
81 Observe,
82 Guardrail,
83 Rewrite,
84}
85
86#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
88#[serde(rename_all = "snake_case")]
89pub enum HookFailurePolicy {
90 FailOpen,
91 FailClosed,
92}
93
94#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
96#[serde(rename_all = "snake_case")]
97pub enum HookReasonCode {
98 PolicyViolation,
99 SafetyViolation,
100 SchemaViolation,
101 Timeout,
102 RuntimeError,
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
107#[serde(tag = "decision", rename_all = "snake_case")]
108pub enum HookDecision {
109 Allow,
110 Deny {
111 hook_id: HookId,
112 reason_code: HookReasonCode,
113 message: String,
114 #[serde(default, skip_serializing_if = "Option::is_none")]
115 payload: Option<Value>,
116 },
117}
118
119impl HookDecision {
120 pub fn deny(
121 hook_id: HookId,
122 reason_code: HookReasonCode,
123 message: impl Into<String>,
124 payload: Option<Value>,
125 ) -> Self {
126 Self::Deny {
127 hook_id,
128 reason_code,
129 message: message.into(),
130 payload,
131 }
132 }
133}
134
135#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
137#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
138#[serde(tag = "patch_type", rename_all = "snake_case")]
139pub enum HookPatch {
140 LlmRequest {
142 #[serde(default, skip_serializing_if = "Option::is_none")]
143 max_tokens: Option<u32>,
144 #[serde(default, skip_serializing_if = "Option::is_none")]
145 temperature: Option<f32>,
146 #[serde(default, skip_serializing_if = "Option::is_none")]
147 provider_params: Option<Value>,
148 },
149 AssistantText { text: String },
151 ToolArgs { args: Value },
153 ToolResult {
155 content: String,
156 #[serde(default, skip_serializing_if = "Option::is_none")]
157 is_error: Option<bool>,
158 },
159 RunResult { text: String },
161}
162
163#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
165#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
166#[serde(transparent)]
167pub struct HookRevision(pub u64);
168
169#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
171#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
172#[serde(rename_all = "snake_case")]
173pub struct HookPatchEnvelope {
174 pub revision: HookRevision,
175 pub hook_id: HookId,
176 pub point: HookPoint,
177 pub patch: HookPatch,
178 #[cfg_attr(feature = "schema", schemars(with = "String"))]
179 pub published_at: DateTime<Utc>,
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
184#[serde(rename_all = "snake_case")]
185pub struct HookLlmRequest {
186 pub max_tokens: u32,
187 #[serde(default, skip_serializing_if = "Option::is_none")]
188 pub temperature: Option<f32>,
189 #[serde(default, skip_serializing_if = "Option::is_none")]
190 pub provider_params: Option<Value>,
191 pub message_count: usize,
192}
193
194#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
196#[serde(rename_all = "snake_case")]
197pub struct HookLlmResponse {
198 pub assistant_text: String,
199 #[serde(default)]
200 pub tool_call_names: Vec<String>,
201 #[serde(default, skip_serializing_if = "Option::is_none")]
202 pub stop_reason: Option<StopReason>,
203 #[serde(default, skip_serializing_if = "Option::is_none")]
204 pub usage: Option<Usage>,
205}
206
207#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
209#[serde(rename_all = "snake_case")]
210pub struct HookToolCall {
211 pub tool_use_id: String,
212 pub name: String,
213 pub args: Value,
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
223#[serde(rename_all = "snake_case")]
224pub struct HookToolResult {
225 pub tool_use_id: String,
226 pub name: String,
227 pub content: String,
228 pub is_error: bool,
229 #[serde(default)]
233 pub has_images: bool,
234}
235
236#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
238#[serde(rename_all = "snake_case")]
239pub struct HookInvocation {
240 pub point: HookPoint,
241 pub session_id: SessionId,
242 #[serde(default, skip_serializing_if = "Option::is_none")]
243 pub turn_number: Option<u32>,
244 #[serde(default, skip_serializing_if = "Option::is_none")]
245 pub prompt: Option<String>,
246 #[serde(default, skip_serializing_if = "Option::is_none")]
247 pub error: Option<String>,
248 #[serde(default, skip_serializing_if = "Option::is_none")]
249 pub llm_request: Option<HookLlmRequest>,
250 #[serde(default, skip_serializing_if = "Option::is_none")]
251 pub llm_response: Option<HookLlmResponse>,
252 #[serde(default, skip_serializing_if = "Option::is_none")]
253 pub tool_call: Option<HookToolCall>,
254 #[serde(default, skip_serializing_if = "Option::is_none")]
255 pub tool_result: Option<HookToolResult>,
256}
257
258#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
260#[serde(rename_all = "snake_case")]
261pub struct HookOutcome {
262 pub hook_id: HookId,
263 pub point: HookPoint,
264 pub priority: i32,
265 pub registration_index: usize,
266 #[serde(default, skip_serializing_if = "Option::is_none")]
267 pub decision: Option<HookDecision>,
268 #[serde(default)]
269 pub patches: Vec<HookPatch>,
270 #[serde(default)]
271 pub published_patches: Vec<HookPatchEnvelope>,
272 #[serde(default, skip_serializing_if = "Option::is_none")]
273 pub error: Option<String>,
274 #[serde(default, skip_serializing_if = "Option::is_none")]
275 pub duration_ms: Option<u64>,
276}
277
278#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
280#[serde(rename_all = "snake_case")]
281pub struct HookExecutionReport {
282 #[serde(default)]
283 pub outcomes: Vec<HookOutcome>,
284 #[serde(default, skip_serializing_if = "Option::is_none")]
285 pub decision: Option<HookDecision>,
286 #[serde(default)]
287 pub patches: Vec<HookPatch>,
288 #[serde(default)]
289 pub published_patches: Vec<HookPatchEnvelope>,
290}
291
292impl HookExecutionReport {
293 pub fn empty() -> Self {
294 Self::default()
295 }
296}
297
298pub fn default_failure_policy(capability: HookCapability) -> HookFailurePolicy {
299 match capability {
300 HookCapability::Observe => HookFailurePolicy::FailOpen,
301 HookCapability::Guardrail | HookCapability::Rewrite => HookFailurePolicy::FailClosed,
302 }
303}
304
305pub fn apply_tool_result_patch(
312 tool_result: &mut crate::types::ToolResult,
313 patched_text: String,
314 is_error: Option<bool>,
315) {
316 use crate::types::ContentBlock;
317
318 let image_blocks: Vec<ContentBlock> = tool_result
319 .content
320 .iter()
321 .filter(|b| matches!(b, ContentBlock::Image { .. }))
322 .cloned()
323 .collect();
324 let mut new_content = vec![ContentBlock::Text { text: patched_text }];
325 new_content.extend(image_blocks);
326 tool_result.content = new_content;
327 if let Some(value) = is_error {
328 tool_result.is_error = value;
329 }
330}
331
332#[derive(Debug, Clone, thiserror::Error)]
334pub enum HookEngineError {
335 #[error("Hook configuration invalid: {0}")]
336 InvalidConfiguration(String),
337 #[error("Hook runtime execution failed for '{hook_id}': {reason}")]
338 ExecutionFailed { hook_id: HookId, reason: String },
339 #[error("Hook '{hook_id}' timed out after {timeout_ms}ms")]
340 Timeout { hook_id: HookId, timeout_ms: u64 },
341}
342
343#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
345#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
346pub trait HookEngine: Send + Sync {
347 fn matching_hooks(
348 &self,
349 _invocation: &HookInvocation,
350 _overrides: Option<&crate::config::HookRunOverrides>,
351 ) -> Result<Vec<HookId>, HookEngineError> {
352 Ok(Vec::new())
353 }
354
355 async fn execute(
356 &self,
357 invocation: HookInvocation,
358 overrides: Option<&crate::config::HookRunOverrides>,
359 ) -> Result<HookExecutionReport, HookEngineError>;
360}
361
362#[cfg(test)]
363#[allow(clippy::unwrap_used, clippy::expect_used)]
364mod tests {
365 use super::*;
366 use crate::types::{ContentBlock, ToolResult};
367
368 fn text_block(s: &str) -> ContentBlock {
369 ContentBlock::Text {
370 text: s.to_string(),
371 }
372 }
373
374 fn image_block(media_type: &str, data: &str) -> ContentBlock {
375 ContentBlock::Image {
376 media_type: media_type.to_string(),
377 data: data.into(),
378 }
379 }
380
381 #[test]
382 fn hook_result_from_multimodal_uses_text_projection() {
383 let tr = ToolResult::with_blocks(
384 "tc_1".into(),
385 vec![text_block("hello"), image_block("image/png", "AAAA")],
386 false,
387 );
388 let hook_result = HookToolResult {
389 tool_use_id: tr.tool_use_id.clone(),
390 name: "test_tool".into(),
391 content: tr.text_content(),
392 is_error: tr.is_error,
393 has_images: tr.has_images(),
394 };
395 assert_eq!(hook_result.content, "hello\n[image: image/png]");
397 assert!(hook_result.has_images);
398 }
399
400 #[test]
401 fn hook_result_text_only_has_images_false() {
402 let tr = ToolResult::new("tc_1".into(), "just text".into(), false);
403 let hook_result = HookToolResult {
404 tool_use_id: tr.tool_use_id.clone(),
405 name: "test_tool".into(),
406 content: tr.text_content(),
407 is_error: tr.is_error,
408 has_images: tr.has_images(),
409 };
410 assert_eq!(hook_result.content, "just text");
411 assert!(!hook_result.has_images);
412 }
413
414 #[test]
415 fn hook_patch_replaces_text_preserves_images() {
416 let mut tr = ToolResult::with_blocks(
417 "tc_1".into(),
418 vec![
419 text_block("original text"),
420 image_block("image/png", "AAAA"),
421 image_block("image/jpeg", "BBBB"),
422 ],
423 false,
424 );
425 apply_tool_result_patch(&mut tr, "patched text".into(), None);
426 assert_eq!(tr.content.len(), 3);
427 assert_eq!(
428 tr.content[0],
429 ContentBlock::Text {
430 text: "patched text".into()
431 }
432 );
433 assert!(
434 matches!(&tr.content[1], ContentBlock::Image { media_type, data, .. }
435 if media_type == "image/png"
436 && matches!(data, crate::types::ImageData::Inline { data } if data == "AAAA"))
437 );
438 assert!(
439 matches!(&tr.content[2], ContentBlock::Image { media_type, data, .. }
440 if media_type == "image/jpeg"
441 && matches!(data, crate::types::ImageData::Inline { data } if data == "BBBB"))
442 );
443 }
444
445 #[test]
446 fn hook_patch_text_only_unchanged() {
447 let mut tr = ToolResult::new("tc_1".into(), "original".into(), false);
448 apply_tool_result_patch(&mut tr, "patched".into(), None);
449 assert_eq!(tr.content.len(), 1);
450 assert_eq!(tr.text_content(), "patched");
451 assert!(!tr.is_error);
452 }
453
454 #[test]
455 fn hook_patch_image_only_result_prepends_text() {
456 let mut tr =
457 ToolResult::with_blocks("tc_1".into(), vec![image_block("image/png", "AAAA")], false);
458 apply_tool_result_patch(&mut tr, "added text".into(), None);
459 assert_eq!(tr.content.len(), 2);
460 assert_eq!(
461 tr.content[0],
462 ContentBlock::Text {
463 text: "added text".into()
464 }
465 );
466 assert!(matches!(&tr.content[1], ContentBlock::Image { .. }));
467 }
468
469 #[test]
470 fn hook_patch_interleaved_reorders_text_before_images() {
471 let mut tr = ToolResult::with_blocks(
474 "tc_1".into(),
475 vec![
476 text_block("a"),
477 image_block("image/png", "X"),
478 text_block("b"),
479 image_block("image/jpeg", "Y"),
480 ],
481 false,
482 );
483 apply_tool_result_patch(&mut tr, "c".into(), None);
484 assert_eq!(tr.content.len(), 3);
485 assert_eq!(tr.content[0], ContentBlock::Text { text: "c".into() });
486 assert!(
487 matches!(&tr.content[1], ContentBlock::Image { media_type, data, .. }
488 if media_type == "image/png"
489 && matches!(data, crate::types::ImageData::Inline { data } if data == "X"))
490 );
491 assert!(
492 matches!(&tr.content[2], ContentBlock::Image { media_type, data, .. }
493 if media_type == "image/jpeg"
494 && matches!(data, crate::types::ImageData::Inline { data } if data == "Y"))
495 );
496 }
497
498 #[test]
499 fn hook_patch_sets_is_error() {
500 let mut tr = ToolResult::new("tc_1".into(), "ok".into(), false);
501 apply_tool_result_patch(&mut tr, "error".into(), Some(true));
502 assert!(tr.is_error);
503 assert_eq!(tr.text_content(), "error");
504 }
505
506 #[test]
507 fn hook_tool_result_has_images_serde_default() {
508 let json = r#"{
511 "tool_use_id": "tc_1",
512 "name": "test",
513 "content": "hello",
514 "is_error": false
515 }"#;
516 let result: HookToolResult =
517 serde_json::from_str(json).expect("should deserialize without has_images");
518 assert!(!result.has_images);
519 }
520
521 #[test]
522 fn hook_tool_result_has_images_roundtrip() {
523 let result = HookToolResult {
524 tool_use_id: "tc_1".into(),
525 name: "tool".into(),
526 content: "text".into(),
527 is_error: false,
528 has_images: true,
529 };
530 let json = serde_json::to_string(&result).expect("should serialize");
531 let decoded: HookToolResult = serde_json::from_str(&json).expect("should deserialize");
532 assert!(decoded.has_images);
533 }
534}