1use serde_json::Value;
2use tokio::sync::mpsc;
3use tokio_stream::wrappers::ReceiverStream;
4
5use crate::error::{Error, Result};
6use crate::mcp::SdkMcpServer;
7use crate::query::{McpMessageHandler, Query};
8use crate::transport::cli_discovery;
9use crate::transport::subprocess::SubprocessTransport;
10use crate::types::messages::Message;
11use crate::types::options::ClaudeAgentOptions;
12
13pub struct ClaudeSDKClient {
45 options: ClaudeAgentOptions,
46 query: Option<Query>,
47 message_rx: Option<mpsc::Receiver<Result<Message>>>,
48 mcp_servers: Vec<SdkMcpServer>,
49}
50
51impl ClaudeSDKClient {
52 pub fn new(options: ClaudeAgentOptions) -> Self {
53 Self {
54 options,
55 query: None,
56 message_rx: None,
57 mcp_servers: Vec::new(),
58 }
59 }
60
61 pub fn add_mcp_server(&mut self, server: SdkMcpServer) {
63 self.mcp_servers.push(server);
64 }
65
66 pub async fn connect(&mut self, initial_prompt: Option<&str>) -> Result<()> {
68 if self.query.is_some() {
69 return Err(Error::AlreadyConnected);
70 }
71
72 let cli_path = match self.options.cli_path {
73 Some(ref p) => p.clone(),
74 None => cli_discovery::find_cli()?,
75 };
76
77 let transport = SubprocessTransport::new(cli_path, &self.options);
78
79 let mcp_handler = self.build_mcp_handler();
80
81 let mut q = Query::new(
82 Box::new(transport),
83 std::mem::take(&mut self.options.hooks),
84 self.options.can_use_tool.take(),
85 mcp_handler,
86 self.options.control_timeout,
87 );
88
89 let rx = q.connect().await?;
90 self.message_rx = Some(rx);
91 self.query = Some(q);
92
93 if let Some(prompt) = initial_prompt {
94 self.query
95 .as_ref()
96 .unwrap()
97 .send_message(prompt, None)
98 .await?;
99 }
100
101 Ok(())
102 }
103
104 pub async fn query(&self, prompt: &str, session_id: Option<&str>) -> Result<()> {
106 let q = self.query.as_ref().ok_or(Error::NotConnected)?;
107 q.send_message(prompt, session_id).await
108 }
109
110 pub fn receive_messages(&mut self) -> ReceiverStream<Result<Message>> {
114 let rx = self.message_rx.take().unwrap_or_else(|| {
115 let (_tx, rx) = mpsc::channel(1);
116 rx
117 });
118 ReceiverStream::new(rx)
119 }
120
121 pub async fn receive_response(&mut self) -> Result<Vec<Message>> {
123 use tokio_stream::StreamExt;
124
125 let mut stream = self.receive_messages();
126 let mut messages = Vec::new();
127
128 while let Some(msg) = stream.next().await {
129 let msg = msg?;
130 let is_result = msg.is_result();
131 messages.push(msg);
132 if is_result {
133 break;
134 }
135 }
136
137 self.message_rx = Some(stream.into_inner());
139
140 Ok(messages)
141 }
142
143 pub async fn interrupt(&self) -> Result<Value> {
145 let q = self.query.as_ref().ok_or(Error::NotConnected)?;
146 q.interrupt().await
147 }
148
149 pub async fn set_permission_mode(&self, mode: &str) -> Result<Value> {
151 let q = self.query.as_ref().ok_or(Error::NotConnected)?;
152 q.set_permission_mode(mode).await
153 }
154
155 pub async fn set_model(&self, model: &str) -> Result<Value> {
157 let q = self.query.as_ref().ok_or(Error::NotConnected)?;
158 q.set_model(model).await
159 }
160
161 pub async fn rewind_files(&self, user_message_id: &str) -> Result<Value> {
163 let q = self.query.as_ref().ok_or(Error::NotConnected)?;
164 q.rewind_files(user_message_id).await
165 }
166
167 pub async fn get_mcp_status(&self) -> Result<Value> {
169 let q = self.query.as_ref().ok_or(Error::NotConnected)?;
170 q.get_mcp_status().await
171 }
172
173 pub async fn get_server_info(&self) -> Option<Value> {
175 match &self.query {
176 Some(q) => q.get_server_info().await,
177 None => None,
178 }
179 }
180
181 pub async fn disconnect(&mut self) -> Result<()> {
183 if let Some(mut q) = self.query.take() {
184 q.close().await?;
185 }
186 self.message_rx = None;
187 Ok(())
188 }
189
190 pub fn is_connected(&self) -> bool {
192 self.query.is_some()
193 }
194
195 fn build_mcp_handler(&self) -> Option<McpMessageHandler> {
196 if self.mcp_servers.is_empty() {
197 return None;
198 }
199
200 None
204 }
207}