Skip to main content

adb_wire/
shell.rs

1//! Shell protocol (v2) stream types and packet demuxing.
2
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use tokio::io::{self, AsyncRead, AsyncReadExt, AsyncWriteExt, ReadBuf};
7use tokio::net::TcpStream;
8
9use crate::error::Result;
10
11// Shell v2 packets:
12//   [1 byte id] [4 byte LE length] [payload]
13//
14// Stream IDs:
15//   0 = stdin  (host → device)
16//   1 = stdout (device → host)
17//   2 = stderr (device → host)
18//   3 = exit   (device → host, 1-byte payload = exit code)
19//   4 = close stdin (device → host, signals stdin closed)
20
21const ID_STDIN: u8 = 0;
22const ID_STDOUT: u8 = 1;
23const ID_STDERR: u8 = 2;
24const ID_EXIT: u8 = 3;
25const ID_CLOSE_STDIN: u8 = 4;
26
27/// Maximum payload size for a single shell packet (16 MiB).
28const MAX_SHELL_PAYLOAD: usize = 16 * 1024 * 1024;
29
30/// Output from a shell command, with separated stdout/stderr and exit code.
31#[derive(Debug, Clone)]
32pub struct ShellOutput {
33    /// Standard output from the command.
34    pub stdout: Vec<u8>,
35    /// Standard error from the command.
36    pub stderr: Vec<u8>,
37    /// Exit code of the command (0 = success).
38    pub exit_code: u8,
39}
40
41impl ShellOutput {
42    /// Get stdout as a trimmed UTF-8 string.
43    pub fn stdout_str(&self) -> String {
44        String::from_utf8_lossy(&self.stdout).trim().to_string()
45    }
46
47    /// Get stderr as a trimmed UTF-8 string.
48    pub fn stderr_str(&self) -> String {
49        String::from_utf8_lossy(&self.stderr).trim().to_string()
50    }
51
52    /// Returns `true` if the command exited with code 0.
53    #[must_use]
54    pub fn success(&self) -> bool {
55        self.exit_code == 0
56    }
57}
58
59/// Streaming reader for shell command output.
60///
61/// Demultiplexes stdout, stderr, and exit code from the shell v2 packet
62/// framing. Implements [`AsyncRead`] which yields only stdout bytes.
63/// After the stream is fully consumed, call [`exit_code`](Self::exit_code)
64/// and [`stderr`](Self::stderr) to inspect the results.
65pub struct ShellStream {
66    inner: TcpStream,
67    stdout_buf: Vec<u8>,
68    stdout_pos: usize,
69    stderr: Vec<u8>,
70    exit_code: Option<u8>,
71    done: bool,
72    header_buf: [u8; 5],
73    header_pos: usize,
74    payload_buf: Vec<u8>,
75    payload_pos: usize,
76}
77
78impl ShellStream {
79    pub(crate) fn new(stream: TcpStream) -> Self {
80        Self {
81            inner: stream,
82            stdout_buf: Vec::new(),
83            stdout_pos: 0,
84            stderr: Vec::new(),
85            exit_code: None,
86            done: false,
87            header_buf: [0u8; 5],
88            header_pos: 0,
89            payload_buf: Vec::new(),
90            payload_pos: 0,
91        }
92    }
93
94    /// Consume the stream, collecting all output.
95    pub async fn collect_output(mut self) -> Result<ShellOutput> {
96        let mut stdout = Vec::new();
97        loop {
98            let more = self.read_next_packet().await?;
99            if self.stdout_pos < self.stdout_buf.len() {
100                stdout.extend_from_slice(&self.stdout_buf[self.stdout_pos..]);
101                self.stdout_buf.clear();
102                self.stdout_pos = 0;
103            }
104            if !more {
105                break;
106            }
107        }
108        Ok(ShellOutput {
109            stdout,
110            stderr: self.stderr,
111            exit_code: self.exit_code.unwrap_or(255),
112        })
113    }
114
115    /// Get accumulated stderr bytes (available as data arrives).
116    pub fn stderr(&self) -> &[u8] {
117        &self.stderr
118    }
119
120    /// Get the exit code, if received.
121    pub fn exit_code(&self) -> Option<u8> {
122        self.exit_code
123    }
124
125    /// Access the underlying [`TcpStream`].
126    pub fn as_tcp_stream(&self) -> &TcpStream {
127        &self.inner
128    }
129
130    /// Send data to the command's stdin.
131    pub async fn write_stdin(&mut self, data: &[u8]) -> Result<()> {
132        let mut pkt = Vec::with_capacity(5 + data.len());
133        pkt.push(ID_STDIN);
134        pkt.extend_from_slice(&(data.len() as u32).to_le_bytes());
135        pkt.extend_from_slice(data);
136        self.inner.write_all(&pkt).await?;
137        self.inner.flush().await?;
138        Ok(())
139    }
140
141    /// Close the command's stdin.
142    ///
143    /// Signals EOF to the remote process, causing reads from stdin to
144    /// return zero. Required for commands that read until EOF (e.g. `cat`,
145    /// `base64 -d`) to terminate.
146    pub async fn close_stdin(&mut self) -> Result<()> {
147        let pkt: [u8; 5] = [ID_CLOSE_STDIN, 0, 0, 0, 0];
148        self.inner.write_all(&pkt).await?;
149        self.inner.flush().await?;
150        Ok(())
151    }
152
153    /// Read the next packet, buffering stdout and stderr.
154    /// Returns true if more data may follow, false on exit/EOF.
155    async fn read_next_packet(&mut self) -> io::Result<bool> {
156        let mut header = [0u8; 5];
157        match self.inner.read_exact(&mut header).await {
158            Ok(_) => {}
159            Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => {
160                self.done = true;
161                return Ok(false);
162            }
163            Err(e) => return Err(e),
164        }
165
166        let id = header[0];
167        let len = u32::from_le_bytes(header[1..5].try_into().unwrap()) as usize;
168
169        if len > MAX_SHELL_PAYLOAD {
170            self.done = true;
171            return Err(io::Error::new(
172                io::ErrorKind::InvalidData,
173                format!("shell payload too large: {len} bytes"),
174            ));
175        }
176
177        let mut payload = vec![0u8; len];
178        if len > 0 {
179            self.inner.read_exact(&mut payload).await?;
180        }
181
182        match id {
183            ID_STDOUT => self.stdout_buf.extend_from_slice(&payload),
184            ID_STDERR => self.stderr.extend_from_slice(&payload),
185            ID_EXIT => {
186                self.exit_code = payload.first().copied();
187                self.done = true;
188                return Ok(false);
189            }
190            _ => {}
191        }
192
193        Ok(true)
194    }
195}
196
197impl AsyncRead for ShellStream {
198    fn poll_read(
199        self: Pin<&mut Self>,
200        cx: &mut Context<'_>,
201        buf: &mut ReadBuf<'_>,
202    ) -> Poll<io::Result<()>> {
203        let this = self.get_mut();
204
205        loop {
206            if this.stdout_pos < this.stdout_buf.len() {
207                let available = &this.stdout_buf[this.stdout_pos..];
208                let n = available.len().min(buf.remaining());
209                buf.put_slice(&available[..n]);
210                this.stdout_pos += n;
211                if this.stdout_pos == this.stdout_buf.len() {
212                    this.stdout_buf.clear();
213                    this.stdout_pos = 0;
214                }
215                return Poll::Ready(Ok(()));
216            }
217
218            if this.done {
219                return Poll::Ready(Ok(()));
220            }
221
222            while this.header_pos < 5 {
223                let mut tmp = ReadBuf::new(&mut this.header_buf[this.header_pos..]);
224                match Pin::new(&mut this.inner).poll_read(cx, &mut tmp) {
225                    Poll::Ready(Ok(())) => {
226                        let n = tmp.filled().len();
227                        if n == 0 {
228                            this.done = true;
229                            return Poll::Ready(Ok(()));
230                        }
231                        this.header_pos += n;
232                    }
233                    Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
234                    Poll::Pending => return Poll::Pending,
235                }
236            }
237
238            if this.payload_buf.is_empty() && this.payload_pos == 0 {
239                let len = u32::from_le_bytes(
240                    this.header_buf[1..5].try_into().unwrap(),
241                ) as usize;
242
243                if len > MAX_SHELL_PAYLOAD {
244                    this.done = true;
245                    return Poll::Ready(Err(io::Error::new(
246                        io::ErrorKind::InvalidData,
247                        format!("shell payload too large: {len} bytes"),
248                    )));
249                }
250
251                if len > 0 {
252                    this.payload_buf.resize(len, 0);
253                }
254            }
255
256            while this.payload_pos < this.payload_buf.len() {
257                let mut tmp =
258                    ReadBuf::new(&mut this.payload_buf[this.payload_pos..]);
259                match Pin::new(&mut this.inner).poll_read(cx, &mut tmp) {
260                    Poll::Ready(Ok(())) => {
261                        let n = tmp.filled().len();
262                        if n == 0 {
263                            this.done = true;
264                            return Poll::Ready(Ok(()));
265                        }
266                        this.payload_pos += n;
267                    }
268                    Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
269                    Poll::Pending => return Poll::Pending,
270                }
271            }
272
273            let id = this.header_buf[0];
274            let payload = std::mem::take(&mut this.payload_buf);
275            this.header_pos = 0;
276            this.payload_pos = 0;
277
278            match id {
279                ID_STDOUT => {
280                    this.stdout_buf = payload;
281                    this.stdout_pos = 0;
282                }
283                ID_STDERR => this.stderr.extend_from_slice(&payload),
284                ID_EXIT => {
285                    this.exit_code = payload.first().copied();
286                    this.done = true;
287                    return Poll::Ready(Ok(()));
288                }
289                _ => {}
290            }
291        }
292    }
293}
294
295/// Read all shell packets from a stream, returning the collected output.
296pub(crate) async fn read_shell(stream: TcpStream) -> Result<ShellOutput> {
297    ShellStream::new(stream).collect_output().await
298}
299
300#[cfg(test)]
301mod tests {
302    use super::*;
303
304    #[test]
305    fn shell_output_methods() {
306        let out = ShellOutput {
307            stdout: b"  hello\n".to_vec(),
308            stderr: b"warn\n".to_vec(),
309            exit_code: 0,
310        };
311        assert_eq!(out.stdout_str(), "hello");
312        assert_eq!(out.stderr_str(), "warn");
313        assert!(out.success());
314    }
315
316    #[test]
317    fn shell_output_failure() {
318        let out = ShellOutput {
319            stdout: Vec::new(),
320            stderr: b"error".to_vec(),
321            exit_code: 1,
322        };
323        assert!(!out.success());
324    }
325}