monoio_native_tls/
stream.rs

1use std::io::{self, Read, Write};
2
3use monoio::{
4    buf::{IoBuf, IoBufMut, IoVecBuf, IoVecBufMut, RawBuf},
5    io::{AsyncReadRent, AsyncWriteRent, Split},
6    BufResult,
7};
8
9use crate::utils::{Buffers, IOWrapper};
10
11/// A wrapper around an underlying raw stream which implements the TLS or SSL
12/// protocol.
13///
14/// A `TlsStream<S>` represents a handshake that has been completed successfully
15/// and both the server and the client are ready for receiving and sending
16/// data. Bytes read from a `TlsStream` are decrypted from `S` and bytes written
17/// to a `TlsStream` are encrypted when passing through to `S`.
18#[derive(Debug)]
19pub struct TlsStream<S> {
20    tls: native_tls::TlsStream<Buffers>,
21    io: IOWrapper<S>,
22}
23
24impl<S> TlsStream<S> {
25    pub(crate) fn new(tls_stream: native_tls::TlsStream<Buffers>, io: IOWrapper<S>) -> Self {
26        Self {
27            tls: tls_stream,
28            io,
29        }
30    }
31
32    pub fn into_inner(self) -> S {
33        self.io.into_parts().0
34    }
35
36    #[cfg(feature = "alpn")]
37    pub fn alpn_protocol(&self) -> Option<Vec<u8>> {
38        self.tls.negotiated_alpn().ok().flatten()
39    }
40}
41
42unsafe impl<S: Split> Split for TlsStream<S> {}
43
44impl<S: AsyncReadRent> AsyncReadRent for TlsStream<S> {
45    #[allow(clippy::await_holding_refcell_ref)]
46    async fn read<T: IoBufMut>(&mut self, mut buf: T) -> BufResult<usize, T> {
47        let slice = unsafe { std::slice::from_raw_parts_mut(buf.write_ptr(), buf.bytes_total()) };
48
49        loop {
50            // read from native-tls to buffer
51            match self.tls.read(slice) {
52                Ok(n) => {
53                    unsafe { buf.set_init(n) };
54                    return (Ok(n), buf);
55                }
56                // we need more data, read something.
57                Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => (),
58                Err(e) => {
59                    return (Err(e), buf);
60                }
61            }
62
63            // now we need data, read something into native-tls
64            match unsafe { self.io.do_read_io() }.await {
65                Ok(0) => {
66                    return (Ok(0), buf);
67                }
68                Ok(_) => (),
69                Err(e) => {
70                    return (Err(e), buf);
71                }
72            };
73        }
74    }
75
76    async fn readv<T: IoVecBufMut>(&mut self, mut buf: T) -> BufResult<usize, T> {
77        let n = match unsafe { RawBuf::new_from_iovec_mut(&mut buf) } {
78            Some(raw_buf) => self.read(raw_buf).await.0,
79            None => Ok(0),
80        };
81        if let Ok(n) = n {
82            unsafe { buf.set_init(n) };
83        }
84        (n, buf)
85    }
86}
87
88impl<S: AsyncWriteRent> AsyncWriteRent for TlsStream<S> {
89    #[allow(clippy::await_holding_refcell_ref)]
90    async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
91        // construct slice
92        let slice = unsafe { std::slice::from_raw_parts(buf.read_ptr(), buf.bytes_init()) };
93
94        loop {
95            // write slice to native-tls and buffer
96            let maybe_n = match self.tls.write(slice) {
97                Ok(n) => Some(n),
98                Err(e) if e.kind() == io::ErrorKind::WouldBlock => None,
99                Err(e) => return (Err(e), buf),
100            };
101
102            // write from buffer to connection
103            if let Err(e) = unsafe { self.io.do_write_io() }.await {
104                return (Err(e), buf);
105            }
106
107            if let Some(n) = maybe_n {
108                return (Ok(n), buf);
109            }
110        }
111    }
112
113    // TODO: use real writev
114    async fn writev<T: IoVecBuf>(&mut self, buf_vec: T) -> BufResult<usize, T> {
115        let n = match unsafe { RawBuf::new_from_iovec(&buf_vec) } {
116            Some(raw_buf) => self.write(raw_buf).await.0,
117            None => Ok(0),
118        };
119        (n, buf_vec)
120    }
121
122    #[allow(clippy::await_holding_refcell_ref)]
123    async fn flush(&mut self) -> io::Result<()> {
124        loop {
125            match self.tls.flush() {
126                Ok(_) => {
127                    unsafe { self.io.do_write_io() }.await?;
128                    return Ok(());
129                }
130                Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
131                    unsafe { self.io.do_write_io() }.await?;
132                }
133                Err(e) => {
134                    return Err(e);
135                }
136            }
137        }
138    }
139
140    async fn shutdown(&mut self) -> io::Result<()> {
141        self.tls.shutdown()?;
142        unsafe { self.io.do_write_io() }.await?;
143        Ok(())
144    }
145}