1use super::error::AcpClientError;
2use super::event::AcpEvent;
3use super::prompt_handle::{AcpPromptHandle, PromptCommand};
4use super::tokio_agent::TokioAcpAgent;
5use crate::notifications::{
6 AuthMethodsUpdatedParams, ContextClearedParams, ContextUsageParams, ElicitationParams, McpNotification, McpRequest,
7 SubAgentProgressParams,
8};
9use agent_client_protocol::schema::{
10 AuthMethod, AuthenticateRequest, CancelNotification, ConfigOptionUpdate, ContentBlock, InitializeRequest,
11 InitializeResponse, ListSessionsRequest, LoadSessionRequest, NewSessionRequest, NewSessionResponse,
12 PermissionOptionId, PermissionOptionKind, PromptCapabilities, PromptRequest, RequestPermissionOutcome,
13 RequestPermissionRequest, RequestPermissionResponse, SelectedPermissionOutcome, SessionCapabilities,
14 SessionConfigOption, SessionId, SessionNotification, SessionUpdate, SetSessionConfigOptionRequest, TextContent,
15};
16use agent_client_protocol::{self as acp, Client, ConnectionTo, JsonRpcRequest};
17use std::str::FromStr;
18use tokio::sync::mpsc;
19use tracing::info;
20
21type InitializeResult = Result<(InitializeResponse, NewSessionResponse), AcpClientError>;
22
23pub struct AcpSession {
25 pub session_id: SessionId,
26 pub agent_name: String,
27 pub prompt_capabilities: PromptCapabilities,
28 pub session_capabilities: SessionCapabilities,
29 pub config_options: Vec<SessionConfigOption>,
30 pub auth_methods: Vec<AuthMethod>,
31 pub event_rx: mpsc::UnboundedReceiver<AcpEvent>,
32 pub prompt_handle: AcpPromptHandle,
33}
34
35pub async fn spawn_acp_session(
41 agent_command: &str,
42 init_request: InitializeRequest,
43 new_session_request: NewSessionRequest,
44) -> Result<AcpSession, AcpClientError> {
45 let agent = TokioAcpAgent::from_str(agent_command).map_err(AcpClientError::InvalidAgentCommand)?;
46 let (event_tx, event_rx) = mpsc::unbounded_channel::<AcpEvent>();
47 let (cmd_tx, cmd_rx) = mpsc::unbounded_channel::<PromptCommand>();
48 let (init_tx, mut init_rx) = mpsc::unbounded_channel::<InitializeResult>();
49 tokio::spawn(run_client_connection(agent, event_tx, cmd_rx, init_tx, init_request, new_session_request));
50
51 let (init_resp, session_resp) = init_rx
52 .recv()
53 .await
54 .ok_or_else(|| AcpClientError::AgentCrashed("ACP task died during initialization".to_string()))??;
55
56 let agent_name = init_resp
57 .agent_info
58 .as_ref()
59 .map_or_else(|| "agent".to_string(), |info| info.title.as_deref().unwrap_or(&info.name).to_string());
60
61 Ok(AcpSession {
62 session_id: session_resp.session_id,
63 agent_name,
64 prompt_capabilities: init_resp.agent_capabilities.prompt_capabilities,
65 session_capabilities: init_resp.agent_capabilities.session_capabilities,
66 config_options: session_resp.config_options.unwrap_or_default(),
67 auth_methods: init_resp.auth_methods,
68 event_rx,
69 prompt_handle: AcpPromptHandle { cmd_tx },
70 })
71}
72
73#[allow(clippy::too_many_lines)]
74async fn run_client_connection(
75 agent: TokioAcpAgent,
76 event_tx: mpsc::UnboundedSender<AcpEvent>,
77 cmd_rx: mpsc::UnboundedReceiver<PromptCommand>,
78 init_tx: mpsc::UnboundedSender<InitializeResult>,
79 init_request: InitializeRequest,
80 new_session_request: NewSessionRequest,
81) {
82 let connection_result = Client
83 .builder()
84 .on_receive_request(
85 async move |req: RequestPermissionRequest, responder, _cx| {
86 responder.respond(RequestPermissionResponse::new(RequestPermissionOutcome::Selected(
87 SelectedPermissionOutcome::new(auto_approve_option(&req)),
88 )))
89 },
90 acp::on_receive_request!(),
91 )
92 .on_receive_request(
93 {
94 let event_tx = event_tx.clone();
95 async move |params: ElicitationParams, responder, _cx| {
96 if let Err(send_err) = event_tx.send(AcpEvent::ElicitationRequest { params, responder }) {
97 if let AcpEvent::ElicitationRequest { responder, .. } = send_err.0 {
99 return responder.respond_with_error(acp::Error::internal_error());
100 }
101 }
102 Ok(())
103 }
104 },
105 acp::on_receive_request!(),
106 )
107 .on_receive_notification(
108 {
109 let event_tx = event_tx.clone();
110 async move |SessionNotification { session_id, update, .. }: SessionNotification, _cx| {
111 let _ = event_tx.send(AcpEvent::SessionUpdate { session_id, update: Box::new(update) });
112 Ok(())
113 }
114 },
115 acp::on_receive_notification!(),
116 )
117 .on_receive_notification(
118 {
119 let event_tx = event_tx.clone();
120 async move |p: ContextUsageParams, _cx| {
121 let _ = event_tx.send(AcpEvent::ContextUsage(p));
122 Ok(())
123 }
124 },
125 acp::on_receive_notification!(),
126 )
127 .on_receive_notification(
128 {
129 let event_tx = event_tx.clone();
130 async move |p: ContextClearedParams, _cx| {
131 let _ = event_tx.send(AcpEvent::ContextCleared(p));
132 Ok(())
133 }
134 },
135 acp::on_receive_notification!(),
136 )
137 .on_receive_notification(
138 {
139 let event_tx = event_tx.clone();
140 async move |p: SubAgentProgressParams, _cx| {
141 let _ = event_tx.send(AcpEvent::SubAgentProgress(p));
142 Ok(())
143 }
144 },
145 acp::on_receive_notification!(),
146 )
147 .on_receive_notification(
148 {
149 let event_tx = event_tx.clone();
150 async move |p: AuthMethodsUpdatedParams, _cx| {
151 let _ = event_tx.send(AcpEvent::AuthMethodsUpdated(p));
152 Ok(())
153 }
154 },
155 acp::on_receive_notification!(),
156 )
157 .on_receive_notification(
158 {
159 let event_tx = event_tx.clone();
160 async move |n: McpNotification, _cx| {
161 let _ = event_tx.send(AcpEvent::McpNotification(n));
162 Ok(())
163 }
164 },
165 acp::on_receive_notification!(),
166 )
167 .connect_with(agent, {
168 let event_tx = event_tx.clone();
169 let init_tx = init_tx.clone();
170 async move |cx: ConnectionTo<acp::Agent>| {
171 run_main(cx, event_tx, cmd_rx, init_tx, init_request, new_session_request).await;
172 Ok(())
173 }
174 })
175 .await;
176
177 if let Err(e) = connection_result {
178 tracing::warn!("ACP connection exited with error: {e:?}");
179 let _ = init_tx.send(Err(AcpClientError::ConnectFailed(e)));
180 }
181 let _ = event_tx.send(AcpEvent::ConnectionClosed);
182}
183
184#[allow(clippy::too_many_lines)]
185async fn run_main(
186 cx: ConnectionTo<acp::Agent>,
187 event_tx: mpsc::UnboundedSender<AcpEvent>,
188 mut cmd_rx: mpsc::UnboundedReceiver<PromptCommand>,
189 init_tx: mpsc::UnboundedSender<InitializeResult>,
190 init_request: InitializeRequest,
191 new_session_request: NewSessionRequest,
192) {
193 let init_resp = match cx.send_request(init_request).block_task().await {
194 Ok(r) => r,
195 Err(e) => {
196 let _ = init_tx.send(Err(AcpClientError::Protocol(e)));
197 return;
198 }
199 };
200 info!("ACP initialized: protocol={:?}, agent_info={:?}", init_resp.protocol_version, init_resp.agent_info);
201
202 let session_resp = match cx.send_request(new_session_request).block_task().await {
203 Ok(r) => r,
204 Err(e) => {
205 let _ = init_tx.send(Err(AcpClientError::Protocol(e)));
206 return;
207 }
208 };
209 info!("ACP session created: {}", session_resp.session_id);
210
211 let _ = init_tx.send(Ok((init_resp, session_resp)));
212
213 while let Some(cmd) = cmd_rx.recv().await {
214 match cmd {
215 PromptCommand::Prompt { session_id, text, content } => {
216 let mut prompt = vec![ContentBlock::Text(TextContent::new(text))];
217 if let Some(extra_content) = content {
218 prompt.extend(extra_content);
219 }
220 let prompt_fut = cx.send_request(PromptRequest::new(session_id, prompt)).block_task();
221 tokio::pin!(prompt_fut);
222
223 loop {
224 tokio::select! {
225 result = &mut prompt_fut => {
226 let event = match result {
227 Ok(resp) => AcpEvent::PromptDone(resp.stop_reason),
228 Err(e) => AcpEvent::PromptError(e),
229 };
230 let _ = event_tx.send(event);
231 break;
232 }
233 Some(cmd) = cmd_rx.recv() => {
234 handle_side_command(&cx, &event_tx, cmd).await;
235 }
236 }
237 }
238 }
239 PromptCommand::ListSessions => {
240 request_to_event(
241 &cx,
242 &event_tx,
243 ListSessionsRequest::new(),
244 |resp| AcpEvent::SessionsListed { sessions: resp.sessions },
245 AcpEvent::PromptError,
246 )
247 .await;
248 }
249 PromptCommand::LoadSession { session_id, cwd } => {
250 request_to_event(
251 &cx,
252 &event_tx,
253 LoadSessionRequest::new(session_id.clone(), cwd),
254 |resp| AcpEvent::SessionLoaded {
255 session_id,
256 config_options: resp.config_options.unwrap_or_default(),
257 },
258 AcpEvent::PromptError,
259 )
260 .await;
261 }
262 PromptCommand::NewSession { cwd } => {
263 request_to_event(
264 &cx,
265 &event_tx,
266 NewSessionRequest::new(cwd),
267 |resp| AcpEvent::NewSessionCreated {
268 session_id: resp.session_id,
269 config_options: resp.config_options.unwrap_or_default(),
270 },
271 AcpEvent::PromptError,
272 )
273 .await;
274 }
275 PromptCommand::MoveWorkspace(params) => {
276 request_to_event(&cx, &event_tx, params, AcpEvent::WorkspaceMoved, |e| AcpEvent::WorkspaceMoveFailed {
277 error: format!("{e}"),
278 })
279 .await;
280 }
281 cmd => handle_side_command(&cx, &event_tx, cmd).await,
282 }
283 }
284}
285
286async fn handle_side_command(
287 cx: &ConnectionTo<acp::Agent>,
288 event_tx: &mpsc::UnboundedSender<AcpEvent>,
289 cmd: PromptCommand,
290) {
291 match cmd {
292 PromptCommand::Cancel { session_id } => {
293 let _ = cx.send_notification(CancelNotification::new(session_id));
294 }
295 PromptCommand::SetConfigOption { session_id, config_id, value } => {
296 let req = SetSessionConfigOptionRequest::new(session_id.clone(), config_id, value);
297 match cx.send_request(req).block_task().await {
298 Ok(resp) => {
299 let update = ConfigOptionUpdate::new(resp.config_options);
300 let _ = event_tx.send(AcpEvent::SessionUpdate {
301 session_id,
302 update: Box::new(SessionUpdate::ConfigOptionUpdate(update)),
303 });
304 }
305 Err(e) => {
306 tracing::warn!("set_session_config_option failed: {e:?}");
307 }
308 }
309 }
310 PromptCommand::Prompt { .. } => {
311 tracing::warn!("ignoring duplicate Prompt while one is in-flight");
312 }
313 PromptCommand::ListSessions => {
314 tracing::warn!("ignoring ListSessions while prompt is in-flight");
315 }
316 PromptCommand::LoadSession { .. } => {
317 tracing::warn!("ignoring LoadSession while prompt is in-flight");
318 }
319 PromptCommand::NewSession { .. } => {
320 tracing::warn!("ignoring NewSession while prompt is in-flight");
321 }
322 PromptCommand::AuthenticateMcpServer { session_id, server_name } => {
323 let msg = McpRequest::Authenticate { session_id: session_id.0.to_string(), server_name };
324 if let Err(e) = cx.send_notification(msg) {
325 tracing::warn!("authenticate_mcp_server notification failed: {e:?}");
326 }
327 }
328 PromptCommand::Authenticate { method_id } => {
329 match cx.send_request(AuthenticateRequest::new(method_id.clone())).block_task().await {
330 Ok(_) => {
331 let _ = event_tx.send(AcpEvent::AuthenticateComplete { method_id });
332 }
333 Err(e) => {
334 tracing::warn!("authenticate failed: {e:?}");
335 let _ = event_tx.send(AcpEvent::AuthenticateFailed { method_id, error: format!("{e:?}") });
336 }
337 }
338 }
339 PromptCommand::SearchPrompts(params) => {
340 let query = params.query.clone();
341 request_to_event(cx, event_tx, params, AcpEvent::PromptSearchResults, |e| AcpEvent::PromptSearchFailed {
342 query,
343 error: format!("{e}"),
344 })
345 .await;
346 }
347 PromptCommand::SessionPreview(params) => {
348 let session_id = params.session_id.clone();
349 request_to_event(cx, event_tx, params, AcpEvent::SessionPreviewLoaded, |e| {
350 AcpEvent::SessionPreviewFailed { session_id, error: format!("{e}") }
351 })
352 .await;
353 }
354 PromptCommand::ListWorkspaces(params) => {
355 request_to_event(cx, event_tx, params, AcpEvent::WorkspacesListed, |e| AcpEvent::WorkspaceListFailed {
356 error: format!("{e}"),
357 })
358 .await;
359 }
360 PromptCommand::MoveWorkspace(_) => {
361 tracing::warn!("ignoring MoveWorkspace while prompt is in-flight");
362 let _ = event_tx.send(AcpEvent::WorkspaceMoveFailed { error: "a prompt is in flight".to_string() });
363 }
364 }
365}
366
367async fn request_to_event<T: JsonRpcRequest>(
368 cx: &ConnectionTo<acp::Agent>,
369 event_tx: &mpsc::UnboundedSender<AcpEvent>,
370 params: T,
371 ok: impl FnOnce(T::Response) -> AcpEvent,
372 err: impl FnOnce(acp::Error) -> AcpEvent,
373) {
374 let event = match cx.send_request(params).block_task().await {
375 Ok(resp) => ok(resp),
376 Err(e) => err(e),
377 };
378 let _ = event_tx.send(event);
379}
380
381fn auto_approve_option(req: &RequestPermissionRequest) -> PermissionOptionId {
382 debug_assert!(!req.options.is_empty(), "ACP guarantees at least one permission option");
383 req.options
384 .iter()
385 .find(|o| matches!(o.kind, PermissionOptionKind::AllowOnce | PermissionOptionKind::AllowAlways))
386 .map_or_else(|| req.options[0].option_id.clone(), |o| o.option_id.clone())
387}