Skip to main content

acp_utils/client/
session.rs

1use super::error::AcpClientError;
2use super::event::AcpEvent;
3use super::prompt_handle::{AcpPromptHandle, PromptCommand};
4use crate::notifications::{ELICITATION_METHOD, ElicitationParams, McpRequest};
5use agent_client_protocol::{
6    self as acp, Agent, Client, ConfigOptionUpdate, ExtNotification, ExtRequest, ExtResponse, InitializeRequest,
7    PermissionOptionKind, RequestPermissionOutcome, RequestPermissionRequest, RequestPermissionResponse,
8    SelectedPermissionOutcome, SessionConfigOption, SessionId, SessionNotification, SessionUpdate,
9    SetSessionConfigOptionRequest,
10};
11use serde_json::value::RawValue;
12use std::process::Stdio;
13use std::sync::Arc;
14use std::thread::spawn;
15use tokio::process::Command;
16use tokio::sync::{mpsc, oneshot};
17use tokio::task::{LocalSet, spawn_local};
18use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
19use tracing::info;
20
21/// ACP session with all handles needed by the caller.
22pub struct AcpSession {
23    pub session_id: SessionId,
24    pub agent_name: String,
25    pub prompt_capabilities: acp::PromptCapabilities,
26    pub config_options: Vec<SessionConfigOption>,
27    pub auth_methods: Vec<acp::AuthMethod>,
28    pub event_rx: mpsc::UnboundedReceiver<AcpEvent>,
29    pub prompt_handle: AcpPromptHandle,
30}
31
32/// A built-in ACP client that auto-approves permissions and forwards session
33/// notifications as [`AcpEvent`]s.
34pub struct AutoApproveClient {
35    event_tx: mpsc::UnboundedSender<AcpEvent>,
36}
37
38impl AutoApproveClient {
39    pub fn new(event_tx: mpsc::UnboundedSender<AcpEvent>) -> Self {
40        Self { event_tx }
41    }
42}
43
44#[async_trait::async_trait(?Send)]
45impl Client for AutoApproveClient {
46    async fn request_permission(&self, args: RequestPermissionRequest) -> acp::Result<RequestPermissionResponse> {
47        let option_id = args
48            .options
49            .iter()
50            .find(|o| matches!(o.kind, PermissionOptionKind::AllowOnce | PermissionOptionKind::AllowAlways))
51            .map_or_else(|| args.options[0].option_id.clone(), |o| o.option_id.clone());
52
53        Ok(RequestPermissionResponse::new(RequestPermissionOutcome::Selected(SelectedPermissionOutcome::new(
54            option_id,
55        ))))
56    }
57
58    async fn session_notification(&self, args: SessionNotification) -> acp::Result<()> {
59        let _ = self.event_tx.send(AcpEvent::SessionUpdate(Box::new(args.update)));
60
61        Ok(())
62    }
63
64    async fn ext_notification(&self, args: ExtNotification) -> acp::Result<()> {
65        let _ = self.event_tx.send(AcpEvent::ExtNotification(args));
66        Ok(())
67    }
68
69    async fn ext_method(&self, args: ExtRequest) -> acp::Result<ExtResponse> {
70        if args.method.as_ref() == ELICITATION_METHOD {
71            return handle_elicitation_ext_method(&self.event_tx, args).await;
72        }
73
74        // Unknown ext_method — return null (default behavior)
75        let null_raw: Arc<RawValue> = serde_json::from_str("null").expect("null is valid JSON");
76        Ok(ExtResponse::new(null_raw))
77    }
78}
79
80async fn handle_elicitation_ext_method(
81    event_tx: &mpsc::UnboundedSender<AcpEvent>,
82    args: ExtRequest,
83) -> acp::Result<ExtResponse> {
84    let params: ElicitationParams =
85        serde_json::from_str(args.params.get()).map_err(|_| acp::Error::invalid_params())?;
86
87    let (response_tx, response_rx) = oneshot::channel();
88    event_tx.send(AcpEvent::ElicitationRequest { params, response_tx }).map_err(|_| acp::Error::internal_error())?;
89
90    let response = response_rx.await.map_err(|_| acp::Error::internal_error())?;
91
92    let raw = serde_json::value::to_raw_value(&response).map_err(|_| acp::Error::internal_error())?;
93
94    Ok(ExtResponse::new(Arc::from(raw)))
95}
96
97/// Spawn an agent subprocess and establish an ACP session.
98///
99/// The handshake (initialize + `new_session`) runs on a dedicated !Send thread.
100/// `client_factory` creates the ACP [`Client`] implementation,
101/// receiving the event sender so it can forward protocol events.
102///
103/// For the common auto-approve case, use `AutoApproveClient::new`:
104/// ```ignore
105/// spawn_acp_session("my-agent", init_req, session_req, AutoApproveClient::new).await
106/// ```
107pub async fn spawn_acp_session<F, C>(
108    agent_command: &str,
109    init_request: InitializeRequest,
110    new_session_request: acp::NewSessionRequest,
111    client_factory: F,
112) -> Result<AcpSession, AcpClientError>
113where
114    F: FnOnce(mpsc::UnboundedSender<AcpEvent>) -> C + Send + 'static,
115    C: acp::Client + 'static,
116{
117    let parts: Vec<&str> = agent_command.split_whitespace().collect();
118    let (program, args) =
119        parts.split_first().ok_or_else(|| AcpClientError::AgentCrashed("empty agent command".to_string()))?;
120
121    let mut child = Command::new(program)
122        .args(args)
123        .stdin(Stdio::piped())
124        .stdout(Stdio::piped())
125        .stderr(Stdio::inherit())
126        .spawn()
127        .map_err(AcpClientError::SpawnFailed)?;
128
129    let child_stdin =
130        child.stdin.take().ok_or_else(|| AcpClientError::AgentCrashed("no stdin on child".to_string()))?;
131
132    let child_stdout =
133        child.stdout.take().ok_or_else(|| AcpClientError::AgentCrashed("no stdout on child".to_string()))?;
134
135    let (event_tx, event_rx) = mpsc::unbounded_channel::<AcpEvent>();
136    let (cmd_tx, cmd_rx) = mpsc::unbounded_channel::<PromptCommand>();
137    let (session_tx, session_rx) = oneshot::channel::<HandshakeResult>();
138    let thread_ctx = AcpThreadContext {
139        child_stdin,
140        child_stdout,
141        event_tx,
142        cmd_rx,
143        session_tx,
144        client_factory,
145        init_request,
146        new_session_request,
147    };
148
149    spawn(move || {
150        let rt = tokio::runtime::Builder::new_current_thread()
151            .enable_all()
152            .build()
153            .expect("failed to build tokio runtime for ACP");
154
155        LocalSet::new().block_on(&rt, async move {
156            run_acp_thread(thread_ctx).await;
157        });
158    });
159
160    let handshake = session_rx
161        .await
162        .map_err(|_| AcpClientError::AgentCrashed("ACP thread died during handshake".to_string()))??;
163
164    Ok(AcpSession {
165        session_id: handshake.session_id,
166        agent_name: handshake.agent_name,
167        prompt_capabilities: handshake.prompt_capabilities,
168        config_options: handshake.config_options,
169        auth_methods: handshake.auth_methods,
170        event_rx,
171        prompt_handle: AcpPromptHandle { cmd_tx },
172    })
173}
174
175struct HandshakeData {
176    session_id: acp::SessionId,
177    agent_name: String,
178    prompt_capabilities: acp::PromptCapabilities,
179    config_options: Vec<acp::SessionConfigOption>,
180    auth_methods: Vec<acp::AuthMethod>,
181}
182
183type HandshakeResult = Result<HandshakeData, AcpClientError>;
184
185struct AcpThreadContext<F> {
186    child_stdin: tokio::process::ChildStdin,
187    child_stdout: tokio::process::ChildStdout,
188    event_tx: mpsc::UnboundedSender<AcpEvent>,
189    cmd_rx: mpsc::UnboundedReceiver<PromptCommand>,
190    session_tx: oneshot::Sender<HandshakeResult>,
191    client_factory: F,
192    init_request: InitializeRequest,
193    new_session_request: acp::NewSessionRequest,
194}
195
196#[allow(clippy::too_many_lines)]
197async fn run_acp_thread<F, C>(ctx: AcpThreadContext<F>)
198where
199    F: FnOnce(mpsc::UnboundedSender<AcpEvent>) -> C,
200    C: Client + 'static,
201{
202    let AcpThreadContext {
203        child_stdin,
204        child_stdout,
205        event_tx,
206        mut cmd_rx,
207        session_tx,
208        client_factory,
209        init_request,
210        new_session_request,
211    } = ctx;
212
213    let client = client_factory(event_tx.clone());
214    let outgoing = child_stdin.compat_write();
215    let incoming = child_stdout.compat();
216    let (conn, handle_io) = acp::ClientSideConnection::new(client, outgoing, incoming, |fut| {
217        spawn_local(fut);
218    });
219
220    spawn_local(async move {
221        let _ = handle_io.await;
222    });
223
224    let init_resp = match conn.initialize(init_request).await {
225        Ok(r) => r,
226        Err(e) => {
227            let _ = session_tx.send(Err(AcpClientError::Protocol(e)));
228            return;
229        }
230    };
231
232    let agent_name = init_resp
233        .agent_info
234        .as_ref()
235        .map_or_else(|| "agent".to_string(), |info| info.title.as_deref().unwrap_or(&info.name).to_string());
236    let prompt_capabilities = init_resp.agent_capabilities.prompt_capabilities.clone();
237
238    info!("ACP initialized: protocol={:?}, agent_info={:?}", init_resp.protocol_version, init_resp.agent_info);
239
240    let auth_methods = init_resp.auth_methods;
241
242    let session_resp = match conn.new_session(new_session_request).await {
243        Ok(r) => r,
244        Err(e) => {
245            let _ = session_tx.send(Err(AcpClientError::Protocol(e)));
246            return;
247        }
248    };
249
250    let session_id = session_resp.session_id;
251    info!("ACP session created: {session_id}");
252
253    let config_options = session_resp.config_options.unwrap_or_default();
254    let _ = session_tx.send(Ok(HandshakeData {
255        session_id,
256        agent_name,
257        prompt_capabilities,
258        config_options,
259        auth_methods,
260    }));
261
262    while let Some(cmd) = cmd_rx.recv().await {
263        match cmd {
264            PromptCommand::Prompt { session_id, text, content } => {
265                let mut prompt = vec![acp::ContentBlock::Text(acp::TextContent::new(text))];
266                if let Some(extra_content) = content {
267                    prompt.extend(extra_content);
268                }
269                let prompt_fut = conn.prompt(acp::PromptRequest::new(session_id, prompt));
270                tokio::pin!(prompt_fut);
271
272                // Process cancel/config commands while the prompt is in-flight
273                loop {
274                    tokio::select! {
275                        result = &mut prompt_fut => {
276                            let event = match result {
277                                Ok(resp) => AcpEvent::PromptDone(resp.stop_reason),
278                                Err(e) => AcpEvent::PromptError(e),
279                            };
280                            let _ = event_tx.send(event);
281                            break;
282                        }
283                        Some(cmd) = cmd_rx.recv() => {
284                            handle_side_command(&conn, &event_tx, cmd).await;
285                        }
286                    }
287                }
288            }
289            PromptCommand::ListSessions => {
290                let req = acp::ListSessionsRequest::new();
291                match conn.list_sessions(req).await {
292                    Ok(resp) => {
293                        let _ = event_tx.send(AcpEvent::SessionsListed { sessions: resp.sessions });
294                    }
295                    Err(e) => {
296                        let _ = event_tx.send(AcpEvent::PromptError(e));
297                    }
298                }
299            }
300            PromptCommand::LoadSession { session_id, cwd } => {
301                let req = acp::LoadSessionRequest::new(session_id.clone(), cwd);
302                match conn.load_session(req).await {
303                    Ok(resp) => {
304                        let config_options = resp.config_options.unwrap_or_default();
305                        let _ = event_tx.send(AcpEvent::SessionLoaded { session_id, config_options });
306                    }
307                    Err(e) => {
308                        let _ = event_tx.send(AcpEvent::PromptError(e));
309                    }
310                }
311            }
312            PromptCommand::NewSession { cwd } => {
313                let req = acp::NewSessionRequest::new(cwd);
314                match conn.new_session(req).await {
315                    Ok(resp) => {
316                        let config_options = resp.config_options.unwrap_or_default();
317                        let _ =
318                            event_tx.send(AcpEvent::NewSessionCreated { session_id: resp.session_id, config_options });
319                    }
320                    Err(e) => {
321                        let _ = event_tx.send(AcpEvent::PromptError(e));
322                    }
323                }
324            }
325            cmd => handle_side_command(&conn, &event_tx, cmd).await,
326        }
327    }
328
329    let _ = event_tx.send(AcpEvent::ConnectionClosed);
330}
331
332async fn handle_side_command(
333    conn: &acp::ClientSideConnection,
334    event_tx: &mpsc::UnboundedSender<AcpEvent>,
335    cmd: PromptCommand,
336) {
337    match cmd {
338        PromptCommand::Cancel { session_id } => {
339            let _ = conn.cancel(acp::CancelNotification::new(session_id)).await;
340        }
341        PromptCommand::SetConfigOption { session_id, config_id, value } => {
342            let req = SetSessionConfigOptionRequest::new(session_id, config_id, value);
343            match conn.set_session_config_option(req).await {
344                Ok(resp) => {
345                    let update = ConfigOptionUpdate::new(resp.config_options);
346                    let _ = event_tx.send(AcpEvent::SessionUpdate(Box::new(SessionUpdate::ConfigOptionUpdate(update))));
347                }
348                Err(e) => {
349                    tracing::warn!("set_session_config_option failed: {e:?}");
350                }
351            }
352        }
353        PromptCommand::Prompt { .. } => {
354            tracing::warn!("ignoring duplicate Prompt while one is in-flight");
355        }
356        PromptCommand::ListSessions => {
357            tracing::warn!("ignoring ListSessions while prompt is in-flight");
358        }
359        PromptCommand::LoadSession { .. } => {
360            tracing::warn!("ignoring LoadSession while prompt is in-flight");
361        }
362        PromptCommand::NewSession { .. } => {
363            tracing::warn!("ignoring NewSession while prompt is in-flight");
364        }
365        PromptCommand::AuthenticateMcpServer { session_id, server_name } => {
366            let msg = McpRequest::Authenticate { session_id: session_id.0.to_string(), server_name };
367            if let Err(e) = conn.ext_notification(msg.into()).await {
368                tracing::warn!("authenticate_mcp_server notification failed: {e:?}");
369            }
370        }
371        PromptCommand::Authenticate { session_id: _, method_id } => {
372            match conn.authenticate(acp::AuthenticateRequest::new(method_id.clone())).await {
373                Ok(_) => {
374                    let _ = event_tx.send(AcpEvent::AuthenticateComplete { method_id });
375                }
376                Err(e) => {
377                    tracing::warn!("authenticate failed: {e:?}");
378                    let _ = event_tx.send(AcpEvent::AuthenticateFailed { method_id, error: format!("{e:?}") });
379                }
380            }
381        }
382    }
383}