1use compio_buf::{
2 BufResult, IntoInner as _, IoBuf, IoBufMut, bytes::BytesMut
3};
4use compio_io::{
5 AsyncRead,
6 AsyncWrite,
7};
8use rustls::{
9 ConnectionCommon,
10 SideData,
11};
12
13use crate::{
14 DEFAULT_BUF_CAPACITY,
15 stream::util::{
16 flush_tls_writes,
17 process_tls_reads,
18 read_plaintext,
19 },
20};
21
22pub struct TlsStream<S, C> {
23 io: S,
24 connection: C,
25 read_buf: Option<BytesMut>,
26 write_buf: Option<BytesMut>,
27}
28
29#[cfg(unix)]
30use std::os::unix::io::{
31 AsFd,
32 AsRawFd,
33 BorrowedFd,
34 RawFd,
35};
36use std::{
37 io::{
38 self,
39 Write as _,
40 },
41 ops::DerefMut,
42};
43
44#[cfg(unix)]
45impl<S: AsRawFd, C> AsRawFd for TlsStream<S, C> {
46 fn as_raw_fd(&self) -> RawFd {
47 self.io.as_raw_fd()
48 }
49}
50
51#[cfg(unix)]
52impl<S: AsFd, C> AsFd for TlsStream<S, C> {
53 fn as_fd(&self) -> BorrowedFd<'_> {
54 self.io.as_fd()
55 }
56}
57
58#[cfg(windows)]
59use std::os::windows::io::{
60 AsRawSocket,
61 AsSocket,
62 BorrowedSocket,
63 RawSocket,
64};
65
66#[cfg(windows)]
67impl<S: AsRawSocket, C> AsRawSocket for TlsStream<S, C> {
68 fn as_raw_socket(&self) -> RawSocket {
69 self.io.as_raw_socket()
70 }
71}
72
73#[cfg(windows)]
74impl<S: AsSocket, C> AsSocket for TlsStream<S, C> {
75 fn as_socket(&self) -> BorrowedSocket<'_> {
76 self.io.as_socket()
77 }
78}
79
80impl<S, C, SD> TlsStream<S, C>
81where
82 S: AsyncRead + AsyncWrite,
83 C: DerefMut<Target = ConnectionCommon<SD>>,
84 SD: SideData,
85{
86 pub(crate) fn new(io: S, connection: C) -> Self {
87 Self::with_capacity(io, connection, DEFAULT_BUF_CAPACITY)
88 }
89
90 pub(crate) fn with_capacity(io: S, connection: C, capacity: usize) -> Self {
91 Self {
92 io,
93 connection,
94 read_buf: Some(BytesMut::with_capacity(capacity)),
95 write_buf: Some(BytesMut::with_capacity(capacity)),
96 }
97 }
98
99 pub fn get_ref(&self) -> (&S, &C) {
100 (&self.io, &self.connection)
101 }
102
103 pub fn get_mut(&mut self) -> (&mut S, &mut C) {
104 (&mut self.io, &mut self.connection)
105 }
106
107 pub fn into_inner(self) -> (S, C) {
108 (self.io, self.connection)
109 }
110
111 async fn flush_tls_writes(&mut self) -> io::Result<()> {
112 flush_tls_writes(&mut self.connection, &mut self.io, &mut self.write_buf).await
113 }
114
115 async fn fetch_tls_reads(&mut self) -> io::Result<usize> {
116 let mut rbuf = self.read_buf.take().unwrap_or_else(|| BytesMut::with_capacity(4096));
117 if rbuf.buf_len() == rbuf.buf_capacity() {
118 rbuf.reserve(4096);
119 }
120
121 let init_len = rbuf.buf_len();
122 let BufResult(res, slice) = self.io.read(rbuf.slice(init_len..)).await;
123 let mut b = slice.into_inner();
124
125 let n = match res {
126 | Ok(0) => {
127 self.read_buf = Some(b);
128 return Err(io::Error::from(io::ErrorKind::UnexpectedEof));
129 },
130 | Ok(n) => {
131 unsafe { b.set_len(init_len + n) };
132 n
133 },
134 | Err(e) => {
135 self.read_buf = Some(b);
136 return Err(e);
137 },
138 };
139
140 process_tls_reads(&mut self.connection, b, &mut self.read_buf)?;
141 Ok(n)
142 }
143
144 pub(crate) async fn handshake(&mut self) -> io::Result<()> {
145 while self.connection.is_handshaking() {
146 while self.connection.wants_write() {
147 self.flush_tls_writes().await?;
148 }
149 if self.connection.wants_read() {
150 self.fetch_tls_reads().await?;
151 } else if !self.connection.wants_write() {
152 break;
153 }
154 }
155 Ok(())
156 }
157}
158
159impl<S, C, SD> AsyncRead for TlsStream<S, C>
160where
161 S: AsyncRead + AsyncWrite,
162 C: DerefMut<Target = ConnectionCommon<SD>>,
163 SD: SideData,
164{
165 async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
166 loop {
167 buf = match read_plaintext(&mut self.connection, buf) {
169 | Ok(res) => return res,
170 | Err(b) => b,
171 };
172
173 if self.connection.wants_write() {
175 if let Err(e) = self.flush_tls_writes().await {
176 return BufResult(Err(e), buf);
177 }
178 }
179
180 if self.connection.wants_read() {
181 if let Err(e) = self.fetch_tls_reads().await {
182 return BufResult(Err(e), buf);
183 }
184 } else if !self.connection.wants_write() {
185 return BufResult(Ok(0), buf);
186 }
187 }
188 }
189}
190
191impl<S, C, SD> AsyncWrite for TlsStream<S, C>
192where
193 S: AsyncRead + AsyncWrite,
194 C: DerefMut<Target = ConnectionCommon<SD>>,
195 SD: SideData,
196{
197 async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
198 let slice = buf.as_init();
199 let written = match self.connection.writer().write(slice) {
200 | Ok(n) => n,
201 | Err(e) => return BufResult(Err(e), buf),
202 };
203
204 if let Err(e) = self.flush_tls_writes().await {
205 return BufResult(Err(e), buf);
206 }
207
208 BufResult(Ok(written), buf)
209 }
210
211 async fn flush(&mut self) -> io::Result<()> {
212 self.connection.writer().flush()?;
213 self.flush_tls_writes().await?;
214 self.io.flush().await
215 }
216
217 async fn shutdown(&mut self) -> io::Result<()> {
218 self.connection.send_close_notify();
219 self.flush_tls_writes().await?;
220 self.io.shutdown().await
221 }
222}