Skip to main content

claude_code_rs/
client.rs

1use serde_json::Value;
2use tokio::sync::mpsc;
3use tokio_stream::wrappers::ReceiverStream;
4
5use crate::error::{Error, Result};
6use crate::mcp::SdkMcpServer;
7use crate::query::{McpMessageHandler, Query};
8use crate::transport::cli_discovery;
9use crate::transport::subprocess::SubprocessTransport;
10use crate::types::messages::Message;
11use crate::types::options::ClaudeAgentOptions;
12
13/// A stateful client for multi-turn conversations with the Claude CLI.
14///
15/// Unlike `query()` which is one-shot, the client maintains a connection
16/// and supports sending multiple queries, interrupts, and control commands.
17///
18/// # Example
19/// ```no_run
20/// use claude_code_rs::{ClaudeAgentOptions, Message};
21/// use claude_code_rs::client::ClaudeSDKClient;
22/// use tokio_stream::StreamExt;
23///
24/// # async fn example() -> claude_code_rs::Result<()> {
25/// let mut client = ClaudeSDKClient::new(ClaudeAgentOptions::default());
26/// client.connect(None).await?;
27///
28/// // First query.
29/// client.query("What is Rust?", None).await?;
30/// let mut stream = client.receive_messages();
31/// while let Some(msg) = stream.next().await {
32///     let msg = msg?;
33///     if msg.is_result() { break; }
34/// }
35///
36/// // Follow-up query in same session.
37/// client.query("How does ownership work?", None).await?;
38/// // ...
39///
40/// client.disconnect().await?;
41/// # Ok(())
42/// # }
43/// ```
44pub struct ClaudeSDKClient {
45    options: ClaudeAgentOptions,
46    query: Option<Query>,
47    message_rx: Option<mpsc::Receiver<Result<Message>>>,
48    mcp_servers: Vec<SdkMcpServer>,
49}
50
51impl ClaudeSDKClient {
52    pub fn new(options: ClaudeAgentOptions) -> Self {
53        Self {
54            options,
55            query: None,
56            message_rx: None,
57            mcp_servers: Vec::new(),
58        }
59    }
60
61    /// Register an in-process MCP server.
62    pub fn add_mcp_server(&mut self, server: SdkMcpServer) {
63        self.mcp_servers.push(server);
64    }
65
66    /// Connect to the Claude CLI. Optionally send an initial prompt.
67    pub async fn connect(&mut self, initial_prompt: Option<&str>) -> Result<()> {
68        if self.query.is_some() {
69            return Err(Error::AlreadyConnected);
70        }
71
72        let cli_path = match self.options.cli_path {
73            Some(ref p) => p.clone(),
74            None => cli_discovery::find_cli()?,
75        };
76
77        let transport = SubprocessTransport::new(cli_path, &self.options);
78
79        let mcp_handler = self.build_mcp_handler();
80
81        let mut q = Query::new(
82            Box::new(transport),
83            std::mem::take(&mut self.options.hooks),
84            self.options.can_use_tool.take(),
85            mcp_handler,
86            self.options.control_timeout,
87        );
88
89        let rx = q.connect().await?;
90        self.message_rx = Some(rx);
91        self.query = Some(q);
92
93        if let Some(prompt) = initial_prompt {
94            self.query
95                .as_ref()
96                .unwrap()
97                .send_message(prompt, None)
98                .await?;
99        }
100
101        Ok(())
102    }
103
104    /// Send a query/prompt. Optionally provide a session_id for resuming.
105    pub async fn query(&self, prompt: &str, session_id: Option<&str>) -> Result<()> {
106        let q = self.query.as_ref().ok_or(Error::NotConnected)?;
107        q.send_message(prompt, session_id).await
108    }
109
110    /// Get a stream of messages from the current query.
111    ///
112    /// Messages flow until a `ResultMessage` signals end of turn.
113    pub fn receive_messages(&mut self) -> ReceiverStream<Result<Message>> {
114        let rx = self.message_rx.take().unwrap_or_else(|| {
115            let (_tx, rx) = mpsc::channel(1);
116            rx
117        });
118        ReceiverStream::new(rx)
119    }
120
121    /// Collect all messages until the next ResultMessage.
122    pub async fn receive_response(&mut self) -> Result<Vec<Message>> {
123        use tokio_stream::StreamExt;
124
125        let mut stream = self.receive_messages();
126        let mut messages = Vec::new();
127
128        while let Some(msg) = stream.next().await {
129            let msg = msg?;
130            let is_result = msg.is_result();
131            messages.push(msg);
132            if is_result {
133                break;
134            }
135        }
136
137        // Put the receiver back (the stream may have more messages).
138        self.message_rx = Some(stream.into_inner());
139
140        Ok(messages)
141    }
142
143    /// Send an interrupt command.
144    pub async fn interrupt(&self) -> Result<Value> {
145        let q = self.query.as_ref().ok_or(Error::NotConnected)?;
146        q.interrupt().await
147    }
148
149    /// Change the permission mode.
150    pub async fn set_permission_mode(&self, mode: &str) -> Result<Value> {
151        let q = self.query.as_ref().ok_or(Error::NotConnected)?;
152        q.set_permission_mode(mode).await
153    }
154
155    /// Change the model.
156    pub async fn set_model(&self, model: &str) -> Result<Value> {
157        let q = self.query.as_ref().ok_or(Error::NotConnected)?;
158        q.set_model(model).await
159    }
160
161    /// Rewind file changes to a specific user message.
162    pub async fn rewind_files(&self, user_message_id: &str) -> Result<Value> {
163        let q = self.query.as_ref().ok_or(Error::NotConnected)?;
164        q.rewind_files(user_message_id).await
165    }
166
167    /// Get MCP server status.
168    pub async fn get_mcp_status(&self) -> Result<Value> {
169        let q = self.query.as_ref().ok_or(Error::NotConnected)?;
170        q.get_mcp_status().await
171    }
172
173    /// Get server info from the init handshake.
174    pub async fn get_server_info(&self) -> Option<Value> {
175        match &self.query {
176            Some(q) => q.get_server_info().await,
177            None => None,
178        }
179    }
180
181    /// Disconnect from the CLI.
182    pub async fn disconnect(&mut self) -> Result<()> {
183        if let Some(mut q) = self.query.take() {
184            q.close().await?;
185        }
186        self.message_rx = None;
187        Ok(())
188    }
189
190    /// Check if connected.
191    pub fn is_connected(&self) -> bool {
192        self.query.is_some()
193    }
194
195    fn build_mcp_handler(&self) -> Option<McpMessageHandler> {
196        if self.mcp_servers.is_empty() {
197            return None;
198        }
199
200        // For simplicity, we handle only the first MCP server.
201        // A full implementation would look up by server_name.
202        // TODO: support multiple named MCP servers.
203        None
204        // MCP handler will be wired when we have named server lookup.
205        // The control protocol provides the server_name in the request.
206    }
207}