1use serde::{Deserialize, Serialize};
8
9#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
10pub enum HookStage {
11 PreEdit,
12 PostEdit,
13 PreCommand,
14 PostValidation,
15 PreAccept,
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
19pub enum HookAction {
20 Allow,
21 Warn,
22 Deny,
23}
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
26pub struct HookContext {
27 pub stage: HookStage,
28 pub agent_name: String,
29 #[serde(default)]
30 pub edit_description: Option<String>,
31 #[serde(default)]
32 pub patch_bytes: usize,
33 #[serde(default)]
34 pub command: Option<String>,
35 #[serde(default)]
36 pub validation_passed: Option<bool>,
37 #[serde(default)]
38 pub score_delta: Option<f32>,
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct HookDecision {
43 pub stage: HookStage,
44 pub action: HookAction,
45 pub reason: String,
46}
47
48impl HookDecision {
49 pub fn allow(stage: HookStage, reason: impl Into<String>) -> Self {
50 Self {
51 stage,
52 action: HookAction::Allow,
53 reason: reason.into(),
54 }
55 }
56
57 pub fn warn(stage: HookStage, reason: impl Into<String>) -> Self {
58 Self {
59 stage,
60 action: HookAction::Warn,
61 reason: reason.into(),
62 }
63 }
64
65 pub fn deny(stage: HookStage, reason: impl Into<String>) -> Self {
66 Self {
67 stage,
68 action: HookAction::Deny,
69 reason: reason.into(),
70 }
71 }
72
73 pub fn denied(&self) -> bool {
74 self.action == HookAction::Deny
75 }
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct HookPolicy {
80 pub max_patch_bytes: usize,
81 pub require_positive_delta: bool,
82}
83
84impl Default for HookPolicy {
85 fn default() -> Self {
86 Self {
87 max_patch_bytes: 32 * 1024,
88 require_positive_delta: true,
89 }
90 }
91}
92
93pub fn evaluate_builtin_hook(policy: &HookPolicy, context: &HookContext) -> HookDecision {
94 match context.stage {
95 HookStage::PreEdit if context.patch_bytes > policy.max_patch_bytes => HookDecision::deny(
96 HookStage::PreEdit,
97 format!(
98 "patch is too large: {} bytes exceeds {}",
99 context.patch_bytes, policy.max_patch_bytes
100 ),
101 ),
102 HookStage::PostValidation if context.validation_passed == Some(false) => {
103 HookDecision::deny(HookStage::PostValidation, "validation failed")
104 }
105 HookStage::PreAccept
106 if policy.require_positive_delta
107 && context.score_delta.is_some_and(|delta| delta <= 0.0) =>
108 {
109 HookDecision::deny(HookStage::PreAccept, "score delta is not positive")
110 }
111 HookStage::PreCommand => HookDecision::allow(
112 HookStage::PreCommand,
113 context
114 .command
115 .as_deref()
116 .map(|command| format!("command allowed: {command}"))
117 .unwrap_or_else(|| "no command supplied".to_string()),
118 ),
119 ref stage => HookDecision::allow(stage.clone(), "built-in policy allowed stage"),
120 }
121}
122
123#[cfg(test)]
124mod tests {
125 use super::*;
126
127 #[test]
128 fn oversized_patch_is_denied() {
129 let context = HookContext {
130 stage: HookStage::PreEdit,
131 agent_name: "agent".to_string(),
132 edit_description: None,
133 patch_bytes: 99,
134 command: None,
135 validation_passed: None,
136 score_delta: None,
137 };
138 let policy = HookPolicy {
139 max_patch_bytes: 10,
140 require_positive_delta: true,
141 };
142
143 let decision = evaluate_builtin_hook(&policy, &context);
144
145 assert!(decision.denied());
146 }
147
148 #[test]
149 fn non_positive_acceptance_delta_is_denied() {
150 let context = HookContext {
151 stage: HookStage::PreAccept,
152 agent_name: "agent".to_string(),
153 edit_description: None,
154 patch_bytes: 0,
155 command: None,
156 validation_passed: None,
157 score_delta: Some(0.0),
158 };
159
160 let decision = evaluate_builtin_hook(&HookPolicy::default(), &context);
161
162 assert_eq!(decision.action, HookAction::Deny);
163 }
164}