ureq 3.0.0-rc5

Simple, safe HTTP client
Documentation
use std::convert::TryInto;
use std::fmt;
use std::io::{Read, Write};
use std::sync::Arc;

use once_cell::sync::OnceCell;
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::crypto::CryptoProvider;
use rustls::{ClientConfig, ClientConnection, RootCertStore, StreamOwned, ALL_VERSIONS};
use rustls_pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs1KeyDer, PrivatePkcs8KeyDer};
use rustls_pki_types::{PrivateSec1KeyDer, ServerName};

use crate::tls::cert::KeyKind;
use crate::tls::{RootCerts, TlsProvider};
use crate::transport::{Buffers, ConnectionDetails, Connector, LazyBuffers};
use crate::transport::{Either, NextTimeout, Transport, TransportAdapter};
use crate::Error;

use super::TlsConfig;

/// Wrapper for TLS using rustls.
///
/// Requires feature flag **rustls**.
#[derive(Default)]
pub struct RustlsConnector {
    config: OnceCell<Arc<ClientConfig>>,
}

impl<In: Transport> Connector<In> for RustlsConnector {
    type Out = Either<In, RustlsTransport>;

    fn connect(
        &self,
        details: &ConnectionDetails,
        chained: Option<In>,
    ) -> Result<Option<Self::Out>, Error> {
        let Some(transport) = chained else {
            panic!("RustlConnector requires a chained transport");
        };

        // Only add TLS if we are connecting via HTTPS and the transport isn't TLS
        // already, otherwise use chained transport as is.
        if !details.needs_tls() || transport.is_tls() {
            trace!("Skip");
            return Ok(Some(Either::A(transport)));
        }

        if details.config.tls_config().provider != TlsProvider::Rustls {
            debug!("Skip because config is not set to Rustls");
            return Ok(Some(Either::A(transport)));
        }

        trace!("Try wrap in TLS");

        let tls_config = details.config.tls_config();

        // Initialize the config on first run.
        let config_ref = self.config.get_or_init(|| build_config(tls_config));
        let config = config_ref.clone(); // cheap clone due to Arc

        let name_borrowed: ServerName<'_> = details
            .uri
            .authority()
            .expect("uri authority for tls")
            .host()
            .try_into()
            .map_err(|e| {
                warn!("rustls invalid dns name: {}", e);
                Error::Tls("Rustls invalid dns name error")
            })?;

        let name = name_borrowed.to_owned();

        let conn = ClientConnection::new(config, name)?;
        let stream = StreamOwned {
            conn,
            sock: TransportAdapter::new(transport.boxed()),
        };

        let buffers = LazyBuffers::new(
            details.config.input_buffer_size(),
            details.config.output_buffer_size(),
        );

        let transport = RustlsTransport { buffers, stream };

        debug!("Wrapped TLS");

        Ok(Some(Either::B(transport)))
    }
}

fn build_config(tls_config: &TlsConfig) -> Arc<ClientConfig> {
    // 1. Prefer provider set by TlsConfig.
    // 2. Use process wide default set in rustls library.
    // 3. Pick ring, if it is enabled (the default behavior).
    // 4. Error (never pick up a default from feature flags alone).
    let provider = tls_config
        .rustls_crypto_provider
        .clone()
        .or(rustls::crypto::CryptoProvider::get_default().cloned())
        .unwrap_or(ring_if_enabled());

    #[cfg(feature = "_ring")]
    fn ring_if_enabled() -> Arc<CryptoProvider> {
        Arc::new(rustls::crypto::ring::default_provider())
    }

    #[cfg(not(feature = "_ring"))]
    fn ring_if_enabled() -> Arc<CryptoProvider> {
        panic!("No CryptoProvider for Rustls. Enable the feature `ring` or configure the Agent.");
    }

    let builder = ClientConfig::builder_with_provider(provider.clone())
        .with_protocol_versions(ALL_VERSIONS)
        .expect("all TLS versions");

    let builder = if tls_config.disable_verification {
        debug!("Certificate verification disabled");
        builder
            .dangerous()
            .with_custom_certificate_verifier(Arc::new(DisabledVerifier))
    } else {
        match &tls_config.root_certs {
            RootCerts::Specific(certs) => {
                let root_certs = certs.iter().map(|c| CertificateDer::from(c.der()));

                let mut root_store = RootCertStore::empty();
                let (added, ignored) = root_store.add_parsable_certificates(root_certs);
                debug!("Added {} and ignored {} root certs", added, ignored);

                builder.with_root_certificates(root_store)
            }
            #[cfg(not(feature = "platform-verifier"))]
            RootCerts::PlatformVerifier => {
                panic!("Rustls + PlatformVerifier requires feature: platform-verifier");
            }
            #[cfg(feature = "platform-verifier")]
            RootCerts::PlatformVerifier => builder
                // This actually not dangerous. The rustls_platform_verifier is safe.
                .dangerous()
                .with_custom_certificate_verifier(Arc::new(
                    rustls_platform_verifier::Verifier::new().with_provider(provider),
                )),
            RootCerts::WebPki => {
                let root_store = RootCertStore {
                    roots: webpki_roots::TLS_SERVER_ROOTS.to_vec(),
                };
                builder.with_root_certificates(root_store)
            }
        }
    };

    let mut config = if let Some(certs_and_key) = &tls_config.client_cert {
        let cert_chain = certs_and_key
            .certs()
            .iter()
            .map(|c| CertificateDer::from(c.der()).into_owned());

        let key = certs_and_key.private_key();

        let key_der = match key.kind() {
            KeyKind::Pkcs1 => PrivateKeyDer::Pkcs1(PrivatePkcs1KeyDer::from(key.der())),
            KeyKind::Pkcs8 => PrivateKeyDer::Pkcs8(PrivatePkcs8KeyDer::from(key.der())),
            KeyKind::Sec1 => PrivateKeyDer::Sec1(PrivateSec1KeyDer::from(key.der())),
        }
        .clone_key();
        debug!("Use client certficiate with key kind {:?}", key.kind());

        builder
            .with_client_auth_cert(cert_chain.collect(), key_der)
            .expect("valid client auth certificate")
    } else {
        builder.with_no_client_auth()
    };

    config.enable_sni = tls_config.use_sni;

    if !tls_config.use_sni {
        debug!("Disable SNI");
    }

    Arc::new(config)
}

