1use serde::Deserialize;
14
15pub const GATE_SCHEMA_VERSION: u32 = 2;
22
23#[derive(Debug, Deserialize, PartialEq, Eq)]
25pub struct GateInput {
26 pub tool_name: String,
27 pub tool_input: ToolInput,
28}
29
30#[derive(Debug, Deserialize, PartialEq, Eq)]
34pub struct ToolInput {
35 #[serde(default)]
36 pub command: Option<String>,
37}
38
39#[derive(Debug, thiserror::Error, PartialEq, Eq)]
40pub enum GateError {
41 #[error("could not parse gate input as JSON: {0}")]
42 Parse(String),
43 #[error(
44 "klasp-gate: schema mismatch (script={script}, binary={binary}). \
45 Re-run `klasp install` to update the hook."
46 )]
47 SchemaMismatch { script: u32, binary: u32 },
48 #[error(
49 "KLASP_GATE_SCHEMA is not set. Re-run `klasp install` to regenerate \
50 the hook script."
51 )]
52 SchemaMissing,
53}
54
55pub struct GateProtocol;
56
57impl GateProtocol {
58 pub fn parse(stdin: &str) -> Result<GateInput, GateError> {
60 serde_json::from_str(stdin).map_err(|e| GateError::Parse(e.to_string()))
61 }
62
63 pub fn read_schema_from_env() -> Result<u32, GateError> {
69 match std::env::var("KLASP_GATE_SCHEMA") {
70 Err(std::env::VarError::NotPresent) => Err(GateError::SchemaMissing),
71 Err(e) => Err(GateError::Parse(format!("KLASP_GATE_SCHEMA env var: {e}"))),
72 Ok(s) => s
73 .parse::<u32>()
74 .map_err(|e| GateError::Parse(format!("KLASP_GATE_SCHEMA = {s:?}: {e}"))),
75 }
76 }
77
78 pub fn check_schema_env(env_value: u32) -> Result<(), GateError> {
83 if env_value == GATE_SCHEMA_VERSION {
84 Ok(())
85 } else {
86 Err(GateError::SchemaMismatch {
87 script: env_value,
88 binary: GATE_SCHEMA_VERSION,
89 })
90 }
91 }
92}
93
94#[cfg(test)]
95mod tests {
96 use super::*;
97
98 #[test]
99 fn parses_minimal_claude_payload() {
100 let stdin = r#"{
101 "tool_name": "Bash",
102 "tool_input": { "command": "git commit -m 'wip'" }
103 }"#;
104 let input = GateProtocol::parse(stdin).expect("should parse");
105 assert_eq!(input.tool_name, "Bash");
106 assert_eq!(
107 input.tool_input.command.as_deref(),
108 Some("git commit -m 'wip'")
109 );
110 }
111
112 #[test]
113 fn parses_payload_without_command() {
114 let stdin = r#"{ "tool_name": "Read", "tool_input": {} }"#;
115 let input = GateProtocol::parse(stdin).expect("should parse");
116 assert_eq!(input.tool_name, "Read");
117 assert!(input.tool_input.command.is_none());
118 }
119
120 #[test]
121 fn parses_payload_ignoring_extra_fields() {
122 let stdin = r#"{
124 "tool_name": "Bash",
125 "tool_input": { "command": "ls", "extra": 42 },
126 "session_id": "abc"
127 }"#;
128 let input = GateProtocol::parse(stdin).expect("should parse");
129 assert_eq!(input.tool_input.command.as_deref(), Some("ls"));
130 }
131
132 #[test]
133 fn fails_on_malformed_json() {
134 let err = GateProtocol::parse("{ not json").expect_err("should fail");
135 assert!(matches!(err, GateError::Parse(_)));
136 }
137
138 #[test]
139 fn fails_on_missing_tool_input() {
140 let err = GateProtocol::parse(r#"{ "tool_name": "Bash" }"#).expect_err("should fail");
141 assert!(matches!(err, GateError::Parse(_)));
142 }
143
144 #[test]
145 fn schema_match_passes() {
146 assert!(GateProtocol::check_schema_env(GATE_SCHEMA_VERSION).is_ok());
147 }
148
149 #[test]
150 fn schema_mismatch_returns_error() {
151 let err = GateProtocol::check_schema_env(GATE_SCHEMA_VERSION + 1).expect_err("mismatch");
152 match err {
153 GateError::SchemaMismatch { script, binary } => {
154 assert_eq!(script, GATE_SCHEMA_VERSION + 1);
155 assert_eq!(binary, GATE_SCHEMA_VERSION);
156 }
157 other => panic!("expected SchemaMismatch, got {other:?}"),
158 }
159 }
160
161 #[test]
162 fn schema_zero_is_mismatch() {
163 let err = GateProtocol::check_schema_env(0).expect_err("zero should be mismatch");
166 assert!(matches!(err, GateError::SchemaMismatch { .. }));
167 }
168
169 #[test]
170 fn schema_missing_env_returns_schema_missing() {
171 let saved = std::env::var("KLASP_GATE_SCHEMA").ok();
176 unsafe {
178 std::env::remove_var("KLASP_GATE_SCHEMA");
179 }
180 let result = GateProtocol::read_schema_from_env();
181 if let Some(v) = saved {
182 unsafe {
184 std::env::set_var("KLASP_GATE_SCHEMA", v);
185 }
186 }
187 assert!(
188 matches!(result, Err(GateError::SchemaMissing)),
189 "expected SchemaMissing, got {result:?}",
190 );
191 }
192}