ecksport_core/
stream_framing.rs

1use std::future::Future;
2
3use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
4use tracing::*;
5
6use crate::constants::MAX_FRAME_BODY_LEN;
7use crate::errors::ConnError;
8use crate::frame::{FrameBody, FrameType};
9use crate::traits::{AsyncRecvFrame, AsyncSendFrame};
10
11/// Wrapper around a arbitrary reader/writer type to implement corresponding frame
12/// send/recv traits.
13pub struct StreamFramer<T: Sync + Send + Unpin> {
14    inner: T,
15    read_buf: Vec<u8>,
16    write_buf: Vec<u8>,
17}
18
19impl<T: Sync + Send + Unpin> StreamFramer<T> {
20    pub fn new(inner: T) -> Self {
21        Self {
22            inner,
23            read_buf: Vec::new(),
24            write_buf: Vec::new(),
25        }
26    }
27
28    pub fn inner(&self) -> &T {
29        &self.inner
30    }
31
32    pub fn inner_mut(&mut self) -> &mut T {
33        &mut self.inner
34    }
35
36    pub fn into_inner(self) -> T {
37        self.inner
38    }
39}
40
41impl<T: AsyncRead + Sync + Send + Unpin> AsyncRecvFrame for StreamFramer<T> {
42    fn recv_frame_async(&mut self) -> impl Future<Output = Result<FrameBody, ConnError>> + Send {
43        async {
44            // Read the flags, safety check.
45            let flags = self.inner.read_u8().await?;
46            if flags != 0 {
47                return Err(ConnError::UnknownFlags(flags));
48            }
49
50            // Read the type, make sure we recognize it.
51            let ty_tag = self.inner.read_u8().await?;
52            if FrameType::try_from(ty_tag).is_err() {
53                return Err(ConnError::UnknownFrameType(ty_tag));
54            }
55
56            // Read the length, make sure it's not oversized.
57            let len = self.inner.read_u32().await? as usize;
58            if len > MAX_FRAME_BODY_LEN {
59                return Err(ConnError::FrameTooLarge(len));
60            }
61
62            // *now* we can read and parse the vec
63            self.read_buf.resize(len, 0);
64            self.inner.read_exact(&mut self.read_buf).await?;
65            let frame = FrameBody::from_buf(&self.read_buf)?;
66            let ty = frame.ty();
67            trace!(?ty, %len, "recvd frame");
68            self.read_buf.clear();
69
70            // Sanity check.
71            // TODO remove this eventually
72            if u8::from(frame.ty()) != ty_tag {
73                return Err(ConnError::MalformedFrame);
74            }
75
76            Ok(frame)
77        }
78    }
79}
80
81impl<T: AsyncWrite + Sync + Send + Unpin> AsyncSendFrame for StreamFramer<T> {
82    fn send_frame_async(
83        &mut self,
84        body: &FrameBody,
85    ) -> impl Future<Output = Result<(), ConnError>> + Send {
86        async {
87            let ty = body.ty();
88            let flags = 0u8; // TODO support more, like per-frame signing
89            let ty_tag: u8 = ty.into();
90
91            body.into_vec(&mut self.write_buf)?; // FIXME this sucks, use a write buf
92            let len = self.write_buf.len();
93
94            if len > MAX_FRAME_BODY_LEN {
95                return Err(ConnError::FrameTooLarge(len));
96            }
97
98            trace!(?ty, %len, "sending frame");
99
100            self.inner.write_u8(flags).await?;
101            self.inner.write_u8(ty_tag).await?;
102            self.inner.write_u32(len as u32).await?;
103            self.inner.write_all(&self.write_buf).await?;
104            self.write_buf.clear();
105
106            Ok(())
107        }
108    }
109}