pub struct RustlsTransport {
    buffers: LazyBuffers,
    stream: StreamOwned<ClientConnection, TransportAdapter>,
}

impl Transport for RustlsTransport {
    fn buffers(&mut self) -> &mut dyn Buffers {
        &mut self.buffers
    }

    fn transmit_output(&mut self, amount: usize, timeout: NextTimeout) -> Result<(), Error> {
        self.stream.get_mut().set_timeout(timeout);

        let output = &self.buffers.output()[..amount];
        self.stream.write_all(output)?;

        Ok(())
    }

    fn await_input(&mut self, timeout: NextTimeout) -> Result<bool, Error> {
        if self.buffers.can_use_input() {
            return Ok(true);
        }

        self.stream.get_mut().set_timeout(timeout);

        let input = self.buffers.input_append_buf();
        let amount = self.stream.read(input)?;
        self.buffers.input_appended(amount);

        Ok(amount > 0)
    }

    fn is_open(&mut self) -> bool {
        self.stream.get_mut().get_mut().is_open()
    }

    fn is_tls(&self) -> bool {
        true
    }
}

#[derive(Debug)]
struct DisabledVerifier;

impl ServerCertVerifier for DisabledVerifier {
    fn verify_server_cert(
        &self,
        _end_entity: &CertificateDer<'_>,
        _intermediates: &[CertificateDer<'_>],
        _server_name: &rustls_pki_types::ServerName<'_>,
        _ocsp_response: &[u8],
        _now: rustls_pki_types::UnixTime,
    ) -> Result<ServerCertVerified, rustls::Error> {
        Ok(ServerCertVerified::assertion())
    }

    fn verify_tls12_signature(
        &self,
        _message: &[u8],
        _cert: &CertificateDer<'_>,
        _dss: &rustls::DigitallySignedStruct,
    ) -> Result<HandshakeSignatureValid, rustls::Error> {
        Ok(HandshakeSignatureValid::assertion())
    }

    fn verify_tls13_signature(
        &self,
        _message: &[u8],
        _cert: &CertificateDer<'_>,
        _dss: &rustls::DigitallySignedStruct,
    ) -> Result<HandshakeSignatureValid, rustls::Error> {
        Ok(HandshakeSignatureValid::assertion())
    }

    fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
        vec![
            rustls::SignatureScheme::RSA_PKCS1_SHA1,
            rustls::SignatureScheme::RSA_PKCS1_SHA256,
            rustls::SignatureScheme::RSA_PKCS1_SHA384,
            rustls::SignatureScheme::RSA_PKCS1_SHA512,
            rustls::SignatureScheme::ECDSA_NISTP256_SHA256,
            rustls::SignatureScheme::ECDSA_NISTP384_SHA384,
            rustls::SignatureScheme::ECDSA_NISTP521_SHA512,
            rustls::SignatureScheme::RSA_PSS_SHA256,
            rustls::SignatureScheme::RSA_PSS_SHA384,
            rustls::SignatureScheme::RSA_PSS_SHA512,
            rustls::SignatureScheme::ED25519,
            rustls::SignatureScheme::ED448,
        ]
    }
}

impl fmt::Debug for RustlsConnector {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("RustlsConnector").finish()
    }
}

impl fmt::Debug for RustlsTransport {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("RustlsTransport")
            .field("chained", &self.stream.sock.inner())
            .finish()
    }
}