use crate::{
errors::{Result, SdkError},
transport::{InputMessage, SubprocessTransport, Transport},
types::{ClaudeCodeOptions, ControlRequest, Message},
};
use crate::token_tracker::{BudgetLimit, BudgetManager, BudgetWarningCallback, TokenUsageTracker};
use futures::stream::StreamExt;
use std::collections::VecDeque;
use std::sync::Arc;
use tokio::sync::{RwLock, Semaphore, mpsc};
use tokio::time::{Duration, timeout};
use tracing::{debug, error, info, warn};
#[derive(Debug, Clone, Copy)]
pub enum ClientMode {
OneShot,
Interactive,
Batch {
max_concurrent: usize,
},
}
struct ConnectionPool {
idle_connections: Arc<RwLock<VecDeque<Box<dyn Transport + Send>>>>,
max_connections: usize,
connection_semaphore: Arc<Semaphore>,
base_options: ClaudeCodeOptions,
}
impl ConnectionPool {
fn new(base_options: ClaudeCodeOptions, max_connections: usize) -> Self {
Self {
idle_connections: Arc::new(RwLock::new(VecDeque::new())),
max_connections,
connection_semaphore: Arc::new(Semaphore::new(max_connections)),
base_options,
}
}
async fn acquire(&self) -> Result<Box<dyn Transport + Send>> {
{
let mut idle = self.idle_connections.write().await;
if let Some(transport) = idle.pop_front() {
if transport.is_connected() {
debug!("Reusing existing connection from pool");
return Ok(transport);
}
}
}
let _permit =
self.connection_semaphore
.acquire()
.await
.map_err(|_| SdkError::InvalidState {
message: "Failed to acquire connection permit".into(),
})?;
let mut transport: Box<dyn Transport + Send> =
Box::new(SubprocessTransport::new(self.base_options.clone())?);
transport.connect().await?;
debug!("Created new connection");
Ok(transport)
}
async fn release(&self, transport: Box<dyn Transport + Send>) {
if transport.is_connected() && self.idle_connections.read().await.len() < self.max_connections {
let mut idle = self.idle_connections.write().await;
idle.push_back(transport);
debug!("Returned connection to pool");
} else {
debug!("Dropping connection");
}
}
}
pub struct OptimizedClient {
mode: ClientMode,
pool: Arc<ConnectionPool>,
message_rx: Arc<RwLock<Option<mpsc::Receiver<Message>>>>,
current_transport: Arc<RwLock<Option<Box<dyn Transport + Send>>>>,
budget_manager: BudgetManager,
}
impl OptimizedClient {
pub fn new(options: ClaudeCodeOptions, mode: ClientMode) -> Result<Self> {
unsafe {
std::env::set_var("CLAUDE_CODE_ENTRYPOINT", "sdk-rust");
}
let max_connections = match mode {
ClientMode::Batch { max_concurrent } => max_concurrent,
_ => 1,
};
let pool = Arc::new(ConnectionPool::new(options, max_connections));
Ok(Self {
mode,
pool,
message_rx: Arc::new(RwLock::new(None)),
current_transport: Arc::new(RwLock::new(None)),
budget_manager: BudgetManager::new(),
})
}
pub async fn query(&self, prompt: String) -> Result<Vec<Message>> {
self.query_with_retry(prompt, 3, Duration::from_millis(100))
.await
}
pub async fn query_with_retry(
&self,
prompt: String,
max_retries: u32,
initial_delay: Duration,
) -> Result<Vec<Message>> {
let mut retries = 0;
let mut delay = initial_delay;
loop {
match self.execute_query(&prompt).await {
Ok(messages) => return Ok(messages),
Err(e) if retries < max_retries => {
warn!("Query failed, retrying in {:?}: {}", delay, e);
tokio::time::sleep(delay).await;
retries += 1;
delay *= 2; }
Err(e) => return Err(e),
}
}
}
async fn execute_query(&self, prompt: &str) -> Result<Vec<Message>> {
let mut transport = self.pool.acquire().await?;
let message = InputMessage::user(prompt.to_string(), "default".to_string());
transport.send_message(message).await?;
let timeout_duration = Duration::from_secs(120);
let messages = timeout(timeout_duration, self.collect_messages(&mut *transport))
.await
.map_err(|_| SdkError::Timeout { seconds: 120 })??;
self.pool.release(transport).await;
Ok(messages)
}
async fn collect_messages<T: Transport + Send + ?Sized>(&self, transport: &mut T) -> Result<Vec<Message>> {
let mut messages = Vec::new();
let mut stream = transport.receive_messages();
while let Some(result) = stream.next().await {
match result {
Ok(msg) => {
debug!("Received: {:?}", msg);
let is_result = matches!(msg, Message::Result { .. });
if let Message::Result { usage, total_cost_usd, .. } = &msg {
let (input_tokens, output_tokens) = if let Some(usage_json) = usage {
let input = usage_json
.get("input_tokens")
.and_then(|v| v.as_u64())
.unwrap_or(0);
let output = usage_json
.get("output_tokens")
.and_then(|v| v.as_u64())
.unwrap_or(0);
(input, output)
} else {
(0, 0)
};
let cost = total_cost_usd.unwrap_or(0.0);
self.budget_manager
.update_usage(input_tokens, output_tokens, cost)
.await;
}
messages.push(msg);
if is_result {
break;
}
}
Err(e) => return Err(e),
}
}
Ok(messages)
}
pub async fn get_usage_stats(&self) -> TokenUsageTracker {
self.budget_manager.get_usage().await
}
pub async fn set_budget_limit(
&self,
limit: BudgetLimit,
on_warning: Option<BudgetWarningCallback>,
) {
self.budget_manager.set_limit(limit).await;
if let Some(cb) = on_warning {
self.budget_manager.set_warning_callback(cb).await;
}
}
pub async fn clear_budget_limit(&self) {
self.budget_manager.clear_limit().await;
}
pub async fn reset_usage_stats(&self) {
self.budget_manager.reset_usage().await;
}
pub async fn is_budget_exceeded(&self) -> bool {
self.budget_manager.is_exceeded().await
}
pub async fn start_interactive_session(&self) -> Result<()> {
if !matches!(self.mode, ClientMode::Interactive) {
return Err(SdkError::InvalidState {
message: "Client not in interactive mode".into(),
});
}
let transport = self.pool.acquire().await?;
let (tx, rx) = mpsc::channel::<Message>(100);
*self.current_transport.write().await = Some(transport);
*self.message_rx.write().await = Some(rx);
self.start_message_processor(tx).await;
info!("Interactive session started");
Ok(())
}
async fn start_message_processor(&self, tx: mpsc::Sender<Message>) {
let transport_ref = self.current_transport.clone();
tokio::spawn(async move {
loop {
let msg_result = {
let mut transport_guard = transport_ref.write().await;
if let Some(transport) = transport_guard.as_mut() {
let mut stream = transport.receive_messages();
stream.next().await
} else {
break;
}
};
if let Some(result) = msg_result {
match result {
Ok(msg) => {
if tx.send(msg).await.is_err() {
error!("Failed to send message to channel");
break;
}
}
Err(e) => {
error!("Error receiving message: {}", e);
break;
}
}
}
}
});
}
pub async fn send_interactive(&self, prompt: String) -> Result<()> {
let transport_guard = self.current_transport.read().await;
if let Some(_transport) = transport_guard.as_ref() {
drop(transport_guard);
let mut transport_guard = self.current_transport.write().await;
if let Some(transport) = transport_guard.as_mut() {
let message = InputMessage::user(prompt, "default".to_string());
transport.send_message(message).await?;
} else {
return Err(SdkError::InvalidState {
message: "Transport lost during operation".into(),
});
}
Ok(())
} else {
Err(SdkError::InvalidState {
message: "No active interactive session".into(),
})
}
}
pub async fn receive_interactive(&self) -> Result<Vec<Message>> {
let mut rx_guard = self.message_rx.write().await;
if let Some(rx) = rx_guard.as_mut() {
let mut messages = Vec::new();
while let Some(msg) = rx.recv().await {
let is_result = matches!(msg, Message::Result { .. });
messages.push(msg);
if is_result {
break;
}
}
Ok(messages)
} else {
Err(SdkError::InvalidState {
message: "No active interactive session".into(),
})
}
}
pub async fn process_batch(&self, prompts: Vec<String>) -> Result<Vec<Result<Vec<Message>>>> {
let max_concurrent = match self.mode {
ClientMode::Batch { max_concurrent } => max_concurrent,
_ => {
return Err(SdkError::InvalidState {
message: "Client not in batch mode".into(),
});
}
};
let semaphore = Arc::new(Semaphore::new(max_concurrent));
let mut handles = Vec::new();
for prompt in prompts {
let permit = semaphore.clone().acquire_owned().await.unwrap();
let client = self.clone();
let handle = tokio::spawn(async move {
let result = client.query(prompt).await;
drop(permit);
result
});
handles.push(handle);
}
let mut results = Vec::new();
for handle in handles {
match handle.await {
Ok(result) => results.push(result),
Err(e) => {
results.push(Err(SdkError::TransportError(format!("Task failed: {e}"))))
}
}
}
Ok(results)
}
pub async fn interrupt(&self) -> Result<()> {
let transport_guard = self.current_transport.read().await;
if let Some(_transport) = transport_guard.as_ref() {
drop(transport_guard);
let mut transport_guard = self.current_transport.write().await;
if let Some(transport) = transport_guard.as_mut() {
let request = ControlRequest::Interrupt {
request_id: uuid::Uuid::new_v4().to_string(),
};
transport.send_control_request(request).await?;
} else {
return Err(SdkError::InvalidState {
message: "Transport lost during operation".into(),
});
}
info!("Interrupt sent");
Ok(())
} else {
Err(SdkError::InvalidState {
message: "No active session".into(),
})
}
}
pub async fn end_interactive_session(&self) -> Result<()> {
if let Some(transport) = self.current_transport.write().await.take() {
self.pool.release(transport).await;
}
*self.message_rx.write().await = None;
info!("Interactive session ended");
Ok(())
}
}
impl Clone for OptimizedClient {
fn clone(&self) -> Self {
Self {
mode: self.mode,
pool: self.pool.clone(),
message_rx: Arc::new(RwLock::new(None)),
current_transport: Arc::new(RwLock::new(None)),
budget_manager: self.budget_manager.clone(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_client_mode_creation() {
let options = ClaudeCodeOptions::builder().build();
let client = OptimizedClient::new(options.clone(), ClientMode::OneShot);
assert!(client.is_ok());
let client = OptimizedClient::new(options.clone(), ClientMode::Interactive);
assert!(client.is_ok());
let client = OptimizedClient::new(options, ClientMode::Batch { max_concurrent: 5 });
assert!(client.is_ok());
}
#[test]
fn test_connection_pool_creation() {
let options = ClaudeCodeOptions::builder().build();
let pool = ConnectionPool::new(options, 10);
assert_eq!(pool.max_connections, 10);
}
#[tokio::test]
async fn test_client_cloning() {
let options = ClaudeCodeOptions::builder().build();
let client = OptimizedClient::new(options, ClientMode::OneShot).unwrap();
let cloned = client.clone();
match (client.mode, cloned.mode) {
(ClientMode::OneShot, ClientMode::OneShot) => (),
_ => panic!("Mode not preserved during cloning"),
}
}
}