irontide_wire/mse/
stream.rs1use std::io;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
8
9use super::cipher::Rc4;
10
11const WRITE_BUF_CAPACITY: usize = 16_384;
13
14pub struct MseStream<S> {
18 inner: S,
19 read_cipher: Option<Rc4>,
20 write_cipher: Option<Rc4>,
21 write_buf: Vec<u8>,
22 initial_read: Vec<u8>,
23}
24
25impl<S> MseStream<S> {
26 pub fn plaintext(inner: S) -> Self {
28 Self {
29 inner,
30 read_cipher: None,
31 write_cipher: None,
32 write_buf: Vec::new(),
33 initial_read: Vec::new(),
34 }
35 }
36
37 pub(crate) fn encrypted(
43 inner: S,
44 read_cipher: Rc4,
45 write_cipher: Rc4,
46 initial_read: Vec<u8>,
47 ) -> Self {
48 Self {
49 inner,
50 read_cipher: Some(read_cipher),
51 write_cipher: Some(write_cipher),
52 write_buf: Vec::with_capacity(WRITE_BUF_CAPACITY),
53 initial_read,
54 }
55 }
56}
57
58impl<S: AsyncRead + Unpin> AsyncRead for MseStream<S> {
59 fn poll_read(
60 self: Pin<&mut Self>,
61 cx: &mut Context<'_>,
62 buf: &mut ReadBuf<'_>,
63 ) -> Poll<io::Result<()>> {
64 let this = self.get_mut();
65
66 if !this.initial_read.is_empty() {
67 let to_copy = this.initial_read.len().min(buf.remaining());
68 buf.put_slice(&this.initial_read[..to_copy]);
69 this.initial_read.drain(..to_copy);
70 return Poll::Ready(Ok(()));
71 }
72
73 let before = buf.filled().len();
74
75 match Pin::new(&mut this.inner).poll_read(cx, buf) {
76 Poll::Ready(Ok(())) => {
77 if let Some(cipher) = &mut this.read_cipher {
78 let filled = buf.filled_mut();
79 cipher.apply(&mut filled[before..]);
80 }
81 Poll::Ready(Ok(()))
82 }
83 other => other,
84 }
85 }
86}
87
88impl<S: AsyncWrite + Unpin> AsyncWrite for MseStream<S> {
89 fn poll_write(
90 self: Pin<&mut Self>,
91 cx: &mut Context<'_>,
92 buf: &[u8],
93 ) -> Poll<io::Result<usize>> {
94 let this = self.get_mut();
95
96 if let Some(cipher) = &mut this.write_cipher {
97 this.write_buf.clear();
98 this.write_buf.extend_from_slice(buf);
99 cipher.apply(&mut this.write_buf);
100 Pin::new(&mut this.inner).poll_write(cx, &this.write_buf)
101 } else {
102 Pin::new(&mut this.inner).poll_write(cx, buf)
103 }
104 }
105
106 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
107 Pin::new(&mut self.get_mut().inner).poll_flush(cx)
108 }
109
110 fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
111 Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
112 }
113}
114
115#[cfg(test)]
116mod tests {
117 use super::*;
118 use tokio::io::{AsyncReadExt, AsyncWriteExt};
119
120 #[tokio::test]
121 async fn plaintext_passthrough() {
122 let (client, server) = tokio::io::duplex(1024);
123 let mut client = MseStream::plaintext(client);
124 let mut server = MseStream::plaintext(server);
125
126 client.write_all(b"hello").await.unwrap();
127 client.flush().await.unwrap();
128
129 let mut buf = [0u8; 5];
130 server.read_exact(&mut buf).await.unwrap();
131 assert_eq!(&buf, b"hello");
132 }
133
134 #[tokio::test]
135 async fn encrypted_roundtrip() {
136 let key_a = b"key for direction A!";
137 let key_b = b"key for direction B!";
138
139 let (raw_client, raw_server) = tokio::io::duplex(1024);
140
141 let mut client = MseStream::encrypted(
143 raw_client,
144 Rc4::new(key_b), Rc4::new(key_a), Vec::new(),
147 );
148
149 let mut server = MseStream::encrypted(
151 raw_server,
152 Rc4::new(key_a), Rc4::new(key_b), Vec::new(),
155 );
156
157 client.write_all(b"client to server").await.unwrap();
159 client.flush().await.unwrap();
160
161 let mut buf = [0u8; 16];
162 server.read_exact(&mut buf).await.unwrap();
163 assert_eq!(&buf, b"client to server");
164
165 server.write_all(b"server to client").await.unwrap();
167 server.flush().await.unwrap();
168
169 let mut buf = [0u8; 16];
170 client.read_exact(&mut buf).await.unwrap();
171 assert_eq!(&buf, b"server to client");
172 }
173
174 #[test]
175 fn encrypted_write_buf_pre_allocated() {
176 let (raw, _) = tokio::io::duplex(1024);
177 let stream = MseStream::encrypted(raw, Rc4::new(b"r"), Rc4::new(b"w"), Vec::new());
178 assert_eq!(stream.write_buf.capacity(), WRITE_BUF_CAPACITY);
179 }
180
181 #[tokio::test]
182 async fn encrypted_no_realloc_on_chunk_write() {
183 let (raw_client, _raw_server) = tokio::io::duplex(32768);
184 let mut client =
185 MseStream::encrypted(raw_client, Rc4::new(b"r"), Rc4::new(b"w"), Vec::new());
186 let data = vec![0xABu8; WRITE_BUF_CAPACITY];
187 client.write_all(&data).await.unwrap();
188 assert_eq!(client.write_buf.capacity(), WRITE_BUF_CAPACITY);
189 }
190
191 #[tokio::test]
192 async fn initial_read_drains_before_inner() {
193 let (raw_client, mut raw_server) = tokio::io::duplex(1024);
194
195 let initial = b"overflow".to_vec();
196 let mut client = MseStream::encrypted(raw_client, Rc4::new(b"r"), Rc4::new(b"w"), initial);
197
198 raw_server.write_all(b"inner").await.unwrap();
200 raw_server.flush().await.unwrap();
201
202 let mut buf = [0u8; 8];
204 client.read_exact(&mut buf).await.unwrap();
205 assert_eq!(&buf, b"overflow");
206
207 let mut buf = [0u8; 5];
209 client.read_exact(&mut buf).await.unwrap();
210 assert_eq!(buf.len(), 5);
212 }
213}