Skip to main content

dot/acp/
mod.rs

1mod transport;
2pub mod types;
3
4use transport::AcpTransport;
5pub use types::*;
6
7use anyhow::{Context, Result};
8
9pub enum AcpMessage {
10    Notification(SessionNotification),
11    IncomingRequest {
12        id: u64,
13        method: String,
14        params: serde_json::Value,
15    },
16    PromptComplete(PromptResponse),
17    Response {
18        id: u64,
19        result: std::result::Result<serde_json::Value, JsonRpcError>,
20    },
21}
22
23pub struct AcpClient {
24    transport: AcpTransport,
25    session_id: Option<SessionId>,
26    agent_info: Option<Implementation>,
27    agent_capabilities: Option<AgentCapabilities>,
28    modes: Option<SessionModeState>,
29    config_options: Option<Vec<SessionConfigOption>>,
30}
31
32impl AcpClient {
33    pub fn start(command: &str, args: &[String], env: &[(String, String)]) -> Result<Self> {
34        Ok(Self {
35            transport: AcpTransport::spawn(command, args, env)?,
36            session_id: None,
37            agent_info: None,
38            agent_capabilities: None,
39            modes: None,
40            config_options: None,
41        })
42    }
43
44    pub async fn initialize(&mut self) -> Result<InitializeResponse> {
45        let params = serde_json::to_value(InitializeRequest {
46            protocol_version: 1,
47            client_capabilities: ClientCapabilities {
48                fs: FsCapabilities {
49                    read_text_file: true,
50                    write_text_file: true,
51                },
52                terminal: true,
53            },
54            client_info: Some(Implementation {
55                name: "dot".into(),
56                title: Some("dot".into()),
57                version: Some(env!("CARGO_PKG_VERSION").into()),
58            }),
59        })
60        .context("serializing initialize request")?;
61
62        let raw = self.transport.send_request("initialize", params).await?;
63        let resp: InitializeResponse =
64            serde_json::from_value(raw).context("parsing initialize response")?;
65        self.agent_info = resp.agent_info.clone();
66        self.agent_capabilities = Some(resp.agent_capabilities.clone());
67        if let Some(ref info) = resp.agent_info {
68            tracing::info!(agent = %info.name, version = ?info.version, "ACP initialized");
69        }
70        Ok(resp)
71    }
72
73    pub async fn authenticate(&mut self, method_id: &str) -> Result<AuthenticateResponse> {
74        let params = serde_json::to_value(AuthenticateRequest {
75            method_id: method_id.into(),
76        })
77        .context("serializing authenticate request")?;
78        let raw = self.transport.send_request("authenticate", params).await?;
79        serde_json::from_value(raw).context("parsing authenticate response")
80    }
81
82    pub async fn new_session(
83        &mut self,
84        cwd: &str,
85        mcp_servers: Vec<McpServer>,
86    ) -> Result<NewSessionResponse> {
87        let params = serde_json::to_value(NewSessionRequest {
88            cwd: cwd.into(),
89            mcp_servers,
90        })
91        .context("serializing session/new request")?;
92        let raw = self.transport.send_request("session/new", params).await?;
93        let resp: NewSessionResponse =
94            serde_json::from_value(raw).context("parsing session/new response")?;
95        self.session_id = Some(resp.session_id.clone());
96        self.modes = resp.modes.clone();
97        self.config_options = resp.config_options.clone();
98        tracing::info!(session_id = %resp.session_id, "ACP session created");
99        Ok(resp)
100    }
101
102    pub async fn load_session(
103        &mut self,
104        session_id: &str,
105        cwd: &str,
106        mcp_servers: Vec<McpServer>,
107    ) -> Result<LoadSessionResponse> {
108        let params = serde_json::to_value(LoadSessionRequest {
109            session_id: session_id.into(),
110            cwd: cwd.into(),
111            mcp_servers,
112        })
113        .context("serializing session/load request")?;
114        let raw = self.transport.send_request("session/load", params).await?;
115        let resp: LoadSessionResponse =
116            serde_json::from_value(raw).context("parsing session/load response")?;
117        self.session_id = Some(session_id.into());
118        self.modes = resp.modes.clone();
119        self.config_options = resp.config_options.clone();
120        Ok(resp)
121    }
122
123    pub async fn send_prompt(&mut self, text: &str) -> Result<()> {
124        let sid = self
125            .session_id
126            .as_deref()
127            .context("no active session")?
128            .to_string();
129        let params = serde_json::to_value(PromptRequest {
130            session_id: sid,
131            prompt: vec![ContentBlock::Text { text: text.into() }],
132        })
133        .context("serializing session/prompt request")?;
134        let id = self.transport.next_id();
135        self.transport
136            .write_request(id, "session/prompt", params)
137            .await
138    }
139
140    pub async fn send_prompt_with_content(&mut self, content: Vec<ContentBlock>) -> Result<()> {
141        let sid = self
142            .session_id
143            .as_deref()
144            .context("no active session")?
145            .to_string();
146        let params = serde_json::to_value(PromptRequest {
147            session_id: sid,
148            prompt: content,
149        })
150        .context("serializing session/prompt request")?;
151        let id = self.transport.next_id();
152        self.transport
153            .write_request(id, "session/prompt", params)
154            .await
155    }
156
157    pub async fn read_next(&mut self) -> Result<AcpMessage> {
158        if let Some(n) = self.transport.buffered_notifications.pop_front()
159            && let Ok(sn) = serde_json::from_value::<SessionNotification>(n.params.clone())
160        {
161            return Ok(AcpMessage::Notification(sn));
162        }
163        if let Some(r) = self.transport.buffered_requests.pop_front() {
164            return Ok(AcpMessage::IncomingRequest {
165                id: r.id,
166                method: r.method,
167                params: r.params,
168            });
169        }
170        loop {
171            let msg = self.transport.read_message().await?;
172            match msg {
173                JsonRpcMessage::Notification(n) => {
174                    if let Ok(sn) = serde_json::from_value::<SessionNotification>(n.params.clone())
175                    {
176                        return Ok(AcpMessage::Notification(sn));
177                    }
178                }
179                JsonRpcMessage::Request(r) => {
180                    return Ok(AcpMessage::IncomingRequest {
181                        id: r.id,
182                        method: r.method,
183                        params: r.params,
184                    });
185                }
186                JsonRpcMessage::Response(resp) => {
187                    if let Some(err) = resp.error {
188                        return Ok(AcpMessage::Response {
189                            id: resp.id,
190                            result: Err(err),
191                        });
192                    }
193                    let result = resp.result.unwrap_or(serde_json::Value::Null);
194                    if let Ok(pr) = serde_json::from_value::<PromptResponse>(result.clone()) {
195                        return Ok(AcpMessage::PromptComplete(pr));
196                    }
197                    return Ok(AcpMessage::Response {
198                        id: resp.id,
199                        result: Ok(result),
200                    });
201                }
202            }
203        }
204    }
205
206    pub async fn cancel(&mut self) -> Result<()> {
207        let sid = self
208            .session_id
209            .as_deref()
210            .context("no active session")?
211            .to_string();
212        let params = serde_json::to_value(CancelNotification { session_id: sid })
213            .context("serializing cancel")?;
214        self.transport
215            .send_notification("session/cancel", params)
216            .await
217    }
218
219    pub async fn set_mode(&mut self, mode_id: &str) -> Result<SetSessionModeResponse> {
220        let sid = self
221            .session_id
222            .as_deref()
223            .context("no active session")?
224            .to_string();
225        let params = serde_json::to_value(SetSessionModeRequest {
226            session_id: sid,
227            mode_id: mode_id.into(),
228        })
229        .context("serializing set_mode request")?;
230        let raw = self
231            .transport
232            .send_request("session/set_mode", params)
233            .await?;
234        serde_json::from_value(raw).context("parsing set_mode response")
235    }
236
237    pub fn drain_notifications(&mut self) -> Vec<SessionNotification> {
238        self.transport
239            .drain_notifications()
240            .into_iter()
241            .filter_map(|n| serde_json::from_value::<SessionNotification>(n.params).ok())
242            .collect()
243    }
244
245    pub fn drain_incoming_requests(&mut self) -> Vec<JsonRpcRequest> {
246        self.transport.drain_requests()
247    }
248
249    pub async fn respond(&mut self, id: u64, result: serde_json::Value) -> Result<()> {
250        self.transport.send_response(id, result).await
251    }
252
253    pub async fn respond_error(&mut self, id: u64, code: i32, message: &str) -> Result<()> {
254        self.transport.send_error_response(id, code, message).await
255    }
256
257    pub fn session_id(&self) -> Option<&str> {
258        self.session_id.as_deref()
259    }
260
261    pub fn agent_info(&self) -> Option<&Implementation> {
262        self.agent_info.as_ref()
263    }
264
265    pub fn current_mode(&self) -> Option<&str> {
266        self.modes.as_ref().map(|m| m.current_mode_id.as_str())
267    }
268
269    pub fn available_modes(&self) -> &[SessionMode] {
270        self.modes
271            .as_ref()
272            .map(|m| m.available_modes.as_slice())
273            .unwrap_or(&[])
274    }
275
276    pub fn set_current_mode(&mut self, mode_id: &str) {
277        if let Some(ref mut modes) = self.modes {
278            modes.current_mode_id = mode_id.to_string();
279        }
280    }
281
282    pub fn config_options(&self) -> &[SessionConfigOption] {
283        self.config_options.as_deref().unwrap_or(&[])
284    }
285
286    pub fn set_config_options(&mut self, options: Vec<SessionConfigOption>) {
287        self.config_options = Some(options);
288    }
289
290    pub fn kill(&mut self) -> Result<()> {
291        self.transport.kill()
292    }
293}