1use super::error::AcpClientError;
2use super::event::AcpEvent;
3use super::prompt_handle::{AcpPromptHandle, PromptCommand};
4use crate::notifications::{
5 AuthMethodsUpdatedParams, ContextClearedParams, ContextUsageParams, ElicitationParams, McpNotification, McpRequest,
6 SubAgentProgressParams,
7};
8use agent_client_protocol::schema::{
9 AuthMethod, AuthenticateRequest, CancelNotification, ConfigOptionUpdate, ContentBlock, InitializeRequest,
10 InitializeResponse, ListSessionsRequest, LoadSessionRequest, NewSessionRequest, NewSessionResponse,
11 PermissionOptionId, PermissionOptionKind, PromptCapabilities, PromptRequest, RequestPermissionOutcome,
12 RequestPermissionRequest, RequestPermissionResponse, SelectedPermissionOutcome, SessionConfigOption, SessionId,
13 SessionNotification, SetSessionConfigOptionRequest, TextContent,
14};
15use agent_client_protocol::{self as acp, Client, ConnectionTo};
16use agent_client_protocol_tokio::AcpAgent;
17use std::str::FromStr;
18use tokio::sync::mpsc;
19use tracing::info;
20
21type InitializeResult = Result<(InitializeResponse, NewSessionResponse), AcpClientError>;
22
23pub struct AcpSession {
25 pub session_id: SessionId,
26 pub agent_name: String,
27 pub prompt_capabilities: PromptCapabilities,
28 pub config_options: Vec<SessionConfigOption>,
29 pub auth_methods: Vec<AuthMethod>,
30 pub event_rx: mpsc::UnboundedReceiver<AcpEvent>,
31 pub prompt_handle: AcpPromptHandle,
32}
33
34pub async fn spawn_acp_session(
40 agent_command: &str,
41 init_request: InitializeRequest,
42 new_session_request: NewSessionRequest,
43) -> Result<AcpSession, AcpClientError> {
44 let agent = AcpAgent::from_str(agent_command).map_err(AcpClientError::InvalidAgentCommand)?;
45 let (event_tx, event_rx) = mpsc::unbounded_channel::<AcpEvent>();
46 let (cmd_tx, cmd_rx) = mpsc::unbounded_channel::<PromptCommand>();
47 let (init_tx, mut init_rx) = mpsc::unbounded_channel::<InitializeResult>();
48 tokio::spawn(run_client_connection(agent, event_tx, cmd_rx, init_tx, init_request, new_session_request));
49
50 let (init_resp, session_resp) = init_rx
51 .recv()
52 .await
53 .ok_or_else(|| AcpClientError::AgentCrashed("ACP task died during initialization".to_string()))??;
54
55 let agent_name = init_resp
56 .agent_info
57 .as_ref()
58 .map_or_else(|| "agent".to_string(), |info| info.title.as_deref().unwrap_or(&info.name).to_string());
59
60 Ok(AcpSession {
61 session_id: session_resp.session_id,
62 agent_name,
63 prompt_capabilities: init_resp.agent_capabilities.prompt_capabilities,
64 config_options: session_resp.config_options.unwrap_or_default(),
65 auth_methods: init_resp.auth_methods,
66 event_rx,
67 prompt_handle: AcpPromptHandle { cmd_tx },
68 })
69}
70
71#[allow(clippy::too_many_lines)]
72async fn run_client_connection(
73 agent: AcpAgent,
74 event_tx: mpsc::UnboundedSender<AcpEvent>,
75 cmd_rx: mpsc::UnboundedReceiver<PromptCommand>,
76 init_tx: mpsc::UnboundedSender<InitializeResult>,
77 init_request: InitializeRequest,
78 new_session_request: NewSessionRequest,
79) {
80 let connection_result = Client
81 .builder()
82 .on_receive_request(
83 async move |req: RequestPermissionRequest, responder, _cx| {
84 responder.respond(RequestPermissionResponse::new(RequestPermissionOutcome::Selected(
85 SelectedPermissionOutcome::new(auto_approve_option(&req)),
86 )))
87 },
88 acp::on_receive_request!(),
89 )
90 .on_receive_request(
91 {
92 let event_tx = event_tx.clone();
93 async move |params: ElicitationParams, responder, _cx| {
94 if let Err(send_err) = event_tx.send(AcpEvent::ElicitationRequest { params, responder }) {
95 if let AcpEvent::ElicitationRequest { responder, .. } = send_err.0 {
97 return responder.respond_with_error(acp::Error::internal_error());
98 }
99 }
100 Ok(())
101 }
102 },
103 acp::on_receive_request!(),
104 )
105 .on_receive_notification(
106 {
107 let event_tx = event_tx.clone();
108 async move |notif: SessionNotification, _cx| {
109 let _ = event_tx.send(AcpEvent::SessionUpdate(Box::new(notif.update)));
110 Ok(())
111 }
112 },
113 acp::on_receive_notification!(),
114 )
115 .on_receive_notification(
116 {
117 let event_tx = event_tx.clone();
118 async move |p: ContextUsageParams, _cx| {
119 let _ = event_tx.send(AcpEvent::ContextUsage(p));
120 Ok(())
121 }
122 },
123 acp::on_receive_notification!(),
124 )
125 .on_receive_notification(
126 {
127 let event_tx = event_tx.clone();
128 async move |p: ContextClearedParams, _cx| {
129 let _ = event_tx.send(AcpEvent::ContextCleared(p));
130 Ok(())
131 }
132 },
133 acp::on_receive_notification!(),
134 )
135 .on_receive_notification(
136 {
137 let event_tx = event_tx.clone();
138 async move |p: SubAgentProgressParams, _cx| {
139 let _ = event_tx.send(AcpEvent::SubAgentProgress(p));
140 Ok(())
141 }
142 },
143 acp::on_receive_notification!(),
144 )
145 .on_receive_notification(
146 {
147 let event_tx = event_tx.clone();
148 async move |p: AuthMethodsUpdatedParams, _cx| {
149 let _ = event_tx.send(AcpEvent::AuthMethodsUpdated(p));
150 Ok(())
151 }
152 },
153 acp::on_receive_notification!(),
154 )
155 .on_receive_notification(
156 {
157 let event_tx = event_tx.clone();
158 async move |n: McpNotification, _cx| {
159 let _ = event_tx.send(AcpEvent::McpNotification(n));
160 Ok(())
161 }
162 },
163 acp::on_receive_notification!(),
164 )
165 .connect_with(agent, {
166 let event_tx = event_tx.clone();
167 let init_tx = init_tx.clone();
168 async move |cx: ConnectionTo<acp::Agent>| {
169 run_main(cx, event_tx, cmd_rx, init_tx, init_request, new_session_request).await;
170 Ok(())
171 }
172 })
173 .await;
174
175 if let Err(e) = connection_result {
176 tracing::warn!("ACP connection exited with error: {e:?}");
177 let _ = init_tx.send(Err(AcpClientError::ConnectFailed(e)));
178 }
179 let _ = event_tx.send(AcpEvent::ConnectionClosed);
180}
181
182#[allow(clippy::too_many_lines)]
183async fn run_main(
184 cx: ConnectionTo<acp::Agent>,
185 event_tx: mpsc::UnboundedSender<AcpEvent>,
186 mut cmd_rx: mpsc::UnboundedReceiver<PromptCommand>,
187 init_tx: mpsc::UnboundedSender<InitializeResult>,
188 init_request: InitializeRequest,
189 new_session_request: NewSessionRequest,
190) {
191 let init_resp = match cx.send_request(init_request).block_task().await {
192 Ok(r) => r,
193 Err(e) => {
194 let _ = init_tx.send(Err(AcpClientError::Protocol(e)));
195 return;
196 }
197 };
198 info!("ACP initialized: protocol={:?}, agent_info={:?}", init_resp.protocol_version, init_resp.agent_info);
199
200 let session_resp = match cx.send_request(new_session_request).block_task().await {
201 Ok(r) => r,
202 Err(e) => {
203 let _ = init_tx.send(Err(AcpClientError::Protocol(e)));
204 return;
205 }
206 };
207 info!("ACP session created: {}", session_resp.session_id);
208
209 let _ = init_tx.send(Ok((init_resp, session_resp)));
210
211 while let Some(cmd) = cmd_rx.recv().await {
212 match cmd {
213 PromptCommand::Prompt { session_id, text, content } => {
214 let mut prompt = vec![ContentBlock::Text(TextContent::new(text))];
215 if let Some(extra_content) = content {
216 prompt.extend(extra_content);
217 }
218 let prompt_fut = cx.send_request(PromptRequest::new(session_id, prompt)).block_task();
219 tokio::pin!(prompt_fut);
220
221 loop {
222 tokio::select! {
223 result = &mut prompt_fut => {
224 let event = match result {
225 Ok(resp) => AcpEvent::PromptDone(resp.stop_reason),
226 Err(e) => AcpEvent::PromptError(e),
227 };
228 let _ = event_tx.send(event);
229 break;
230 }
231 Some(cmd) = cmd_rx.recv() => {
232 handle_side_command(&cx, &event_tx, cmd).await;
233 }
234 }
235 }
236 }
237 PromptCommand::ListSessions => {
238 let req = ListSessionsRequest::new();
239 match cx.send_request(req).block_task().await {
240 Ok(resp) => {
241 let _ = event_tx.send(AcpEvent::SessionsListed { sessions: resp.sessions });
242 }
243 Err(e) => {
244 let _ = event_tx.send(AcpEvent::PromptError(e));
245 }
246 }
247 }
248 PromptCommand::LoadSession { session_id, cwd } => {
249 let req = LoadSessionRequest::new(session_id.clone(), cwd);
250 match cx.send_request(req).block_task().await {
251 Ok(resp) => {
252 let config_options = resp.config_options.unwrap_or_default();
253 let _ = event_tx.send(AcpEvent::SessionLoaded { session_id, config_options });
254 }
255 Err(e) => {
256 let _ = event_tx.send(AcpEvent::PromptError(e));
257 }
258 }
259 }
260 PromptCommand::NewSession { cwd } => {
261 let req = NewSessionRequest::new(cwd);
262 match cx.send_request(req).block_task().await {
263 Ok(resp) => {
264 let config_options = resp.config_options.unwrap_or_default();
265 let _ =
266 event_tx.send(AcpEvent::NewSessionCreated { session_id: resp.session_id, config_options });
267 }
268 Err(e) => {
269 let _ = event_tx.send(AcpEvent::PromptError(e));
270 }
271 }
272 }
273 cmd => handle_side_command(&cx, &event_tx, cmd).await,
274 }
275 }
276}
277
278async fn handle_side_command(
279 cx: &ConnectionTo<acp::Agent>,
280 event_tx: &mpsc::UnboundedSender<AcpEvent>,
281 cmd: PromptCommand,
282) {
283 match cmd {
284 PromptCommand::Cancel { session_id } => {
285 let _ = cx.send_notification(CancelNotification::new(session_id));
286 }
287 PromptCommand::SetConfigOption { session_id, config_id, value } => {
288 let req = SetSessionConfigOptionRequest::new(session_id, config_id, value);
289 match cx.send_request(req).block_task().await {
290 Ok(resp) => {
291 let update = ConfigOptionUpdate::new(resp.config_options);
292 let _ = event_tx.send(AcpEvent::SessionUpdate(Box::new(
293 acp::schema::SessionUpdate::ConfigOptionUpdate(update),
294 )));
295 }
296 Err(e) => {
297 tracing::warn!("set_session_config_option failed: {e:?}");
298 }
299 }
300 }
301 PromptCommand::Prompt { .. } => {
302 tracing::warn!("ignoring duplicate Prompt while one is in-flight");
303 }
304 PromptCommand::ListSessions => {
305 tracing::warn!("ignoring ListSessions while prompt is in-flight");
306 }
307 PromptCommand::LoadSession { .. } => {
308 tracing::warn!("ignoring LoadSession while prompt is in-flight");
309 }
310 PromptCommand::NewSession { .. } => {
311 tracing::warn!("ignoring NewSession while prompt is in-flight");
312 }
313 PromptCommand::AuthenticateMcpServer { session_id, server_name } => {
314 let msg = McpRequest::Authenticate { session_id: session_id.0.to_string(), server_name };
315 if let Err(e) = cx.send_notification(msg) {
316 tracing::warn!("authenticate_mcp_server notification failed: {e:?}");
317 }
318 }
319 PromptCommand::Authenticate { method_id } => {
320 match cx.send_request(AuthenticateRequest::new(method_id.clone())).block_task().await {
321 Ok(_) => {
322 let _ = event_tx.send(AcpEvent::AuthenticateComplete { method_id });
323 }
324 Err(e) => {
325 tracing::warn!("authenticate failed: {e:?}");
326 let _ = event_tx.send(AcpEvent::AuthenticateFailed { method_id, error: format!("{e:?}") });
327 }
328 }
329 }
330 }
331}
332
333fn auto_approve_option(req: &RequestPermissionRequest) -> PermissionOptionId {
334 debug_assert!(!req.options.is_empty(), "ACP guarantees at least one permission option");
335 req.options
336 .iter()
337 .find(|o| matches!(o.kind, PermissionOptionKind::AllowOnce | PermissionOptionKind::AllowAlways))
338 .map_or_else(|| req.options[0].option_id.clone(), |o| o.option_id.clone())
339}