Skip to main content

compio_rustls/stream/
monolithic.rs

1use compio_buf::{
2    BufResult, IntoInner as _, IoBuf, IoBufMut, bytes::BytesMut
3};
4use compio_io::{
5    AsyncRead,
6    AsyncWrite,
7};
8use rustls::{
9    ConnectionCommon,
10    SideData,
11};
12
13use crate::{
14    DEFAULT_BUF_CAPACITY,
15    stream::util::{
16        flush_tls_writes,
17        process_tls_reads,
18        read_plaintext,
19    },
20};
21
22pub struct TlsStream<S, C> {
23    io:         S,
24    connection: C,
25    read_buf:   Option<BytesMut>,
26    write_buf:  Option<BytesMut>,
27}
28
29#[cfg(unix)]
30use std::os::unix::io::{
31    AsFd,
32    AsRawFd,
33    BorrowedFd,
34    RawFd,
35};
36use std::{
37    io::{
38        self,
39        Write as _,
40    },
41    ops::DerefMut,
42};
43
44#[cfg(unix)]
45impl<S: AsRawFd, C> AsRawFd for TlsStream<S, C> {
46    fn as_raw_fd(&self) -> RawFd {
47        self.io.as_raw_fd()
48    }
49}
50
51#[cfg(unix)]
52impl<S: AsFd, C> AsFd for TlsStream<S, C> {
53    fn as_fd(&self) -> BorrowedFd<'_> {
54        self.io.as_fd()
55    }
56}
57
58#[cfg(windows)]
59use std::os::windows::io::{
60    AsRawSocket,
61    AsSocket,
62    BorrowedSocket,
63    RawSocket,
64};
65
66#[cfg(windows)]
67impl<S: AsRawSocket, C> AsRawSocket for TlsStream<S, C> {
68    fn as_raw_socket(&self) -> RawSocket {
69        self.io.as_raw_socket()
70    }
71}
72
73#[cfg(windows)]
74impl<S: AsSocket, C> AsSocket for TlsStream<S, C> {
75    fn as_socket(&self) -> BorrowedSocket<'_> {
76        self.io.as_socket()
77    }
78}
79
80impl<S, C, SD> TlsStream<S, C>
81where
82    S: AsyncRead + AsyncWrite,
83    C: DerefMut<Target = ConnectionCommon<SD>>,
84    SD: SideData,
85{
86    pub(crate) fn new(io: S, connection: C) -> Self {
87        Self::with_capacity(io, connection, DEFAULT_BUF_CAPACITY)
88    }
89
90    pub(crate) fn with_capacity(io: S, connection: C, capacity: usize) -> Self {
91        Self {
92            io,
93            connection,
94            read_buf: Some(BytesMut::with_capacity(capacity)),
95            write_buf: Some(BytesMut::with_capacity(capacity)),
96        }
97    }
98
99    pub fn get_ref(&self) -> (&S, &C) {
100        (&self.io, &self.connection)
101    }
102
103    pub fn get_mut(&mut self) -> (&mut S, &mut C) {
104        (&mut self.io, &mut self.connection)
105    }
106
107    pub fn into_inner(self) -> (S, C) {
108        (self.io, self.connection)
109    }
110
111    async fn flush_tls_writes(&mut self) -> io::Result<()> {
112        flush_tls_writes(&mut self.connection, &mut self.io, &mut self.write_buf).await
113    }
114
115    async fn fetch_tls_reads(&mut self) -> io::Result<usize> {
116        let mut rbuf = self.read_buf.take().unwrap_or_else(|| BytesMut::with_capacity(4096));
117        if rbuf.buf_len() == rbuf.buf_capacity() {
118            rbuf.reserve(4096);
119        }
120
121        let init_len = rbuf.buf_len();
122        let BufResult(res, slice) = self.io.read(rbuf.slice(init_len..)).await;
123        let mut b = slice.into_inner();
124
125        let n = match res {
126            | Ok(0) => {
127                self.read_buf = Some(b);
128                return Err(io::Error::from(io::ErrorKind::UnexpectedEof));
129            },
130            | Ok(n) => {
131                unsafe { b.set_len(init_len + n) };
132                n
133            },
134            | Err(e) => {
135                self.read_buf = Some(b);
136                return Err(e);
137            },
138        };
139
140        process_tls_reads(&mut self.connection, b, &mut self.read_buf)?;
141        Ok(n)
142    }
143
144    pub(crate) async fn handshake(&mut self) -> io::Result<()> {
145        while self.connection.is_handshaking() {
146            while self.connection.wants_write() {
147                self.flush_tls_writes().await?;
148            }
149            if self.connection.wants_read() {
150                self.fetch_tls_reads().await?;
151            } else if !self.connection.wants_write() {
152                break;
153            }
154        }
155        Ok(())
156    }
157}
158
159impl<S, C, SD> AsyncRead for TlsStream<S, C>
160where
161    S: AsyncRead + AsyncWrite,
162    C: DerefMut<Target = ConnectionCommon<SD>>,
163    SD: SideData,
164{
165    async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
166        loop {
167            // Attempt to read plaintext
168            buf = match read_plaintext(&mut self.connection, buf) {
169                | Ok(res) => return res,
170                | Err(b) => b,
171            };
172
173            // Drive TLS state machine
174            if self.connection.wants_write() {
175                if let Err(e) = self.flush_tls_writes().await {
176                    return BufResult(Err(e), buf);
177                }
178            }
179
180            if self.connection.wants_read() {
181                if let Err(e) = self.fetch_tls_reads().await {
182                    return BufResult(Err(e), buf);
183                }
184            } else if !self.connection.wants_write() {
185                return BufResult(Ok(0), buf);
186            }
187        }
188    }
189}
190
191impl<S, C, SD> AsyncWrite for TlsStream<S, C>
192where
193    S: AsyncRead + AsyncWrite,
194    C: DerefMut<Target = ConnectionCommon<SD>>,
195    SD: SideData,
196{
197    async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
198        let slice = buf.as_init();
199        let written = match self.connection.writer().write(slice) {
200            | Ok(n) => n,
201            | Err(e) => return BufResult(Err(e), buf),
202        };
203
204        if let Err(e) = self.flush_tls_writes().await {
205            return BufResult(Err(e), buf);
206        }
207
208        BufResult(Ok(written), buf)
209    }
210
211    async fn flush(&mut self) -> io::Result<()> {
212        self.connection.writer().flush()?;
213        self.flush_tls_writes().await?;
214        self.io.flush().await
215    }
216
217    async fn shutdown(&mut self) -> io::Result<()> {
218        self.connection.send_close_notify();
219        self.flush_tls_writes().await?;
220        self.io.shutdown().await
221    }
222}