Skip to main content

acp_utils/client/
session.rs

1use super::error::AcpClientError;
2use super::event::AcpEvent;
3use super::prompt_handle::{AcpPromptHandle, PromptCommand};
4use crate::notifications::{
5    AuthMethodsUpdatedParams, ContextClearedParams, ContextUsageParams, ElicitationParams, McpNotification, McpRequest,
6    SubAgentProgressParams,
7};
8use agent_client_protocol::schema::{
9    AuthMethod, AuthenticateRequest, CancelNotification, ConfigOptionUpdate, ContentBlock, InitializeRequest,
10    InitializeResponse, ListSessionsRequest, LoadSessionRequest, NewSessionRequest, NewSessionResponse,
11    PermissionOptionId, PermissionOptionKind, PromptCapabilities, PromptRequest, RequestPermissionOutcome,
12    RequestPermissionRequest, RequestPermissionResponse, SelectedPermissionOutcome, SessionConfigOption, SessionId,
13    SessionNotification, SetSessionConfigOptionRequest, TextContent,
14};
15use agent_client_protocol::{self as acp, Client, ConnectionTo};
16use agent_client_protocol_tokio::AcpAgent;
17use std::str::FromStr;
18use tokio::sync::mpsc;
19use tracing::info;
20
21type InitializeResult = Result<(InitializeResponse, NewSessionResponse), AcpClientError>;
22
23/// ACP session with all handles needed by the caller.
24pub struct AcpSession {
25    pub session_id: SessionId,
26    pub agent_name: String,
27    pub prompt_capabilities: PromptCapabilities,
28    pub config_options: Vec<SessionConfigOption>,
29    pub auth_methods: Vec<AuthMethod>,
30    pub event_rx: mpsc::UnboundedReceiver<AcpEvent>,
31    pub prompt_handle: AcpPromptHandle,
32}
33
34/// Spawn an agent subprocess and establish an ACP session.
35///
36/// The connection auto-approves permissions, forwards session notifications as
37/// [`AcpEvent`]s, and tunnels elicitation requests through the `_aether/elicitation`
38/// extension method.
39pub async fn spawn_acp_session(
40    agent_command: &str,
41    init_request: InitializeRequest,
42    new_session_request: NewSessionRequest,
43) -> Result<AcpSession, AcpClientError> {
44    let agent = AcpAgent::from_str(agent_command).map_err(AcpClientError::InvalidAgentCommand)?;
45    let (event_tx, event_rx) = mpsc::unbounded_channel::<AcpEvent>();
46    let (cmd_tx, cmd_rx) = mpsc::unbounded_channel::<PromptCommand>();
47    let (init_tx, mut init_rx) = mpsc::unbounded_channel::<InitializeResult>();
48    tokio::spawn(run_client_connection(agent, event_tx, cmd_rx, init_tx, init_request, new_session_request));
49
50    let (init_resp, session_resp) = init_rx
51        .recv()
52        .await
53        .ok_or_else(|| AcpClientError::AgentCrashed("ACP task died during initialization".to_string()))??;
54
55    let agent_name = init_resp
56        .agent_info
57        .as_ref()
58        .map_or_else(|| "agent".to_string(), |info| info.title.as_deref().unwrap_or(&info.name).to_string());
59
60    Ok(AcpSession {
61        session_id: session_resp.session_id,
62        agent_name,
63        prompt_capabilities: init_resp.agent_capabilities.prompt_capabilities,
64        config_options: session_resp.config_options.unwrap_or_default(),
65        auth_methods: init_resp.auth_methods,
66        event_rx,
67        prompt_handle: AcpPromptHandle { cmd_tx },
68    })
69}
70
71#[allow(clippy::too_many_lines)]
72async fn run_client_connection(
73    agent: AcpAgent,
74    event_tx: mpsc::UnboundedSender<AcpEvent>,
75    cmd_rx: mpsc::UnboundedReceiver<PromptCommand>,
76    init_tx: mpsc::UnboundedSender<InitializeResult>,
77    init_request: InitializeRequest,
78    new_session_request: NewSessionRequest,
79) {
80    let connection_result = Client
81        .builder()
82        .on_receive_request(
83            async move |req: RequestPermissionRequest, responder, _cx| {
84                responder.respond(RequestPermissionResponse::new(RequestPermissionOutcome::Selected(
85                    SelectedPermissionOutcome::new(auto_approve_option(&req)),
86                )))
87            },
88            acp::on_receive_request!(),
89        )
90        .on_receive_request(
91            {
92                let event_tx = event_tx.clone();
93                async move |params: ElicitationParams, responder, _cx| {
94                    if let Err(send_err) = event_tx.send(AcpEvent::ElicitationRequest { params, responder }) {
95                        // Recover the responder and reply with an error so the remote caller doesn't hang.
96                        if let AcpEvent::ElicitationRequest { responder, .. } = send_err.0 {
97                            return responder.respond_with_error(acp::Error::internal_error());
98                        }
99                    }
100                    Ok(())
101                }
102            },
103            acp::on_receive_request!(),
104        )
105        .on_receive_notification(
106            {
107                let event_tx = event_tx.clone();
108                async move |notif: SessionNotification, _cx| {
109                    let _ = event_tx.send(AcpEvent::SessionUpdate(Box::new(notif.update)));
110                    Ok(())
111                }
112            },
113            acp::on_receive_notification!(),
114        )
115        .on_receive_notification(
116            {
117                let event_tx = event_tx.clone();
118                async move |p: ContextUsageParams, _cx| {
119                    let _ = event_tx.send(AcpEvent::ContextUsage(p));
120                    Ok(())
121                }
122            },
123            acp::on_receive_notification!(),
124        )
125        .on_receive_notification(
126            {
127                let event_tx = event_tx.clone();
128                async move |p: ContextClearedParams, _cx| {
129                    let _ = event_tx.send(AcpEvent::ContextCleared(p));
130                    Ok(())
131                }
132            },
133            acp::on_receive_notification!(),
134        )
135        .on_receive_notification(
136            {
137                let event_tx = event_tx.clone();
138                async move |p: SubAgentProgressParams, _cx| {
139                    let _ = event_tx.send(AcpEvent::SubAgentProgress(p));
140                    Ok(())
141                }
142            },
143            acp::on_receive_notification!(),
144        )
145        .on_receive_notification(
146            {
147                let event_tx = event_tx.clone();
148                async move |p: AuthMethodsUpdatedParams, _cx| {
149                    let _ = event_tx.send(AcpEvent::AuthMethodsUpdated(p));
150                    Ok(())
151                }
152            },
153            acp::on_receive_notification!(),
154        )
155        .on_receive_notification(
156            {
157                let event_tx = event_tx.clone();
158                async move |n: McpNotification, _cx| {
159                    let _ = event_tx.send(AcpEvent::McpNotification(n));
160                    Ok(())
161                }
162            },
163            acp::on_receive_notification!(),
164        )
165        .connect_with(agent, {
166            let event_tx = event_tx.clone();
167            let init_tx = init_tx.clone();
168            async move |cx: ConnectionTo<acp::Agent>| {
169                run_main(cx, event_tx, cmd_rx, init_tx, init_request, new_session_request).await;
170                Ok(())
171            }
172        })
173        .await;
174
175    if let Err(e) = connection_result {
176        tracing::warn!("ACP connection exited with error: {e:?}");
177        let _ = init_tx.send(Err(AcpClientError::ConnectFailed(e)));
178    }
179    let _ = event_tx.send(AcpEvent::ConnectionClosed);
180}
181
182#[allow(clippy::too_many_lines)]
183async fn run_main(
184    cx: ConnectionTo<acp::Agent>,
185    event_tx: mpsc::UnboundedSender<AcpEvent>,
186    mut cmd_rx: mpsc::UnboundedReceiver<PromptCommand>,
187    init_tx: mpsc::UnboundedSender<InitializeResult>,
188    init_request: InitializeRequest,
189    new_session_request: NewSessionRequest,
190) {
191    let init_resp = match cx.send_request(init_request).block_task().await {
192        Ok(r) => r,
193        Err(e) => {
194            let _ = init_tx.send(Err(AcpClientError::Protocol(e)));
195            return;
196        }
197    };
198    info!("ACP initialized: protocol={:?}, agent_info={:?}", init_resp.protocol_version, init_resp.agent_info);
199
200    let session_resp = match cx.send_request(new_session_request).block_task().await {
201        Ok(r) => r,
202        Err(e) => {
203            let _ = init_tx.send(Err(AcpClientError::Protocol(e)));
204            return;
205        }
206    };
207    info!("ACP session created: {}", session_resp.session_id);
208
209    let _ = init_tx.send(Ok((init_resp, session_resp)));
210
211    while let Some(cmd) = cmd_rx.recv().await {
212        match cmd {
213            PromptCommand::Prompt { session_id, text, content } => {
214                let mut prompt = vec![ContentBlock::Text(TextContent::new(text))];
215                if let Some(extra_content) = content {
216                    prompt.extend(extra_content);
217                }
218                let prompt_fut = cx.send_request(PromptRequest::new(session_id, prompt)).block_task();
219                tokio::pin!(prompt_fut);
220
221                loop {
222                    tokio::select! {
223                        result = &mut prompt_fut => {
224                            let event = match result {
225                                Ok(resp) => AcpEvent::PromptDone(resp.stop_reason),
226                                Err(e) => AcpEvent::PromptError(e),
227                            };
228                            let _ = event_tx.send(event);
229                            break;
230                        }
231                        Some(cmd) = cmd_rx.recv() => {
232                            handle_side_command(&cx, &event_tx, cmd).await;
233                        }
234                    }
235                }
236            }
237            PromptCommand::ListSessions => {
238                let req = ListSessionsRequest::new();
239                match cx.send_request(req).block_task().await {
240                    Ok(resp) => {
241                        let _ = event_tx.send(AcpEvent::SessionsListed { sessions: resp.sessions });
242                    }
243                    Err(e) => {
244                        let _ = event_tx.send(AcpEvent::PromptError(e));
245                    }
246                }
247            }
248            PromptCommand::LoadSession { session_id, cwd } => {
249                let req = LoadSessionRequest::new(session_id.clone(), cwd);
250                match cx.send_request(req).block_task().await {
251                    Ok(resp) => {
252                        let config_options = resp.config_options.unwrap_or_default();
253                        let _ = event_tx.send(AcpEvent::SessionLoaded { session_id, config_options });
254                    }
255                    Err(e) => {
256                        let _ = event_tx.send(AcpEvent::PromptError(e));
257                    }
258                }
259            }
260            PromptCommand::NewSession { cwd } => {
261                let req = NewSessionRequest::new(cwd);
262                match cx.send_request(req).block_task().await {
263                    Ok(resp) => {
264                        let config_options = resp.config_options.unwrap_or_default();
265                        let _ =
266                            event_tx.send(AcpEvent::NewSessionCreated { session_id: resp.session_id, config_options });
267                    }
268                    Err(e) => {
269                        let _ = event_tx.send(AcpEvent::PromptError(e));
270                    }
271                }
272            }
273            cmd => handle_side_command(&cx, &event_tx, cmd).await,
274        }
275    }
276}
277
278async fn handle_side_command(
279    cx: &ConnectionTo<acp::Agent>,
280    event_tx: &mpsc::UnboundedSender<AcpEvent>,
281    cmd: PromptCommand,
282) {
283    match cmd {
284        PromptCommand::Cancel { session_id } => {
285            let _ = cx.send_notification(CancelNotification::new(session_id));
286        }
287        PromptCommand::SetConfigOption { session_id, config_id, value } => {
288            let req = SetSessionConfigOptionRequest::new(session_id, config_id, value);
289            match cx.send_request(req).block_task().await {
290                Ok(resp) => {
291                    let update = ConfigOptionUpdate::new(resp.config_options);
292                    let _ = event_tx.send(AcpEvent::SessionUpdate(Box::new(
293                        acp::schema::SessionUpdate::ConfigOptionUpdate(update),
294                    )));
295                }
296                Err(e) => {
297                    tracing::warn!("set_session_config_option failed: {e:?}");
298                }
299            }
300        }
301        PromptCommand::Prompt { .. } => {
302            tracing::warn!("ignoring duplicate Prompt while one is in-flight");
303        }
304        PromptCommand::ListSessions => {
305            tracing::warn!("ignoring ListSessions while prompt is in-flight");
306        }
307        PromptCommand::LoadSession { .. } => {
308            tracing::warn!("ignoring LoadSession while prompt is in-flight");
309        }
310        PromptCommand::NewSession { .. } => {
311            tracing::warn!("ignoring NewSession while prompt is in-flight");
312        }
313        PromptCommand::AuthenticateMcpServer { session_id, server_name } => {
314            let msg = McpRequest::Authenticate { session_id: session_id.0.to_string(), server_name };
315            if let Err(e) = cx.send_notification(msg) {
316                tracing::warn!("authenticate_mcp_server notification failed: {e:?}");
317            }
318        }
319        PromptCommand::Authenticate { method_id } => {
320            match cx.send_request(AuthenticateRequest::new(method_id.clone())).block_task().await {
321                Ok(_) => {
322                    let _ = event_tx.send(AcpEvent::AuthenticateComplete { method_id });
323                }
324                Err(e) => {
325                    tracing::warn!("authenticate failed: {e:?}");
326                    let _ = event_tx.send(AcpEvent::AuthenticateFailed { method_id, error: format!("{e:?}") });
327                }
328            }
329        }
330    }
331}
332
333fn auto_approve_option(req: &RequestPermissionRequest) -> PermissionOptionId {
334    debug_assert!(!req.options.is_empty(), "ACP guarantees at least one permission option");
335    req.options
336        .iter()
337        .find(|o| matches!(o.kind, PermissionOptionKind::AllowOnce | PermissionOptionKind::AllowAlways))
338        .map_or_else(|| req.options[0].option_id.clone(), |o| o.option_id.clone())
339}