use crate::{
errors::{Result, SdkError},
transport::{InputMessage, SubprocessTransport, Transport},
types::{ClaudeCodeOptions, ControlRequest, Message},
};
use futures::{Stream, StreamExt};
use std::sync::Arc;
use tokio::sync::Mutex;
use tokio_stream::wrappers::ReceiverStream;
use tracing::{debug, info};
pub struct InteractiveClient {
transport: Arc<Mutex<Box<dyn Transport + Send>>>,
connected: bool,
}
impl InteractiveClient {
pub fn new(options: ClaudeCodeOptions) -> Result<Self> {
unsafe {
std::env::set_var("CLAUDE_CODE_ENTRYPOINT", "sdk-rust");
}
let transport: Box<dyn Transport + Send> = Box::new(SubprocessTransport::new(options)?);
Ok(Self {
transport: Arc::new(Mutex::new(transport)),
connected: false,
})
}
pub async fn connect(&mut self) -> Result<()> {
if self.connected {
return Ok(());
}
let mut transport = self.transport.lock().await;
transport.connect().await?;
drop(transport);
self.connected = true;
info!("Connected to Claude CLI");
Ok(())
}
pub async fn send_and_receive(&mut self, prompt: String) -> Result<Vec<Message>> {
if !self.connected {
return Err(SdkError::InvalidState {
message: "Not connected".into(),
});
}
{
let mut transport = self.transport.lock().await;
let message = InputMessage::user(prompt, "default".to_string());
transport.send_message(message).await?;
}
debug!("Message sent, waiting for response");
let mut messages = Vec::new();
loop {
let msg_result = {
let mut transport = self.transport.lock().await;
let mut stream = transport.receive_messages();
stream.next().await
};
if let Some(result) = msg_result {
match result {
Ok(msg) => {
debug!("Received: {:?}", msg);
let is_result = matches!(msg, Message::Result { .. });
messages.push(msg);
if is_result {
break;
}
}
Err(e) => return Err(e),
}
} else {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
}
Ok(messages)
}
pub async fn send_message(&mut self, prompt: String) -> Result<()> {
if !self.connected {
return Err(SdkError::InvalidState {
message: "Not connected".into(),
});
}
let mut transport = self.transport.lock().await;
let message = InputMessage::user(prompt, "default".to_string());
transport.send_message(message).await?;
drop(transport);
debug!("Message sent");
Ok(())
}
pub async fn receive_response(&mut self) -> Result<Vec<Message>> {
if !self.connected {
return Err(SdkError::InvalidState {
message: "Not connected".into(),
});
}
let mut messages = Vec::new();
loop {
let msg_result = {
let mut transport = self.transport.lock().await;
let mut stream = transport.receive_messages();
stream.next().await
};
if let Some(result) = msg_result {
match result {
Ok(msg) => {
debug!("Received: {:?}", msg);
let is_result = matches!(msg, Message::Result { .. });
messages.push(msg);
if is_result {
break;
}
}
Err(e) => return Err(e),
}
} else {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
}
}
Ok(messages)
}
pub async fn receive_messages_stream(&mut self) -> impl Stream<Item = Result<Message>> + '_ {
let (tx, rx) = tokio::sync::mpsc::channel(100);
let transport = self.transport.clone();
tokio::spawn(async move {
let mut transport = transport.lock().await;
let mut stream = transport.receive_messages();
while let Some(result) = stream.next().await {
if tx.send(result).await.is_err() {
break;
}
}
});
ReceiverStream::new(rx)
}
pub async fn receive_response_stream(&mut self) -> impl Stream<Item = Result<Message>> + '_ {
async_stream::stream! {
let mut stream = self.receive_messages_stream().await;
while let Some(result) = stream.next().await {
match &result {
Ok(msg) => {
let is_result = matches!(msg, Message::Result { .. });
yield result;
if is_result {
break;
}
}
Err(_) => {
yield result;
break;
}
}
}
}
}
pub async fn interrupt(&mut self) -> Result<()> {
if !self.connected {
return Err(SdkError::InvalidState {
message: "Not connected".into(),
});
}
let mut transport = self.transport.lock().await;
let request = ControlRequest::Interrupt {
request_id: uuid::Uuid::new_v4().to_string(),
};
transport.send_control_request(request).await?;
drop(transport);
info!("Interrupt sent");
Ok(())
}
pub async fn get_mcp_status(&mut self) -> Result<Vec<crate::types::McpServerStatus>> {
if !self.connected {
return Err(SdkError::InvalidState {
message: "Not connected".into(),
});
}
Ok(vec![])
}
pub async fn add_mcp_server(
&mut self,
name: &str,
config: crate::types::McpServerConfig,
) -> Result<()> {
if !self.connected {
return Err(SdkError::InvalidState {
message: "Not connected".into(),
});
}
let config_json = serde_json::to_value(&config)
.map_err(|e| SdkError::TransportError(format!("Failed to serialize MCP config: {e}")))?;
let mcp_msg = crate::types::SDKControlMcpMessageRequest {
subtype: "mcp_message".to_string(),
mcp_server_name: name.to_string(),
message: serde_json::json!({
"action": "add",
"config": config_json
}),
};
let mut transport = self.transport.lock().await;
let request = crate::types::SDKControlRequest::McpMessage(mcp_msg);
let json = serde_json::to_value(&request)
.map_err(|e| SdkError::TransportError(format!("Failed to serialize: {e}")))?;
let input = crate::transport::InputMessage {
r#type: "sdk_control".to_string(),
message: json,
parent_tool_use_id: None,
session_id: String::new(),
};
transport.send_message(input).await
}
pub async fn remove_mcp_server(&mut self, name: &str) -> Result<()> {
if !self.connected {
return Err(SdkError::InvalidState {
message: "Not connected".into(),
});
}
let mcp_msg = crate::types::SDKControlMcpMessageRequest {
subtype: "mcp_message".to_string(),
mcp_server_name: name.to_string(),
message: serde_json::json!({ "action": "remove" }),
};
let mut transport = self.transport.lock().await;
let request = crate::types::SDKControlRequest::McpMessage(mcp_msg);
let json = serde_json::to_value(&request)
.map_err(|e| SdkError::TransportError(format!("Failed to serialize: {e}")))?;
let input = crate::transport::InputMessage {
r#type: "sdk_control".to_string(),
message: json,
parent_tool_use_id: None,
session_id: String::new(),
};
transport.send_message(input).await
}
pub async fn reconnect_mcp_server(&mut self, name: &str) -> Result<()> {
if !self.connected {
return Err(SdkError::InvalidState {
message: "Not connected".into(),
});
}
let mcp_msg = crate::types::SDKControlMcpMessageRequest {
subtype: "mcp_message".to_string(),
mcp_server_name: name.to_string(),
message: serde_json::json!({ "action": "reconnect" }),
};
let mut transport = self.transport.lock().await;
let request = crate::types::SDKControlRequest::McpMessage(mcp_msg);
let json = serde_json::to_value(&request)
.map_err(|e| SdkError::TransportError(format!("Failed to serialize: {e}")))?;
let input = crate::transport::InputMessage {
r#type: "sdk_control".to_string(),
message: json,
parent_tool_use_id: None,
session_id: String::new(),
};
transport.send_message(input).await
}
pub async fn toggle_mcp_server(&mut self, name: &str, enabled: bool) -> Result<()> {
if !self.connected {
return Err(SdkError::InvalidState {
message: "Not connected".into(),
});
}
let mcp_msg = crate::types::SDKControlMcpMessageRequest {
subtype: "mcp_message".to_string(),
mcp_server_name: name.to_string(),
message: serde_json::json!({ "action": "toggle", "enabled": enabled }),
};
let mut transport = self.transport.lock().await;
let request = crate::types::SDKControlRequest::McpMessage(mcp_msg);
let json = serde_json::to_value(&request)
.map_err(|e| SdkError::TransportError(format!("Failed to serialize: {e}")))?;
let input = crate::transport::InputMessage {
r#type: "sdk_control".to_string(),
message: json,
parent_tool_use_id: None,
session_id: String::new(),
};
transport.send_message(input).await
}
pub async fn list_sessions(
&self,
directory: Option<&str>,
limit: Option<usize>,
include_worktrees: bool,
) -> Result<Vec<crate::sessions::SessionInfo>> {
crate::sessions::list_sessions(directory, limit, include_worktrees).await
}
pub async fn get_session_messages(
&self,
session_id: &str,
directory: Option<&str>,
limit: Option<usize>,
offset: usize,
) -> Result<Vec<crate::sessions::SessionMessage>> {
crate::sessions::get_session_messages(session_id, directory, limit, offset).await
}
pub async fn rename_session(&self, session_id: &str, title: &str) -> Result<()> {
crate::sessions::rename_session(session_id, title).await
}
pub async fn tag_session(&self, session_id: &str, tag: Option<&str>) -> Result<()> {
crate::sessions::tag_session(session_id, tag).await
}
pub async fn disconnect(&mut self) -> Result<()> {
if !self.connected {
return Ok(());
}
let mut transport = self.transport.lock().await;
transport.disconnect().await?;
drop(transport);
self.connected = false;
info!("Disconnected from Claude CLI");
Ok(())
}
}