1use super::error::AcpClientError;
2use super::event::AcpEvent;
3use super::prompt_handle::{AcpPromptHandle, PromptCommand};
4use crate::notifications::{
5 AuthMethodsUpdatedParams, ContextClearedParams, ContextUsageParams, ElicitationParams, McpNotification, McpRequest,
6 SubAgentProgressParams,
7};
8use agent_client_protocol::schema::{
9 AuthMethod, AuthenticateRequest, CancelNotification, ConfigOptionUpdate, ContentBlock, InitializeRequest,
10 InitializeResponse, ListSessionsRequest, LoadSessionRequest, NewSessionRequest, NewSessionResponse,
11 PermissionOptionId, PermissionOptionKind, PromptCapabilities, PromptRequest, RequestPermissionOutcome,
12 RequestPermissionRequest, RequestPermissionResponse, SelectedPermissionOutcome, SessionConfigOption, SessionId,
13 SessionNotification, SessionUpdate, SetSessionConfigOptionRequest, TextContent,
14};
15use agent_client_protocol::{self as acp, AcpAgent, Client, ConnectionTo};
16use std::str::FromStr;
17use tokio::sync::mpsc;
18use tracing::info;
19
20type InitializeResult = Result<(InitializeResponse, NewSessionResponse), AcpClientError>;
21
22pub struct AcpSession {
24 pub session_id: SessionId,
25 pub agent_name: String,
26 pub prompt_capabilities: PromptCapabilities,
27 pub config_options: Vec<SessionConfigOption>,
28 pub auth_methods: Vec<AuthMethod>,
29 pub event_rx: mpsc::UnboundedReceiver<AcpEvent>,
30 pub prompt_handle: AcpPromptHandle,
31}
32
33pub async fn spawn_acp_session(
39 agent_command: &str,
40 init_request: InitializeRequest,
41 new_session_request: NewSessionRequest,
42) -> Result<AcpSession, AcpClientError> {
43 let agent = AcpAgent::from_str(agent_command).map_err(AcpClientError::InvalidAgentCommand)?;
44 let (event_tx, event_rx) = mpsc::unbounded_channel::<AcpEvent>();
45 let (cmd_tx, cmd_rx) = mpsc::unbounded_channel::<PromptCommand>();
46 let (init_tx, mut init_rx) = mpsc::unbounded_channel::<InitializeResult>();
47 tokio::spawn(run_client_connection(agent, event_tx, cmd_rx, init_tx, init_request, new_session_request));
48
49 let (init_resp, session_resp) = init_rx
50 .recv()
51 .await
52 .ok_or_else(|| AcpClientError::AgentCrashed("ACP task died during initialization".to_string()))??;
53
54 let agent_name = init_resp
55 .agent_info
56 .as_ref()
57 .map_or_else(|| "agent".to_string(), |info| info.title.as_deref().unwrap_or(&info.name).to_string());
58
59 Ok(AcpSession {
60 session_id: session_resp.session_id,
61 agent_name,
62 prompt_capabilities: init_resp.agent_capabilities.prompt_capabilities,
63 config_options: session_resp.config_options.unwrap_or_default(),
64 auth_methods: init_resp.auth_methods,
65 event_rx,
66 prompt_handle: AcpPromptHandle { cmd_tx },
67 })
68}
69
70#[allow(clippy::too_many_lines)]
71async fn run_client_connection(
72 agent: AcpAgent,
73 event_tx: mpsc::UnboundedSender<AcpEvent>,
74 cmd_rx: mpsc::UnboundedReceiver<PromptCommand>,
75 init_tx: mpsc::UnboundedSender<InitializeResult>,
76 init_request: InitializeRequest,
77 new_session_request: NewSessionRequest,
78) {
79 let connection_result = Client
80 .builder()
81 .on_receive_request(
82 async move |req: RequestPermissionRequest, responder, _cx| {
83 responder.respond(RequestPermissionResponse::new(RequestPermissionOutcome::Selected(
84 SelectedPermissionOutcome::new(auto_approve_option(&req)),
85 )))
86 },
87 acp::on_receive_request!(),
88 )
89 .on_receive_request(
90 {
91 let event_tx = event_tx.clone();
92 async move |params: ElicitationParams, responder, _cx| {
93 if let Err(send_err) = event_tx.send(AcpEvent::ElicitationRequest { params, responder }) {
94 if let AcpEvent::ElicitationRequest { responder, .. } = send_err.0 {
96 return responder.respond_with_error(acp::Error::internal_error());
97 }
98 }
99 Ok(())
100 }
101 },
102 acp::on_receive_request!(),
103 )
104 .on_receive_notification(
105 {
106 let event_tx = event_tx.clone();
107 async move |SessionNotification { session_id, update, .. }: SessionNotification, _cx| {
108 let _ = event_tx.send(AcpEvent::SessionUpdate { session_id, update: Box::new(update) });
109 Ok(())
110 }
111 },
112 acp::on_receive_notification!(),
113 )
114 .on_receive_notification(
115 {
116 let event_tx = event_tx.clone();
117 async move |p: ContextUsageParams, _cx| {
118 let _ = event_tx.send(AcpEvent::ContextUsage(p));
119 Ok(())
120 }
121 },
122 acp::on_receive_notification!(),
123 )
124 .on_receive_notification(
125 {
126 let event_tx = event_tx.clone();
127 async move |p: ContextClearedParams, _cx| {
128 let _ = event_tx.send(AcpEvent::ContextCleared(p));
129 Ok(())
130 }
131 },
132 acp::on_receive_notification!(),
133 )
134 .on_receive_notification(
135 {
136 let event_tx = event_tx.clone();
137 async move |p: SubAgentProgressParams, _cx| {
138 let _ = event_tx.send(AcpEvent::SubAgentProgress(p));
139 Ok(())
140 }
141 },
142 acp::on_receive_notification!(),
143 )
144 .on_receive_notification(
145 {
146 let event_tx = event_tx.clone();
147 async move |p: AuthMethodsUpdatedParams, _cx| {
148 let _ = event_tx.send(AcpEvent::AuthMethodsUpdated(p));
149 Ok(())
150 }
151 },
152 acp::on_receive_notification!(),
153 )
154 .on_receive_notification(
155 {
156 let event_tx = event_tx.clone();
157 async move |n: McpNotification, _cx| {
158 let _ = event_tx.send(AcpEvent::McpNotification(n));
159 Ok(())
160 }
161 },
162 acp::on_receive_notification!(),
163 )
164 .connect_with(agent, {
165 let event_tx = event_tx.clone();
166 let init_tx = init_tx.clone();
167 async move |cx: ConnectionTo<acp::Agent>| {
168 run_main(cx, event_tx, cmd_rx, init_tx, init_request, new_session_request).await;
169 Ok(())
170 }
171 })
172 .await;
173
174 if let Err(e) = connection_result {
175 tracing::warn!("ACP connection exited with error: {e:?}");
176 let _ = init_tx.send(Err(AcpClientError::ConnectFailed(e)));
177 }
178 let _ = event_tx.send(AcpEvent::ConnectionClosed);
179}
180
181#[allow(clippy::too_many_lines)]
182async fn run_main(
183 cx: ConnectionTo<acp::Agent>,
184 event_tx: mpsc::UnboundedSender<AcpEvent>,
185 mut cmd_rx: mpsc::UnboundedReceiver<PromptCommand>,
186 init_tx: mpsc::UnboundedSender<InitializeResult>,
187 init_request: InitializeRequest,
188 new_session_request: NewSessionRequest,
189) {
190 let init_resp = match cx.send_request(init_request).block_task().await {
191 Ok(r) => r,
192 Err(e) => {
193 let _ = init_tx.send(Err(AcpClientError::Protocol(e)));
194 return;
195 }
196 };
197 info!("ACP initialized: protocol={:?}, agent_info={:?}", init_resp.protocol_version, init_resp.agent_info);
198
199 let session_resp = match cx.send_request(new_session_request).block_task().await {
200 Ok(r) => r,
201 Err(e) => {
202 let _ = init_tx.send(Err(AcpClientError::Protocol(e)));
203 return;
204 }
205 };
206 info!("ACP session created: {}", session_resp.session_id);
207
208 let _ = init_tx.send(Ok((init_resp, session_resp)));
209
210 while let Some(cmd) = cmd_rx.recv().await {
211 match cmd {
212 PromptCommand::Prompt { session_id, text, content } => {
213 let mut prompt = vec![ContentBlock::Text(TextContent::new(text))];
214 if let Some(extra_content) = content {
215 prompt.extend(extra_content);
216 }
217 let prompt_fut = cx.send_request(PromptRequest::new(session_id, prompt)).block_task();
218 tokio::pin!(prompt_fut);
219
220 loop {
221 tokio::select! {
222 result = &mut prompt_fut => {
223 let event = match result {
224 Ok(resp) => AcpEvent::PromptDone(resp.stop_reason),
225 Err(e) => AcpEvent::PromptError(e),
226 };
227 let _ = event_tx.send(event);
228 break;
229 }
230 Some(cmd) = cmd_rx.recv() => {
231 handle_side_command(&cx, &event_tx, cmd).await;
232 }
233 }
234 }
235 }
236 PromptCommand::ListSessions => {
237 let req = ListSessionsRequest::new();
238 match cx.send_request(req).block_task().await {
239 Ok(resp) => {
240 let _ = event_tx.send(AcpEvent::SessionsListed { sessions: resp.sessions });
241 }
242 Err(e) => {
243 let _ = event_tx.send(AcpEvent::PromptError(e));
244 }
245 }
246 }
247 PromptCommand::LoadSession { session_id, cwd } => {
248 let req = LoadSessionRequest::new(session_id.clone(), cwd);
249 match cx.send_request(req).block_task().await {
250 Ok(resp) => {
251 let config_options = resp.config_options.unwrap_or_default();
252 let _ = event_tx.send(AcpEvent::SessionLoaded { session_id, config_options });
253 }
254 Err(e) => {
255 let _ = event_tx.send(AcpEvent::PromptError(e));
256 }
257 }
258 }
259 PromptCommand::NewSession { cwd } => {
260 let req = NewSessionRequest::new(cwd);
261 match cx.send_request(req).block_task().await {
262 Ok(resp) => {
263 let config_options = resp.config_options.unwrap_or_default();
264 let _ =
265 event_tx.send(AcpEvent::NewSessionCreated { session_id: resp.session_id, config_options });
266 }
267 Err(e) => {
268 let _ = event_tx.send(AcpEvent::PromptError(e));
269 }
270 }
271 }
272 cmd => handle_side_command(&cx, &event_tx, cmd).await,
273 }
274 }
275}
276
277async fn handle_side_command(
278 cx: &ConnectionTo<acp::Agent>,
279 event_tx: &mpsc::UnboundedSender<AcpEvent>,
280 cmd: PromptCommand,
281) {
282 match cmd {
283 PromptCommand::Cancel { session_id } => {
284 let _ = cx.send_notification(CancelNotification::new(session_id));
285 }
286 PromptCommand::SetConfigOption { session_id, config_id, value } => {
287 let req = SetSessionConfigOptionRequest::new(session_id.clone(), config_id, value);
288 match cx.send_request(req).block_task().await {
289 Ok(resp) => {
290 let update = ConfigOptionUpdate::new(resp.config_options);
291 let _ = event_tx.send(AcpEvent::SessionUpdate {
292 session_id,
293 update: Box::new(SessionUpdate::ConfigOptionUpdate(update)),
294 });
295 }
296 Err(e) => {
297 tracing::warn!("set_session_config_option failed: {e:?}");
298 }
299 }
300 }
301 PromptCommand::Prompt { .. } => {
302 tracing::warn!("ignoring duplicate Prompt while one is in-flight");
303 }
304 PromptCommand::ListSessions => {
305 tracing::warn!("ignoring ListSessions while prompt is in-flight");
306 }
307 PromptCommand::LoadSession { .. } => {
308 tracing::warn!("ignoring LoadSession while prompt is in-flight");
309 }
310 PromptCommand::NewSession { .. } => {
311 tracing::warn!("ignoring NewSession while prompt is in-flight");
312 }
313 PromptCommand::AuthenticateMcpServer { session_id, server_name } => {
314 let msg = McpRequest::Authenticate { session_id: session_id.0.to_string(), server_name };
315 if let Err(e) = cx.send_notification(msg) {
316 tracing::warn!("authenticate_mcp_server notification failed: {e:?}");
317 }
318 }
319 PromptCommand::Authenticate { method_id } => {
320 match cx.send_request(AuthenticateRequest::new(method_id.clone())).block_task().await {
321 Ok(_) => {
322 let _ = event_tx.send(AcpEvent::AuthenticateComplete { method_id });
323 }
324 Err(e) => {
325 tracing::warn!("authenticate failed: {e:?}");
326 let _ = event_tx.send(AcpEvent::AuthenticateFailed { method_id, error: format!("{e:?}") });
327 }
328 }
329 }
330 }
331}
332
333fn auto_approve_option(req: &RequestPermissionRequest) -> PermissionOptionId {
334 debug_assert!(!req.options.is_empty(), "ACP guarantees at least one permission option");
335 req.options
336 .iter()
337 .find(|o| matches!(o.kind, PermissionOptionKind::AllowOnce | PermissionOptionKind::AllowAlways))
338 .map_or_else(|| req.options[0].option_id.clone(), |o| o.option_id.clone())
339}