1use crate::error::AgentError;
4use crate::event::{AgentErrorClass, AgentErrorReport, ToolCallArguments};
5use crate::types::{ContentBlock, ContentInput, SessionId, StopReason, ToolResult, Usage};
6use async_trait::async_trait;
7use chrono::{DateTime, Utc};
8use schemars::JsonSchema;
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11
12#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
14#[serde(transparent)]
15pub struct HookId(pub String);
16
17impl HookId {
18 pub fn new(id: impl Into<String>) -> Self {
19 Self(id.into())
20 }
21}
22
23impl std::fmt::Display for HookId {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 self.0.fmt(f)
26 }
27}
28
29impl From<&str> for HookId {
30 fn from(value: &str) -> Self {
31 Self::new(value)
32 }
33}
34
35impl From<String> for HookId {
36 fn from(value: String) -> Self {
37 Self::new(value)
38 }
39}
40
41#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
43#[serde(rename_all = "snake_case")]
44pub enum HookPoint {
45 RunStarted,
46 RunCompleted,
47 RunFailed,
48 PreLlmRequest,
49 PostLlmResponse,
50 PreToolExecution,
51 PostToolExecution,
52 TurnBoundary,
53}
54
55impl HookPoint {
56 pub fn is_pre(self) -> bool {
57 matches!(
58 self,
59 Self::RunStarted | Self::PreLlmRequest | Self::PreToolExecution | Self::TurnBoundary
60 )
61 }
62
63 pub fn is_post(self) -> bool {
64 matches!(
65 self,
66 Self::PostLlmResponse | Self::PostToolExecution | Self::RunCompleted | Self::RunFailed
67 )
68 }
69}
70
71#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
73#[serde(rename_all = "snake_case")]
74pub enum HookExecutionMode {
75 Foreground,
76 Background,
77}
78
79#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
81#[serde(rename_all = "snake_case")]
82pub enum HookCapability {
83 Observe,
84 Guardrail,
85 Rewrite,
88}
89
90#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
92#[serde(rename_all = "snake_case")]
93pub enum HookFailurePolicy {
94 FailOpen,
95 FailClosed,
96}
97
98#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq, Hash)]
100#[serde(rename_all = "snake_case")]
101pub enum HookReasonCode {
102 PolicyViolation,
103 SafetyViolation,
104 SchemaViolation,
105 Timeout,
106 RuntimeError,
107}
108
109#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
111#[serde(tag = "decision", rename_all = "snake_case")]
112pub enum HookDecision {
113 Allow,
114 Deny {
115 hook_id: HookId,
116 reason_code: HookReasonCode,
117 message: String,
118 #[serde(default, skip_serializing_if = "Option::is_none")]
119 payload: Option<Value>,
120 },
121}
122
123impl HookDecision {
124 pub fn deny(
125 hook_id: HookId,
126 reason_code: HookReasonCode,
127 message: impl Into<String>,
128 payload: Option<Value>,
129 ) -> Self {
130 Self::Deny {
131 hook_id,
132 reason_code,
133 message: message.into(),
134 payload,
135 }
136 }
137}
138
139#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
147#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
148#[serde(tag = "patch_type", rename_all = "snake_case")]
149pub enum HookPatch {}
150
151#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
153#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
154#[serde(transparent)]
155pub struct HookRevision(pub u64);
156
157#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))]
159#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
160#[serde(rename_all = "snake_case")]
161pub struct HookPatchEnvelope {
162 pub revision: HookRevision,
163 pub hook_id: HookId,
164 pub point: HookPoint,
165 pub patch: HookPatch,
166 #[cfg_attr(feature = "schema", schemars(with = "String"))]
167 pub published_at: DateTime<Utc>,
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
172#[serde(rename_all = "snake_case")]
173pub struct HookLlmRequest {
174 pub max_tokens: u32,
175 #[serde(default, skip_serializing_if = "Option::is_none")]
176 pub temperature: Option<f32>,
177 #[serde(default, skip_serializing_if = "Option::is_none")]
178 pub provider_params: Option<Value>,
179 pub message_count: usize,
180}
181
182#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
184#[serde(rename_all = "snake_case")]
185pub struct HookLlmResponse {
186 pub assistant_text: String,
187 #[serde(default)]
188 pub tool_call_names: Vec<String>,
189 #[serde(default, skip_serializing_if = "Option::is_none")]
190 pub stop_reason: Option<StopReason>,
191 #[serde(default, skip_serializing_if = "Option::is_none")]
192 pub usage: Option<Usage>,
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
197#[serde(rename_all = "snake_case")]
198pub struct HookToolCall {
199 pub tool_use_id: String,
200 pub name: String,
201 pub args: ToolCallArguments,
202}
203
204#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
209#[serde(rename_all = "snake_case")]
210pub struct HookToolResult {
211 pub tool_use_id: String,
212 pub name: String,
213 pub content: String,
215 #[serde(default, skip_serializing_if = "Vec::is_empty")]
217 pub content_blocks: Vec<ContentBlock>,
218 pub is_error: bool,
219 #[serde(default, skip_serializing)]
222 pub has_images: bool,
223}
224
225impl HookToolResult {
226 pub fn from_tool_result(name: impl Into<String>, result: &ToolResult) -> Self {
227 Self::from_tool_result_with_id(result.tool_use_id.clone(), name, result)
228 }
229
230 pub fn from_tool_result_with_id(
231 tool_use_id: impl Into<String>,
232 name: impl Into<String>,
233 result: &ToolResult,
234 ) -> Self {
235 Self {
236 tool_use_id: tool_use_id.into(),
237 name: name.into(),
238 content: result.text_content(),
239 content_blocks: result.content.clone(),
240 is_error: result.is_error,
241 has_images: result.has_images(),
242 }
243 }
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
248#[serde(rename_all = "snake_case")]
249pub struct HookInvocation {
250 pub point: HookPoint,
251 pub session_id: SessionId,
252 #[serde(default, skip_serializing_if = "Option::is_none")]
253 pub turn_number: Option<u32>,
254 #[serde(default, skip_serializing_if = "Option::is_none")]
255 pub prompt_input: Option<ContentInput>,
256 #[serde(default, skip_serializing_if = "Option::is_none")]
258 pub prompt: Option<String>,
259 #[serde(default, skip_serializing_if = "Option::is_none")]
260 pub error_report: Option<AgentErrorReport>,
261 #[serde(default, skip_serializing_if = "Option::is_none")]
262 pub error_class: Option<AgentErrorClass>,
263 #[serde(default, skip_serializing_if = "Option::is_none")]
265 pub error: Option<String>,
266 #[serde(default, skip_serializing_if = "Option::is_none")]
267 pub llm_request: Option<HookLlmRequest>,
268 #[serde(default, skip_serializing_if = "Option::is_none")]
269 pub llm_response: Option<HookLlmResponse>,
270 #[serde(default, skip_serializing_if = "Option::is_none")]
271 pub tool_call: Option<HookToolCall>,
272 #[serde(default, skip_serializing_if = "Option::is_none")]
273 pub tool_result: Option<HookToolResult>,
274}
275
276impl HookInvocation {
277 pub fn new(point: HookPoint, session_id: SessionId) -> Self {
278 Self {
279 point,
280 session_id,
281 turn_number: None,
282 prompt_input: None,
283 prompt: None,
284 error_report: None,
285 error_class: None,
286 error: None,
287 llm_request: None,
288 llm_response: None,
289 tool_call: None,
290 tool_result: None,
291 }
292 }
293
294 pub fn run_started(session_id: SessionId, prompt_input: ContentInput) -> Self {
295 let prompt = prompt_input.text_content();
296 Self {
297 prompt_input: Some(prompt_input),
298 prompt: Some(prompt),
299 ..Self::new(HookPoint::RunStarted, session_id)
300 }
301 }
302
303 pub fn run_completed(session_id: SessionId, turn_number: u32) -> Self {
304 Self {
305 turn_number: Some(turn_number),
306 ..Self::new(HookPoint::RunCompleted, session_id)
307 }
308 }
309
310 pub fn run_failed(session_id: SessionId, error: &AgentError) -> Self {
311 let error_report = AgentErrorReport::from_agent_error(error);
312 Self {
313 error_class: Some(error_report.class),
314 error: Some(error_report.message.clone()),
315 error_report: Some(error_report),
316 ..Self::new(HookPoint::RunFailed, session_id)
317 }
318 }
319}
320
321#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
323#[serde(rename_all = "snake_case")]
324pub struct HookOutcome {
325 pub hook_id: HookId,
326 pub point: HookPoint,
327 pub priority: i32,
328 pub registration_index: usize,
329 #[serde(default, skip_serializing_if = "Option::is_none")]
330 pub decision: Option<HookDecision>,
331 #[serde(default)]
332 pub patches: Vec<HookPatch>,
333 #[serde(default)]
334 pub published_patches: Vec<HookPatchEnvelope>,
335 #[serde(default, skip_serializing_if = "Option::is_none")]
336 pub error: Option<String>,
337 #[serde(default, skip_serializing_if = "Option::is_none")]
338 pub duration_ms: Option<u64>,
339}
340
341#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
343#[serde(rename_all = "snake_case")]
344pub struct HookExecutionReport {
345 #[serde(default)]
346 pub outcomes: Vec<HookOutcome>,
347 #[serde(default, skip_serializing_if = "Option::is_none")]
348 pub decision: Option<HookDecision>,
349 #[serde(default)]
350 pub patches: Vec<HookPatch>,
351 #[serde(default)]
352 pub published_patches: Vec<HookPatchEnvelope>,
353}
354
355impl HookExecutionReport {
356 pub fn empty() -> Self {
357 Self::default()
358 }
359
360 pub fn denial_error(&self, point: HookPoint) -> Option<AgentError> {
366 match self.decision.as_ref()? {
367 HookDecision::Deny {
368 hook_id,
369 reason_code,
370 message,
371 payload,
372 } => Some(AgentError::HookDenied {
373 hook_id: hook_id.clone(),
374 point,
375 reason_code: *reason_code,
376 message: message.clone(),
377 payload: payload.clone(),
378 }),
379 HookDecision::Allow => None,
380 }
381 }
382}
383
384pub fn default_failure_policy(capability: HookCapability) -> HookFailurePolicy {
385 match capability {
386 HookCapability::Observe => HookFailurePolicy::FailOpen,
387 HookCapability::Guardrail | HookCapability::Rewrite => HookFailurePolicy::FailClosed,
388 }
389}
390
391#[derive(Debug, Clone, thiserror::Error)]
393pub enum HookEngineError {
394 #[error("Hook configuration invalid: {0}")]
395 InvalidConfiguration(String),
396 #[error("Hook runtime execution failed for '{hook_id}': {reason}")]
397 ExecutionFailed { hook_id: HookId, reason: String },
398 #[error("Hook '{hook_id}' timed out after {timeout_ms}ms")]
399 Timeout { hook_id: HookId, timeout_ms: u64 },
400}
401
402impl HookEngineError {
403 pub fn hook_id(&self) -> Option<&HookId> {
404 match self {
405 Self::InvalidConfiguration(_) => None,
406 Self::ExecutionFailed { hook_id, .. } | Self::Timeout { hook_id, .. } => Some(hook_id),
407 }
408 }
409
410 pub fn into_agent_error(self) -> AgentError {
411 match self {
412 Self::InvalidConfiguration(reason) => AgentError::HookConfigInvalid { reason },
413 Self::Timeout {
414 hook_id,
415 timeout_ms,
416 } => AgentError::HookTimeout {
417 hook_id,
418 timeout_ms,
419 },
420 Self::ExecutionFailed { hook_id, reason } => {
421 AgentError::HookExecutionFailed { hook_id, reason }
422 }
423 }
424 }
425}
426
427#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
429#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
430pub trait HookEngine: Send + Sync {
431 fn matching_hooks(
432 &self,
433 _invocation: &HookInvocation,
434 _overrides: Option<&crate::config::HookRunOverrides>,
435 ) -> Result<Vec<HookId>, HookEngineError> {
436 Ok(Vec::new())
437 }
438
439 async fn execute(
440 &self,
441 invocation: HookInvocation,
442 overrides: Option<&crate::config::HookRunOverrides>,
443 ) -> Result<HookExecutionReport, HookEngineError>;
444
445 async fn drain_published_patches(
447 &self,
448 _session_id: &SessionId,
449 ) -> Result<Vec<HookPatchEnvelope>, HookEngineError> {
450 Ok(Vec::new())
451 }
452}
453
454#[cfg(test)]
455#[allow(clippy::unwrap_used, clippy::expect_used)]
456mod tests {
457 use super::*;
458 use crate::types::{ContentBlock, ToolResult};
459
460 fn text_block(s: &str) -> ContentBlock {
461 ContentBlock::Text {
462 text: s.to_string(),
463 }
464 }
465
466 fn image_block(media_type: &str, data: &str) -> ContentBlock {
467 ContentBlock::Image {
468 media_type: media_type.to_string(),
469 data: data.into(),
470 }
471 }
472
473 #[test]
474 fn hook_tool_call_rejects_string_args_on_deserialize() {
475 let value = serde_json::json!({
476 "tool_use_id": "tc_1",
477 "name": "search",
478 "args": "{\"query\":"
479 });
480
481 let err = serde_json::from_value::<HookToolCall>(value)
482 .expect_err("hook surface must reject string-success tool args");
483 assert!(
484 err.to_string().contains("JSON object, got string"),
485 "unexpected error: {err}"
486 );
487 }
488
489 #[test]
490 fn hook_result_from_multimodal_uses_text_projection() {
491 let tr = ToolResult::with_blocks(
492 "tc_1".into(),
493 vec![text_block("hello"), image_block("image/png", "AAAA")],
494 false,
495 );
496 let hook_result = HookToolResult {
497 tool_use_id: tr.tool_use_id.clone(),
498 name: "test_tool".into(),
499 content: tr.text_content(),
500 content_blocks: tr.content.clone(),
501 is_error: tr.is_error,
502 has_images: tr.has_images(),
503 };
504 assert_eq!(hook_result.content, "hello\n[image: image/png]");
506 assert!(hook_result.has_images);
507 }
508
509 #[test]
510 fn hook_result_text_only_has_images_false() {
511 let tr = ToolResult::new("tc_1".into(), "just text".into(), false);
512 let hook_result = HookToolResult {
513 tool_use_id: tr.tool_use_id.clone(),
514 name: "test_tool".into(),
515 content: tr.text_content(),
516 content_blocks: tr.content.clone(),
517 is_error: tr.is_error,
518 has_images: tr.has_images(),
519 };
520 assert_eq!(hook_result.content, "just text");
521 assert_eq!(hook_result.content_blocks, vec![text_block("just text")]);
522 assert!(!hook_result.has_images);
523 }
524
525 #[test]
526 fn hook_result_text_only_serializes_typed_content_blocks() {
527 let tr = ToolResult::new("tc_1".into(), "just text".into(), false);
528 let hook_result = HookToolResult::from_tool_result("test_tool", &tr);
529
530 assert_eq!(hook_result.content, "just text");
531 assert_eq!(hook_result.content_blocks, vec![text_block("just text")]);
532
533 let json = serde_json::to_value(&hook_result).expect("serialize hook tool result");
534 assert_eq!(
535 json["content_blocks"],
536 serde_json::json!([{"type": "text", "text": "just text"}])
537 );
538 assert!(
539 json.get("has_images").is_none(),
540 "typed content blocks should replace the image side flag on the hook surface"
541 );
542 }
543
544 #[test]
545 fn hook_result_image_only_serializes_typed_content_blocks() {
546 let tr =
547 ToolResult::with_blocks("tc_1".into(), vec![image_block("image/png", "AAAA")], false);
548 let hook_result = HookToolResult::from_tool_result("view_image", &tr);
549
550 assert_eq!(hook_result.content, "[image: image/png]");
551 assert_eq!(
552 hook_result.content_blocks,
553 vec![image_block("image/png", "AAAA")]
554 );
555
556 let json = serde_json::to_value(&hook_result).expect("serialize hook tool result");
557 assert_eq!(
558 json["content_blocks"],
559 serde_json::json!([{
560 "type": "image",
561 "media_type": "image/png",
562 "source": "inline",
563 "data": "AAAA"
564 }])
565 );
566 assert!(
567 json.get("has_images").is_none(),
568 "typed content blocks should replace the image side flag on the hook surface"
569 );
570 }
571
572 #[test]
573 fn hook_result_mixed_content_preserves_block_order() {
574 let tr = ToolResult::with_blocks(
575 "tc_1".into(),
576 vec![
577 text_block("before"),
578 image_block("image/png", "AAAA"),
579 text_block("after"),
580 ],
581 false,
582 );
583 let hook_result = HookToolResult::from_tool_result("mixed_tool", &tr);
584
585 assert_eq!(hook_result.content, "before\n[image: image/png]\nafter");
586 assert_eq!(hook_result.content_blocks, tr.content);
587 }
588
589 #[test]
590 fn hook_result_can_use_authoritative_tool_call_id() {
591 let tr = ToolResult::new("stale_tool_id".into(), "ok".into(), false);
592 let hook_result =
593 HookToolResult::from_tool_result_with_id("active_tool_id", "test_tool", &tr);
594
595 assert_eq!(hook_result.tool_use_id, "active_tool_id");
596 assert_eq!(hook_result.content_blocks, vec![text_block("ok")]);
597 }
598
599 #[test]
600 fn hook_tool_result_has_images_serde_default() {
601 let json = r#"{
604 "tool_use_id": "tc_1",
605 "name": "test",
606 "content": "hello",
607 "is_error": false
608 }"#;
609 let result: HookToolResult =
610 serde_json::from_str(json).expect("should deserialize without has_images");
611 assert!(!result.has_images);
612 }
613
614 #[test]
615 fn hook_tool_result_has_images_is_deserialize_only_legacy_flag() {
616 let result = HookToolResult {
617 tool_use_id: "tc_1".into(),
618 name: "tool".into(),
619 content: "text".into(),
620 content_blocks: vec![text_block("text")],
621 is_error: false,
622 has_images: true,
623 };
624 let json = serde_json::to_value(&result).expect("should serialize");
625 assert!(
626 json.get("has_images").is_none(),
627 "new hook payloads carry content_blocks instead of has_images"
628 );
629
630 let decoded: HookToolResult = serde_json::from_value(serde_json::json!({
631 "tool_use_id": "tc_1",
632 "name": "tool",
633 "content": "text",
634 "content_blocks": [{"type": "text", "text": "text"}],
635 "is_error": false,
636 "has_images": true
637 }))
638 .expect("should deserialize");
639 assert!(decoded.has_images);
640 }
641
642 #[test]
643 fn legacy_semantic_hook_patches_are_rejected_on_deserialize() {
644 let legacy_payloads = [
645 serde_json::json!({
646 "patch_type": "llm_request",
647 "max_tokens": 1,
648 "temperature": 0.1,
649 "provider_params": {"reasoning_effort": "low"}
650 }),
651 serde_json::json!({
652 "patch_type": "assistant_text",
653 "text": "patched"
654 }),
655 serde_json::json!({
656 "patch_type": "tool_args",
657 "args": {"value": "patched"}
658 }),
659 serde_json::json!({
660 "patch_type": "tool_result",
661 "content": "patched",
662 "is_error": false
663 }),
664 serde_json::json!({
665 "patch_type": "run_result",
666 "text": "patched"
667 }),
668 ];
669
670 for value in legacy_payloads {
671 let result = serde_json::from_value::<HookPatch>(value.clone());
672 assert!(
673 result.is_err(),
674 "legacy semantic hook patch payload must be rejected: {value}"
675 );
676 }
677 }
678}