1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
use std::{future::Future, io, pin::Pin, sync::Arc, task::Context, task::Poll};

pub use rust_tls::Session;
pub use tokio_rustls::{client::TlsStream, rustls::ClientConfig};

use tokio_rustls::{self, TlsConnector};
use webpki::DNSNameRef;

use crate::rt::net::TcpStream;
use crate::service::{Service, ServiceFactory};
use crate::util::Ready;

use super::{Address, Connect, ConnectError, Connector};

/// Rustls connector factory
pub struct RustlsConnector<T> {
    connector: Connector<T>,
    config: Arc<ClientConfig>,
}

impl<T> RustlsConnector<T> {
    pub fn new(config: Arc<ClientConfig>) -> Self {
        RustlsConnector {
            config,
            connector: Connector::default(),
        }
    }
}

impl<T: Address + 'static> RustlsConnector<T> {
    /// Resolve and connect to remote host
    pub fn connect<U>(
        &self,
        message: U,
    ) -> impl Future<Output = Result<TlsStream<TcpStream>, ConnectError>>
    where
        Connect<T>: From<U>,
    {
        let req = Connect::from(message);
        let host = req.host().split(':').next().unwrap().to_owned();
        let conn = self.connector.call(req);
        let config = self.config.clone();

        async move {
            let io = conn.await?;
            trace!("SSL Handshake start for: {:?}", host);

            let host = DNSNameRef::try_from_ascii_str(&host)
                .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("{}", e)))?;

            match TlsConnector::from(config).connect(host, io).await {
                Ok(io) => {
                    trace!("SSL Handshake success: {:?}", host);
                    Ok(io)
                }
                Err(e) => {
                    trace!("SSL Handshake error: {:?}", e);
                    Err(io::Error::new(io::ErrorKind::Other, format!("{}", e)).into())
                }
            }
        }
    }
}

impl<T> Clone for RustlsConnector<T> {
    fn clone(&self) -> Self {
        Self {
            config: self.config.clone(),
            connector: self.connector.clone(),
        }
    }
}

impl<T: Address + 'static> ServiceFactory for RustlsConnector<T> {
    type Request = Connect<T>;
    type Response = TlsStream<TcpStream>;
    type Error = ConnectError;
    type Config = ();
    type Service = RustlsConnector<T>;
    type InitError = ();
    type Future = Ready<Self::Service, Self::InitError>;

    fn new_service(&self, _: ()) -> Self::Future {
        Ready::Ok(self.clone())
    }
}

impl<T: Address + 'static> Service for RustlsConnector<T> {
    type Request = Connect<T>;
    type Response = TlsStream<TcpStream>;
    type Error = ConnectError;
    type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;

    #[inline]
    fn poll_ready(&self, _: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
        Poll::Ready(Ok(()))
    }

    fn call(&self, req: Connect<T>) -> Self::Future {
        Box::pin(self.connect(req))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::service::{Service, ServiceFactory};

    #[crate::rt_test]
    async fn test_rustls_connect() {
        let server = crate::server::test_server(|| {
            crate::service::fn_service(|_| async { Ok::<_, ()>(()) })
        });

        let mut config = ClientConfig::new();
        config
            .root_store
            .add_server_trust_anchors(&webpki_roots::TLS_SERVER_ROOTS);
        let factory = RustlsConnector::new(Arc::new(config)).clone();

        let srv = factory.new_service(()).await.unwrap();
        let result = srv
            .call(Connect::new("www.rust-lang.org").set_addr(Some(server.addr())))
            .await;
        assert!(result.is_err());
    }
}