use serde::Deserialize;
pub const GATE_SCHEMA_VERSION: u32 = 2;
pub const PLUGIN_PROTOCOL_VERSION: u32 = 0;
pub const KLASP_OUTPUT_SCHEMA: u32 = 1;
#[derive(Debug, Deserialize, PartialEq, Eq)]
pub struct GateInput {
pub tool_name: String,
pub tool_input: ToolInput,
}
#[derive(Debug, Deserialize, PartialEq, Eq)]
pub struct ToolInput {
#[serde(default)]
pub command: Option<String>,
}
#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub enum GateError {
#[error("could not parse gate input as JSON: {0}")]
Parse(String),
#[error(
"klasp-gate: schema mismatch (script={script}, binary={binary}). \
Re-run `klasp install` to update the hook."
)]
SchemaMismatch { script: u32, binary: u32 },
#[error(
"KLASP_GATE_SCHEMA is not set. Re-run `klasp install` to regenerate \
the hook script."
)]
SchemaMissing,
}
pub struct GateProtocol;
impl GateProtocol {
pub fn parse(stdin: &str) -> Result<GateInput, GateError> {
serde_json::from_str(stdin).map_err(|e| GateError::Parse(e.to_string()))
}
pub fn read_schema_from_env() -> Result<u32, GateError> {
match std::env::var("KLASP_GATE_SCHEMA") {
Err(std::env::VarError::NotPresent) => Err(GateError::SchemaMissing),
Err(e) => Err(GateError::Parse(format!("KLASP_GATE_SCHEMA env var: {e}"))),
Ok(s) => s
.parse::<u32>()
.map_err(|e| GateError::Parse(format!("KLASP_GATE_SCHEMA = {s:?}: {e}"))),
}
}
pub fn check_schema_env(env_value: u32) -> Result<(), GateError> {
if env_value == GATE_SCHEMA_VERSION {
Ok(())
} else {
Err(GateError::SchemaMismatch {
script: env_value,
binary: GATE_SCHEMA_VERSION,
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parses_minimal_claude_payload() {
let stdin = r#"{
"tool_name": "Bash",
"tool_input": { "command": "git commit -m 'wip'" }
}"#;
let input = GateProtocol::parse(stdin).expect("should parse");
assert_eq!(input.tool_name, "Bash");
assert_eq!(
input.tool_input.command.as_deref(),
Some("git commit -m 'wip'")
);
}
#[test]
fn parses_payload_without_command() {
let stdin = r#"{ "tool_name": "Read", "tool_input": {} }"#;
let input = GateProtocol::parse(stdin).expect("should parse");
assert_eq!(input.tool_name, "Read");
assert!(input.tool_input.command.is_none());
}
#[test]
fn parses_payload_ignoring_extra_fields() {
let stdin = r#"{
"tool_name": "Bash",
"tool_input": { "command": "ls", "extra": 42 },
"session_id": "abc"
}"#;
let input = GateProtocol::parse(stdin).expect("should parse");
assert_eq!(input.tool_input.command.as_deref(), Some("ls"));
}
#[test]
fn fails_on_malformed_json() {
let err = GateProtocol::parse("{ not json").expect_err("should fail");
assert!(matches!(err, GateError::Parse(_)));
}
#[test]
fn fails_on_missing_tool_input() {
let err = GateProtocol::parse(r#"{ "tool_name": "Bash" }"#).expect_err("should fail");
assert!(matches!(err, GateError::Parse(_)));
}
#[test]
fn schema_match_passes() {
assert!(GateProtocol::check_schema_env(GATE_SCHEMA_VERSION).is_ok());
}
#[test]
fn schema_mismatch_returns_error() {
let err = GateProtocol::check_schema_env(GATE_SCHEMA_VERSION + 1).expect_err("mismatch");
match err {
GateError::SchemaMismatch { script, binary } => {
assert_eq!(script, GATE_SCHEMA_VERSION + 1);
assert_eq!(binary, GATE_SCHEMA_VERSION);
}
other => panic!("expected SchemaMismatch, got {other:?}"),
}
}
#[test]
fn schema_zero_is_mismatch() {
let err = GateProtocol::check_schema_env(0).expect_err("zero should be mismatch");
assert!(matches!(err, GateError::SchemaMismatch { .. }));
}
#[test]
fn schema_missing_env_returns_schema_missing() {
let saved = std::env::var("KLASP_GATE_SCHEMA").ok();
unsafe {
std::env::remove_var("KLASP_GATE_SCHEMA");
}
let result = GateProtocol::read_schema_from_env();
if let Some(v) = saved {
unsafe {
std::env::set_var("KLASP_GATE_SCHEMA", v);
}
}
assert!(
matches!(result, Err(GateError::SchemaMissing)),
"expected SchemaMissing, got {result:?}",
);
}
}