rama_net/stream/
read.rs

1use bytes::Bytes;
2use pin_project_lite::pin_project;
3use std::{
4    fmt,
5    io::Cursor,
6    pin::Pin,
7    task::{Context, Poll, ready},
8};
9use tokio::io::{self, AsyncBufRead, AsyncRead, ReadBuf};
10
11pin_project! {
12    /// Reader for reading from a heap-allocated bytes buffer.
13    #[derive(Debug, Clone)]
14    pub struct HeapReader {
15        #[pin]
16        inner: Cursor<Vec<u8>>,
17    }
18}
19
20impl HeapReader {
21    /// Creates a new `HeapReader` with the specified bytes data.
22    pub const fn new(data: Vec<u8>) -> Self {
23        Self {
24            inner: Cursor::new(data),
25        }
26    }
27}
28
29impl From<Vec<u8>> for HeapReader {
30    fn from(data: Vec<u8>) -> Self {
31        Self::new(data)
32    }
33}
34
35impl From<&[u8]> for HeapReader {
36    fn from(data: &[u8]) -> Self {
37        Self::new(data.to_vec())
38    }
39}
40
41impl From<&str> for HeapReader {
42    fn from(data: &str) -> Self {
43        Self::new(data.as_bytes().to_vec())
44    }
45}
46
47impl Default for HeapReader {
48    fn default() -> Self {
49        Self::new(Vec::new())
50    }
51}
52
53impl From<Bytes> for HeapReader {
54    fn from(data: Bytes) -> Self {
55        Self::new(data.to_vec())
56    }
57}
58
59impl AsyncRead for HeapReader {
60    fn poll_read(
61        self: Pin<&mut Self>,
62        cx: &mut Context<'_>,
63        buf: &mut ReadBuf<'_>,
64    ) -> Poll<io::Result<()>> {
65        self.project().inner.poll_read(cx, buf)
66    }
67}
68
69pin_project! {
70    /// Reader that can be used to chain two readers together.
71    #[must_use = "streams do nothing unless polled"]
72    #[cfg_attr(docsrs, doc(cfg(feature = "io-util")))]
73    pub struct ChainReader<T, U> {
74        #[pin]
75        first: T,
76        #[pin]
77        second: U,
78        done_first: bool,
79    }
80}
81
82impl<T, U> ChainReader<T, U>
83where
84    T: AsyncRead,
85    U: AsyncRead,
86{
87    /// Creates a new `ChainReader` with the specified readers.
88    pub const fn new(first: T, second: U) -> Self {
89        Self {
90            first,
91            second,
92            done_first: false,
93        }
94    }
95
96    /// Gets references to the underlying readers in this `ChainReader`.
97    pub fn get_ref(&self) -> (&T, &U) {
98        (&self.first, &self.second)
99    }
100
101    /// Gets mutable references to the underlying readers in this `ChainReader`.
102    ///
103    /// Care should be taken to avoid modifying the internal I/O state of the
104    /// underlying readers as doing so may corrupt the internal state of this
105    /// `ChainReader`.
106    pub fn get_mut(&mut self) -> (&mut T, &mut U) {
107        (&mut self.first, &mut self.second)
108    }
109
110    /// Gets pinned mutable references to the underlying readers in this `ChainReader`.
111    ///
112    /// Care should be taken to avoid modifying the internal I/O state of the
113    /// underlying readers as doing so may corrupt the internal state of this
114    /// `ChainReader`.
115    pub fn get_pin_mut(self: Pin<&mut Self>) -> (Pin<&mut T>, Pin<&mut U>) {
116        let me = self.project();
117        (me.first, me.second)
118    }
119
120    /// Consumes the `ChainReader`, returning the wrapped readers.
121    pub fn into_inner(self) -> (T, U) {
122        (self.first, self.second)
123    }
124}
125
126impl<T, U> fmt::Debug for ChainReader<T, U>
127where
128    T: fmt::Debug,
129    U: fmt::Debug,
130{
131    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
132        f.debug_struct("ChainReader")
133            .field("t", &self.first)
134            .field("u", &self.second)
135            .finish()
136    }
137}
138
139impl<T, U> AsyncRead for ChainReader<T, U>
140where
141    T: AsyncRead,
142    U: AsyncRead,
143{
144    fn poll_read(
145        self: Pin<&mut Self>,
146        cx: &mut Context<'_>,
147        buf: &mut ReadBuf<'_>,
148    ) -> Poll<io::Result<()>> {
149        let me = self.project();
150
151        if !*me.done_first {
152            let rem = buf.remaining();
153            ready!(me.first.poll_read(cx, buf))?;
154            if buf.remaining() == rem {
155                *me.done_first = true;
156            } else {
157                return Poll::Ready(Ok(()));
158            }
159        }
160        me.second.poll_read(cx, buf)
161    }
162}
163
164impl<T, U> AsyncBufRead for ChainReader<T, U>
165where
166    T: AsyncBufRead,
167    U: AsyncBufRead,
168{
169    fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
170        let me = self.project();
171
172        if !*me.done_first {
173            match ready!(me.first.poll_fill_buf(cx)?) {
174                [] => {
175                    *me.done_first = true;
176                }
177                buf => return Poll::Ready(Ok(buf)),
178            }
179        }
180        me.second.poll_fill_buf(cx)
181    }
182
183    fn consume(self: Pin<&mut Self>, amt: usize) {
184        let me = self.project();
185        if !*me.done_first {
186            me.first.consume(amt)
187        } else {
188            me.second.consume(amt)
189        }
190    }
191}