use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use crate::options::{AgentDefinition, HookEvent, HookMatcher};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "subtype", rename_all = "snake_case")]
pub enum ControlRequest {
Initialize {
#[serde(skip_serializing_if = "HashMap::is_empty", default)]
hooks: HashMap<HookEvent, Vec<HookMatcher>>,
#[serde(skip_serializing_if = "HashMap::is_empty", default)]
agents: HashMap<String, AgentDefinition>,
#[serde(
rename = "sdkMcpServers",
skip_serializing_if = "Vec::is_empty",
default
)]
sdk_mcp_servers: Vec<String>,
},
Interrupt,
SetPermissionMode {
mode: String,
},
SetModel {
model: String,
},
McpStatus,
RewindFiles {
user_message_id: String,
},
GetServerInfo,
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
#[serde(tag = "subtype", rename_all = "snake_case")]
pub enum ControlResponse {
Success {
#[serde(flatten)]
data: Value,
},
Error {
error: String,
#[serde(flatten)]
extra: Value,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "subtype", rename_all = "snake_case")]
pub enum IncomingControlRequest {
CanUseTool {
tool_name: String,
tool_input: Value,
},
HookCallback {
hook_id: String,
hook_event: HookEvent,
hook_input: Value,
},
McpMessage {
server_name: String,
message: Value,
},
}
#[cfg(test)]
mod tests {
use super::*;
use serde_json::json;
#[test]
fn test_control_request_initialize_minimal() {
let req = ControlRequest::Initialize {
hooks: HashMap::new(),
agents: HashMap::new(),
sdk_mcp_servers: vec![],
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["subtype"], "initialize");
assert!(json.get("hooks").is_none());
assert!(json.get("agents").is_none());
assert!(json.get("sdkMcpServers").is_none());
}
#[test]
fn test_control_request_initialize_roundtrip() {
let req = ControlRequest::Initialize {
hooks: HashMap::new(),
agents: HashMap::new(),
sdk_mcp_servers: vec![],
};
let json = serde_json::to_string(&req).unwrap();
let parsed: ControlRequest = serde_json::from_str(&json).unwrap();
match parsed {
ControlRequest::Initialize {
hooks,
agents,
sdk_mcp_servers,
} => {
assert!(hooks.is_empty());
assert!(agents.is_empty());
assert!(sdk_mcp_servers.is_empty());
}
_ => panic!("Wrong variant"),
}
}
#[test]
fn test_control_request_interrupt() {
let req = ControlRequest::Interrupt;
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["subtype"], "interrupt");
}
#[test]
fn test_control_request_set_permission_mode() {
let req = ControlRequest::SetPermissionMode {
mode: "accept_edits".to_string(),
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["subtype"], "set_permission_mode");
assert_eq!(json["mode"], "accept_edits");
}
#[test]
fn test_control_request_set_model() {
let req = ControlRequest::SetModel {
model: "claude-sonnet-4".to_string(),
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["subtype"], "set_model");
assert_eq!(json["model"], "claude-sonnet-4");
}
#[test]
fn test_control_request_mcp_status() {
let req = ControlRequest::McpStatus;
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["subtype"], "mcp_status");
}
#[test]
fn test_control_request_rewind_files() {
let req = ControlRequest::RewindFiles {
user_message_id: "msg_123".to_string(),
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["subtype"], "rewind_files");
assert_eq!(json["user_message_id"], "msg_123");
}
#[test]
fn test_control_request_get_server_info() {
let req = ControlRequest::GetServerInfo;
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["subtype"], "get_server_info");
}
#[test]
fn test_control_response_success() {
let resp = ControlResponse::Success {
data: json!({ "allowed": true }),
};
let json = serde_json::to_value(&resp).unwrap();
assert_eq!(json["subtype"], "success");
assert_eq!(json["allowed"], true);
}
#[test]
fn test_control_response_error() {
let resp = ControlResponse::Error {
error: "Tool not found".to_string(),
extra: json!({ "code": "tool_not_found" }),
};
let json = serde_json::to_value(&resp).unwrap();
assert_eq!(json["subtype"], "error");
assert_eq!(json["error"], "Tool not found");
assert_eq!(json["code"], "tool_not_found");
}
#[test]
fn test_control_response_roundtrip() {
let resp = ControlResponse::Success {
data: json!({ "foo": "bar", "count": 42 }),
};
let json = serde_json::to_string(&resp).unwrap();
let parsed: ControlResponse = serde_json::from_str(&json).unwrap();
match parsed {
ControlResponse::Success { data } => {
assert_eq!(data["foo"], "bar");
assert_eq!(data["count"], 42);
}
_ => panic!("Wrong variant"),
}
}
#[test]
fn test_incoming_control_request_can_use_tool() {
let req = IncomingControlRequest::CanUseTool {
tool_name: "Bash".to_string(),
tool_input: json!({ "command": "ls -la" }),
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["subtype"], "can_use_tool");
assert_eq!(json["tool_name"], "Bash");
assert_eq!(json["tool_input"]["command"], "ls -la");
}
#[test]
fn test_incoming_control_request_hook_callback() {
let req = IncomingControlRequest::HookCallback {
hook_id: "pre_commit".to_string(),
hook_event: crate::options::HookEvent::PreToolUse,
hook_input: json!({ "tool": "Bash" }),
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["subtype"], "hook_callback");
assert_eq!(json["hook_id"], "pre_commit");
}
#[test]
fn test_incoming_control_request_mcp_message() {
let req = IncomingControlRequest::McpMessage {
server_name: "my_server".to_string(),
message: json!({
"jsonrpc": "2.0",
"id": 1,
"method": "tools/call"
}),
};
let json = serde_json::to_value(&req).unwrap();
assert_eq!(json["subtype"], "mcp_message");
assert_eq!(json["server_name"], "my_server");
assert_eq!(json["message"]["method"], "tools/call");
}
#[test]
fn test_incoming_control_request_roundtrip() {
let req = IncomingControlRequest::CanUseTool {
tool_name: "Read".to_string(),
tool_input: json!({ "file_path": "/tmp/test.txt" }),
};
let json = serde_json::to_string(&req).unwrap();
let parsed: IncomingControlRequest = serde_json::from_str(&json).unwrap();
match parsed {
IncomingControlRequest::CanUseTool {
tool_name,
tool_input,
} => {
assert_eq!(tool_name, "Read");
assert_eq!(tool_input["file_path"], "/tmp/test.txt");
}
_ => panic!("Wrong variant"),
}
}
}