1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
// ref https://github.com/async-rs/async-tls/blob/v0.7.1/src/connector.rs

use std::future::Future;
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};

use futures_util::io::{AsyncRead, AsyncWrite};
use rustls::{ClientConfig, ClientSession};
use webpki::DNSNameRef;

use crate::{handshake, MidHandshake, TlsStream};

#[derive(Clone)]
pub struct TlsConnector {
    inner: Arc<ClientConfig>,
}

impl From<Arc<ClientConfig>> for TlsConnector {
    fn from(inner: Arc<ClientConfig>) -> TlsConnector {
        TlsConnector { inner }
    }
}

impl Default for TlsConnector {
    fn default() -> Self {
        let mut config = ClientConfig::new();
        config
            .root_store
            .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
        Arc::new(config).into()
    }
}

impl TlsConnector {
    pub fn new() -> Self {
        Default::default()
    }

    pub fn connect<IO>(&self, domain: impl AsRef<str>, stream: IO) -> Connect<IO> {
        let domain = match DNSNameRef::try_from_ascii_str(domain.as_ref()) {
            Ok(domain) => domain,
            Err(_) => {
                return Connect(Err(io::Error::new(
                    io::ErrorKind::InvalidInput,
                    "invalid domain",
                )));
            }
        };

        let session = ClientSession::new(&self.inner, domain);

        Connect(Ok(handshake(session, stream)))
    }
}

pub struct Connect<IO>(io::Result<MidHandshake<ClientSession, IO>>);

impl<IO> Future for Connect<IO>
where
    IO: AsyncRead + AsyncWrite + Unpin,
{
    type Output = io::Result<TlsStream<ClientSession, IO>>;

    #[inline]
    fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
        match self.0 {
            Ok(ref mut mid_handshake) => Pin::new(mid_handshake).poll(cx),
            Err(ref err) => Poll::Ready(Err(io::Error::new(err.kind(), err.to_string()))),
        }
    }
}