1use std::io;
2
3use bytes::{Bytes, BytesMut};
4use ironrdp_connector::{ConnectorResult, Sequence, Written};
5use ironrdp_core::WriteBuf;
6use ironrdp_pdu::PduHint;
7
8pub trait FramedRead {
12 type ReadFut<'read>: core::future::Future<Output = io::Result<usize>> + 'read
13 where
14 Self: 'read;
15
16 fn read<'a>(&'a mut self, buf: &'a mut BytesMut) -> Self::ReadFut<'a>;
24}
25
26pub trait FramedWrite {
27 type WriteAllFut<'write>: core::future::Future<Output = io::Result<()>> + 'write
28 where
29 Self: 'write;
30
31 fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAllFut<'a>;
41}
42
43pub trait StreamWrapper: Sized {
44 type InnerStream;
45
46 fn from_inner(stream: Self::InnerStream) -> Self;
47
48 fn into_inner(self) -> Self::InnerStream;
49
50 fn get_inner(&self) -> &Self::InnerStream;
51
52 fn get_inner_mut(&mut self) -> &mut Self::InnerStream;
53}
54
55pub struct Framed<S> {
56 stream: S,
57 buf: BytesMut,
58}
59
60impl<S> Framed<S> {
61 pub fn peek(&self) -> &[u8] {
62 &self.buf
63 }
64}
65
66impl<S> Framed<S>
67where
68 S: StreamWrapper,
69{
70 pub fn new(stream: S::InnerStream) -> Self {
71 Self::new_with_leftover(stream, BytesMut::new())
72 }
73
74 pub fn new_with_leftover(stream: S::InnerStream, leftover: BytesMut) -> Self {
75 Self {
76 stream: S::from_inner(stream),
77 buf: leftover,
78 }
79 }
80
81 pub fn into_inner(self) -> (S::InnerStream, BytesMut) {
82 (self.stream.into_inner(), self.buf)
83 }
84
85 pub fn into_inner_no_leftover(self) -> S::InnerStream {
86 let (stream, leftover) = self.into_inner();
87 debug_assert_eq!(leftover.len(), 0, "unexpected leftover");
88 stream
89 }
90
91 pub fn get_inner(&self) -> (&S::InnerStream, &BytesMut) {
92 (self.stream.get_inner(), &self.buf)
93 }
94
95 pub fn get_inner_mut(&mut self) -> (&mut S::InnerStream, &mut BytesMut) {
96 (self.stream.get_inner_mut(), &mut self.buf)
97 }
98}
99
100impl<S> Framed<S>
101where
102 S: FramedRead,
103{
104 pub async fn read_exact(&mut self, length: usize) -> io::Result<BytesMut> {
113 loop {
114 if self.buf.len() >= length {
115 return Ok(self.buf.split_to(length));
116 } else {
117 self.buf
118 .reserve(length.checked_sub(self.buf.len()).expect("length > self.buf.len()"));
119 }
120
121 let len = self.read().await?;
122
123 if len == 0 {
125 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "not enough bytes"));
126 }
127 }
128 }
129
130 pub async fn read_pdu(&mut self) -> io::Result<(ironrdp_pdu::Action, BytesMut)> {
139 loop {
140 match ironrdp_pdu::find_size(self.peek()) {
142 Ok(Some(pdu_info)) => {
143 let frame = self.read_exact(pdu_info.length).await?;
144
145 return Ok((pdu_info.action, frame));
146 }
147 Ok(None) => {
148 let len = self.read().await?;
149
150 if len == 0 {
152 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "not enough bytes"));
153 }
154 }
155 Err(e) => return Err(io::Error::other(e)),
156 };
157 }
158 }
159
160 pub async fn read_by_hint(&mut self, hint: &dyn PduHint) -> io::Result<Bytes> {
169 loop {
170 match hint.find_size(self.peek()).map_err(io::Error::other)? {
171 Some((matched, length)) => {
172 let bytes = self.read_exact(length).await?.freeze();
173 if matched {
174 return Ok(bytes);
175 } else {
176 debug!("Received and lost an unexpected PDU");
177 }
178 }
179 None => {
180 let len = self.read().await?;
181
182 if len == 0 {
184 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "not enough bytes"));
185 }
186 }
187 };
188 }
189 }
190
191 async fn read(&mut self) -> io::Result<usize> {
199 self.stream.read(&mut self.buf).await
200 }
201}
202
203impl<S> FramedWrite for Framed<S>
204where
205 S: FramedWrite,
206{
207 type WriteAllFut<'write>
208 = S::WriteAllFut<'write>
209 where
210 Self: 'write;
211
212 fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAllFut<'a> {
222 self.stream.write_all(buf)
223 }
224}
225
226pub async fn single_sequence_step<S>(
227 framed: &mut Framed<S>,
228 sequence: &mut dyn Sequence,
229 buf: &mut WriteBuf,
230) -> ConnectorResult<()>
231where
232 S: FramedWrite + FramedRead,
233{
234 buf.clear();
235 let written = single_sequence_step_read(framed, sequence, buf).await?;
236 single_sequence_step_write(framed, buf, written).await
237}
238
239pub async fn single_sequence_step_read<S>(
240 framed: &mut Framed<S>,
241 sequence: &mut dyn Sequence,
242 buf: &mut WriteBuf,
243) -> ConnectorResult<Written>
244where
245 S: FramedRead,
246{
247 buf.clear();
248
249 if let Some(next_pdu_hint) = sequence.next_pdu_hint() {
250 debug!(
251 connector.state = sequence.state().name(),
252 hint = ?next_pdu_hint,
253 "Wait for PDU"
254 );
255
256 let pdu = framed
257 .read_by_hint(next_pdu_hint)
258 .await
259 .map_err(|e| ironrdp_connector::custom_err!("read frame by hint", e))?;
260
261 trace!(length = pdu.len(), "PDU received");
262
263 sequence.step(&pdu, buf)
264 } else {
265 sequence.step_no_input(buf)
266 }
267}
268
269async fn single_sequence_step_write<S>(
270 framed: &mut Framed<S>,
271 buf: &mut WriteBuf,
272 written: Written,
273) -> ConnectorResult<()>
274where
275 S: FramedWrite,
276{
277 if let Some(response_len) = written.size() {
278 debug_assert_eq!(buf.filled_len(), response_len);
279 let response = buf.filled();
280 trace!(response_len, "Send response");
281 framed
282 .write_all(response)
283 .await
284 .map_err(|e| ironrdp_connector::custom_err!("write all", e))?;
285 }
286
287 Ok(())
288}