1use std::future::Future;
2use std::io::{self, Read, Write};
3use std::pin::Pin;
4use std::task::{Context, Poll};
5
6use futures_io_traits_sync_wrapper::Wrapper as AsyncRWSyncWrapper;
7use futures_util::io::{AsyncRead, AsyncWrite};
8use rustls::{Session, Stream};
9
10#[cfg(feature = "acceptor")]
11mod acceptor;
12#[cfg(feature = "acceptor")]
13pub use acceptor::TlsAcceptor;
14
15#[cfg(feature = "connector")]
16mod connector;
17#[cfg(feature = "connector")]
18pub use connector::TlsConnector;
19
20pub mod prelude {
21 pub use rustls::{
22 internal::pemfile, ClientConfig, ClientSession, NoClientAuth, ServerConfig, ServerSession,
23 Session as RustlsSession,
24 };
25
26 #[cfg(feature = "connector")]
27 pub use webpki::DNSNameRef;
28 #[cfg(feature = "connector")]
29 pub use webpki_roots::TLS_SERVER_ROOTS;
30}
31
32pub struct TlsStream<S, IO> {
33 session: S,
34 io: Box<IO>,
35}
36
37impl<S, IO> TlsStream<S, IO> {
38 pub fn get_mut(&mut self) -> (&mut S, &mut IO) {
39 (&mut self.session, self.io.as_mut())
40 }
41
42 pub fn get_ref(&self) -> (&S, &IO) {
43 (&self.session, self.io.as_ref())
44 }
45
46 pub fn into_inner(self) -> (S, IO) {
47 (self.session, *self.io)
48 }
49}
50
51pub async fn handshake<S, IO>(session: S, io: IO) -> io::Result<TlsStream<S, IO>>
52where
53 S: Session + Unpin,
54 IO: AsyncRead + AsyncWrite + Unpin,
55{
56 Handshake(Some((session, io))).await
57}
58
59struct Handshake<S, IO>(Option<(S, IO)>);
60
61impl<S, IO> Future for Handshake<S, IO>
62where
63 S: Session + Unpin,
64 IO: AsyncRead + AsyncWrite + Unpin,
65{
66 type Output = io::Result<TlsStream<S, IO>>;
67
68 fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
69 let this = self.get_mut();
70
71 let (mut session, mut stream) = this.0.take().expect("never");
72
73 let mut io = AsyncRWSyncWrapper::new(&mut stream, cx);
74
75 match session.complete_io(&mut io) {
76 Ok(_) => Poll::Ready(Ok(TlsStream {
77 session,
78 io: Box::new(stream),
79 })),
80 Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => {
81 this.0 = Some((session, stream));
82
83 Poll::Pending
84 }
85 Err(err) => Poll::Ready(Err(err)),
86 }
87 }
88}
89
90impl<S, IO> AsyncRead for TlsStream<S, IO>
91where
92 S: Session + Unpin,
93 IO: AsyncRead + AsyncWrite + Unpin,
94{
95 fn poll_read(
96 self: Pin<&mut Self>,
97 cx: &mut Context,
98 buf: &mut [u8],
99 ) -> Poll<io::Result<usize>> {
100 let this = self.get_mut();
101
102 let mut io = AsyncRWSyncWrapper::new(&mut this.io, cx);
103
104 let mut stream = Stream::new(&mut this.session, &mut io);
105
106 match stream.read(buf) {
107 Ok(n) => Poll::Ready(Ok(n)),
108 Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
109 Err(err) => Poll::Ready(Err(err)),
110 }
111 }
112}
113
114impl<S, IO> AsyncWrite for TlsStream<S, IO>
115where
116 S: Session + Unpin,
117 IO: AsyncRead + AsyncWrite + Unpin,
118{
119 fn poll_write(self: Pin<&mut Self>, cx: &mut Context, buf: &[u8]) -> Poll<io::Result<usize>> {
120 let this = self.get_mut();
121
122 let mut io = AsyncRWSyncWrapper::new(&mut this.io, cx);
123
124 let mut stream = Stream::new(&mut this.session, &mut io);
125
126 match stream.write(buf) {
127 Ok(n) => Poll::Ready(Ok(n)),
128 Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
129 Err(err) => Poll::Ready(Err(err)),
130 }
131 }
132
133 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
134 let this = self.get_mut();
135
136 let mut io = AsyncRWSyncWrapper::new(&mut this.io, cx);
137
138 let mut stream = Stream::new(&mut this.session, &mut io);
139
140 stream.flush()?;
141
142 Pin::new(&mut this.io).poll_flush(cx)
143 }
144
145 fn poll_close(self: Pin<&mut Self>, cx: &mut Context) -> Poll<io::Result<()>> {
146 let this = self.get_mut();
147
148 Pin::new(&mut this.io).poll_close(cx)
149 }
150}