Skip to main content

claude_code_rs/
client.rs

1use std::collections::HashMap;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use serde_json::Value;
7use tokio::sync::{mpsc, Mutex};
8use tokio_stream::wrappers::ReceiverStream;
9use tokio_stream::Stream;
10
11use crate::error::{Error, Result};
12use crate::mcp::SdkMcpServer;
13use crate::query::{McpMessageHandler, Query};
14use crate::transport::subprocess::SubprocessTransport;
15use crate::types::messages::Message;
16use crate::types::options::ClaudeAgentOptions;
17
18/// RAII guard that returns the receiver back to the client on drop.
19///
20/// Implements [`Stream`] by delegating to the inner [`ReceiverStream`],
21/// so `stream.next().await` works as expected.
22pub struct MessageStream<'a> {
23    inner: Option<ReceiverStream<Result<Message>>>,
24    slot: &'a mut Option<mpsc::Receiver<Result<Message>>>,
25}
26
27impl<'a> MessageStream<'a> {
28    fn new(
29        stream: ReceiverStream<Result<Message>>,
30        slot: &'a mut Option<mpsc::Receiver<Result<Message>>>,
31    ) -> Self {
32        Self {
33            inner: Some(stream),
34            slot,
35        }
36    }
37}
38
39impl Stream for MessageStream<'_> {
40    type Item = Result<Message>;
41
42    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
43        match self.inner.as_mut() {
44            Some(stream) => Pin::new(stream).poll_next(cx),
45            None => Poll::Ready(None),
46        }
47    }
48
49    fn size_hint(&self) -> (usize, Option<usize>) {
50        match &self.inner {
51            Some(stream) => stream.size_hint(),
52            None => (0, Some(0)),
53        }
54    }
55}
56
57impl Drop for MessageStream<'_> {
58    fn drop(&mut self) {
59        if let Some(stream) = self.inner.take() {
60            *self.slot = Some(stream.into_inner());
61        }
62    }
63}
64
65/// A stateful client for multi-turn conversations with the Claude CLI.
66///
67/// Unlike `query()` which is one-shot, the client maintains a connection
68/// and supports sending multiple queries, interrupts, and control commands.
69///
70/// # Example
71/// ```no_run
72/// use claude_code_rs::{ClaudeAgentOptions, Message};
73/// use claude_code_rs::client::ClaudeSDKClient;
74/// use tokio_stream::StreamExt;
75///
76/// # async fn example() -> claude_code_rs::Result<()> {
77/// let mut client = ClaudeSDKClient::new(ClaudeAgentOptions::default());
78/// client.connect(None).await?;
79///
80/// // First query.
81/// client.query("What is Rust?", None).await?;
82/// {
83///     let mut stream = client.receive_messages();
84///     while let Some(msg) = stream.next().await {
85///         let msg = msg?;
86///         if msg.is_result() { break; }
87///     }
88/// } // Receiver auto-restores when `stream` drops.
89///
90/// // Follow-up query in same session.
91/// client.query("How does ownership work?", None).await?;
92/// // ...
93///
94/// client.disconnect().await?;
95/// # Ok(())
96/// # }
97/// ```
98pub struct ClaudeSDKClient {
99    options: ClaudeAgentOptions,
100    query: Option<Query>,
101    message_rx: Option<mpsc::Receiver<Result<Message>>>,
102    mcp_servers: HashMap<String, Arc<Mutex<SdkMcpServer>>>,
103}
104
105impl ClaudeSDKClient {
106    fn query_ref(&self) -> Result<&Query> {
107        self.query.as_ref().ok_or(Error::NotConnected)
108    }
109
110    #[must_use]
111    pub fn new(options: ClaudeAgentOptions) -> Self {
112        Self {
113            options,
114            query: None,
115            message_rx: None,
116            mcp_servers: HashMap::new(),
117        }
118    }
119
120    /// Register an in-process MCP server by name.
121    ///
122    /// Must be called **before** [`connect()`](Self::connect). Returns an error
123    /// if the client is already connected (servers are snapshot-cloned during connect).
124    pub fn add_mcp_server(
125        &mut self,
126        name: impl Into<String>,
127        server: SdkMcpServer,
128    ) -> Result<()> {
129        if self.is_connected() {
130            return Err(Error::AlreadyConnected);
131        }
132        self.mcp_servers
133            .insert(name.into(), Arc::new(Mutex::new(server)));
134        Ok(())
135    }
136
137    /// Connect to the Claude CLI. Optionally send an initial prompt.
138    pub async fn connect(&mut self, initial_prompt: Option<&str>) -> Result<()> {
139        if self.query.is_some() {
140            return Err(Error::AlreadyConnected);
141        }
142
143        let cli_path = self.options.resolve_cli_path()?;
144        let transport = SubprocessTransport::new(cli_path, &self.options);
145
146        let mcp_handler = self.build_mcp_handler();
147
148        let mut q = Query::new(
149            Box::new(transport),
150            self.options.hooks.clone(),
151            self.options.can_use_tool.clone(),
152            mcp_handler,
153            self.options.control_timeout,
154        );
155
156        let rx = q.connect().await?;
157        self.message_rx = Some(rx);
158        self.query = Some(q);
159
160        if let Some(prompt) = initial_prompt {
161            self.query_ref()?.send_message(prompt, None).await?;
162        }
163
164        Ok(())
165    }
166
167    /// Send a query/prompt. Optionally provide a session_id for resuming.
168    pub async fn query(&self, prompt: &str, session_id: Option<&str>) -> Result<()> {
169        self.query_ref()?.send_message(prompt, session_id).await
170    }
171
172    /// Get a stream of messages from the current query.
173    ///
174    /// Messages flow until a `ResultMessage` signals end of turn.
175    /// The receiver is automatically restored when the returned
176    /// [`MessageStream`] is dropped, so the client remains usable
177    /// for follow-up queries.
178    pub fn receive_messages(&mut self) -> MessageStream<'_> {
179        let rx = self.message_rx.take().unwrap_or_else(|| {
180            let (_tx, rx) = mpsc::channel(1);
181            rx
182        });
183        MessageStream::new(ReceiverStream::new(rx), &mut self.message_rx)
184    }
185
186    /// Collect all messages until the next ResultMessage.
187    pub async fn receive_response(&mut self) -> Result<Vec<Message>> {
188        use crate::types::messages::collect_until_result;
189
190        let mut stream = self.receive_messages();
191        collect_until_result(&mut stream).await
192        // Receiver auto-restores when `stream` drops.
193    }
194
195    /// Send an interrupt command.
196    pub async fn interrupt(&self) -> Result<Value> {
197        self.query_ref()?.interrupt().await
198    }
199
200    /// Change the permission mode.
201    pub async fn set_permission_mode(&self, mode: &str) -> Result<Value> {
202        self.query_ref()?.set_permission_mode(mode).await
203    }
204
205    /// Change the model.
206    pub async fn set_model(&self, model: &str) -> Result<Value> {
207        self.query_ref()?.set_model(model).await
208    }
209
210    /// Rewind file changes to a specific user message.
211    pub async fn rewind_files(&self, user_message_id: &str) -> Result<Value> {
212        self.query_ref()?.rewind_files(user_message_id).await
213    }
214
215    /// Get MCP server status.
216    pub async fn get_mcp_status(&self) -> Result<Value> {
217        self.query_ref()?.get_mcp_status().await
218    }
219
220    /// Get server info from the init handshake.
221    pub async fn get_server_info(&self) -> Option<Value> {
222        match &self.query {
223            Some(q) => q.get_server_info().await,
224            None => None,
225        }
226    }
227
228    /// Disconnect from the CLI.
229    pub async fn disconnect(&mut self) -> Result<()> {
230        if let Some(mut q) = self.query.take() {
231            q.close().await?;
232        }
233        self.message_rx = None;
234        Ok(())
235    }
236
237    /// Check if connected.
238    pub fn is_connected(&self) -> bool {
239        self.query.is_some()
240    }
241
242    fn build_mcp_handler(&self) -> Option<McpMessageHandler> {
243        if self.mcp_servers.is_empty() {
244            return None;
245        }
246
247        let servers = self.mcp_servers.clone();
248        Some(Arc::new(move |server_name: String, message: Value| {
249            let servers = servers.clone();
250            Box::pin(async move {
251                if let Some(server) = servers.get(&server_name) {
252                    let srv = server.lock().await;
253                    srv.handle_message(message).await
254                } else {
255                    serde_json::json!({"error": format!("unknown MCP server: {server_name}")})
256                }
257            })
258        }))
259    }
260}