1use super::error::AcpClientError;
2use super::event::AcpEvent;
3use super::prompt_handle::{AcpPromptHandle, PromptCommand};
4use crate::notifications::{ELICITATION_METHOD, ElicitationParams, McpRequest};
5use agent_client_protocol::{
6 self as acp, Agent, Client, ConfigOptionUpdate, ExtNotification, ExtRequest, ExtResponse, InitializeRequest,
7 PermissionOptionKind, RequestPermissionOutcome, RequestPermissionRequest, RequestPermissionResponse,
8 SelectedPermissionOutcome, SessionConfigOption, SessionId, SessionNotification, SessionUpdate,
9 SetSessionConfigOptionRequest,
10};
11use serde_json::value::RawValue;
12use std::process::Stdio;
13use std::sync::Arc;
14use std::thread::spawn;
15use tokio::process::Command;
16use tokio::sync::{mpsc, oneshot};
17use tokio::task::{LocalSet, spawn_local};
18use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
19use tracing::info;
20
21pub struct AcpSession {
23 pub session_id: SessionId,
24 pub agent_name: String,
25 pub prompt_capabilities: acp::PromptCapabilities,
26 pub config_options: Vec<SessionConfigOption>,
27 pub auth_methods: Vec<acp::AuthMethod>,
28 pub event_rx: mpsc::UnboundedReceiver<AcpEvent>,
29 pub prompt_handle: AcpPromptHandle,
30}
31
32pub struct AutoApproveClient {
35 event_tx: mpsc::UnboundedSender<AcpEvent>,
36}
37
38impl AutoApproveClient {
39 pub fn new(event_tx: mpsc::UnboundedSender<AcpEvent>) -> Self {
40 Self { event_tx }
41 }
42}
43
44#[async_trait::async_trait(?Send)]
45impl Client for AutoApproveClient {
46 async fn request_permission(&self, args: RequestPermissionRequest) -> acp::Result<RequestPermissionResponse> {
47 let option_id = args
48 .options
49 .iter()
50 .find(|o| matches!(o.kind, PermissionOptionKind::AllowOnce | PermissionOptionKind::AllowAlways))
51 .map_or_else(|| args.options[0].option_id.clone(), |o| o.option_id.clone());
52
53 Ok(RequestPermissionResponse::new(RequestPermissionOutcome::Selected(SelectedPermissionOutcome::new(
54 option_id,
55 ))))
56 }
57
58 async fn session_notification(&self, args: SessionNotification) -> acp::Result<()> {
59 let _ = self.event_tx.send(AcpEvent::SessionUpdate(Box::new(args.update)));
60
61 Ok(())
62 }
63
64 async fn ext_notification(&self, args: ExtNotification) -> acp::Result<()> {
65 let _ = self.event_tx.send(AcpEvent::ExtNotification(args));
66 Ok(())
67 }
68
69 async fn ext_method(&self, args: ExtRequest) -> acp::Result<ExtResponse> {
70 if args.method.as_ref() == ELICITATION_METHOD {
71 return handle_elicitation_ext_method(&self.event_tx, args).await;
72 }
73
74 let null_raw: Arc<RawValue> = serde_json::from_str("null").expect("null is valid JSON");
76 Ok(ExtResponse::new(null_raw))
77 }
78}
79
80async fn handle_elicitation_ext_method(
81 event_tx: &mpsc::UnboundedSender<AcpEvent>,
82 args: ExtRequest,
83) -> acp::Result<ExtResponse> {
84 let params: ElicitationParams =
85 serde_json::from_str(args.params.get()).map_err(|_| acp::Error::invalid_params())?;
86
87 let (response_tx, response_rx) = oneshot::channel();
88 event_tx.send(AcpEvent::ElicitationRequest { params, response_tx }).map_err(|_| acp::Error::internal_error())?;
89
90 let response = response_rx.await.map_err(|_| acp::Error::internal_error())?;
91
92 let raw = serde_json::value::to_raw_value(&response).map_err(|_| acp::Error::internal_error())?;
93
94 Ok(ExtResponse::new(Arc::from(raw)))
95}
96
97pub async fn spawn_acp_session<F, C>(
108 agent_command: &str,
109 init_request: InitializeRequest,
110 new_session_request: acp::NewSessionRequest,
111 client_factory: F,
112) -> Result<AcpSession, AcpClientError>
113where
114 F: FnOnce(mpsc::UnboundedSender<AcpEvent>) -> C + Send + 'static,
115 C: acp::Client + 'static,
116{
117 let parts: Vec<&str> = agent_command.split_whitespace().collect();
118 let (program, args) =
119 parts.split_first().ok_or_else(|| AcpClientError::AgentCrashed("empty agent command".to_string()))?;
120
121 let mut child = Command::new(program)
122 .args(args)
123 .stdin(Stdio::piped())
124 .stdout(Stdio::piped())
125 .stderr(Stdio::inherit())
126 .spawn()
127 .map_err(AcpClientError::SpawnFailed)?;
128
129 let child_stdin =
130 child.stdin.take().ok_or_else(|| AcpClientError::AgentCrashed("no stdin on child".to_string()))?;
131
132 let child_stdout =
133 child.stdout.take().ok_or_else(|| AcpClientError::AgentCrashed("no stdout on child".to_string()))?;
134
135 let (event_tx, event_rx) = mpsc::unbounded_channel::<AcpEvent>();
136 let (cmd_tx, cmd_rx) = mpsc::unbounded_channel::<PromptCommand>();
137 let (session_tx, session_rx) = oneshot::channel::<HandshakeResult>();
138 let thread_ctx = AcpThreadContext {
139 child_stdin,
140 child_stdout,
141 event_tx,
142 cmd_rx,
143 session_tx,
144 client_factory,
145 init_request,
146 new_session_request,
147 };
148
149 spawn(move || {
150 let rt = tokio::runtime::Builder::new_current_thread()
151 .enable_all()
152 .build()
153 .expect("failed to build tokio runtime for ACP");
154
155 LocalSet::new().block_on(&rt, async move {
156 run_acp_thread(thread_ctx).await;
157 });
158 });
159
160 let handshake = session_rx
161 .await
162 .map_err(|_| AcpClientError::AgentCrashed("ACP thread died during handshake".to_string()))??;
163
164 Ok(AcpSession {
165 session_id: handshake.session_id,
166 agent_name: handshake.agent_name,
167 prompt_capabilities: handshake.prompt_capabilities,
168 config_options: handshake.config_options,
169 auth_methods: handshake.auth_methods,
170 event_rx,
171 prompt_handle: AcpPromptHandle { cmd_tx },
172 })
173}
174
175struct HandshakeData {
176 session_id: acp::SessionId,
177 agent_name: String,
178 prompt_capabilities: acp::PromptCapabilities,
179 config_options: Vec<acp::SessionConfigOption>,
180 auth_methods: Vec<acp::AuthMethod>,
181}
182
183type HandshakeResult = Result<HandshakeData, AcpClientError>;
184
185struct AcpThreadContext<F> {
186 child_stdin: tokio::process::ChildStdin,
187 child_stdout: tokio::process::ChildStdout,
188 event_tx: mpsc::UnboundedSender<AcpEvent>,
189 cmd_rx: mpsc::UnboundedReceiver<PromptCommand>,
190 session_tx: oneshot::Sender<HandshakeResult>,
191 client_factory: F,
192 init_request: InitializeRequest,
193 new_session_request: acp::NewSessionRequest,
194}
195
196#[allow(clippy::too_many_lines)]
197async fn run_acp_thread<F, C>(ctx: AcpThreadContext<F>)
198where
199 F: FnOnce(mpsc::UnboundedSender<AcpEvent>) -> C,
200 C: Client + 'static,
201{
202 let AcpThreadContext {
203 child_stdin,
204 child_stdout,
205 event_tx,
206 mut cmd_rx,
207 session_tx,
208 client_factory,
209 init_request,
210 new_session_request,
211 } = ctx;
212
213 let client = client_factory(event_tx.clone());
214 let outgoing = child_stdin.compat_write();
215 let incoming = child_stdout.compat();
216 let (conn, handle_io) = acp::ClientSideConnection::new(client, outgoing, incoming, |fut| {
217 spawn_local(fut);
218 });
219
220 spawn_local(async move {
221 let _ = handle_io.await;
222 });
223
224 let init_resp = match conn.initialize(init_request).await {
225 Ok(r) => r,
226 Err(e) => {
227 let _ = session_tx.send(Err(AcpClientError::Protocol(e)));
228 return;
229 }
230 };
231
232 let agent_name = init_resp
233 .agent_info
234 .as_ref()
235 .map_or_else(|| "agent".to_string(), |info| info.title.as_deref().unwrap_or(&info.name).to_string());
236 let prompt_capabilities = init_resp.agent_capabilities.prompt_capabilities.clone();
237
238 info!("ACP initialized: protocol={:?}, agent_info={:?}", init_resp.protocol_version, init_resp.agent_info);
239
240 let auth_methods = init_resp.auth_methods;
241
242 let session_resp = match conn.new_session(new_session_request).await {
243 Ok(r) => r,
244 Err(e) => {
245 let _ = session_tx.send(Err(AcpClientError::Protocol(e)));
246 return;
247 }
248 };
249
250 let session_id = session_resp.session_id;
251 info!("ACP session created: {session_id}");
252
253 let config_options = session_resp.config_options.unwrap_or_default();
254 let _ = session_tx.send(Ok(HandshakeData {
255 session_id,
256 agent_name,
257 prompt_capabilities,
258 config_options,
259 auth_methods,
260 }));
261
262 while let Some(cmd) = cmd_rx.recv().await {
263 match cmd {
264 PromptCommand::Prompt { session_id, text, content } => {
265 let mut prompt = vec![acp::ContentBlock::Text(acp::TextContent::new(text))];
266 if let Some(extra_content) = content {
267 prompt.extend(extra_content);
268 }
269 let prompt_fut = conn.prompt(acp::PromptRequest::new(session_id, prompt));
270 tokio::pin!(prompt_fut);
271
272 loop {
274 tokio::select! {
275 result = &mut prompt_fut => {
276 let event = match result {
277 Ok(resp) => AcpEvent::PromptDone(resp.stop_reason),
278 Err(e) => AcpEvent::PromptError(e),
279 };
280 let _ = event_tx.send(event);
281 break;
282 }
283 Some(cmd) = cmd_rx.recv() => {
284 handle_side_command(&conn, &event_tx, cmd).await;
285 }
286 }
287 }
288 }
289 PromptCommand::ListSessions => {
290 let req = acp::ListSessionsRequest::new();
291 match conn.list_sessions(req).await {
292 Ok(resp) => {
293 let _ = event_tx.send(AcpEvent::SessionsListed { sessions: resp.sessions });
294 }
295 Err(e) => {
296 let _ = event_tx.send(AcpEvent::PromptError(e));
297 }
298 }
299 }
300 PromptCommand::LoadSession { session_id, cwd } => {
301 let req = acp::LoadSessionRequest::new(session_id.clone(), cwd);
302 match conn.load_session(req).await {
303 Ok(resp) => {
304 let config_options = resp.config_options.unwrap_or_default();
305 let _ = event_tx.send(AcpEvent::SessionLoaded { session_id, config_options });
306 }
307 Err(e) => {
308 let _ = event_tx.send(AcpEvent::PromptError(e));
309 }
310 }
311 }
312 PromptCommand::NewSession { cwd } => {
313 let req = acp::NewSessionRequest::new(cwd);
314 match conn.new_session(req).await {
315 Ok(resp) => {
316 let config_options = resp.config_options.unwrap_or_default();
317 let _ =
318 event_tx.send(AcpEvent::NewSessionCreated { session_id: resp.session_id, config_options });
319 }
320 Err(e) => {
321 let _ = event_tx.send(AcpEvent::PromptError(e));
322 }
323 }
324 }
325 cmd => handle_side_command(&conn, &event_tx, cmd).await,
326 }
327 }
328
329 let _ = event_tx.send(AcpEvent::ConnectionClosed);
330}
331
332async fn handle_side_command(
333 conn: &acp::ClientSideConnection,
334 event_tx: &mpsc::UnboundedSender<AcpEvent>,
335 cmd: PromptCommand,
336) {
337 match cmd {
338 PromptCommand::Cancel { session_id } => {
339 let _ = conn.cancel(acp::CancelNotification::new(session_id)).await;
340 }
341 PromptCommand::SetConfigOption { session_id, config_id, value } => {
342 let req = SetSessionConfigOptionRequest::new(session_id, config_id, value);
343 match conn.set_session_config_option(req).await {
344 Ok(resp) => {
345 let update = ConfigOptionUpdate::new(resp.config_options);
346 let _ = event_tx.send(AcpEvent::SessionUpdate(Box::new(SessionUpdate::ConfigOptionUpdate(update))));
347 }
348 Err(e) => {
349 tracing::warn!("set_session_config_option failed: {e:?}");
350 }
351 }
352 }
353 PromptCommand::Prompt { .. } => {
354 tracing::warn!("ignoring duplicate Prompt while one is in-flight");
355 }
356 PromptCommand::ListSessions => {
357 tracing::warn!("ignoring ListSessions while prompt is in-flight");
358 }
359 PromptCommand::LoadSession { .. } => {
360 tracing::warn!("ignoring LoadSession while prompt is in-flight");
361 }
362 PromptCommand::NewSession { .. } => {
363 tracing::warn!("ignoring NewSession while prompt is in-flight");
364 }
365 PromptCommand::AuthenticateMcpServer { session_id, server_name } => {
366 let msg = McpRequest::Authenticate { session_id: session_id.0.to_string(), server_name };
367 if let Err(e) = conn.ext_notification(msg.into()).await {
368 tracing::warn!("authenticate_mcp_server notification failed: {e:?}");
369 }
370 }
371 PromptCommand::Authenticate { session_id: _, method_id } => {
372 match conn.authenticate(acp::AuthenticateRequest::new(method_id.clone())).await {
373 Ok(_) => {
374 let _ = event_tx.send(AcpEvent::AuthenticateComplete { method_id });
375 }
376 Err(e) => {
377 tracing::warn!("authenticate failed: {e:?}");
378 let _ = event_tx.send(AcpEvent::AuthenticateFailed { method_id, error: format!("{e:?}") });
379 }
380 }
381 }
382 }
383}