Skip to main content

compio_rustls/stream/
monolithic.rs

1use compio_buf::{
2    BufResult,
3    IoBuf,
4    IoBufMut,
5    bytes::BytesMut,
6};
7use compio_io::{
8    AsyncRead,
9    AsyncWrite,
10};
11use rustls::{
12    ConnectionCommon,
13    SideData,
14};
15
16use crate::{
17    DEFAULT_BUF_CAPACITY,
18    stream::util::{
19        flush_tls_writes,
20        process_tls_reads,
21        read_plaintext,
22    },
23};
24
25pub struct TlsStream<S, C> {
26    io:         S,
27    connection: C,
28    read_buf:   Option<BytesMut>,
29    write_buf:  Option<BytesMut>,
30}
31
32#[cfg(unix)]
33use std::os::unix::io::{
34    AsFd,
35    AsRawFd,
36    BorrowedFd,
37    RawFd,
38};
39use std::{
40    io::{
41        self,
42        Write as _,
43    },
44    ops::DerefMut,
45};
46
47#[cfg(unix)]
48impl<S: AsRawFd, C> AsRawFd for TlsStream<S, C> {
49    fn as_raw_fd(&self) -> RawFd {
50        self.io.as_raw_fd()
51    }
52}
53
54#[cfg(unix)]
55impl<S: AsFd, C> AsFd for TlsStream<S, C> {
56    fn as_fd(&self) -> BorrowedFd<'_> {
57        self.io.as_fd()
58    }
59}
60
61#[cfg(windows)]
62use std::os::windows::io::{
63    AsRawSocket,
64    AsSocket,
65    BorrowedSocket,
66    RawSocket,
67};
68
69#[cfg(windows)]
70impl<S: AsRawSocket, C> AsRawSocket for TlsStream<S, C> {
71    fn as_raw_socket(&self) -> RawSocket {
72        self.io.as_raw_socket()
73    }
74}
75
76#[cfg(windows)]
77impl<S: AsSocket, C> AsSocket for TlsStream<S, C> {
78    fn as_socket(&self) -> BorrowedSocket<'_> {
79        self.io.as_socket()
80    }
81}
82
83impl<S, C, SD> TlsStream<S, C>
84where
85    S: AsyncRead + AsyncWrite,
86    C: DerefMut<Target = ConnectionCommon<SD>>,
87    SD: SideData,
88{
89    pub(crate) fn new(io: S, connection: C) -> Self {
90        Self::with_capacity(io, connection, DEFAULT_BUF_CAPACITY)
91    }
92
93    pub(crate) fn with_capacity(io: S, connection: C, capacity: usize) -> Self {
94        Self {
95            io,
96            connection,
97            read_buf: Some(BytesMut::with_capacity(capacity)),
98            write_buf: Some(BytesMut::with_capacity(capacity)),
99        }
100    }
101
102    pub fn get_ref(&self) -> (&S, &C) {
103        (&self.io, &self.connection)
104    }
105
106    pub fn get_mut(&mut self) -> (&mut S, &mut C) {
107        (&mut self.io, &mut self.connection)
108    }
109
110    pub fn into_inner(self) -> (S, C) {
111        (self.io, self.connection)
112    }
113
114    async fn flush_tls_writes(&mut self) -> io::Result<()> {
115        flush_tls_writes(&mut self.connection, &mut self.io, &mut self.write_buf).await
116    }
117
118    async fn fetch_tls_reads(&mut self) -> io::Result<usize> {
119        let mut rbuf = self.read_buf.take().unwrap_or_else(|| BytesMut::with_capacity(4096));
120        if rbuf.buf_len() == rbuf.buf_capacity() {
121            rbuf.reserve(4096);
122        }
123
124        let BufResult(res, b) = self.io.read(rbuf).await;
125        let n = match res {
126            | Ok(0) => {
127                self.read_buf = Some(b);
128                return Err(io::Error::from(io::ErrorKind::UnexpectedEof));
129            },
130            | Ok(n) => n,
131            | Err(e) => {
132                self.read_buf = Some(b);
133                return Err(e);
134            },
135        };
136
137        process_tls_reads(&mut self.connection, b, &mut self.read_buf)?;
138        Ok(n)
139    }
140
141    pub(crate) async fn handshake(&mut self) -> io::Result<()> {
142        while self.connection.is_handshaking() {
143            while self.connection.wants_write() {
144                self.flush_tls_writes().await?;
145            }
146            if self.connection.wants_read() {
147                self.fetch_tls_reads().await?;
148            } else if !self.connection.wants_write() {
149                break;
150            }
151        }
152        Ok(())
153    }
154}
155
156impl<S, C, SD> AsyncRead for TlsStream<S, C>
157where
158    S: AsyncRead + AsyncWrite,
159    C: DerefMut<Target = ConnectionCommon<SD>>,
160    SD: SideData,
161{
162    async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
163        loop {
164            // Attempt to read plaintext
165            buf = match read_plaintext(&mut self.connection, buf) {
166                | Ok(res) => return res,
167                | Err(b) => b,
168            };
169
170            // Drive TLS state machine
171            if self.connection.wants_write() {
172                if let Err(e) = self.flush_tls_writes().await {
173                    return BufResult(Err(e), buf);
174                }
175            }
176
177            if self.connection.wants_read() {
178                if let Err(e) = self.fetch_tls_reads().await {
179                    return BufResult(Err(e), buf);
180                }
181            } else if !self.connection.wants_write() {
182                return BufResult(Ok(0), buf);
183            }
184        }
185    }
186}
187
188impl<S, C, SD> AsyncWrite for TlsStream<S, C>
189where
190    S: AsyncRead + AsyncWrite,
191    C: DerefMut<Target = ConnectionCommon<SD>>,
192    SD: SideData,
193{
194    async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
195        let slice = buf.as_init();
196        let written = match self.connection.writer().write(slice) {
197            | Ok(n) => n,
198            | Err(e) => return BufResult(Err(e), buf),
199        };
200
201        if let Err(e) = self.flush_tls_writes().await {
202            return BufResult(Err(e), buf);
203        }
204
205        BufResult(Ok(written), buf)
206    }
207
208    async fn flush(&mut self) -> io::Result<()> {
209        self.connection.writer().flush()?;
210        self.flush_tls_writes().await?;
211        self.io.flush().await
212    }
213
214    async fn shutdown(&mut self) -> io::Result<()> {
215        self.connection.send_close_notify();
216        self.flush_tls_writes().await?;
217        self.io.shutdown().await
218    }
219}