1use serde::Deserialize;
14
15pub const GATE_SCHEMA_VERSION: u32 = 1;
18
19#[derive(Debug, Deserialize, PartialEq, Eq)]
21pub struct GateInput {
22 pub tool_name: String,
23 pub tool_input: ToolInput,
24}
25
26#[derive(Debug, Deserialize, PartialEq, Eq)]
30pub struct ToolInput {
31 #[serde(default)]
32 pub command: Option<String>,
33}
34
35#[derive(Debug, thiserror::Error, PartialEq, Eq)]
36pub enum GateError {
37 #[error("could not parse gate input as JSON: {0}")]
38 Parse(String),
39 #[error(
40 "klasp-gate: schema mismatch (script={script}, binary={binary}). \
41 Re-run `klasp install` to update the hook."
42 )]
43 SchemaMismatch { script: u32, binary: u32 },
44 #[error(
45 "KLASP_GATE_SCHEMA is not set. Re-run `klasp install` to regenerate \
46 the hook script."
47 )]
48 SchemaMissing,
49}
50
51pub struct GateProtocol;
52
53impl GateProtocol {
54 pub fn parse(stdin: &str) -> Result<GateInput, GateError> {
56 serde_json::from_str(stdin).map_err(|e| GateError::Parse(e.to_string()))
57 }
58
59 pub fn read_schema_from_env() -> Result<u32, GateError> {
65 match std::env::var("KLASP_GATE_SCHEMA") {
66 Err(std::env::VarError::NotPresent) => Err(GateError::SchemaMissing),
67 Err(e) => Err(GateError::Parse(format!("KLASP_GATE_SCHEMA env var: {e}"))),
68 Ok(s) => s
69 .parse::<u32>()
70 .map_err(|e| GateError::Parse(format!("KLASP_GATE_SCHEMA = {s:?}: {e}"))),
71 }
72 }
73
74 pub fn check_schema_env(env_value: u32) -> Result<(), GateError> {
79 if env_value == GATE_SCHEMA_VERSION {
80 Ok(())
81 } else {
82 Err(GateError::SchemaMismatch {
83 script: env_value,
84 binary: GATE_SCHEMA_VERSION,
85 })
86 }
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93
94 #[test]
95 fn parses_minimal_claude_payload() {
96 let stdin = r#"{
97 "tool_name": "Bash",
98 "tool_input": { "command": "git commit -m 'wip'" }
99 }"#;
100 let input = GateProtocol::parse(stdin).expect("should parse");
101 assert_eq!(input.tool_name, "Bash");
102 assert_eq!(
103 input.tool_input.command.as_deref(),
104 Some("git commit -m 'wip'")
105 );
106 }
107
108 #[test]
109 fn parses_payload_without_command() {
110 let stdin = r#"{ "tool_name": "Read", "tool_input": {} }"#;
111 let input = GateProtocol::parse(stdin).expect("should parse");
112 assert_eq!(input.tool_name, "Read");
113 assert!(input.tool_input.command.is_none());
114 }
115
116 #[test]
117 fn parses_payload_ignoring_extra_fields() {
118 let stdin = r#"{
120 "tool_name": "Bash",
121 "tool_input": { "command": "ls", "extra": 42 },
122 "session_id": "abc"
123 }"#;
124 let input = GateProtocol::parse(stdin).expect("should parse");
125 assert_eq!(input.tool_input.command.as_deref(), Some("ls"));
126 }
127
128 #[test]
129 fn fails_on_malformed_json() {
130 let err = GateProtocol::parse("{ not json").expect_err("should fail");
131 assert!(matches!(err, GateError::Parse(_)));
132 }
133
134 #[test]
135 fn fails_on_missing_tool_input() {
136 let err = GateProtocol::parse(r#"{ "tool_name": "Bash" }"#).expect_err("should fail");
137 assert!(matches!(err, GateError::Parse(_)));
138 }
139
140 #[test]
141 fn schema_match_passes() {
142 assert!(GateProtocol::check_schema_env(GATE_SCHEMA_VERSION).is_ok());
143 }
144
145 #[test]
146 fn schema_mismatch_returns_error() {
147 let err = GateProtocol::check_schema_env(GATE_SCHEMA_VERSION + 1).expect_err("mismatch");
148 match err {
149 GateError::SchemaMismatch { script, binary } => {
150 assert_eq!(script, GATE_SCHEMA_VERSION + 1);
151 assert_eq!(binary, GATE_SCHEMA_VERSION);
152 }
153 other => panic!("expected SchemaMismatch, got {other:?}"),
154 }
155 }
156
157 #[test]
158 fn schema_zero_is_mismatch() {
159 let err = GateProtocol::check_schema_env(0).expect_err("zero should be mismatch");
162 assert!(matches!(err, GateError::SchemaMismatch { .. }));
163 }
164
165 #[test]
166 fn schema_missing_env_returns_schema_missing() {
167 let saved = std::env::var("KLASP_GATE_SCHEMA").ok();
172 unsafe {
174 std::env::remove_var("KLASP_GATE_SCHEMA");
175 }
176 let result = GateProtocol::read_schema_from_env();
177 if let Some(v) = saved {
178 unsafe {
180 std::env::set_var("KLASP_GATE_SCHEMA", v);
181 }
182 }
183 assert!(
184 matches!(result, Err(GateError::SchemaMissing)),
185 "expected SchemaMissing, got {result:?}",
186 );
187 }
188}