use std::sync::Arc;
use tokio::sync::{mpsc, Mutex};
use crate::control::{ControlMessage, ControlRequest, ProtocolHandler};
use crate::error::{ClaudeError, Result};
use crate::hooks::HookManager;
use crate::message::parse_message;
use crate::permissions::PermissionManager;
use crate::transport::{PromptInput, SubprocessTransport, Transport};
use crate::types::{
ClaudeAgentOptions, HookContext, HookEvent, Message, PermissionRequest, RequestId,
};
pub struct ClaudeSDKClient {
transport: Arc<Mutex<SubprocessTransport>>,
protocol: Arc<Mutex<ProtocolHandler>>,
message_rx: mpsc::UnboundedReceiver<Result<Message>>,
control_tx: mpsc::UnboundedSender<ControlRequest>,
hook_rx: Option<mpsc::UnboundedReceiver<(String, HookEvent)>>,
permission_rx: Option<mpsc::UnboundedReceiver<(RequestId, PermissionRequest)>>,
#[allow(dead_code)]
hook_manager: Option<Arc<Mutex<HookManager>>>,
#[allow(dead_code)]
permission_manager: Option<Arc<Mutex<PermissionManager>>>,
}
impl ClaudeSDKClient {
pub async fn new(
options: ClaudeAgentOptions,
cli_path: Option<std::path::PathBuf>,
) -> Result<Self> {
let (hook_manager, hook_rx) = if let Some(ref hooks_config) = options.hooks {
let mut manager = HookManager::new();
for matchers in hooks_config.values() {
for matcher in matchers {
manager.register(matcher.clone());
}
}
(Some(Arc::new(Mutex::new(manager))), None)
} else {
(None, Some(mpsc::unbounded_channel().1))
};
let (permission_manager, permission_rx) = if options.can_use_tool.is_some() {
let mut manager = PermissionManager::new();
if let Some(callback) = options.can_use_tool.clone() {
manager.set_callback(callback);
}
manager.set_allowed_tools(Some(options.allowed_tools.clone()));
manager.set_disallowed_tools(options.disallowed_tools.clone());
(Some(Arc::new(Mutex::new(manager))), None)
} else {
(None, Some(mpsc::unbounded_channel().1))
};
let prompt_input = PromptInput::Stream;
let mut transport = SubprocessTransport::new(prompt_input, options, cli_path)?;
transport.connect().await?;
let mut protocol = ProtocolHandler::new();
let (hook_tx, hook_rx_internal) = mpsc::unbounded_channel();
let (permission_tx, permission_rx_internal) = mpsc::unbounded_channel();
protocol.set_hook_channel(hook_tx);
protocol.set_permission_channel(permission_tx);
let (message_tx, message_rx) = mpsc::unbounded_channel();
let (control_tx, control_rx) = mpsc::unbounded_channel();
protocol.set_initialized(true);
let transport = Arc::new(Mutex::new(transport));
let protocol = Arc::new(Mutex::new(protocol));
let transport_clone = transport.clone();
let protocol_clone = protocol.clone();
let message_tx_clone = message_tx;
tokio::spawn(async move {
Self::message_reader_task(transport_clone, protocol_clone, message_tx_clone).await;
});
let transport_clone = transport.clone();
let protocol_clone = protocol.clone();
tokio::spawn(async move {
Self::control_writer_task(transport_clone, protocol_clone, control_rx).await;
});
if let Some(ref manager) = hook_manager {
let manager_clone = manager.clone();
let protocol_clone = protocol.clone();
tokio::spawn(async move {
Self::hook_handler_task(manager_clone, protocol_clone, hook_rx_internal).await;
});
}
if let Some(ref manager) = permission_manager {
let manager_clone = manager.clone();
let protocol_clone = protocol.clone();
tokio::spawn(async move {
Self::permission_handler_task(
manager_clone,
protocol_clone,
permission_rx_internal,
)
.await;
});
}
Ok(Self {
transport,
protocol,
message_rx,
control_tx,
hook_rx,
permission_rx,
hook_manager,
permission_manager,
})
}
async fn message_reader_task(
transport: Arc<Mutex<SubprocessTransport>>,
protocol: Arc<Mutex<ProtocolHandler>>,
message_tx: mpsc::UnboundedSender<Result<Message>>,
) {
let mut msg_stream = {
let mut transport_guard = transport.lock().await;
transport_guard.read_messages()
};
while let Some(result) = msg_stream.recv().await {
match result {
Ok(value) => {
let protocol_guard = protocol.lock().await;
if let Ok(control_msg) = protocol_guard.deserialize_message(
&serde_json::to_string(&value).unwrap_or_default(),
) {
match control_msg {
ControlMessage::InitResponse(init_response) => {
if let Err(e) = protocol_guard.handle_init_response(init_response)
{
let _ = message_tx.send(Err(e));
break;
}
}
ControlMessage::Response(response) => {
if let Err(e) = protocol_guard.handle_response(response).await {
let _ = message_tx.send(Err(e));
}
}
ControlMessage::Request(_) => {
}
ControlMessage::Init(_) => {
}
}
drop(protocol_guard);
continue;
}
drop(protocol_guard);
match parse_message(value) {
Ok(msg) => {
if message_tx.send(Ok(msg)).is_err() {
break;
}
}
Err(e) => {
let _ = message_tx.send(Err(e));
}
}
}
Err(e) => {
let _ = message_tx.send(Err(e));
break;
}
}
}
}
async fn control_writer_task(
transport: Arc<Mutex<SubprocessTransport>>,
_protocol: Arc<Mutex<ProtocolHandler>>,
mut control_rx: mpsc::UnboundedReceiver<ControlRequest>,
) {
while let Some(request) = control_rx.recv().await {
let control_json = match request {
ControlRequest::Interrupt { .. } => {
serde_json::json!({
"type": "control",
"method": "interrupt"
})
}
ControlRequest::SendMessage { content, .. } => {
serde_json::json!({
"type": "user",
"message": {
"role": "user",
"content": content
}
})
}
_ => {
continue;
}
};
if let Ok(json_str) = serde_json::to_string(&control_json) {
let message_line = format!("{json_str}\n");
let mut transport_guard = transport.lock().await;
if transport_guard.write(&message_line).await.is_err() {
break;
}
} else {
break;
}
}
}
async fn hook_handler_task(
manager: Arc<Mutex<HookManager>>,
protocol: Arc<Mutex<ProtocolHandler>>,
mut hook_rx: mpsc::UnboundedReceiver<(String, HookEvent)>,
) {
while let Some((hook_id, _event)) = hook_rx.recv().await {
let manager_guard = manager.lock().await;
let context = HookContext {};
match manager_guard
.invoke(serde_json::json!({}), None, context)
.await
{
Ok(output) => {
drop(manager_guard);
let protocol_guard = protocol.lock().await;
let response = serde_json::to_value(&output).unwrap_or_default();
let _request = protocol_guard.create_hook_response(hook_id, response);
drop(protocol_guard);
#[cfg(feature = "tracing-support")]
tracing::debug!(event = ?_event, "Hook processed");
#[cfg(all(debug_assertions, not(feature = "tracing-support")))]
eprintln!("Hook processed for event {_event:?}");
}
Err(_e) => {
#[cfg(feature = "tracing-support")]
tracing::error!(error = %_e, "Hook processing error");
#[cfg(all(debug_assertions, not(feature = "tracing-support")))]
eprintln!("Hook processing error: {_e}");
}
}
}
}
async fn permission_handler_task(
manager: Arc<Mutex<PermissionManager>>,
protocol: Arc<Mutex<ProtocolHandler>>,
mut permission_rx: mpsc::UnboundedReceiver<(RequestId, PermissionRequest)>,
) {
while let Some((request_id, request)) = permission_rx.recv().await {
let manager_guard = manager.lock().await;
match manager_guard
.can_use_tool(
request.tool_name.clone(),
request.tool_input.clone(),
request.context.clone(),
)
.await
{
Ok(result) => {
drop(manager_guard);
let protocol_guard = protocol.lock().await;
let _request = protocol_guard.create_permission_response(request_id.clone(), result.clone());
drop(protocol_guard);
#[cfg(feature = "tracing-support")]
tracing::debug!(request_id = %request_id.as_str(), result = ?result, "Permission processed");
#[cfg(all(debug_assertions, not(feature = "tracing-support")))]
eprintln!("Permission {} processed: {:?}", request_id.as_str(), result);
}
Err(_e) => {
#[cfg(feature = "tracing-support")]
tracing::error!(error = %_e, "Permission processing error");
#[cfg(all(debug_assertions, not(feature = "tracing-support")))]
eprintln!("Permission processing error: {_e}");
}
}
}
}
pub async fn send_message(&mut self, content: impl Into<String>) -> Result<()> {
let message = serde_json::json!({
"type": "user",
"message": {
"role": "user",
"content": content.into()
}
});
let message_json = format!("{}\n", serde_json::to_string(&message)?);
let mut transport = self.transport.lock().await;
transport.write(&message_json).await
}
pub async fn interrupt(&mut self) -> Result<()> {
let protocol = self.protocol.lock().await;
let request = protocol.create_interrupt_request();
drop(protocol);
self.control_tx
.send(request)
.map_err(|_| ClaudeError::transport("Control channel closed"))
}
pub async fn next_message(&mut self) -> Option<Result<Message>> {
self.message_rx.recv().await
}
pub fn take_hook_receiver(&mut self) -> Option<mpsc::UnboundedReceiver<(String, HookEvent)>> {
self.hook_rx.take()
}
pub fn take_permission_receiver(
&mut self,
) -> Option<mpsc::UnboundedReceiver<(RequestId, PermissionRequest)>> {
self.permission_rx.take()
}
pub async fn respond_to_hook(
&mut self,
hook_id: String,
response: serde_json::Value,
) -> Result<()> {
let protocol = self.protocol.lock().await;
let request = protocol.create_hook_response(hook_id, response);
drop(protocol);
self.control_tx
.send(request)
.map_err(|_| ClaudeError::transport("Control channel closed"))
}
pub async fn respond_to_permission(
&mut self,
request_id: RequestId,
result: crate::types::PermissionResult,
) -> Result<()> {
let protocol = self.protocol.lock().await;
let request = protocol.create_permission_response(request_id, result);
drop(protocol);
self.control_tx
.send(request)
.map_err(|_| ClaudeError::transport("Control channel closed"))
}
pub async fn close(&mut self) -> Result<()> {
let mut transport = self.transport.lock().await;
transport.close().await
}
}
impl Drop for ClaudeSDKClient {
fn drop(&mut self) {
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_client_creation() {
let options = ClaudeAgentOptions::default();
let result = ClaudeSDKClient::new(options, None).await;
assert!(result.is_ok() || result.is_err()); }
}