1#![warn(missing_docs)]
7
8use futures_io::{AsyncRead, AsyncWrite};
9use openssl::{
10 error::ErrorStack,
11 ssl::{self, ErrorCode, ShutdownResult, Ssl, SslRef},
12};
13use std::{
14 fmt, future,
15 io::{self, Read, Write},
16 pin::Pin,
17 task::{Context, Poll, Waker},
18};
19
20#[cfg(test)]
21mod test;
22
23struct StreamWrapper<S: Unpin> {
24 stream: S,
25 waker: Waker,
26}
27
28impl<S> fmt::Debug for StreamWrapper<S>
29where
30 S: fmt::Debug + Unpin,
31{
32 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
33 self.stream.fmt(fmt)
34 }
35}
36
37impl<S: Unpin> StreamWrapper<S> {
38 fn parts(&mut self) -> (Pin<&mut S>, Context<'_>) {
39 let stream = Pin::new(&mut self.stream);
40 let context = Context::from_waker(&self.waker);
41 (stream, context)
42 }
43}
44
45impl<S> Read for StreamWrapper<S>
46where
47 S: AsyncRead + Unpin,
48{
49 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
50 let (stream, mut cx) = self.parts();
51 match stream.poll_read(&mut cx, buf)? {
52 Poll::Ready(nread) => Ok(nread),
53 Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
54 }
55 }
56}
57
58impl<S> Write for StreamWrapper<S>
59where
60 S: AsyncWrite + Unpin,
61{
62 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
63 let (stream, mut cx) = self.parts();
64 match stream.poll_write(&mut cx, buf) {
65 Poll::Ready(r) => r,
66 Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
67 }
68 }
69
70 fn flush(&mut self) -> io::Result<()> {
71 let (stream, mut cx) = self.parts();
72 match stream.poll_flush(&mut cx) {
73 Poll::Ready(r) => r,
74 Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
75 }
76 }
77}
78
79fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
80 match r {
81 Ok(v) => Poll::Ready(Ok(v)),
82 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
83 Err(e) => Poll::Ready(Err(e)),
84 }
85}
86
87fn cvt_ossl<T>(r: Result<T, ssl::Error>) -> Poll<Result<T, ssl::Error>> {
88 match r {
89 Ok(v) => Poll::Ready(Ok(v)),
90 Err(e) => match e.code() {
91 ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Poll::Pending,
92 _ => Poll::Ready(Err(e)),
93 },
94 }
95}
96
97#[derive(Debug)]
99pub struct SslStream<S: Unpin>(ssl::SslStream<StreamWrapper<S>>);
100
101impl<S> SslStream<S>
102where
103 S: AsyncRead + AsyncWrite + Unpin,
104{
105 pub fn new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> {
107 ssl::SslStream::new(
108 ssl,
109 StreamWrapper {
110 stream,
111 waker: Waker::noop().clone(),
112 },
113 )
114 .map(SslStream)
115 }
116
117 pub fn poll_connect(
119 self: Pin<&mut Self>,
120 cx: &mut Context<'_>,
121 ) -> Poll<Result<(), ssl::Error>> {
122 self.with_context(cx, |s| cvt_ossl(s.connect()))
123 }
124
125 pub async fn connect(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
127 future::poll_fn(|cx| self.as_mut().poll_connect(cx)).await
128 }
129
130 pub fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), ssl::Error>> {
132 self.with_context(cx, |s| cvt_ossl(s.accept()))
133 }
134
135 pub async fn accept(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
137 future::poll_fn(|cx| self.as_mut().poll_accept(cx)).await
138 }
139
140 pub fn poll_do_handshake(
142 self: Pin<&mut Self>,
143 cx: &mut Context<'_>,
144 ) -> Poll<Result<(), ssl::Error>> {
145 self.with_context(cx, |s| cvt_ossl(s.do_handshake()))
146 }
147
148 pub async fn do_handshake(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
150 future::poll_fn(|cx| self.as_mut().poll_do_handshake(cx)).await
151 }
152
153 pub fn poll_peek(
155 self: Pin<&mut Self>,
156 cx: &mut Context<'_>,
157 buf: &mut [u8],
158 ) -> Poll<Result<usize, ssl::Error>> {
159 self.with_context(cx, |s| cvt_ossl(s.ssl_peek(buf)))
160 }
161
162 pub async fn peek(mut self: Pin<&mut Self>, buf: &mut [u8]) -> Result<usize, ssl::Error> {
164 future::poll_fn(|cx| self.as_mut().poll_peek(cx, buf)).await
165 }
166
167 #[cfg(ossl111)]
169 pub fn poll_read_early_data(
170 self: Pin<&mut Self>,
171 cx: &mut Context<'_>,
172 buf: &mut [u8],
173 ) -> Poll<Result<usize, ssl::Error>> {
174 self.with_context(cx, |s| cvt_ossl(s.read_early_data(buf)))
175 }
176
177 #[cfg(ossl111)]
179 pub async fn read_early_data(
180 mut self: Pin<&mut Self>,
181 buf: &mut [u8],
182 ) -> Result<usize, ssl::Error> {
183 future::poll_fn(|cx| self.as_mut().poll_read_early_data(cx, buf)).await
184 }
185
186 #[cfg(ossl111)]
188 pub fn poll_write_early_data(
189 self: Pin<&mut Self>,
190 cx: &mut Context<'_>,
191 buf: &[u8],
192 ) -> Poll<Result<usize, ssl::Error>> {
193 self.with_context(cx, |s| cvt_ossl(s.write_early_data(buf)))
194 }
195
196 #[cfg(ossl111)]
198 pub async fn write_early_data(
199 mut self: Pin<&mut Self>,
200 buf: &[u8],
201 ) -> Result<usize, ssl::Error> {
202 future::poll_fn(|cx| self.as_mut().poll_write_early_data(cx, buf)).await
203 }
204}
205
206impl<S: Unpin> SslStream<S> {
207 pub fn ssl(&self) -> &SslRef {
209 self.0.ssl()
210 }
211
212 pub fn get_ref(&self) -> &S {
214 &self.0.get_ref().stream
215 }
216
217 pub fn get_mut(&mut self) -> &mut S {
219 &mut self.0.get_mut().stream
220 }
221
222 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
224 Pin::new(&mut self.get_mut().0.get_mut().stream)
225 }
226
227 fn with_context<F, R>(self: Pin<&mut Self>, ctx: &mut Context<'_>, f: F) -> R
228 where
229 F: FnOnce(&mut ssl::SslStream<StreamWrapper<S>>) -> R,
230 {
231 let this = self.get_mut();
232 this.0.get_mut().waker = ctx.waker().clone();
233 f(&mut this.0)
234 }
235}
236
237impl<S> AsyncRead for SslStream<S>
238where
239 S: AsyncRead + AsyncWrite + Unpin,
240{
241 fn poll_read(
242 self: Pin<&mut Self>,
243 ctx: &mut Context<'_>,
244 buf: &mut [u8],
245 ) -> Poll<io::Result<usize>> {
246 self.with_context(ctx, |s| cvt(s.read(buf)))
247 }
248}
249
250impl<S> AsyncWrite for SslStream<S>
251where
252 S: AsyncRead + AsyncWrite + Unpin,
253{
254 fn poll_write(self: Pin<&mut Self>, ctx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
255 self.with_context(ctx, |s| cvt(s.write(buf)))
256 }
257
258 fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
259 self.with_context(ctx, |s| cvt(s.flush()))
260 }
261
262 fn poll_close(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
263 match self.as_mut().with_context(ctx, |s| s.shutdown()) {
264 Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {}
265 Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {}
266 Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
267 return Poll::Pending;
268 }
269 Err(e) => {
270 return Poll::Ready(Err(e.into_io_error().unwrap_or_else(io::Error::other)));
271 }
272 }
273
274 self.get_pin_mut().poll_close(ctx)
275 }
276}