1use std::io;
2
3use bytes::{Bytes, BytesMut};
4use ironrdp_connector::{ConnectorResult, Sequence, Written};
5use ironrdp_core::WriteBuf;
6use ironrdp_pdu::PduHint;
7use tracing::{debug, trace};
8
9pub trait FramedRead {
13 type ReadFut<'read>: core::future::Future<Output = io::Result<usize>> + 'read
14 where
15 Self: 'read;
16
17 fn read<'a>(&'a mut self, buf: &'a mut BytesMut) -> Self::ReadFut<'a>;
25}
26
27pub trait FramedWrite {
28 type WriteAllFut<'write>: core::future::Future<Output = io::Result<()>> + 'write
29 where
30 Self: 'write;
31
32 fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAllFut<'a>;
42}
43
44pub trait StreamWrapper: Sized {
45 type InnerStream;
46
47 fn from_inner(stream: Self::InnerStream) -> Self;
48
49 fn into_inner(self) -> Self::InnerStream;
50
51 fn get_inner(&self) -> &Self::InnerStream;
52
53 fn get_inner_mut(&mut self) -> &mut Self::InnerStream;
54}
55
56pub struct Framed<S> {
57 stream: S,
58 buf: BytesMut,
59}
60
61impl<S> Framed<S> {
62 pub fn peek(&self) -> &[u8] {
63 &self.buf
64 }
65}
66
67impl<S> Framed<S>
68where
69 S: StreamWrapper,
70{
71 pub fn new(stream: S::InnerStream) -> Self {
72 Self::new_with_leftover(stream, BytesMut::new())
73 }
74
75 pub fn new_with_leftover(stream: S::InnerStream, leftover: BytesMut) -> Self {
76 Self {
77 stream: S::from_inner(stream),
78 buf: leftover,
79 }
80 }
81
82 pub fn into_inner(self) -> (S::InnerStream, BytesMut) {
83 (self.stream.into_inner(), self.buf)
84 }
85
86 pub fn into_inner_no_leftover(self) -> S::InnerStream {
87 let (stream, leftover) = self.into_inner();
88 debug_assert_eq!(leftover.len(), 0, "unexpected leftover");
89 stream
90 }
91
92 pub fn get_inner(&self) -> (&S::InnerStream, &BytesMut) {
93 (self.stream.get_inner(), &self.buf)
94 }
95
96 pub fn get_inner_mut(&mut self) -> (&mut S::InnerStream, &mut BytesMut) {
97 (self.stream.get_inner_mut(), &mut self.buf)
98 }
99}
100
101impl<S> Framed<S>
102where
103 S: FramedRead,
104{
105 pub async fn read_exact(&mut self, length: usize) -> io::Result<BytesMut> {
114 loop {
115 if self.buf.len() >= length {
116 return Ok(self.buf.split_to(length));
117 } else {
118 self.buf
119 .reserve(length.checked_sub(self.buf.len()).expect("length > self.buf.len()"));
120 }
121
122 let len = self.read().await?;
123
124 if len == 0 {
126 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "not enough bytes"));
127 }
128 }
129 }
130
131 pub async fn read_pdu(&mut self) -> io::Result<(ironrdp_pdu::Action, BytesMut)> {
140 loop {
141 match ironrdp_pdu::find_size(self.peek()) {
143 Ok(Some(pdu_info)) => {
144 let frame = self.read_exact(pdu_info.length).await?;
145
146 return Ok((pdu_info.action, frame));
147 }
148 Ok(None) => {
149 let len = self.read().await?;
150
151 if len == 0 {
153 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "not enough bytes"));
154 }
155 }
156 Err(e) => return Err(io::Error::other(e)),
157 };
158 }
159 }
160
161 pub async fn read_by_hint(&mut self, hint: &dyn PduHint) -> io::Result<Bytes> {
170 loop {
171 match hint.find_size(self.peek()).map_err(io::Error::other)? {
172 Some((matched, length)) => {
173 let bytes = self.read_exact(length).await?.freeze();
174 if matched {
175 return Ok(bytes);
176 } else {
177 debug!("Received and lost an unexpected PDU");
178 }
179 }
180 None => {
181 let len = self.read().await?;
182
183 if len == 0 {
185 return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "not enough bytes"));
186 }
187 }
188 };
189 }
190 }
191
192 async fn read(&mut self) -> io::Result<usize> {
200 self.stream.read(&mut self.buf).await
201 }
202}
203
204impl<S> FramedWrite for Framed<S>
205where
206 S: FramedWrite,
207{
208 type WriteAllFut<'write>
209 = S::WriteAllFut<'write>
210 where
211 Self: 'write;
212
213 fn write_all<'a>(&'a mut self, buf: &'a [u8]) -> Self::WriteAllFut<'a> {
223 self.stream.write_all(buf)
224 }
225}
226
227pub async fn single_sequence_step<S>(
228 framed: &mut Framed<S>,
229 sequence: &mut dyn Sequence,
230 buf: &mut WriteBuf,
231) -> ConnectorResult<()>
232where
233 S: FramedWrite + FramedRead,
234{
235 buf.clear();
236 let written = single_sequence_step_read(framed, sequence, buf).await?;
237 single_sequence_step_write(framed, buf, written).await
238}
239
240pub async fn single_sequence_step_read<S>(
241 framed: &mut Framed<S>,
242 sequence: &mut dyn Sequence,
243 buf: &mut WriteBuf,
244) -> ConnectorResult<Written>
245where
246 S: FramedRead,
247{
248 buf.clear();
249
250 if let Some(next_pdu_hint) = sequence.next_pdu_hint() {
251 debug!(
252 connector.state = sequence.state().name(),
253 hint = ?next_pdu_hint,
254 "Wait for PDU"
255 );
256
257 let pdu = framed
258 .read_by_hint(next_pdu_hint)
259 .await
260 .map_err(|e| ironrdp_connector::custom_err!("read frame by hint", e))?;
261
262 trace!(length = pdu.len(), "PDU received");
263
264 sequence.step(&pdu, buf)
265 } else {
266 sequence.step_no_input(buf)
267 }
268}
269
270async fn single_sequence_step_write<S>(
271 framed: &mut Framed<S>,
272 buf: &mut WriteBuf,
273 written: Written,
274) -> ConnectorResult<()>
275where
276 S: FramedWrite,
277{
278 if let Some(response_len) = written.size() {
279 debug_assert_eq!(buf.filled_len(), response_len);
280 let response = buf.filled();
281 trace!(response_len, "Send response");
282 framed
283 .write_all(response)
284 .await
285 .map_err(|e| ironrdp_connector::custom_err!("write all", e))?;
286 }
287
288 Ok(())
289}