use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::sync::mpsc;
use tokio_stream::Stream;
use tracing::{debug, info};
use super::query::Query;
use super::transport::{SubprocessTransport, Transport};
use crate::errors::{ClaudeSDKError, Result};
use crate::types::*;
pub struct ClientStream {
#[allow(dead_code)]
client: InternalClient,
receiver: tokio_stream::wrappers::ReceiverStream<Result<Message>>,
}
impl ClientStream {
fn new(client: InternalClient, rx: mpsc::Receiver<Result<Message>>) -> Self {
Self {
client,
receiver: tokio_stream::wrappers::ReceiverStream::new(rx),
}
}
}
impl Stream for ClientStream {
type Item = Result<Message>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.receiver).poll_next(cx)
}
}
pub struct InternalClient {
query: Option<Query>,
message_rx: Option<mpsc::Receiver<Result<Message>>>,
options: ClaudeAgentOptions,
connected: bool,
}
impl InternalClient {
pub fn new(options: ClaudeAgentOptions) -> Self {
Self {
query: None,
message_rx: None,
options,
connected: false,
}
}
fn validate_options(&self) -> Result<()> {
if self.options.can_use_tool.is_some() && self.options.permission_prompt_tool_name.is_some()
{
return Err(ClaudeSDKError::configuration(
"Cannot specify both 'can_use_tool' and 'permission_prompt_tool_name'",
));
}
Ok(())
}
fn build_agents_dict(
options: &ClaudeAgentOptions,
) -> Option<std::collections::HashMap<String, serde_json::Value>> {
options.agents.as_ref().map(|agents| {
agents
.iter()
.map(|(name, def)| {
let value = serde_json::to_value(def).unwrap_or(serde_json::Value::Null);
(name.clone(), value)
})
.collect()
})
}
pub async fn connect(&mut self) -> Result<()> {
if self.connected {
return Ok(());
}
self.validate_options()?;
let agents_dict = Self::build_agents_dict(&self.options);
let mut transport = SubprocessTransport::new(&self.options)?;
transport.connect().await?;
let (query, message_rx) = Query::new(transport, &self.options, agents_dict);
self.message_rx = Some(message_rx);
self.query = Some(query);
if let Some(ref mut q) = self.query {
q.start().await?;
let response = q.initialize().await?;
debug!("CLI initialized: {:?}", response);
}
self.connected = true;
info!("Connected to Claude CLI");
Ok(())
}
pub async fn process_query(
options: ClaudeAgentOptions,
prompt: &str,
) -> Result<Pin<Box<dyn Stream<Item = Result<Message>> + Send>>> {
if options.can_use_tool.is_some() && options.permission_prompt_tool_name.is_some() {
return Err(ClaudeSDKError::configuration(
"Cannot specify both 'can_use_tool' and 'permission_prompt_tool_name'",
));
}
let has_hooks_or_callbacks = options.can_use_tool.is_some() || options.hooks.is_some();
let mut client = InternalClient::new(options);
client.connect().await?;
client.send_message(prompt).await?;
if has_hooks_or_callbacks {
client.set_close_stdin_on_result(true);
} else {
client.end_input().await?;
}
let rx = client
.take_message_rx()
.ok_or_else(|| ClaudeSDKError::internal("Message receiver not available"))?;
Ok(Box::pin(ClientStream::new(client, rx)))
}
pub async fn send_message(&mut self, message: &str) -> Result<()> {
let query = self
.query
.as_ref()
.ok_or_else(|| ClaudeSDKError::cli_connection("Client not connected"))?;
query.send_message(message).await
}
fn set_close_stdin_on_result(&self, value: bool) {
if let Some(ref q) = self.query {
q.set_close_stdin_on_result(value);
}
}
pub async fn end_input(&self) -> Result<()> {
let query = self
.query
.as_ref()
.ok_or_else(|| ClaudeSDKError::cli_connection("Client not connected"))?;
query.end_input().await
}
pub fn take_message_rx(&mut self) -> Option<mpsc::Receiver<Result<Message>>> {
self.message_rx.take()
}
pub async fn interrupt(&self) -> Result<()> {
let query = self
.query
.as_ref()
.ok_or_else(|| ClaudeSDKError::cli_connection("Client not connected"))?;
query.interrupt().await
}
pub async fn set_permission_mode(&self, mode: PermissionMode) -> Result<()> {
let query = self
.query
.as_ref()
.ok_or_else(|| ClaudeSDKError::cli_connection("Client not connected"))?;
query.set_permission_mode(mode).await
}
pub async fn set_model(&self, model: impl Into<String>) -> Result<()> {
let query = self
.query
.as_ref()
.ok_or_else(|| ClaudeSDKError::cli_connection("Client not connected"))?;
query.set_model(model).await
}
pub async fn rewind_files(&self, user_message_id: impl Into<String>) -> Result<()> {
let query = self
.query
.as_ref()
.ok_or_else(|| ClaudeSDKError::cli_connection("Client not connected"))?;
query.rewind_files(user_message_id).await
}
pub async fn get_server_info(&self) -> Option<serde_json::Value> {
let query = self.query.as_ref()?;
query.get_server_info().await
}
pub async fn get_mcp_status(&self) -> Result<serde_json::Value> {
let query = self
.query
.as_ref()
.ok_or_else(|| ClaudeSDKError::cli_connection("Client not connected"))?;
query.get_mcp_status().await
}
pub async fn disconnect(&mut self) -> Result<()> {
if !self.connected {
return Ok(());
}
if let Some(ref mut query) = self.query {
query.stop().await?;
}
self.query = None;
self.message_rx = None;
self.connected = false;
info!("Disconnected from Claude CLI");
Ok(())
}
pub fn is_connected(&self) -> bool {
self.connected
}
}
impl Drop for InternalClient {
fn drop(&mut self) {
}
}
pub async fn check_cli_version(cli_path: Option<&std::path::Path>) -> Result<String> {
use std::process::Stdio;
use tokio::process::Command;
let path = cli_path
.map(|p| p.to_path_buf())
.unwrap_or_else(|| std::path::PathBuf::from("claude"));
let output = tokio::time::timeout(
std::time::Duration::from_secs(2),
Command::new(&path)
.arg("--version")
.stdout(Stdio::piped())
.stderr(Stdio::null())
.output(),
)
.await
.map_err(|_| ClaudeSDKError::timeout(2000))?
.map_err(|e| {
if e.kind() == std::io::ErrorKind::NotFound {
ClaudeSDKError::cli_not_found(format!("CLI not found at {}", path.display()))
} else {
ClaudeSDKError::cli_connection_with_source("Failed to run CLI version check", e)
}
})?;
let version_str = String::from_utf8_lossy(&output.stdout);
let version = version_str
.lines()
.next()
.and_then(|line| line.split_whitespace().last())
.unwrap_or("unknown")
.to_string();
if let (Ok(found), Ok(required)) = (
semver::Version::parse(&version),
semver::Version::parse(crate::MIN_CLI_VERSION),
) {
if found < required {
tracing::warn!(
"CLI version {} is below minimum required version {}",
version,
crate::MIN_CLI_VERSION
);
}
}
Ok(version)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_internal_client_new() {
let options = ClaudeAgentOptions::new();
let client = InternalClient::new(options);
assert!(!client.is_connected());
}
#[test]
fn test_validate_options_conflict() {
use std::sync::Arc;
let mut options = ClaudeAgentOptions::new();
options.can_use_tool = Some(Arc::new(|_, _, _| {
Box::pin(async { PermissionResult::allow() })
}));
options.permission_prompt_tool_name = Some("test".to_string());
let client = InternalClient::new(options);
assert!(client.validate_options().is_err());
}
}