openssl_async/
stream.rs

1use std::fmt;
2use std::io::{self, Read, Write};
3use std::ops::{Deref, DerefMut};
4use std::pin::Pin;
5
6use openssl::ssl;
7
8use futures::io::{AsyncRead, AsyncWrite};
9use futures::ready;
10use futures::task::{Context, Poll};
11
12use async_stdio::*;
13
14/// An asynchronous SSL stream
15pub struct SslStream<S> {
16    pub(crate) inner: ssl::SslStream<AsStdIo<S>>,
17    pub(crate) ctrl: WakerCtrl,
18}
19
20impl<S> Deref for SslStream<S> {
21    type Target = ssl::SslStream<AsStdIo<S>>;
22
23    fn deref(&self) -> &Self::Target {
24        &self.inner
25    }
26}
27
28impl<S> DerefMut for SslStream<S> {
29    fn deref_mut(&mut self) -> &mut Self::Target {
30        &mut self.inner
31    }
32}
33
34impl<S: fmt::Debug> fmt::Debug for SslStream<S> {
35    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
36        f.debug_struct("SslStream")
37            .field("inner", &self.inner)
38            .finish()
39    }
40}
41
42impl<S> AsyncRead for SslStream<S>
43where
44    S: Unpin + AsyncRead + AsyncWrite,
45{
46    fn poll_read(
47        self: Pin<&mut Self>,
48        cx: &mut Context<'_>,
49        buf: &mut [u8],
50    ) -> Poll<io::Result<usize>> {
51        let this = self.get_mut();
52
53        this.ctrl.register(cx.waker());
54
55        this.inner.read(buf).into_poll()
56    }
57}
58
59impl<S> AsyncWrite for SslStream<S>
60where
61    S: AsyncWrite + AsyncRead + Unpin,
62{
63    fn poll_write(
64        self: Pin<&mut Self>,
65        cx: &mut Context<'_>,
66        buf: &[u8],
67    ) -> Poll<io::Result<usize>> {
68        let this = self.get_mut();
69
70        this.ctrl.register(cx.waker());
71
72        this.inner.write(buf).into_poll()
73    }
74
75    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
76        let this = self.get_mut();
77
78        this.ctrl.register(cx.waker());
79
80        this.inner.flush().into_poll()
81    }
82
83    fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
84        ready!(self.as_mut().poll_flush(cx)?);
85
86        self.get_mut()
87            .inner
88            .get_mut()
89            .with_context(|stream, cx| stream.poll_close(cx))
90    }
91}
92
93#[cfg(feature = "tokio")]
94mod tokio {
95    use super::*;
96
97    impl<S> tokio_io::AsyncRead for SslStream<S>
98    where
99        S: Unpin + AsyncRead + AsyncWrite,
100    {
101        fn poll_read(
102            self: Pin<&mut Self>,
103            cx: &mut Context<'_>,
104            buf: &mut [u8],
105        ) -> Poll<io::Result<usize>> {
106            AsyncRead::poll_read(self, cx, buf)
107        }
108    }
109
110    impl<S> tokio_io::AsyncWrite for SslStream<S>
111    where
112        S: AsyncWrite + AsyncRead + Unpin,
113    {
114        fn poll_write(
115            self: Pin<&mut Self>,
116            cx: &mut Context<'_>,
117            buf: &[u8],
118        ) -> Poll<io::Result<usize>> {
119            AsyncWrite::poll_write(self, cx, buf)
120        }
121
122        fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
123            AsyncWrite::poll_flush(self, cx)
124        }
125
126        fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
127            AsyncWrite::poll_close(self.as_mut(), cx)
128        }
129    }
130}