ironrdp_async/
framed.rs

1use std::io;
2
3use bytes::{Bytes, BytesMut};
4use ironrdp_connector::{ConnectorResult, Sequence, Written};
5use ironrdp_core::WriteBuf;
6use ironrdp_pdu::PduHint;
7
8// TODO: investigate if we could use static async fn / return position impl trait in traits when stabilized:
9// https://github.com/rust-lang/rust/issues/91611
10
11pub trait FramedRead {
12    type ReadFut<'read>: core::future::Future<Output = io::Result<usize>> + 'read
13    where
14        Self: 'read;
15
16    /// Reads from stream and fills internal buffer
17    ///
18    /// # Cancel safety
19    ///
20    /// This method is cancel safe. If you use it as the event in a
21    /// `tokio::select!` statement and some other branch
22    /// completes first, then it is guaranteed that no data was read.
23    fn read<'a>(&'a mut self, buf: &'a mut BytesMut) -> Self::ReadFut<'a>;
24}
25
26pub trait FramedWrite {
27    type WriteAllFut<'write>: core::future::Future<Output = io::Result<()>> + 'write
28    where
29        Self: 'write;
30
31    /// Writes an entire buffer into this stream.
32    ///
33    /// # Cancel safety
34    ///
35    /// This method is not cancellation safe. If it is used as the event
36    /// in a `tokio::select!` statement and some other
37    /// branch completes first, then the provided buffer may have been
38    /// partially written, but future calls to `write_all` will start over
39    /// from the beginning of the buffer.
40    fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAllFut<'a>;
41}
42
43pub trait StreamWrapper: Sized {
44    type InnerStream;
45
46    fn from_inner(stream: Self::InnerStream) -> Self;
47
48    fn into_inner(self) -> Self::InnerStream;
49
50    fn get_inner(&self) -> &Self::InnerStream;
51
52    fn get_inner_mut(&mut self) -> &mut Self::InnerStream;
53}
54
55pub struct Framed<S> {
56    stream: S,
57    buf: BytesMut,
58}
59
60impl<S> Framed<S> {
61    pub fn peek(&self) -> &[u8] {
62        &self.buf
63    }
64}
65
66impl<S> Framed<S>
67where
68    S: StreamWrapper,
69{
70    pub fn new(stream: S::InnerStream) -> Self {
71        Self::new_with_leftover(stream, BytesMut::new())
72    }
73
74    pub fn new_with_leftover(stream: S::InnerStream, leftover: BytesMut) -> Self {
75        Self {
76            stream: S::from_inner(stream),
77            buf: leftover,
78        }
79    }
80
81    pub fn into_inner(self) -> (S::InnerStream, BytesMut) {
82        (self.stream.into_inner(), self.buf)
83    }
84
85    pub fn into_inner_no_leftover(self) -> S::InnerStream {
86        let (stream, leftover) = self.into_inner();
87        debug_assert_eq!(leftover.len(), 0, "unexpected leftover");
88        stream
89    }
90
91    pub fn get_inner(&self) -> (&S::InnerStream, &BytesMut) {
92        (self.stream.get_inner(), &self.buf)
93    }
94
95    pub fn get_inner_mut(&mut self) -> (&mut S::InnerStream, &mut BytesMut) {
96        (self.stream.get_inner_mut(), &mut self.buf)
97    }
98}
99
100impl<S> Framed<S>
101where
102    S: FramedRead,
103{
104    /// Accumulates at least `length` bytes and returns exactly `length` bytes, keeping the leftover in the internal buffer.
105    ///
106    /// # Cancel safety
107    ///
108    /// This method is cancel safe. If you use it as the event in a
109    /// `tokio::select!` statement and some other branch
110    /// completes first, then it is safe to drop the future and re-create it later.
111    /// Data may have been read, but it will be stored in the internal buffer.
112    pub async fn read_exact(&mut self, length: usize) -> io::Result<BytesMut> {
113        loop {
114            if self.buf.len() >= length {
115                return Ok(self.buf.split_to(length));
116            } else {
117                self.buf
118                    .reserve(length.checked_sub(self.buf.len()).expect("length > self.buf.len()"));
119            }
120
121            let len = self.read().await?;
122
123            // Handle EOF
124            if len == 0 {
125                return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "not enough bytes"));
126            }
127        }
128    }
129
130    /// Reads a standard RDP PDU frame.
131    ///
132    /// # Cancel safety
133    ///
134    /// This method is cancel safe. If you use it as the event in a
135    /// `tokio::select!` statement and some other branch
136    /// completes first, then it is safe to drop the future and re-create it later.
137    /// Data may have been read, but it will be stored in the internal buffer.
138    pub async fn read_pdu(&mut self) -> io::Result<(ironrdp_pdu::Action, BytesMut)> {
139        loop {
140            // Try decoding and see if a frame has been received already
141            match ironrdp_pdu::find_size(self.peek()) {
142                Ok(Some(pdu_info)) => {
143                    let frame = self.read_exact(pdu_info.length).await?;
144
145                    return Ok((pdu_info.action, frame));
146                }
147                Ok(None) => {
148                    let len = self.read().await?;
149
150                    // Handle EOF
151                    if len == 0 {
152                        return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "not enough bytes"));
153                    }
154                }
155                Err(e) => return Err(io::Error::other(e)),
156            };
157        }
158    }
159
160    /// Reads a frame using the provided PduHint.
161    ///
162    /// # Cancel safety
163    ///
164    /// This method is cancel safe. If you use it as the event in a
165    /// `tokio::select!` statement and some other branch
166    /// completes first, then it is safe to drop the future and re-create it later.
167    /// Data may have been read, but it will be stored in the internal buffer.
168    pub async fn read_by_hint(&mut self, hint: &dyn PduHint) -> io::Result<Bytes> {
169        loop {
170            match hint.find_size(self.peek()).map_err(io::Error::other)? {
171                Some((matched, length)) => {
172                    let bytes = self.read_exact(length).await?.freeze();
173                    if matched {
174                        return Ok(bytes);
175                    } else {
176                        debug!("Received and lost an unexpected PDU");
177                    }
178                }
179                None => {
180                    let len = self.read().await?;
181
182                    // Handle EOF
183                    if len == 0 {
184                        return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "not enough bytes"));
185                    }
186                }
187            };
188        }
189    }
190
191    /// Reads from stream and fills internal buffer, returning how many bytes were read.
192    ///
193    /// # Cancel safety
194    ///
195    /// This method is cancel safe. If you use it as the event in a
196    /// `tokio::select!` statement and some other branch
197    /// completes first, then it is guaranteed that no data was read.
198    async fn read(&mut self) -> io::Result<usize> {
199        self.stream.read(&mut self.buf).await
200    }
201}
202
203impl<S> FramedWrite for Framed<S>
204where
205    S: FramedWrite,
206{
207    type WriteAllFut<'write>
208        = S::WriteAllFut<'write>
209    where
210        Self: 'write;
211
212    /// Attempts to write an entire buffer into this `Framed`’s stream.
213    ///
214    /// # Cancel safety
215    ///
216    /// This method is not cancellation safe. If it is used as the event
217    /// in a `tokio::select!` statement and some other
218    /// branch completes first, then the provided buffer may have been
219    /// partially written, but future calls to `write_all` will start over
220    /// from the beginning of the buffer.
221    fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAllFut<'a> {
222        self.stream.write_all(buf)
223    }
224}
225
226pub async fn single_sequence_step<S>(
227    framed: &mut Framed<S>,
228    sequence: &mut dyn Sequence,
229    buf: &mut WriteBuf,
230) -> ConnectorResult<()>
231where
232    S: FramedWrite + FramedRead,
233{
234    buf.clear();
235    let written = single_sequence_step_read(framed, sequence, buf).await?;
236    single_sequence_step_write(framed, buf, written).await
237}
238
239pub async fn single_sequence_step_read<S>(
240    framed: &mut Framed<S>,
241    sequence: &mut dyn Sequence,
242    buf: &mut WriteBuf,
243) -> ConnectorResult<Written>
244where
245    S: FramedRead,
246{
247    buf.clear();
248
249    if let Some(next_pdu_hint) = sequence.next_pdu_hint() {
250        debug!(
251            connector.state = sequence.state().name(),
252            hint = ?next_pdu_hint,
253            "Wait for PDU"
254        );
255
256        let pdu = framed
257            .read_by_hint(next_pdu_hint)
258            .await
259            .map_err(|e| ironrdp_connector::custom_err!("read frame by hint", e))?;
260
261        trace!(length = pdu.len(), "PDU received");
262
263        sequence.step(&pdu, buf)
264    } else {
265        sequence.step_no_input(buf)
266    }
267}
268
269async fn single_sequence_step_write<S>(
270    framed: &mut Framed<S>,
271    buf: &mut WriteBuf,
272    written: Written,
273) -> ConnectorResult<()>
274where
275    S: FramedWrite,
276{
277    if let Some(response_len) = written.size() {
278        debug_assert_eq!(buf.filled_len(), response_len);
279        let response = buf.filled();
280        trace!(response_len, "Send response");
281        framed
282            .write_all(response)
283            .await
284            .map_err(|e| ironrdp_connector::custom_err!("write all", e))?;
285    }
286
287    Ok(())
288}