1use super::error::AcpClientError;
2use super::event::AcpEvent;
3use super::prompt_handle::{AcpPromptHandle, PromptCommand};
4use crate::notifications::{ELICITATION_METHOD, ElicitationParams, McpRequest};
5use agent_client_protocol::{
6 self as acp, Agent, Client, ConfigOptionUpdate, ExtNotification, ExtRequest, ExtResponse,
7 InitializeRequest, PermissionOptionKind, RequestPermissionOutcome, RequestPermissionRequest,
8 RequestPermissionResponse, SelectedPermissionOutcome, SessionConfigOption, SessionId,
9 SessionNotification, SessionUpdate, SetSessionConfigOptionRequest,
10};
11use serde_json::value::RawValue;
12use std::process::Stdio;
13use std::sync::Arc;
14use std::thread::spawn;
15use tokio::process::Command;
16use tokio::sync::{mpsc, oneshot};
17use tokio::task::{LocalSet, spawn_local};
18use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
19use tracing::info;
20
21pub struct AcpSession {
23 pub session_id: SessionId,
24 pub agent_name: String,
25 pub prompt_capabilities: acp::PromptCapabilities,
26 pub config_options: Vec<SessionConfigOption>,
27 pub auth_methods: Vec<acp::AuthMethod>,
28 pub event_rx: mpsc::UnboundedReceiver<AcpEvent>,
29 pub prompt_handle: AcpPromptHandle,
30}
31
32pub struct AutoApproveClient {
35 event_tx: mpsc::UnboundedSender<AcpEvent>,
36}
37
38impl AutoApproveClient {
39 pub fn new(event_tx: mpsc::UnboundedSender<AcpEvent>) -> Self {
40 Self { event_tx }
41 }
42}
43
44#[async_trait::async_trait(?Send)]
45impl Client for AutoApproveClient {
46 async fn request_permission(
47 &self,
48 args: RequestPermissionRequest,
49 ) -> acp::Result<RequestPermissionResponse> {
50 let option_id = args
51 .options
52 .iter()
53 .find(|o| {
54 matches!(
55 o.kind,
56 PermissionOptionKind::AllowOnce | PermissionOptionKind::AllowAlways
57 )
58 })
59 .map_or_else(
60 || args.options[0].option_id.clone(),
61 |o| o.option_id.clone(),
62 );
63
64 Ok(RequestPermissionResponse::new(
65 RequestPermissionOutcome::Selected(SelectedPermissionOutcome::new(option_id)),
66 ))
67 }
68
69 async fn session_notification(&self, args: SessionNotification) -> acp::Result<()> {
70 let _ = self
71 .event_tx
72 .send(AcpEvent::SessionUpdate(Box::new(args.update)));
73
74 Ok(())
75 }
76
77 async fn ext_notification(&self, args: ExtNotification) -> acp::Result<()> {
78 let _ = self.event_tx.send(AcpEvent::ExtNotification(args));
79 Ok(())
80 }
81
82 async fn ext_method(&self, args: ExtRequest) -> acp::Result<ExtResponse> {
83 if args.method.as_ref() == ELICITATION_METHOD {
84 return handle_elicitation_ext_method(&self.event_tx, args).await;
85 }
86
87 let null_raw: Arc<RawValue> = serde_json::from_str("null").expect("null is valid JSON");
89 Ok(ExtResponse::new(null_raw))
90 }
91}
92
93async fn handle_elicitation_ext_method(
94 event_tx: &mpsc::UnboundedSender<AcpEvent>,
95 args: ExtRequest,
96) -> acp::Result<ExtResponse> {
97 let params: ElicitationParams =
98 serde_json::from_str(args.params.get()).map_err(|_| acp::Error::invalid_params())?;
99
100 let (response_tx, response_rx) = oneshot::channel();
101 event_tx
102 .send(AcpEvent::ElicitationRequest {
103 params,
104 response_tx,
105 })
106 .map_err(|_| acp::Error::internal_error())?;
107
108 let response = response_rx
109 .await
110 .map_err(|_| acp::Error::internal_error())?;
111
112 let raw =
113 serde_json::value::to_raw_value(&response).map_err(|_| acp::Error::internal_error())?;
114
115 Ok(ExtResponse::new(Arc::from(raw)))
116}
117
118pub async fn spawn_acp_session<F, C>(
129 agent_command: &str,
130 init_request: InitializeRequest,
131 new_session_request: acp::NewSessionRequest,
132 client_factory: F,
133) -> Result<AcpSession, AcpClientError>
134where
135 F: FnOnce(mpsc::UnboundedSender<AcpEvent>) -> C + Send + 'static,
136 C: acp::Client + 'static,
137{
138 let parts: Vec<&str> = agent_command.split_whitespace().collect();
139 let (program, args) = parts
140 .split_first()
141 .ok_or_else(|| AcpClientError::AgentCrashed("empty agent command".to_string()))?;
142
143 let mut child = Command::new(program)
144 .args(args)
145 .stdin(Stdio::piped())
146 .stdout(Stdio::piped())
147 .stderr(Stdio::inherit())
148 .spawn()
149 .map_err(AcpClientError::SpawnFailed)?;
150
151 let child_stdin = child
152 .stdin
153 .take()
154 .ok_or_else(|| AcpClientError::AgentCrashed("no stdin on child".to_string()))?;
155
156 let child_stdout = child
157 .stdout
158 .take()
159 .ok_or_else(|| AcpClientError::AgentCrashed("no stdout on child".to_string()))?;
160
161 let (event_tx, event_rx) = mpsc::unbounded_channel::<AcpEvent>();
162 let (cmd_tx, cmd_rx) = mpsc::unbounded_channel::<PromptCommand>();
163 let (session_tx, session_rx) = oneshot::channel::<HandshakeResult>();
164 let thread_ctx = AcpThreadContext {
165 child_stdin,
166 child_stdout,
167 event_tx,
168 cmd_rx,
169 session_tx,
170 client_factory,
171 init_request,
172 new_session_request,
173 };
174
175 spawn(move || {
176 let rt = tokio::runtime::Builder::new_current_thread()
177 .enable_all()
178 .build()
179 .expect("failed to build tokio runtime for ACP");
180
181 LocalSet::new().block_on(&rt, async move {
182 run_acp_thread(thread_ctx).await;
183 });
184 });
185
186 let handshake = session_rx.await.map_err(|_| {
187 AcpClientError::AgentCrashed("ACP thread died during handshake".to_string())
188 })??;
189
190 Ok(AcpSession {
191 session_id: handshake.session_id,
192 agent_name: handshake.agent_name,
193 prompt_capabilities: handshake.prompt_capabilities,
194 config_options: handshake.config_options,
195 auth_methods: handshake.auth_methods,
196 event_rx,
197 prompt_handle: AcpPromptHandle { cmd_tx },
198 })
199}
200
201struct HandshakeData {
202 session_id: acp::SessionId,
203 agent_name: String,
204 prompt_capabilities: acp::PromptCapabilities,
205 config_options: Vec<acp::SessionConfigOption>,
206 auth_methods: Vec<acp::AuthMethod>,
207}
208
209type HandshakeResult = Result<HandshakeData, AcpClientError>;
210
211struct AcpThreadContext<F> {
212 child_stdin: tokio::process::ChildStdin,
213 child_stdout: tokio::process::ChildStdout,
214 event_tx: mpsc::UnboundedSender<AcpEvent>,
215 cmd_rx: mpsc::UnboundedReceiver<PromptCommand>,
216 session_tx: oneshot::Sender<HandshakeResult>,
217 client_factory: F,
218 init_request: InitializeRequest,
219 new_session_request: acp::NewSessionRequest,
220}
221
222#[allow(clippy::too_many_lines)]
223async fn run_acp_thread<F, C>(ctx: AcpThreadContext<F>)
224where
225 F: FnOnce(mpsc::UnboundedSender<AcpEvent>) -> C,
226 C: Client + 'static,
227{
228 let AcpThreadContext {
229 child_stdin,
230 child_stdout,
231 event_tx,
232 mut cmd_rx,
233 session_tx,
234 client_factory,
235 init_request,
236 new_session_request,
237 } = ctx;
238
239 let client = client_factory(event_tx.clone());
240 let outgoing = child_stdin.compat_write();
241 let incoming = child_stdout.compat();
242 let (conn, handle_io) = acp::ClientSideConnection::new(client, outgoing, incoming, |fut| {
243 spawn_local(fut);
244 });
245
246 spawn_local(async move {
247 let _ = handle_io.await;
248 });
249
250 let init_resp = match conn.initialize(init_request).await {
251 Ok(r) => r,
252 Err(e) => {
253 let _ = session_tx.send(Err(AcpClientError::Protocol(e)));
254 return;
255 }
256 };
257
258 let agent_name = init_resp.agent_info.as_ref().map_or_else(
259 || "agent".to_string(),
260 |info| info.title.as_deref().unwrap_or(&info.name).to_string(),
261 );
262 let prompt_capabilities = init_resp.agent_capabilities.prompt_capabilities.clone();
263
264 info!(
265 "ACP initialized: protocol={:?}, agent_info={:?}",
266 init_resp.protocol_version, init_resp.agent_info
267 );
268
269 let auth_methods = init_resp.auth_methods;
270
271 let session_resp = match conn.new_session(new_session_request).await {
272 Ok(r) => r,
273 Err(e) => {
274 let _ = session_tx.send(Err(AcpClientError::Protocol(e)));
275 return;
276 }
277 };
278
279 let session_id = session_resp.session_id;
280 info!("ACP session created: {session_id}");
281
282 let config_options = session_resp.config_options.unwrap_or_default();
283 let _ = session_tx.send(Ok(HandshakeData {
284 session_id,
285 agent_name,
286 prompt_capabilities,
287 config_options,
288 auth_methods,
289 }));
290
291 while let Some(cmd) = cmd_rx.recv().await {
292 match cmd {
293 PromptCommand::Prompt {
294 session_id,
295 text,
296 content,
297 } => {
298 let mut prompt = vec![acp::ContentBlock::Text(acp::TextContent::new(text))];
299 if let Some(extra_content) = content {
300 prompt.extend(extra_content);
301 }
302 let prompt_fut = conn.prompt(acp::PromptRequest::new(session_id, prompt));
303 tokio::pin!(prompt_fut);
304
305 loop {
307 tokio::select! {
308 result = &mut prompt_fut => {
309 let event = match result {
310 Ok(resp) => AcpEvent::PromptDone(resp.stop_reason),
311 Err(e) => AcpEvent::PromptError(e),
312 };
313 let _ = event_tx.send(event);
314 break;
315 }
316 Some(cmd) = cmd_rx.recv() => {
317 handle_side_command(&conn, &event_tx, cmd).await;
318 }
319 }
320 }
321 }
322 PromptCommand::ListSessions => {
323 let req = acp::ListSessionsRequest::new();
324 match conn.list_sessions(req).await {
325 Ok(resp) => {
326 let _ = event_tx.send(AcpEvent::SessionsListed {
327 sessions: resp.sessions,
328 });
329 }
330 Err(e) => {
331 let _ = event_tx.send(AcpEvent::PromptError(e));
332 }
333 }
334 }
335 PromptCommand::LoadSession { session_id, cwd } => {
336 let req = acp::LoadSessionRequest::new(session_id.clone(), cwd);
337 match conn.load_session(req).await {
338 Ok(resp) => {
339 let config_options = resp.config_options.unwrap_or_default();
340 let _ = event_tx.send(AcpEvent::SessionLoaded {
341 session_id,
342 config_options,
343 });
344 }
345 Err(e) => {
346 let _ = event_tx.send(AcpEvent::PromptError(e));
347 }
348 }
349 }
350 PromptCommand::NewSession { cwd } => {
351 let req = acp::NewSessionRequest::new(cwd);
352 match conn.new_session(req).await {
353 Ok(resp) => {
354 let config_options = resp.config_options.unwrap_or_default();
355 let _ = event_tx.send(AcpEvent::NewSessionCreated {
356 session_id: resp.session_id,
357 config_options,
358 });
359 }
360 Err(e) => {
361 let _ = event_tx.send(AcpEvent::PromptError(e));
362 }
363 }
364 }
365 cmd => handle_side_command(&conn, &event_tx, cmd).await,
366 }
367 }
368
369 let _ = event_tx.send(AcpEvent::ConnectionClosed);
370}
371
372async fn handle_side_command(
373 conn: &acp::ClientSideConnection,
374 event_tx: &mpsc::UnboundedSender<AcpEvent>,
375 cmd: PromptCommand,
376) {
377 match cmd {
378 PromptCommand::Cancel { session_id } => {
379 let _ = conn.cancel(acp::CancelNotification::new(session_id)).await;
380 }
381 PromptCommand::SetConfigOption {
382 session_id,
383 config_id,
384 value,
385 } => {
386 let req = SetSessionConfigOptionRequest::new(session_id, config_id, value);
387 match conn.set_session_config_option(req).await {
388 Ok(resp) => {
389 let update = ConfigOptionUpdate::new(resp.config_options);
390 let _ = event_tx.send(AcpEvent::SessionUpdate(Box::new(
391 SessionUpdate::ConfigOptionUpdate(update),
392 )));
393 }
394 Err(e) => {
395 tracing::warn!("set_session_config_option failed: {e:?}");
396 }
397 }
398 }
399 PromptCommand::Prompt { .. } => {
400 tracing::warn!("ignoring duplicate Prompt while one is in-flight");
401 }
402 PromptCommand::ListSessions => {
403 tracing::warn!("ignoring ListSessions while prompt is in-flight");
404 }
405 PromptCommand::LoadSession { .. } => {
406 tracing::warn!("ignoring LoadSession while prompt is in-flight");
407 }
408 PromptCommand::NewSession { .. } => {
409 tracing::warn!("ignoring NewSession while prompt is in-flight");
410 }
411 PromptCommand::AuthenticateMcpServer {
412 session_id,
413 server_name,
414 } => {
415 let msg = McpRequest::Authenticate {
416 session_id: session_id.0.to_string(),
417 server_name,
418 };
419 if let Err(e) = conn.ext_notification(msg.into()).await {
420 tracing::warn!("authenticate_mcp_server notification failed: {e:?}");
421 }
422 }
423 PromptCommand::Authenticate {
424 session_id: _,
425 method_id,
426 } => {
427 match conn
428 .authenticate(acp::AuthenticateRequest::new(method_id.clone()))
429 .await
430 {
431 Ok(_) => {
432 let _ = event_tx.send(AcpEvent::AuthenticateComplete { method_id });
433 }
434 Err(e) => {
435 tracing::warn!("authenticate failed: {e:?}");
436 let _ = event_tx.send(AcpEvent::AuthenticateFailed {
437 method_id,
438 error: format!("{e:?}"),
439 });
440 }
441 }
442 }
443 }
444}