1#![deny(missing_docs, missing_debug_implementations, unsafe_code)]
7#![warn(unreachable_pub, unused_qualifications, unused_lifetimes)]
8#![warn(clippy::must_use_candidate, clippy::unwrap_in_result, clippy::panic_in_result_fn)]
9
10use futures_io::{AsyncRead, AsyncWrite};
11use openssl::{
12 error::ErrorStack,
13 ssl::{self, ErrorCode, ShutdownResult, Ssl, SslRef},
14};
15use std::{
16 fmt, future,
17 io::{self, Read, Write},
18 pin::Pin,
19 task::{Context, Poll, Waker},
20};
21
22#[cfg(test)]
23mod test;
24
25struct StreamWrapper<S: Unpin> {
26 stream: S,
27 waker: Waker,
28}
29
30impl<S> fmt::Debug for StreamWrapper<S>
31where
32 S: fmt::Debug + Unpin,
33{
34 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
35 self.stream.fmt(fmt)
36 }
37}
38
39impl<S: Unpin> StreamWrapper<S> {
40 fn parts(&mut self) -> (Pin<&mut S>, Context<'_>) {
41 let stream = Pin::new(&mut self.stream);
42 let context = Context::from_waker(&self.waker);
43 (stream, context)
44 }
45}
46
47impl<S> Read for StreamWrapper<S>
48where
49 S: AsyncRead + Unpin,
50{
51 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
52 let (stream, mut cx) = self.parts();
53 match stream.poll_read(&mut cx, buf)? {
54 Poll::Ready(nread) => Ok(nread),
55 Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
56 }
57 }
58}
59
60impl<S> Write for StreamWrapper<S>
61where
62 S: AsyncWrite + Unpin,
63{
64 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
65 let (stream, mut cx) = self.parts();
66 match stream.poll_write(&mut cx, buf) {
67 Poll::Ready(r) => r,
68 Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
69 }
70 }
71
72 fn flush(&mut self) -> io::Result<()> {
73 let (stream, mut cx) = self.parts();
74 match stream.poll_flush(&mut cx) {
75 Poll::Ready(r) => r,
76 Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
77 }
78 }
79}
80
81fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
82 match r {
83 Ok(v) => Poll::Ready(Ok(v)),
84 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
85 Err(e) => Poll::Ready(Err(e)),
86 }
87}
88
89fn cvt_ossl<T>(r: Result<T, ssl::Error>) -> Poll<Result<T, ssl::Error>> {
90 match r {
91 Ok(v) => Poll::Ready(Ok(v)),
92 Err(e) => match e.code() {
93 ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Poll::Pending,
94 _ => Poll::Ready(Err(e)),
95 },
96 }
97}
98
99#[derive(Debug)]
101pub struct SslStream<S: Unpin>(ssl::SslStream<StreamWrapper<S>>);
102
103impl<S> SslStream<S>
104where
105 S: AsyncRead + AsyncWrite + Unpin,
106{
107 pub fn new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> {
109 ssl::SslStream::new(
110 ssl,
111 StreamWrapper {
112 stream,
113 waker: Waker::noop().clone(),
114 },
115 )
116 .map(SslStream)
117 }
118
119 pub fn poll_connect(
121 self: Pin<&mut Self>,
122 cx: &mut Context<'_>,
123 ) -> Poll<Result<(), ssl::Error>> {
124 self.with_context(cx, |s| cvt_ossl(s.connect()))
125 }
126
127 pub async fn connect(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
129 future::poll_fn(|cx| self.as_mut().poll_connect(cx)).await
130 }
131
132 pub fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), ssl::Error>> {
134 self.with_context(cx, |s| cvt_ossl(s.accept()))
135 }
136
137 pub async fn accept(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
139 future::poll_fn(|cx| self.as_mut().poll_accept(cx)).await
140 }
141
142 pub fn poll_do_handshake(
144 self: Pin<&mut Self>,
145 cx: &mut Context<'_>,
146 ) -> Poll<Result<(), ssl::Error>> {
147 self.with_context(cx, |s| cvt_ossl(s.do_handshake()))
148 }
149
150 pub async fn do_handshake(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
152 future::poll_fn(|cx| self.as_mut().poll_do_handshake(cx)).await
153 }
154
155 pub fn poll_peek(
157 self: Pin<&mut Self>,
158 cx: &mut Context<'_>,
159 buf: &mut [u8],
160 ) -> Poll<Result<usize, ssl::Error>> {
161 self.with_context(cx, |s| cvt_ossl(s.ssl_peek(buf)))
162 }
163
164 pub async fn peek(mut self: Pin<&mut Self>, buf: &mut [u8]) -> Result<usize, ssl::Error> {
166 future::poll_fn(|cx| self.as_mut().poll_peek(cx, buf)).await
167 }
168
169 #[cfg(ossl111)]
171 pub fn poll_read_early_data(
172 self: Pin<&mut Self>,
173 cx: &mut Context<'_>,
174 buf: &mut [u8],
175 ) -> Poll<Result<usize, ssl::Error>> {
176 self.with_context(cx, |s| cvt_ossl(s.read_early_data(buf)))
177 }
178
179 #[cfg(ossl111)]
181 pub async fn read_early_data(
182 mut self: Pin<&mut Self>,
183 buf: &mut [u8],
184 ) -> Result<usize, ssl::Error> {
185 future::poll_fn(|cx| self.as_mut().poll_read_early_data(cx, buf)).await
186 }
187
188 #[cfg(ossl111)]
190 pub fn poll_write_early_data(
191 self: Pin<&mut Self>,
192 cx: &mut Context<'_>,
193 buf: &[u8],
194 ) -> Poll<Result<usize, ssl::Error>> {
195 self.with_context(cx, |s| cvt_ossl(s.write_early_data(buf)))
196 }
197
198 #[cfg(ossl111)]
200 pub async fn write_early_data(
201 mut self: Pin<&mut Self>,
202 buf: &[u8],
203 ) -> Result<usize, ssl::Error> {
204 future::poll_fn(|cx| self.as_mut().poll_write_early_data(cx, buf)).await
205 }
206}
207
208impl<S: Unpin> SslStream<S> {
209 pub fn ssl(&self) -> &SslRef {
211 self.0.ssl()
212 }
213
214 pub fn get_ref(&self) -> &S {
216 &self.0.get_ref().stream
217 }
218
219 pub fn get_mut(&mut self) -> &mut S {
221 &mut self.0.get_mut().stream
222 }
223
224 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
226 Pin::new(&mut self.get_mut().0.get_mut().stream)
227 }
228
229 fn with_context<F, R>(self: Pin<&mut Self>, ctx: &mut Context<'_>, f: F) -> R
230 where
231 F: FnOnce(&mut ssl::SslStream<StreamWrapper<S>>) -> R,
232 {
233 let this = self.get_mut();
234 this.0.get_mut().waker = ctx.waker().clone();
235 f(&mut this.0)
236 }
237}
238
239impl<S> AsyncRead for SslStream<S>
240where
241 S: AsyncRead + AsyncWrite + Unpin,
242{
243 fn poll_read(
244 self: Pin<&mut Self>,
245 ctx: &mut Context<'_>,
246 buf: &mut [u8],
247 ) -> Poll<io::Result<usize>> {
248 self.with_context(ctx, |s| cvt(s.read(buf)))
249 }
250}
251
252impl<S> AsyncWrite for SslStream<S>
253where
254 S: AsyncRead + AsyncWrite + Unpin,
255{
256 fn poll_write(self: Pin<&mut Self>, ctx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
257 self.with_context(ctx, |s| cvt(s.write(buf)))
258 }
259
260 fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
261 self.with_context(ctx, |s| cvt(s.flush()))
262 }
263
264 fn poll_close(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
265 match self.as_mut().with_context(ctx, |s| s.shutdown()) {
269 Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {}
270 Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {}
271 Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
272 return Poll::Pending;
273 }
274 Err(e) => {
275 return Poll::Ready(Err(e.into_io_error().unwrap_or_else(io::Error::other)));
276 }
277 }
278
279 self.get_pin_mut().poll_close(ctx)
280 }
281}