use std::sync::Arc;
use std::time::Duration;
use tokio::sync::Mutex;
use uuid::Uuid;
use crate::control::handlers::ControlHandlers;
use crate::control::messages::{ControlRequest, ControlResponse, IncomingControlRequest};
use crate::control::pending::PendingRequests;
use crate::error::ClawError;
use crate::options::ClaudeAgentOptions;
use crate::transport::Transport;
pub mod handlers;
pub mod messages;
pub mod pending;
pub struct ControlProtocol {
transport: Arc<dyn Transport>,
pending: PendingRequests,
handlers: Arc<Mutex<ControlHandlers>>,
}
impl ControlProtocol {
pub fn new(transport: Arc<dyn Transport>) -> Self {
Self {
transport,
pending: PendingRequests::new(),
handlers: Arc::new(Mutex::new(ControlHandlers::new())),
}
}
pub async fn handlers(&self) -> tokio::sync::MutexGuard<'_, ControlHandlers> {
self.handlers.lock().await
}
pub async fn initialize(&self, options: &ClaudeAgentOptions) -> Result<(), ClawError> {
let request = ControlRequest::Initialize {
hooks: options.hooks.clone(),
agents: options.agents.clone(),
sdk_mcp_servers: options
.sdk_mcp_servers
.iter()
.map(|s| s.name.clone())
.collect(),
};
match self.request(request).await? {
ControlResponse::Success { .. } => Ok(()),
ControlResponse::Error { error, .. } => Err(ClawError::ControlError(format!(
"Initialization failed: {}",
error
))),
}
}
pub async fn request(&self, request: ControlRequest) -> Result<ControlResponse, ClawError> {
let id = Uuid::new_v4().to_string();
let (tx, rx) = tokio::sync::oneshot::channel();
self.pending.insert(id.clone(), tx).await;
let msg = serde_json::json!({
"type": "control_request",
"request_id": id,
"request": request,
});
let mut bytes = match serde_json::to_vec(&msg) {
Ok(b) => b,
Err(e) => {
self.pending.cancel(&id).await;
return Err(e.into());
}
};
bytes.push(b'\n');
if let Err(e) = self.transport.write(&bytes).await {
self.pending.cancel(&id).await;
return Err(ClawError::Connection(format!(
"Failed to send control request: {}",
e
)));
}
match tokio::time::timeout(Duration::from_secs(60), rx).await {
Ok(Ok(response)) => Ok(response),
Ok(Err(_)) => Err(ClawError::ControlError(
"Response channel closed".to_string(),
)),
Err(_) => {
self.pending.cancel(&id).await;
Err(ClawError::ControlTimeout {
subtype: "control_request".to_string(),
})
}
}
}
pub async fn handle_response(&self, request_id: &str, response: ControlResponse) {
self.pending.complete(request_id, response).await;
}
pub async fn handle_incoming(&self, request_id: &str, request: IncomingControlRequest) {
use serde_json::json;
use tracing::error;
let response = match request {
IncomingControlRequest::CanUseTool {
tool_name,
tool_input,
} => {
let handler = {
let handlers = self.handlers.lock().await;
handlers.can_use_tool.clone()
};
if let Some(handler) = handler {
match handler.can_use_tool(&tool_name, &tool_input).await {
Ok(decision) => {
use crate::permissions::PermissionDecision;
match decision {
PermissionDecision::Allow { updated_input } => {
let mut data = json!({ "allowed": true });
if let Some(input) = updated_input {
data["updatedInput"] = input;
}
ControlResponse::Success { data }
}
PermissionDecision::Deny { interrupt } => {
ControlResponse::Success {
data: json!({
"allowed": false,
"interrupt": interrupt,
}),
}
}
}
}
Err(e) => ControlResponse::Error {
error: e.to_string(),
extra: json!({}),
},
}
} else {
ControlResponse::Success {
data: json!({ "allowed": true }),
}
}
}
IncomingControlRequest::HookCallback {
hook_id,
hook_event,
hook_input,
} => {
let handler = {
let handlers = self.handlers.lock().await;
handlers.hook_callbacks.get(&hook_id).cloned()
};
if let Some(handler) = handler {
match handler.call(hook_event, hook_input).await {
Ok(result) => ControlResponse::Success { data: result },
Err(e) => ControlResponse::Error {
error: e.to_string(),
extra: json!({}),
},
}
} else {
ControlResponse::Error {
error: format!("No handler registered for hook_id: {}", hook_id),
extra: json!({}),
}
}
}
IncomingControlRequest::McpMessage {
server_name,
message,
} => {
let handler = {
let handlers = self.handlers.lock().await;
handlers.mcp_message.clone()
};
if let Some(handler) = handler {
match handler.handle(&server_name, message).await {
Ok(result) => ControlResponse::Success {
data: json!({"mcp_response": result}),
},
Err(e) => ControlResponse::Error {
error: e.to_string(),
extra: json!({}),
},
}
} else {
ControlResponse::Error {
error: "No MCP message handler registered".to_string(),
extra: json!({}),
}
}
}
};
let msg = match response {
ControlResponse::Success { data } => {
json!({
"type": "control_response",
"response": {
"subtype": "success",
"request_id": request_id,
"response": data,
}
})
}
ControlResponse::Error { error, extra } => {
let mut resp = json!({
"type": "control_response",
"response": {
"subtype": "error",
"request_id": request_id,
"error": error,
}
});
if let (Some(resp_obj), Some(extra_obj)) =
(resp["response"].as_object_mut(), extra.as_object())
{
for (k, v) in extra_obj {
resp_obj.insert(k.clone(), v.clone());
}
}
resp
}
};
match serde_json::to_vec(&msg) {
Ok(mut bytes) => {
bytes.push(b'\n'); if let Err(e) = self.transport.write(&bytes).await {
error!("Failed to send control response: {}", e);
}
}
Err(e) => {
error!("Failed to serialize control response: {}", e);
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::control::handlers::{CanUseToolHandler, HookHandler, McpMessageHandler};
use crate::options::HookEvent;
use async_trait::async_trait;
use serde_json::{Value, json};
use std::sync::Arc;
use tokio::sync::Mutex as TokioMutex;
use tokio::sync::mpsc;
#[allow(clippy::type_complexity)]
struct MockTransport {
sent: Arc<TokioMutex<Vec<Vec<u8>>>>,
receiver: Arc<TokioMutex<Option<mpsc::UnboundedReceiver<Result<Value, ClawError>>>>>,
}
impl MockTransport {
fn new() -> Self {
let (_sender, receiver) = mpsc::unbounded_channel();
Self {
sent: Arc::new(TokioMutex::new(Vec::new())),
receiver: Arc::new(TokioMutex::new(Some(receiver))),
}
}
async fn get_sent(&self) -> Vec<Vec<u8>> {
self.sent.lock().await.clone()
}
}
#[async_trait]
impl Transport for MockTransport {
async fn connect(&mut self) -> Result<(), ClawError> {
Ok(())
}
async fn write(&self, data: &[u8]) -> Result<(), ClawError> {
self.sent.lock().await.push(data.to_vec());
Ok(())
}
fn messages(&self) -> mpsc::UnboundedReceiver<Result<Value, ClawError>> {
self.receiver.blocking_lock().take().unwrap()
}
async fn end_input(&self) -> Result<(), ClawError> {
Ok(())
}
async fn close(&self) -> Result<(), ClawError> {
Ok(())
}
fn is_ready(&self) -> bool {
true
}
}
#[derive(Debug)]
struct MockCanUseToolHandler;
#[async_trait]
impl CanUseToolHandler for MockCanUseToolHandler {
async fn can_use_tool(
&self,
tool_name: &str,
_tool_input: &Value,
) -> Result<crate::permissions::PermissionDecision, ClawError> {
if tool_name == "Read" {
Ok(crate::permissions::PermissionDecision::Allow {
updated_input: None,
})
} else {
Ok(crate::permissions::PermissionDecision::Deny { interrupt: false })
}
}
}
struct MockHookHandler;
#[async_trait]
impl HookHandler for MockHookHandler {
async fn call(
&self,
_hook_event: HookEvent,
hook_input: Value,
) -> Result<Value, ClawError> {
Ok(json!({ "echo": hook_input }))
}
}
struct MockMcpHandler;
#[async_trait]
impl McpMessageHandler for MockMcpHandler {
async fn handle(&self, server_name: &str, _message: Value) -> Result<Value, ClawError> {
Ok(json!({ "server": server_name }))
}
}
#[tokio::test]
async fn test_request_success() {
let transport = Arc::new(MockTransport::new());
let control = ControlProtocol::new(transport.clone() as Arc<dyn Transport>);
let control_clone = Arc::new(control);
let transport_clone = transport.clone();
let control_for_response = control_clone.clone();
tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
let sent = transport_clone.get_sent().await;
if sent.is_empty() {
return;
}
let msg: Value = serde_json::from_slice(&sent[0]).unwrap();
let request_id = msg["request_id"].as_str().unwrap().to_string();
control_for_response
.handle_response(
&request_id,
ControlResponse::Success {
data: json!({ "result": "ok" }),
},
)
.await;
});
let response = control_clone
.request(ControlRequest::Interrupt)
.await
.unwrap();
match response {
ControlResponse::Success { data } => {
assert_eq!(data["result"], "ok");
}
_ => panic!("Expected success response"),
}
}
#[tokio::test]
async fn test_initialize_success() {
let transport = Arc::new(MockTransport::new());
let control = ControlProtocol::new(transport.clone() as Arc<dyn Transport>);
let control_clone = Arc::new(control);
let transport_clone = transport.clone();
let control_for_response = control_clone.clone();
tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
let sent = transport_clone.get_sent().await;
if sent.is_empty() {
return;
}
let msg: Value = serde_json::from_slice(&sent[0]).unwrap();
let request_id = msg["request_id"].as_str().unwrap().to_string();
control_for_response
.handle_response(&request_id, ControlResponse::Success { data: json!({}) })
.await;
});
let options = ClaudeAgentOptions::default();
control_clone.initialize(&options).await.unwrap();
}
#[tokio::test]
async fn test_initialize_error() {
let transport = Arc::new(MockTransport::new());
let control = ControlProtocol::new(transport.clone() as Arc<dyn Transport>);
let control_clone = Arc::new(control);
let transport_clone = transport.clone();
let control_for_response = control_clone.clone();
tokio::spawn(async move {
tokio::time::sleep(tokio::time::Duration::from_millis(50)).await;
let sent = transport_clone.get_sent().await;
if sent.is_empty() {
return;
}
let msg: Value = serde_json::from_slice(&sent[0]).unwrap();
let request_id = msg["request_id"].as_str().unwrap().to_string();
control_for_response
.handle_response(
&request_id,
ControlResponse::Error {
error: "Bad config".to_string(),
extra: json!({}),
},
)
.await;
});
let options = ClaudeAgentOptions::default();
let result = control_clone.initialize(&options).await;
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Initialization failed")
);
}
#[tokio::test]
async fn test_handle_incoming_can_use_tool_with_handler() {
let transport = Arc::new(MockTransport::new());
let control = ControlProtocol::new(transport.clone() as Arc<dyn Transport>);
{
let mut handlers = control.handlers().await;
handlers.register_can_use_tool(Arc::new(MockCanUseToolHandler));
}
let request = IncomingControlRequest::CanUseTool {
tool_name: "Read".to_string(),
tool_input: json!({}),
};
control.handle_incoming("req_1", request).await;
let sent = transport.get_sent().await;
assert_eq!(sent.len(), 1);
let msg: Value = serde_json::from_slice(&sent[0]).unwrap();
assert_eq!(msg["type"], "control_response");
assert!(
msg.get("request_id").is_none(),
"request_id should NOT be at top level"
);
assert_eq!(msg["response"]["subtype"], "success");
assert_eq!(msg["response"]["request_id"], "req_1");
assert_eq!(msg["response"]["response"]["allowed"], true);
}
#[tokio::test]
async fn test_handle_incoming_can_use_tool_default() {
let transport = Arc::new(MockTransport::new());
let control = ControlProtocol::new(transport.clone() as Arc<dyn Transport>);
let request = IncomingControlRequest::CanUseTool {
tool_name: "Bash".to_string(),
tool_input: json!({}),
};
control.handle_incoming("req_1", request).await;
let sent = transport.get_sent().await;
let msg: Value = serde_json::from_slice(&sent[0]).unwrap();
assert_eq!(msg["response"]["response"]["allowed"], true);
}
#[tokio::test]
async fn test_handle_incoming_hook_callback() {
let transport = Arc::new(MockTransport::new());
let control = ControlProtocol::new(transport.clone() as Arc<dyn Transport>);
{
let mut handlers = control.handlers().await;
handlers.register_hook("hook1".to_string(), Arc::new(MockHookHandler));
}
let request = IncomingControlRequest::HookCallback {
hook_id: "hook1".to_string(),
hook_event: crate::options::HookEvent::PreToolUse,
hook_input: json!({ "test": "data" }),
};
control.handle_incoming("req_1", request).await;
let sent = transport.get_sent().await;
let msg: Value = serde_json::from_slice(&sent[0]).unwrap();
assert_eq!(msg["response"]["subtype"], "success");
assert_eq!(msg["response"]["response"]["echo"]["test"], "data");
}
#[tokio::test]
async fn test_handle_incoming_mcp_message() {
let transport = Arc::new(MockTransport::new());
let control = ControlProtocol::new(transport.clone() as Arc<dyn Transport>);
{
let mut handlers = control.handlers().await;
handlers.register_mcp_message(Arc::new(MockMcpHandler));
}
let request = IncomingControlRequest::McpMessage {
server_name: "test_server".to_string(),
message: json!({ "method": "test" }),
};
control.handle_incoming("req_1", request).await;
let sent = transport.get_sent().await;
let msg: Value = serde_json::from_slice(&sent[0]).unwrap();
assert_eq!(msg["response"]["subtype"], "success");
assert_eq!(msg["response"]["request_id"], "req_1");
assert_eq!(
msg["response"]["response"]["mcp_response"]["server"],
"test_server"
);
}
}