use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use futures::{Stream, StreamExt};
use serde_json::Value;
use tokio::task::JoinHandle;
use tracing::warn;
use crate::errors::{CLIConnectionError, Error, Result};
use crate::query::{Query, build_hooks_config};
use crate::sdk_mcp::McpSdkServer;
use crate::transport::subprocess_cli::{Prompt as TransportPrompt, SubprocessCliTransport};
use crate::transport::{Transport, TransportFactory};
use crate::types::{
ClaudeAgentOptions, McpServerConfig, McpServersOption, McpStatusResponse, Message,
};
#[derive(Debug, Clone, PartialEq)]
pub enum InputPrompt {
Text(String),
Messages(Vec<Value>),
}
pub struct ClaudeSdkClient {
options: ClaudeAgentOptions,
transport_factory: Option<Box<dyn TransportFactory>>,
query: Option<Query>,
initial_message_stream_task: Option<JoinHandle<Result<()>>>,
}
struct SingleUseTransportFactory(std::sync::Mutex<Option<Box<dyn Transport>>>);
impl TransportFactory for SingleUseTransportFactory {
fn create_transport(&self) -> Result<Box<dyn Transport>> {
self.0
.lock()
.map_err(|_| Error::Other("Transport factory lock poisoned".to_string()))?
.take()
.ok_or_else(|| {
Error::Other(
"Single-use transport already consumed. Use a TransportFactory for reconnect support."
.to_string(),
)
})
}
}
impl ClaudeSdkClient {
pub fn new(
options: Option<ClaudeAgentOptions>,
transport_factory: Option<Box<dyn TransportFactory>>,
) -> Self {
Self {
options: options.unwrap_or_default(),
transport_factory,
query: None,
initial_message_stream_task: None,
}
}
pub fn new_with_transport(
options: Option<ClaudeAgentOptions>,
transport: Box<dyn Transport>,
) -> Self {
Self {
options: options.unwrap_or_default(),
transport_factory: Some(Box::new(SingleUseTransportFactory(std::sync::Mutex::new(
Some(transport),
)))),
query: None,
initial_message_stream_task: None,
}
}
async fn handle_initial_message_stream_task(&mut self, abort_running: bool) -> Result<()> {
let Some(task) = self.initial_message_stream_task.take() else {
return Ok(());
};
if abort_running && !task.is_finished() {
task.abort();
}
match task.await {
Ok(Ok(())) => Ok(()),
Ok(Err(err)) => {
if abort_running {
warn!("Initial message stream task ended with error during shutdown: {err}");
Ok(())
} else {
Err(err)
}
}
Err(join_err) => {
if join_err.is_cancelled() {
Ok(())
} else {
let message = format!("Initial message stream task panicked: {join_err}");
if abort_running {
warn!("{message}");
Ok(())
} else {
Err(Error::Other(message))
}
}
}
}
}
fn initialize_timeout() -> Duration {
let timeout_ms = std::env::var("CLAUDE_CODE_STREAM_CLOSE_TIMEOUT")
.ok()
.and_then(|value| value.parse::<u64>().ok())
.unwrap_or(60_000);
Duration::from_secs_f64((timeout_ms as f64 / 1000.0).max(60.0))
}
fn extract_sdk_mcp_servers(options: &ClaudeAgentOptions) -> HashMap<String, Arc<McpSdkServer>> {
let mut servers = HashMap::new();
if let McpServersOption::Servers(configs) = &options.mcp_servers {
for (name, config) in configs {
if let McpServerConfig::Sdk(sdk_config) = config {
servers.insert(name.clone(), sdk_config.instance.clone());
}
}
}
servers
}
pub async fn connect(&mut self, prompt: Option<InputPrompt>) -> Result<()> {
self.handle_initial_message_stream_task(true).await?;
if self.query.is_some() {
self.disconnect().await?;
}
if self.options.can_use_tool.is_some() {
if matches!(prompt, Some(InputPrompt::Text(_))) {
return Err(Error::Other(
"can_use_tool callback requires streaming mode. Please provide prompt as messages."
.to_string(),
));
}
if self.options.permission_prompt_tool_name.is_some() {
return Err(Error::Other(
"can_use_tool callback cannot be used with permission_prompt_tool_name."
.to_string(),
));
}
}
let mut configured_options = self.options.clone();
if configured_options.can_use_tool.is_some() {
configured_options.permission_prompt_tool_name = Some("stdio".to_string());
}
let transport_prompt = match &prompt {
Some(InputPrompt::Text(text)) => TransportPrompt::Text(text.clone()),
_ => TransportPrompt::Messages,
};
let mut transport: Box<dyn Transport> = if let Some(factory) = &self.transport_factory {
factory.create_transport()?
} else {
Box::new(SubprocessCliTransport::new(
transport_prompt,
configured_options.clone(),
)?)
};
transport.connect().await?;
let hooks = configured_options.hooks.clone().unwrap_or_default();
let sdk_mcp_servers = Self::extract_sdk_mcp_servers(&configured_options);
let (hooks_config, hook_callbacks) = build_hooks_config(&hooks);
let (reader, writer, close_handle) = transport.into_split()?;
let mut query = Query::start(
reader,
writer,
close_handle,
true,
configured_options.can_use_tool.clone(),
hook_callbacks,
sdk_mcp_servers,
configured_options.agents.clone(),
Self::initialize_timeout(),
);
query.initialize(hooks_config).await?;
if let Some(InputPrompt::Messages(messages)) = prompt {
query.send_input_messages(messages).await?;
}
self.query = Some(query);
Ok(())
}
pub async fn connect_with_messages<S>(&mut self, prompt: S) -> Result<()>
where
S: Stream<Item = Value> + Send + Unpin + 'static,
{
self.connect(None).await?;
let query = self.query.as_ref().ok_or_else(|| {
Error::CLIConnection(CLIConnectionError::new(
"Not connected. Call connect() first.",
))
})?;
self.initial_message_stream_task = Some(query.spawn_input_from_stream(prompt)?);
Ok(())
}
pub async fn wait_for_initial_messages(&mut self) -> Result<()> {
self.handle_initial_message_stream_task(false).await
}
pub async fn query(&self, prompt: InputPrompt, session_id: &str) -> Result<()> {
let query = self.query.as_ref().ok_or_else(|| {
Error::CLIConnection(CLIConnectionError::new(
"Not connected. Call connect() first.",
))
})?;
match prompt {
InputPrompt::Text(text) => {
query.send_user_message(&text, session_id).await?;
}
InputPrompt::Messages(messages) => {
for mut message in messages {
if let Value::Object(ref mut obj) = message
&& !obj.contains_key("session_id")
{
obj.insert(
"session_id".to_string(),
Value::String(session_id.to_string()),
);
}
query.send_raw_message(message).await?;
}
}
}
Ok(())
}
pub async fn query_stream<S>(&self, prompt: S, session_id: &str) -> Result<()>
where
S: Stream<Item = Value> + Unpin,
{
let query = self.query.as_ref().ok_or_else(|| {
Error::CLIConnection(CLIConnectionError::new(
"Not connected. Call connect() first.",
))
})?;
let session_id = session_id.to_string();
let mapped = prompt.map(move |mut message| {
if let Value::Object(ref mut obj) = message
&& !obj.contains_key("session_id")
{
obj.insert("session_id".to_string(), Value::String(session_id.clone()));
}
message
});
query.send_input_from_stream(mapped).await
}
pub async fn receive_message(&mut self) -> Result<Option<Message>> {
let query = self.query.as_mut().ok_or_else(|| {
Error::CLIConnection(CLIConnectionError::new(
"Not connected. Call connect() first.",
))
})?;
query.receive_next_message().await
}
pub async fn receive_response(&mut self) -> Result<Vec<Message>> {
let mut messages = Vec::new();
while let Some(message) = self.receive_message().await? {
let is_result = matches!(message, Message::Result(_));
messages.push(message);
if is_result {
break;
}
}
Ok(messages)
}
pub async fn interrupt(&self) -> Result<()> {
let query = self.query.as_ref().ok_or_else(|| {
Error::CLIConnection(CLIConnectionError::new(
"Not connected. Call connect() first.",
))
})?;
query.interrupt().await
}
pub async fn set_permission_mode(&self, mode: &str) -> Result<()> {
let query = self.query.as_ref().ok_or_else(|| {
Error::CLIConnection(CLIConnectionError::new(
"Not connected. Call connect() first.",
))
})?;
query.set_permission_mode(mode).await
}
pub async fn set_model(&self, model: Option<&str>) -> Result<()> {
let query = self.query.as_ref().ok_or_else(|| {
Error::CLIConnection(CLIConnectionError::new(
"Not connected. Call connect() first.",
))
})?;
query.set_model(model).await
}
pub async fn rewind_files(&self, user_message_id: &str) -> Result<()> {
let query = self.query.as_ref().ok_or_else(|| {
Error::CLIConnection(CLIConnectionError::new(
"Not connected. Call connect() first.",
))
})?;
query.rewind_files(user_message_id).await
}
pub async fn get_mcp_status(&self) -> Result<McpStatusResponse> {
let query = self.query.as_ref().ok_or_else(|| {
Error::CLIConnection(CLIConnectionError::new(
"Not connected. Call connect() first.",
))
})?;
query.get_mcp_status().await
}
pub async fn reconnect_mcp_server(&self, server_name: &str) -> Result<()> {
let query = self.query.as_ref().ok_or_else(|| {
Error::CLIConnection(CLIConnectionError::new(
"Not connected. Call connect() first.",
))
})?;
query.reconnect_mcp_server(server_name).await
}
pub async fn toggle_mcp_server(&self, server_name: &str, enabled: bool) -> Result<()> {
let query = self.query.as_ref().ok_or_else(|| {
Error::CLIConnection(CLIConnectionError::new(
"Not connected. Call connect() first.",
))
})?;
query.toggle_mcp_server(server_name, enabled).await
}
pub async fn stop_task(&self, task_id: &str) -> Result<()> {
let query = self.query.as_ref().ok_or_else(|| {
Error::CLIConnection(CLIConnectionError::new(
"Not connected. Call connect() first.",
))
})?;
query.stop_task(task_id).await
}
pub fn get_server_info(&self) -> Result<Option<Value>> {
let query = self.query.as_ref().ok_or_else(|| {
Error::CLIConnection(CLIConnectionError::new(
"Not connected. Call connect() first.",
))
})?;
Ok(query.initialization_result())
}
pub async fn disconnect(&mut self) -> Result<()> {
self.handle_initial_message_stream_task(true).await?;
if let Some(query) = self.query.take() {
query.close().await?;
}
Ok(())
}
}
impl Drop for ClaudeSdkClient {
fn drop(&mut self) {
if let Some(task) = self.initial_message_stream_task.take() {
task.abort();
}
}
}