1#![deny(missing_docs, missing_debug_implementations, unsafe_code)]
7#![warn(unreachable_pub, unused_qualifications, unused_lifetimes)]
8#![warn(
9 clippy::must_use_candidate,
10 clippy::unwrap_in_result,
11 clippy::panic_in_result_fn
12)]
13
14use futures_io::{AsyncRead, AsyncWrite};
15use openssl::{
16 error::ErrorStack,
17 ssl::{self, ErrorCode, ShutdownResult, Ssl, SslRef},
18};
19use std::{
20 fmt, future,
21 io::{self, Read, Write},
22 pin::Pin,
23 task::{Context, Poll, Waker},
24};
25
26#[cfg(test)]
27mod test;
28
29struct StreamWrapper<S: Unpin> {
30 stream: S,
31 waker: Option<Waker>,
32}
33
34impl<S> fmt::Debug for StreamWrapper<S>
35where
36 S: fmt::Debug + Unpin,
37{
38 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
39 self.stream.fmt(fmt)
40 }
41}
42
43impl<S: Unpin> StreamWrapper<S> {
44 fn parts(&mut self) -> (Pin<&mut S>, Context<'_>) {
45 let stream = Pin::new(&mut self.stream);
46 let context = Context::from_waker(self.waker.as_ref().unwrap_or(Waker::noop()));
47 (stream, context)
48 }
49}
50
51impl<S> Read for StreamWrapper<S>
52where
53 S: AsyncRead + Unpin,
54{
55 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
56 let (stream, mut cx) = self.parts();
57 match stream.poll_read(&mut cx, buf)? {
58 Poll::Ready(nread) => Ok(nread),
59 Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
60 }
61 }
62}
63
64impl<S> Write for StreamWrapper<S>
65where
66 S: AsyncWrite + Unpin,
67{
68 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
69 let (stream, mut cx) = self.parts();
70 match stream.poll_write(&mut cx, buf) {
71 Poll::Ready(r) => r,
72 Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
73 }
74 }
75
76 fn flush(&mut self) -> io::Result<()> {
77 let (stream, mut cx) = self.parts();
78 match stream.poll_flush(&mut cx) {
79 Poll::Ready(r) => r,
80 Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
81 }
82 }
83}
84
85fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
86 match r {
87 Ok(v) => Poll::Ready(Ok(v)),
88 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
89 Err(e) => Poll::Ready(Err(e)),
90 }
91}
92
93fn cvt_ossl<T>(r: Result<T, ssl::Error>) -> Poll<Result<T, ssl::Error>> {
94 match r {
95 Ok(v) => Poll::Ready(Ok(v)),
96 Err(e) => match e.code() {
97 ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Poll::Pending,
98 _ => Poll::Ready(Err(e)),
99 },
100 }
101}
102
103#[derive(Debug)]
105pub struct SslStream<S: Unpin>(ssl::SslStream<StreamWrapper<S>>);
106
107impl<S> SslStream<S>
108where
109 S: AsyncRead + AsyncWrite + Unpin,
110{
111 pub fn new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> {
113 ssl::SslStream::new(
114 ssl,
115 StreamWrapper {
116 stream,
117 waker: None,
118 },
119 )
120 .map(SslStream)
121 }
122
123 pub fn poll_connect(
125 self: Pin<&mut Self>,
126 cx: &mut Context<'_>,
127 ) -> Poll<Result<(), ssl::Error>> {
128 self.with_context(cx, |s| cvt_ossl(s.connect()))
129 }
130
131 pub async fn connect(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
133 future::poll_fn(|cx| self.as_mut().poll_connect(cx)).await
134 }
135
136 pub fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), ssl::Error>> {
138 self.with_context(cx, |s| cvt_ossl(s.accept()))
139 }
140
141 pub async fn accept(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
143 future::poll_fn(|cx| self.as_mut().poll_accept(cx)).await
144 }
145
146 pub fn poll_do_handshake(
148 self: Pin<&mut Self>,
149 cx: &mut Context<'_>,
150 ) -> Poll<Result<(), ssl::Error>> {
151 self.with_context(cx, |s| cvt_ossl(s.do_handshake()))
152 }
153
154 pub async fn do_handshake(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
156 future::poll_fn(|cx| self.as_mut().poll_do_handshake(cx)).await
157 }
158
159 pub fn poll_peek(
161 self: Pin<&mut Self>,
162 cx: &mut Context<'_>,
163 buf: &mut [u8],
164 ) -> Poll<Result<usize, ssl::Error>> {
165 self.with_context(cx, |s| cvt_ossl(s.ssl_peek(buf)))
166 }
167
168 pub async fn peek(mut self: Pin<&mut Self>, buf: &mut [u8]) -> Result<usize, ssl::Error> {
170 future::poll_fn(|cx| self.as_mut().poll_peek(cx, buf)).await
171 }
172
173 #[cfg(ossl111)]
175 pub fn poll_read_early_data(
176 self: Pin<&mut Self>,
177 cx: &mut Context<'_>,
178 buf: &mut [u8],
179 ) -> Poll<Result<usize, ssl::Error>> {
180 self.with_context(cx, |s| cvt_ossl(s.read_early_data(buf)))
181 }
182
183 #[cfg(ossl111)]
185 pub async fn read_early_data(
186 mut self: Pin<&mut Self>,
187 buf: &mut [u8],
188 ) -> Result<usize, ssl::Error> {
189 future::poll_fn(|cx| self.as_mut().poll_read_early_data(cx, buf)).await
190 }
191
192 #[cfg(ossl111)]
194 pub fn poll_write_early_data(
195 self: Pin<&mut Self>,
196 cx: &mut Context<'_>,
197 buf: &[u8],
198 ) -> Poll<Result<usize, ssl::Error>> {
199 self.with_context(cx, |s| cvt_ossl(s.write_early_data(buf)))
200 }
201
202 #[cfg(ossl111)]
204 pub async fn write_early_data(
205 mut self: Pin<&mut Self>,
206 buf: &[u8],
207 ) -> Result<usize, ssl::Error> {
208 future::poll_fn(|cx| self.as_mut().poll_write_early_data(cx, buf)).await
209 }
210}
211
212impl<S: Unpin> SslStream<S> {
213 pub fn ssl(&self) -> &SslRef {
215 self.0.ssl()
216 }
217
218 pub fn get_ref(&self) -> &S {
220 &self.0.get_ref().stream
221 }
222
223 pub fn get_mut(&mut self) -> &mut S {
225 &mut self.0.get_mut().stream
226 }
227
228 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
230 Pin::new(&mut self.get_mut().0.get_mut().stream)
231 }
232
233 fn with_context<F, R>(self: Pin<&mut Self>, ctx: &mut Context<'_>, f: F) -> R
234 where
235 F: FnOnce(&mut ssl::SslStream<StreamWrapper<S>>) -> R,
236 {
237 let this = self.get_mut();
238 this.0.get_mut().waker = Some(ctx.waker().clone());
239 f(&mut this.0)
240 }
241}
242
243impl<S> AsyncRead for SslStream<S>
244where
245 S: AsyncRead + AsyncWrite + Unpin,
246{
247 fn poll_read(
248 self: Pin<&mut Self>,
249 ctx: &mut Context<'_>,
250 buf: &mut [u8],
251 ) -> Poll<io::Result<usize>> {
252 self.with_context(ctx, |s| cvt(s.read(buf)))
253 }
254}
255
256impl<S> AsyncWrite for SslStream<S>
257where
258 S: AsyncRead + AsyncWrite + Unpin,
259{
260 fn poll_write(self: Pin<&mut Self>, ctx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
261 self.with_context(ctx, |s| cvt(s.write(buf)))
262 }
263
264 fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
265 self.with_context(ctx, |s| cvt(s.flush()))
266 }
267
268 fn poll_close(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
269 match self.as_mut().with_context(ctx, |s| s.shutdown()) {
273 Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {}
274 Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {}
275 Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
276 return Poll::Pending;
277 }
278 Err(e) => {
279 return Poll::Ready(Err(e.into_io_error().unwrap_or_else(io::Error::other)));
280 }
281 }
282
283 self.get_pin_mut().poll_close(ctx)
284 }
285}