1#![doc(
4 html_favicon_url = "https://raw.githubusercontent.com/smol-rs/smol/master/assets/images/logo_fullsize_transparent.png"
5)]
6#![doc(
7 html_logo_url = "https://raw.githubusercontent.com/smol-rs/smol/master/assets/images/logo_fullsize_transparent.png"
8)]
9#."
12)]
13
14macro_rules! ready {
15 ( $e:expr ) => {
16 match $e {
17 std::task::Poll::Ready(t) => t,
18 std::task::Poll::Pending => return std::task::Poll::Pending,
19 }
20 };
21}
22
23pub mod client;
24mod common;
25pub mod server;
26
27use common::{MidHandshake, Stream, TlsState};
28use futures_io::{AsyncRead, AsyncWrite};
29use rustls::{ClientConfig, ClientConnection, CommonState, ServerConfig, ServerConnection};
30use std::future::Future;
31use std::io;
32#[cfg(unix)]
33use std::os::unix::io::{AsRawFd, RawFd};
34#[cfg(windows)]
35use std::os::windows::io::{AsRawSocket, RawSocket};
36use std::pin::Pin;
37use std::sync::Arc;
38use std::task::{Context, Poll};
39
40pub use rustls;
41
42#[derive(Clone)]
44pub struct TlsConnector {
45 inner: Arc<ClientConfig>,
46 #[cfg(feature = "early-data")]
47 early_data: bool,
48}
49
50#[derive(Clone)]
52pub struct TlsAcceptor {
53 inner: Arc<ServerConfig>,
54}
55
56impl From<Arc<ClientConfig>> for TlsConnector {
57 fn from(inner: Arc<ClientConfig>) -> TlsConnector {
58 TlsConnector {
59 inner,
60 #[cfg(feature = "early-data")]
61 early_data: false,
62 }
63 }
64}
65
66impl From<Arc<ServerConfig>> for TlsAcceptor {
67 fn from(inner: Arc<ServerConfig>) -> TlsAcceptor {
68 TlsAcceptor { inner }
69 }
70}
71
72impl TlsConnector {
73 #[cfg(feature = "early-data")]
78 pub fn early_data(mut self, flag: bool) -> TlsConnector {
79 self.early_data = flag;
80 self
81 }
82
83 #[inline]
84 pub fn connect<IO>(&self, domain: rustls::ServerName, stream: IO) -> Connect<IO>
85 where
86 IO: AsyncRead + AsyncWrite + Unpin,
87 {
88 self.connect_with(domain, stream, |_| ())
89 }
90
91 pub fn connect_with<IO, F>(&self, domain: rustls::ServerName, stream: IO, f: F) -> Connect<IO>
92 where
93 IO: AsyncRead + AsyncWrite + Unpin,
94 F: FnOnce(&mut ClientConnection),
95 {
96 let mut session = match ClientConnection::new(self.inner.clone(), domain) {
97 Ok(session) => session,
98 Err(error) => {
99 return Connect(MidHandshake::Error {
100 io: stream,
101 error: io::Error::new(io::ErrorKind::Other, error),
104 });
105 }
106 };
107 f(&mut session);
108
109 Connect(MidHandshake::Handshaking(client::TlsStream {
110 io: stream,
111
112 #[cfg(not(feature = "early-data"))]
113 state: TlsState::Stream,
114
115 #[cfg(feature = "early-data")]
116 state: if self.early_data && session.early_data().is_some() {
117 TlsState::EarlyData(0, Vec::new())
118 } else {
119 TlsState::Stream
120 },
121
122 #[cfg(feature = "early-data")]
123 early_waker: None,
124
125 session,
126 }))
127 }
128}
129
130impl TlsAcceptor {
131 #[inline]
132 pub fn accept<IO>(&self, stream: IO) -> Accept<IO>
133 where
134 IO: AsyncRead + AsyncWrite + Unpin,
135 {
136 self.accept_with(stream, |_| ())
137 }
138
139 pub fn accept_with<IO, F>(&self, stream: IO, f: F) -> Accept<IO>
140 where
141 IO: AsyncRead + AsyncWrite + Unpin,
142 F: FnOnce(&mut ServerConnection),
143 {
144 let mut session = match ServerConnection::new(self.inner.clone()) {
145 Ok(session) => session,
146 Err(error) => {
147 return Accept(MidHandshake::Error {
148 io: stream,
149 error: io::Error::new(io::ErrorKind::Other, error),
152 });
153 }
154 };
155 f(&mut session);
156
157 Accept(MidHandshake::Handshaking(server::TlsStream {
158 session,
159 io: stream,
160 state: TlsState::Stream,
161 }))
162 }
163}
164
165pub struct LazyConfigAcceptor<IO> {
166 acceptor: rustls::server::Acceptor,
167 io: Option<IO>,
168}
169
170impl<IO> LazyConfigAcceptor<IO>
171where
172 IO: AsyncRead + AsyncWrite + Unpin,
173{
174 #[inline]
175 pub fn new(acceptor: rustls::server::Acceptor, io: IO) -> Self {
176 Self {
177 acceptor,
178 io: Some(io),
179 }
180 }
181}
182
183impl<IO> Future for LazyConfigAcceptor<IO>
184where
185 IO: AsyncRead + AsyncWrite + Unpin,
186{
187 type Output = Result<StartHandshake<IO>, io::Error>;
188
189 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
190 let this = self.get_mut();
191 loop {
192 let io = match this.io.as_mut() {
193 Some(io) => io,
194 None => {
195 panic!("Acceptor cannot be polled after acceptance.");
196 }
197 };
198
199 let mut reader = common::SyncReadAdapter { io, cx };
200 match this.acceptor.read_tls(&mut reader) {
201 Ok(0) => return Poll::Ready(Err(io::ErrorKind::UnexpectedEof.into())),
202 Ok(_) => {}
203 Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Poll::Pending,
204 Err(e) => return Poll::Ready(Err(e)),
205 }
206
207 match this.acceptor.accept() {
208 Ok(Some(accepted)) => {
209 let io = this.io.take().unwrap();
210 return Poll::Ready(Ok(StartHandshake { accepted, io }));
211 }
212 Ok(None) => continue,
213 Err(err) => {
214 return Poll::Ready(Err(io::Error::new(io::ErrorKind::InvalidInput, err)))
215 }
216 }
217 }
218 }
219}
220
221pub struct StartHandshake<IO> {
222 accepted: rustls::server::Accepted,
223 io: IO,
224}
225
226impl<IO> StartHandshake<IO>
227where
228 IO: AsyncRead + AsyncWrite + Unpin,
229{
230 pub fn client_hello(&self) -> rustls::server::ClientHello<'_> {
231 self.accepted.client_hello()
232 }
233
234 pub fn into_stream(self, config: Arc<ServerConfig>) -> Accept<IO> {
235 self.into_stream_with(config, |_| ())
236 }
237
238 pub fn into_stream_with<F>(self, config: Arc<ServerConfig>, f: F) -> Accept<IO>
239 where
240 F: FnOnce(&mut ServerConnection),
241 {
242 let mut conn = match self.accepted.into_connection(config) {
243 Ok(conn) => conn,
244 Err(error) => {
245 return Accept(MidHandshake::Error {
246 io: self.io,
247 error: io::Error::new(io::ErrorKind::Other, error),
250 });
251 }
252 };
253 f(&mut conn);
254
255 Accept(MidHandshake::Handshaking(server::TlsStream {
256 session: conn,
257 io: self.io,
258 state: TlsState::Stream,
259 }))
260 }
261}
262
263pub struct Connect<IO>(MidHandshake<client::TlsStream<IO>>);
266
267pub struct Accept<IO>(MidHandshake<server::TlsStream<IO>>);
270
271pub struct FallibleConnect<IO>(MidHandshake<client::TlsStream<IO>>);
273
274pub struct FallibleAccept<IO>(MidHandshake<server::TlsStream<IO>>);
276
277impl<IO> Connect<IO> {
278 #[inline]
279 pub fn into_fallible(self) -> FallibleConnect<IO> {
280 FallibleConnect(self.0)
281 }
282
283 pub fn get_ref(&self) -> Option<&IO> {
284 match &self.0 {
285 MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
286 MidHandshake::Error { io, .. } => Some(io),
287 MidHandshake::End => None,
288 }
289 }
290
291 pub fn get_mut(&mut self) -> Option<&mut IO> {
292 match &mut self.0 {
293 MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
294 MidHandshake::Error { io, .. } => Some(io),
295 MidHandshake::End => None,
296 }
297 }
298}
299
300impl<IO> Accept<IO> {
301 #[inline]
302 pub fn into_fallible(self) -> FallibleAccept<IO> {
303 FallibleAccept(self.0)
304 }
305
306 pub fn get_ref(&self) -> Option<&IO> {
307 match &self.0 {
308 MidHandshake::Handshaking(sess) => Some(sess.get_ref().0),
309 MidHandshake::Error { io, .. } => Some(io),
310 MidHandshake::End => None,
311 }
312 }
313
314 pub fn get_mut(&mut self) -> Option<&mut IO> {
315 match &mut self.0 {
316 MidHandshake::Handshaking(sess) => Some(sess.get_mut().0),
317 MidHandshake::Error { io, .. } => Some(io),
318 MidHandshake::End => None,
319 }
320 }
321}
322
323impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Connect<IO> {
324 type Output = io::Result<client::TlsStream<IO>>;
325
326 #[inline]
327 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
328 Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
329 }
330}
331
332impl<IO: AsyncRead + AsyncWrite + Unpin> Future for Accept<IO> {
333 type Output = io::Result<server::TlsStream<IO>>;
334
335 #[inline]
336 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
337 Pin::new(&mut self.0).poll(cx).map_err(|(err, _)| err)
338 }
339}
340
341impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleConnect<IO> {
342 type Output = Result<client::TlsStream<IO>, (io::Error, IO)>;
343
344 #[inline]
345 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
346 Pin::new(&mut self.0).poll(cx)
347 }
348}
349
350impl<IO: AsyncRead + AsyncWrite + Unpin> Future for FallibleAccept<IO> {
351 type Output = Result<server::TlsStream<IO>, (io::Error, IO)>;
352
353 #[inline]
354 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
355 Pin::new(&mut self.0).poll(cx)
356 }
357}
358
359#[allow(clippy::large_enum_variant)] #[derive(Debug)]
365pub enum TlsStream<T> {
366 Client(client::TlsStream<T>),
367 Server(server::TlsStream<T>),
368}
369
370impl<T> TlsStream<T> {
371 pub fn get_ref(&self) -> (&T, &CommonState) {
372 use TlsStream::*;
373 match self {
374 Client(io) => {
375 let (io, session) = io.get_ref();
376 (io, session)
377 }
378 Server(io) => {
379 let (io, session) = io.get_ref();
380 (io, session)
381 }
382 }
383 }
384
385 pub fn get_mut(&mut self) -> (&mut T, &mut CommonState) {
386 use TlsStream::*;
387 match self {
388 Client(io) => {
389 let (io, session) = io.get_mut();
390 (io, &mut *session)
391 }
392 Server(io) => {
393 let (io, session) = io.get_mut();
394 (io, &mut *session)
395 }
396 }
397 }
398}
399
400impl<T> From<client::TlsStream<T>> for TlsStream<T> {
401 fn from(s: client::TlsStream<T>) -> Self {
402 Self::Client(s)
403 }
404}
405
406impl<T> From<server::TlsStream<T>> for TlsStream<T> {
407 fn from(s: server::TlsStream<T>) -> Self {
408 Self::Server(s)
409 }
410}
411
412#[cfg(unix)]
413impl<S> AsRawFd for TlsStream<S>
414where
415 S: AsRawFd,
416{
417 #[inline]
418 fn as_raw_fd(&self) -> RawFd {
419 self.get_ref().0.as_raw_fd()
420 }
421}
422
423#[cfg(windows)]
424impl<S> AsRawSocket for TlsStream<S>
425where
426 S: AsRawSocket,
427{
428 #[inline]
429 fn as_raw_socket(&self) -> RawSocket {
430 self.get_ref().0.as_raw_socket()
431 }
432}
433
434impl<T> AsyncRead for TlsStream<T>
435where
436 T: AsyncRead + AsyncWrite + Unpin,
437{
438 #[inline]
439 fn poll_read(
440 self: Pin<&mut Self>,
441 cx: &mut Context<'_>,
442 buf: &mut [u8],
443 ) -> Poll<io::Result<usize>> {
444 match self.get_mut() {
445 TlsStream::Client(x) => Pin::new(x).poll_read(cx, buf),
446 TlsStream::Server(x) => Pin::new(x).poll_read(cx, buf),
447 }
448 }
449}
450
451impl<T> AsyncWrite for TlsStream<T>
452where
453 T: AsyncRead + AsyncWrite + Unpin,
454{
455 #[inline]
456 fn poll_write(
457 self: Pin<&mut Self>,
458 cx: &mut Context<'_>,
459 buf: &[u8],
460 ) -> Poll<io::Result<usize>> {
461 match self.get_mut() {
462 TlsStream::Client(x) => Pin::new(x).poll_write(cx, buf),
463 TlsStream::Server(x) => Pin::new(x).poll_write(cx, buf),
464 }
465 }
466
467 #[inline]
468 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
469 match self.get_mut() {
470 TlsStream::Client(x) => Pin::new(x).poll_flush(cx),
471 TlsStream::Server(x) => Pin::new(x).poll_flush(cx),
472 }
473 }
474
475 #[inline]
476 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
477 match self.get_mut() {
478 TlsStream::Client(x) => Pin::new(x).poll_close(cx),
479 TlsStream::Server(x) => Pin::new(x).poll_close(cx),
480 }
481 }
482}