1mod transport;
2pub mod types;
3
4use transport::AcpTransport;
5pub use types::*;
6
7use anyhow::{Context, Result};
8
9pub enum AcpMessage {
10 Notification(SessionNotification),
11 IncomingRequest {
12 id: u64,
13 method: String,
14 params: serde_json::Value,
15 },
16 PromptComplete(PromptResponse),
17 Response {
18 id: u64,
19 result: std::result::Result<serde_json::Value, JsonRpcError>,
20 },
21}
22
23pub struct AcpClient {
24 transport: AcpTransport,
25 session_id: Option<SessionId>,
26 agent_info: Option<Implementation>,
27 agent_capabilities: Option<AgentCapabilities>,
28 modes: Option<SessionModeState>,
29 config_options: Option<Vec<SessionConfigOption>>,
30}
31
32impl AcpClient {
33 pub fn start(command: &str, args: &[String], env: &[(String, String)]) -> Result<Self> {
34 Ok(Self {
35 transport: AcpTransport::spawn(command, args, env)?,
36 session_id: None,
37 agent_info: None,
38 agent_capabilities: None,
39 modes: None,
40 config_options: None,
41 })
42 }
43
44 pub async fn initialize(&mut self) -> Result<InitializeResponse> {
45 let params = serde_json::to_value(InitializeRequest {
46 protocol_version: 1,
47 client_capabilities: ClientCapabilities {
48 fs: FsCapabilities {
49 read_text_file: true,
50 write_text_file: true,
51 },
52 terminal: true,
53 },
54 client_info: Some(Implementation {
55 name: "dot".into(),
56 title: Some("dot".into()),
57 version: Some(env!("CARGO_PKG_VERSION").into()),
58 }),
59 })
60 .context("serializing initialize request")?;
61
62 let raw = self.transport.send_request("initialize", params).await?;
63 let resp: InitializeResponse =
64 serde_json::from_value(raw).context("parsing initialize response")?;
65 self.agent_info = resp.agent_info.clone();
66 self.agent_capabilities = Some(resp.agent_capabilities.clone());
67 if let Some(ref info) = resp.agent_info {
68 tracing::info!(agent = %info.name, version = ?info.version, "ACP initialized");
69 }
70 Ok(resp)
71 }
72
73 pub async fn authenticate(&mut self, method_id: &str) -> Result<AuthenticateResponse> {
74 let params = serde_json::to_value(AuthenticateRequest {
75 method_id: method_id.into(),
76 })
77 .context("serializing authenticate request")?;
78 let raw = self.transport.send_request("authenticate", params).await?;
79 serde_json::from_value(raw).context("parsing authenticate response")
80 }
81
82 pub async fn new_session(
83 &mut self,
84 cwd: &str,
85 mcp_servers: Vec<McpServer>,
86 ) -> Result<NewSessionResponse> {
87 let params = serde_json::to_value(NewSessionRequest {
88 cwd: cwd.into(),
89 mcp_servers,
90 })
91 .context("serializing session/new request")?;
92 let raw = self.transport.send_request("session/new", params).await?;
93 let resp: NewSessionResponse =
94 serde_json::from_value(raw).context("parsing session/new response")?;
95 self.session_id = Some(resp.session_id.clone());
96 self.modes = resp.modes.clone();
97 self.config_options = resp.config_options.clone();
98 tracing::info!(session_id = %resp.session_id, "ACP session created");
99 Ok(resp)
100 }
101
102 pub async fn load_session(
103 &mut self,
104 session_id: &str,
105 cwd: &str,
106 mcp_servers: Vec<McpServer>,
107 ) -> Result<LoadSessionResponse> {
108 let params = serde_json::to_value(LoadSessionRequest {
109 session_id: session_id.into(),
110 cwd: cwd.into(),
111 mcp_servers,
112 })
113 .context("serializing session/load request")?;
114 let raw = self.transport.send_request("session/load", params).await?;
115 let resp: LoadSessionResponse =
116 serde_json::from_value(raw).context("parsing session/load response")?;
117 self.session_id = Some(session_id.into());
118 self.modes = resp.modes.clone();
119 self.config_options = resp.config_options.clone();
120 Ok(resp)
121 }
122
123 pub async fn send_prompt(&mut self, text: &str) -> Result<()> {
124 let sid = self
125 .session_id
126 .as_deref()
127 .context("no active session")?
128 .to_string();
129 let params = serde_json::to_value(PromptRequest {
130 session_id: sid,
131 prompt: vec![ContentBlock::Text { text: text.into() }],
132 })
133 .context("serializing session/prompt request")?;
134 let id = self.transport.next_id();
135 self.transport
136 .write_request(id, "session/prompt", params)
137 .await
138 }
139
140 pub async fn send_prompt_with_content(&mut self, content: Vec<ContentBlock>) -> Result<()> {
141 let sid = self
142 .session_id
143 .as_deref()
144 .context("no active session")?
145 .to_string();
146 let params = serde_json::to_value(PromptRequest {
147 session_id: sid,
148 prompt: content,
149 })
150 .context("serializing session/prompt request")?;
151 let id = self.transport.next_id();
152 self.transport
153 .write_request(id, "session/prompt", params)
154 .await
155 }
156
157 pub async fn read_next(&mut self) -> Result<AcpMessage> {
158 if let Some(n) = self.transport.buffered_notifications.pop_front()
159 && let Ok(sn) = serde_json::from_value::<SessionNotification>(n.params.clone())
160 {
161 return Ok(AcpMessage::Notification(sn));
162 }
163 if let Some(r) = self.transport.buffered_requests.pop_front() {
164 return Ok(AcpMessage::IncomingRequest {
165 id: r.id,
166 method: r.method,
167 params: r.params,
168 });
169 }
170 loop {
171 let msg = self.transport.read_message().await?;
172 match msg {
173 JsonRpcMessage::Notification(n) => {
174 if let Ok(sn) = serde_json::from_value::<SessionNotification>(n.params.clone())
175 {
176 return Ok(AcpMessage::Notification(sn));
177 }
178 }
179 JsonRpcMessage::Request(r) => {
180 return Ok(AcpMessage::IncomingRequest {
181 id: r.id,
182 method: r.method,
183 params: r.params,
184 });
185 }
186 JsonRpcMessage::Response(resp) => {
187 if let Some(err) = resp.error {
188 return Ok(AcpMessage::Response {
189 id: resp.id,
190 result: Err(err),
191 });
192 }
193 let result = resp.result.unwrap_or(serde_json::Value::Null);
194 if let Ok(pr) = serde_json::from_value::<PromptResponse>(result.clone()) {
195 return Ok(AcpMessage::PromptComplete(pr));
196 }
197 return Ok(AcpMessage::Response {
198 id: resp.id,
199 result: Ok(result),
200 });
201 }
202 }
203 }
204 }
205
206 pub async fn cancel(&mut self) -> Result<()> {
207 let sid = self
208 .session_id
209 .as_deref()
210 .context("no active session")?
211 .to_string();
212 let params = serde_json::to_value(CancelNotification { session_id: sid })
213 .context("serializing cancel")?;
214 self.transport
215 .send_notification("session/cancel", params)
216 .await
217 }
218
219 pub async fn set_mode(&mut self, mode_id: &str) -> Result<SetSessionModeResponse> {
220 let sid = self
221 .session_id
222 .as_deref()
223 .context("no active session")?
224 .to_string();
225 let params = serde_json::to_value(SetSessionModeRequest {
226 session_id: sid,
227 mode_id: mode_id.into(),
228 })
229 .context("serializing set_mode request")?;
230 let raw = self
231 .transport
232 .send_request("session/set_mode", params)
233 .await?;
234 serde_json::from_value(raw).context("parsing set_mode response")
235 }
236
237 pub fn drain_notifications(&mut self) -> Vec<SessionNotification> {
238 self.transport
239 .drain_notifications()
240 .into_iter()
241 .filter_map(|n| serde_json::from_value::<SessionNotification>(n.params).ok())
242 .collect()
243 }
244
245 pub fn drain_incoming_requests(&mut self) -> Vec<JsonRpcRequest> {
246 self.transport.drain_requests()
247 }
248
249 pub async fn respond(&mut self, id: u64, result: serde_json::Value) -> Result<()> {
250 self.transport.send_response(id, result).await
251 }
252
253 pub async fn respond_error(&mut self, id: u64, code: i32, message: &str) -> Result<()> {
254 self.transport.send_error_response(id, code, message).await
255 }
256
257 pub fn session_id(&self) -> Option<&str> {
258 self.session_id.as_deref()
259 }
260
261 pub fn agent_info(&self) -> Option<&Implementation> {
262 self.agent_info.as_ref()
263 }
264
265 pub fn current_mode(&self) -> Option<&str> {
266 self.modes.as_ref().map(|m| m.current_mode_id.as_str())
267 }
268
269 pub fn available_modes(&self) -> &[SessionMode] {
270 self.modes
271 .as_ref()
272 .map(|m| m.available_modes.as_slice())
273 .unwrap_or(&[])
274 }
275
276 pub fn set_current_mode(&mut self, mode_id: &str) {
277 if let Some(ref mut modes) = self.modes {
278 modes.current_mode_id = mode_id.to_string();
279 }
280 }
281
282 pub fn config_options(&self) -> &[SessionConfigOption] {
283 self.config_options.as_deref().unwrap_or(&[])
284 }
285
286 pub fn set_config_options(&mut self, options: Vec<SessionConfigOption>) {
287 self.config_options = Some(options);
288 }
289
290 pub fn kill(&mut self) -> Result<()> {
291 self.transport.kill()
292 }
293}