io_tubes/tubes/
tube.rs

1use std::{
2    ffi::OsStr,
3    io,
4    pin::Pin,
5    task::{Context, Poll},
6    time::Duration,
7};
8
9use log::debug;
10use pretty_hex::PrettyHex;
11use tokio::{
12    io::{
13        AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt,
14        BufReader, ReadBuf,
15    },
16    net::{TcpStream, ToSocketAddrs},
17    time,
18};
19
20use crate::utils::{Interactive, RecvUntil};
21
22use super::ProcessTube;
23
24/// A wrapper to provide extra methods. Note that the API from this crate is different from pwntools.
25#[derive(Debug)]
26pub struct Tube<T>
27where
28    T: AsyncBufRead + AsyncWrite + Unpin,
29{
30    /// The inner struct, usually a BufReader containing the original struct.
31    pub inner: T,
32
33    /// This field is only used by methods directly provided by this struct and not methods from
34    /// traits like [`AsyncRead`].
35    ///
36    /// This is due to the fact that during the polling, there is no way to keep track of the
37    /// futures involved. If 2 calls to the poll functions occurs, there is not enough
38    /// information in the argument to deduce whether it come from the same future or the previous
39    /// future is dropped and another future has started polling. As a result, the API will be
40    /// producing inconsistent timeout if it is implemented.
41    ///
42    /// Luckily, [`tokio::time::timeout`] provides an easy way to add timeout to a future (which is
43    /// how timeout is implemented in this library) so you can still have timeout behaviour on
44    /// functions that doesn't support them.
45    ///
46    /// Hence, timeout can only be reliably implemented for async fn (which returns a future under
47    /// the hood) or fn that return a future.
48    pub timeout: Duration,
49
50    read_buf_logged: usize,
51}
52
53const NEW_LINE: u8 = 0xA;
54
55impl<T> Tube<BufReader<T>>
56where
57    T: AsyncRead + AsyncWrite + Unpin,
58{
59    /// Construct a new `Tube<T>`.
60    pub fn new(inner: T) -> Self {
61        Self {
62            inner: BufReader::new(inner),
63            timeout: Duration::MAX,
64            read_buf_logged: 0,
65        }
66    }
67
68    /// Construct a new `Tube<T>` with the supplied timeout argument. Note that timeout is only
69    /// implemented for methods directly provided by this struct and not methods from traits.
70    ///
71    /// ```rust
72    /// use io_tubes::tubes::{ProcessTube, Tube};
73    /// use std::{io, time::Duration};
74    ///
75    /// #[tokio::main]
76    /// async fn create_with_timeout() -> io::Result<()> {
77    ///     let mut p = Tube::process("/usr/bin/cat")?;
78    ///     p.timeout = Duration::from_millis(50);
79    ///     // Equivalent to
80    ///     let mut p =
81    ///         Tube::with_timeout(ProcessTube::new("/usr/bin/cat")?, Duration::from_millis(50));
82    ///     Ok(())
83    /// }
84    ///
85    /// create_with_timeout();
86    /// ```
87    pub fn with_timeout(inner: T, timeout: Duration) -> Self {
88        Self {
89            inner: BufReader::new(inner),
90            timeout,
91            read_buf_logged: 0,
92        }
93    }
94}
95
96impl Tube<BufReader<ProcessTube>> {
97    /// Create a process with supplied path to program.
98    /// ```rust
99    /// use io_tubes::tubes::Tube;
100    /// use std::io;
101    ///
102    /// #[tokio::main]
103    /// async fn create_process() -> io::Result<()> {
104    ///     let mut p = Tube::process("/usr/bin/cat")?;
105    ///     p.send("abcdHi!").await?;
106    ///     let result = p.recv_until("Hi").await?;
107    ///     assert_eq!(result, b"abcdHi");
108    ///     Ok(())
109    /// }
110    ///
111    /// create_process();
112    /// ```
113    pub fn process<S: AsRef<OsStr>>(program: S) -> io::Result<Self> {
114        Ok(Self::new(ProcessTube::new(program)?))
115    }
116}
117
118impl Tube<BufReader<TcpStream>> {
119    /// Create a tube by connecting to the remote address.
120    /// ```rust
121    /// use io_tubes::tubes::{Listener, Tube};
122    /// use std::{
123    ///     io,
124    ///     net::{IpAddr, Ipv4Addr, SocketAddr},
125    /// };
126    ///
127    /// #[tokio::main]
128    /// async fn create_remote() -> io::Result<()> {
129    ///     let l = Listener::listen().await?;
130    ///     let mut p =
131    ///         Tube::remote(SocketAddr::new(IpAddr::V4(Ipv4Addr::LOCALHOST), l.port()?)).await?;
132    ///     let mut server = l.accept().await?;
133    ///     p.send("Client Hello").await?;
134    ///     server.send("Server Hello").await?;
135    ///     assert_eq!(p.recv_until("Hello").await?, b"Server Hello");
136    ///     assert_eq!(server.recv_until("Hello").await?, b"Client Hello");
137    ///     Ok(())
138    /// }
139    ///
140    /// create_remote();
141    /// ```
142    pub async fn remote<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
143        Ok(Self::new(TcpStream::connect(addr).await?))
144    }
145}
146
147impl<T> Tube<T>
148where
149    T: AsyncBufRead + AsyncWrite + Unpin,
150{
151    /// Construct a tube from any custom buffered type.
152    pub fn from_buffered(inner: T) -> Self {
153        Self {
154            inner,
155            timeout: Duration::MAX,
156            read_buf_logged: 0,
157        }
158    }
159
160    /// Receive up to `len` bytes.
161    pub async fn recv(&mut self, len: usize) -> io::Result<Vec<u8>> {
162        let mut buf = vec![0; len];
163        let len = time::timeout(self.timeout, self.read(&mut buf[..]))
164            .await
165            .unwrap_or(Ok(0))?;
166        buf.truncate(len);
167        Ok(buf)
168    }
169
170    /// Receive until new line (0xA byte) is reached or EOF is reached.
171    pub async fn recv_line(&mut self) -> io::Result<Vec<u8>> {
172        let mut buf = Vec::new();
173        time::timeout(self.timeout, self.read_until(NEW_LINE, &mut buf))
174            .await
175            .unwrap_or(Ok(0))?;
176        Ok(buf)
177    }
178
179    /// Receive until the delims are found or EOF is reached.
180    ///
181    /// A lookup table will be built to enable efficient matching of long patterns.
182    pub async fn recv_until<A: AsRef<[u8]>>(&mut self, delims: A) -> io::Result<Vec<u8>> {
183        let mut buf = Vec::new();
184        time::timeout(
185            self.timeout,
186            RecvUntil::new(self, delims.as_ref(), &mut buf),
187        )
188        .await
189        .unwrap_or(Ok(()))?;
190        Ok(buf)
191    }
192
193    /// Send data and flush.
194    pub async fn send<A: AsRef<[u8]>>(&mut self, data: A) -> io::Result<()> {
195        self.write_all(data.as_ref()).await?;
196        self.flush().await
197    }
198
199    /// Same as send, but add new line (0xA byte).
200    pub async fn send_line<A: AsRef<[u8]>>(&mut self, data: A) -> io::Result<()> {
201        self.write_all(data.as_ref()).await?;
202        self.write_all(&[NEW_LINE]).await?;
203        self.flush().await
204    }
205
206    /// Send line after receiving the pattern from read.
207    /// ```rust
208    /// use io_tubes::tubes::Tube;
209    /// use std::io;
210    ///
211    /// #[tokio::main]
212    /// async fn send_line_after() -> io::Result<()> {
213    ///     let mut p = Tube::process("/usr/bin/cat")?;
214    ///
215    ///     p.send("Hello, what's your name? ").await?;
216    ///     assert_eq!(
217    ///         p.send_line_after("name", "test").await?,
218    ///         b"Hello, what's your name"
219    ///     );
220    ///     assert_eq!(p.recv_line().await?, b"? test\n");
221    ///
222    ///     Ok(())
223    /// }
224    ///
225    /// send_line_after();
226    /// ```
227    pub async fn send_line_after<A: AsRef<[u8]>, B: AsRef<[u8]>>(
228        &mut self,
229        pattern: A,
230        data: B,
231    ) -> io::Result<Vec<u8>> {
232        let result = self.recv_until(pattern).await?;
233        self.send_line(data).await?;
234        Ok(result)
235    }
236
237    /// Connect the tube to stdin and stdout so you can interact with it directly.
238    pub async fn interactive(&mut self) -> io::Result<()> {
239        Interactive::new(self).await
240    }
241
242    /// Consume the tube to get back the underlying BufReader
243    pub fn into_inner(self) -> T {
244        self.inner
245    }
246}
247
248impl<T> AsyncRead for Tube<T>
249where
250    T: AsyncBufRead + AsyncWrite + Unpin,
251{
252    fn poll_read(
253        self: Pin<&mut Self>,
254        cx: &mut Context,
255        buf: &mut ReadBuf,
256    ) -> Poll<io::Result<()>> {
257        let olen = buf.filled().len();
258
259        if Pin::new(&mut self.get_mut().inner)
260            .poll_read(cx, buf)?
261            .is_pending()
262        {
263            return Poll::Pending;
264        }
265
266        debug!(target: "Tube::recv", "Received {:?}", buf.filled()[olen..].hex_dump());
267
268        Poll::Ready(Ok(()))
269    }
270}
271
272impl<T> AsyncWrite for Tube<T>
273where
274    T: AsyncBufRead + AsyncWrite + Unpin,
275{
276    fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
277        let numb = match Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)? {
278            Poll::Ready(numb) => numb,
279            Poll::Pending => return Poll::Pending,
280        };
281
282        debug!(target: "Tube::send", "Sent {:?}", buf[..numb].hex_dump());
283
284        Poll::Ready(Ok(numb))
285    }
286
287    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
288        Pin::new(&mut self.get_mut().inner).poll_flush(cx)
289    }
290
291    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
292        Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
293    }
294
295    fn poll_write_vectored(
296        self: Pin<&mut Self>,
297        cx: &mut Context,
298        bufs: &[io::IoSlice],
299    ) -> Poll<Result<usize, io::Error>> {
300        let numb = match Pin::new(&mut self.get_mut().inner).poll_write_vectored(cx, bufs)? {
301            Poll::Ready(numb) => numb,
302            Poll::Pending => return Poll::Pending,
303        };
304
305        let mut to_log = numb;
306        for buf in bufs {
307            if to_log == 0 {
308                break;
309            }
310            debug!(target: "Tube::send", "Send {:?}", buf[..to_log].hex_dump());
311            to_log = to_log.saturating_sub(buf.len());
312        }
313
314        Poll::Ready(Ok(numb))
315    }
316
317    fn is_write_vectored(&self) -> bool {
318        self.inner.is_write_vectored()
319    }
320}
321
322impl<T> AsyncBufRead for Tube<T>
323where
324    T: AsyncBufRead + AsyncWrite + Unpin,
325{
326    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<&[u8]>> {
327        let Self {
328            inner,
329            timeout: _,
330            read_buf_logged,
331        } = self.get_mut();
332
333        let buf = match Pin::new(inner).poll_fill_buf(cx)? {
334            Poll::Ready(buf) => buf,
335            Poll::Pending => return Poll::Pending,
336        };
337
338        if buf.len() > *read_buf_logged {
339            debug!(target: "Tube::recv", "Recevied {:?}", buf[*read_buf_logged..].hex_dump());
340            *read_buf_logged = buf.len();
341        }
342
343        Poll::Ready(Ok(buf))
344    }
345
346    fn consume(mut self: Pin<&mut Self>, amt: usize) {
347        self.read_buf_logged -= amt;
348        Pin::new(&mut self.get_mut().inner).consume(amt);
349    }
350}
351
352impl<T> From<Tube<BufReader<T>>> for BufReader<T>
353where
354    T: AsyncRead + AsyncWrite + Unpin,
355{
356    fn from(tube: Tube<BufReader<T>>) -> Self {
357        tube.into_inner()
358    }
359}
360
361impl<T> From<T> for Tube<T>
362where
363    T: AsyncBufRead + AsyncWrite + Unpin,
364{
365    fn from(tube_like: T) -> Self {
366        Self {
367            inner: tube_like,
368            timeout: Duration::MAX,
369            read_buf_logged: 0,
370        }
371    }
372}