Skip to main content

acp_utils/client/
session.rs

1use super::error::AcpClientError;
2use super::event::AcpEvent;
3use super::prompt_handle::{AcpPromptHandle, PromptCommand};
4use super::tokio_agent::TokioAcpAgent;
5use crate::notifications::{
6    AuthMethodsUpdatedParams, ContextClearedParams, ContextUsageParams, ElicitationParams, McpNotification, McpRequest,
7    SubAgentProgressParams,
8};
9use agent_client_protocol::schema::{
10    AuthMethod, AuthenticateRequest, CancelNotification, ConfigOptionUpdate, ContentBlock, InitializeRequest,
11    InitializeResponse, ListSessionsRequest, LoadSessionRequest, NewSessionRequest, NewSessionResponse,
12    PermissionOptionId, PermissionOptionKind, PromptCapabilities, PromptRequest, RequestPermissionOutcome,
13    RequestPermissionRequest, RequestPermissionResponse, SelectedPermissionOutcome, SessionCapabilities,
14    SessionConfigOption, SessionId, SessionNotification, SessionUpdate, SetSessionConfigOptionRequest, TextContent,
15};
16use agent_client_protocol::{self as acp, Client, ConnectionTo, JsonRpcRequest};
17use std::str::FromStr;
18use tokio::sync::mpsc;
19use tracing::info;
20
21type InitializeResult = Result<(InitializeResponse, NewSessionResponse), AcpClientError>;
22
23/// ACP session with all handles needed by the caller.
24pub struct AcpSession {
25    pub session_id: SessionId,
26    pub agent_name: String,
27    pub prompt_capabilities: PromptCapabilities,
28    pub session_capabilities: SessionCapabilities,
29    pub config_options: Vec<SessionConfigOption>,
30    pub auth_methods: Vec<AuthMethod>,
31    pub event_rx: mpsc::UnboundedReceiver<AcpEvent>,
32    pub prompt_handle: AcpPromptHandle,
33}
34
35/// Spawn an agent subprocess and establish an ACP session.
36///
37/// The connection auto-approves permissions, forwards session notifications as
38/// [`AcpEvent`]s, and tunnels elicitation requests through the `_aether/elicitation`
39/// extension method.
40pub async fn spawn_acp_session(
41    agent_command: &str,
42    init_request: InitializeRequest,
43    new_session_request: NewSessionRequest,
44) -> Result<AcpSession, AcpClientError> {
45    let agent = TokioAcpAgent::from_str(agent_command).map_err(AcpClientError::InvalidAgentCommand)?;
46    let (event_tx, event_rx) = mpsc::unbounded_channel::<AcpEvent>();
47    let (cmd_tx, cmd_rx) = mpsc::unbounded_channel::<PromptCommand>();
48    let (init_tx, mut init_rx) = mpsc::unbounded_channel::<InitializeResult>();
49    tokio::spawn(run_client_connection(agent, event_tx, cmd_rx, init_tx, init_request, new_session_request));
50
51    let (init_resp, session_resp) = init_rx
52        .recv()
53        .await
54        .ok_or_else(|| AcpClientError::AgentCrashed("ACP task died during initialization".to_string()))??;
55
56    let agent_name = init_resp
57        .agent_info
58        .as_ref()
59        .map_or_else(|| "agent".to_string(), |info| info.title.as_deref().unwrap_or(&info.name).to_string());
60
61    Ok(AcpSession {
62        session_id: session_resp.session_id,
63        agent_name,
64        prompt_capabilities: init_resp.agent_capabilities.prompt_capabilities,
65        session_capabilities: init_resp.agent_capabilities.session_capabilities,
66        config_options: session_resp.config_options.unwrap_or_default(),
67        auth_methods: init_resp.auth_methods,
68        event_rx,
69        prompt_handle: AcpPromptHandle { cmd_tx },
70    })
71}
72
73#[allow(clippy::too_many_lines)]
74async fn run_client_connection(
75    agent: TokioAcpAgent,
76    event_tx: mpsc::UnboundedSender<AcpEvent>,
77    cmd_rx: mpsc::UnboundedReceiver<PromptCommand>,
78    init_tx: mpsc::UnboundedSender<InitializeResult>,
79    init_request: InitializeRequest,
80    new_session_request: NewSessionRequest,
81) {
82    let connection_result = Client
83        .builder()
84        .on_receive_request(
85            async move |req: RequestPermissionRequest, responder, _cx| {
86                responder.respond(RequestPermissionResponse::new(RequestPermissionOutcome::Selected(
87                    SelectedPermissionOutcome::new(auto_approve_option(&req)),
88                )))
89            },
90            acp::on_receive_request!(),
91        )
92        .on_receive_request(
93            {
94                let event_tx = event_tx.clone();
95                async move |params: ElicitationParams, responder, _cx| {
96                    if let Err(send_err) = event_tx.send(AcpEvent::ElicitationRequest { params, responder }) {
97                        // Recover the responder and reply with an error so the remote caller doesn't hang.
98                        if let AcpEvent::ElicitationRequest { responder, .. } = send_err.0 {
99                            return responder.respond_with_error(acp::Error::internal_error());
100                        }
101                    }
102                    Ok(())
103                }
104            },
105            acp::on_receive_request!(),
106        )
107        .on_receive_notification(
108            {
109                let event_tx = event_tx.clone();
110                async move |SessionNotification { session_id, update, .. }: SessionNotification, _cx| {
111                    let _ = event_tx.send(AcpEvent::SessionUpdate { session_id, update: Box::new(update) });
112                    Ok(())
113                }
114            },
115            acp::on_receive_notification!(),
116        )
117        .on_receive_notification(
118            {
119                let event_tx = event_tx.clone();
120                async move |p: ContextUsageParams, _cx| {
121                    let _ = event_tx.send(AcpEvent::ContextUsage(p));
122                    Ok(())
123                }
124            },
125            acp::on_receive_notification!(),
126        )
127        .on_receive_notification(
128            {
129                let event_tx = event_tx.clone();
130                async move |p: ContextClearedParams, _cx| {
131                    let _ = event_tx.send(AcpEvent::ContextCleared(p));
132                    Ok(())
133                }
134            },
135            acp::on_receive_notification!(),
136        )
137        .on_receive_notification(
138            {
139                let event_tx = event_tx.clone();
140                async move |p: SubAgentProgressParams, _cx| {
141                    let _ = event_tx.send(AcpEvent::SubAgentProgress(p));
142                    Ok(())
143                }
144            },
145            acp::on_receive_notification!(),
146        )
147        .on_receive_notification(
148            {
149                let event_tx = event_tx.clone();
150                async move |p: AuthMethodsUpdatedParams, _cx| {
151                    let _ = event_tx.send(AcpEvent::AuthMethodsUpdated(p));
152                    Ok(())
153                }
154            },
155            acp::on_receive_notification!(),
156        )
157        .on_receive_notification(
158            {
159                let event_tx = event_tx.clone();
160                async move |n: McpNotification, _cx| {
161                    let _ = event_tx.send(AcpEvent::McpNotification(n));
162                    Ok(())
163                }
164            },
165            acp::on_receive_notification!(),
166        )
167        .connect_with(agent, {
168            let event_tx = event_tx.clone();
169            let init_tx = init_tx.clone();
170            async move |cx: ConnectionTo<acp::Agent>| {
171                run_main(cx, event_tx, cmd_rx, init_tx, init_request, new_session_request).await;
172                Ok(())
173            }
174        })
175        .await;
176
177    if let Err(e) = connection_result {
178        tracing::warn!("ACP connection exited with error: {e:?}");
179        let _ = init_tx.send(Err(AcpClientError::ConnectFailed(e)));
180    }
181    let _ = event_tx.send(AcpEvent::ConnectionClosed);
182}
183
184#[allow(clippy::too_many_lines)]
185async fn run_main(
186    cx: ConnectionTo<acp::Agent>,
187    event_tx: mpsc::UnboundedSender<AcpEvent>,
188    mut cmd_rx: mpsc::UnboundedReceiver<PromptCommand>,
189    init_tx: mpsc::UnboundedSender<InitializeResult>,
190    init_request: InitializeRequest,
191    new_session_request: NewSessionRequest,
192) {
193    let init_resp = match cx.send_request(init_request).block_task().await {
194        Ok(r) => r,
195        Err(e) => {
196            let _ = init_tx.send(Err(AcpClientError::Protocol(e)));
197            return;
198        }
199    };
200    info!("ACP initialized: protocol={:?}, agent_info={:?}", init_resp.protocol_version, init_resp.agent_info);
201
202    let session_resp = match cx.send_request(new_session_request).block_task().await {
203        Ok(r) => r,
204        Err(e) => {
205            let _ = init_tx.send(Err(AcpClientError::Protocol(e)));
206            return;
207        }
208    };
209    info!("ACP session created: {}", session_resp.session_id);
210
211    let _ = init_tx.send(Ok((init_resp, session_resp)));
212
213    while let Some(cmd) = cmd_rx.recv().await {
214        match cmd {
215            PromptCommand::Prompt { session_id, text, content } => {
216                let mut prompt = vec![ContentBlock::Text(TextContent::new(text))];
217                if let Some(extra_content) = content {
218                    prompt.extend(extra_content);
219                }
220                let prompt_fut = cx.send_request(PromptRequest::new(session_id, prompt)).block_task();
221                tokio::pin!(prompt_fut);
222
223                loop {
224                    tokio::select! {
225                        result = &mut prompt_fut => {
226                            let event = match result {
227                                Ok(resp) => AcpEvent::PromptDone(resp.stop_reason),
228                                Err(e) => AcpEvent::PromptError(e),
229                            };
230                            let _ = event_tx.send(event);
231                            break;
232                        }
233                        Some(cmd) = cmd_rx.recv() => {
234                            handle_side_command(&cx, &event_tx, cmd).await;
235                        }
236                    }
237                }
238            }
239            PromptCommand::ListSessions => {
240                request_to_event(
241                    &cx,
242                    &event_tx,
243                    ListSessionsRequest::new(),
244                    |resp| AcpEvent::SessionsListed { sessions: resp.sessions },
245                    AcpEvent::PromptError,
246                )
247                .await;
248            }
249            PromptCommand::LoadSession { session_id, cwd } => {
250                request_to_event(
251                    &cx,
252                    &event_tx,
253                    LoadSessionRequest::new(session_id.clone(), cwd),
254                    |resp| AcpEvent::SessionLoaded {
255                        session_id,
256                        config_options: resp.config_options.unwrap_or_default(),
257                    },
258                    AcpEvent::PromptError,
259                )
260                .await;
261            }
262            PromptCommand::NewSession { cwd } => {
263                request_to_event(
264                    &cx,
265                    &event_tx,
266                    NewSessionRequest::new(cwd),
267                    |resp| AcpEvent::NewSessionCreated {
268                        session_id: resp.session_id,
269                        config_options: resp.config_options.unwrap_or_default(),
270                    },
271                    AcpEvent::PromptError,
272                )
273                .await;
274            }
275            PromptCommand::MoveWorkspace(params) => {
276                request_to_event(&cx, &event_tx, params, AcpEvent::WorkspaceMoved, |e| AcpEvent::WorkspaceMoveFailed {
277                    error: format!("{e}"),
278                })
279                .await;
280            }
281            cmd => handle_side_command(&cx, &event_tx, cmd).await,
282        }
283    }
284}
285
286async fn handle_side_command(
287    cx: &ConnectionTo<acp::Agent>,
288    event_tx: &mpsc::UnboundedSender<AcpEvent>,
289    cmd: PromptCommand,
290) {
291    match cmd {
292        PromptCommand::Cancel { session_id } => {
293            let _ = cx.send_notification(CancelNotification::new(session_id));
294        }
295        PromptCommand::SetConfigOption { session_id, config_id, value } => {
296            let req = SetSessionConfigOptionRequest::new(session_id.clone(), config_id, value);
297            match cx.send_request(req).block_task().await {
298                Ok(resp) => {
299                    let update = ConfigOptionUpdate::new(resp.config_options);
300                    let _ = event_tx.send(AcpEvent::SessionUpdate {
301                        session_id,
302                        update: Box::new(SessionUpdate::ConfigOptionUpdate(update)),
303                    });
304                }
305                Err(e) => {
306                    tracing::warn!("set_session_config_option failed: {e:?}");
307                }
308            }
309        }
310        PromptCommand::Prompt { .. } => {
311            tracing::warn!("ignoring duplicate Prompt while one is in-flight");
312        }
313        PromptCommand::ListSessions => {
314            tracing::warn!("ignoring ListSessions while prompt is in-flight");
315        }
316        PromptCommand::LoadSession { .. } => {
317            tracing::warn!("ignoring LoadSession while prompt is in-flight");
318        }
319        PromptCommand::NewSession { .. } => {
320            tracing::warn!("ignoring NewSession while prompt is in-flight");
321        }
322        PromptCommand::AuthenticateMcpServer { session_id, server_name } => {
323            let msg = McpRequest::Authenticate { session_id: session_id.0.to_string(), server_name };
324            if let Err(e) = cx.send_notification(msg) {
325                tracing::warn!("authenticate_mcp_server notification failed: {e:?}");
326            }
327        }
328        PromptCommand::Authenticate { method_id } => {
329            match cx.send_request(AuthenticateRequest::new(method_id.clone())).block_task().await {
330                Ok(_) => {
331                    let _ = event_tx.send(AcpEvent::AuthenticateComplete { method_id });
332                }
333                Err(e) => {
334                    tracing::warn!("authenticate failed: {e:?}");
335                    let _ = event_tx.send(AcpEvent::AuthenticateFailed { method_id, error: format!("{e:?}") });
336                }
337            }
338        }
339        PromptCommand::SearchPrompts(params) => {
340            let query = params.query.clone();
341            request_to_event(cx, event_tx, params, AcpEvent::PromptSearchResults, |e| AcpEvent::PromptSearchFailed {
342                query,
343                error: format!("{e}"),
344            })
345            .await;
346        }
347        PromptCommand::SessionPreview(params) => {
348            let session_id = params.session_id.clone();
349            request_to_event(cx, event_tx, params, AcpEvent::SessionPreviewLoaded, |e| {
350                AcpEvent::SessionPreviewFailed { session_id, error: format!("{e}") }
351            })
352            .await;
353        }
354        PromptCommand::ListWorkspaces(params) => {
355            request_to_event(cx, event_tx, params, AcpEvent::WorkspacesListed, |e| AcpEvent::WorkspaceListFailed {
356                error: format!("{e}"),
357            })
358            .await;
359        }
360        PromptCommand::MoveWorkspace(_) => {
361            tracing::warn!("ignoring MoveWorkspace while prompt is in-flight");
362            let _ = event_tx.send(AcpEvent::WorkspaceMoveFailed { error: "a prompt is in flight".to_string() });
363        }
364    }
365}
366
367async fn request_to_event<T: JsonRpcRequest>(
368    cx: &ConnectionTo<acp::Agent>,
369    event_tx: &mpsc::UnboundedSender<AcpEvent>,
370    params: T,
371    ok: impl FnOnce(T::Response) -> AcpEvent,
372    err: impl FnOnce(acp::Error) -> AcpEvent,
373) {
374    let event = match cx.send_request(params).block_task().await {
375        Ok(resp) => ok(resp),
376        Err(e) => err(e),
377    };
378    let _ = event_tx.send(event);
379}
380
381fn auto_approve_option(req: &RequestPermissionRequest) -> PermissionOptionId {
382    debug_assert!(!req.options.is_empty(), "ACP guarantees at least one permission option");
383    req.options
384        .iter()
385        .find(|o| matches!(o.kind, PermissionOptionKind::AllowOnce | PermissionOptionKind::AllowAlways))
386        .map_or_else(|| req.options[0].option_id.clone(), |o| o.option_id.clone())
387}