1use std::io;
2
3use bytes::{Bytes, BytesMut};
4use ironrdp_connector::{ConnectorResult, Sequence, Written};
5use ironrdp_core::WriteBuf;
6use ironrdp_pdu::PduHint;
7
8pub trait FramedRead {
12 type ReadFut<'read>: core::future::Future<Output = io::Result<usize>> + 'read
13 where
14 Self: 'read;
15
16 fn read<'a>(&'a mut self, buf: &'a mut BytesMut) -> Self::ReadFut<'a>;
24}
25
26pub trait FramedWrite {
27 type WriteAllFut<'write>: core::future::Future<Output = io::Result<()>> + 'write
28 where
29 Self: 'write;
30
31 fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAllFut<'a>;
41}
42
43pub trait StreamWrapper: Sized {
44 type InnerStream;
45
46 fn from_inner(stream: Self::InnerStream) -> Self;
47
48 fn into_inner(self) -> Self::InnerStream;
49
50 fn get_inner(&self) -> &Self::InnerStream;
51
52 fn get_inner_mut(&mut self) -> &mut Self::InnerStream;
53}
54
55pub struct Framed<S> {
56 stream: S,
57 buf: BytesMut,
58}
59
60impl<S> Framed<S> {
61 pub fn peek(&self) -> &[u8] {
62 &self.buf
63 }
64}
65
66impl<S> Framed<S>
67where
68 S: StreamWrapper,
69{
70 pub fn new(stream: S::InnerStream) -> Self {
71 Self::new_with_leftover(stream, BytesMut::new())
72 }
73
74 pub fn new_with_leftover(stream: S::InnerStream, leftover: BytesMut) -> Self {
75 Self {
76 stream: S::from_inner(stream),
77 buf: leftover,
78 }
79 }
80
81 pub fn into_inner(self) -> (S::InnerStream, BytesMut) {
82 (self.stream.into_inner(), self.buf)
83 }
84
85 pub fn into_inner_no_leftover(self) -> S::InnerStream {
86 let (stream, leftover) = self.into_inner();
87 debug_assert_eq!(leftover.len(), 0, "unexpected leftover");
88 stream
89 }
90
91 pub fn get_inner(&self) -> (&S::InnerStream, &BytesMut) {
92 (self.stream.get_inner(), &self.buf)
93 }
94
95 pub fn get_inner_mut(&mut self) -> (&mut S::InnerStream, &mut BytesMut) {
96 (self.stream.get_inner_mut(), &mut self.buf)
97 }
98}
99
100impl<S> Framed<S>
101where
102 S: FramedRead,
103{
104 pub async fn read_exact(&mut self, length: usize) -> io::Result<BytesMut> {
113 loop {
114 if self.buf.len() >= length {
115 return Ok(self.buf.split_to(length));
116 } else {
117 self.buf
118 .reserve(length.checked_sub(self.buf.len()).expect("length > self.buf.len()"));
119 }
120
121 let len = self.read().await?;
122
123 if len == 0 {
125 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "not enough bytes"));
126 }
127 }
128 }
129
130 pub async fn read_pdu(&mut self) -> io::Result<(ironrdp_pdu::Action, BytesMut)> {
139 loop {
140 match ironrdp_pdu::find_size(self.peek()) {
142 Ok(Some(pdu_info)) => {
143 let frame = self.read_exact(pdu_info.length).await?;
144
145 return Ok((pdu_info.action, frame));
146 }
147 Ok(None) => {
148 let len = self.read().await?;
149
150 if len == 0 {
152 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "not enough bytes"));
153 }
154 }
155 Err(e) => return Err(io::Error::new(io::ErrorKind::Other, e)),
156 };
157 }
158 }
159
160 pub async fn read_by_hint(&mut self, hint: &dyn PduHint) -> io::Result<Bytes> {
169 loop {
170 match hint
171 .find_size(self.peek())
172 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
173 {
174 Some((matched, length)) => {
175 let bytes = self.read_exact(length).await?.freeze();
176 if matched {
177 return Ok(bytes);
178 } else {
179 debug!("Received and lost an unexpected PDU");
180 }
181 }
182 None => {
183 let len = self.read().await?;
184
185 if len == 0 {
187 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "not enough bytes"));
188 }
189 }
190 };
191 }
192 }
193
194 async fn read(&mut self) -> io::Result<usize> {
202 self.stream.read(&mut self.buf).await
203 }
204}
205
206impl<S> FramedWrite for Framed<S>
207where
208 S: FramedWrite,
209{
210 type WriteAllFut<'write>
211 = S::WriteAllFut<'write>
212 where
213 Self: 'write;
214
215 fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAllFut<'a> {
225 self.stream.write_all(buf)
226 }
227}
228
229pub async fn single_sequence_step<S>(
230 framed: &mut Framed<S>,
231 sequence: &mut dyn Sequence,
232 buf: &mut WriteBuf,
233) -> ConnectorResult<()>
234where
235 S: FramedWrite + FramedRead,
236{
237 buf.clear();
238 let written = single_sequence_step_read(framed, sequence, buf).await?;
239 single_sequence_step_write(framed, buf, written).await
240}
241
242pub async fn single_sequence_step_read<S>(
243 framed: &mut Framed<S>,
244 sequence: &mut dyn Sequence,
245 buf: &mut WriteBuf,
246) -> ConnectorResult<Written>
247where
248 S: FramedRead,
249{
250 buf.clear();
251
252 if let Some(next_pdu_hint) = sequence.next_pdu_hint() {
253 debug!(
254 connector.state = sequence.state().name(),
255 hint = ?next_pdu_hint,
256 "Wait for PDU"
257 );
258
259 let pdu = framed
260 .read_by_hint(next_pdu_hint)
261 .await
262 .map_err(|e| ironrdp_connector::custom_err!("read frame by hint", e))?;
263
264 trace!(length = pdu.len(), "PDU received");
265
266 sequence.step(&pdu, buf)
267 } else {
268 sequence.step_no_input(buf)
269 }
270}
271
272async fn single_sequence_step_write<S>(
273 framed: &mut Framed<S>,
274 buf: &mut WriteBuf,
275 written: Written,
276) -> ConnectorResult<()>
277where
278 S: FramedWrite,
279{
280 if let Some(response_len) = written.size() {
281 debug_assert_eq!(buf.filled_len(), response_len);
282 let response = buf.filled();
283 trace!(response_len, "Send response");
284 framed
285 .write_all(response)
286 .await
287 .map_err(|e| ironrdp_connector::custom_err!("write all", e))?;
288 }
289
290 Ok(())
291}