1use boring::ssl;
2pub use boring::ssl::*;
3
4mod bridge;
5mod callbacks;
6
7use futures::{AsyncRead, AsyncWrite};
8
9use std::fmt;
10use std::future::Future;
11use std::io::{self, Read, Write};
12use std::pin::Pin;
13use std::task::{Context, Poll};
14
15use bridge::*;
16
17pub use callbacks::SslContextBuilderExt;
18
19pub use boring::ssl::{
20 AsyncPrivateKeyMethod, AsyncPrivateKeyMethodError, AsyncSelectCertError, BoxGetSessionFinish,
21 BoxGetSessionFuture, BoxPrivateKeyMethodFinish, BoxPrivateKeyMethodFuture, BoxSelectCertFinish,
22 BoxSelectCertFuture, ExDataFuture,
23};
24
25pub async fn connect<S>(
30 config: ConnectConfiguration,
31 domain: &str,
32 stream: S,
33) -> Result<SslStream<S>, HandshakeError<S>>
34where
35 S: AsyncRead + AsyncWrite + Unpin,
36{
37 let mid_handshake = config
38 .setup_connect(domain, AsyncStreamBridge::new(stream))
39 .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?;
40
41 HandshakeFuture(Some(mid_handshake)).await
42}
43
44pub async fn accept<S>(acceptor: &SslAcceptor, stream: S) -> Result<SslStream<S>, HandshakeError<S>>
49where
50 S: AsyncRead + AsyncWrite + Unpin,
51{
52 let mid_handshake = acceptor
53 .setup_accept(AsyncStreamBridge::new(stream))
54 .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?;
55
56 HandshakeFuture(Some(mid_handshake)).await
57}
58
59fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
60 match r {
61 Ok(v) => Poll::Ready(Ok(v)),
62 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
63 Err(e) => Poll::Ready(Err(e)),
64 }
65}
66
67pub struct SslStreamBuilder<S> {
69 inner: ssl::SslStreamBuilder<AsyncStreamBridge<S>>,
70}
71
72impl<S> SslStreamBuilder<S>
73where
74 S: AsyncRead + AsyncWrite + Unpin,
75{
76 pub fn new(ssl: ssl::Ssl, stream: S) -> Self {
78 Self {
79 inner: ssl::SslStreamBuilder::new(ssl, AsyncStreamBridge::new(stream)),
80 }
81 }
82
83 pub async fn accept(self) -> Result<SslStream<S>, HandshakeError<S>> {
85 let mid_handshake = self.inner.setup_accept();
86
87 HandshakeFuture(Some(mid_handshake)).await
88 }
89
90 pub async fn connect(self) -> Result<SslStream<S>, HandshakeError<S>> {
92 let mid_handshake = self.inner.setup_connect();
93
94 HandshakeFuture(Some(mid_handshake)).await
95 }
96}
97
98impl<S> SslStreamBuilder<S> {
99 pub fn ssl(&self) -> &SslRef {
101 self.inner.ssl()
102 }
103
104 pub fn ssl_mut(&mut self) -> &mut SslRef {
106 self.inner.ssl_mut()
107 }
108}
109
110#[derive(Debug)]
118pub struct SslStream<S>(ssl::SslStream<AsyncStreamBridge<S>>);
119
120impl<S> SslStream<S> {
121 pub fn ssl(&self) -> &SslRef {
123 self.0.ssl()
124 }
125
126 pub fn ssl_mut(&mut self) -> &mut SslRef {
128 self.0.ssl_mut()
129 }
130
131 pub fn get_ref(&self) -> &S {
133 &self.0.get_ref().stream
134 }
135
136 pub fn get_mut(&mut self) -> &mut S {
138 &mut self.0.get_mut().stream
139 }
140
141 fn run_in_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> R
142 where
143 F: FnOnce(&mut ssl::SslStream<AsyncStreamBridge<S>>) -> R,
144 {
145 self.0.get_mut().set_waker(Some(ctx));
146
147 let result = f(&mut self.0);
148
149 self.0.get_mut().set_waker(None);
154
155 result
156 }
157}
158
159impl<S> AsyncRead for SslStream<S>
160where
161 S: AsyncRead + AsyncWrite + Unpin,
162{
163 fn poll_read(
164 mut self: Pin<&mut Self>,
165 ctx: &mut Context<'_>,
166 buf: &mut [u8],
167 ) -> Poll<io::Result<usize>> {
168 self.run_in_context(ctx, |s| cvt(s.read(buf)))
169 }
170}
171
172impl<S> AsyncWrite for SslStream<S>
173where
174 S: AsyncRead + AsyncWrite + Unpin,
175{
176 fn poll_write(
177 mut self: Pin<&mut Self>,
178 ctx: &mut Context,
179 buf: &[u8],
180 ) -> Poll<io::Result<usize>> {
181 self.run_in_context(ctx, |s| cvt(s.write(buf)))
182 }
183
184 fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
185 self.run_in_context(ctx, |s| cvt(s.flush()))
186 }
187
188 fn poll_close(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
189 match self.run_in_context(ctx, |s| s.shutdown()) {
190 Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {}
191 Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {}
192 Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
193 return Poll::Pending;
194 }
195 Err(e) => {
196 return Poll::Ready(Err(e
197 .into_io_error()
198 .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))));
199 }
200 }
201
202 Pin::new(&mut self.0.get_mut().stream).poll_close(ctx)
203 }
204}
205
206pub struct HandshakeError<S>(ssl::HandshakeError<AsyncStreamBridge<S>>);
208
209impl<S> HandshakeError<S> {
210 pub fn ssl(&self) -> Option<&SslRef> {
212 match &self.0 {
213 ssl::HandshakeError::Failure(s) => Some(s.ssl()),
214 _ => None,
215 }
216 }
217
218 pub fn into_source_stream(self) -> Option<S> {
220 match self.0 {
221 ssl::HandshakeError::Failure(s) => Some(s.into_source_stream().stream),
222 _ => None,
223 }
224 }
225
226 pub fn as_source_stream(&self) -> Option<&S> {
228 match &self.0 {
229 ssl::HandshakeError::Failure(s) => Some(&s.get_ref().stream),
230 _ => None,
231 }
232 }
233
234 pub fn code(&self) -> Option<ErrorCode> {
236 match &self.0 {
237 ssl::HandshakeError::Failure(s) => Some(s.error().code()),
238 _ => None,
239 }
240 }
241
242 pub fn as_io_error(&self) -> Option<&io::Error> {
244 match &self.0 {
245 ssl::HandshakeError::Failure(s) => s.error().io_error(),
246 _ => None,
247 }
248 }
249}
250
251impl<S> fmt::Debug for HandshakeError<S>
252where
253 S: fmt::Debug,
254{
255 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
256 fmt::Debug::fmt(&self.0, fmt)
257 }
258}
259
260impl<S> fmt::Display for HandshakeError<S> {
261 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
262 fmt::Display::fmt(&self.0, fmt)
263 }
264}
265
266impl<S> std::error::Error for HandshakeError<S>
267where
268 S: fmt::Debug,
269{
270 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
271 self.0.source()
272 }
273}
274
275pub struct HandshakeFuture<S>(Option<MidHandshakeSslStream<AsyncStreamBridge<S>>>);
279
280impl<S> Future for HandshakeFuture<S>
281where
282 S: AsyncRead + AsyncWrite + Unpin,
283{
284 type Output = Result<SslStream<S>, HandshakeError<S>>;
285
286 fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
287 let mut mid_handshake = self.0.take().expect("future polled after completion");
288
289 mid_handshake.get_mut().set_waker(Some(ctx));
290 mid_handshake
291 .ssl_mut()
292 .set_task_waker(Some(ctx.waker().clone()));
293
294 match mid_handshake.handshake() {
295 Ok(mut stream) => {
296 stream.get_mut().set_waker(None);
297 stream.ssl_mut().set_task_waker(None);
298
299 Poll::Ready(Ok(SslStream(stream)))
300 }
301 Err(ssl::HandshakeError::WouldBlock(mut mid_handshake)) => {
302 mid_handshake.get_mut().set_waker(None);
303 mid_handshake.ssl_mut().set_task_waker(None);
304
305 self.0 = Some(mid_handshake);
306
307 Poll::Pending
308 }
309 Err(ssl::HandshakeError::Failure(mut mid_handshake)) => {
310 mid_handshake.get_mut().set_waker(None);
311
312 Poll::Ready(Err(HandshakeError(ssl::HandshakeError::Failure(
313 mid_handshake,
314 ))))
315 }
316 Err(err @ ssl::HandshakeError::SetupFailure(_)) => {
317 Poll::Ready(Err(HandshakeError(err)))
318 }
319 }
320 }
321}