1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
#![doc(html_root_url = "https://docs.rs/tokio-postgres-native-tls/0.1.0-rc.1")]
#![warn(rust_2018_idioms, clippy::all, missing_docs)]
use futures::{try_ready, Async, Future, Poll};
use tokio_io::{AsyncRead, AsyncWrite};
#[cfg(feature = "runtime")]
use tokio_postgres::tls::MakeTlsConnect;
use tokio_postgres::tls::{ChannelBinding, TlsConnect};
use tokio_tls::{Connect, TlsStream};
#[cfg(test)]
mod test;
#[cfg(feature = "runtime")]
#[derive(Clone)]
pub struct MakeTlsConnector(native_tls::TlsConnector);
#[cfg(feature = "runtime")]
impl MakeTlsConnector {
pub fn new(connector: native_tls::TlsConnector) -> MakeTlsConnector {
MakeTlsConnector(connector)
}
}
#[cfg(feature = "runtime")]
impl<S> MakeTlsConnect<S> for MakeTlsConnector
where
S: AsyncRead + AsyncWrite,
{
type Stream = TlsStream<S>;
type TlsConnect = TlsConnector;
type Error = native_tls::Error;
fn make_tls_connect(&mut self, domain: &str) -> Result<TlsConnector, native_tls::Error> {
Ok(TlsConnector::new(self.0.clone(), domain))
}
}
pub struct TlsConnector {
connector: tokio_tls::TlsConnector,
domain: String,
}
impl TlsConnector {
pub fn new(connector: native_tls::TlsConnector, domain: &str) -> TlsConnector {
TlsConnector {
connector: tokio_tls::TlsConnector::from(connector),
domain: domain.to_string(),
}
}
}
impl<S> TlsConnect<S> for TlsConnector
where
S: AsyncRead + AsyncWrite,
{
type Stream = TlsStream<S>;
type Error = native_tls::Error;
type Future = TlsConnectFuture<S>;
fn connect(self, stream: S) -> TlsConnectFuture<S> {
TlsConnectFuture(self.connector.connect(&self.domain, stream))
}
}
pub struct TlsConnectFuture<S>(Connect<S>);
impl<S> Future for TlsConnectFuture<S>
where
S: AsyncRead + AsyncWrite,
{
type Item = (TlsStream<S>, ChannelBinding);
type Error = native_tls::Error;
fn poll(&mut self) -> Poll<(TlsStream<S>, ChannelBinding), native_tls::Error> {
let stream = try_ready!(self.0.poll());
let channel_binding = match stream.get_ref().tls_server_end_point().unwrap_or(None) {
Some(buf) => ChannelBinding::tls_server_end_point(buf),
None => ChannelBinding::none(),
};
Ok(Async::Ready((stream, channel_binding)))
}
}