agent_client_protocol_tokio/
acp_agent.rs1use std::path::PathBuf;
7use std::str::FromStr;
8use std::sync::Arc;
9
10use agent_client_protocol::{Client, Conductor, Role};
11use tokio::process::Child;
12use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
13
14#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum LineDirection {
17 Stdin,
19 Stdout,
21 Stderr,
23}
24
25pub struct AcpAgent {
82 server: agent_client_protocol::schema::McpServer,
83 debug_callback: Option<Arc<dyn Fn(&str, LineDirection) + Send + Sync + 'static>>,
84}
85
86impl std::fmt::Debug for AcpAgent {
87 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88 f.debug_struct("AcpAgent")
89 .field("server", &self.server)
90 .field(
91 "debug_callback",
92 &self.debug_callback.as_ref().map(|_| "..."),
93 )
94 .finish()
95 }
96}
97
98impl AcpAgent {
99 #[must_use]
101 pub fn new(server: agent_client_protocol::schema::McpServer) -> Self {
102 Self {
103 server,
104 debug_callback: None,
105 }
106 }
107
108 #[must_use]
111 pub fn zed_claude_code() -> Self {
112 Self::from_str("npx -y @zed-industries/claude-code-acp@latest").expect("valid bash command")
113 }
114
115 #[must_use]
118 pub fn zed_codex() -> Self {
119 Self::from_str("npx -y @zed-industries/codex-acp@latest").expect("valid bash command")
120 }
121
122 #[must_use]
125 pub fn google_gemini() -> Self {
126 Self::from_str("npx -y -- @google/gemini-cli@latest --experimental-acp")
127 .expect("valid bash command")
128 }
129
130 #[must_use]
132 pub fn server(&self) -> &agent_client_protocol::schema::McpServer {
133 &self.server
134 }
135
136 #[must_use]
138 pub fn into_server(self) -> agent_client_protocol::schema::McpServer {
139 self.server
140 }
141
142 #[must_use]
159 pub fn with_debug<F>(mut self, callback: F) -> Self
160 where
161 F: Fn(&str, LineDirection) + Send + Sync + 'static,
162 {
163 self.debug_callback = Some(Arc::new(callback));
164 self
165 }
166
167 pub fn spawn_process(
170 &self,
171 ) -> Result<
172 (
173 tokio::process::ChildStdin,
174 tokio::process::ChildStdout,
175 tokio::process::ChildStderr,
176 Child,
177 ),
178 agent_client_protocol::Error,
179 > {
180 match &self.server {
181 agent_client_protocol::schema::McpServer::Stdio(stdio) => {
182 let mut cmd = tokio::process::Command::new(&stdio.command);
183 cmd.args(&stdio.args);
184 for env_var in &stdio.env {
185 cmd.env(&env_var.name, &env_var.value);
186 }
187 cmd.stdin(std::process::Stdio::piped())
188 .stdout(std::process::Stdio::piped())
189 .stderr(std::process::Stdio::piped());
190
191 let mut child = cmd
192 .spawn()
193 .map_err(agent_client_protocol::Error::into_internal_error)?;
194
195 let child_stdin = child.stdin.take().ok_or_else(|| {
196 agent_client_protocol::util::internal_error("Failed to open stdin")
197 })?;
198 let child_stdout = child.stdout.take().ok_or_else(|| {
199 agent_client_protocol::util::internal_error("Failed to open stdout")
200 })?;
201 let child_stderr = child.stderr.take().ok_or_else(|| {
202 agent_client_protocol::util::internal_error("Failed to open stderr")
203 })?;
204
205 Ok((child_stdin, child_stdout, child_stderr, child))
206 }
207 agent_client_protocol::schema::McpServer::Http(_) => {
208 Err(agent_client_protocol::util::internal_error(
209 "HTTP transport not yet supported by AcpAgent",
210 ))
211 }
212 agent_client_protocol::schema::McpServer::Sse(_) => {
213 Err(agent_client_protocol::util::internal_error(
214 "SSE transport not yet supported by AcpAgent",
215 ))
216 }
217 _ => Err(agent_client_protocol::util::internal_error(
218 "Unknown MCP server transport type",
219 )),
220 }
221 }
222}
223
224struct ChildGuard(Child);
226
227impl ChildGuard {
228 async fn wait(&mut self) -> std::io::Result<std::process::ExitStatus> {
229 self.0.wait().await
230 }
231}
232
233impl Drop for ChildGuard {
234 fn drop(&mut self) {
235 drop(self.0.start_kill());
236 }
237}
238
239async fn monitor_child(
244 child: Child,
245 stderr_rx: tokio::sync::oneshot::Receiver<String>,
246) -> Result<(), agent_client_protocol::Error> {
247 let mut guard = ChildGuard(child);
248
249 let status = guard.wait().await.map_err(|e| {
251 agent_client_protocol::util::internal_error(format!("Failed to wait for process: {e}"))
252 })?;
253
254 if status.success() {
255 Ok(())
256 } else {
257 let stderr = stderr_rx.await.unwrap_or_default();
259
260 let message = if stderr.is_empty() {
261 format!("Process exited with {status}")
262 } else {
263 format!("Process exited with {status}: {stderr}")
264 };
265
266 Err(agent_client_protocol::util::internal_error(message))
267 }
268}
269
270pub trait AcpAgentCounterpartRole: Role {}
272
273impl AcpAgentCounterpartRole for Client {}
274
275impl AcpAgentCounterpartRole for Conductor {}
276
277impl<Counterpart: AcpAgentCounterpartRole> agent_client_protocol::ConnectTo<Counterpart>
278 for AcpAgent
279{
280 async fn connect_to(
281 self,
282 client: impl agent_client_protocol::ConnectTo<Counterpart::Counterpart>,
283 ) -> Result<(), agent_client_protocol::Error> {
284 use futures::AsyncBufReadExt;
285 use futures::AsyncWriteExt;
286 use futures::StreamExt;
287 use futures::io::BufReader;
288
289 let (child_stdin, child_stdout, child_stderr, child) = self.spawn_process()?;
290
291 let (stderr_tx, stderr_rx) = tokio::sync::oneshot::channel::<String>();
293
294 let debug_callback = self.debug_callback.clone();
296 tokio::spawn(async move {
297 let stderr_reader = BufReader::new(child_stderr.compat());
298 let mut stderr_lines = stderr_reader.lines();
299 let mut collected = String::new();
300 while let Some(line_result) = stderr_lines.next().await {
301 if let Ok(line) = line_result {
302 if let Some(ref callback) = debug_callback {
304 callback(&line, LineDirection::Stderr);
305 }
306 if !collected.is_empty() {
308 collected.push('\n');
309 }
310 collected.push_str(&line);
311 }
312 }
313 drop(stderr_tx.send(collected));
314 });
315
316 let child_monitor = monitor_child(child, stderr_rx);
318
319 let incoming_lines = if let Some(callback) = self.debug_callback.clone() {
321 Box::pin(
322 BufReader::new(child_stdout.compat())
323 .lines()
324 .inspect(move |result| {
325 if let Ok(line) = result {
326 callback(line, LineDirection::Stdout);
327 }
328 }),
329 )
330 as std::pin::Pin<Box<dyn futures::Stream<Item = std::io::Result<String>> + Send>>
331 } else {
332 Box::pin(BufReader::new(child_stdout.compat()).lines())
333 };
334
335 let outgoing_sink = if let Some(callback) = self.debug_callback.clone() {
337 Box::pin(futures::sink::unfold(
338 (child_stdin.compat_write(), callback),
339 async move |(mut writer, callback), line: String| {
340 callback(&line, LineDirection::Stdin);
341 let mut bytes = line.into_bytes();
342 bytes.push(b'\n');
343 writer.write_all(&bytes).await?;
344 Ok::<_, std::io::Error>((writer, callback))
345 },
346 ))
347 as std::pin::Pin<Box<dyn futures::Sink<String, Error = std::io::Error> + Send>>
348 } else {
349 Box::pin(futures::sink::unfold(
350 child_stdin.compat_write(),
351 async move |mut writer, line: String| {
352 let mut bytes = line.into_bytes();
353 bytes.push(b'\n');
354 writer.write_all(&bytes).await?;
355 Ok::<_, std::io::Error>(writer)
356 },
357 ))
358 };
359
360 let protocol_future = agent_client_protocol::ConnectTo::<Counterpart>::connect_to(
363 agent_client_protocol::Lines::new(outgoing_sink, incoming_lines),
364 client,
365 );
366
367 tokio::select! {
368 result = protocol_future => result,
369 result = child_monitor => result,
370 }
371 }
372}
373
374impl AcpAgent {
375 pub fn from_args<I, T>(args: I) -> Result<Self, agent_client_protocol::Error>
393 where
394 I: IntoIterator<Item = T>,
395 T: ToString,
396 {
397 let args: Vec<String> = args.into_iter().map(|s| s.to_string()).collect();
398
399 if args.is_empty() {
400 return Err(agent_client_protocol::util::internal_error(
401 "Arguments cannot be empty",
402 ));
403 }
404
405 let mut env = vec![];
406 let mut command_idx = 0;
407
408 for (i, arg) in args.iter().enumerate() {
410 if let Some((name, value)) = parse_env_var(arg) {
411 env.push(agent_client_protocol::schema::EnvVariable::new(name, value));
412 command_idx = i + 1;
413 } else {
414 break;
415 }
416 }
417
418 if command_idx >= args.len() {
419 return Err(agent_client_protocol::util::internal_error(
420 "No command found (only environment variables provided)",
421 ));
422 }
423
424 let command = PathBuf::from(&args[command_idx]);
425 let cmd_args = args[command_idx + 1..].to_vec();
426
427 let name = command
429 .file_name()
430 .and_then(|n| n.to_str())
431 .unwrap_or("agent")
432 .to_string();
433
434 Ok(AcpAgent {
435 server: agent_client_protocol::schema::McpServer::Stdio(
436 agent_client_protocol::schema::McpServerStdio::new(name, command)
437 .args(cmd_args)
438 .env(env),
439 ),
440 debug_callback: None,
441 })
442 }
443}
444
445fn parse_env_var(s: &str) -> Option<(String, String)> {
448 let eq_pos = s.find('=')?;
450 if eq_pos == 0 {
451 return None;
452 }
453
454 let name = &s[..eq_pos];
455 let value = &s[eq_pos + 1..];
456
457 let mut chars = name.chars();
460 let first = chars.next()?;
461 if !first.is_ascii_alphabetic() && first != '_' {
462 return None;
463 }
464 if !chars.all(|c| c.is_ascii_alphanumeric() || c == '_') {
465 return None;
466 }
467
468 Some((name.to_string(), value.to_string()))
469}
470
471impl FromStr for AcpAgent {
472 type Err = agent_client_protocol::Error;
473
474 fn from_str(s: &str) -> Result<Self, Self::Err> {
475 let trimmed = s.trim();
476
477 if trimmed.starts_with('{') {
479 let server: agent_client_protocol::schema::McpServer = serde_json::from_str(trimmed)
480 .map_err(|e| {
481 agent_client_protocol::util::internal_error(format!(
482 "Failed to parse JSON: {e}"
483 ))
484 })?;
485 return Ok(Self {
486 server,
487 debug_callback: None,
488 });
489 }
490
491 let parts = shell_words::split(trimmed).map_err(|e| {
493 agent_client_protocol::util::internal_error(format!("Failed to parse command: {e}"))
494 })?;
495
496 Self::from_args(parts)
497 }
498}
499
500#[cfg(test)]
501mod tests {
502 use super::*;
503
504 #[test]
505 fn test_parse_simple_command() {
506 let agent = AcpAgent::from_str("python agent.py").unwrap();
507 match agent.server {
508 agent_client_protocol::schema::McpServer::Stdio(stdio) => {
509 assert_eq!(stdio.name, "python");
510 assert_eq!(stdio.command, PathBuf::from("python"));
511 assert_eq!(stdio.args, vec!["agent.py"]);
512 assert!(stdio.env.is_empty());
513 }
514 _ => panic!("Expected Stdio variant"),
515 }
516 }
517
518 #[test]
519 fn test_parse_command_with_args() {
520 let agent = AcpAgent::from_str("node server.js --port 8080 --verbose").unwrap();
521 match agent.server {
522 agent_client_protocol::schema::McpServer::Stdio(stdio) => {
523 assert_eq!(stdio.name, "node");
524 assert_eq!(stdio.command, PathBuf::from("node"));
525 assert_eq!(stdio.args, vec!["server.js", "--port", "8080", "--verbose"]);
526 assert!(stdio.env.is_empty());
527 }
528 _ => panic!("Expected Stdio variant"),
529 }
530 }
531
532 #[test]
533 fn test_parse_command_with_quotes() {
534 let agent = AcpAgent::from_str(r#"python "my agent.py" --name "Test Agent""#).unwrap();
535 match agent.server {
536 agent_client_protocol::schema::McpServer::Stdio(stdio) => {
537 assert_eq!(stdio.name, "python");
538 assert_eq!(stdio.command, PathBuf::from("python"));
539 assert_eq!(stdio.args, vec!["my agent.py", "--name", "Test Agent"]);
540 assert!(stdio.env.is_empty());
541 }
542 _ => panic!("Expected Stdio variant"),
543 }
544 }
545
546 #[test]
547 fn test_parse_json_stdio() {
548 let json = r#"{
549 "type": "stdio",
550 "name": "my-agent",
551 "command": "/usr/bin/python",
552 "args": ["agent.py", "--verbose"],
553 "env": []
554 }"#;
555 let agent = AcpAgent::from_str(json).unwrap();
556 match agent.server {
557 agent_client_protocol::schema::McpServer::Stdio(stdio) => {
558 assert_eq!(stdio.name, "my-agent");
559 assert_eq!(stdio.command, PathBuf::from("/usr/bin/python"));
560 assert_eq!(stdio.args, vec!["agent.py", "--verbose"]);
561 assert!(stdio.env.is_empty());
562 }
563 _ => panic!("Expected Stdio variant"),
564 }
565 }
566
567 #[test]
568 fn test_parse_json_http() {
569 let json = r#"{
570 "type": "http",
571 "name": "remote-agent",
572 "url": "https://example.com/agent",
573 "headers": []
574 }"#;
575 let agent = AcpAgent::from_str(json).unwrap();
576 match agent.server {
577 agent_client_protocol::schema::McpServer::Http(http) => {
578 assert_eq!(http.name, "remote-agent");
579 assert_eq!(http.url, "https://example.com/agent");
580 assert!(http.headers.is_empty());
581 }
582 _ => panic!("Expected Http variant"),
583 }
584 }
585}