1#![warn(missing_docs)]
7
8use futures_io::{AsyncRead, AsyncWrite};
9use openssl::{
10 error::ErrorStack,
11 ssl::{self, ErrorCode, ShutdownResult, Ssl, SslRef},
12};
13use std::{
14 fmt, future,
15 io::{self, Read, Write},
16 pin::Pin,
17 task::{Context, Poll},
18};
19
20#[cfg(test)]
21mod test;
22
23struct StreamWrapper<S> {
24 stream: S,
25 context: usize,
26}
27
28impl<S> fmt::Debug for StreamWrapper<S>
29where
30 S: fmt::Debug,
31{
32 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
33 self.stream.fmt(fmt)
34 }
35}
36
37impl<S> StreamWrapper<S> {
38 unsafe fn parts(&mut self) -> (Pin<&mut S>, &mut Context<'_>) {
43 debug_assert_ne!(self.context, 0);
44 let stream = Pin::new_unchecked(&mut self.stream);
45 let context = &mut *(self.context as *mut _);
46 (stream, context)
47 }
48}
49
50impl<S> Read for StreamWrapper<S>
51where
52 S: AsyncRead,
53{
54 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
55 let (stream, cx) = unsafe { self.parts() };
56 match stream.poll_read(cx, buf)? {
57 Poll::Ready(nread) => Ok(nread),
58 Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
59 }
60 }
61}
62
63impl<S> Write for StreamWrapper<S>
64where
65 S: AsyncWrite,
66{
67 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
68 let (stream, cx) = unsafe { self.parts() };
69 match stream.poll_write(cx, buf) {
70 Poll::Ready(r) => r,
71 Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
72 }
73 }
74
75 fn flush(&mut self) -> io::Result<()> {
76 let (stream, cx) = unsafe { self.parts() };
77 match stream.poll_flush(cx) {
78 Poll::Ready(r) => r,
79 Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
80 }
81 }
82}
83
84fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
85 match r {
86 Ok(v) => Poll::Ready(Ok(v)),
87 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
88 Err(e) => Poll::Ready(Err(e)),
89 }
90}
91
92fn cvt_ossl<T>(r: Result<T, ssl::Error>) -> Poll<Result<T, ssl::Error>> {
93 match r {
94 Ok(v) => Poll::Ready(Ok(v)),
95 Err(e) => match e.code() {
96 ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Poll::Pending,
97 _ => Poll::Ready(Err(e)),
98 },
99 }
100}
101
102#[derive(Debug)]
104pub struct SslStream<S>(ssl::SslStream<StreamWrapper<S>>);
105
106impl<S> SslStream<S>
107where
108 S: AsyncRead + AsyncWrite,
109{
110 pub fn new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> {
112 ssl::SslStream::new(ssl, StreamWrapper { stream, context: 0 }).map(SslStream)
113 }
114
115 pub fn poll_connect(
117 self: Pin<&mut Self>,
118 cx: &mut Context<'_>,
119 ) -> Poll<Result<(), ssl::Error>> {
120 self.with_context(cx, |s| cvt_ossl(s.connect()))
121 }
122
123 pub async fn connect(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
125 future::poll_fn(|cx| self.as_mut().poll_connect(cx)).await
126 }
127
128 pub fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), ssl::Error>> {
130 self.with_context(cx, |s| cvt_ossl(s.accept()))
131 }
132
133 pub async fn accept(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
135 future::poll_fn(|cx| self.as_mut().poll_accept(cx)).await
136 }
137
138 pub fn poll_do_handshake(
140 self: Pin<&mut Self>,
141 cx: &mut Context<'_>,
142 ) -> Poll<Result<(), ssl::Error>> {
143 self.with_context(cx, |s| cvt_ossl(s.do_handshake()))
144 }
145
146 pub async fn do_handshake(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
148 future::poll_fn(|cx| self.as_mut().poll_do_handshake(cx)).await
149 }
150
151 pub fn poll_peek(
153 self: Pin<&mut Self>,
154 cx: &mut Context<'_>,
155 buf: &mut [u8],
156 ) -> Poll<Result<usize, ssl::Error>> {
157 self.with_context(cx, |s| cvt_ossl(s.ssl_peek(buf)))
158 }
159
160 pub async fn peek(mut self: Pin<&mut Self>, buf: &mut [u8]) -> Result<usize, ssl::Error> {
162 future::poll_fn(|cx| self.as_mut().poll_peek(cx, buf)).await
163 }
164
165 #[cfg(ossl111)]
167 pub fn poll_read_early_data(
168 self: Pin<&mut Self>,
169 cx: &mut Context<'_>,
170 buf: &mut [u8],
171 ) -> Poll<Result<usize, ssl::Error>> {
172 self.with_context(cx, |s| cvt_ossl(s.read_early_data(buf)))
173 }
174
175 #[cfg(ossl111)]
177 pub async fn read_early_data(
178 mut self: Pin<&mut Self>,
179 buf: &mut [u8],
180 ) -> Result<usize, ssl::Error> {
181 future::poll_fn(|cx| self.as_mut().poll_read_early_data(cx, buf)).await
182 }
183
184 #[cfg(ossl111)]
186 pub fn poll_write_early_data(
187 self: Pin<&mut Self>,
188 cx: &mut Context<'_>,
189 buf: &[u8],
190 ) -> Poll<Result<usize, ssl::Error>> {
191 self.with_context(cx, |s| cvt_ossl(s.write_early_data(buf)))
192 }
193
194 #[cfg(ossl111)]
196 pub async fn write_early_data(
197 mut self: Pin<&mut Self>,
198 buf: &[u8],
199 ) -> Result<usize, ssl::Error> {
200 future::poll_fn(|cx| self.as_mut().poll_write_early_data(cx, buf)).await
201 }
202}
203
204impl<S> SslStream<S> {
205 pub fn ssl(&self) -> &SslRef {
207 self.0.ssl()
208 }
209
210 pub fn get_ref(&self) -> &S {
212 &self.0.get_ref().stream
213 }
214
215 pub fn get_mut(&mut self) -> &mut S {
217 &mut self.0.get_mut().stream
218 }
219
220 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
222 unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().0.get_mut().stream) }
223 }
224
225 fn with_context<F, R>(self: Pin<&mut Self>, ctx: &mut Context<'_>, f: F) -> R
226 where
227 F: FnOnce(&mut ssl::SslStream<StreamWrapper<S>>) -> R,
228 {
229 let this = unsafe { self.get_unchecked_mut() };
230 this.0.get_mut().context = ctx as *mut _ as usize;
231 let r = f(&mut this.0);
232 this.0.get_mut().context = 0;
233 r
234 }
235}
236
237impl<S> AsyncRead for SslStream<S>
238where
239 S: AsyncRead + AsyncWrite,
240{
241 fn poll_read(
242 self: Pin<&mut Self>,
243 ctx: &mut Context<'_>,
244 buf: &mut [u8],
245 ) -> Poll<io::Result<usize>> {
246 self.with_context(ctx, |s| cvt(s.read(buf)))
247 }
248}
249
250impl<S> AsyncWrite for SslStream<S>
251where
252 S: AsyncRead + AsyncWrite,
253{
254 fn poll_write(self: Pin<&mut Self>, ctx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
255 self.with_context(ctx, |s| cvt(s.write(buf)))
256 }
257
258 fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
259 self.with_context(ctx, |s| cvt(s.flush()))
260 }
261
262 fn poll_close(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
263 match self.as_mut().with_context(ctx, |s| s.shutdown()) {
264 Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {}
265 Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {}
266 Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
267 return Poll::Pending;
268 }
269 Err(e) => {
270 return Poll::Ready(Err(e.into_io_error().unwrap_or_else(io::Error::other)));
271 }
272 }
273
274 self.get_pin_mut().poll_close(ctx)
275 }
276}