Skip to main content

agent_client_protocol_tokio/
acp_agent.rs

1//! Utilities for connecting to ACP agents and proxies.
2//!
3//! This module provides [`AcpAgent`], a convenient wrapper around [`agent_client_protocol::schema::McpServer`]
4//! that can be parsed from either a command string or JSON configuration.
5
6use std::path::PathBuf;
7use std::str::FromStr;
8use std::sync::Arc;
9
10use agent_client_protocol::{Client, Conductor, Role};
11use tokio::process::Child;
12use tokio_util::compat::{TokioAsyncReadCompatExt, TokioAsyncWriteCompatExt};
13
14/// Direction of a line being sent or received.
15#[derive(Debug, Clone, Copy, PartialEq, Eq)]
16pub enum LineDirection {
17    /// Line being sent to the agent (stdin)
18    Stdin,
19    /// Line being received from the agent (stdout)
20    Stdout,
21    /// Line being received from the agent (stderr)
22    Stderr,
23}
24
25/// A component representing an external ACP agent running in a separate process.
26///
27/// `AcpAgent` implements the [`agent_client_protocol::ConnectTo`] trait for spawning and communicating with
28/// external agents or proxies via stdio. It handles process spawning, stream setup, and
29/// byte stream serialization automatically. This is the primary way to connect to agents
30/// that run as separate executables.
31///
32/// This is a wrapper around [`agent_client_protocol::schema::McpServer`] that provides convenient parsing
33/// from command-line strings or JSON configurations.
34///
35/// # Use Cases
36///
37/// - **External agents**: Connect to agents written in any language (Python, Node.js, Rust, etc.)
38/// - **Proxy chains**: Spawn intermediate proxies that transform or intercept messages
39/// - **Conductor components**: Use with [`agent_client_protocol_conductor::Conductor`] to build proxy chains
40/// - **Subprocess isolation**: Run potentially untrusted code in a separate process
41///
42/// # Examples
43///
44/// Parse from a command string:
45/// ```
46/// # use agent_client_protocol_tokio::AcpAgent;
47/// # use std::str::FromStr;
48/// let agent = AcpAgent::from_str("python my_agent.py --verbose").unwrap();
49/// ```
50///
51/// Parse from JSON:
52/// ```
53/// # use agent_client_protocol_tokio::AcpAgent;
54/// # use std::str::FromStr;
55/// let agent = AcpAgent::from_str(r#"{"type": "stdio", "name": "my-agent", "command": "python", "args": ["my_agent.py"], "env": []}"#).unwrap();
56/// ```
57///
58/// Use as a component to connect to an external agent:
59/// ```ignore
60/// use agent_client_protocol::{Client, Builder};
61/// use agent_client_protocol_tokio::AcpAgent;
62/// use std::str::FromStr;
63///
64/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
65/// let agent = AcpAgent::from_str("python my_agent.py")?;
66///
67/// // The agent process will be spawned automatically when connected
68/// Client.builder()
69///     .connect_to(agent)
70///     .await?
71///     .connect_with(|cx| async move {
72///         // Use the connection to communicate with the agent process
73///         Ok(())
74///     })
75///     .await?;
76/// # Ok(())
77/// # }
78/// ```
79///
80/// [`agent_client_protocol_conductor::Conductor`]: https://docs.rs/agent-client-protocol-conductor/latest/agent_client_protocol_conductor/struct.Conductor.html
81pub struct AcpAgent {
82    server: agent_client_protocol::schema::McpServer,
83    debug_callback: Option<Arc<dyn Fn(&str, LineDirection) + Send + Sync + 'static>>,
84}
85
86impl std::fmt::Debug for AcpAgent {
87    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
88        f.debug_struct("AcpAgent")
89            .field("server", &self.server)
90            .field(
91                "debug_callback",
92                &self.debug_callback.as_ref().map(|_| "..."),
93            )
94            .finish()
95    }
96}
97
98impl AcpAgent {
99    /// Create a new `AcpAgent` from an [`agent_client_protocol::schema::McpServer`] configuration.
100    #[must_use]
101    pub fn new(server: agent_client_protocol::schema::McpServer) -> Self {
102        Self {
103            server,
104            debug_callback: None,
105        }
106    }
107
108    /// Create an ACP agent for Zed Industries' Claude Code tool.
109    /// Just runs `npx -y @zed-industries/claude-code-acp@latest`.
110    #[must_use]
111    pub fn zed_claude_code() -> Self {
112        Self::from_str("npx -y @zed-industries/claude-code-acp@latest").expect("valid bash command")
113    }
114
115    /// Create an ACP agent for Zed Industries' Codex tool.
116    /// Just runs `npx -y @zed-industries/codex-acp@latest`.
117    #[must_use]
118    pub fn zed_codex() -> Self {
119        Self::from_str("npx -y @zed-industries/codex-acp@latest").expect("valid bash command")
120    }
121
122    /// Create an ACP agent for Google's Gemini CLI.
123    /// Just runs `npx -y -- @google/gemini-cli@latest --experimental-acp`.
124    #[must_use]
125    pub fn google_gemini() -> Self {
126        Self::from_str("npx -y -- @google/gemini-cli@latest --experimental-acp")
127            .expect("valid bash command")
128    }
129
130    /// Get the underlying [`agent_client_protocol::schema::McpServer`] configuration.
131    #[must_use]
132    pub fn server(&self) -> &agent_client_protocol::schema::McpServer {
133        &self.server
134    }
135
136    /// Convert into the underlying [`agent_client_protocol::schema::McpServer`] configuration.
137    #[must_use]
138    pub fn into_server(self) -> agent_client_protocol::schema::McpServer {
139        self.server
140    }
141
142    /// Add a debug callback that will be invoked for each line sent/received.
143    ///
144    /// The callback receives the line content and the direction (stdin/stdout/stderr).
145    /// This is useful for logging, debugging, or monitoring agent communication.
146    ///
147    /// # Example
148    ///
149    /// ```no_run
150    /// # use agent_client_protocol_tokio::{AcpAgent, LineDirection};
151    /// # use std::str::FromStr;
152    /// let agent = AcpAgent::from_str("python my_agent.py")
153    ///     .unwrap()
154    ///     .with_debug(|line, direction| {
155    ///         eprintln!("{:?}: {}", direction, line);
156    ///     });
157    /// ```
158    #[must_use]
159    pub fn with_debug<F>(mut self, callback: F) -> Self
160    where
161        F: Fn(&str, LineDirection) + Send + Sync + 'static,
162    {
163        self.debug_callback = Some(Arc::new(callback));
164        self
165    }
166
167    /// Spawn the process and get stdio streams.
168    /// Used internally by the Component trait implementation.
169    pub fn spawn_process(
170        &self,
171    ) -> Result<
172        (
173            tokio::process::ChildStdin,
174            tokio::process::ChildStdout,
175            tokio::process::ChildStderr,
176            Child,
177        ),
178        agent_client_protocol::Error,
179    > {
180        match &self.server {
181            agent_client_protocol::schema::McpServer::Stdio(stdio) => {
182                let mut cmd = tokio::process::Command::new(&stdio.command);
183                cmd.args(&stdio.args);
184                for env_var in &stdio.env {
185                    cmd.env(&env_var.name, &env_var.value);
186                }
187                cmd.stdin(std::process::Stdio::piped())
188                    .stdout(std::process::Stdio::piped())
189                    .stderr(std::process::Stdio::piped());
190
191                let mut child = cmd
192                    .spawn()
193                    .map_err(agent_client_protocol::Error::into_internal_error)?;
194
195                let child_stdin = child.stdin.take().ok_or_else(|| {
196                    agent_client_protocol::util::internal_error("Failed to open stdin")
197                })?;
198                let child_stdout = child.stdout.take().ok_or_else(|| {
199                    agent_client_protocol::util::internal_error("Failed to open stdout")
200                })?;
201                let child_stderr = child.stderr.take().ok_or_else(|| {
202                    agent_client_protocol::util::internal_error("Failed to open stderr")
203                })?;
204
205                Ok((child_stdin, child_stdout, child_stderr, child))
206            }
207            agent_client_protocol::schema::McpServer::Http(_) => {
208                Err(agent_client_protocol::util::internal_error(
209                    "HTTP transport not yet supported by AcpAgent",
210                ))
211            }
212            agent_client_protocol::schema::McpServer::Sse(_) => {
213                Err(agent_client_protocol::util::internal_error(
214                    "SSE transport not yet supported by AcpAgent",
215                ))
216            }
217            _ => Err(agent_client_protocol::util::internal_error(
218                "Unknown MCP server transport type",
219            )),
220        }
221    }
222}
223
224/// A wrapper around Child that kills the process when dropped.
225struct ChildGuard(Child);
226
227impl ChildGuard {
228    async fn wait(&mut self) -> std::io::Result<std::process::ExitStatus> {
229        self.0.wait().await
230    }
231}
232
233impl Drop for ChildGuard {
234    fn drop(&mut self) {
235        drop(self.0.start_kill());
236    }
237}
238
239/// Waits for a child process and returns an error if it exits with non-zero status.
240///
241/// The error message includes any stderr output collected by the background task.
242/// When dropped, the child process is killed.
243async fn monitor_child(
244    child: Child,
245    stderr_rx: tokio::sync::oneshot::Receiver<String>,
246) -> Result<(), agent_client_protocol::Error> {
247    let mut guard = ChildGuard(child);
248
249    // Wait for the child to exit
250    let status = guard.wait().await.map_err(|e| {
251        agent_client_protocol::util::internal_error(format!("Failed to wait for process: {e}"))
252    })?;
253
254    if status.success() {
255        Ok(())
256    } else {
257        // Get stderr content if available
258        let stderr = stderr_rx.await.unwrap_or_default();
259
260        let message = if stderr.is_empty() {
261            format!("Process exited with {status}")
262        } else {
263            format!("Process exited with {status}: {stderr}")
264        };
265
266        Err(agent_client_protocol::util::internal_error(message))
267    }
268}
269
270/// Roles that an ACP agent executable can potentially serve.
271pub trait AcpAgentCounterpartRole: Role {}
272
273impl AcpAgentCounterpartRole for Client {}
274
275impl AcpAgentCounterpartRole for Conductor {}
276
277impl<Counterpart: AcpAgentCounterpartRole> agent_client_protocol::ConnectTo<Counterpart>
278    for AcpAgent
279{
280    async fn connect_to(
281        self,
282        client: impl agent_client_protocol::ConnectTo<Counterpart::Counterpart>,
283    ) -> Result<(), agent_client_protocol::Error> {
284        use futures::AsyncBufReadExt;
285        use futures::AsyncWriteExt;
286        use futures::StreamExt;
287        use futures::io::BufReader;
288
289        let (child_stdin, child_stdout, child_stderr, child) = self.spawn_process()?;
290
291        // Create a channel to collect stderr for error reporting
292        let (stderr_tx, stderr_rx) = tokio::sync::oneshot::channel::<String>();
293
294        // Spawn a task to read stderr, optionally calling the debug callback
295        let debug_callback = self.debug_callback.clone();
296        tokio::spawn(async move {
297            let stderr_reader = BufReader::new(child_stderr.compat());
298            let mut stderr_lines = stderr_reader.lines();
299            let mut collected = String::new();
300            while let Some(line_result) = stderr_lines.next().await {
301                if let Ok(line) = line_result {
302                    // Call debug callback if present
303                    if let Some(ref callback) = debug_callback {
304                        callback(&line, LineDirection::Stderr);
305                    }
306                    // Always collect for error reporting
307                    if !collected.is_empty() {
308                        collected.push('\n');
309                    }
310                    collected.push_str(&line);
311                }
312            }
313            drop(stderr_tx.send(collected));
314        });
315
316        // Create a future that monitors the child process for early exit
317        let child_monitor = monitor_child(child, stderr_rx);
318
319        // Convert stdio to line streams with optional debug inspection
320        let incoming_lines = if let Some(callback) = self.debug_callback.clone() {
321            Box::pin(
322                BufReader::new(child_stdout.compat())
323                    .lines()
324                    .inspect(move |result| {
325                        if let Ok(line) = result {
326                            callback(line, LineDirection::Stdout);
327                        }
328                    }),
329            )
330                as std::pin::Pin<Box<dyn futures::Stream<Item = std::io::Result<String>> + Send>>
331        } else {
332            Box::pin(BufReader::new(child_stdout.compat()).lines())
333        };
334
335        // Create a sink that writes lines (with newlines) to stdin with optional debug logging
336        let outgoing_sink = if let Some(callback) = self.debug_callback.clone() {
337            Box::pin(futures::sink::unfold(
338                (child_stdin.compat_write(), callback),
339                async move |(mut writer, callback), line: String| {
340                    callback(&line, LineDirection::Stdin);
341                    let mut bytes = line.into_bytes();
342                    bytes.push(b'\n');
343                    writer.write_all(&bytes).await?;
344                    Ok::<_, std::io::Error>((writer, callback))
345                },
346            ))
347                as std::pin::Pin<Box<dyn futures::Sink<String, Error = std::io::Error> + Send>>
348        } else {
349            Box::pin(futures::sink::unfold(
350                child_stdin.compat_write(),
351                async move |mut writer, line: String| {
352                    let mut bytes = line.into_bytes();
353                    bytes.push(b'\n');
354                    writer.write_all(&bytes).await?;
355                    Ok::<_, std::io::Error>(writer)
356                },
357            ))
358        };
359
360        // Race the protocol against child process exit
361        // If the child exits early (e.g., with an error), we return that error
362        let protocol_future = agent_client_protocol::ConnectTo::<Counterpart>::connect_to(
363            agent_client_protocol::Lines::new(outgoing_sink, incoming_lines),
364            client,
365        );
366
367        tokio::select! {
368            result = protocol_future => result,
369            result = child_monitor => result,
370        }
371    }
372}
373
374impl AcpAgent {
375    /// Create an `AcpAgent` from an iterator of command-line arguments.
376    ///
377    /// Leading arguments of the form `NAME=value` are parsed as environment variables.
378    /// The first non-env argument is the command, and the rest are arguments.
379    ///
380    /// # Example
381    ///
382    /// ```
383    /// # use agent_client_protocol_tokio::AcpAgent;
384    /// let agent = AcpAgent::from_args([
385    ///     "RUST_LOG=debug",
386    ///     "cargo",
387    ///     "run",
388    ///     "-p",
389    ///     "my-crate",
390    /// ]).unwrap();
391    /// ```
392    pub fn from_args<I, T>(args: I) -> Result<Self, agent_client_protocol::Error>
393    where
394        I: IntoIterator<Item = T>,
395        T: ToString,
396    {
397        let args: Vec<String> = args.into_iter().map(|s| s.to_string()).collect();
398
399        if args.is_empty() {
400            return Err(agent_client_protocol::util::internal_error(
401                "Arguments cannot be empty",
402            ));
403        }
404
405        let mut env = vec![];
406        let mut command_idx = 0;
407
408        // Parse leading FOO=bar arguments as environment variables
409        for (i, arg) in args.iter().enumerate() {
410            if let Some((name, value)) = parse_env_var(arg) {
411                env.push(agent_client_protocol::schema::EnvVariable::new(name, value));
412                command_idx = i + 1;
413            } else {
414                break;
415            }
416        }
417
418        if command_idx >= args.len() {
419            return Err(agent_client_protocol::util::internal_error(
420                "No command found (only environment variables provided)",
421            ));
422        }
423
424        let command = PathBuf::from(&args[command_idx]);
425        let cmd_args = args[command_idx + 1..].to_vec();
426
427        // Generate a name from the command
428        let name = command
429            .file_name()
430            .and_then(|n| n.to_str())
431            .unwrap_or("agent")
432            .to_string();
433
434        Ok(AcpAgent {
435            server: agent_client_protocol::schema::McpServer::Stdio(
436                agent_client_protocol::schema::McpServerStdio::new(name, command)
437                    .args(cmd_args)
438                    .env(env),
439            ),
440            debug_callback: None,
441        })
442    }
443}
444
445/// Parse a string as an environment variable assignment (NAME=value).
446/// Returns None if it doesn't match the pattern.
447fn parse_env_var(s: &str) -> Option<(String, String)> {
448    // Must contain '=' and the part before must be a valid env var name
449    let eq_pos = s.find('=')?;
450    if eq_pos == 0 {
451        return None;
452    }
453
454    let name = &s[..eq_pos];
455    let value = &s[eq_pos + 1..];
456
457    // Env var names must start with a letter or underscore, and contain only
458    // alphanumeric characters and underscores
459    let mut chars = name.chars();
460    let first = chars.next()?;
461    if !first.is_ascii_alphabetic() && first != '_' {
462        return None;
463    }
464    if !chars.all(|c| c.is_ascii_alphanumeric() || c == '_') {
465        return None;
466    }
467
468    Some((name.to_string(), value.to_string()))
469}
470
471impl FromStr for AcpAgent {
472    type Err = agent_client_protocol::Error;
473
474    fn from_str(s: &str) -> Result<Self, Self::Err> {
475        let trimmed = s.trim();
476
477        // If it starts with '{', try to parse as JSON
478        if trimmed.starts_with('{') {
479            let server: agent_client_protocol::schema::McpServer = serde_json::from_str(trimmed)
480                .map_err(|e| {
481                    agent_client_protocol::util::internal_error(format!(
482                        "Failed to parse JSON: {e}"
483                    ))
484                })?;
485            return Ok(Self {
486                server,
487                debug_callback: None,
488            });
489        }
490
491        // Otherwise, parse as a command string
492        let parts = shell_words::split(trimmed).map_err(|e| {
493            agent_client_protocol::util::internal_error(format!("Failed to parse command: {e}"))
494        })?;
495
496        Self::from_args(parts)
497    }
498}
499
500#[cfg(test)]
501mod tests {
502    use super::*;
503
504    #[test]
505    fn test_parse_simple_command() {
506        let agent = AcpAgent::from_str("python agent.py").unwrap();
507        match agent.server {
508            agent_client_protocol::schema::McpServer::Stdio(stdio) => {
509                assert_eq!(stdio.name, "python");
510                assert_eq!(stdio.command, PathBuf::from("python"));
511                assert_eq!(stdio.args, vec!["agent.py"]);
512                assert!(stdio.env.is_empty());
513            }
514            _ => panic!("Expected Stdio variant"),
515        }
516    }
517
518    #[test]
519    fn test_parse_command_with_args() {
520        let agent = AcpAgent::from_str("node server.js --port 8080 --verbose").unwrap();
521        match agent.server {
522            agent_client_protocol::schema::McpServer::Stdio(stdio) => {
523                assert_eq!(stdio.name, "node");
524                assert_eq!(stdio.command, PathBuf::from("node"));
525                assert_eq!(stdio.args, vec!["server.js", "--port", "8080", "--verbose"]);
526                assert!(stdio.env.is_empty());
527            }
528            _ => panic!("Expected Stdio variant"),
529        }
530    }
531
532    #[test]
533    fn test_parse_command_with_quotes() {
534        let agent = AcpAgent::from_str(r#"python "my agent.py" --name "Test Agent""#).unwrap();
535        match agent.server {
536            agent_client_protocol::schema::McpServer::Stdio(stdio) => {
537                assert_eq!(stdio.name, "python");
538                assert_eq!(stdio.command, PathBuf::from("python"));
539                assert_eq!(stdio.args, vec!["my agent.py", "--name", "Test Agent"]);
540                assert!(stdio.env.is_empty());
541            }
542            _ => panic!("Expected Stdio variant"),
543        }
544    }
545
546    #[test]
547    fn test_parse_json_stdio() {
548        let json = r#"{
549            "type": "stdio",
550            "name": "my-agent",
551            "command": "/usr/bin/python",
552            "args": ["agent.py", "--verbose"],
553            "env": []
554        }"#;
555        let agent = AcpAgent::from_str(json).unwrap();
556        match agent.server {
557            agent_client_protocol::schema::McpServer::Stdio(stdio) => {
558                assert_eq!(stdio.name, "my-agent");
559                assert_eq!(stdio.command, PathBuf::from("/usr/bin/python"));
560                assert_eq!(stdio.args, vec!["agent.py", "--verbose"]);
561                assert!(stdio.env.is_empty());
562            }
563            _ => panic!("Expected Stdio variant"),
564        }
565    }
566
567    #[test]
568    fn test_parse_json_http() {
569        let json = r#"{
570            "type": "http",
571            "name": "remote-agent",
572            "url": "https://example.com/agent",
573            "headers": []
574        }"#;
575        let agent = AcpAgent::from_str(json).unwrap();
576        match agent.server {
577            agent_client_protocol::schema::McpServer::Http(http) => {
578                assert_eq!(http.name, "remote-agent");
579                assert_eq!(http.url, "https://example.com/agent");
580                assert!(http.headers.is_empty());
581            }
582            _ => panic!("Expected Http variant"),
583        }
584    }
585}