use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use tokio::sync::{mpsc, oneshot, Mutex};
use crate::error::{ClaudeError, Result};
use crate::types::{HookEvent, PermissionRequest, PermissionResult, RequestId};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type")]
pub enum ControlMessage {
#[serde(rename = "request")]
Request(ControlRequest),
#[serde(rename = "response")]
Response(ControlResponse),
#[serde(rename = "init")]
Init(InitRequest),
#[serde(rename = "init_response")]
InitResponse(InitResponse),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "method", content = "params")]
pub enum ControlRequest {
#[serde(rename = "interrupt")]
Interrupt {
id: RequestId,
},
#[serde(rename = "send_message")]
SendMessage {
id: RequestId,
content: String,
},
#[serde(rename = "hook_response")]
HookResponse {
id: RequestId,
hook_id: String,
response: serde_json::Value,
},
#[serde(rename = "permission_response")]
PermissionResponse {
id: RequestId,
request_id: RequestId,
result: PermissionResult,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "status")]
pub enum ControlResponse {
#[serde(rename = "success")]
Success {
id: RequestId,
data: Option<serde_json::Value>,
},
#[serde(rename = "error")]
Error {
id: RequestId,
message: String,
code: Option<String>,
},
#[serde(rename = "hook")]
Hook {
id: String,
event: HookEvent,
},
#[serde(rename = "permission")]
Permission {
id: RequestId,
request: PermissionRequest,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InitRequest {
pub protocol_version: String,
pub sdk_version: String,
pub capabilities: ClientCapabilities,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClientCapabilities {
pub bidirectional: bool,
pub hooks: bool,
pub permissions: bool,
pub interrupts: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InitResponse {
pub protocol_version: String,
pub cli_version: String,
pub capabilities: ServerCapabilities,
pub session_id: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ServerCapabilities {
pub streaming: bool,
pub tools: bool,
pub mcp: bool,
}
struct PendingRequest {
response_tx: oneshot::Sender<ControlResponse>,
}
pub struct ProtocolHandler {
next_request_id: Arc<AtomicU64>,
pending_requests: Arc<Mutex<HashMap<RequestId, PendingRequest>>>,
initialized: Arc<AtomicBool>,
hook_tx: Option<mpsc::UnboundedSender<(String, HookEvent)>>,
permission_tx: Option<mpsc::UnboundedSender<(RequestId, PermissionRequest)>>,
}
impl ProtocolHandler {
pub fn new() -> Self {
Self {
next_request_id: Arc::new(AtomicU64::new(1)),
pending_requests: Arc::new(Mutex::new(HashMap::new())),
initialized: Arc::new(AtomicBool::new(false)),
hook_tx: None,
permission_tx: None,
}
}
pub fn set_hook_channel(&mut self, tx: mpsc::UnboundedSender<(String, HookEvent)>) {
self.hook_tx = Some(tx);
}
pub fn set_permission_channel(
&mut self,
tx: mpsc::UnboundedSender<(RequestId, PermissionRequest)>,
) {
self.permission_tx = Some(tx);
}
pub fn is_initialized(&self) -> bool {
self.initialized.load(Ordering::SeqCst)
}
pub fn set_initialized(&self, value: bool) {
self.initialized.store(value, Ordering::SeqCst);
}
fn next_id(&self) -> RequestId {
let id = self.next_request_id.fetch_add(1, Ordering::SeqCst);
RequestId::new(format!("req-{id}"))
}
pub fn create_init_request(&self) -> InitRequest {
InitRequest {
protocol_version: "1.0".to_string(),
sdk_version: crate::VERSION.to_string(),
capabilities: ClientCapabilities {
bidirectional: true,
hooks: true,
permissions: true,
interrupts: true,
},
}
}
pub fn handle_init_response(&self, response: InitResponse) -> Result<()> {
if response.protocol_version != "1.0" {
return Err(ClaudeError::protocol_error(format!(
"Unsupported protocol version: {}",
response.protocol_version
)));
}
self.initialized.store(true, Ordering::SeqCst);
Ok(())
}
pub async fn send_request(
&self,
request: ControlRequest,
) -> Result<oneshot::Receiver<ControlResponse>> {
if !self.is_initialized() {
return Err(ClaudeError::protocol_error(
"Protocol not initialized - call init first",
));
}
let id = self.get_request_id(&request);
let (response_tx, response_rx) = oneshot::channel();
let pending = PendingRequest { response_tx };
{
let mut pending_requests = self.pending_requests.lock().await;
pending_requests.insert(id, pending);
}
Ok(response_rx)
}
fn get_request_id(&self, request: &ControlRequest) -> RequestId {
match request {
ControlRequest::Interrupt { id } => id.clone(),
ControlRequest::SendMessage { id, .. } => id.clone(),
ControlRequest::HookResponse { id, .. } => id.clone(),
ControlRequest::PermissionResponse { id, .. } => id.clone(),
}
}
pub async fn handle_response(&self, response: ControlResponse) -> Result<()> {
match &response {
ControlResponse::Success { id, .. } | ControlResponse::Error { id, .. } => {
let mut pending_requests = self.pending_requests.lock().await;
if let Some(pending) = pending_requests.remove(id) {
let _ = pending.response_tx.send(response);
}
Ok(())
}
ControlResponse::Hook { id, event } => {
if let Some(ref tx) = self.hook_tx {
tx.send((id.clone(), *event))
.map_err(|_| ClaudeError::protocol_error("Hook channel closed"))?;
}
Ok(())
}
ControlResponse::Permission { id, request } => {
if let Some(ref tx) = self.permission_tx {
tx.send((id.clone(), request.clone()))
.map_err(|_| ClaudeError::protocol_error("Permission channel closed"))?;
}
Ok(())
}
}
}
pub fn create_interrupt_request(&self) -> ControlRequest {
ControlRequest::Interrupt {
id: self.next_id(),
}
}
pub fn create_send_message_request(&self, content: String) -> ControlRequest {
ControlRequest::SendMessage {
id: self.next_id(),
content,
}
}
pub fn create_hook_response(
&self,
hook_id: String,
response: serde_json::Value,
) -> ControlRequest {
ControlRequest::HookResponse {
id: self.next_id(),
hook_id,
response,
}
}
pub fn create_permission_response(
&self,
request_id: RequestId,
result: PermissionResult,
) -> ControlRequest {
ControlRequest::PermissionResponse {
id: self.next_id(),
request_id,
result,
}
}
pub fn serialize_message(&self, message: &ControlMessage) -> Result<String> {
serde_json::to_string(message)
.map(|s| format!("{s}\n"))
.map_err(|e| ClaudeError::json_encode(format!("Failed to serialize message: {e}")))
}
pub fn deserialize_message(&self, json: &str) -> Result<ControlMessage> {
serde_json::from_str(json)
.map_err(|e| ClaudeError::json_decode(format!("Failed to deserialize message: {e}")))
}
}
impl Default for ProtocolHandler {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::types::ToolName;
#[test]
fn test_request_id_generation() {
let handler = ProtocolHandler::new();
let id1 = handler.next_id();
let id2 = handler.next_id();
assert_ne!(id1, id2);
}
#[test]
fn test_init_request_creation() {
let handler = ProtocolHandler::new();
let init_req = handler.create_init_request();
assert_eq!(init_req.protocol_version, "1.0");
assert!(init_req.capabilities.bidirectional);
}
#[test]
fn test_serialize_deserialize() {
let handler = ProtocolHandler::new();
let request = handler.create_interrupt_request();
let message = ControlMessage::Request(request);
let serialized = handler.serialize_message(&message).unwrap();
let deserialized = handler.deserialize_message(serialized.trim()).unwrap();
match deserialized {
ControlMessage::Request(ControlRequest::Interrupt { .. }) => {}
_ => panic!("Wrong message type"),
}
}
#[test]
fn test_deserialize_invalid_json() {
let handler = ProtocolHandler::new();
let result = handler.deserialize_message("not valid json");
assert!(result.is_err());
}
#[test]
fn test_deserialize_invalid_message_structure() {
let handler = ProtocolHandler::new();
let invalid = r#"{"type":"unknown_type"}"#;
let result = handler.deserialize_message(invalid);
assert!(result.is_err());
}
#[test]
fn test_deserialize_missing_fields() {
let handler = ProtocolHandler::new();
let missing = r#"{"type":"request"}"#;
let result = handler.deserialize_message(missing);
assert!(result.is_err());
}
#[tokio::test]
async fn test_handle_response_with_missing_pending_request() {
let handler = ProtocolHandler::new();
handler.set_initialized(true);
let response = ControlResponse::Success {
id: RequestId::new("non-existent-req"),
data: None,
};
let result = handler.handle_response(response).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_hook_response_without_channel() {
let handler = ProtocolHandler::new();
let response = ControlResponse::Hook {
id: "hook-1".to_string(),
event: HookEvent::PreToolUse,
};
let result = handler.handle_response(response).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_permission_response_without_channel() {
let handler = ProtocolHandler::new();
let response = ControlResponse::Permission {
id: RequestId::new("perm-1"),
request: PermissionRequest {
tool_name: ToolName::new("test"),
tool_input: serde_json::json!({}),
context: crate::types::ToolPermissionContext {
suggestions: vec![],
},
},
};
let result = handler.handle_response(response).await;
assert!(result.is_ok());
}
#[test]
fn test_init_response_with_wrong_version() {
let handler = ProtocolHandler::new();
let init_response = InitResponse {
protocol_version: "999.0".to_string(),
cli_version: "1.0.0".to_string(),
capabilities: ServerCapabilities {
streaming: true,
tools: true,
mcp: true,
},
session_id: "test".to_string(),
};
let result = handler.handle_init_response(init_response);
assert!(result.is_err());
assert!(!handler.is_initialized());
}
#[tokio::test]
async fn test_send_request_without_init() {
let handler = ProtocolHandler::new();
assert!(!handler.is_initialized());
let request = handler.create_interrupt_request();
let result = handler.send_request(request).await;
assert!(result.is_err());
}
#[test]
fn test_serialize_all_request_types() {
let handler = ProtocolHandler::new();
let req = handler.create_interrupt_request();
let msg = ControlMessage::Request(req);
assert!(handler.serialize_message(&msg).is_ok());
let req = handler.create_send_message_request("test".to_string());
let msg = ControlMessage::Request(req);
assert!(handler.serialize_message(&msg).is_ok());
let req = handler.create_hook_response("hook-1".to_string(), serde_json::json!({}));
let msg = ControlMessage::Request(req);
assert!(handler.serialize_message(&msg).is_ok());
let req = handler.create_permission_response(
RequestId::new("req-1"),
crate::types::PermissionResult::Allow(crate::types::PermissionResultAllow {
updated_input: None,
updated_permissions: None,
}),
);
let msg = ControlMessage::Request(req);
assert!(handler.serialize_message(&msg).is_ok());
}
#[test]
fn test_serialize_all_response_types() {
let handler = ProtocolHandler::new();
let resp = ControlResponse::Success {
id: RequestId::new("req-1"),
data: Some(serde_json::json!({"result": "ok"})),
};
let msg = ControlMessage::Response(resp);
assert!(handler.serialize_message(&msg).is_ok());
let resp = ControlResponse::Error {
id: RequestId::new("req-1"),
message: "test error".to_string(),
code: Some("ERR_TEST".to_string()),
};
let msg = ControlMessage::Response(resp);
assert!(handler.serialize_message(&msg).is_ok());
let resp = ControlResponse::Hook {
id: "hook-1".to_string(),
event: HookEvent::PreToolUse,
};
let msg = ControlMessage::Response(resp);
assert!(handler.serialize_message(&msg).is_ok());
let resp = ControlResponse::Permission {
id: RequestId::new("perm-1"),
request: PermissionRequest {
tool_name: ToolName::new("test"),
tool_input: serde_json::json!({}),
context: crate::types::ToolPermissionContext {
suggestions: vec![],
},
},
};
let msg = ControlMessage::Response(resp);
assert!(handler.serialize_message(&msg).is_ok());
}
#[test]
fn test_get_request_id() {
let handler = ProtocolHandler::new();
let interrupt = ControlRequest::Interrupt {
id: RequestId::new("id1"),
};
assert_eq!(handler.get_request_id(&interrupt).as_str(), "id1");
let send_msg = ControlRequest::SendMessage {
id: RequestId::new("id2"),
content: "test".to_string(),
};
assert_eq!(handler.get_request_id(&send_msg).as_str(), "id2");
let hook_resp = ControlRequest::HookResponse {
id: RequestId::new("id3"),
hook_id: "hook".to_string(),
response: serde_json::json!({}),
};
assert_eq!(handler.get_request_id(&hook_resp).as_str(), "id3");
let perm_resp = ControlRequest::PermissionResponse {
id: RequestId::new("id4"),
request_id: RequestId::new("perm"),
result: crate::types::PermissionResult::Allow(
crate::types::PermissionResultAllow {
updated_input: None,
updated_permissions: None,
},
),
};
assert_eq!(handler.get_request_id(&perm_resp).as_str(), "id4");
}
}