async_ssh2_russh/
read_stream.rs

1use std::io::{BufRead, Read};
2use std::pin::Pin;
3use std::task::{ready, Context, Poll};
4
5use russh::CryptoVec;
6use tokio::io::{AsyncBufRead, AsyncRead, ReadBuf};
7use tokio::sync::mpsc;
8
9/// Read byte data from an SSH channel stream.
10///
11/// Implements [`AsyncRead`], [`AsyncBufRead`], [`Read`], and [`BufRead`].
12pub struct ReadStream {
13    recv: mpsc::UnboundedReceiver<CryptoVec>,
14    buffer: Option<(CryptoVec, usize)>,
15}
16impl ReadStream {
17    pub(crate) fn from_recv(recv: mpsc::UnboundedReceiver<CryptoVec>) -> Self {
18        Self { recv, buffer: None }
19    }
20
21    fn consume_internal(&mut self, amt: usize) {
22        if let Some((buf, offset)) = &mut self.buffer {
23            *offset += amt;
24            debug_assert!(*offset <= buf.len());
25            if *offset == buf.len() {
26                self.buffer = None;
27            }
28        } else {
29            debug_assert!(amt == 0);
30        }
31    }
32}
33impl AsyncRead for ReadStream {
34    fn poll_read(mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>) -> Poll<std::io::Result<()>> {
35        // Defer to `AsyncBufRead`.
36        let read_buf = ready!(self.as_mut().poll_fill_buf(cx))?;
37        let amt = std::cmp::min(read_buf.len(), buf.capacity());
38        buf.put_slice(&read_buf[..amt]);
39        self.consume(amt);
40        Poll::Ready(Ok(()))
41    }
42}
43impl AsyncBufRead for ReadStream {
44    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<&[u8]>> {
45        let this = self.get_mut();
46
47        if this.buffer.is_none() {
48            let opt_data = ready!(this.recv.poll_recv(cx));
49            this.buffer = opt_data.map(|data| (data, 0));
50        }
51
52        Poll::Ready(Ok(this
53            .buffer
54            .as_ref()
55            .map(|(buf, offset)| &buf[*offset..])
56            .unwrap_or(&[])))
57    }
58
59    fn consume(self: Pin<&mut Self>, amt: usize) {
60        let this = self.get_mut();
61        this.consume_internal(amt)
62    }
63}
64impl Read for ReadStream {
65    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
66        // Defer to `BufRead`.
67        let read_buf = self.fill_buf()?;
68        let amt = std::cmp::min(read_buf.len(), buf.len());
69        buf.copy_from_slice(&read_buf[..amt]);
70        self.consume(amt);
71        Ok(amt)
72    }
73}
74impl BufRead for ReadStream {
75    fn fill_buf(&mut self) -> std::io::Result<&[u8]> {
76        if self.buffer.is_none() {
77            let opt_data = self.recv.blocking_recv();
78            self.buffer = opt_data.map(|data| (data, 0));
79        }
80
81        Ok(self.buffer.as_ref().map(|(buf, offset)| &buf[*offset..]).unwrap_or(&[]))
82    }
83
84    fn consume(&mut self, amt: usize) {
85        self.consume_internal(amt)
86    }
87}