use std::pin::Pin;
use tokio::sync::mpsc;
use tokio_stream::{Stream, StreamExt};
use crate::_internal::client::InternalClient;
use crate::errors::{ClaudeSDKError, Result};
use crate::types::*;
pub struct ClaudeClient {
internal: InternalClient,
message_rx: Option<mpsc::Receiver<Result<Message>>>,
}
impl ClaudeClient {
pub fn new(options: Option<ClaudeAgentOptions>) -> Self {
Self {
internal: InternalClient::new(options.unwrap_or_default()),
message_rx: None,
}
}
pub async fn connect(&mut self) -> Result<()> {
self.internal.connect().await?;
self.message_rx = self.internal.take_message_rx();
Ok(())
}
pub async fn query(&mut self, prompt: &str) -> Result<()> {
self.internal.send_message(prompt).await
}
pub fn receive_messages(&mut self) -> impl Stream<Item = Result<Message>> + '_ {
futures::stream::poll_fn(move |cx| {
if let Some(ref mut rx) = self.message_rx {
Pin::new(rx).poll_recv(cx)
} else {
std::task::Poll::Ready(None)
}
})
}
pub async fn receive_response(&mut self) -> Result<(String, ResultMessage)> {
let mut response_parts: Vec<String> = Vec::new();
while let Some(msg) = self.receive_messages().next().await {
match msg? {
Message::Assistant(asst) => {
let text = asst.text();
if !text.is_empty() {
response_parts.push(text);
}
}
Message::Result(result) => {
return Ok((response_parts.concat(), result));
}
_ => {}
}
}
Err(ClaudeSDKError::internal("Connection closed without result"))
}
pub async fn interrupt(&self) -> Result<()> {
self.internal.interrupt().await
}
pub async fn set_permission_mode(&self, mode: PermissionMode) -> Result<()> {
self.internal.set_permission_mode(mode).await
}
pub async fn set_model(&self, model: impl Into<String>) -> Result<()> {
self.internal.set_model(model).await
}
pub async fn rewind_files(&self, user_message_id: impl Into<String>) -> Result<()> {
self.internal.rewind_files(user_message_id).await
}
pub async fn get_server_info(&self) -> Option<serde_json::Value> {
self.internal.get_server_info().await
}
pub async fn get_mcp_status(&self) -> Result<serde_json::Value> {
self.internal.get_mcp_status().await
}
pub async fn disconnect(&mut self) -> Result<()> {
self.message_rx = None;
self.internal.disconnect().await
}
pub fn is_connected(&self) -> bool {
self.internal.is_connected()
}
}
pub struct ClaudeClientBuilder {
options: ClaudeAgentOptions,
}
impl ClaudeClientBuilder {
pub fn new() -> Self {
Self {
options: ClaudeAgentOptions::new(),
}
}
pub fn model(mut self, model: impl Into<String>) -> Self {
self.options.model = Some(model.into());
self
}
pub fn system_prompt(mut self, prompt: impl Into<String>) -> Self {
self.options.system_prompt = Some(SystemPromptConfig::Text(prompt.into()));
self
}
pub fn permission_mode(mut self, mode: PermissionMode) -> Self {
self.options.permission_mode = Some(mode);
self
}
pub fn max_turns(mut self, turns: u32) -> Self {
self.options.max_turns = Some(turns);
self
}
pub fn max_budget_usd(mut self, budget: f64) -> Self {
self.options.max_budget_usd = Some(budget);
self
}
pub fn cwd(mut self, path: impl Into<std::path::PathBuf>) -> Self {
self.options.cwd = Some(path.into());
self
}
pub fn can_use_tool<F, Fut>(mut self, callback: F) -> Self
where
F: Fn(String, serde_json::Value, ToolPermissionContext) -> Fut + Send + Sync + 'static,
Fut: std::future::Future<Output = PermissionResult> + Send + 'static,
{
self.options = self.options.with_can_use_tool(callback);
self
}
pub fn include_partial_messages(mut self) -> Self {
self.options.include_partial_messages = true;
self
}
pub fn enable_file_checkpointing(mut self) -> Self {
self.options.enable_file_checkpointing = true;
self
}
pub fn allowed_tools(mut self, tools: Vec<String>) -> Self {
self.options.allowed_tools = tools;
self
}
pub fn disallowed_tools(mut self, tools: Vec<String>) -> Self {
self.options.disallowed_tools = tools;
self
}
pub fn build(self) -> ClaudeClient {
ClaudeClient::new(Some(self.options))
}
}
impl Default for ClaudeClientBuilder {
fn default() -> Self {
Self::new()
}
}
pub struct ClientGuard {
client: Option<ClaudeClient>,
runtime: Option<tokio::runtime::Handle>,
}
impl ClientGuard {
pub fn new(client: ClaudeClient) -> Self {
Self {
client: Some(client),
runtime: tokio::runtime::Handle::try_current().ok(),
}
}
pub fn client(&self) -> &ClaudeClient {
self.client.as_ref().expect("Client already taken")
}
pub fn client_mut(&mut self) -> &mut ClaudeClient {
self.client.as_mut().expect("Client already taken")
}
pub fn into_inner(mut self) -> ClaudeClient {
self.client.take().expect("Client already taken")
}
}
impl Drop for ClientGuard {
fn drop(&mut self) {
if let Some(mut client) = self.client.take() {
if let Some(runtime) = &self.runtime {
runtime.spawn(async move {
let _ = client.disconnect().await;
});
} else {
tracing::warn!(
"ClientGuard dropped without Tokio runtime - skipping async disconnect"
);
}
}
}
}
impl ClaudeClient {
pub fn into_guard(self) -> ClientGuard {
ClientGuard::new(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_builder() {
let client = ClaudeClientBuilder::new()
.model("claude-3-sonnet")
.max_turns(5)
.build();
assert!(!client.is_connected());
}
}