Skip to main content

acp_utils/client/
session.rs

1use super::error::AcpClientError;
2use super::event::AcpEvent;
3use super::prompt_handle::{AcpPromptHandle, PromptCommand};
4use crate::notifications::{ELICITATION_METHOD, ElicitationParams, McpRequest};
5use agent_client_protocol::{
6    self as acp, Agent, Client, ConfigOptionUpdate, ExtNotification, ExtRequest, ExtResponse,
7    InitializeRequest, PermissionOptionKind, RequestPermissionOutcome, RequestPermissionRequest,
8    RequestPermissionResponse, SelectedPermissionOutcome, SessionConfigOption, SessionId,
9    SessionNotification, SessionUpdate, SetSessionConfigOptionRequest,
10};
11use serde_json::value::RawValue;
12use std::process::Stdio;
13use std::sync::Arc;
14use std::thread::spawn;
15use tokio::process::Command;
16use tokio::sync::{mpsc, oneshot};
17use tokio::task::{LocalSet, spawn_local};
18use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
19use tracing::info;
20
21/// ACP session with all handles needed by the caller.
22pub struct AcpSession {
23    pub session_id: SessionId,
24    pub agent_name: String,
25    pub prompt_capabilities: acp::PromptCapabilities,
26    pub config_options: Vec<SessionConfigOption>,
27    pub auth_methods: Vec<acp::AuthMethod>,
28    pub event_rx: mpsc::UnboundedReceiver<AcpEvent>,
29    pub prompt_handle: AcpPromptHandle,
30}
31
32/// A built-in ACP client that auto-approves permissions and forwards session
33/// notifications as [`AcpEvent`]s.
34pub struct AutoApproveClient {
35    event_tx: mpsc::UnboundedSender<AcpEvent>,
36}
37
38impl AutoApproveClient {
39    pub fn new(event_tx: mpsc::UnboundedSender<AcpEvent>) -> Self {
40        Self { event_tx }
41    }
42}
43
44#[async_trait::async_trait(?Send)]
45impl Client for AutoApproveClient {
46    async fn request_permission(
47        &self,
48        args: RequestPermissionRequest,
49    ) -> acp::Result<RequestPermissionResponse> {
50        let option_id = args
51            .options
52            .iter()
53            .find(|o| {
54                matches!(
55                    o.kind,
56                    PermissionOptionKind::AllowOnce | PermissionOptionKind::AllowAlways
57                )
58            })
59            .map_or_else(
60                || args.options[0].option_id.clone(),
61                |o| o.option_id.clone(),
62            );
63
64        Ok(RequestPermissionResponse::new(
65            RequestPermissionOutcome::Selected(SelectedPermissionOutcome::new(option_id)),
66        ))
67    }
68
69    async fn session_notification(&self, args: SessionNotification) -> acp::Result<()> {
70        let _ = self
71            .event_tx
72            .send(AcpEvent::SessionUpdate(Box::new(args.update)));
73
74        Ok(())
75    }
76
77    async fn ext_notification(&self, args: ExtNotification) -> acp::Result<()> {
78        let _ = self.event_tx.send(AcpEvent::ExtNotification(args));
79        Ok(())
80    }
81
82    async fn ext_method(&self, args: ExtRequest) -> acp::Result<ExtResponse> {
83        if args.method.as_ref() == ELICITATION_METHOD {
84            return handle_elicitation_ext_method(&self.event_tx, args).await;
85        }
86
87        // Unknown ext_method — return null (default behavior)
88        let null_raw: Arc<RawValue> = serde_json::from_str("null").expect("null is valid JSON");
89        Ok(ExtResponse::new(null_raw))
90    }
91}
92
93async fn handle_elicitation_ext_method(
94    event_tx: &mpsc::UnboundedSender<AcpEvent>,
95    args: ExtRequest,
96) -> acp::Result<ExtResponse> {
97    let params: ElicitationParams =
98        serde_json::from_str(args.params.get()).map_err(|_| acp::Error::invalid_params())?;
99
100    let (response_tx, response_rx) = oneshot::channel();
101    event_tx
102        .send(AcpEvent::ElicitationRequest {
103            params,
104            response_tx,
105        })
106        .map_err(|_| acp::Error::internal_error())?;
107
108    let response = response_rx
109        .await
110        .map_err(|_| acp::Error::internal_error())?;
111
112    let raw =
113        serde_json::value::to_raw_value(&response).map_err(|_| acp::Error::internal_error())?;
114
115    Ok(ExtResponse::new(Arc::from(raw)))
116}
117
118/// Spawn an agent subprocess and establish an ACP session.
119///
120/// The handshake (initialize + `new_session`) runs on a dedicated !Send thread.
121/// `client_factory` creates the ACP [`Client`](acp::Client) implementation,
122/// receiving the event sender so it can forward protocol events.
123///
124/// For the common auto-approve case, use `AutoApproveClient::new`:
125/// ```ignore
126/// spawn_acp_session("my-agent", init_req, session_req, AutoApproveClient::new).await
127/// ```
128pub async fn spawn_acp_session<F, C>(
129    agent_command: &str,
130    init_request: InitializeRequest,
131    new_session_request: acp::NewSessionRequest,
132    client_factory: F,
133) -> Result<AcpSession, AcpClientError>
134where
135    F: FnOnce(mpsc::UnboundedSender<AcpEvent>) -> C + Send + 'static,
136    C: acp::Client + 'static,
137{
138    let parts: Vec<&str> = agent_command.split_whitespace().collect();
139    let (program, args) = parts
140        .split_first()
141        .ok_or_else(|| AcpClientError::AgentCrashed("empty agent command".to_string()))?;
142
143    let mut child = Command::new(program)
144        .args(args)
145        .stdin(Stdio::piped())
146        .stdout(Stdio::piped())
147        .stderr(Stdio::inherit())
148        .spawn()
149        .map_err(AcpClientError::SpawnFailed)?;
150
151    let child_stdin = child
152        .stdin
153        .take()
154        .ok_or_else(|| AcpClientError::AgentCrashed("no stdin on child".to_string()))?;
155
156    let child_stdout = child
157        .stdout
158        .take()
159        .ok_or_else(|| AcpClientError::AgentCrashed("no stdout on child".to_string()))?;
160
161    let (event_tx, event_rx) = mpsc::unbounded_channel::<AcpEvent>();
162    let (cmd_tx, cmd_rx) = mpsc::unbounded_channel::<PromptCommand>();
163    let (session_tx, session_rx) = oneshot::channel::<HandshakeResult>();
164    let thread_ctx = AcpThreadContext {
165        child_stdin,
166        child_stdout,
167        event_tx,
168        cmd_rx,
169        session_tx,
170        client_factory,
171        init_request,
172        new_session_request,
173    };
174
175    spawn(move || {
176        let rt = tokio::runtime::Builder::new_current_thread()
177            .enable_all()
178            .build()
179            .expect("failed to build tokio runtime for ACP");
180
181        LocalSet::new().block_on(&rt, async move {
182            run_acp_thread(thread_ctx).await;
183        });
184    });
185
186    let handshake = session_rx.await.map_err(|_| {
187        AcpClientError::AgentCrashed("ACP thread died during handshake".to_string())
188    })??;
189
190    Ok(AcpSession {
191        session_id: handshake.session_id,
192        agent_name: handshake.agent_name,
193        prompt_capabilities: handshake.prompt_capabilities,
194        config_options: handshake.config_options,
195        auth_methods: handshake.auth_methods,
196        event_rx,
197        prompt_handle: AcpPromptHandle { cmd_tx },
198    })
199}
200
201struct HandshakeData {
202    session_id: acp::SessionId,
203    agent_name: String,
204    prompt_capabilities: acp::PromptCapabilities,
205    config_options: Vec<acp::SessionConfigOption>,
206    auth_methods: Vec<acp::AuthMethod>,
207}
208
209type HandshakeResult = Result<HandshakeData, AcpClientError>;
210
211struct AcpThreadContext<F> {
212    child_stdin: tokio::process::ChildStdin,
213    child_stdout: tokio::process::ChildStdout,
214    event_tx: mpsc::UnboundedSender<AcpEvent>,
215    cmd_rx: mpsc::UnboundedReceiver<PromptCommand>,
216    session_tx: oneshot::Sender<HandshakeResult>,
217    client_factory: F,
218    init_request: InitializeRequest,
219    new_session_request: acp::NewSessionRequest,
220}
221
222#[allow(clippy::too_many_lines)]
223async fn run_acp_thread<F, C>(ctx: AcpThreadContext<F>)
224where
225    F: FnOnce(mpsc::UnboundedSender<AcpEvent>) -> C,
226    C: Client + 'static,
227{
228    let AcpThreadContext {
229        child_stdin,
230        child_stdout,
231        event_tx,
232        mut cmd_rx,
233        session_tx,
234        client_factory,
235        init_request,
236        new_session_request,
237    } = ctx;
238
239    let client = client_factory(event_tx.clone());
240    let outgoing = child_stdin.compat_write();
241    let incoming = child_stdout.compat();
242    let (conn, handle_io) = acp::ClientSideConnection::new(client, outgoing, incoming, |fut| {
243        spawn_local(fut);
244    });
245
246    spawn_local(async move {
247        let _ = handle_io.await;
248    });
249
250    let init_resp = match conn.initialize(init_request).await {
251        Ok(r) => r,
252        Err(e) => {
253            let _ = session_tx.send(Err(AcpClientError::Protocol(e)));
254            return;
255        }
256    };
257
258    let agent_name = init_resp.agent_info.as_ref().map_or_else(
259        || "agent".to_string(),
260        |info| info.title.as_deref().unwrap_or(&info.name).to_string(),
261    );
262    let prompt_capabilities = init_resp.agent_capabilities.prompt_capabilities.clone();
263
264    info!(
265        "ACP initialized: protocol={:?}, agent_info={:?}",
266        init_resp.protocol_version, init_resp.agent_info
267    );
268
269    let auth_methods = init_resp.auth_methods;
270
271    let session_resp = match conn.new_session(new_session_request).await {
272        Ok(r) => r,
273        Err(e) => {
274            let _ = session_tx.send(Err(AcpClientError::Protocol(e)));
275            return;
276        }
277    };
278
279    let session_id = session_resp.session_id;
280    info!("ACP session created: {session_id}");
281
282    let config_options = session_resp.config_options.unwrap_or_default();
283    let _ = session_tx.send(Ok(HandshakeData {
284        session_id,
285        agent_name,
286        prompt_capabilities,
287        config_options,
288        auth_methods,
289    }));
290
291    while let Some(cmd) = cmd_rx.recv().await {
292        match cmd {
293            PromptCommand::Prompt {
294                session_id,
295                text,
296                content,
297            } => {
298                let mut prompt = vec![acp::ContentBlock::Text(acp::TextContent::new(text))];
299                if let Some(extra_content) = content {
300                    prompt.extend(extra_content);
301                }
302                let prompt_fut = conn.prompt(acp::PromptRequest::new(session_id, prompt));
303                tokio::pin!(prompt_fut);
304
305                // Process cancel/config commands while the prompt is in-flight
306                loop {
307                    tokio::select! {
308                        result = &mut prompt_fut => {
309                            let event = match result {
310                                Ok(resp) => AcpEvent::PromptDone(resp.stop_reason),
311                                Err(e) => AcpEvent::PromptError(e),
312                            };
313                            let _ = event_tx.send(event);
314                            break;
315                        }
316                        Some(cmd) = cmd_rx.recv() => {
317                            handle_side_command(&conn, &event_tx, cmd).await;
318                        }
319                    }
320                }
321            }
322            PromptCommand::ListSessions => {
323                let req = acp::ListSessionsRequest::new();
324                match conn.list_sessions(req).await {
325                    Ok(resp) => {
326                        let _ = event_tx.send(AcpEvent::SessionsListed {
327                            sessions: resp.sessions,
328                        });
329                    }
330                    Err(e) => {
331                        let _ = event_tx.send(AcpEvent::PromptError(e));
332                    }
333                }
334            }
335            PromptCommand::LoadSession { session_id, cwd } => {
336                let req = acp::LoadSessionRequest::new(session_id.clone(), cwd);
337                match conn.load_session(req).await {
338                    Ok(resp) => {
339                        let config_options = resp.config_options.unwrap_or_default();
340                        let _ = event_tx.send(AcpEvent::SessionLoaded {
341                            session_id,
342                            config_options,
343                        });
344                    }
345                    Err(e) => {
346                        let _ = event_tx.send(AcpEvent::PromptError(e));
347                    }
348                }
349            }
350            PromptCommand::NewSession { cwd } => {
351                let req = acp::NewSessionRequest::new(cwd);
352                match conn.new_session(req).await {
353                    Ok(resp) => {
354                        let config_options = resp.config_options.unwrap_or_default();
355                        let _ = event_tx.send(AcpEvent::NewSessionCreated {
356                            session_id: resp.session_id,
357                            config_options,
358                        });
359                    }
360                    Err(e) => {
361                        let _ = event_tx.send(AcpEvent::PromptError(e));
362                    }
363                }
364            }
365            cmd => handle_side_command(&conn, &event_tx, cmd).await,
366        }
367    }
368
369    let _ = event_tx.send(AcpEvent::ConnectionClosed);
370}
371
372async fn handle_side_command(
373    conn: &acp::ClientSideConnection,
374    event_tx: &mpsc::UnboundedSender<AcpEvent>,
375    cmd: PromptCommand,
376) {
377    match cmd {
378        PromptCommand::Cancel { session_id } => {
379            let _ = conn.cancel(acp::CancelNotification::new(session_id)).await;
380        }
381        PromptCommand::SetConfigOption {
382            session_id,
383            config_id,
384            value,
385        } => {
386            let req = SetSessionConfigOptionRequest::new(session_id, config_id, value);
387            match conn.set_session_config_option(req).await {
388                Ok(resp) => {
389                    let update = ConfigOptionUpdate::new(resp.config_options);
390                    let _ = event_tx.send(AcpEvent::SessionUpdate(Box::new(
391                        SessionUpdate::ConfigOptionUpdate(update),
392                    )));
393                }
394                Err(e) => {
395                    tracing::warn!("set_session_config_option failed: {e:?}");
396                }
397            }
398        }
399        PromptCommand::Prompt { .. } => {
400            tracing::warn!("ignoring duplicate Prompt while one is in-flight");
401        }
402        PromptCommand::ListSessions => {
403            tracing::warn!("ignoring ListSessions while prompt is in-flight");
404        }
405        PromptCommand::LoadSession { .. } => {
406            tracing::warn!("ignoring LoadSession while prompt is in-flight");
407        }
408        PromptCommand::NewSession { .. } => {
409            tracing::warn!("ignoring NewSession while prompt is in-flight");
410        }
411        PromptCommand::AuthenticateMcpServer {
412            session_id,
413            server_name,
414        } => {
415            let msg = McpRequest::Authenticate {
416                session_id: session_id.0.to_string(),
417                server_name,
418            };
419            if let Err(e) = conn.ext_notification(msg.into()).await {
420                tracing::warn!("authenticate_mcp_server notification failed: {e:?}");
421            }
422        }
423        PromptCommand::Authenticate {
424            session_id: _,
425            method_id,
426        } => {
427            match conn
428                .authenticate(acp::AuthenticateRequest::new(method_id.clone()))
429                .await
430            {
431                Ok(_) => {
432                    let _ = event_tx.send(AcpEvent::AuthenticateComplete { method_id });
433                }
434                Err(e) => {
435                    tracing::warn!("authenticate failed: {e:?}");
436                    let _ = event_tx.send(AcpEvent::AuthenticateFailed {
437                        method_id,
438                        error: format!("{e:?}"),
439                    });
440                }
441            }
442        }
443    }
444}