Skip to main content

compio_rustls/
stream.rs

1use std::{
2    io::{
3        self,
4        Cursor,
5        Read as _,
6        Write,
7    },
8    ops::DerefMut,
9};
10
11use compio_buf::{
12    BufResult,
13    IoBuf,
14    IoBufMut,
15    bytes::BytesMut,
16};
17use compio_io::{
18    AsyncRead,
19    AsyncWrite,
20};
21use rustls::{
22    ConnectionCommon,
23    SideData,
24};
25
26struct BytesMutWriter<'a>(&'a mut BytesMut);
27
28impl Write for BytesMutWriter<'_> {
29    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
30        self.0.extend_from_slice(buf);
31        Ok(buf.len())
32    }
33
34    fn flush(&mut self) -> io::Result<()> {
35        Ok(())
36    }
37
38    fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
39        self.0.extend_from_slice(buf);
40        Ok(())
41    }
42}
43
44/// A wrapper around an underlying raw stream which implements the TLS or SSL
45/// protocol.
46pub struct TlsStream<S, C> {
47    io:         S,
48    connection: C,
49
50    // Intermediate buffers for ciphertext.
51    read_buf:  Option<BytesMut>,
52    write_buf: Option<BytesMut>,
53}
54
55#[cfg(unix)]
56use std::os::unix::io::{
57    AsFd,
58    AsRawFd,
59    BorrowedFd,
60    RawFd,
61};
62
63#[cfg(unix)]
64impl<S, C> AsRawFd for TlsStream<S, C>
65where
66    S: AsRawFd,
67{
68    fn as_raw_fd(&self) -> RawFd {
69        self.io.as_raw_fd()
70    }
71}
72
73#[cfg(unix)]
74impl<S, C> AsFd for TlsStream<S, C>
75where
76    S: AsFd,
77{
78    fn as_fd(&self) -> BorrowedFd<'_> {
79        self.io.as_fd()
80    }
81}
82
83#[cfg(windows)]
84use std::os::windows::io::{
85    AsRawSocket,
86    AsSocket,
87    BorrowedSocket,
88    RawSocket,
89};
90
91#[cfg(windows)]
92impl<S, C> AsRawSocket for TlsStream<S, C>
93where
94    S: AsRawSocket,
95{
96    fn as_raw_socket(&self) -> RawSocket {
97        self.io.as_raw_socket()
98    }
99}
100
101#[cfg(windows)]
102impl<S, C> AsSocket for TlsStream<S, C>
103where
104    S: AsSocket,
105{
106    fn as_socket(&self) -> BorrowedSocket<'_> {
107        self.io.as_socket()
108    }
109}
110
111impl<S, C, SD> TlsStream<S, C>
112where
113    S: AsyncRead + AsyncWrite,
114    C: DerefMut<Target = ConnectionCommon<SD>>,
115    SD: SideData,
116{
117    pub(crate) fn new(io: S, connection: C) -> Self {
118        Self {
119            io,
120            connection,
121            read_buf: Some(BytesMut::with_capacity(4096)),
122            write_buf: Some(BytesMut::with_capacity(4096)),
123        }
124    }
125
126    pub fn get_ref(&self) -> (&S, &C) {
127        (&self.io, &self.connection)
128    }
129
130    pub fn get_mut(&mut self) -> (&mut S, &mut C) {
131        (&mut self.io, &mut self.connection)
132    }
133
134    pub fn into_inner(self) -> (S, C) {
135        (self.io, self.connection)
136    }
137
138    /// Pull generated ciphertext from `rustls` and pushes it to the underlying
139    /// OS socket.
140    async fn flush_tls_writes(&mut self) -> io::Result<()> {
141        let mut wbuf = self.write_buf.take().unwrap_or_else(|| BytesMut::with_capacity(4096));
142
143        // Drain all pending TLS ciphertext into our intermediate buffer
144        while self.connection.wants_write() {
145            if let Err(e) = self.connection.write_tls(&mut BytesMutWriter(&mut wbuf)) {
146                self.write_buf = Some(wbuf);
147                return Err(e);
148            }
149        }
150
151        // Flush the buffer to the OS, handling potential partial writes
152        while wbuf.buf_len() > 0 {
153            let BufResult(res, mut b) = self.io.write(wbuf).await;
154
155            let n = match res {
156                | Ok(n) => n,
157                | Err(e) => {
158                    self.write_buf = Some(b);
159                    return Err(e);
160                },
161            };
162
163            if n == 0 {
164                self.write_buf = Some(b);
165                return Err(io::Error::new(io::ErrorKind::WriteZero, "failed to write tls data"));
166            }
167
168            let len = b.buf_len();
169            if n == len {
170                // Fully written
171                b.clear();
172                wbuf = b;
173                break;
174            } else {
175                // Partial write: shift remaining unwritten bytes to the front
176                b.copy_within(n .. len, 0);
177                unsafe { b.set_len(len - n) };
178                wbuf = b;
179            }
180        }
181
182        self.write_buf = Some(wbuf);
183        Ok(())
184    }
185
186    /// Pull ciphertext from the OS socket and feeds it into the `rustls`.
187    async fn fetch_tls_reads(&mut self) -> io::Result<usize> {
188        let mut rbuf = self.read_buf.take().unwrap_or_else(|| BytesMut::with_capacity(4096));
189
190        // Ensure we have uninitialized capacity available for the OS to write into
191        if rbuf.buf_len() == rbuf.buf_capacity() {
192            rbuf.reserve(4096);
193        }
194
195        let BufResult(res, mut b) = self.io.read(rbuf).await;
196
197        let n = match res {
198            | Ok(0) => {
199                self.read_buf = Some(b);
200                return Err(io::Error::from(io::ErrorKind::UnexpectedEof));
201            },
202            | Ok(n) => n,
203            | Err(e) => {
204                self.read_buf = Some(b);
205                return Err(e);
206            },
207        };
208
209        // Pass the initialized bytes to rustls
210        let mut cursor = Cursor::new(b.as_init());
211        let read_res = self.connection.read_tls(&mut cursor);
212
213        // Check how many bytes rustls actually consumed
214        let consumed = cursor.position() as usize;
215        let len = b.buf_len();
216
217        if consumed == len {
218            b.clear();
219        } else {
220            // Shift leftover, unconsumed ciphertext to the front for the next read cycle
221            b.copy_within(consumed .. len, 0);
222            unsafe { b.set_len(len - consumed) };
223        }
224
225        self.read_buf = Some(b);
226
227        read_res?;
228
229        self.connection
230            .process_new_packets()
231            .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
232
233        Ok(n)
234    }
235
236    pub(crate) async fn handshake(&mut self) -> io::Result<()> {
237        while self.connection.is_handshaking() {
238            while self.connection.wants_write() {
239                self.flush_tls_writes().await?;
240            }
241            if self.connection.wants_read() {
242                self.fetch_tls_reads().await?;
243            } else if !self.connection.wants_write() {
244                // Handshake is stalled or finished
245                break;
246            }
247        }
248        Ok(())
249    }
250}
251
252impl<S, C, SD> AsyncRead for TlsStream<S, C>
253where
254    S: AsyncRead + AsyncWrite,
255    C: std::ops::DerefMut<Target = ConnectionCommon<SD>>,
256    SD: SideData,
257{
258    async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
259        loop {
260            let init_len = buf.buf_len();
261            let cap = buf.buf_capacity();
262
263            // Prevent undefined behavior or infinite loops if passed a completely full
264            // buffer.
265            if init_len == cap {
266                return BufResult(Ok(0), buf);
267            }
268
269            // Try to yield existing plaintext
270            let mut reader = self.connection.reader();
271
272            // Extract the uninitialized portion of the buffer to write plaintext into.
273            let slice =
274                unsafe { std::slice::from_raw_parts_mut(buf.buf_mut_ptr().cast::<u8>().add(init_len), cap - init_len) };
275
276            match reader.read(slice) {
277                | Ok(n) if n > 0 => unsafe {
278                    // Update the initialized length marker in the compio buffer
279                    buf.advance_to(init_len + n);
280                    return BufResult(Ok(n), buf);
281                },
282                | Err(e) if e.kind() != io::ErrorKind::WouldBlock => return BufResult(Err(e), buf),
283                | _ => {}, // Need more data from the socket
284            }
285
286            // Drive the TLS state machine
287            if self.connection.wants_write() {
288                if let Err(e) = self.flush_tls_writes().await {
289                    return BufResult(Err(e), buf);
290                }
291            }
292
293            if self.connection.wants_read() {
294                if let Err(e) = self.fetch_tls_reads().await {
295                    return BufResult(Err(e), buf);
296                }
297            } else if !self.connection.wants_write() {
298                // Connection closed cleanly or EOF
299                return BufResult(Ok(0), buf);
300            }
301        }
302    }
303}
304
305impl<S, C, SD> AsyncWrite for TlsStream<S, C>
306where
307    S: AsyncRead + AsyncWrite,
308    C: std::ops::DerefMut<Target = ConnectionCommon<SD>>,
309    SD: SideData,
310{
311    async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
312        // Extract the initialized slice representing the plaintext to be sent
313        let slice = buf.as_init();
314
315        let written = match self.connection.writer().write(slice) {
316            | Ok(n) => n,
317            | Err(e) => return BufResult(Err(e), buf),
318        };
319
320        // Push the newly generated ciphertext down to the OS
321        if let Err(e) = self.flush_tls_writes().await {
322            return BufResult(Err(e), buf);
323        }
324
325        BufResult(Ok(written), buf)
326    }
327
328    async fn flush(&mut self) -> io::Result<()> {
329        self.connection.writer().flush()?;
330        self.flush_tls_writes().await?;
331        self.io.flush().await
332    }
333
334    async fn shutdown(&mut self) -> io::Result<()> {
335        self.connection.send_close_notify();
336        self.flush_tls_writes().await?;
337        self.io.shutdown().await
338    }
339}