1use std::collections::HashMap;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::task::{Context, Poll};
5
6use serde_json::Value;
7use tokio::sync::{mpsc, Mutex};
8use tokio_stream::wrappers::ReceiverStream;
9use tokio_stream::Stream;
10
11use crate::error::{Error, Result};
12use crate::mcp::SdkMcpServer;
13use crate::query::{McpMessageHandler, Query};
14use crate::transport::subprocess::SubprocessTransport;
15use crate::types::messages::Message;
16use crate::types::options::ClaudeAgentOptions;
17
18pub struct MessageStream<'a> {
23 inner: Option<ReceiverStream<Result<Message>>>,
24 slot: &'a mut Option<mpsc::Receiver<Result<Message>>>,
25}
26
27impl<'a> MessageStream<'a> {
28 fn new(
29 stream: ReceiverStream<Result<Message>>,
30 slot: &'a mut Option<mpsc::Receiver<Result<Message>>>,
31 ) -> Self {
32 Self {
33 inner: Some(stream),
34 slot,
35 }
36 }
37}
38
39impl Stream for MessageStream<'_> {
40 type Item = Result<Message>;
41
42 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
43 match self.inner.as_mut() {
44 Some(stream) => Pin::new(stream).poll_next(cx),
45 None => Poll::Ready(None),
46 }
47 }
48
49 fn size_hint(&self) -> (usize, Option<usize>) {
50 match &self.inner {
51 Some(stream) => stream.size_hint(),
52 None => (0, Some(0)),
53 }
54 }
55}
56
57impl Drop for MessageStream<'_> {
58 fn drop(&mut self) {
59 if let Some(stream) = self.inner.take() {
60 *self.slot = Some(stream.into_inner());
61 }
62 }
63}
64
65pub struct ClaudeSDKClient {
99 options: ClaudeAgentOptions,
100 query: Option<Query>,
101 message_rx: Option<mpsc::Receiver<Result<Message>>>,
102 mcp_servers: HashMap<String, Arc<Mutex<SdkMcpServer>>>,
103}
104
105impl ClaudeSDKClient {
106 fn query_ref(&self) -> Result<&Query> {
107 self.query.as_ref().ok_or(Error::NotConnected)
108 }
109
110 #[must_use]
111 pub fn new(options: ClaudeAgentOptions) -> Self {
112 Self {
113 options,
114 query: None,
115 message_rx: None,
116 mcp_servers: HashMap::new(),
117 }
118 }
119
120 pub fn add_mcp_server(
125 &mut self,
126 name: impl Into<String>,
127 server: SdkMcpServer,
128 ) -> Result<()> {
129 if self.is_connected() {
130 return Err(Error::AlreadyConnected);
131 }
132 self.mcp_servers
133 .insert(name.into(), Arc::new(Mutex::new(server)));
134 Ok(())
135 }
136
137 pub async fn connect(&mut self, initial_prompt: Option<&str>) -> Result<()> {
139 if self.query.is_some() {
140 return Err(Error::AlreadyConnected);
141 }
142
143 let cli_path = self.options.resolve_cli_path()?;
144 let transport = SubprocessTransport::new(cli_path, &self.options);
145
146 let mcp_handler = self.build_mcp_handler();
147
148 let mut q = Query::new(
149 Box::new(transport),
150 self.options.hooks.clone(),
151 self.options.can_use_tool.clone(),
152 mcp_handler,
153 self.options.control_timeout,
154 );
155
156 let rx = q.connect().await?;
157 self.message_rx = Some(rx);
158 self.query = Some(q);
159
160 if let Some(prompt) = initial_prompt {
161 self.query_ref()?.send_message(prompt, None).await?;
162 }
163
164 Ok(())
165 }
166
167 pub async fn query(&self, prompt: &str, session_id: Option<&str>) -> Result<()> {
169 self.query_ref()?.send_message(prompt, session_id).await
170 }
171
172 pub fn receive_messages(&mut self) -> MessageStream<'_> {
179 let rx = self.message_rx.take().unwrap_or_else(|| {
180 let (_tx, rx) = mpsc::channel(1);
181 rx
182 });
183 MessageStream::new(ReceiverStream::new(rx), &mut self.message_rx)
184 }
185
186 pub async fn receive_response(&mut self) -> Result<Vec<Message>> {
188 use crate::types::messages::collect_until_result;
189
190 let mut stream = self.receive_messages();
191 collect_until_result(&mut stream).await
192 }
194
195 pub async fn interrupt(&self) -> Result<Value> {
197 self.query_ref()?.interrupt().await
198 }
199
200 pub async fn set_permission_mode(&self, mode: &str) -> Result<Value> {
202 self.query_ref()?.set_permission_mode(mode).await
203 }
204
205 pub async fn set_model(&self, model: &str) -> Result<Value> {
207 self.query_ref()?.set_model(model).await
208 }
209
210 pub async fn rewind_files(&self, user_message_id: &str) -> Result<Value> {
212 self.query_ref()?.rewind_files(user_message_id).await
213 }
214
215 pub async fn get_mcp_status(&self) -> Result<Value> {
217 self.query_ref()?.get_mcp_status().await
218 }
219
220 pub async fn get_server_info(&self) -> Option<Value> {
222 match &self.query {
223 Some(q) => q.get_server_info().await,
224 None => None,
225 }
226 }
227
228 pub async fn disconnect(&mut self) -> Result<()> {
230 if let Some(mut q) = self.query.take() {
231 q.close().await?;
232 }
233 self.message_rx = None;
234 Ok(())
235 }
236
237 pub fn is_connected(&self) -> bool {
239 self.query.is_some()
240 }
241
242 fn build_mcp_handler(&self) -> Option<McpMessageHandler> {
243 if self.mcp_servers.is_empty() {
244 return None;
245 }
246
247 let servers = self.mcp_servers.clone();
248 Some(Arc::new(move |server_name: String, message: Value| {
249 let servers = servers.clone();
250 Box::pin(async move {
251 if let Some(server) = servers.get(&server_name) {
252 let srv = server.lock().await;
253 srv.handle_message(message).await
254 } else {
255 serde_json::json!({"error": format!("unknown MCP server: {server_name}")})
256 }
257 })
258 }))
259 }
260}