1use std::{
2 ffi::OsStr,
3 io,
4 pin::Pin,
5 task::{Context, Poll},
6 time::Duration,
7};
8
9use log::debug;
10use pretty_hex::PrettyHex;
11use tokio::{
12 io::{
13 AsyncBufRead, AsyncBufReadExt, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt,
14 BufReader, ReadBuf,
15 },
16 net::{TcpStream, ToSocketAddrs},
17 time,
18};
19
20use crate::utils::{Interactive, RecvUntil};
21
22use super::ProcessTube;
23
24#[derive(Debug)]
26pub struct Tube<T>
27where
28 T: AsyncBufRead + AsyncWrite + Unpin,
29{
30 pub inner: T,
32
33 pub timeout: Duration,
49
50 read_buf_logged: usize,
51}
52
53const NEW_LINE: u8 = 0xA;
54
55impl<T> Tube<BufReader<T>>
56where
57 T: AsyncRead + AsyncWrite + Unpin,
58{
59 pub fn new(inner: T) -> Self {
61 Self {
62 inner: BufReader::new(inner),
63 timeout: Duration::MAX,
64 read_buf_logged: 0,
65 }
66 }
67
68 pub fn with_timeout(inner: T, timeout: Duration) -> Self {
88 Self {
89 inner: BufReader::new(inner),
90 timeout,
91 read_buf_logged: 0,
92 }
93 }
94}
95
96impl Tube<BufReader<ProcessTube>> {
97 pub fn process<S: AsRef<OsStr>>(program: S) -> io::Result<Self> {
114 Ok(Self::new(ProcessTube::new(program)?))
115 }
116}
117
118impl Tube<BufReader<TcpStream>> {
119 pub async fn remote<A: ToSocketAddrs>(addr: A) -> io::Result<Self> {
143 Ok(Self::new(TcpStream::connect(addr).await?))
144 }
145}
146
147impl<T> Tube<T>
148where
149 T: AsyncBufRead + AsyncWrite + Unpin,
150{
151 pub fn from_buffered(inner: T) -> Self {
153 Self {
154 inner,
155 timeout: Duration::MAX,
156 read_buf_logged: 0,
157 }
158 }
159
160 pub async fn recv(&mut self, len: usize) -> io::Result<Vec<u8>> {
162 let mut buf = vec![0; len];
163 let len = time::timeout(self.timeout, self.read(&mut buf[..]))
164 .await
165 .unwrap_or(Ok(0))?;
166 buf.truncate(len);
167 Ok(buf)
168 }
169
170 pub async fn recv_line(&mut self) -> io::Result<Vec<u8>> {
172 let mut buf = Vec::new();
173 time::timeout(self.timeout, self.read_until(NEW_LINE, &mut buf))
174 .await
175 .unwrap_or(Ok(0))?;
176 Ok(buf)
177 }
178
179 pub async fn recv_until<A: AsRef<[u8]>>(&mut self, delims: A) -> io::Result<Vec<u8>> {
183 let mut buf = Vec::new();
184 time::timeout(
185 self.timeout,
186 RecvUntil::new(self, delims.as_ref(), &mut buf),
187 )
188 .await
189 .unwrap_or(Ok(()))?;
190 Ok(buf)
191 }
192
193 pub async fn send<A: AsRef<[u8]>>(&mut self, data: A) -> io::Result<()> {
195 self.write_all(data.as_ref()).await?;
196 self.flush().await
197 }
198
199 pub async fn send_line<A: AsRef<[u8]>>(&mut self, data: A) -> io::Result<()> {
201 self.write_all(data.as_ref()).await?;
202 self.write_all(&[NEW_LINE]).await?;
203 self.flush().await
204 }
205
206 pub async fn send_line_after<A: AsRef<[u8]>, B: AsRef<[u8]>>(
228 &mut self,
229 pattern: A,
230 data: B,
231 ) -> io::Result<Vec<u8>> {
232 let result = self.recv_until(pattern).await?;
233 self.send_line(data).await?;
234 Ok(result)
235 }
236
237 pub async fn interactive(&mut self) -> io::Result<()> {
239 Interactive::new(self).await
240 }
241
242 pub fn into_inner(self) -> T {
244 self.inner
245 }
246}
247
248impl<T> AsyncRead for Tube<T>
249where
250 T: AsyncBufRead + AsyncWrite + Unpin,
251{
252 fn poll_read(
253 self: Pin<&mut Self>,
254 cx: &mut Context,
255 buf: &mut ReadBuf,
256 ) -> Poll<io::Result<()>> {
257 let olen = buf.filled().len();
258
259 if Pin::new(&mut self.get_mut().inner)
260 .poll_read(cx, buf)?
261 .is_pending()
262 {
263 return Poll::Pending;
264 }
265
266 debug!(target: "Tube::recv", "Received {:?}", buf.filled()[olen..].hex_dump());
267
268 Poll::Ready(Ok(()))
269 }
270}
271
272impl<T> AsyncWrite for Tube<T>
273where
274 T: AsyncBufRead + AsyncWrite + Unpin,
275{
276 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
277 let numb = match Pin::new(&mut self.get_mut().inner).poll_write(cx, buf)? {
278 Poll::Ready(numb) => numb,
279 Poll::Pending => return Poll::Pending,
280 };
281
282 debug!(target: "Tube::send", "Sent {:?}", buf[..numb].hex_dump());
283
284 Poll::Ready(Ok(numb))
285 }
286
287 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
288 Pin::new(&mut self.get_mut().inner).poll_flush(cx)
289 }
290
291 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
292 Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
293 }
294
295 fn poll_write_vectored(
296 self: Pin<&mut Self>,
297 cx: &mut Context,
298 bufs: &[io::IoSlice],
299 ) -> Poll<Result<usize, io::Error>> {
300 let numb = match Pin::new(&mut self.get_mut().inner).poll_write_vectored(cx, bufs)? {
301 Poll::Ready(numb) => numb,
302 Poll::Pending => return Poll::Pending,
303 };
304
305 let mut to_log = numb;
306 for buf in bufs {
307 if to_log == 0 {
308 break;
309 }
310 debug!(target: "Tube::send", "Send {:?}", buf[..to_log].hex_dump());
311 to_log = to_log.saturating_sub(buf.len());
312 }
313
314 Poll::Ready(Ok(numb))
315 }
316
317 fn is_write_vectored(&self) -> bool {
318 self.inner.is_write_vectored()
319 }
320}
321
322impl<T> AsyncBufRead for Tube<T>
323where
324 T: AsyncBufRead + AsyncWrite + Unpin,
325{
326 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<&[u8]>> {
327 let Self {
328 inner,
329 timeout: _,
330 read_buf_logged,
331 } = self.get_mut();
332
333 let buf = match Pin::new(inner).poll_fill_buf(cx)? {
334 Poll::Ready(buf) => buf,
335 Poll::Pending => return Poll::Pending,
336 };
337
338 if buf.len() > *read_buf_logged {
339 debug!(target: "Tube::recv", "Recevied {:?}", buf[*read_buf_logged..].hex_dump());
340 *read_buf_logged = buf.len();
341 }
342
343 Poll::Ready(Ok(buf))
344 }
345
346 fn consume(mut self: Pin<&mut Self>, amt: usize) {
347 self.read_buf_logged -= amt;
348 Pin::new(&mut self.get_mut().inner).consume(amt);
349 }
350}
351
352impl<T> From<Tube<BufReader<T>>> for BufReader<T>
353where
354 T: AsyncRead + AsyncWrite + Unpin,
355{
356 fn from(tube: Tube<BufReader<T>>) -> Self {
357 tube.into_inner()
358 }
359}
360
361impl<T> From<T> for Tube<T>
362where
363 T: AsyncBufRead + AsyncWrite + Unpin,
364{
365 fn from(tube_like: T) -> Self {
366 Self {
367 inner: tube_like,
368 timeout: Duration::MAX,
369 read_buf_logged: 0,
370 }
371 }
372}