1use std::path::PathBuf;
7use std::str::FromStr;
8use std::sync::Arc;
9
10use async_process::Child;
11use std::pin::pin;
12
13use crate::schema::v1::{EnvVariable, McpServer as SchemaMcpServer, McpServerStdio};
14use crate::{Client, Conductor, Role};
15
16#[derive(Debug, Clone, Copy, PartialEq, Eq)]
18pub enum LineDirection {
19 Stdin,
21 Stdout,
23 Stderr,
25}
26
27pub struct AcpAgent {
60 server: SchemaMcpServer,
61 debug_callback: Option<Arc<dyn Fn(&str, LineDirection) + Send + Sync + 'static>>,
62}
63
64impl std::fmt::Debug for AcpAgent {
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.debug_struct("AcpAgent")
67 .field("server", &self.server)
68 .field(
69 "debug_callback",
70 &self.debug_callback.as_ref().map(|_| "..."),
71 )
72 .finish()
73 }
74}
75
76impl AcpAgent {
77 #[must_use]
79 pub fn new(server: SchemaMcpServer) -> Self {
80 Self {
81 server,
82 debug_callback: None,
83 }
84 }
85
86 #[must_use]
89 pub fn zed_claude_code() -> Self {
90 Self::from_str("npx -y @zed-industries/claude-code-acp@latest").expect("valid bash command")
91 }
92
93 #[must_use]
96 pub fn zed_codex() -> Self {
97 Self::from_str("npx -y @zed-industries/codex-acp@latest").expect("valid bash command")
98 }
99
100 #[must_use]
103 pub fn google_gemini() -> Self {
104 Self::from_str("npx -y -- @google/gemini-cli@latest --experimental-acp")
105 .expect("valid bash command")
106 }
107
108 #[must_use]
110 pub fn server(&self) -> &SchemaMcpServer {
111 &self.server
112 }
113
114 #[must_use]
116 pub fn into_server(self) -> SchemaMcpServer {
117 self.server
118 }
119
120 #[must_use]
137 pub fn with_debug<F>(mut self, callback: F) -> Self
138 where
139 F: Fn(&str, LineDirection) + Send + Sync + 'static,
140 {
141 self.debug_callback = Some(Arc::new(callback));
142 self
143 }
144
145 pub fn spawn_process(
148 &self,
149 ) -> Result<
150 (
151 async_process::ChildStdin,
152 async_process::ChildStdout,
153 async_process::ChildStderr,
154 Child,
155 ),
156 crate::Error,
157 > {
158 match &self.server {
159 SchemaMcpServer::Stdio(stdio) => {
160 let mut cmd = async_process::Command::new(&stdio.command);
161 cmd.args(&stdio.args);
162 for env_var in &stdio.env {
163 cmd.env(&env_var.name, &env_var.value);
164 }
165 #[cfg(windows)]
166 {
167 use async_process::windows::CommandExt as _;
168
169 cmd.creation_flags(windows_sys::Win32::System::Threading::CREATE_NO_WINDOW);
170 }
171 cmd.stdin(std::process::Stdio::piped())
172 .stdout(std::process::Stdio::piped())
173 .stderr(std::process::Stdio::piped());
174
175 let mut child = cmd.spawn().map_err(crate::Error::into_internal_error)?;
176
177 let child_stdin = child
178 .stdin
179 .take()
180 .ok_or_else(|| crate::util::internal_error("Failed to open stdin"))?;
181 let child_stdout = child
182 .stdout
183 .take()
184 .ok_or_else(|| crate::util::internal_error("Failed to open stdout"))?;
185 let child_stderr = child
186 .stderr
187 .take()
188 .ok_or_else(|| crate::util::internal_error("Failed to open stderr"))?;
189
190 Ok((child_stdin, child_stdout, child_stderr, child))
191 }
192 SchemaMcpServer::Http(_) => Err(crate::util::internal_error(
193 "HTTP transport not yet supported by AcpAgent",
194 )),
195 SchemaMcpServer::Sse(_) => Err(crate::util::internal_error(
196 "SSE transport not yet supported by AcpAgent",
197 )),
198 _ => Err(crate::util::internal_error(
199 "Unknown MCP server transport type",
200 )),
201 }
202 }
203}
204
205struct ChildGuard(Child);
207
208impl ChildGuard {
209 async fn wait(&mut self) -> std::io::Result<std::process::ExitStatus> {
210 self.0.status().await
211 }
212}
213
214impl Drop for ChildGuard {
215 fn drop(&mut self) {
216 drop(self.0.kill());
217 }
218}
219
220async fn monitor_child(
225 child: Child,
226 stderr_rx: futures::channel::oneshot::Receiver<String>,
227) -> Result<(), crate::Error> {
228 let mut guard = ChildGuard(child);
229
230 let status = guard
231 .wait()
232 .await
233 .map_err(|e| crate::util::internal_error(format!("Failed to wait for process: {e}")))?;
234
235 if status.success() {
236 Ok(())
237 } else {
238 let stderr = stderr_rx.await.unwrap_or_default();
239
240 let message = if stderr.is_empty() {
241 format!("Process exited with {status}")
242 } else {
243 format!("Process exited with {status}: {stderr}")
244 };
245
246 Err(crate::util::internal_error(message))
247 }
248}
249
250pub trait AcpAgentCounterpartRole: Role {}
252
253impl AcpAgentCounterpartRole for Client {}
254
255impl AcpAgentCounterpartRole for Conductor {}
256
257impl<Counterpart: AcpAgentCounterpartRole> crate::ConnectTo<Counterpart> for AcpAgent {
258 async fn connect_to(
259 self,
260 client: impl crate::ConnectTo<Counterpart::Counterpart>,
261 ) -> Result<(), crate::Error> {
262 use futures::io::BufReader;
263 use futures::{AsyncBufReadExt, AsyncWriteExt, StreamExt};
264
265 let (child_stdin, child_stdout, child_stderr, child) = self.spawn_process()?;
266
267 let (stderr_tx, stderr_rx) = futures::channel::oneshot::channel::<String>();
269
270 let debug_callback = self.debug_callback.clone();
274 let stderr_future = async move {
275 let stderr_reader = BufReader::new(child_stderr);
276 let mut stderr_lines = stderr_reader.lines();
277 let mut collected = String::new();
278 while let Some(line_result) = stderr_lines.next().await {
279 if let Ok(line) = line_result {
280 if let Some(ref callback) = debug_callback {
281 callback(&line, LineDirection::Stderr);
282 }
283 if !collected.is_empty() {
284 collected.push('\n');
285 }
286 collected.push_str(&line);
287 }
288 }
289 drop(stderr_tx.send(collected));
290 };
291
292 let child_monitor = monitor_child(child, stderr_rx);
294
295 let incoming_lines: std::pin::Pin<
297 Box<dyn futures::Stream<Item = std::io::Result<String>> + Send>,
298 > = if let Some(callback) = self.debug_callback.clone() {
299 Box::pin(BufReader::new(child_stdout).lines().inspect(move |result| {
300 if let Ok(line) = result {
301 callback(line, LineDirection::Stdout);
302 }
303 }))
304 } else {
305 Box::pin(BufReader::new(child_stdout).lines())
306 };
307
308 let outgoing_sink: std::pin::Pin<
310 Box<dyn futures::Sink<String, Error = std::io::Error> + Send>,
311 > = if let Some(callback) = self.debug_callback.clone() {
312 Box::pin(futures::sink::unfold(
313 (child_stdin, callback),
314 async move |(mut writer, callback), line: String| {
315 callback(&line, LineDirection::Stdin);
316 let mut bytes = line.into_bytes();
317 bytes.push(b'\n');
318 writer.write_all(&bytes).await?;
319 Ok::<_, std::io::Error>((writer, callback))
320 },
321 ))
322 } else {
323 Box::pin(futures::sink::unfold(
324 child_stdin,
325 async move |mut writer, line: String| {
326 let mut bytes = line.into_bytes();
327 bytes.push(b'\n');
328 writer.write_all(&bytes).await?;
329 Ok::<_, std::io::Error>(writer)
330 },
331 ))
332 };
333
334 let protocol_future = crate::ConnectTo::<Counterpart>::connect_to(
337 crate::Lines::new(outgoing_sink, incoming_lines),
338 client,
339 );
340
341 let stderr_future = pin!(stderr_future);
342 let protocol_future = pin!(protocol_future);
343 let child_monitor = pin!(child_monitor);
344
345 let main_race = async {
347 match futures::future::select(protocol_future, child_monitor).await {
348 futures::future::Either::Left((result, _))
349 | futures::future::Either::Right((result, _)) => result,
350 }
351 };
352
353 let main_race = pin!(main_race);
356 match futures::future::select(main_race, stderr_future).await {
357 futures::future::Either::Left((result, _)) => result,
358 futures::future::Either::Right(((), protocol)) => protocol.await,
359 }
360 }
361}
362
363impl AcpAgent {
364 pub fn from_args<I, T>(args: I) -> Result<Self, crate::Error>
382 where
383 I: IntoIterator<Item = T>,
384 T: ToString,
385 {
386 let args: Vec<String> = args.into_iter().map(|s| s.to_string()).collect();
387
388 if args.is_empty() {
389 return Err(crate::util::internal_error("Arguments cannot be empty"));
390 }
391
392 let mut env = vec![];
393 let mut command_idx = 0;
394
395 for (i, arg) in args.iter().enumerate() {
396 if let Some((name, value)) = parse_env_var(arg) {
397 env.push(EnvVariable::new(name, value));
398 command_idx = i + 1;
399 } else {
400 break;
401 }
402 }
403
404 if command_idx >= args.len() {
405 return Err(crate::util::internal_error(
406 "No command found (only environment variables provided)",
407 ));
408 }
409
410 let command = PathBuf::from(&args[command_idx]);
411 let cmd_args = args[command_idx + 1..].to_vec();
412
413 let name = command
414 .file_name()
415 .and_then(|n| n.to_str())
416 .unwrap_or("agent")
417 .to_string();
418
419 Ok(AcpAgent {
420 server: SchemaMcpServer::Stdio(
421 McpServerStdio::new(name, command).args(cmd_args).env(env),
422 ),
423 debug_callback: None,
424 })
425 }
426}
427
428fn parse_env_var(s: &str) -> Option<(String, String)> {
430 let eq_pos = s.find('=')?;
431 if eq_pos == 0 {
432 return None;
433 }
434
435 let name = &s[..eq_pos];
436 let value = &s[eq_pos + 1..];
437
438 let mut chars = name.chars();
439 let first = chars.next()?;
440 if !first.is_ascii_alphabetic() && first != '_' {
441 return None;
442 }
443 if !chars.all(|c| c.is_ascii_alphanumeric() || c == '_') {
444 return None;
445 }
446
447 Some((name.to_string(), value.to_string()))
448}
449
450impl FromStr for AcpAgent {
451 type Err = crate::Error;
452
453 fn from_str(s: &str) -> Result<Self, Self::Err> {
454 let trimmed = s.trim();
455
456 if trimmed.starts_with('{') {
457 let server: SchemaMcpServer = serde_json::from_str(trimmed)
458 .map_err(|e| crate::util::internal_error(format!("Failed to parse JSON: {e}")))?;
459 return Ok(Self {
460 server,
461 debug_callback: None,
462 });
463 }
464
465 let parts = shell_words::split(trimmed)
466 .map_err(|e| crate::util::internal_error(format!("Failed to parse command: {e}")))?;
467
468 Self::from_args(parts)
469 }
470}
471
472#[cfg(test)]
473mod tests {
474 use super::*;
475
476 #[test]
477 fn test_parse_simple_command() {
478 let agent = AcpAgent::from_str("python agent.py").unwrap();
479 match agent.server {
480 SchemaMcpServer::Stdio(stdio) => {
481 assert_eq!(stdio.name, "python");
482 assert_eq!(stdio.command, PathBuf::from("python"));
483 assert_eq!(stdio.args, vec!["agent.py"]);
484 assert!(stdio.env.is_empty());
485 }
486 _ => panic!("Expected Stdio variant"),
487 }
488 }
489
490 #[test]
491 fn test_parse_command_with_args() {
492 let agent = AcpAgent::from_str("node server.js --port 8080 --verbose").unwrap();
493 match agent.server {
494 SchemaMcpServer::Stdio(stdio) => {
495 assert_eq!(stdio.name, "node");
496 assert_eq!(stdio.command, PathBuf::from("node"));
497 assert_eq!(stdio.args, vec!["server.js", "--port", "8080", "--verbose"]);
498 assert!(stdio.env.is_empty());
499 }
500 _ => panic!("Expected Stdio variant"),
501 }
502 }
503
504 #[test]
505 fn test_parse_command_with_quotes() {
506 let agent = AcpAgent::from_str(r#"python "my agent.py" --name "Test Agent""#).unwrap();
507 match agent.server {
508 SchemaMcpServer::Stdio(stdio) => {
509 assert_eq!(stdio.name, "python");
510 assert_eq!(stdio.command, PathBuf::from("python"));
511 assert_eq!(stdio.args, vec!["my agent.py", "--name", "Test Agent"]);
512 assert!(stdio.env.is_empty());
513 }
514 _ => panic!("Expected Stdio variant"),
515 }
516 }
517
518 #[test]
519 fn test_parse_json_stdio() {
520 let json = r#"{
521 "type": "stdio",
522 "name": "my-agent",
523 "command": "/usr/bin/python",
524 "args": ["agent.py", "--verbose"],
525 "env": []
526 }"#;
527 let agent = AcpAgent::from_str(json).unwrap();
528 match agent.server {
529 SchemaMcpServer::Stdio(stdio) => {
530 assert_eq!(stdio.name, "my-agent");
531 assert_eq!(stdio.command, PathBuf::from("/usr/bin/python"));
532 assert_eq!(stdio.args, vec!["agent.py", "--verbose"]);
533 assert!(stdio.env.is_empty());
534 }
535 _ => panic!("Expected Stdio variant"),
536 }
537 }
538
539 #[test]
540 fn test_parse_json_http() {
541 let json = r#"{
542 "type": "http",
543 "name": "remote-agent",
544 "url": "https://example.com/agent",
545 "headers": []
546 }"#;
547 let agent = AcpAgent::from_str(json).unwrap();
548 match agent.server {
549 SchemaMcpServer::Http(http) => {
550 assert_eq!(http.name, "remote-agent");
551 assert_eq!(http.url, "https://example.com/agent");
552 assert!(http.headers.is_empty());
553 }
554 _ => panic!("Expected Http variant"),
555 }
556 }
557}