1use compio_buf::{
2 BufResult,
3 IoBuf,
4 IoBufMut,
5 bytes::BytesMut,
6};
7use compio_io::{
8 AsyncRead,
9 AsyncWrite,
10};
11use rustls::{
12 ConnectionCommon,
13 SideData,
14};
15
16use crate::{
17 DEFAULT_BUF_CAPACITY,
18 stream::util::{
19 flush_tls_writes,
20 process_tls_reads,
21 read_plaintext,
22 },
23};
24
25pub struct TlsStream<S, C> {
26 io: S,
27 connection: C,
28 read_buf: Option<BytesMut>,
29 write_buf: Option<BytesMut>,
30}
31
32#[cfg(unix)]
33use std::os::unix::io::{
34 AsFd,
35 AsRawFd,
36 BorrowedFd,
37 RawFd,
38};
39use std::{
40 io::{
41 self,
42 Write as _,
43 },
44 ops::DerefMut,
45};
46
47#[cfg(unix)]
48impl<S: AsRawFd, C> AsRawFd for TlsStream<S, C> {
49 fn as_raw_fd(&self) -> RawFd {
50 self.io.as_raw_fd()
51 }
52}
53
54#[cfg(unix)]
55impl<S: AsFd, C> AsFd for TlsStream<S, C> {
56 fn as_fd(&self) -> BorrowedFd<'_> {
57 self.io.as_fd()
58 }
59}
60
61#[cfg(windows)]
62use std::os::windows::io::{
63 AsRawSocket,
64 AsSocket,
65 BorrowedSocket,
66 RawSocket,
67};
68
69#[cfg(windows)]
70impl<S: AsRawSocket, C> AsRawSocket for TlsStream<S, C> {
71 fn as_raw_socket(&self) -> RawSocket {
72 self.io.as_raw_socket()
73 }
74}
75
76#[cfg(windows)]
77impl<S: AsSocket, C> AsSocket for TlsStream<S, C> {
78 fn as_socket(&self) -> BorrowedSocket<'_> {
79 self.io.as_socket()
80 }
81}
82
83impl<S, C, SD> TlsStream<S, C>
84where
85 S: AsyncRead + AsyncWrite,
86 C: DerefMut<Target = ConnectionCommon<SD>>,
87 SD: SideData,
88{
89 pub(crate) fn new(io: S, connection: C) -> Self {
90 Self::with_capacity(io, connection, DEFAULT_BUF_CAPACITY)
91 }
92
93 pub(crate) fn with_capacity(io: S, connection: C, capacity: usize) -> Self {
94 Self {
95 io,
96 connection,
97 read_buf: Some(BytesMut::with_capacity(capacity)),
98 write_buf: Some(BytesMut::with_capacity(capacity)),
99 }
100 }
101
102 pub fn get_ref(&self) -> (&S, &C) {
103 (&self.io, &self.connection)
104 }
105
106 pub fn get_mut(&mut self) -> (&mut S, &mut C) {
107 (&mut self.io, &mut self.connection)
108 }
109
110 pub fn into_inner(self) -> (S, C) {
111 (self.io, self.connection)
112 }
113
114 async fn flush_tls_writes(&mut self) -> io::Result<()> {
115 flush_tls_writes(&mut self.connection, &mut self.io, &mut self.write_buf).await
116 }
117
118 async fn fetch_tls_reads(&mut self) -> io::Result<usize> {
119 let mut rbuf = self.read_buf.take().unwrap_or_else(|| BytesMut::with_capacity(4096));
120 if rbuf.buf_len() == rbuf.buf_capacity() {
121 rbuf.reserve(4096);
122 }
123
124 let BufResult(res, b) = self.io.read(rbuf).await;
125 let n = match res {
126 | Ok(0) => {
127 self.read_buf = Some(b);
128 return Err(io::Error::from(io::ErrorKind::UnexpectedEof));
129 },
130 | Ok(n) => n,
131 | Err(e) => {
132 self.read_buf = Some(b);
133 return Err(e);
134 },
135 };
136
137 process_tls_reads(&mut self.connection, b, &mut self.read_buf)?;
138 Ok(n)
139 }
140
141 pub(crate) async fn handshake(&mut self) -> io::Result<()> {
142 while self.connection.is_handshaking() {
143 while self.connection.wants_write() {
144 self.flush_tls_writes().await?;
145 }
146 if self.connection.wants_read() {
147 self.fetch_tls_reads().await?;
148 } else if !self.connection.wants_write() {
149 break;
150 }
151 }
152 Ok(())
153 }
154}
155
156impl<S, C, SD> AsyncRead for TlsStream<S, C>
157where
158 S: AsyncRead + AsyncWrite,
159 C: DerefMut<Target = ConnectionCommon<SD>>,
160 SD: SideData,
161{
162 async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
163 loop {
164 buf = match read_plaintext(&mut self.connection, buf) {
166 | Ok(res) => return res,
167 | Err(b) => b,
168 };
169
170 if self.connection.wants_write() {
172 if let Err(e) = self.flush_tls_writes().await {
173 return BufResult(Err(e), buf);
174 }
175 }
176
177 if self.connection.wants_read() {
178 if let Err(e) = self.fetch_tls_reads().await {
179 return BufResult(Err(e), buf);
180 }
181 } else if !self.connection.wants_write() {
182 return BufResult(Ok(0), buf);
183 }
184 }
185 }
186}
187
188impl<S, C, SD> AsyncWrite for TlsStream<S, C>
189where
190 S: AsyncRead + AsyncWrite,
191 C: DerefMut<Target = ConnectionCommon<SD>>,
192 SD: SideData,
193{
194 async fn write<B: IoBuf>(&mut self, buf: B) -> BufResult<usize, B> {
195 let slice = buf.as_init();
196 let written = match self.connection.writer().write(slice) {
197 | Ok(n) => n,
198 | Err(e) => return BufResult(Err(e), buf),
199 };
200
201 if let Err(e) = self.flush_tls_writes().await {
202 return BufResult(Err(e), buf);
203 }
204
205 BufResult(Ok(written), buf)
206 }
207
208 async fn flush(&mut self) -> io::Result<()> {
209 self.connection.writer().flush()?;
210 self.flush_tls_writes().await?;
211 self.io.flush().await
212 }
213
214 async fn shutdown(&mut self) -> io::Result<()> {
215 self.connection.send_close_notify();
216 self.flush_tls_writes().await?;
217 self.io.shutdown().await
218 }
219}