1use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use tokio::io::{self, AsyncRead, AsyncReadExt, AsyncWriteExt, ReadBuf};
7use tokio::net::TcpStream;
8
9use crate::error::Result;
10
11const ID_STDIN: u8 = 0;
22const ID_STDOUT: u8 = 1;
23const ID_STDERR: u8 = 2;
24const ID_EXIT: u8 = 3;
25const ID_CLOSE_STDIN: u8 = 4;
26
27const MAX_SHELL_PAYLOAD: usize = 16 * 1024 * 1024;
29
30#[derive(Debug, Clone)]
32pub struct ShellOutput {
33 pub stdout: Vec<u8>,
35 pub stderr: Vec<u8>,
37 pub exit_code: u8,
39}
40
41impl ShellOutput {
42 pub fn stdout_str(&self) -> String {
44 String::from_utf8_lossy(&self.stdout).trim().to_string()
45 }
46
47 pub fn stderr_str(&self) -> String {
49 String::from_utf8_lossy(&self.stderr).trim().to_string()
50 }
51
52 #[must_use]
54 pub fn success(&self) -> bool {
55 self.exit_code == 0
56 }
57}
58
59pub struct ShellStream {
66 inner: TcpStream,
67 stdout_buf: Vec<u8>,
68 stdout_pos: usize,
69 stderr: Vec<u8>,
70 exit_code: Option<u8>,
71 done: bool,
72 header_buf: [u8; 5],
73 header_pos: usize,
74 payload_buf: Vec<u8>,
75 payload_pos: usize,
76}
77
78impl ShellStream {
79 pub(crate) fn new(stream: TcpStream) -> Self {
80 Self {
81 inner: stream,
82 stdout_buf: Vec::new(),
83 stdout_pos: 0,
84 stderr: Vec::new(),
85 exit_code: None,
86 done: false,
87 header_buf: [0u8; 5],
88 header_pos: 0,
89 payload_buf: Vec::new(),
90 payload_pos: 0,
91 }
92 }
93
94 pub async fn collect_output(mut self) -> Result<ShellOutput> {
96 let mut stdout = Vec::new();
97 loop {
98 let more = self.read_next_packet().await?;
99 if self.stdout_pos < self.stdout_buf.len() {
100 stdout.extend_from_slice(&self.stdout_buf[self.stdout_pos..]);
101 self.stdout_buf.clear();
102 self.stdout_pos = 0;
103 }
104 if !more {
105 break;
106 }
107 }
108 Ok(ShellOutput {
109 stdout,
110 stderr: self.stderr,
111 exit_code: self.exit_code.unwrap_or(255),
112 })
113 }
114
115 pub fn stderr(&self) -> &[u8] {
117 &self.stderr
118 }
119
120 pub fn exit_code(&self) -> Option<u8> {
122 self.exit_code
123 }
124
125 pub fn as_tcp_stream(&self) -> &TcpStream {
127 &self.inner
128 }
129
130 pub async fn write_stdin(&mut self, data: &[u8]) -> Result<()> {
132 let mut pkt = Vec::with_capacity(5 + data.len());
133 pkt.push(ID_STDIN);
134 pkt.extend_from_slice(&(data.len() as u32).to_le_bytes());
135 pkt.extend_from_slice(data);
136 self.inner.write_all(&pkt).await?;
137 self.inner.flush().await?;
138 Ok(())
139 }
140
141 pub async fn close_stdin(&mut self) -> Result<()> {
147 let pkt: [u8; 5] = [ID_CLOSE_STDIN, 0, 0, 0, 0];
148 self.inner.write_all(&pkt).await?;
149 self.inner.flush().await?;
150 Ok(())
151 }
152
153 async fn read_next_packet(&mut self) -> io::Result<bool> {
156 let mut header = [0u8; 5];
157 match self.inner.read_exact(&mut header).await {
158 Ok(_) => {}
159 Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => {
160 self.done = true;
161 return Ok(false);
162 }
163 Err(e) => return Err(e),
164 }
165
166 let id = header[0];
167 let len = u32::from_le_bytes(header[1..5].try_into().unwrap()) as usize;
168
169 if len > MAX_SHELL_PAYLOAD {
170 self.done = true;
171 return Err(io::Error::new(
172 io::ErrorKind::InvalidData,
173 format!("shell payload too large: {len} bytes"),
174 ));
175 }
176
177 let mut payload = vec![0u8; len];
178 if len > 0 {
179 self.inner.read_exact(&mut payload).await?;
180 }
181
182 match id {
183 ID_STDOUT => self.stdout_buf.extend_from_slice(&payload),
184 ID_STDERR => self.stderr.extend_from_slice(&payload),
185 ID_EXIT => {
186 self.exit_code = payload.first().copied();
187 self.done = true;
188 return Ok(false);
189 }
190 _ => {}
191 }
192
193 Ok(true)
194 }
195}
196
197impl AsyncRead for ShellStream {
198 fn poll_read(
199 self: Pin<&mut Self>,
200 cx: &mut Context<'_>,
201 buf: &mut ReadBuf<'_>,
202 ) -> Poll<io::Result<()>> {
203 let this = self.get_mut();
204
205 loop {
206 if this.stdout_pos < this.stdout_buf.len() {
207 let available = &this.stdout_buf[this.stdout_pos..];
208 let n = available.len().min(buf.remaining());
209 buf.put_slice(&available[..n]);
210 this.stdout_pos += n;
211 if this.stdout_pos == this.stdout_buf.len() {
212 this.stdout_buf.clear();
213 this.stdout_pos = 0;
214 }
215 return Poll::Ready(Ok(()));
216 }
217
218 if this.done {
219 return Poll::Ready(Ok(()));
220 }
221
222 while this.header_pos < 5 {
223 let mut tmp = ReadBuf::new(&mut this.header_buf[this.header_pos..]);
224 match Pin::new(&mut this.inner).poll_read(cx, &mut tmp) {
225 Poll::Ready(Ok(())) => {
226 let n = tmp.filled().len();
227 if n == 0 {
228 this.done = true;
229 return Poll::Ready(Ok(()));
230 }
231 this.header_pos += n;
232 }
233 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
234 Poll::Pending => return Poll::Pending,
235 }
236 }
237
238 if this.payload_buf.is_empty() && this.payload_pos == 0 {
239 let len = u32::from_le_bytes(
240 this.header_buf[1..5].try_into().unwrap(),
241 ) as usize;
242
243 if len > MAX_SHELL_PAYLOAD {
244 this.done = true;
245 return Poll::Ready(Err(io::Error::new(
246 io::ErrorKind::InvalidData,
247 format!("shell payload too large: {len} bytes"),
248 )));
249 }
250
251 if len > 0 {
252 this.payload_buf.resize(len, 0);
253 }
254 }
255
256 while this.payload_pos < this.payload_buf.len() {
257 let mut tmp =
258 ReadBuf::new(&mut this.payload_buf[this.payload_pos..]);
259 match Pin::new(&mut this.inner).poll_read(cx, &mut tmp) {
260 Poll::Ready(Ok(())) => {
261 let n = tmp.filled().len();
262 if n == 0 {
263 this.done = true;
264 return Poll::Ready(Ok(()));
265 }
266 this.payload_pos += n;
267 }
268 Poll::Ready(Err(e)) => return Poll::Ready(Err(e)),
269 Poll::Pending => return Poll::Pending,
270 }
271 }
272
273 let id = this.header_buf[0];
274 let payload = std::mem::take(&mut this.payload_buf);
275 this.header_pos = 0;
276 this.payload_pos = 0;
277
278 match id {
279 ID_STDOUT => {
280 this.stdout_buf = payload;
281 this.stdout_pos = 0;
282 }
283 ID_STDERR => this.stderr.extend_from_slice(&payload),
284 ID_EXIT => {
285 this.exit_code = payload.first().copied();
286 this.done = true;
287 return Poll::Ready(Ok(()));
288 }
289 _ => {}
290 }
291 }
292 }
293}
294
295pub(crate) async fn read_shell(stream: TcpStream) -> Result<ShellOutput> {
297 ShellStream::new(stream).collect_output().await
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn shell_output_methods() {
306 let out = ShellOutput {
307 stdout: b" hello\n".to_vec(),
308 stderr: b"warn\n".to_vec(),
309 exit_code: 0,
310 };
311 assert_eq!(out.stdout_str(), "hello");
312 assert_eq!(out.stderr_str(), "warn");
313 assert!(out.success());
314 }
315
316 #[test]
317 fn shell_output_failure() {
318 let out = ShellOutput {
319 stdout: Vec::new(),
320 stderr: b"error".to_vec(),
321 exit_code: 1,
322 };
323 assert!(!out.success());
324 }
325}