1use std::{
2 io::{
3 self,
4 Cursor,
5 Read as _,
6 Write,
7 },
8 ops::DerefMut,
9};
10
11use compio_buf::{
12 BufResult,
13 IoBuf,
14 IoBufMut,
15 bytes::BytesMut,
16};
17use compio_io::{
18 AsyncRead,
19 AsyncWrite,
20};
21use rustls::{
22 ConnectionCommon,
23 SideData,
24};
25
26struct BytesMutWriter<'a>(&'a mut BytesMut);
27
28impl Write for BytesMutWriter<'_> {
29 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
30 self.0.extend_from_slice(buf);
31 Ok(buf.len())
32 }
33
34 fn flush(&mut self) -> io::Result<()> {
35 Ok(())
36 }
37
38 fn write_all(&mut self, buf: &[u8]) -> io::Result<()> {
39 self.0.extend_from_slice(buf);
40 Ok(())
41 }
42}
43
44pub struct TlsStream<S, C> {
47 io: S,
48 connection: C,
49
50 read_buf: Option<BytesMut>,
52 write_buf: Option<BytesMut>,
53}
54
55#[cfg(unix)]
56use std::os::unix::io::{
57 AsFd,
58 AsRawFd,
59 BorrowedFd,
60 RawFd,
61};
62
63#[cfg(unix)]
64impl<S, C> AsRawFd for TlsStream<S, C>
65where
66 S: AsRawFd,
67{
68 fn as_raw_fd(&self) -> RawFd {
69 self.io.as_raw_fd()
70 }
71}
72
73#[cfg(unix)]
74impl<S, C> AsFd for TlsStream<S, C>
75where
76 S: AsFd,
77{
78 fn as_fd(&self) -> BorrowedFd<'_> {
79 self.io.as_fd()
80 }
81}
82
83#[cfg(windows)]
84use std::os::windows::io::{
85 AsRawSocket,
86 AsSocket,
87 BorrowedSocket,
88 RawSocket,
89};
90
91#[cfg(windows)]
92impl<S, C> AsRawSocket for TlsStream<S, C>
93where
94 S: AsRawSocket,
95{
96 fn as_raw_socket(&self) -> RawSocket {
97 self.io.as_raw_socket()
98 }
99}
100
101#[cfg(windows)]
102impl<S, C> AsSocket for TlsStream<S, C>
103where
104 S: AsSocket,
105{
106 fn as_socket(&self) -> BorrowedSocket<'_> {
107 self.io.as_socket()
108 }
109}
110
111impl<S, C, SD> TlsStream<S, C>
112where
113 S: AsyncRead + AsyncWrite,
114 C: DerefMut<Target = ConnectionCommon<SD>>,
115 SD: SideData,
116{
117 pub(crate) fn new(io: S, connection: C) -> Self {
118 Self {
119 io,
120 connection,
121 read_buf: Some(BytesMut::with_capacity(4096)),
122 write_buf: Some(BytesMut::with_capacity(4096)),
123 }
124 }
125
126 pub fn get_ref(&self) -> (&S, &C) {
127 (&self.io, &self.connection)
128 }
129
130 pub fn get_mut(&mut self) -> (&mut S, &mut C) {
131 (&mut self.io, &mut self.connection)
132 }
133
134 pub fn into_inner(self) -> (S, C) {
135 (self.io, self.connection)
136 }
137
138 async fn flush_tls_writes(&mut self) -> io::Result<()> {
141 let mut wbuf = self.write_buf.take().unwrap_or_else(|| BytesMut::with_capacity(4096));
142
143 while self.connection.wants_write() {
145 if let Err(e) = self.connection.write_tls(&mut BytesMutWriter(&mut wbuf)) {
146 self.write_buf = Some(wbuf);
147 return Err(e);
148 }
149 }
150
151 while wbuf.buf_len() > 0 {
153 let BufResult(res, mut b) = self.io.write(wbuf).await;
154
155 let n = match res {
156 | Ok(n) => n,
157 | Err(e) => {
158 self.write_buf = Some(b);
159 return Err(e);
160 },
161 };
162
163 if n == 0 {
164 self.write_buf = Some(b);
165 return Err(io::Error::new(io::ErrorKind::WriteZero, "failed to write tls data"));
166 }
167
168 let len = b.buf_len();
169 if n == len {
170 b.clear();
172 wbuf = b;
173 break;
174 } else {
175 b.copy_within(n .. len, 0);
177 unsafe { b.set_len(len - n) };
178 wbuf = b;
179 }
180 }
181
182 self.write_buf = Some(wbuf);
183 Ok(())
184 }
185
186 async fn fetch_tls_reads(&mut self) -> io::Result<usize> {
188 let mut rbuf = self.read_buf.take().unwrap_or_else(|| BytesMut::with_capacity(4096));
189
190 if rbuf.buf_len() == rbuf.buf_capacity() {
192 rbuf.reserve(4096);
193 }
194
195 let BufResult(res, mut b) = self.io.read(rbuf).await;
196
197 let n = match res {
198 | Ok(0) => {
199 self.read_buf = Some(b);
200 return Err(io::Error::from(io::ErrorKind::UnexpectedEof));
201 },
202 | Ok(n) => n,
203 | Err(e) => {
204 self.read_buf = Some(b);
205 return Err(e);
206 },
207 };
208
209 let mut cursor = Cursor::new(b.as_init());
211 let read_res = self.connection.read_tls(&mut cursor);
212
213 let consumed = cursor.position() as usize;
215 let len = b.buf_len();
216
217 if consumed == len {
218 b.clear();
219 } else {
220 b.copy_within(consumed .. len, 0);
222 unsafe { b.set_len(len - consumed) };
223 }
224
225 self.read_buf = Some(b);
226
227 read_res?;
228
229 self.connection
230 .process_new_packets()
231 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
232
233 Ok(n)
234 }
235
236 pub(crate) async fn handshake(&mut self) -> io::Result<()> {
237 while self.connection.is_handshaking() {
238 while self.connection.wants_write() {
239 self.flush_tls_writes().await?;
240 }
241 if self.connection.wants_read() {
242 self.fetch_tls_reads().await?;
243 } else if !self.connection.wants_write() {
244 break;
246 }
247 }
248 Ok(())
249 }
250}
251
252impl<S, C, SD> AsyncRead for TlsStream<S, C>
253where
254 S: AsyncRead + AsyncWrite,
255 C: std::ops::DerefMut<Target = ConnectionCommon<SD>>,
256 SD: SideData,
257{
258 async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
259 loop {
260 let init_len = buf.buf_len();
261 let cap = buf.buf_capacity();
262
263 if init_len == cap {
266 return BufResult(Ok(0), buf);
267 }
268
269 let mut reader = self.connection.reader();
271
272 let slice =
274 unsafe { std::slice::from_raw_parts_mut(buf.buf_mut_ptr().cast::<u8>().add(init_len), cap - init_len) };
275
276 match reader.read(slice) {
277 | Ok(n) if n > 0 => unsafe {
278 buf.advance_to(init_len + n);
280 return BufResult(Ok(n), buf);
281 },
282 | Err(e) if e.kind() != io::ErrorKind::WouldBlock => return BufResult(Err(e), buf),
283 | _ => {}, }
285
286 if self.connection.wants_write() {
288 if let Err(e) = self.flush_tls_writes().await {
289 return BufResult(Err(e), buf);
290 }
291 }
292
293 if self.connection.wants_read() {
294 if let Err(e) = self.fetch_tls_reads().await {
295 return BufResult(Err(e), buf);
296 }
297 } else if !self.connection.wants_write() {
298 return BufResult(Ok(0), buf);
300 }
301 }
302 }
303}
304
305impl<S, C, SD> AsyncWrite for TlsStream<S, C>
306where
307 S: AsyncRead + AsyncWrite,
308 C: std::ops::DerefMut<Target = ConnectionCommon<SD>>,
309 SD: SideData,
310{
311 async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
312 let slice = buf.as_init();
314
315 let written = match self.connection.writer().write(slice) {
316 | Ok(n) => n,
317 | Err(e) => return BufResult(Err(e), buf),
318 };
319
320 if let Err(e) = self.flush_tls_writes().await {
322 return BufResult(Err(e), buf);
323 }
324
325 BufResult(Ok(written), buf)
326 }
327
328 async fn flush(&mut self) -> io::Result<()> {
329 self.connection.writer().flush()?;
330 self.flush_tls_writes().await?;
331 self.io.flush().await
332 }
333
334 async fn shutdown(&mut self) -> io::Result<()> {
335 self.connection.send_close_notify();
336 self.flush_tls_writes().await?;
337 self.io.shutdown().await
338 }
339}