use serde_json::Value;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::task::{Context, Poll};
use tokio::sync::Mutex;
use tokio::sync::mpsc;
use tokio_stream::Stream;
use crate::control::ControlProtocol;
use crate::control::handlers::{CanUseToolHandler, HookHandler, McpMessageHandler};
use crate::error::ClawError;
use crate::messages::Message;
use crate::options::{ClaudeAgentOptions, PermissionMode};
use crate::transport::Transport;
type CurrentTurnSender = Arc<Mutex<Option<mpsc::UnboundedSender<Result<Value, ClawError>>>>>;
pub struct ClaudeClient {
control: Option<Arc<ControlProtocol>>,
transport: Option<Arc<dyn Transport>>,
options: ClaudeAgentOptions,
current_turn_tx: CurrentTurnSender,
is_initialized: Arc<AtomicBool>,
pre_transport: Option<Box<dyn Transport>>,
pending_mcp_handler: std::sync::Mutex<Option<Arc<dyn McpMessageHandler>>>,
}
impl ClaudeClient {
pub fn new(options: ClaudeAgentOptions) -> Result<Self, ClawError> {
Ok(Self {
control: None,
transport: None,
pre_transport: None,
options,
current_turn_tx: Arc::new(Mutex::new(None)),
is_initialized: Arc::new(AtomicBool::new(false)),
pending_mcp_handler: std::sync::Mutex::new(None),
})
}
pub fn with_transport(
options: ClaudeAgentOptions,
transport: Box<dyn Transport>,
) -> Result<Self, ClawError> {
Ok(Self {
control: None,
transport: None,
pre_transport: Some(transport),
options,
current_turn_tx: Arc::new(Mutex::new(None)),
is_initialized: Arc::new(AtomicBool::new(false)),
pending_mcp_handler: std::sync::Mutex::new(None),
})
}
pub fn is_connected(&self) -> bool {
self.transport
.as_ref()
.map(|t| t.is_ready())
.unwrap_or(false)
&& self.is_initialized.load(Ordering::SeqCst)
}
pub async fn connect(&mut self) -> Result<(), ClawError> {
use crate::transport::SubprocessCLITransport;
let mut transport: Box<dyn Transport> = if let Some(pre) = self.pre_transport.take() {
pre
} else {
let mut cli_args = self.options.to_base_cli_args();
cli_args.push("--input-format".to_string());
cli_args.push("stream-json".to_string());
let mut t = SubprocessCLITransport::new(self.options.cli_path.clone(), cli_args);
if let Some(cwd) = &self.options.cwd {
t.set_cwd(cwd.clone());
}
if !self.options.env.is_empty() {
t.set_env(self.options.env.clone());
}
Box::new(t) as Box<dyn Transport>
};
transport.connect().await?;
let message_rx = transport.messages();
let transport_arc: Arc<dyn Transport> = Arc::from(transport as Box<dyn Transport>);
let control = Arc::new(ControlProtocol::new(transport_arc.clone()));
Self::spawn_message_router(message_rx, control.clone(), self.current_turn_tx.clone());
let pending_mcp = self
.pending_mcp_handler
.lock()
.unwrap_or_else(|e| e.into_inner())
.take();
if let Some(handler) = pending_mcp {
let mut handlers = control.handlers().await;
handlers.register_mcp_message(handler);
}
if let Some(handler) = self.options.permission_handler.clone() {
let mut handlers = control.handlers().await;
handlers.register_can_use_tool(handler);
}
control.initialize(&self.options).await?;
self.transport = Some(transport_arc);
self.control = Some(control);
self.is_initialized.store(true, Ordering::SeqCst);
Ok(())
}
pub async fn close(&mut self) -> Result<(), ClawError> {
if let Some(transport) = &self.transport {
transport.close().await?;
}
*self.current_turn_tx.lock().await = None;
self.is_initialized.store(false, Ordering::SeqCst);
self.transport = None;
self.control = None;
Ok(())
}
pub async fn disconnect(&mut self) -> Result<(), ClawError> {
self.close().await
}
pub async fn send_message(
&self,
content: impl Into<String>,
) -> Result<ResponseStream, ClawError> {
if self.control.is_none() {
return Err(ClawError::Connection(
"Not connected. Call connect() first.".to_string(),
));
}
let (tx, rx) = mpsc::unbounded_channel();
*self.current_turn_tx.lock().await = Some(tx);
self.write_message(content.into().as_str()).await?;
Ok(ResponseStream::new(rx))
}
async fn write_message(&self, content: &str) -> Result<(), ClawError> {
use serde_json::json;
let transport = self
.transport
.as_ref()
.ok_or_else(|| ClawError::Connection("Transport not available".to_string()))?;
let message = json!({
"type": "user",
"session_id": "",
"message": {
"role": "user",
"content": content
},
"parent_tool_use_id": null
});
let mut bytes = serde_json::to_vec(&message).map_err(|e| {
ClawError::Connection(format!("Failed to serialize user message: {}", e))
})?;
bytes.push(b'\n');
transport.write(&bytes).await?;
Ok(())
}
fn spawn_message_router(
mut rx: mpsc::UnboundedReceiver<Result<Value, ClawError>>,
control: Arc<ControlProtocol>,
current_turn_tx: CurrentTurnSender,
) {
use crate::control::messages::{ControlResponse, IncomingControlRequest};
use tracing::{debug, warn};
tokio::spawn(async move {
while let Some(msg) = rx.recv().await {
match msg {
Ok(value) => {
let msg_type = value.get("type").and_then(|v| v.as_str());
match msg_type {
Some("control_response") => {
let request_id = value
.get("response")
.and_then(|r| r.get("request_id"))
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
debug!(
request_id = %request_id,
"Received control_response"
);
if let Some(response_val) = value.get("response") {
match serde_json::from_value::<ControlResponse>(
response_val.clone(),
) {
Ok(response) => {
control.handle_response(&request_id, response).await;
}
Err(e) => {
warn!("Failed to parse control response: {}", e);
}
}
}
}
Some("control_request") => {
let request_id = value
.get("request_id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
if let Some(request_val) = value.get("request") {
match serde_json::from_value::<IncomingControlRequest>(
request_val.clone(),
) {
Ok(incoming) => {
control.handle_incoming(&request_id, incoming).await;
}
Err(e) => {
warn!(
"Failed to parse incoming control request: {}",
e
);
}
}
}
}
_ => {
let sender = {
let guard = current_turn_tx.lock().await;
guard.clone()
};
if let Some(tx) = sender
&& tx.send(Ok(value)).is_err()
{
*current_turn_tx.lock().await = None;
}
}
}
}
Err(e) => {
let sender = {
let guard = current_turn_tx.lock().await;
guard.clone()
};
if let Some(tx) = sender
&& tx.send(Err(e)).is_err()
{
*current_turn_tx.lock().await = None;
}
}
}
}
debug!("Message routing task finished");
});
}
pub async fn interrupt(&self) -> Result<(), ClawError> {
use crate::control::messages::{ControlRequest, ControlResponse};
let control = self.control.as_ref().ok_or_else(|| {
ClawError::Connection("Not connected. Call connect() first.".to_string())
})?;
let response = control.request(ControlRequest::Interrupt).await?;
match response {
ControlResponse::Success { .. } => Ok(()),
ControlResponse::Error { error, .. } => Err(ClawError::ControlError(format!(
"Interrupt failed: {}",
error
))),
}
}
pub async fn set_permission_mode(&self, mode: PermissionMode) -> Result<(), ClawError> {
use crate::control::messages::{ControlRequest, ControlResponse};
let control = self.control.as_ref().ok_or_else(|| {
ClawError::Connection("Not connected. Call connect() first.".to_string())
})?;
let response = control
.request(ControlRequest::SetPermissionMode {
mode: mode.to_cli_arg().to_string(),
})
.await?;
match response {
ControlResponse::Success { .. } => Ok(()),
ControlResponse::Error { error, .. } => Err(ClawError::ControlError(format!(
"Set permission mode failed: {}",
error
))),
}
}
pub async fn set_model(&self, model: impl Into<String>) -> Result<(), ClawError> {
use crate::control::messages::{ControlRequest, ControlResponse};
let control = self.control.as_ref().ok_or_else(|| {
ClawError::Connection("Not connected. Call connect() first.".to_string())
})?;
let response = control
.request(ControlRequest::SetModel {
model: model.into(),
})
.await?;
match response {
ControlResponse::Success { .. } => Ok(()),
ControlResponse::Error { error, .. } => Err(ClawError::ControlError(format!(
"Set model failed: {}",
error
))),
}
}
pub async fn mcp_status(&self) -> Result<serde_json::Value, ClawError> {
use crate::control::messages::{ControlRequest, ControlResponse};
let control = self.control.as_ref().ok_or_else(|| {
ClawError::Connection("Not connected. Call connect() first.".to_string())
})?;
let response = control.request(ControlRequest::McpStatus).await?;
match response {
ControlResponse::Success { data } => Ok(data),
ControlResponse::Error { error, .. } => Err(ClawError::ControlError(format!(
"MCP status query failed: {}",
error
))),
}
}
pub async fn rewind_files(&self, message_id: impl Into<String>) -> Result<(), ClawError> {
use crate::control::messages::{ControlRequest, ControlResponse};
let control = self.control.as_ref().ok_or_else(|| {
ClawError::Connection("Not connected. Call connect() first.".to_string())
})?;
let response = control
.request(ControlRequest::RewindFiles {
user_message_id: message_id.into(),
})
.await?;
match response {
ControlResponse::Success { .. } => Ok(()),
ControlResponse::Error { error, .. } => Err(ClawError::ControlError(format!(
"Rewind files failed: {}",
error
))),
}
}
pub async fn get_server_info(&self) -> Result<serde_json::Value, ClawError> {
use crate::control::messages::{ControlRequest, ControlResponse};
let control = self.control.as_ref().ok_or_else(|| {
ClawError::Connection("Not connected. Call connect() first.".to_string())
})?;
let response = control.request(ControlRequest::GetServerInfo).await?;
match response {
ControlResponse::Success { data } => Ok(data),
ControlResponse::Error { error, .. } => Err(ClawError::ControlError(format!(
"Get server info failed: {}",
error
))),
}
}
pub async fn register_can_use_tool_handler(&self, handler: Arc<dyn CanUseToolHandler>) {
if let Some(control) = &self.control {
let mut handlers = control.handlers().await;
handlers.register_can_use_tool(handler);
}
}
pub async fn register_hook(&self, hook_id: String, handler: Arc<dyn HookHandler>) {
if let Some(control) = &self.control {
let mut handlers = control.handlers().await;
handlers.register_hook(hook_id, handler);
}
}
pub async fn register_mcp_message_handler(&self, handler: Arc<dyn McpMessageHandler>) {
if let Some(control) = &self.control {
let mut handlers = control.handlers().await;
handlers.register_mcp_message(handler);
} else {
if let Ok(mut guard) = self.pending_mcp_handler.lock() {
*guard = Some(handler);
}
}
}
}
pub async fn with_client<F, Fut>(options: ClaudeAgentOptions, f: F) -> Result<(), ClawError>
where
F: FnOnce(&ClaudeClient) -> Fut,
Fut: Future<Output = Result<(), ClawError>>,
{
let mut client = ClaudeClient::new(options)?;
client.connect().await?;
let user_result = f(&client).await;
let close_result = client.close().await;
match user_result {
Err(e) => Err(e),
Ok(()) => close_result,
}
}
pub struct ResponseStream {
rx: mpsc::UnboundedReceiver<Result<Value, ClawError>>,
is_complete: bool,
}
impl ResponseStream {
fn new(rx: mpsc::UnboundedReceiver<Result<Value, ClawError>>) -> Self {
Self {
rx,
is_complete: false,
}
}
pub fn is_complete(&self) -> bool {
self.is_complete
}
pub async fn receive_response(mut self) -> Result<Vec<Message>, ClawError> {
use tokio_stream::StreamExt;
let mut messages = Vec::new();
while let Some(result) = self.next().await {
let msg = result?;
let is_result = matches!(msg, Message::Result(_));
messages.push(msg);
if is_result {
break;
}
}
Ok(messages)
}
}
impl Stream for ResponseStream {
type Item = Result<Message, ClawError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.is_complete {
return Poll::Ready(None);
}
match Pin::new(&mut self.rx).poll_recv(cx) {
Poll::Ready(Some(Ok(value))) => {
match serde_json::from_value::<Message>(value.clone()) {
Ok(message) => Poll::Ready(Some(Ok(message))),
Err(e) => Poll::Ready(Some(Err(ClawError::MessageParse {
reason: format!("Failed to parse message: {}", e),
raw: value.to_string(),
}))),
}
}
Poll::Ready(Some(Err(e))) => Poll::Ready(Some(Err(e))),
Poll::Ready(None) => {
self.is_complete = true;
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
}
}
pub type ClaudeSDKClient = ClaudeClient;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_client() {
let options = ClaudeAgentOptions::default();
let client = ClaudeClient::new(options);
assert!(client.is_ok());
}
#[test]
fn test_not_connected_initially() {
let options = ClaudeAgentOptions::default();
let client = ClaudeClient::new(options).unwrap();
assert!(!client.is_connected());
}
#[test]
fn test_response_stream_not_complete_initially() {
let (_tx, rx) = mpsc::unbounded_channel();
let stream = ResponseStream::new(rx);
assert!(!stream.is_complete());
}
#[tokio::test]
async fn test_send_message_without_connect() {
let options = ClaudeAgentOptions::default();
let client = ClaudeClient::new(options).unwrap();
let result = client.send_message("test").await;
assert!(result.is_err());
if let Err(e) = result {
assert!(matches!(e, ClawError::Connection(_)));
}
}
#[tokio::test]
async fn test_interrupt_without_connect() {
let options = ClaudeAgentOptions::default();
let client = ClaudeClient::new(options).unwrap();
let result = client.interrupt().await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), ClawError::Connection(_)));
}
#[tokio::test]
async fn test_set_permission_mode_without_connect() {
let options = ClaudeAgentOptions::default();
let client = ClaudeClient::new(options).unwrap();
let result = client.set_permission_mode(PermissionMode::Ask).await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), ClawError::Connection(_)));
}
#[tokio::test]
async fn test_set_model_without_connect() {
let options = ClaudeAgentOptions::default();
let client = ClaudeClient::new(options).unwrap();
let result = client.set_model("claude-sonnet-4-5").await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), ClawError::Connection(_)));
}
#[tokio::test]
async fn test_mcp_status_without_connect() {
let options = ClaudeAgentOptions::default();
let client = ClaudeClient::new(options).unwrap();
let result = client.mcp_status().await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), ClawError::Connection(_)));
}
#[tokio::test]
async fn test_rewind_files_without_connect() {
let options = ClaudeAgentOptions::default();
let client = ClaudeClient::new(options).unwrap();
let result = client.rewind_files("msg_123").await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), ClawError::Connection(_)));
}
#[tokio::test]
async fn test_get_server_info_without_connect() {
let options = ClaudeAgentOptions::default();
let client = ClaudeClient::new(options).unwrap();
let result = client.get_server_info().await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), ClawError::Connection(_)));
}
#[test]
fn test_client_is_send() {
fn assert_send<T: Send>() {}
assert_send::<ClaudeClient>();
}
#[test]
fn test_client_is_sync() {
fn assert_sync<T: Sync>() {}
assert_sync::<ClaudeClient>();
}
#[test]
fn test_response_stream_is_send() {
fn assert_send<T: Send>() {}
assert_send::<ResponseStream>();
}
#[test]
fn test_response_stream_is_unpin() {
fn assert_unpin<T: Unpin>() {}
assert_unpin::<ResponseStream>();
}
#[test]
fn test_client_with_custom_options() {
let options = ClaudeAgentOptions::builder()
.max_turns(10)
.permission_mode(PermissionMode::AcceptEdits)
.model("claude-sonnet-4-5".to_string())
.build();
let client = ClaudeClient::new(options);
assert!(client.is_ok());
}
#[test]
fn test_multiple_clients() {
let options1 = ClaudeAgentOptions::default();
let options2 = ClaudeAgentOptions::default();
let client1 = ClaudeClient::new(options1).unwrap();
let client2 = ClaudeClient::new(options2).unwrap();
assert!(!client1.is_connected());
assert!(!client2.is_connected());
}
#[tokio::test]
async fn test_register_handlers_without_connect() {
use crate::control::handlers::{CanUseToolHandler, HookHandler, McpMessageHandler};
use crate::options::HookEvent;
use async_trait::async_trait;
use serde_json::{Value, json};
#[derive(Debug)]
struct TestPermHandler;
#[async_trait]
impl CanUseToolHandler for TestPermHandler {
async fn can_use_tool(
&self,
_tool_name: &str,
_tool_input: &serde_json::Value,
) -> Result<crate::permissions::PermissionDecision, ClawError> {
Ok(crate::permissions::PermissionDecision::Allow {
updated_input: None,
})
}
}
struct TestHookHandler;
#[async_trait]
impl HookHandler for TestHookHandler {
async fn call(
&self,
_hook_event: HookEvent,
hook_input: Value,
) -> Result<Value, ClawError> {
Ok(json!({ "echo": hook_input }))
}
}
struct TestMcpHandler;
#[async_trait]
impl McpMessageHandler for TestMcpHandler {
async fn handle(
&self,
_server_name: &str,
_message: Value,
) -> Result<Value, ClawError> {
Ok(json!({"result": "ok"}))
}
}
let options = ClaudeAgentOptions::default();
let client = ClaudeClient::new(options).unwrap();
client
.register_can_use_tool_handler(Arc::new(TestPermHandler))
.await;
client
.register_hook("test".to_string(), Arc::new(TestHookHandler))
.await;
client
.register_mcp_message_handler(Arc::new(TestMcpHandler))
.await;
}
#[tokio::test]
async fn test_send_message_multiple_turns() {
let current_turn_tx: CurrentTurnSender = Arc::new(Mutex::new(None));
let (tx1, rx1) = mpsc::unbounded_channel::<Result<Value, ClawError>>();
*current_turn_tx.lock().await = Some(tx1);
let (tx2, rx2) = mpsc::unbounded_channel::<Result<Value, ClawError>>();
*current_turn_tx.lock().await = Some(tx2);
let _ = rx1;
let slot_has_sender = current_turn_tx.lock().await.is_some();
assert!(slot_has_sender, "Current turn sender should be set");
{
let guard = current_turn_tx.lock().await;
if let Some(tx) = guard.as_ref() {
tx.send(Ok(serde_json::json!({"type": "system"}))).unwrap();
}
}
let mut rx2 = rx2;
let received = rx2.try_recv().unwrap();
assert!(received.is_ok());
}
#[tokio::test]
async fn test_receive_response_collects_until_result() {
use crate::messages::Message;
let (tx, rx) = mpsc::unbounded_channel();
let assistant_json = serde_json::json!({
"type": "assistant",
"session_id": "test",
"message": {
"id": "msg_1",
"role": "assistant",
"content": [{"type": "text", "text": "Hello!"}],
"model": "claude-opus-4",
"stop_reason": null,
"stop_sequence": null,
"usage": {"input_tokens": 10, "output_tokens": 5, "cache_creation_input_tokens": 0, "cache_read_input_tokens": 0}
}
});
let result_json = serde_json::json!({
"type": "result",
"subtype": "success",
"session_id": "test",
"result": "done",
"is_error": false,
"num_turns": 1,
"usage": {"input_tokens": 10, "output_tokens": 5, "cache_creation_input_tokens": 0, "cache_read_input_tokens": 0}
});
tx.send(Ok(assistant_json)).unwrap();
tx.send(Ok(result_json)).unwrap();
tx.send(Ok(serde_json::json!({"type": "system", "subtype": "init", "session_id": "x", "tools": [], "mcp_servers": []}))).unwrap();
let stream = ResponseStream::new(rx);
let messages = stream.receive_response().await.unwrap();
assert_eq!(messages.len(), 2);
assert!(matches!(messages[0], Message::Assistant(_)));
assert!(matches!(messages[1], Message::Result(_)));
}
#[test]
fn test_with_client_type_signature() {
fn _assert_types() {
let _f = |client: &ClaudeClient| {
let _ = client.is_connected();
async { Ok::<(), ClawError>(()) }
};
}
}
#[test]
fn test_with_transport_constructor() {
use crate::transport::SubprocessCLITransport;
let transport = SubprocessCLITransport::new(None, vec![]);
let options = ClaudeAgentOptions::default();
let client = ClaudeClient::with_transport(options, Box::new(transport));
assert!(client.is_ok());
let client = client.unwrap();
assert!(!client.is_connected());
}
}