1use btls::{
8 error::ErrorStack,
9 ssl::{self, ErrorCode, Ssl, SslRef, SslStream as SslStreamCore},
10};
11use compio::buf::{IoBuf, IoBufMut};
12use compio::BufResult;
13use compio_io::{compat::SyncStream, AsyncRead, AsyncWrite};
14use std::error::Error;
15use std::pin::Pin;
16use std::task::Context;
17use std::task::Poll;
18use std::{fmt, io};
19
20fn cvt_ossl<T>(r: Result<T, ssl::Error>) -> Poll<Result<T, ssl::Error>> {
21 match r {
22 Ok(v) => Poll::Ready(Ok(v)),
23 Err(e) => match e.code() {
24 ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Poll::Pending,
25 _ => Poll::Ready(Err(e)),
26 },
27 }
28}
29
30#[derive(Debug)]
32pub struct SslStream<S>(SslStreamCore<SyncStream<S>>);
33
34impl<S: AsyncRead + AsyncWrite> SslStream<S> {
35 #[inline]
36 pub fn new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> {
38 SslStreamCore::new(ssl, SyncStream::new(stream)).map(SslStream)
39 }
40
41 #[inline]
42 pub fn poll_connect(
44 self: Pin<&mut Self>,
45 cx: &mut Context<'_>,
46 ) -> Poll<Result<(), HandshakeError>> {
47 self.with_context(cx, |s| cvt_ossl(s.connect()))
48 .map_err(HandshakeError::Ssl)
49 }
50
51 #[inline]
52 pub async fn connect(self: Pin<&mut Self>) -> Result<(), HandshakeError> {
54 self.drive_handshake(|s| s.connect()).await
55 }
56
57 #[inline]
58 pub fn poll_accept(
60 self: Pin<&mut Self>,
61 cx: &mut Context<'_>,
62 ) -> Poll<Result<(), HandshakeError>> {
63 self.with_context(cx, |s| cvt_ossl(s.accept()))
64 .map_err(HandshakeError::Ssl)
65 }
66
67 #[inline]
68 pub async fn accept(self: Pin<&mut Self>) -> Result<(), HandshakeError> {
70 self.drive_handshake(|s| s.accept()).await
71 }
72
73 #[inline]
74 pub fn poll_do_handshake(
76 self: Pin<&mut Self>,
77 cx: &mut Context<'_>,
78 ) -> Poll<Result<(), HandshakeError>> {
79 self.with_context(cx, |s| cvt_ossl(s.do_handshake()))
80 .map_err(HandshakeError::Ssl)
81 }
82
83 #[inline]
84 pub async fn do_handshake(self: Pin<&mut Self>) -> Result<(), HandshakeError> {
86 self.drive_handshake(|s| s.do_handshake()).await
87 }
88
89 async fn drive_handshake<F>(mut self: Pin<&mut Self>, mut f: F) -> Result<(), HandshakeError>
90 where
91 F: FnMut(&mut SslStreamCore<SyncStream<S>>) -> Result<(), ssl::Error>,
92 {
93 loop {
94 let res = {
95 let this = unsafe { self.as_mut().get_unchecked_mut() };
96 f(&mut this.0)
97 };
98
99 match res {
100 Ok(()) => {
101 self.as_mut()
103 .flush_write_buf()
104 .await
105 .map_err(HandshakeError::Io)?;
106
107 return Ok(());
108 }
109 Err(e) => match e.code() {
110 ErrorCode::WANT_WRITE => {
111 self.as_mut()
112 .flush_write_buf()
113 .await
114 .map_err(HandshakeError::Io)?;
115 }
116 ErrorCode::WANT_READ => {
117 self.as_mut()
118 .flush_write_buf()
119 .await
120 .map_err(HandshakeError::Io)?;
121
122 self.as_mut()
123 .fill_read_buf()
124 .await
125 .map_err(HandshakeError::Io)?;
126 }
127 _ => return Err(HandshakeError::Ssl(e)),
128 },
129 }
130 }
131 }
132}
133
134impl<S: AsyncRead + AsyncWrite> SslStream<S> {
135 async fn fill_read_buf(mut self: Pin<&mut Self>) -> io::Result<usize> {
136 let this = unsafe { self.as_mut().get_unchecked_mut() };
137 this.0.get_mut().fill_read_buf().await
138 }
139
140 async fn flush_write_buf(mut self: Pin<&mut Self>) -> io::Result<usize> {
141 let this = unsafe { self.as_mut().get_unchecked_mut() };
142 this.0.get_mut().flush_write_buf().await
143 }
144}
145
146impl<S> SslStream<S> {
147 #[inline]
148 pub fn ssl(&self) -> &SslRef {
150 self.0.ssl()
151 }
152
153 #[inline]
154 pub fn get_ref(&self) -> &S {
156 self.0.get_ref().get_ref()
157 }
158
159 #[inline]
160 pub fn get_mut(&mut self) -> &mut S {
162 self.0.get_mut().get_mut()
163 }
164
165 #[inline]
166 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
168 unsafe {
169 let this = self.get_unchecked_mut();
170 Pin::new_unchecked(this.0.get_mut().get_mut())
171 }
172 }
173
174 fn with_context<F, R>(self: Pin<&mut Self>, ctx: &mut Context<'_>, f: F) -> R
175 where
176 F: FnOnce(&mut SslStreamCore<SyncStream<S>>) -> R,
177 {
178 let this = unsafe { self.get_unchecked_mut() };
179 this.0.ssl_mut().set_task_waker(Some(ctx.waker().clone()));
180 let r = f(&mut this.0);
181 this.0.ssl_mut().set_task_waker(None);
182 r
183 }
184}
185
186impl<S> AsyncRead for SslStream<S>
187where
188 S: AsyncRead + AsyncWrite,
189{
190 async fn read<B: IoBufMut>(&mut self, mut buf: B) -> BufResult<usize, B> {
191 let slice = buf.as_uninit();
192 loop {
193 match self.0.read_uninit(slice) {
195 Ok(res) => {
196 unsafe { buf.advance_to(res) };
198 return BufResult(Ok(res), buf);
199 }
200 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
201 match self.0.get_mut().fill_read_buf().await {
202 Ok(_) => continue,
203 Err(e) => return BufResult(Err(e), buf),
204 }
205 }
206 res => return BufResult(res, buf),
207 }
208 }
209 }
210}
211
212impl<S> AsyncWrite for SslStream<S>
213where
214 S: AsyncRead + AsyncWrite,
215{
216 async fn write<T: IoBuf>(&mut self, buf: T) -> BufResult<usize, T> {
217 let slice = buf.as_init();
218 loop {
219 let res = io::Write::write(&mut self.0, slice);
220 match res {
221 Err(e) if e.kind() == io::ErrorKind::WouldBlock => match self.flush().await {
222 Ok(_) => continue,
223 Err(e) => return BufResult(Err(e), buf),
224 },
225 _ => return BufResult(res, buf),
226 }
227 }
228 }
229
230 async fn flush(&mut self) -> io::Result<()> {
231 loop {
232 match io::Write::flush(&mut self.0) {
233 Ok(()) => break,
234 Err(e) if e.kind() == io::ErrorKind::WouldBlock => {
235 self.0.get_mut().flush_write_buf().await?;
236 }
237 Err(e) => return Err(e),
238 }
239 }
240 self.0.get_mut().flush_write_buf().await?;
241 Ok(())
242 }
243
244 async fn shutdown(&mut self) -> io::Result<()> {
245 self.flush().await?;
246 self.0.get_mut().get_mut().shutdown().await
247 }
248}
249
250pub enum HandshakeError {
252 Ssl(ssl::Error),
254 Io(io::Error),
256}
257
258impl HandshakeError {
259 #[must_use]
261 pub fn code(&self) -> Option<ErrorCode> {
262 match self {
263 HandshakeError::Ssl(e) => Some(e.code()),
264 _ => None,
265 }
266 }
267
268 #[must_use]
270 pub fn as_io_error(&self) -> Option<&io::Error> {
271 match self {
272 HandshakeError::Ssl(e) => e.io_error(),
273 HandshakeError::Io(e) => Some(e),
274 }
275 }
276}
277
278impl fmt::Debug for HandshakeError {
279 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
280 match self {
281 HandshakeError::Ssl(e) => fmt::Debug::fmt(e, fmt),
282 HandshakeError::Io(e) => fmt::Debug::fmt(e, fmt),
283 }
284 }
285}
286
287impl fmt::Display for HandshakeError {
288 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
289 match self {
290 HandshakeError::Ssl(e) => fmt::Display::fmt(e, fmt),
291 HandshakeError::Io(e) => fmt::Display::fmt(e, fmt),
292 }
293 }
294}
295
296impl Error for HandshakeError {
297 fn source(&self) -> Option<&(dyn Error + 'static)> {
298 match self {
299 HandshakeError::Ssl(e) => e.source(),
300 HandshakeError::Io(e) => Some(e),
301 }
302 }
303}