1use std::path::PathBuf;
7use std::str::FromStr;
8use std::sync::Arc;
9
10use async_process::Child;
11use std::pin::pin;
12
13use crate::schema::v1::{EnvVariable, McpServer as SchemaMcpServer, McpServerStdio};
14use crate::{Client, Conductor, Role};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum LineDirection {
19 Stdin,
21 Stdout,
23 Stderr,
25}
26
27pub struct AcpAgent {
60 server: SchemaMcpServer,
61 debug_callback: Option<Arc<dyn Fn(&str, LineDirection) + Send + Sync + 'static>>,
62}
63
64impl std::fmt::Debug for AcpAgent {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.debug_struct("AcpAgent")
67 .field("server", &self.server)
68 .field(
69 "debug_callback",
70 &self.debug_callback.as_ref().map(|_| "..."),
71 )
72 .finish()
73 }
74}
75
76impl AcpAgent {
77 #[must_use]
79 pub fn new(server: SchemaMcpServer) -> Self {
80 Self {
81 server,
82 debug_callback: None,
83 }
84 }
85
86 #[must_use]
89 pub fn zed_claude_code() -> Self {
90 Self::from_str("npx -y @zed-industries/claude-code-acp@latest").expect("valid bash command")
91 }
92
93 #[must_use]
96 pub fn zed_codex() -> Self {
97 Self::from_str("npx -y @zed-industries/codex-acp@latest").expect("valid bash command")
98 }
99
100 #[must_use]
103 pub fn google_gemini() -> Self {
104 Self::from_str("npx -y -- @google/gemini-cli@latest --experimental-acp")
105 .expect("valid bash command")
106 }
107
108 #[must_use]
110 pub fn server(&self) -> &SchemaMcpServer {
111 &self.server
112 }
113
114 #[must_use]
116 pub fn into_server(self) -> SchemaMcpServer {
117 self.server
118 }
119
120 #[must_use]
137 pub fn with_debug<F>(mut self, callback: F) -> Self
138 where
139 F: Fn(&str, LineDirection) + Send + Sync + 'static,
140 {
141 self.debug_callback = Some(Arc::new(callback));
142 self
143 }
144
145 pub fn spawn_process(
148 &self,
149 ) -> Result<
150 (
151 async_process::ChildStdin,
152 async_process::ChildStdout,
153 async_process::ChildStderr,
154 Child,
155 ),
156 crate::Error,
157 > {
158 match &self.server {
159 SchemaMcpServer::Stdio(stdio) => {
160 let mut cmd = async_process::Command::new(&stdio.command);
161 cmd.args(&stdio.args);
162 for env_var in &stdio.env {
163 cmd.env(&env_var.name, &env_var.value);
164 }
165 cmd.stdin(std::process::Stdio::piped())
166 .stdout(std::process::Stdio::piped())
167 .stderr(std::process::Stdio::piped());
168
169 let mut child = cmd.spawn().map_err(crate::Error::into_internal_error)?;
170
171 let child_stdin = child
172 .stdin
173 .take()
174 .ok_or_else(|| crate::util::internal_error("Failed to open stdin"))?;
175 let child_stdout = child
176 .stdout
177 .take()
178 .ok_or_else(|| crate::util::internal_error("Failed to open stdout"))?;
179 let child_stderr = child
180 .stderr
181 .take()
182 .ok_or_else(|| crate::util::internal_error("Failed to open stderr"))?;
183
184 Ok((child_stdin, child_stdout, child_stderr, child))
185 }
186 SchemaMcpServer::Http(_) => Err(crate::util::internal_error(
187 "HTTP transport not yet supported by AcpAgent",
188 )),
189 SchemaMcpServer::Sse(_) => Err(crate::util::internal_error(
190 "SSE transport not yet supported by AcpAgent",
191 )),
192 _ => Err(crate::util::internal_error(
193 "Unknown MCP server transport type",
194 )),
195 }
196 }
197}
198
199struct ChildGuard(Child);
201
202impl ChildGuard {
203 async fn wait(&mut self) -> std::io::Result<std::process::ExitStatus> {
204 self.0.status().await
205 }
206}
207
208impl Drop for ChildGuard {
209 fn drop(&mut self) {
210 drop(self.0.kill());
211 }
212}
213
214async fn monitor_child(
219 child: Child,
220 stderr_rx: futures::channel::oneshot::Receiver<String>,
221) -> Result<(), crate::Error> {
222 let mut guard = ChildGuard(child);
223
224 let status = guard
225 .wait()
226 .await
227 .map_err(|e| crate::util::internal_error(format!("Failed to wait for process: {e}")))?;
228
229 if status.success() {
230 Ok(())
231 } else {
232 let stderr = stderr_rx.await.unwrap_or_default();
233
234 let message = if stderr.is_empty() {
235 format!("Process exited with {status}")
236 } else {
237 format!("Process exited with {status}: {stderr}")
238 };
239
240 Err(crate::util::internal_error(message))
241 }
242}
243
244pub trait AcpAgentCounterpartRole: Role {}
246
247impl AcpAgentCounterpartRole for Client {}
248
249impl AcpAgentCounterpartRole for Conductor {}
250
251impl<Counterpart: AcpAgentCounterpartRole> crate::ConnectTo<Counterpart> for AcpAgent {
252 async fn connect_to(
253 self,
254 client: impl crate::ConnectTo<Counterpart::Counterpart>,
255 ) -> Result<(), crate::Error> {
256 use futures::io::BufReader;
257 use futures::{AsyncBufReadExt, AsyncWriteExt, StreamExt};
258
259 let (child_stdin, child_stdout, child_stderr, child) = self.spawn_process()?;
260
261 let (stderr_tx, stderr_rx) = futures::channel::oneshot::channel::<String>();
263
264 let debug_callback = self.debug_callback.clone();
268 let stderr_future = async move {
269 let stderr_reader = BufReader::new(child_stderr);
270 let mut stderr_lines = stderr_reader.lines();
271 let mut collected = String::new();
272 while let Some(line_result) = stderr_lines.next().await {
273 if let Ok(line) = line_result {
274 if let Some(ref callback) = debug_callback {
275 callback(&line, LineDirection::Stderr);
276 }
277 if !collected.is_empty() {
278 collected.push('\n');
279 }
280 collected.push_str(&line);
281 }
282 }
283 drop(stderr_tx.send(collected));
284 };
285
286 let child_monitor = monitor_child(child, stderr_rx);
288
289 let incoming_lines: std::pin::Pin<
291 Box<dyn futures::Stream<Item = std::io::Result<String>> + Send>,
292 > = if let Some(callback) = self.debug_callback.clone() {
293 Box::pin(BufReader::new(child_stdout).lines().inspect(move |result| {
294 if let Ok(line) = result {
295 callback(line, LineDirection::Stdout);
296 }
297 }))
298 } else {
299 Box::pin(BufReader::new(child_stdout).lines())
300 };
301
302 let outgoing_sink: std::pin::Pin<
304 Box<dyn futures::Sink<String, Error = std::io::Error> + Send>,
305 > = if let Some(callback) = self.debug_callback.clone() {
306 Box::pin(futures::sink::unfold(
307 (child_stdin, callback),
308 async move |(mut writer, callback), line: String| {
309 callback(&line, LineDirection::Stdin);
310 let mut bytes = line.into_bytes();
311 bytes.push(b'\n');
312 writer.write_all(&bytes).await?;
313 Ok::<_, std::io::Error>((writer, callback))
314 },
315 ))
316 } else {
317 Box::pin(futures::sink::unfold(
318 child_stdin,
319 async move |mut writer, line: String| {
320 let mut bytes = line.into_bytes();
321 bytes.push(b'\n');
322 writer.write_all(&bytes).await?;
323 Ok::<_, std::io::Error>(writer)
324 },
325 ))
326 };
327
328 let protocol_future = crate::ConnectTo::<Counterpart>::connect_to(
331 crate::Lines::new(outgoing_sink, incoming_lines),
332 client,
333 );
334
335 let stderr_future = pin!(stderr_future);
336 let protocol_future = pin!(protocol_future);
337 let child_monitor = pin!(child_monitor);
338
339 let main_race = async {
341 match futures::future::select(protocol_future, child_monitor).await {
342 futures::future::Either::Left((result, _))
343 | futures::future::Either::Right((result, _)) => result,
344 }
345 };
346
347 let main_race = pin!(main_race);
350 match futures::future::select(main_race, stderr_future).await {
351 futures::future::Either::Left((result, _)) => result,
352 futures::future::Either::Right(((), protocol)) => protocol.await,
353 }
354 }
355}
356
357impl AcpAgent {
358 pub fn from_args<I, T>(args: I) -> Result<Self, crate::Error>
376 where
377 I: IntoIterator<Item = T>,
378 T: ToString,
379 {
380 let args: Vec<String> = args.into_iter().map(|s| s.to_string()).collect();
381
382 if args.is_empty() {
383 return Err(crate::util::internal_error("Arguments cannot be empty"));
384 }
385
386 let mut env = vec![];
387 let mut command_idx = 0;
388
389 for (i, arg) in args.iter().enumerate() {
390 if let Some((name, value)) = parse_env_var(arg) {
391 env.push(EnvVariable::new(name, value));
392 command_idx = i + 1;
393 } else {
394 break;
395 }
396 }
397
398 if command_idx >= args.len() {
399 return Err(crate::util::internal_error(
400 "No command found (only environment variables provided)",
401 ));
402 }
403
404 let command = PathBuf::from(&args[command_idx]);
405 let cmd_args = args[command_idx + 1..].to_vec();
406
407 let name = command
408 .file_name()
409 .and_then(|n| n.to_str())
410 .unwrap_or("agent")
411 .to_string();
412
413 Ok(AcpAgent {
414 server: SchemaMcpServer::Stdio(
415 McpServerStdio::new(name, command).args(cmd_args).env(env),
416 ),
417 debug_callback: None,
418 })
419 }
420}
421
422fn parse_env_var(s: &str) -> Option<(String, String)> {
424 let eq_pos = s.find('=')?;
425 if eq_pos == 0 {
426 return None;
427 }
428
429 let name = &s[..eq_pos];
430 let value = &s[eq_pos + 1..];
431
432 let mut chars = name.chars();
433 let first = chars.next()?;
434 if !first.is_ascii_alphabetic() && first != '_' {
435 return None;
436 }
437 if !chars.all(|c| c.is_ascii_alphanumeric() || c == '_') {
438 return None;
439 }
440
441 Some((name.to_string(), value.to_string()))
442}
443
444impl FromStr for AcpAgent {
445 type Err = crate::Error;
446
447 fn from_str(s: &str) -> Result<Self, Self::Err> {
448 let trimmed = s.trim();
449
450 if trimmed.starts_with('{') {
451 let server: SchemaMcpServer = serde_json::from_str(trimmed)
452 .map_err(|e| crate::util::internal_error(format!("Failed to parse JSON: {e}")))?;
453 return Ok(Self {
454 server,
455 debug_callback: None,
456 });
457 }
458
459 let parts = shell_words::split(trimmed)
460 .map_err(|e| crate::util::internal_error(format!("Failed to parse command: {e}")))?;
461
462 Self::from_args(parts)
463 }
464}
465
466#[cfg(test)]
467mod tests {
468 use super::*;
469
470 #[test]
471 fn test_parse_simple_command() {
472 let agent = AcpAgent::from_str("python agent.py").unwrap();
473 match agent.server {
474 SchemaMcpServer::Stdio(stdio) => {
475 assert_eq!(stdio.name, "python");
476 assert_eq!(stdio.command, PathBuf::from("python"));
477 assert_eq!(stdio.args, vec!["agent.py"]);
478 assert!(stdio.env.is_empty());
479 }
480 _ => panic!("Expected Stdio variant"),
481 }
482 }
483
484 #[test]
485 fn test_parse_command_with_args() {
486 let agent = AcpAgent::from_str("node server.js --port 8080 --verbose").unwrap();
487 match agent.server {
488 SchemaMcpServer::Stdio(stdio) => {
489 assert_eq!(stdio.name, "node");
490 assert_eq!(stdio.command, PathBuf::from("node"));
491 assert_eq!(stdio.args, vec!["server.js", "--port", "8080", "--verbose"]);
492 assert!(stdio.env.is_empty());
493 }
494 _ => panic!("Expected Stdio variant"),
495 }
496 }
497
498 #[test]
499 fn test_parse_command_with_quotes() {
500 let agent = AcpAgent::from_str(r#"python "my agent.py" --name "Test Agent""#).unwrap();
501 match agent.server {
502 SchemaMcpServer::Stdio(stdio) => {
503 assert_eq!(stdio.name, "python");
504 assert_eq!(stdio.command, PathBuf::from("python"));
505 assert_eq!(stdio.args, vec!["my agent.py", "--name", "Test Agent"]);
506 assert!(stdio.env.is_empty());
507 }
508 _ => panic!("Expected Stdio variant"),
509 }
510 }
511
512 #[test]
513 fn test_parse_json_stdio() {
514 let json = r#"{
515 "type": "stdio",
516 "name": "my-agent",
517 "command": "/usr/bin/python",
518 "args": ["agent.py", "--verbose"],
519 "env": []
520 }"#;
521 let agent = AcpAgent::from_str(json).unwrap();
522 match agent.server {
523 SchemaMcpServer::Stdio(stdio) => {
524 assert_eq!(stdio.name, "my-agent");
525 assert_eq!(stdio.command, PathBuf::from("/usr/bin/python"));
526 assert_eq!(stdio.args, vec!["agent.py", "--verbose"]);
527 assert!(stdio.env.is_empty());
528 }
529 _ => panic!("Expected Stdio variant"),
530 }
531 }
532
533 #[test]
534 fn test_parse_json_http() {
535 let json = r#"{
536 "type": "http",
537 "name": "remote-agent",
538 "url": "https://example.com/agent",
539 "headers": []
540 }"#;
541 let agent = AcpAgent::from_str(json).unwrap();
542 match agent.server {
543 SchemaMcpServer::Http(http) => {
544 assert_eq!(http.name, "remote-agent");
545 assert_eq!(http.url, "https://example.com/agent");
546 assert!(http.headers.is_empty());
547 }
548 _ => panic!("Expected Http variant"),
549 }
550 }
551}