async_rustls_stream/
lib.rs

1//! An async tls stream library based on [rustls] and [futures_io]. Both for server/client.
2//!
3//! # Examples
4//!
5//! **Server**
6//! ```ignore
7//! let listener = async_net::TcpListener::bind((Ipv4Addr::LOCALHOST, 4443)).await.unwrap();
8//! let (stream, remote_addr) = listener.accept().await.unwrap();
9//!
10//! // Recv Client Hello
11//! let accept = TlsAccepted::accept(stream).await.unwrap();
12//!
13//! let server_config = Arc::new(server_config);
14//! let mut stream = accept.into_stream(server_config.clone()).unwrap();
15//! // handshake completed
16//! stream.flush().await.unwrap();
17//! ```
18//!
19//! **Client**
20//!
21//! ```ignore
22//! let server_name = "test.com".try_into().unwrap();
23//! let client_config = Arc::new(client_config);
24//! let connector = TlsConnector::new(client_config.clone(), server_name).unwrap();
25//!
26//! let stream = async_net::TcpStream::connect((Ipv4Addr::LOCALHOST, 4443)).await.unwrap();
27//!
28//! let mut stream = connector.connect(stream);
29//! // handshake completed
30//! stream.flush().await.unwrap();
31//! ```
32//! or [examples](https://github.com/hs-CN/async-rustls-stream/blob/master/examples).
33
34use futures_io::{AsyncRead, AsyncWrite};
35use rustls::{
36    server::{Accepted, Acceptor, ClientHello},
37    ClientConfig, ClientConnection, ConnectionCommon, ServerConfig, ServerConnection, ServerName,
38    SideData, Stream,
39};
40use std::{
41    future::Future,
42    io::{self, Read, Write},
43    ops::{Deref, DerefMut},
44    pin::Pin,
45    sync::Arc,
46    task::{Context, Poll},
47};
48
49struct InnerStream<'a, 'b, T> {
50    cx: &'a mut Context<'b>,
51    stream: &'a mut T,
52}
53
54impl<'a, 'b, T: AsyncRead + Unpin> Read for InnerStream<'a, 'b, T> {
55    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
56        match Pin::new(&mut self.stream).poll_read(self.cx, buf) {
57            Poll::Ready(res) => res,
58            Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
59        }
60    }
61
62    fn read_vectored(&mut self, bufs: &mut [io::IoSliceMut<'_>]) -> io::Result<usize> {
63        match Pin::new(&mut self.stream).poll_read_vectored(self.cx, bufs) {
64            Poll::Ready(res) => res,
65            Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
66        }
67    }
68}
69
70impl<'a, 'b, T: AsyncWrite + Unpin> Write for InnerStream<'a, 'b, T> {
71    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
72        match Pin::new(&mut self.stream).poll_write(self.cx, buf) {
73            Poll::Ready(res) => res,
74            Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
75        }
76    }
77
78    fn write_vectored(&mut self, bufs: &[io::IoSlice<'_>]) -> io::Result<usize> {
79        match Pin::new(&mut self.stream).poll_write_vectored(self.cx, bufs) {
80            Poll::Ready(res) => res,
81            Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
82        }
83    }
84
85    fn flush(&mut self) -> io::Result<()> {
86        match Pin::new(&mut self.stream).poll_flush(self.cx) {
87            Poll::Ready(res) => res,
88            Poll::Pending => Err(io::ErrorKind::WouldBlock.into()),
89        }
90    }
91}
92
93/// Tls Stream Implement [AsyncRead] and [AsyncWrite]
94pub struct TlsStream<C, T> {
95    connection: C,
96    stream: T,
97}
98
99impl<C, T> TlsStream<C, T> {
100    pub fn get_ref(&self) -> (&C, &T) {
101        (&self.connection, &self.stream)
102    }
103
104    pub fn get_mut(&mut self) -> (&mut C, &mut T) {
105        (&mut self.connection, &mut self.stream)
106    }
107}
108
109impl<C, T, S> AsyncRead for TlsStream<C, T>
110where
111    C: DerefMut + Deref<Target = ConnectionCommon<S>> + Unpin,
112    T: AsyncRead + AsyncWrite + Unpin,
113    S: SideData,
114{
115    fn poll_read(
116        mut self: std::pin::Pin<&mut Self>,
117        cx: &mut std::task::Context<'_>,
118        buf: &mut [u8],
119    ) -> std::task::Poll<std::io::Result<usize>> {
120        let (connection, stream) = (*self).get_mut();
121        let mut stream = Stream {
122            conn: connection,
123            sock: &mut InnerStream { cx, stream },
124        };
125        match stream.read(buf) {
126            Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
127            res => Poll::Ready(res),
128        }
129    }
130
131    fn poll_read_vectored(
132        mut self: std::pin::Pin<&mut Self>,
133        cx: &mut std::task::Context<'_>,
134        bufs: &mut [std::io::IoSliceMut<'_>],
135    ) -> std::task::Poll<std::io::Result<usize>> {
136        let (connection, stream) = (*self).get_mut();
137        let mut stream = Stream {
138            conn: connection,
139            sock: &mut InnerStream { cx, stream },
140        };
141        match stream.read_vectored(bufs) {
142            Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
143            res => Poll::Ready(res),
144        }
145    }
146}
147
148impl<C, T, S> AsyncWrite for TlsStream<C, T>
149where
150    C: DerefMut + Deref<Target = ConnectionCommon<S>> + Unpin,
151    T: AsyncRead + AsyncWrite + Unpin,
152    S: SideData,
153{
154    fn poll_write(
155        mut self: std::pin::Pin<&mut Self>,
156        cx: &mut std::task::Context<'_>,
157        buf: &[u8],
158    ) -> std::task::Poll<std::io::Result<usize>> {
159        let (connection, stream) = (*self).get_mut();
160        let mut stream = Stream {
161            conn: connection,
162            sock: &mut InnerStream { cx, stream },
163        };
164        match stream.write(buf) {
165            Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
166            res => Poll::Ready(res),
167        }
168    }
169
170    fn poll_write_vectored(
171        mut self: std::pin::Pin<&mut Self>,
172        cx: &mut std::task::Context<'_>,
173        bufs: &[std::io::IoSlice<'_>],
174    ) -> std::task::Poll<std::io::Result<usize>> {
175        let (connection, stream) = (*self).get_mut();
176        let mut stream = Stream {
177            conn: connection,
178            sock: &mut InnerStream { cx, stream },
179        };
180        match stream.write_vectored(bufs) {
181            Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
182            res => Poll::Ready(res),
183        }
184    }
185
186    fn poll_flush(
187        mut self: std::pin::Pin<&mut Self>,
188        cx: &mut std::task::Context<'_>,
189    ) -> std::task::Poll<std::io::Result<()>> {
190        let (connection, stream) = (*self).get_mut();
191        let mut stream = Stream {
192            conn: connection,
193            sock: &mut InnerStream { cx, stream },
194        };
195        match stream.flush() {
196            Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
197            res => Poll::Ready(res),
198        }
199    }
200
201    fn poll_close(
202        self: std::pin::Pin<&mut Self>,
203        cx: &mut std::task::Context<'_>,
204    ) -> std::task::Poll<std::io::Result<()>> {
205        self.poll_flush(cx)
206    }
207}
208
209/// Tls Client Connector.
210///
211/// Use [TlsConnector::connect()] to get [TlsStream] for client.
212///
213/// Then use [TlsStream]::flush() to finish the handshake.
214pub struct TlsConnector(ClientConnection);
215
216impl TlsConnector {
217    pub fn new(config: Arc<ClientConfig>, server_name: ServerName) -> Result<Self, rustls::Error> {
218        let connection = ClientConnection::new(config, server_name)?;
219        Ok(Self(connection))
220    }
221
222    /// The `stream` generally should implement [AsyncRead] and [AsyncWrite].
223    pub fn connect<T>(self, stream: T) -> TlsStream<ClientConnection, T> {
224        TlsStream {
225            connection: self.0,
226            stream,
227        }
228    }
229}
230
231/// Tls Server Accept the `Client Hello` and finish the handshake.
232///
233/// Use [`TlsAccepted::accept()`] to receive the `Client Hello`.
234///
235/// Then use [TlsAccepted::into_stream()] to get [TlsStream].
236///
237/// Then use [TlsStream]::flush() to finish the handshake.
238pub struct TlsAccepted<T> {
239    accepted: Accepted,
240    stream: T,
241}
242
243impl<T> TlsAccepted<T> {
244    /// Get the [`ClientHello`] received form client.
245    pub fn client_hello(&self) -> ClientHello {
246        self.accepted.client_hello()
247    }
248
249    /// Convert Into [`TlsStream`] with [`ServerConfig`].
250    pub fn into_stream(
251        self,
252        config: Arc<ServerConfig>,
253    ) -> Result<TlsStream<ServerConnection, T>, rustls::Error> {
254        let connection = self.accepted.into_connection(config)?;
255        Ok(TlsStream {
256            connection,
257            stream: self.stream,
258        })
259    }
260}
261
262impl<T> TlsAccepted<T>
263where
264    T: AsyncRead + Unpin,
265{
266    /// Receive `Client Hello`. The `stream` generally should implement [AsyncRead] and [AsyncWrite].
267    pub async fn accept(mut stream: T) -> io::Result<TlsAccepted<T>> {
268        let accepted = AcceptFuture {
269            acceptor: Acceptor::new().unwrap(),
270            stream: &mut stream,
271        }
272        .await?;
273        Ok(TlsAccepted { accepted, stream })
274    }
275}
276
277struct AcceptFuture<'a, T> {
278    acceptor: Acceptor,
279    stream: &'a mut T,
280}
281
282impl<'a, T> AcceptFuture<'a, T> {
283    fn get_mut(&mut self) -> (&mut Acceptor, &mut T) {
284        (&mut self.acceptor, self.stream)
285    }
286}
287
288impl<'a, T: AsyncRead + Unpin> Future for AcceptFuture<'a, T> {
289    type Output = io::Result<Accepted>;
290
291    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
292        let (acceptor, stream) = (*self).get_mut();
293        match acceptor.read_tls(&mut InnerStream { cx, stream }) {
294            Ok(_) => match self.acceptor.accept() {
295                Ok(None) => Poll::Pending,
296                Ok(Some(accepted)) => Poll::Ready(Ok(accepted)),
297                Err(err) => Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidData, err))),
298            },
299            Err(err) if err.kind() == io::ErrorKind::WouldBlock => Poll::Pending,
300            Err(err) => Poll::Ready(Err(err)),
301        }
302    }
303}
304
305#[cfg(test)]
306mod test;