Skip to main content

acp_utils/client/
session.rs

1use super::error::AcpClientError;
2use super::event::AcpEvent;
3use super::prompt_handle::{AcpPromptHandle, PromptCommand};
4use crate::notifications::{
5    AuthMethodsUpdatedParams, ContextClearedParams, ContextUsageParams, ElicitationParams, McpNotification, McpRequest,
6    SubAgentProgressParams,
7};
8use agent_client_protocol::schema::{
9    AuthMethod, AuthenticateRequest, CancelNotification, ConfigOptionUpdate, ContentBlock, InitializeRequest,
10    InitializeResponse, ListSessionsRequest, LoadSessionRequest, NewSessionRequest, NewSessionResponse,
11    PermissionOptionId, PermissionOptionKind, PromptCapabilities, PromptRequest, RequestPermissionOutcome,
12    RequestPermissionRequest, RequestPermissionResponse, SelectedPermissionOutcome, SessionConfigOption, SessionId,
13    SessionNotification, SessionUpdate, SetSessionConfigOptionRequest, TextContent,
14};
15use agent_client_protocol::{self as acp, AcpAgent, Client, ConnectionTo};
16use std::str::FromStr;
17use tokio::sync::mpsc;
18use tracing::info;
19
20type InitializeResult = Result<(InitializeResponse, NewSessionResponse), AcpClientError>;
21
22/// ACP session with all handles needed by the caller.
23pub struct AcpSession {
24    pub session_id: SessionId,
25    pub agent_name: String,
26    pub prompt_capabilities: PromptCapabilities,
27    pub config_options: Vec<SessionConfigOption>,
28    pub auth_methods: Vec<AuthMethod>,
29    pub event_rx: mpsc::UnboundedReceiver<AcpEvent>,
30    pub prompt_handle: AcpPromptHandle,
31}
32
33/// Spawn an agent subprocess and establish an ACP session.
34///
35/// The connection auto-approves permissions, forwards session notifications as
36/// [`AcpEvent`]s, and tunnels elicitation requests through the `_aether/elicitation`
37/// extension method.
38pub async fn spawn_acp_session(
39    agent_command: &str,
40    init_request: InitializeRequest,
41    new_session_request: NewSessionRequest,
42) -> Result<AcpSession, AcpClientError> {
43    let agent = AcpAgent::from_str(agent_command).map_err(AcpClientError::InvalidAgentCommand)?;
44    let (event_tx, event_rx) = mpsc::unbounded_channel::<AcpEvent>();
45    let (cmd_tx, cmd_rx) = mpsc::unbounded_channel::<PromptCommand>();
46    let (init_tx, mut init_rx) = mpsc::unbounded_channel::<InitializeResult>();
47    tokio::spawn(run_client_connection(agent, event_tx, cmd_rx, init_tx, init_request, new_session_request));
48
49    let (init_resp, session_resp) = init_rx
50        .recv()
51        .await
52        .ok_or_else(|| AcpClientError::AgentCrashed("ACP task died during initialization".to_string()))??;
53
54    let agent_name = init_resp
55        .agent_info
56        .as_ref()
57        .map_or_else(|| "agent".to_string(), |info| info.title.as_deref().unwrap_or(&info.name).to_string());
58
59    Ok(AcpSession {
60        session_id: session_resp.session_id,
61        agent_name,
62        prompt_capabilities: init_resp.agent_capabilities.prompt_capabilities,
63        config_options: session_resp.config_options.unwrap_or_default(),
64        auth_methods: init_resp.auth_methods,
65        event_rx,
66        prompt_handle: AcpPromptHandle { cmd_tx },
67    })
68}
69
70#[allow(clippy::too_many_lines)]
71async fn run_client_connection(
72    agent: AcpAgent,
73    event_tx: mpsc::UnboundedSender<AcpEvent>,
74    cmd_rx: mpsc::UnboundedReceiver<PromptCommand>,
75    init_tx: mpsc::UnboundedSender<InitializeResult>,
76    init_request: InitializeRequest,
77    new_session_request: NewSessionRequest,
78) {
79    let connection_result = Client
80        .builder()
81        .on_receive_request(
82            async move |req: RequestPermissionRequest, responder, _cx| {
83                responder.respond(RequestPermissionResponse::new(RequestPermissionOutcome::Selected(
84                    SelectedPermissionOutcome::new(auto_approve_option(&req)),
85                )))
86            },
87            acp::on_receive_request!(),
88        )
89        .on_receive_request(
90            {
91                let event_tx = event_tx.clone();
92                async move |params: ElicitationParams, responder, _cx| {
93                    if let Err(send_err) = event_tx.send(AcpEvent::ElicitationRequest { params, responder }) {
94                        // Recover the responder and reply with an error so the remote caller doesn't hang.
95                        if let AcpEvent::ElicitationRequest { responder, .. } = send_err.0 {
96                            return responder.respond_with_error(acp::Error::internal_error());
97                        }
98                    }
99                    Ok(())
100                }
101            },
102            acp::on_receive_request!(),
103        )
104        .on_receive_notification(
105            {
106                let event_tx = event_tx.clone();
107                async move |SessionNotification { session_id, update, .. }: SessionNotification, _cx| {
108                    let _ = event_tx.send(AcpEvent::SessionUpdate { session_id, update: Box::new(update) });
109                    Ok(())
110                }
111            },
112            acp::on_receive_notification!(),
113        )
114        .on_receive_notification(
115            {
116                let event_tx = event_tx.clone();
117                async move |p: ContextUsageParams, _cx| {
118                    let _ = event_tx.send(AcpEvent::ContextUsage(p));
119                    Ok(())
120                }
121            },
122            acp::on_receive_notification!(),
123        )
124        .on_receive_notification(
125            {
126                let event_tx = event_tx.clone();
127                async move |p: ContextClearedParams, _cx| {
128                    let _ = event_tx.send(AcpEvent::ContextCleared(p));
129                    Ok(())
130                }
131            },
132            acp::on_receive_notification!(),
133        )
134        .on_receive_notification(
135            {
136                let event_tx = event_tx.clone();
137                async move |p: SubAgentProgressParams, _cx| {
138                    let _ = event_tx.send(AcpEvent::SubAgentProgress(p));
139                    Ok(())
140                }
141            },
142            acp::on_receive_notification!(),
143        )
144        .on_receive_notification(
145            {
146                let event_tx = event_tx.clone();
147                async move |p: AuthMethodsUpdatedParams, _cx| {
148                    let _ = event_tx.send(AcpEvent::AuthMethodsUpdated(p));
149                    Ok(())
150                }
151            },
152            acp::on_receive_notification!(),
153        )
154        .on_receive_notification(
155            {
156                let event_tx = event_tx.clone();
157                async move |n: McpNotification, _cx| {
158                    let _ = event_tx.send(AcpEvent::McpNotification(n));
159                    Ok(())
160                }
161            },
162            acp::on_receive_notification!(),
163        )
164        .connect_with(agent, {
165            let event_tx = event_tx.clone();
166            let init_tx = init_tx.clone();
167            async move |cx: ConnectionTo<acp::Agent>| {
168                run_main(cx, event_tx, cmd_rx, init_tx, init_request, new_session_request).await;
169                Ok(())
170            }
171        })
172        .await;
173
174    if let Err(e) = connection_result {
175        tracing::warn!("ACP connection exited with error: {e:?}");
176        let _ = init_tx.send(Err(AcpClientError::ConnectFailed(e)));
177    }
178    let _ = event_tx.send(AcpEvent::ConnectionClosed);
179}
180
181#[allow(clippy::too_many_lines)]
182async fn run_main(
183    cx: ConnectionTo<acp::Agent>,
184    event_tx: mpsc::UnboundedSender<AcpEvent>,
185    mut cmd_rx: mpsc::UnboundedReceiver<PromptCommand>,
186    init_tx: mpsc::UnboundedSender<InitializeResult>,
187    init_request: InitializeRequest,
188    new_session_request: NewSessionRequest,
189) {
190    let init_resp = match cx.send_request(init_request).block_task().await {
191        Ok(r) => r,
192        Err(e) => {
193            let _ = init_tx.send(Err(AcpClientError::Protocol(e)));
194            return;
195        }
196    };
197    info!("ACP initialized: protocol={:?}, agent_info={:?}", init_resp.protocol_version, init_resp.agent_info);
198
199    let session_resp = match cx.send_request(new_session_request).block_task().await {
200        Ok(r) => r,
201        Err(e) => {
202            let _ = init_tx.send(Err(AcpClientError::Protocol(e)));
203            return;
204        }
205    };
206    info!("ACP session created: {}", session_resp.session_id);
207
208    let _ = init_tx.send(Ok((init_resp, session_resp)));
209
210    while let Some(cmd) = cmd_rx.recv().await {
211        match cmd {
212            PromptCommand::Prompt { session_id, text, content } => {
213                let mut prompt = vec![ContentBlock::Text(TextContent::new(text))];
214                if let Some(extra_content) = content {
215                    prompt.extend(extra_content);
216                }
217                let prompt_fut = cx.send_request(PromptRequest::new(session_id, prompt)).block_task();
218                tokio::pin!(prompt_fut);
219
220                loop {
221                    tokio::select! {
222                        result = &mut prompt_fut => {
223                            let event = match result {
224                                Ok(resp) => AcpEvent::PromptDone(resp.stop_reason),
225                                Err(e) => AcpEvent::PromptError(e),
226                            };
227                            let _ = event_tx.send(event);
228                            break;
229                        }
230                        Some(cmd) = cmd_rx.recv() => {
231                            handle_side_command(&cx, &event_tx, cmd).await;
232                        }
233                    }
234                }
235            }
236            PromptCommand::ListSessions => {
237                let req = ListSessionsRequest::new();
238                match cx.send_request(req).block_task().await {
239                    Ok(resp) => {
240                        let _ = event_tx.send(AcpEvent::SessionsListed { sessions: resp.sessions });
241                    }
242                    Err(e) => {
243                        let _ = event_tx.send(AcpEvent::PromptError(e));
244                    }
245                }
246            }
247            PromptCommand::LoadSession { session_id, cwd } => {
248                let req = LoadSessionRequest::new(session_id.clone(), cwd);
249                match cx.send_request(req).block_task().await {
250                    Ok(resp) => {
251                        let config_options = resp.config_options.unwrap_or_default();
252                        let _ = event_tx.send(AcpEvent::SessionLoaded { session_id, config_options });
253                    }
254                    Err(e) => {
255                        let _ = event_tx.send(AcpEvent::PromptError(e));
256                    }
257                }
258            }
259            PromptCommand::NewSession { cwd } => {
260                let req = NewSessionRequest::new(cwd);
261                match cx.send_request(req).block_task().await {
262                    Ok(resp) => {
263                        let config_options = resp.config_options.unwrap_or_default();
264                        let _ =
265                            event_tx.send(AcpEvent::NewSessionCreated { session_id: resp.session_id, config_options });
266                    }
267                    Err(e) => {
268                        let _ = event_tx.send(AcpEvent::PromptError(e));
269                    }
270                }
271            }
272            cmd => handle_side_command(&cx, &event_tx, cmd).await,
273        }
274    }
275}
276
277async fn handle_side_command(
278    cx: &ConnectionTo<acp::Agent>,
279    event_tx: &mpsc::UnboundedSender<AcpEvent>,
280    cmd: PromptCommand,
281) {
282    match cmd {
283        PromptCommand::Cancel { session_id } => {
284            let _ = cx.send_notification(CancelNotification::new(session_id));
285        }
286        PromptCommand::SetConfigOption { session_id, config_id, value } => {
287            let req = SetSessionConfigOptionRequest::new(session_id.clone(), config_id, value);
288            match cx.send_request(req).block_task().await {
289                Ok(resp) => {
290                    let update = ConfigOptionUpdate::new(resp.config_options);
291                    let _ = event_tx.send(AcpEvent::SessionUpdate {
292                        session_id,
293                        update: Box::new(SessionUpdate::ConfigOptionUpdate(update)),
294                    });
295                }
296                Err(e) => {
297                    tracing::warn!("set_session_config_option failed: {e:?}");
298                }
299            }
300        }
301        PromptCommand::Prompt { .. } => {
302            tracing::warn!("ignoring duplicate Prompt while one is in-flight");
303        }
304        PromptCommand::ListSessions => {
305            tracing::warn!("ignoring ListSessions while prompt is in-flight");
306        }
307        PromptCommand::LoadSession { .. } => {
308            tracing::warn!("ignoring LoadSession while prompt is in-flight");
309        }
310        PromptCommand::NewSession { .. } => {
311            tracing::warn!("ignoring NewSession while prompt is in-flight");
312        }
313        PromptCommand::AuthenticateMcpServer { session_id, server_name } => {
314            let msg = McpRequest::Authenticate { session_id: session_id.0.to_string(), server_name };
315            if let Err(e) = cx.send_notification(msg) {
316                tracing::warn!("authenticate_mcp_server notification failed: {e:?}");
317            }
318        }
319        PromptCommand::Authenticate { method_id } => {
320            match cx.send_request(AuthenticateRequest::new(method_id.clone())).block_task().await {
321                Ok(_) => {
322                    let _ = event_tx.send(AcpEvent::AuthenticateComplete { method_id });
323                }
324                Err(e) => {
325                    tracing::warn!("authenticate failed: {e:?}");
326                    let _ = event_tx.send(AcpEvent::AuthenticateFailed { method_id, error: format!("{e:?}") });
327                }
328            }
329        }
330    }
331}
332
333fn auto_approve_option(req: &RequestPermissionRequest) -> PermissionOptionId {
334    debug_assert!(!req.options.is_empty(), "ACP guarantees at least one permission option");
335    req.options
336        .iter()
337        .find(|o| matches!(o.kind, PermissionOptionKind::AllowOnce | PermissionOptionKind::AllowAlways))
338        .map_or_else(|| req.options[0].option_id.clone(), |o| o.option_id.clone())
339}