ironrdp_async/
framed.rs

1use std::io;
2
3use bytes::{Bytes, BytesMut};
4use ironrdp_connector::{ConnectorResult, Sequence, Written};
5use ironrdp_core::WriteBuf;
6use ironrdp_pdu::PduHint;
7use tracing::{debug, trace};
8
9// TODO: investigate if we could use static async fn / return position impl trait in traits when stabilized:
10// https://github.com/rust-lang/rust/issues/91611
11
12pub trait FramedRead {
13    type ReadFut<'read>: core::future::Future<Output = io::Result<usize>> + 'read
14    where
15        Self: 'read;
16
17    /// Reads from stream and fills internal buffer
18    ///
19    /// # Cancel safety
20    ///
21    /// This method is cancel safe. If you use it as the event in a
22    /// `tokio::select!` statement and some other branch
23    /// completes first, then it is guaranteed that no data was read.
24    fn read<'a>(&'a mut self, buf: &'a mut BytesMut) -> Self::ReadFut<'a>;
25}
26
27pub trait FramedWrite {
28    type WriteAllFut<'write>: core::future::Future<Output = io::Result<()>> + 'write
29    where
30        Self: 'write;
31
32    /// Writes an entire buffer into this stream.
33    ///
34    /// # Cancel safety
35    ///
36    /// This method is not cancellation safe. If it is used as the event
37    /// in a `tokio::select!` statement and some other
38    /// branch completes first, then the provided buffer may have been
39    /// partially written, but future calls to `write_all` will start over
40    /// from the beginning of the buffer.
41    fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAllFut<'a>;
42}
43
44pub trait StreamWrapper: Sized {
45    type InnerStream;
46
47    fn from_inner(stream: Self::InnerStream) -> Self;
48
49    fn into_inner(self) -> Self::InnerStream;
50
51    fn get_inner(&self) -> &Self::InnerStream;
52
53    fn get_inner_mut(&mut self) -> &mut Self::InnerStream;
54}
55
56pub struct Framed<S> {
57    stream: S,
58    buf: BytesMut,
59}
60
61impl<S> Framed<S> {
62    pub fn peek(&self) -> &[u8] {
63        &self.buf
64    }
65}
66
67impl<S> Framed<S>
68where
69    S: StreamWrapper,
70{
71    pub fn new(stream: S::InnerStream) -> Self {
72        Self::new_with_leftover(stream, BytesMut::new())
73    }
74
75    pub fn new_with_leftover(stream: S::InnerStream, leftover: BytesMut) -> Self {
76        Self {
77            stream: S::from_inner(stream),
78            buf: leftover,
79        }
80    }
81
82    pub fn into_inner(self) -> (S::InnerStream, BytesMut) {
83        (self.stream.into_inner(), self.buf)
84    }
85
86    pub fn into_inner_no_leftover(self) -> S::InnerStream {
87        let (stream, leftover) = self.into_inner();
88        debug_assert_eq!(leftover.len(), 0, "unexpected leftover");
89        stream
90    }
91
92    pub fn get_inner(&self) -> (&S::InnerStream, &BytesMut) {
93        (self.stream.get_inner(), &self.buf)
94    }
95
96    pub fn get_inner_mut(&mut self) -> (&mut S::InnerStream, &mut BytesMut) {
97        (self.stream.get_inner_mut(), &mut self.buf)
98    }
99}
100
101impl<S> Framed<S>
102where
103    S: FramedRead,
104{
105    /// Accumulates at least `length` bytes and returns exactly `length` bytes, keeping the leftover in the internal buffer.
106    ///
107    /// # Cancel safety
108    ///
109    /// This method is cancel safe. If you use it as the event in a
110    /// `tokio::select!` statement and some other branch
111    /// completes first, then it is safe to drop the future and re-create it later.
112    /// Data may have been read, but it will be stored in the internal buffer.
113    pub async fn read_exact(&mut self, length: usize) -> io::Result<BytesMut> {
114        loop {
115            if self.buf.len() >= length {
116                return Ok(self.buf.split_to(length));
117            } else {
118                self.buf
119                    .reserve(length.checked_sub(self.buf.len()).expect("length > self.buf.len()"));
120            }
121
122            let len = self.read().await?;
123
124            // Handle EOF
125            if len == 0 {
126                return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "not enough bytes"));
127            }
128        }
129    }
130
131    /// Reads a standard RDP PDU frame.
132    ///
133    /// # Cancel safety
134    ///
135    /// This method is cancel safe. If you use it as the event in a
136    /// `tokio::select!` statement and some other branch
137    /// completes first, then it is safe to drop the future and re-create it later.
138    /// Data may have been read, but it will be stored in the internal buffer.
139    pub async fn read_pdu(&mut self) -> io::Result<(ironrdp_pdu::Action, BytesMut)> {
140        loop {
141            // Try decoding and see if a frame has been received already
142            match ironrdp_pdu::find_size(self.peek()) {
143                Ok(Some(pdu_info)) => {
144                    let frame = self.read_exact(pdu_info.length).await?;
145
146                    return Ok((pdu_info.action, frame));
147                }
148                Ok(None) => {
149                    let len = self.read().await?;
150
151                    // Handle EOF
152                    if len == 0 {
153                        return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "not enough bytes"));
154                    }
155                }
156                Err(e) => return Err(io::Error::other(e)),
157            };
158        }
159    }
160
161    /// Reads a frame using the provided PduHint.
162    ///
163    /// # Cancel safety
164    ///
165    /// This method is cancel safe. If you use it as the event in a
166    /// `tokio::select!` statement and some other branch
167    /// completes first, then it is safe to drop the future and re-create it later.
168    /// Data may have been read, but it will be stored in the internal buffer.
169    pub async fn read_by_hint(&mut self, hint: &dyn PduHint) -> io::Result<Bytes> {
170        loop {
171            match hint.find_size(self.peek()).map_err(io::Error::other)? {
172                Some((matched, length)) => {
173                    let bytes = self.read_exact(length).await?.freeze();
174                    if matched {
175                        return Ok(bytes);
176                    } else {
177                        debug!("Received and lost an unexpected PDU");
178                    }
179                }
180                None => {
181                    let len = self.read().await?;
182
183                    // Handle EOF
184                    if len == 0 {
185                        return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "not enough bytes"));
186                    }
187                }
188            };
189        }
190    }
191
192    /// Reads from stream and fills internal buffer, returning how many bytes were read.
193    ///
194    /// # Cancel safety
195    ///
196    /// This method is cancel safe. If you use it as the event in a
197    /// `tokio::select!` statement and some other branch
198    /// completes first, then it is guaranteed that no data was read.
199    async fn read(&mut self) -> io::Result<usize> {
200        self.stream.read(&mut self.buf).await
201    }
202}
203
204impl<S> FramedWrite for Framed<S>
205where
206    S: FramedWrite,
207{
208    type WriteAllFut<'write>
209        = S::WriteAllFut<'write>
210    where
211        Self: 'write;
212
213    /// Attempts to write an entire buffer into this `Framed`’s stream.
214    ///
215    /// # Cancel safety
216    ///
217    /// This method is not cancellation safe. If it is used as the event
218    /// in a `tokio::select!` statement and some other
219    /// branch completes first, then the provided buffer may have been
220    /// partially written, but future calls to `write_all` will start over
221    /// from the beginning of the buffer.
222    fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAllFut<'a> {
223        self.stream.write_all(buf)
224    }
225}
226
227pub async fn single_sequence_step<S>(
228    framed: &mut Framed<S>,
229    sequence: &mut dyn Sequence,
230    buf: &mut WriteBuf,
231) -> ConnectorResult<()>
232where
233    S: FramedWrite + FramedRead,
234{
235    buf.clear();
236    let written = single_sequence_step_read(framed, sequence, buf).await?;
237    single_sequence_step_write(framed, buf, written).await
238}
239
240pub async fn single_sequence_step_read<S>(
241    framed: &mut Framed<S>,
242    sequence: &mut dyn Sequence,
243    buf: &mut WriteBuf,
244) -> ConnectorResult<Written>
245where
246    S: FramedRead,
247{
248    buf.clear();
249
250    if let Some(next_pdu_hint) = sequence.next_pdu_hint() {
251        debug!(
252            connector.state = sequence.state().name(),
253            hint = ?next_pdu_hint,
254            "Wait for PDU"
255        );
256
257        let pdu = framed
258            .read_by_hint(next_pdu_hint)
259            .await
260            .map_err(|e| ironrdp_connector::custom_err!("read frame by hint", e))?;
261
262        trace!(length = pdu.len(), "PDU received");
263
264        sequence.step(&pdu, buf)
265    } else {
266        sequence.step_no_input(buf)
267    }
268}
269
270async fn single_sequence_step_write<S>(
271    framed: &mut Framed<S>,
272    buf: &mut WriteBuf,
273    written: Written,
274) -> ConnectorResult<()>
275where
276    S: FramedWrite,
277{
278    if let Some(response_len) = written.size() {
279        debug_assert_eq!(buf.filled_len(), response_len);
280        let response = buf.filled();
281        trace!(response_len, "Send response");
282        framed
283            .write_all(response)
284            .await
285            .map_err(|e| ironrdp_connector::custom_err!("write all", e))?;
286    }
287
288    Ok(())
289}