1use std::fmt;
2use std::io::{self, Read, Write};
3use std::ops::{Deref, DerefMut};
4use std::pin::Pin;
5
6use openssl::ssl;
7
8use futures::io::{AsyncRead, AsyncWrite};
9use futures::ready;
10use futures::task::{Context, Poll};
11
12use async_stdio::*;
13
14pub struct SslStream<S> {
16 pub(crate) inner: ssl::SslStream<AsStdIo<S>>,
17 pub(crate) ctrl: WakerCtrl,
18}
19
20impl<S> Deref for SslStream<S> {
21 type Target = ssl::SslStream<AsStdIo<S>>;
22
23 fn deref(&self) -> &Self::Target {
24 &self.inner
25 }
26}
27
28impl<S> DerefMut for SslStream<S> {
29 fn deref_mut(&mut self) -> &mut Self::Target {
30 &mut self.inner
31 }
32}
33
34impl<S: fmt::Debug> fmt::Debug for SslStream<S> {
35 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
36 f.debug_struct("SslStream")
37 .field("inner", &self.inner)
38 .finish()
39 }
40}
41
42impl<S> AsyncRead for SslStream<S>
43where
44 S: Unpin + AsyncRead + AsyncWrite,
45{
46 fn poll_read(
47 self: Pin<&mut Self>,
48 cx: &mut Context<'_>,
49 buf: &mut [u8],
50 ) -> Poll<io::Result<usize>> {
51 let this = self.get_mut();
52
53 this.ctrl.register(cx.waker());
54
55 this.inner.read(buf).into_poll()
56 }
57}
58
59impl<S> AsyncWrite for SslStream<S>
60where
61 S: AsyncWrite + AsyncRead + Unpin,
62{
63 fn poll_write(
64 self: Pin<&mut Self>,
65 cx: &mut Context<'_>,
66 buf: &[u8],
67 ) -> Poll<io::Result<usize>> {
68 let this = self.get_mut();
69
70 this.ctrl.register(cx.waker());
71
72 this.inner.write(buf).into_poll()
73 }
74
75 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
76 let this = self.get_mut();
77
78 this.ctrl.register(cx.waker());
79
80 this.inner.flush().into_poll()
81 }
82
83 fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
84 ready!(self.as_mut().poll_flush(cx)?);
85
86 self.get_mut()
87 .inner
88 .get_mut()
89 .with_context(|stream, cx| stream.poll_close(cx))
90 }
91}
92
93#[cfg(feature = "tokio")]
94mod tokio {
95 use super::*;
96
97 impl<S> tokio_io::AsyncRead for SslStream<S>
98 where
99 S: Unpin + AsyncRead + AsyncWrite,
100 {
101 fn poll_read(
102 self: Pin<&mut Self>,
103 cx: &mut Context<'_>,
104 buf: &mut [u8],
105 ) -> Poll<io::Result<usize>> {
106 AsyncRead::poll_read(self, cx, buf)
107 }
108 }
109
110 impl<S> tokio_io::AsyncWrite for SslStream<S>
111 where
112 S: AsyncWrite + AsyncRead + Unpin,
113 {
114 fn poll_write(
115 self: Pin<&mut Self>,
116 cx: &mut Context<'_>,
117 buf: &[u8],
118 ) -> Poll<io::Result<usize>> {
119 AsyncWrite::poll_write(self, cx, buf)
120 }
121
122 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
123 AsyncWrite::poll_flush(self, cx)
124 }
125
126 fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
127 AsyncWrite::poll_close(self.as_mut(), cx)
128 }
129 }
130}