Skip to main content

irontide_wire/mse/
stream.rs

1//! Encrypted stream wrapper implementing `AsyncRead` + `AsyncWrite`.
2
3use std::io;
4use std::pin::Pin;
5use std::task::{Context, Poll};
6
7use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
8
9use super::cipher::Rc4;
10
11/// Default write buffer capacity — matches `DEFAULT_CHUNK_SIZE` (16 KiB).
12const WRITE_BUF_CAPACITY: usize = 16_384;
13
14/// A stream that optionally encrypts/decrypts all data with RC4.
15///
16/// When ciphers are None, data passes through unmodified (plaintext mode).
17pub struct MseStream<S> {
18    inner: S,
19    read_cipher: Option<Rc4>,
20    write_cipher: Option<Rc4>,
21    write_buf: Vec<u8>,
22    initial_read: Vec<u8>,
23}
24
25impl<S> MseStream<S> {
26    /// Create a plaintext stream (no encryption).
27    pub fn plaintext(inner: S) -> Self {
28        Self {
29            inner,
30            read_cipher: None,
31            write_cipher: None,
32            write_buf: Vec::new(),
33            initial_read: Vec::new(),
34        }
35    }
36
37    /// Create an encrypted stream with RC4 ciphers.
38    ///
39    /// `initial_read` contains overflow bytes read past the VC marker during
40    /// the handshake scan. These bytes are drained first on subsequent reads
41    /// before reading from the inner stream.
42    pub(crate) fn encrypted(
43        inner: S,
44        read_cipher: Rc4,
45        write_cipher: Rc4,
46        initial_read: Vec<u8>,
47    ) -> Self {
48        Self {
49            inner,
50            read_cipher: Some(read_cipher),
51            write_cipher: Some(write_cipher),
52            write_buf: Vec::with_capacity(WRITE_BUF_CAPACITY),
53            initial_read,
54        }
55    }
56}
57
58impl<S: AsyncRead + Unpin> AsyncRead for MseStream<S> {
59    fn poll_read(
60        self: Pin<&mut Self>,
61        cx: &mut Context<'_>,
62        buf: &mut ReadBuf<'_>,
63    ) -> Poll<io::Result<()>> {
64        let this = self.get_mut();
65
66        if !this.initial_read.is_empty() {
67            let to_copy = this.initial_read.len().min(buf.remaining());
68            buf.put_slice(&this.initial_read[..to_copy]);
69            this.initial_read.drain(..to_copy);
70            return Poll::Ready(Ok(()));
71        }
72
73        let before = buf.filled().len();
74
75        match Pin::new(&mut this.inner).poll_read(cx, buf) {
76            Poll::Ready(Ok(())) => {
77                if let Some(cipher) = &mut this.read_cipher {
78                    let filled = buf.filled_mut();
79                    cipher.apply(&mut filled[before..]);
80                }
81                Poll::Ready(Ok(()))
82            }
83            other => other,
84        }
85    }
86}
87
88impl<S: AsyncWrite + Unpin> AsyncWrite for MseStream<S> {
89    fn poll_write(
90        self: Pin<&mut Self>,
91        cx: &mut Context<'_>,
92        buf: &[u8],
93    ) -> Poll<io::Result<usize>> {
94        let this = self.get_mut();
95
96        if let Some(cipher) = &mut this.write_cipher {
97            this.write_buf.clear();
98            this.write_buf.extend_from_slice(buf);
99            cipher.apply(&mut this.write_buf);
100            Pin::new(&mut this.inner).poll_write(cx, &this.write_buf)
101        } else {
102            Pin::new(&mut this.inner).poll_write(cx, buf)
103        }
104    }
105
106    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
107        Pin::new(&mut self.get_mut().inner).poll_flush(cx)
108    }
109
110    fn poll_shutdown(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
111        Pin::new(&mut self.get_mut().inner).poll_shutdown(cx)
112    }
113}
114
115#[cfg(test)]
116mod tests {
117    use super::*;
118    use tokio::io::{AsyncReadExt, AsyncWriteExt};
119
120    #[tokio::test]
121    async fn plaintext_passthrough() {
122        let (client, server) = tokio::io::duplex(1024);
123        let mut client = MseStream::plaintext(client);
124        let mut server = MseStream::plaintext(server);
125
126        client.write_all(b"hello").await.unwrap();
127        client.flush().await.unwrap();
128
129        let mut buf = [0u8; 5];
130        server.read_exact(&mut buf).await.unwrap();
131        assert_eq!(&buf, b"hello");
132    }
133
134    #[tokio::test]
135    async fn encrypted_roundtrip() {
136        let key_a = b"key for direction A!";
137        let key_b = b"key for direction B!";
138
139        let (raw_client, raw_server) = tokio::io::duplex(1024);
140
141        // Client: encrypt with A, decrypt with B
142        let mut client = MseStream::encrypted(
143            raw_client,
144            Rc4::new(key_b), // read (decrypt) = B
145            Rc4::new(key_a), // write (encrypt) = A
146            Vec::new(),
147        );
148
149        // Server: decrypt with A, encrypt with B
150        let mut server = MseStream::encrypted(
151            raw_server,
152            Rc4::new(key_a), // read (decrypt) = A
153            Rc4::new(key_b), // write (encrypt) = B
154            Vec::new(),
155        );
156
157        // Client -> Server
158        client.write_all(b"client to server").await.unwrap();
159        client.flush().await.unwrap();
160
161        let mut buf = [0u8; 16];
162        server.read_exact(&mut buf).await.unwrap();
163        assert_eq!(&buf, b"client to server");
164
165        // Server -> Client
166        server.write_all(b"server to client").await.unwrap();
167        server.flush().await.unwrap();
168
169        let mut buf = [0u8; 16];
170        client.read_exact(&mut buf).await.unwrap();
171        assert_eq!(&buf, b"server to client");
172    }
173
174    #[test]
175    fn encrypted_write_buf_pre_allocated() {
176        let (raw, _) = tokio::io::duplex(1024);
177        let stream = MseStream::encrypted(raw, Rc4::new(b"r"), Rc4::new(b"w"), Vec::new());
178        assert_eq!(stream.write_buf.capacity(), WRITE_BUF_CAPACITY);
179    }
180
181    #[tokio::test]
182    async fn encrypted_no_realloc_on_chunk_write() {
183        let (raw_client, _raw_server) = tokio::io::duplex(32768);
184        let mut client =
185            MseStream::encrypted(raw_client, Rc4::new(b"r"), Rc4::new(b"w"), Vec::new());
186        let data = vec![0xABu8; WRITE_BUF_CAPACITY];
187        client.write_all(&data).await.unwrap();
188        assert_eq!(client.write_buf.capacity(), WRITE_BUF_CAPACITY);
189    }
190
191    #[tokio::test]
192    async fn initial_read_drains_before_inner() {
193        let (raw_client, mut raw_server) = tokio::io::duplex(1024);
194
195        let initial = b"overflow".to_vec();
196        let mut client = MseStream::encrypted(raw_client, Rc4::new(b"r"), Rc4::new(b"w"), initial);
197
198        // Write something to the inner stream from the other side
199        raw_server.write_all(b"inner").await.unwrap();
200        raw_server.flush().await.unwrap();
201
202        // First read should return initial_read bytes (plaintext, not decrypted)
203        let mut buf = [0u8; 8];
204        client.read_exact(&mut buf).await.unwrap();
205        assert_eq!(&buf, b"overflow");
206
207        // Second read should come from inner (decrypted)
208        let mut buf = [0u8; 5];
209        client.read_exact(&mut buf).await.unwrap();
210        // The inner bytes are decrypted — just verify we got 5 bytes
211        assert_eq!(buf.len(), 5);
212    }
213}