1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
use std::future::Future;
use std::io;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use futures_util::io::{AsyncRead, AsyncWrite};
use rustls::{ClientConfig, ClientSession};
use webpki::DNSNameRef;
use crate::{handshake, MidHandshake, TlsStream};
#[derive(Clone)]
pub struct TlsConnector {
inner: Arc<ClientConfig>,
}
impl From<Arc<ClientConfig>> for TlsConnector {
fn from(inner: Arc<ClientConfig>) -> TlsConnector {
TlsConnector { inner }
}
}
impl Default for TlsConnector {
fn default() -> Self {
let mut config = ClientConfig::new();
config
.root_store
.add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
Arc::new(config).into()
}
}
impl TlsConnector {
pub fn new() -> Self {
Default::default()
}
pub fn connect<IO>(&self, domain: impl AsRef<str>, stream: IO) -> Connect<IO> {
let domain = match DNSNameRef::try_from_ascii_str(domain.as_ref()) {
Ok(domain) => domain,
Err(_) => {
return Connect(Err(io::Error::new(
io::ErrorKind::InvalidInput,
"invalid domain",
)));
}
};
let session = ClientSession::new(&self.inner, domain);
Connect(Ok(handshake(session, stream)))
}
}
pub struct Connect<IO>(io::Result<MidHandshake<ClientSession, IO>>);
impl<IO> Future for Connect<IO>
where
IO: AsyncRead + AsyncWrite + Unpin,
{
type Output = io::Result<TlsStream<ClientSession, IO>>;
#[inline]
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
match self.0 {
Ok(ref mut mid_handshake) => Pin::new(mid_handshake).poll(cx),
Err(ref err) => Poll::Ready(Err(io::Error::new(err.kind(), err.to_string()))),
}
}
}