Skip to main content

rmcp_soddygo/transport/
child_process.rs

1use std::process::Stdio;
2
3use futures::future::Future;
4use process_wrap::tokio::{ChildWrapper, CommandWrap};
5use tokio::{
6    io::AsyncRead,
7    process::{ChildStderr, ChildStdin, ChildStdout},
8};
9
10use super::{RxJsonRpcMessage, Transport, TxJsonRpcMessage, async_rw::AsyncRwTransport};
11use crate::RoleClient;
12
13const MAX_WAIT_ON_DROP_SECS: u64 = 3;
14/// The parts of a child process.
15type ChildProcessParts = (
16    Box<dyn ChildWrapper>,
17    ChildStdout,
18    ChildStdin,
19    Option<ChildStderr>,
20);
21
22/// Extract the stdio handles from a spawned child.
23/// Returns `(child, stdout, stdin, stderr)` where `stderr` is `Some` only
24/// if the process was spawned with `Stdio::piped()`.
25#[inline]
26fn child_process(mut child: Box<dyn ChildWrapper>) -> std::io::Result<ChildProcessParts> {
27    let child_stdin = match child.inner_mut().stdin().take() {
28        Some(stdin) => stdin,
29        None => return Err(std::io::Error::other("stdin was already taken")),
30    };
31    let child_stdout = match child.inner_mut().stdout().take() {
32        Some(stdout) => stdout,
33        None => return Err(std::io::Error::other("stdout was already taken")),
34    };
35    let child_stderr = child.inner_mut().stderr().take();
36    Ok((child, child_stdout, child_stdin, child_stderr))
37}
38
39pub struct TokioChildProcess {
40    child: ChildWithCleanup,
41    transport: AsyncRwTransport<RoleClient, ChildStdout, ChildStdin>,
42}
43
44pub struct ChildWithCleanup {
45    inner: Option<Box<dyn ChildWrapper>>,
46}
47
48impl Drop for ChildWithCleanup {
49    fn drop(&mut self) {
50        // We should not use start_kill(), instead we should use kill() to avoid zombies
51        if let Some(mut inner) = self.inner.take() {
52            // We don't care about the result, just try to kill it
53            tokio::spawn(async move {
54                if let Err(e) = Box::into_pin(inner.kill()).await {
55                    tracing::warn!("Error killing child process: {}", e);
56                }
57            });
58        }
59    }
60}
61
62// we hold the child process with stdout, for it's easier to implement AsyncRead
63pin_project_lite::pin_project! {
64    pub struct TokioChildProcessOut {
65        child: ChildWithCleanup,
66        #[pin]
67        child_stdout: ChildStdout,
68    }
69}
70
71impl TokioChildProcessOut {
72    /// Get the process ID of the child process.
73    pub fn id(&self) -> Option<u32> {
74        self.child.inner.as_ref()?.id()
75    }
76}
77
78impl AsyncRead for TokioChildProcessOut {
79    fn poll_read(
80        self: std::pin::Pin<&mut Self>,
81        cx: &mut std::task::Context<'_>,
82        buf: &mut tokio::io::ReadBuf<'_>,
83    ) -> std::task::Poll<std::io::Result<()>> {
84        self.project().child_stdout.poll_read(cx, buf)
85    }
86}
87
88impl TokioChildProcess {
89    /// Convenience: spawn with default `piped` stdio
90    pub fn new(command: impl Into<CommandWrap>) -> std::io::Result<Self> {
91        let (proc, _ignored) = TokioChildProcessBuilder::new(command).spawn()?;
92        Ok(proc)
93    }
94
95    /// Builder entry-point allowing fine-grained stdio control.
96    pub fn builder(command: impl Into<CommandWrap>) -> TokioChildProcessBuilder {
97        TokioChildProcessBuilder::new(command)
98    }
99
100    /// Get the process ID of the child process.
101    pub fn id(&self) -> Option<u32> {
102        self.child.inner.as_ref()?.id()
103    }
104
105    /// Gracefully shutdown the child process
106    ///
107    /// This will first close the transport to the child process (the server),
108    /// and wait for the child process to exit normally with a timeout.
109    /// If the child process doesn't exit within the timeout, it will be killed.
110    pub async fn graceful_shutdown(&mut self) -> std::io::Result<()> {
111        if let Some(mut child) = self.child.inner.take() {
112            self.transport.close().await?;
113
114            let wait_fut = child.wait();
115            tokio::select! {
116                _ = tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS)) => {
117                    if let Err(e) = Box::into_pin(child.kill()).await {
118                        tracing::warn!("Error killing child: {e}");
119                        return Err(e);
120                    }
121                },
122                res = wait_fut => {
123                    match res {
124                        Ok(status) => {
125                            tracing::info!("Child exited gracefully {}", status);
126                        }
127                        Err(e) => {
128                            tracing::warn!("Error waiting for child: {e}");
129                            return Err(e);
130                        }
131                    }
132                }
133            }
134        }
135        Ok(())
136    }
137
138    /// Take ownership of the inner child process
139    pub fn into_inner(mut self) -> Option<Box<dyn ChildWrapper>> {
140        self.child.inner.take()
141    }
142
143    /// Split this helper into a reader (stdout) and writer (stdin).
144    #[deprecated(
145        since = "0.5.0",
146        note = "use the Transport trait implementation instead"
147    )]
148    pub fn split(self) -> (TokioChildProcessOut, ChildStdin) {
149        unimplemented!("This method is deprecated, use the Transport trait implementation instead");
150    }
151}
152
153/// Builder for `TokioChildProcess` allowing custom `Stdio` configuration.
154pub struct TokioChildProcessBuilder {
155    cmd: CommandWrap,
156    stdin: Stdio,
157    stdout: Stdio,
158    stderr: Stdio,
159}
160
161impl TokioChildProcessBuilder {
162    fn new(cmd: impl Into<CommandWrap>) -> Self {
163        Self {
164            cmd: cmd.into(),
165            stdin: Stdio::piped(),
166            stdout: Stdio::piped(),
167            stderr: Stdio::inherit(),
168        }
169    }
170
171    /// Override the child stdin configuration.
172    pub fn stdin(mut self, io: impl Into<Stdio>) -> Self {
173        self.stdin = io.into();
174        self
175    }
176    /// Override the child stdout configuration.
177    pub fn stdout(mut self, io: impl Into<Stdio>) -> Self {
178        self.stdout = io.into();
179        self
180    }
181    /// Override the child stderr configuration.
182    pub fn stderr(mut self, io: impl Into<Stdio>) -> Self {
183        self.stderr = io.into();
184        self
185    }
186
187    /// Spawn the child process. Returns the transport plus an optional captured stderr handle.
188    pub fn spawn(mut self) -> std::io::Result<(TokioChildProcess, Option<ChildStderr>)> {
189        self.cmd
190            .command_mut()
191            .stdin(self.stdin)
192            .stdout(self.stdout)
193            .stderr(self.stderr);
194
195        let (child, stdout, stdin, stderr_opt) = child_process(self.cmd.spawn()?)?;
196
197        let transport = AsyncRwTransport::new(stdout, stdin);
198        let proc = TokioChildProcess {
199            child: ChildWithCleanup { inner: Some(child) },
200            transport,
201        };
202        Ok((proc, stderr_opt))
203    }
204}
205
206impl Transport<RoleClient> for TokioChildProcess {
207    type Error = std::io::Error;
208
209    fn send(
210        &mut self,
211        item: TxJsonRpcMessage<RoleClient>,
212    ) -> impl Future<Output = Result<(), Self::Error>> + Send + 'static {
213        self.transport.send(item)
214    }
215
216    fn receive(&mut self) -> impl Future<Output = Option<RxJsonRpcMessage<RoleClient>>> + Send {
217        self.transport.receive()
218    }
219
220    fn close(&mut self) -> impl Future<Output = Result<(), Self::Error>> + Send {
221        self.graceful_shutdown()
222    }
223}
224
225pub trait ConfigureCommandExt {
226    fn configure(self, f: impl FnOnce(&mut Self)) -> Self;
227}
228
229impl ConfigureCommandExt for tokio::process::Command {
230    fn configure(mut self, f: impl FnOnce(&mut Self)) -> Self {
231        f(&mut self);
232        self
233    }
234}
235
236/// Resolve the absolute path to an executable using the system `PATH`,
237/// then return a [`tokio::process::Command`] pointing at it.
238///
239/// This is especially useful on Windows where `.cmd` / `.exe` shim scripts
240/// (e.g. `npx.cmd`) are not reliably found by [`tokio::process::Command`]
241/// without a fully-qualified path.
242///
243/// # Example
244/// ```rust,no_run
245/// use rmcp::transport::{which_command, ConfigureCommandExt};
246///
247/// # fn example() -> std::io::Result<()> {
248/// let cmd = which_command("npx")?
249///     .configure(|cmd| {
250///         cmd.arg("-y").arg("@modelcontextprotocol/server-everything");
251///     });
252/// # Ok(())
253/// # }
254/// ```
255#[cfg(feature = "which-command")]
256pub fn which_command(
257    name: impl AsRef<std::ffi::OsStr>,
258) -> std::io::Result<tokio::process::Command> {
259    let resolved = which::which(name.as_ref())
260        .map_err(|e| std::io::Error::new(std::io::ErrorKind::NotFound, e))?;
261    Ok(tokio::process::Command::new(resolved))
262}
263
264#[cfg(feature = "which-command")]
265#[cfg(test)]
266mod tests_which {
267    #[test]
268    fn which_command_resolves_known_binary() {
269        // `ls` exists on every Unix system, `cmd` on Windows
270        #[cfg(unix)]
271        let result = super::which_command("ls");
272        #[cfg(windows)]
273        let result = super::which_command("cmd");
274
275        assert!(result.is_ok());
276    }
277
278    #[test]
279    fn which_command_fails_for_nonexistent() {
280        let result = super::which_command("this_binary_definitely_does_not_exist_12345");
281        assert!(result.is_err());
282        assert_eq!(result.unwrap_err().kind(), std::io::ErrorKind::NotFound);
283    }
284}
285
286#[cfg(unix)]
287#[cfg(test)]
288mod tests {
289    use tokio::process::Command;
290
291    use super::*;
292
293    #[tokio::test]
294    async fn test_tokio_child_process_drop() {
295        let r = TokioChildProcess::new(Command::new("sleep").configure(|cmd| {
296            cmd.arg("30");
297        }));
298        assert!(r.is_ok());
299        let child_process = r.unwrap();
300        let id = child_process.id();
301        assert!(id.is_some());
302        let id = id.unwrap();
303        // Drop the child process
304        drop(child_process);
305        // Wait a moment to allow the cleanup task to run
306        tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS + 1)).await;
307        // Check if the process is still running
308        let status = Command::new("ps")
309            .arg("-p")
310            .arg(id.to_string())
311            .status()
312            .await;
313        match status {
314            Ok(status) => {
315                assert!(
316                    !status.success(),
317                    "Process with PID {} is still running",
318                    id
319                );
320            }
321            Err(e) => {
322                panic!("Failed to check process status: {}", e);
323            }
324        }
325    }
326
327    #[tokio::test]
328    async fn test_tokio_child_process_graceful_shutdown() {
329        let r = TokioChildProcess::new(Command::new("sleep").configure(|cmd| {
330            cmd.arg("30");
331        }));
332        assert!(r.is_ok());
333        let mut child_process = r.unwrap();
334        let id = child_process.id();
335        assert!(id.is_some());
336        let id = id.unwrap();
337        child_process.graceful_shutdown().await.unwrap();
338        // Wait a moment to allow the cleanup task to run
339        tokio::time::sleep(std::time::Duration::from_secs(MAX_WAIT_ON_DROP_SECS + 1)).await;
340        // Check if the process is still running
341        let status = Command::new("ps")
342            .arg("-p")
343            .arg(id.to_string())
344            .status()
345            .await;
346        match status {
347            Ok(status) => {
348                assert!(
349                    !status.success(),
350                    "Process with PID {} is still running",
351                    id
352                );
353            }
354            Err(e) => {
355                panic!("Failed to check process status: {}", e);
356            }
357        }
358    }
359}