1use futures_io::{AsyncRead, AsyncWrite};
35use rustls::{
36 server::{Accepted, Acceptor, ClientHello},
37 ClientConfig, ClientConnection, ConnectionCommon, ServerConfig, ServerConnection, ServerName,
38 SideData, Stream,
39};
40use std::{
41 future::Future,
42 io::{self, Read, Write},
43 ops::{Deref, DerefMut},
44 pin::Pin,
45 sync::Arc,
46 task::{Context, Poll},
47};
48
49struct InnerStream<'a, 'b, T> {
50 cx: &'a mut Context<'b>,
51 stream: &'a mut T,
52}
53
54impl<'a, 'b, T: AsyncRead + Unpin> Read for InnerStream<'a, 'b, T> {
55 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
56 match Pin::new(&mut self.stream).poll_read(self.cx, buf) {
57 Poll::Ready(res) => res,
58 Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
59 }
60 }
61
62 fn read_vectored(&mut self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> {
63 match Pin::new(&mut self.stream).poll_read_vectored(self.cx, bufs) {
64 Poll::Ready(res) => res,
65 Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
66 }
67 }
68}
69
70impl<'a, 'b, T: AsyncWrite + Unpin> Write for InnerStream<'a, 'b, T> {
71 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
72 match Pin::new(&mut self.stream).poll_write(self.cx, buf) {
73 Poll::Ready(res) => res,
74 Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
75 }
76 }
77
78 fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> {
79 match Pin::new(&mut self.stream).poll_write_vectored(self.cx, bufs) {
80 Poll::Ready(res) => res,
81 Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
82 }
83 }
84
85 fn flush(&mut self) -> io::Result<()> {
86 match Pin::new(&mut self.stream).poll_flush(self.cx) {
87 Poll::Ready(res) => res,
88 Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
89 }
90 }
91}
92
93pub struct TlsStream<C, T> {
95 connection: C,
96 stream: T,
97}
98
99impl<C, T> TlsStream<C, T> {
100 pub fn get_ref(&self) -> (&C, &T) {
101 (&self.connection, &self.stream)
102 }
103
104 pub fn get_mut(&mut self) -> (&mut C, &mut T) {
105 (&mut self.connection, &mut self.stream)
106 }
107}
108
109impl<C, T, S> AsyncRead for TlsStream<C, T>
110where
111 C: DerefMut + Deref<Target = ConnectionCommon<S>> + Unpin,
112 T: AsyncRead + AsyncWrite + Unpin,
113 S: SideData,
114{
115 fn poll_read(
116 mut self: std::pin::Pin<&mut Self>,
117 cx: &mut std::task::Context<'_>,
118 buf: &mut [u8],
119 ) -> std::task::Poll<std::io::Result<usize>> {
120 let (connection, stream) = (*self).get_mut();
121 let mut stream = Stream {
122 conn: connection,
123 sock: &mut InnerStream { cx, stream },
124 };
125 match stream.read(buf) {
126 Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
127 res => Poll::Ready(res),
128 }
129 }
130
131 fn poll_read_vectored(
132 mut self: std::pin::Pin<&mut Self>,
133 cx: &mut std::task::Context<'_>,
134 bufs: &mut [std::io::IoSliceMut<'_>],
135 ) -> std::task::Poll<std::io::Result<usize>> {
136 let (connection, stream) = (*self).get_mut();
137 let mut stream = Stream {
138 conn: connection,
139 sock: &mut InnerStream { cx, stream },
140 };
141 match stream.read_vectored(bufs) {
142 Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
143 res => Poll::Ready(res),
144 }
145 }
146}
147
148impl<C, T, S> AsyncWrite for TlsStream<C, T>
149where
150 C: DerefMut + Deref<Target = ConnectionCommon<S>> + Unpin,
151 T: AsyncRead + AsyncWrite + Unpin,
152 S: SideData,
153{
154 fn poll_write(
155 mut self: std::pin::Pin<&mut Self>,
156 cx: &mut std::task::Context<'_>,
157 buf: &[u8],
158 ) -> std::task::Poll<std::io::Result<usize>> {
159 let (connection, stream) = (*self).get_mut();
160 let mut stream = Stream {
161 conn: connection,
162 sock: &mut InnerStream { cx, stream },
163 };
164 match stream.write(buf) {
165 Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
166 res => Poll::Ready(res),
167 }
168 }
169
170 fn poll_write_vectored(
171 mut self: std::pin::Pin<&mut Self>,
172 cx: &mut std::task::Context<'_>,
173 bufs: &[std::io::IoSlice<'_>],
174 ) -> std::task::Poll<std::io::Result<usize>> {
175 let (connection, stream) = (*self).get_mut();
176 let mut stream = Stream {
177 conn: connection,
178 sock: &mut InnerStream { cx, stream },
179 };
180 match stream.write_vectored(bufs) {
181 Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
182 res => Poll::Ready(res),
183 }
184 }
185
186 fn poll_flush(
187 mut self: std::pin::Pin<&mut Self>,
188 cx: &mut std::task::Context<'_>,
189 ) -> std::task::Poll<std::io::Result<()>> {
190 let (connection, stream) = (*self).get_mut();
191 let mut stream = Stream {
192 conn: connection,
193 sock: &mut InnerStream { cx, stream },
194 };
195 match stream.flush() {
196 Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
197 res => Poll::Ready(res),
198 }
199 }
200
201 fn poll_close(
202 self: std::pin::Pin<&mut Self>,
203 cx: &mut std::task::Context<'_>,
204 ) -> std::task::Poll<std::io::Result<()>> {
205 self.poll_flush(cx)
206 }
207}
208
209pub struct TlsConnector(ClientConnection);
215
216impl TlsConnector {
217 pub fn new(config: Arc<ClientConfig>, server_name: ServerName) -> Result<Self, rustls::Error> {
218 let connection = ClientConnection::new(config, server_name)?;
219 Ok(Self(connection))
220 }
221
222 pub fn connect<T>(self, stream: T) -> TlsStream<ClientConnection, T> {
224 TlsStream {
225 connection: self.0,
226 stream,
227 }
228 }
229}
230
231pub struct TlsAccepted<T> {
239 accepted: Accepted,
240 stream: T,
241}
242
243impl<T> TlsAccepted<T> {
244 pub fn client_hello(&self) -> ClientHello {
246 self.accepted.client_hello()
247 }
248
249 pub fn into_stream(
251 self,
252 config: Arc<ServerConfig>,
253 ) -> Result<TlsStream<ServerConnection, T>, rustls::Error> {
254 let connection = self.accepted.into_connection(config)?;
255 Ok(TlsStream {
256 connection,
257 stream: self.stream,
258 })
259 }
260}
261
262impl<T> TlsAccepted<T>
263where
264 T: AsyncRead + Unpin,
265{
266 pub async fn accept(mut stream: T) -> io::Result<TlsAccepted<T>> {
268 let accepted = AcceptFuture {
269 acceptor: Acceptor::new().unwrap(),
270 stream: &mut stream,
271 }
272 .await?;
273 Ok(TlsAccepted { accepted, stream })
274 }
275}
276
277struct AcceptFuture<'a, T> {
278 acceptor: Acceptor,
279 stream: &'a mut T,
280}
281
282impl<'a, T> AcceptFuture<'a, T> {
283 fn get_mut(&mut self) -> (&mut Acceptor, &mut T) {
284 (&mut self.acceptor, self.stream)
285 }
286}
287
288impl<'a, T: AsyncRead + Unpin> Future for AcceptFuture<'a, T> {
289 type Output = io::Result<Accepted>;
290
291 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
292 let (acceptor, stream) = (*self).get_mut();
293 match acceptor.read_tls(&mut InnerStream { cx, stream }) {
294 Ok(_) => match self.acceptor.accept() {
295 Ok(None) => Poll::Pending,
296 Ok(Some(accepted)) => Poll::Ready(Ok(accepted)),
297 Err(err) => Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidData, err))),
298 },
299 Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
300 Err(err) => Poll::Ready(Err(err)),
301 }
302 }
303}
304
305#[cfg(test)]
306mod test;