russh_process/
child.rs

1
2use core::pin::Pin;
3use core::task::{Context, Poll};
4use russh::{Channel, ChannelId, ChannelMsg};
5use std::io;
6use tokio::io::{AsyncRead, AsyncWrite, ReadBuf, ReadHalf, SimplexStream, WriteHalf};
7use tokio::task::JoinHandle;
8use crate::{ExitStatus, ExitStatusImp};
9
10/// Represents the standard input (stdin) of a child process.
11/// Implements `AsyncWrite` for non-blocking I/O.
12#[derive(Debug)]
13pub struct ChildStdin {
14    pub(crate) inner: WriteHalf<SimplexStream>,
15}
16
17impl AsyncWrite for ChildStdin {
18    fn poll_write(
19        self: Pin<&mut Self>,
20        cx: &mut Context<'_>,
21        buf: &[u8],
22    ) -> Poll<Result<usize, io::Error>> {
23        let this = self.get_mut();
24        Pin::new(&mut this.inner).poll_write(cx, buf)
25    }
26
27    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
28        let this = self.get_mut();
29        Pin::new(&mut this.inner).poll_flush(cx)
30    }
31
32    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), io::Error>> {
33        let this = self.get_mut();
34        Pin::new(&mut this.inner).poll_shutdown(cx)
35    }
36}
37
38/// Represents the standard output (stdout) of a child process.
39/// Implements `AsyncRead` for non-blocking I/O.
40#[derive(Debug)]
41pub struct ChildStdout {
42    pub(crate) inner: ReadHalf<SimplexStream>,
43}
44
45impl AsyncRead for ChildStdout {
46    fn poll_read(
47        self: Pin<&mut Self>,
48        cx: &mut Context,
49        buf: &mut ReadBuf,
50    ) -> Poll<Result<(), io::Error>> {
51        let this = self.get_mut();
52        Pin::new(&mut this.inner).poll_read(cx, buf)
53    }
54}
55
56/// Represents the standard error (stderr) of a child process.
57/// Implements `AsyncRead` for non-blocking I/O.
58#[derive(Debug)]
59pub struct ChildStderr {
60    pub(crate) inner: ReadHalf<SimplexStream>,
61}
62
63impl AsyncRead for ChildStderr {
64    fn poll_read(
65        self: Pin<&mut Self>,
66        cx: &mut Context,
67        buf: &mut ReadBuf,
68    ) -> Poll<Result<(), io::Error>> {
69        let this = self.get_mut();
70        Pin::new(&mut this.inner).poll_read(cx, buf)
71    }
72}
73
74/// A running child process, providing access to its standard I/O streams
75/// and the ability to wait for its completion.
76#[derive(Debug)]
77pub struct Child {
78    pub stdin: Option<ChildStdin>,
79    pub stdout: Option<ChildStdout>,
80    pub stderr: Option<ChildStderr>,
81    pub(crate) handle: JoinHandle<Result<ExitStatus, io::Error>>,
82}
83
84#[derive(Debug)]
85pub(crate) struct ChildImp<S>
86where
87    S: From<(ChannelId, ChannelMsg)> + Send + Sync + 'static,
88{
89    pub(crate) channel: Channel<S>,
90    pub(crate) stdin_rx: ReadHalf<SimplexStream>,
91    pub(crate) stdout_tx: WriteHalf<SimplexStream>,
92    pub(crate) stderr_tx: WriteHalf<SimplexStream>,
93}
94
95impl<S> ChildImp<S>
96where
97    S: From<(ChannelId, ChannelMsg)> + Send + Sync + 'static,
98{
99    pub async fn wait(mut self) -> Result<ExitStatus, io::Error> {
100        use tokio::io::AsyncWriteExt;
101
102        let mut code = ExitStatusImp::Processing;
103
104        let mut writer = self.channel.make_writer_ext(None);
105        let mut stdin_rx = self.stdin_rx;
106        tokio::spawn(async move {
107            let _ = tokio::io::copy(&mut stdin_rx, &mut writer).await; // TODO: handle error
108        });
109
110        loop {
111            let Some(msg) = self.channel.wait().await else {
112                break;
113            };
114            match msg {
115                ChannelMsg::ExitStatus { exit_status } => {
116                    // Do not return here, we need to read all the data
117                    code = ExitStatusImp::Code(exit_status);
118                }
119                ChannelMsg::Data { ref data } => {
120                    self.stdout_tx.write_all(data).await?;
121                }
122                ChannelMsg::ExtendedData { ref data, ext: 1 } => {
123                    self.stderr_tx.write_all(data).await?;
124                }
125                _ => {}
126            }
127        }
128        tokio::try_join!(self.stdout_tx.shutdown(), self.stderr_tx.shutdown())?;
129        Ok(ExitStatus { inner: code })
130    }
131}
132
133impl Child {
134    /// Waits for the child process to exit, returning the exit status.
135    pub async fn wait(self) -> Result<ExitStatus, io::Error> {
136        self.handle.await?
137    }
138
139    /// Waits for the child process to exit, returning the exit status, stdout, and stderr.
140    pub async fn wait_with_output(mut self) -> Result<Output, io::Error> {
141        async fn read_to_end<A: AsyncRead + Unpin>(io: &mut Option<A>) -> io::Result<Vec<u8>> {
142            use tokio::io::AsyncReadExt;
143            let mut vec = Vec::new();
144            if let Some(io) = io.as_mut() {
145                io.read_to_end(&mut vec).await?;
146            }
147            Ok(vec)
148        }
149
150        let mut stdout_pipe = self.stdout.take();
151        let mut stderr_pipe = self.stderr.take();
152
153        let stdout_fut = read_to_end(&mut stdout_pipe);
154        let stderr_fut = read_to_end(&mut stderr_pipe);
155
156        let (status, stdout, stderr) = tokio::try_join!(self.wait(), stdout_fut, stderr_fut)?;
157
158        drop(stdout_pipe);
159        drop(stderr_pipe);
160
161        Ok(Output {
162            status,
163            stdout,
164            stderr,
165        })
166    }
167}
168
169/// The result of a completed command, including exit status,
170/// standard output, and standard error.
171#[derive(Debug)]
172pub struct Output {
173    pub status: ExitStatus,
174    pub stdout: Vec<u8>,
175    pub stderr: Vec<u8>,
176}