1use boring::ssl::*;
2pub use boring::*;
3
4mod bridge;
5mod callbacks;
6
7use futures::{AsyncRead, AsyncWrite, Stream};
8
9use std::error::Error;
10use std::fmt;
11use std::future::Future;
12use std::io::{self, Read, Write};
13use std::pin::Pin;
14
15use std::task::{Context, Poll};
16
17use bridge::*;
18
19pub use callbacks::SslContextBuilderExt;
20
21pub use boring::ssl::{
22 AsyncPrivateKeyMethod, AsyncPrivateKeyMethodError, AsyncSelectCertError, BoxGetSessionFinish,
23 BoxGetSessionFuture, BoxPrivateKeyMethodFinish, BoxPrivateKeyMethodFuture, BoxSelectCertFinish,
24 BoxSelectCertFuture, ExDataFuture,
25};
26
27pub async fn connect<S>(
32 config: ConnectConfiguration,
33 domain: &str,
34 stream: S,
35) -> Result<SslStream<S>, HandshakeError<S>>
36where
37 S: AsyncRead + AsyncWrite + Unpin,
38{
39 let mid_handshake = config
40 .setup_connect(domain, AsyncStreamBridge::new(stream))
41 .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?;
42
43 HandshakeFuture(Some(mid_handshake)).await
44}
45
46pub async fn accept<S>(acceptor: &SslAcceptor, stream: S) -> Result<SslStream<S>, HandshakeError<S>>
51where
52 S: AsyncRead + AsyncWrite + Unpin,
53{
54 let mid_handshake = acceptor
55 .setup_accept(AsyncStreamBridge::new(stream))
56 .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?;
57
58 HandshakeFuture(Some(mid_handshake)).await
59}
60
61fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
62 match r {
63 Ok(v) => Poll::Ready(Ok(v)),
64 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
65 Err(e) => Poll::Ready(Err(e)),
66 }
67}
68
69pub struct SslStreamBuilder<S> {
71 inner: ssl::SslStreamBuilder<AsyncStreamBridge<S>>,
72}
73
74impl<S> SslStreamBuilder<S>
75where
76 S: AsyncRead + AsyncWrite + Unpin,
77{
78 pub fn new(ssl: ssl::Ssl, stream: S) -> Self {
80 Self {
81 inner: ssl::SslStreamBuilder::new(ssl, AsyncStreamBridge::new(stream)),
82 }
83 }
84
85 pub async fn accept(self) -> Result<SslStream<S>, HandshakeError<S>> {
87 let mid_handshake = self.inner.setup_accept();
88
89 HandshakeFuture(Some(mid_handshake)).await
90 }
91
92 pub async fn connect(self) -> Result<SslStream<S>, HandshakeError<S>> {
94 let mid_handshake = self.inner.setup_connect();
95
96 HandshakeFuture(Some(mid_handshake)).await
97 }
98}
99
100impl<S> SslStreamBuilder<S> {
101 pub fn ssl(&self) -> &SslRef {
103 self.inner.ssl()
104 }
105
106 pub fn ssl_mut(&mut self) -> &mut SslRef {
108 self.inner.ssl_mut()
109 }
110}
111
112#[derive(Debug)]
120pub struct SslStream<S>(ssl::SslStream<AsyncStreamBridge<S>>);
121
122impl<S> SslStream<S> {
123 pub fn ssl(&self) -> &SslRef {
125 self.0.ssl()
126 }
127
128 pub fn ssl_mut(&mut self) -> &mut SslRef {
130 self.0.ssl_mut()
131 }
132
133 pub fn get_ref(&self) -> &S {
135 &self.0.get_ref().stream
136 }
137
138 pub fn get_mut(&mut self) -> &mut S {
140 &mut self.0.get_mut().stream
141 }
142
143 fn run_in_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> R
144 where
145 F: FnOnce(&mut ssl::SslStream<AsyncStreamBridge<S>>) -> R,
146 {
147 self.0.get_mut().set_waker(Some(ctx));
148
149 let result = f(&mut self.0);
150
151 self.0.get_mut().set_waker(None);
156
157 result
158 }
159}
160
161impl<S> AsyncRead for SslStream<S>
162where
163 S: AsyncRead + AsyncWrite + Unpin,
164{
165 fn poll_read(
166 mut self: Pin<&mut Self>,
167 ctx: &mut Context<'_>,
168 buf: &mut [u8],
169 ) -> Poll<io::Result<usize>> {
170 self.run_in_context(ctx, |s| cvt(s.read(buf)))
171 }
172}
173
174impl<S> AsyncWrite for SslStream<S>
175where
176 S: AsyncRead + AsyncWrite + Unpin,
177{
178 fn poll_write(
179 mut self: Pin<&mut Self>,
180 ctx: &mut Context,
181 buf: &[u8],
182 ) -> Poll<io::Result<usize>> {
183 self.run_in_context(ctx, |s| cvt(s.write(buf)))
184 }
185
186 fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
187 self.run_in_context(ctx, |s| cvt(s.flush()))
188 }
189
190 fn poll_close(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
191 match self.run_in_context(ctx, |s| s.shutdown()) {
192 Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {}
193 Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {}
194 Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
195 return Poll::Pending;
196 }
197 Err(e) => {
198 return Poll::Ready(Err(e
199 .into_io_error()
200 .unwrap_or_else(|e| io::Error::new(io::ErrorKind::Other, e))));
201 }
202 }
203
204 Pin::new(&mut self.0.get_mut().stream).poll_close(ctx)
205 }
206}
207
208pub struct HandshakeError<S>(ssl::HandshakeError<AsyncStreamBridge<S>>);
210
211impl<S> HandshakeError<S> {
212 pub fn ssl(&self) -> Option<&SslRef> {
214 match &self.0 {
215 ssl::HandshakeError::Failure(s) => Some(s.ssl()),
216 _ => None,
217 }
218 }
219
220 pub fn into_source_stream(self) -> Option<S> {
222 match self.0 {
223 ssl::HandshakeError::Failure(s) => Some(s.into_source_stream().stream),
224 _ => None,
225 }
226 }
227
228 pub fn as_source_stream(&self) -> Option<&S> {
230 match &self.0 {
231 ssl::HandshakeError::Failure(s) => Some(&s.get_ref().stream),
232 _ => None,
233 }
234 }
235
236 pub fn code(&self) -> Option<ErrorCode> {
238 match &self.0 {
239 ssl::HandshakeError::Failure(s) => Some(s.error().code()),
240 _ => None,
241 }
242 }
243
244 pub fn as_io_error(&self) -> Option<&io::Error> {
246 match &self.0 {
247 ssl::HandshakeError::Failure(s) => s.error().io_error(),
248 _ => None,
249 }
250 }
251}
252
253impl<S> fmt::Debug for HandshakeError<S>
254where
255 S: fmt::Debug,
256{
257 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
258 fmt::Debug::fmt(&self.0, fmt)
259 }
260}
261
262impl<S> fmt::Display for HandshakeError<S> {
263 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
264 fmt::Display::fmt(&self.0, fmt)
265 }
266}
267
268impl<S> std::error::Error for HandshakeError<S>
269where
270 S: fmt::Debug,
271{
272 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
273 self.0.source()
274 }
275}
276
277pub struct HandshakeFuture<S>(Option<MidHandshakeSslStream<AsyncStreamBridge<S>>>);
281
282impl<S> Future for HandshakeFuture<S>
283where
284 S: AsyncRead + AsyncWrite + Unpin,
285{
286 type Output = Result<SslStream<S>, HandshakeError<S>>;
287
288 fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
289 let mut mid_handshake = self.0.take().expect("future polled after completion");
290
291 mid_handshake.get_mut().set_waker(Some(ctx));
292 mid_handshake
293 .ssl_mut()
294 .set_task_waker(Some(ctx.waker().clone()));
295
296 match mid_handshake.handshake() {
297 Ok(mut stream) => {
298 stream.get_mut().set_waker(None);
299 stream.ssl_mut().set_task_waker(None);
300
301 Poll::Ready(Ok(SslStream(stream)))
302 }
303 Err(ssl::HandshakeError::WouldBlock(mut mid_handshake)) => {
304 mid_handshake.get_mut().set_waker(None);
305 mid_handshake.ssl_mut().set_task_waker(None);
306
307 self.0 = Some(mid_handshake);
308
309 Poll::Pending
310 }
311 Err(ssl::HandshakeError::Failure(mut mid_handshake)) => {
312 mid_handshake.get_mut().set_waker(None);
313
314 Poll::Ready(Err(HandshakeError(ssl::HandshakeError::Failure(
315 mid_handshake,
316 ))))
317 }
318 Err(err @ ssl::HandshakeError::SetupFailure(_)) => {
319 Poll::Ready(Err(HandshakeError(err)))
320 }
321 }
322 }
323}
324
325pub struct SslListener<S> {
326 incoming: S,
327 acceptor: SslAcceptor,
328}
329
330impl<S> SslListener<S> {
331 pub fn on(incoming: S, acceptor: SslAcceptor) -> Self {
333 Self { incoming, acceptor }
334 }
335
336 pub async fn accept<I, E>(&mut self) -> std::io::Result<SslStream<I>>
337 where
338 S: Stream<Item = Result<I, E>> + Unpin,
339 I: AsyncRead + AsyncWrite + Unpin,
340 E: Error,
341 {
342 use futures::TryStreamExt;
343
344 while let Some(stream) = self
345 .incoming
346 .try_next()
347 .await
348 .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err.to_string()))?
349 {
350 let stream = accept(&self.acceptor, stream).await.map_err(|err| {
351 let err = std::io::Error::new(std::io::ErrorKind::Other, err.to_string());
352
353 log::error!("{}", err);
354
355 err
356 })?;
357
358 return Ok(stream);
359 }
360
361 Err(std::io::Error::new(
362 io::ErrorKind::BrokenPipe,
363 "Ssl listener inner stream broken.",
364 ))
365 }
366
367 pub fn into_incoming<I, E>(self) -> impl Stream<Item = io::Result<SslStream<I>>> + Unpin
368 where
369 S: Stream<Item = Result<I, E>> + Unpin,
370 I: AsyncRead + AsyncWrite + Unpin,
371 E: Error,
372 {
373 Box::pin(futures::stream::unfold(self, |mut listener| async move {
374 let res = listener.accept().await;
375 Some((res, listener))
376 }))
377 }
378}