1#![warn(missing_docs)]
14#![cfg_attr(docsrs, feature(doc_auto_cfg))]
15
16use rama_boring::ssl::{
17 self, ConnectConfiguration, ErrorCode, MidHandshakeSslStream, ShutdownResult, SslAcceptor,
18 SslRef,
19};
20use rama_boring_sys as ffi;
21use std::error::Error;
22use std::fmt;
23use std::future::Future;
24use std::io::{self, Write};
25use std::pin::Pin;
26use std::task::{Context, Poll};
27use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
28
29mod async_callbacks;
30mod bridge;
31
32use self::bridge::AsyncStreamBridge;
33
34pub use crate::async_callbacks::SslContextBuilderExt;
35pub use rama_boring::ssl::{
36 AsyncPrivateKeyMethod, AsyncPrivateKeyMethodError, AsyncSelectCertError, BoxGetSessionFinish,
37 BoxGetSessionFuture, BoxPrivateKeyMethodFinish, BoxPrivateKeyMethodFuture, BoxSelectCertFinish,
38 BoxSelectCertFuture, ExDataFuture,
39};
40
41pub async fn connect<S>(
46 config: ConnectConfiguration,
47 domain: &str,
48 stream: S,
49) -> Result<SslStream<S>, HandshakeError<S>>
50where
51 S: AsyncRead + AsyncWrite + Unpin,
52{
53 let mid_handshake = config
54 .setup_connect(domain, AsyncStreamBridge::new(stream))
55 .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?;
56
57 HandshakeFuture(Some(mid_handshake)).await
58}
59
60pub async fn accept<S>(acceptor: &SslAcceptor, stream: S) -> Result<SslStream<S>, HandshakeError<S>>
65where
66 S: AsyncRead + AsyncWrite + Unpin,
67{
68 let mid_handshake = acceptor
69 .setup_accept(AsyncStreamBridge::new(stream))
70 .map_err(|err| HandshakeError(ssl::HandshakeError::SetupFailure(err)))?;
71
72 HandshakeFuture(Some(mid_handshake)).await
73}
74
75fn cvt<T>(r: io::Result<T>) -> Poll<io::Result<T>> {
76 match r {
77 Ok(v) => Poll::Ready(Ok(v)),
78 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
79 Err(e) => Poll::Ready(Err(e)),
80 }
81}
82
83pub struct SslStreamBuilder<S> {
85 inner: ssl::SslStreamBuilder<AsyncStreamBridge<S>>,
86}
87
88impl<S> SslStreamBuilder<S>
89where
90 S: AsyncRead + AsyncWrite + Unpin,
91{
92 pub fn new(ssl: ssl::Ssl, stream: S) -> Self {
94 Self {
95 inner: ssl::SslStreamBuilder::new(ssl, AsyncStreamBridge::new(stream)),
96 }
97 }
98
99 pub async fn accept(self) -> Result<SslStream<S>, HandshakeError<S>> {
101 let mid_handshake = self.inner.setup_accept();
102
103 HandshakeFuture(Some(mid_handshake)).await
104 }
105
106 pub async fn connect(self) -> Result<SslStream<S>, HandshakeError<S>> {
108 let mid_handshake = self.inner.setup_connect();
109
110 HandshakeFuture(Some(mid_handshake)).await
111 }
112}
113
114impl<S> SslStreamBuilder<S> {
115 pub fn ssl(&self) -> &SslRef {
117 self.inner.ssl()
118 }
119
120 pub fn ssl_mut(&mut self) -> &mut SslRef {
122 self.inner.ssl_mut()
123 }
124}
125
126#[derive(Debug)]
134pub struct SslStream<S>(ssl::SslStream<AsyncStreamBridge<S>>);
135
136impl<S> SslStream<S> {
137 pub fn ssl(&self) -> &SslRef {
139 self.0.ssl()
140 }
141
142 pub fn ssl_mut(&mut self) -> &mut SslRef {
144 self.0.ssl_mut()
145 }
146
147 pub fn get_ref(&self) -> &S {
149 &self.0.get_ref().stream
150 }
151
152 pub fn get_mut(&mut self) -> &mut S {
154 &mut self.0.get_mut().stream
155 }
156
157 fn run_in_context<F, R>(&mut self, ctx: &mut Context<'_>, f: F) -> R
158 where
159 F: FnOnce(&mut ssl::SslStream<AsyncStreamBridge<S>>) -> R,
160 {
161 self.0.get_mut().set_waker(Some(ctx));
162
163 let result = f(&mut self.0);
164
165 self.0.get_mut().set_waker(None);
170
171 result
172 }
173}
174
175impl<S> SslStream<S>
176where
177 S: AsyncRead + AsyncWrite + Unpin,
178{
179 pub unsafe fn from_raw_parts(ssl: *mut ffi::SSL, stream: S) -> Self {
187 Self(ssl::SslStream::from_raw_parts(
188 ssl,
189 AsyncStreamBridge::new(stream),
190 ))
191 }
192}
193
194impl<S> AsyncRead for SslStream<S>
195where
196 S: AsyncRead + AsyncWrite + Unpin,
197{
198 fn poll_read(
199 mut self: Pin<&mut Self>,
200 ctx: &mut Context<'_>,
201 buf: &mut ReadBuf,
202 ) -> Poll<io::Result<()>> {
203 self.run_in_context(ctx, |s| {
204 match cvt(s.read_uninit(unsafe { buf.unfilled_mut() }))? {
206 Poll::Ready(nread) => {
207 unsafe {
208 buf.assume_init(nread);
209 }
210 buf.advance(nread);
211 Poll::Ready(Ok(()))
212 }
213 Poll::Pending => Poll::Pending,
214 }
215 })
216 }
217}
218
219impl<S> AsyncWrite for SslStream<S>
220where
221 S: AsyncRead + AsyncWrite + Unpin,
222{
223 fn poll_write(
224 mut self: Pin<&mut Self>,
225 ctx: &mut Context,
226 buf: &[u8],
227 ) -> Poll<io::Result<usize>> {
228 self.run_in_context(ctx, |s| cvt(s.write(buf)))
229 }
230
231 fn poll_flush(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
232 self.run_in_context(ctx, |s| cvt(s.flush()))
233 }
234
235 fn poll_shutdown(mut self: Pin<&mut Self>, ctx: &mut Context) -> Poll<io::Result<()>> {
236 match self.run_in_context(ctx, |s| s.shutdown()) {
237 Ok(ShutdownResult::Sent) | Ok(ShutdownResult::Received) => {}
238 Err(ref e) if e.code() == ErrorCode::ZERO_RETURN => {}
239 Err(ref e) if e.code() == ErrorCode::WANT_READ || e.code() == ErrorCode::WANT_WRITE => {
240 return Poll::Pending;
241 }
242 Err(e) => {
243 return Poll::Ready(Err(e.into_io_error().unwrap_or_else(io::Error::other)));
244 }
245 }
246
247 Pin::new(&mut self.0.get_mut().stream).poll_shutdown(ctx)
248 }
249}
250
251pub struct HandshakeError<S>(ssl::HandshakeError<AsyncStreamBridge<S>>);
253
254impl<S> HandshakeError<S> {
255 pub fn ssl(&self) -> Option<&SslRef> {
257 match &self.0 {
258 ssl::HandshakeError::Failure(s) => Some(s.ssl()),
259 _ => None,
260 }
261 }
262
263 pub fn into_source_stream(self) -> Option<S> {
265 match self.0 {
266 ssl::HandshakeError::Failure(s) => Some(s.into_source_stream().stream),
267 _ => None,
268 }
269 }
270
271 pub fn as_source_stream(&self) -> Option<&S> {
273 match &self.0 {
274 ssl::HandshakeError::Failure(s) => Some(&s.get_ref().stream),
275 _ => None,
276 }
277 }
278
279 pub fn code(&self) -> Option<ErrorCode> {
281 match &self.0 {
282 ssl::HandshakeError::Failure(s) => Some(s.error().code()),
283 _ => None,
284 }
285 }
286
287 pub fn as_io_error(&self) -> Option<&io::Error> {
289 match &self.0 {
290 ssl::HandshakeError::Failure(s) => s.error().io_error(),
291 _ => None,
292 }
293 }
294}
295
296impl<S> fmt::Debug for HandshakeError<S>
297where
298 S: fmt::Debug,
299{
300 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
301 fmt::Debug::fmt(&self.0, fmt)
302 }
303}
304
305impl<S> fmt::Display for HandshakeError<S> {
306 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
307 fmt::Display::fmt(&self.0, fmt)
308 }
309}
310
311impl<S> Error for HandshakeError<S>
312where
313 S: fmt::Debug,
314{
315 fn source(&self) -> Option<&(dyn Error + 'static)> {
316 self.0.source()
317 }
318}
319
320pub struct HandshakeFuture<S>(Option<MidHandshakeSslStream<AsyncStreamBridge<S>>>);
324
325impl<S> Future for HandshakeFuture<S>
326where
327 S: AsyncRead + AsyncWrite + Unpin,
328{
329 type Output = Result<SslStream<S>, HandshakeError<S>>;
330
331 fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
332 let mut mid_handshake = self.0.take().expect("future polled after completion");
333
334 mid_handshake.get_mut().set_waker(Some(ctx));
335 mid_handshake
336 .ssl_mut()
337 .set_task_waker(Some(ctx.waker().clone()));
338
339 match mid_handshake.handshake() {
340 Ok(mut stream) => {
341 stream.get_mut().set_waker(None);
342 stream.ssl_mut().set_task_waker(None);
343
344 Poll::Ready(Ok(SslStream(stream)))
345 }
346 Err(ssl::HandshakeError::WouldBlock(mut mid_handshake)) => {
347 mid_handshake.get_mut().set_waker(None);
348 mid_handshake.ssl_mut().set_task_waker(None);
349
350 self.0 = Some(mid_handshake);
351
352 Poll::Pending
353 }
354 Err(ssl::HandshakeError::Failure(mut mid_handshake)) => {
355 mid_handshake.get_mut().set_waker(None);
356
357 Poll::Ready(Err(HandshakeError(ssl::HandshakeError::Failure(
358 mid_handshake,
359 ))))
360 }
361 Err(err @ ssl::HandshakeError::SetupFailure(_)) => {
362 Poll::Ready(Err(HandshakeError(err)))
363 }
364 }
365 }
366}