1#![warn(missing_docs)]
7
8use async_std::io::{Read as AsyncRead, Write as AsyncWrite};
9use futures_util::future;
10use openssl::error::ErrorStack;
11use openssl::ssl::{self, ErrorCode, ShutdownResult, Ssl, SslRef};
12use std::fmt;
13use std::io::{self, Read, Write};
14use std::pin::Pin;
15use std::task::{Context, Poll};
17
18pub mod tls_stream_wrapper;
19pub use crate::tls_stream_wrapper::SslStreamWrapper;
20
21#[cfg(test)]
22mod test;
23
24struct StreamWrapper<S> {
25 stream: S,
26 context: usize,
27}
28
29impl<S> fmt::Debug for StreamWrapper<S>
30where
31 S: fmt::Debug,
32{
33 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
34 fmt::Debug::fmt(&self.stream, fmt)
35 }
36}
37
38impl<S> StreamWrapper<S> {
39 unsafe fn parts(&mut self) -> (Pin<&mut S>, &mut Context<'_>) {
44 debug_assert_ne!(self.context, 0);
45 let stream = Pin::new_unchecked(&mut self.stream);
46 let context = &mut *(self.context as *mut _);
47 (stream, context)
48 }
49}
50
51impl<S> Read for StreamWrapper<S>
52where
53 S: AsyncRead,
54{
55 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
56 let (stream, cx) = unsafe { self.parts() };
57 match stream.poll_read(cx, buf)? {
59 Poll::Ready(num_bytes_read) => Ok(num_bytes_read),
60 Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
61 }
62 }
63}
64
65impl<S> Write for StreamWrapper<S>
66where
67 S: AsyncWrite,
68{
69 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
70 let (stream, cx) = unsafe { self.parts() };
71 match stream.poll_write(cx, buf) {
72 Poll::Ready(r) => r,
73 Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
74 }
75 }
76
77 fn flush(&mut self) -> io::Result<()> {
78 let (stream, cx) = unsafe { self.parts() };
79 match stream.poll_flush(cx) {
80 Poll::Ready(r) => r,
81 Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)),
82 }
83 }
84}
85
86fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
87 match r {
88 Ok(v) => Poll::Ready(Ok(v)),
89 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
90 Err(e) => Poll::Ready(Err(e)),
91 }
92}
93
94fn cvt_ossl<T>(r: Result<T, ssl::Error>) -> Poll<Result<T, ssl::Error>> {
95 match r {
96 Ok(v) => Poll::Ready(Ok(v)),
97 Err(e) => match e.code() {
98 ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Poll::Pending,
99 _ => Poll::Ready(Err(e)),
100 },
101 }
102}
103
104#[derive(Debug)]
106pub struct SslStream<S>(ssl::SslStream<StreamWrapper<S>>);
107
108impl<S> SslStream<S>
109where
110 S: AsyncRead + AsyncWrite,
111{
112 pub fn new(ssl: Ssl, stream: S) -> Result<Self, ErrorStack> {
114 ssl::SslStream::new(ssl, StreamWrapper { stream, context: 0 }).map(SslStream)
115 }
116
117 pub fn poll_connect(
119 self: Pin<&mut Self>,
120 cx: &mut Context<'_>,
121 ) -> Poll<Result<(), ssl::Error>> {
122 self.with_context(cx, |s| cvt_ossl(s.connect()))
123 }
124
125 pub async fn connect(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
127 future::poll_fn(|cx| self.as_mut().poll_connect(cx)).await
128 }
129
130 pub fn poll_accept(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), ssl::Error>> {
132 self.with_context(cx, |s| cvt_ossl(s.accept()))
133 }
134
135 pub async fn accept(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
137 future::poll_fn(|cx| self.as_mut().poll_accept(cx)).await
138 }
139
140 pub fn poll_do_handshake(
142 self: Pin<&mut Self>,
143 cx: &mut Context<'_>,
144 ) -> Poll<Result<(), ssl::Error>> {
145 self.with_context(cx, |s| cvt_ossl(s.do_handshake()))
146 }
147
148 pub async fn do_handshake(mut self: Pin<&mut Self>) -> Result<(), ssl::Error> {
150 future::poll_fn(|cx| self.as_mut().poll_do_handshake(cx)).await
151 }
152
153 #[cfg(ossl111)]
155 pub fn poll_read_early_data(
156 self: Pin<&mut Self>,
157 cx: &mut Context<'_>,
158 buf: &mut [u8],
159 ) -> Poll<Result<usize, ssl::Error>> {
160 self.with_context(cx, |s| cvt_ossl(s.read_early_data(buf)))
161 }
162
163 #[cfg(ossl111)]
165 pub async fn read_early_data(
166 mut self: Pin<&mut Self>,
167 buf: &mut [u8],
168 ) -> Result<usize, ssl::Error> {
169 future::poll_fn(|cx| self.as_mut().poll_read_early_data(cx, buf)).await
170 }
171
172 #[cfg(ossl111)]
174 pub fn poll_write_early_data(
175 self: Pin<&mut Self>,
176 cx: &mut Context<'_>,
177 buf: &[u8],
178 ) -> Poll<Result<usize, ssl::Error>> {
179 self.with_context(cx, |s| cvt_ossl(s.write_early_data(buf)))
180 }
181
182 #[cfg(ossl111)]
184 pub async fn write_early_data(
185 mut self: Pin<&mut Self>,
186 buf: &[u8],
187 ) -> Result<usize, ssl::Error> {
188 future::poll_fn(|cx| self.as_mut().poll_write_early_data(cx, buf)).await
189 }
190}
191
192impl<S> SslStream<S> {
193 pub fn ssl(&self) -> &SslRef {
195 self.0.ssl()
196 }
197
198 pub fn get_ref(&self) -> &S {
200 &self.0.get_ref().stream
201 }
202
203 pub fn get_mut(&mut self) -> &mut S {
205 &mut self.0.get_mut().stream
206 }
207
208 pub fn get_pin_mut(self: Pin<&mut Self>) -> Pin<&mut S> {
210 unsafe { Pin::new_unchecked(&mut self.get_unchecked_mut().0.get_mut().stream) }
211 }
212
213 fn with_context<F, R>(self: Pin<&mut Self>, ctx: &mut Context<'_>, f: F) -> R
214 where
215 F: FnOnce(&mut ssl::SslStream<StreamWrapper<S>>) -> R,
216 {
217 let this = unsafe { self.get_unchecked_mut() };
218 this.0.get_mut().context = ctx as *mut _ as usize;
219 let r = f(&mut this.0);
220 this.0.get_mut().context = 0;
221 r
222 }
223}
224
225impl<S> AsyncRead for SslStream<S>
226where
227 S: AsyncRead + AsyncWrite,
228{
229 fn poll_read(
230 self: Pin<&mut Self>,
231 ctx: &mut Context<'_>,
232 buf: &mut [u8],
233 ) -> Poll<io::Result<usize>> {
234 self.with_context(ctx, |s| {
235 match cvt(s.read(buf))? {
244 Poll::Ready(nread) => {
245 Poll::Ready(Ok(nread))
250 }
251 Poll::Pending => Poll::Pending,
252 }
253 })
254 }
255}
256
257impl<S> AsyncWrite for SslStream<S>
258where
259 S: AsyncRead + AsyncWrite,
260{
261 fn poll_write(self: Pin<&mut Self>, ctx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
262 self.with_context(ctx, |s| cvt(s.write(buf)))
263 }
264
265 fn poll_flush(self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
266 self.with_context(ctx, |s| cvt(s.flush()))
267 }
268
269 fn poll_close(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
270 match self.as_mut().with_context(ctx, |s| s.shutdown()) {
271 Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {}
272 Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {}
273 Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
274 return Poll::Pending;
275 }
276 Err(e) => {
277 return Poll::Ready(Err(e
278 .into_io_error()
279 .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))));
280 }
281 }
282
283 self.get_pin_mut().poll_close(ctx)
284 }
285}