use std::path::PathBuf;
use std::sync::Arc;
use rustls::crypto::CryptoProvider;
use rustls_pki_types::CertificateDer;
use rustls_pki_types::ServerName;
use rustls_platform_verifier::BuilderVerifierExt;
use tokio::sync::watch::Receiver;
use tokio_rustls::TlsConnector;
use tokio_rustls::TlsStream as RustlsStream;
use crate::attributes::Attributes;
use crate::credentials::ChannelCredentials;
use crate::credentials::ProtocolInfo;
use crate::credentials::SecurityLevel;
use crate::credentials::call::CallCredentials;
use crate::credentials::client::ClientConnectionSecurityContext;
use crate::credentials::client::ClientConnectionSecurityInfo;
use crate::credentials::client::ClientHandshakeInfo;
use crate::credentials::client::HandshakeOutput;
use crate::credentials::common::Authority;
use crate::credentials::rustls::ALPN_PROTO_STR_H2;
use crate::credentials::rustls::Identity;
use crate::credentials::rustls::Provider;
use crate::credentials::rustls::RootCertificates;
use crate::credentials::rustls::TLS_PROTO_INFO;
use crate::credentials::rustls::key_log::KeyLogFile;
use crate::credentials::rustls::parse_certs;
use crate::credentials::rustls::parse_key;
use crate::credentials::rustls::sanitize_crypto_provider;
use crate::credentials::rustls::tls_stream::TlsStream;
use crate::private;
use crate::rt::AsyncIoAdapter;
use crate::rt::GrpcEndpoint;
use crate::rt::GrpcRuntime;
#[cfg(test)]
mod test;
pub struct ClientTlsConfig {
pem_roots_provider: Option<Receiver<RootCertificates>>,
identity_provider: Option<Receiver<Identity>>,
key_log_path: Option<PathBuf>,
}
impl ClientTlsConfig {
pub fn new() -> Self {
ClientTlsConfig {
pem_roots_provider: None,
identity_provider: None,
key_log_path: None,
}
}
pub fn with_root_certificates_provider<R>(mut self, provider: R) -> Self
where
R: Provider<RootCertificates>,
{
self.pem_roots_provider = Some(provider.get_receiver(private::Internal));
self
}
pub fn with_identity_provider<I>(mut self, provider: I) -> Self
where
I: Provider<Identity>,
{
self.identity_provider = Some(provider.get_receiver(private::Internal));
self
}
pub fn insecure_with_key_log_path(mut self, path: impl Into<PathBuf>) -> Self {
self.key_log_path = Some(path.into());
self
}
}
impl Default for ClientTlsConfig {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone)]
pub struct RustlsChannelCredendials {
connector: TlsConnector,
}
impl RustlsChannelCredendials {
pub fn new(config: ClientTlsConfig) -> Result<RustlsChannelCredendials, String> {
let provider = if let Some(p) = CryptoProvider::get_default() {
p.as_ref().clone()
} else {
return Err(
"No crypto provider installed. Enable `tls-aws-lc` feature in rustls or install one manually."
.to_string()
);
};
Self::new_impl(config, provider)
}
fn new_impl(
mut config: ClientTlsConfig,
provider: CryptoProvider,
) -> Result<RustlsChannelCredendials, String> {
let provider = sanitize_crypto_provider(provider)?;
let builder = rustls::ClientConfig::builder_with_provider(Arc::new(provider))
.with_protocol_versions(&[&rustls::version::TLS13, &rustls::version::TLS12])
.map_err(|e| e.to_string())?;
let builder = if let Some(mut roots_provider) = config.pem_roots_provider.take() {
let mut root_store = rustls::RootCertStore::empty();
let ca_pem = roots_provider.borrow_and_update();
let certs = parse_certs(ca_pem.get_ref())?;
for cert in certs {
root_store.add(cert).map_err(|e| e.to_string())?;
}
builder.with_root_certificates(root_store)
} else {
builder
.with_platform_verifier()
.map_err(|e| e.to_string())?
};
let mut client_config = if let Some(mut identity_provider) = config.identity_provider.take()
{
let identity = identity_provider.borrow_and_update();
let certs = parse_certs(&identity.certs)?;
let key = parse_key(&identity.key)?;
builder
.with_client_auth_cert(certs, key)
.map_err(|e| e.to_string())?
} else {
builder.with_no_client_auth()
};
client_config.alpn_protocols = vec![ALPN_PROTO_STR_H2.to_vec()];
client_config.resumption = rustls::client::Resumption::disabled();
if let Some(path) = config.key_log_path {
client_config.key_log = Arc::new(KeyLogFile::new(&path))
}
Ok(RustlsChannelCredendials {
connector: TlsConnector::from(Arc::new(client_config)),
})
}
}
pub struct ClientTlsSecurityContext {
verified_peer_cert: Option<CertificateDer<'static>>,
}
impl ClientConnectionSecurityContext for ClientTlsSecurityContext {
fn validate_authority(&self, authority: &Authority) -> bool {
let server_name = match ServerName::try_from(authority.host()) {
Ok(n) => n,
Err(_) => return false,
};
let cert_der = match &self.verified_peer_cert {
Some(c) => c,
None => return false,
};
let cert = match webpki::EndEntityCert::try_from(cert_der) {
Ok(c) => c,
Err(_) => return false,
};
cert.verify_is_valid_for_subject_name(&server_name).is_ok()
}
}
impl ChannelCredentials for RustlsChannelCredendials {
type ContextType = ClientTlsSecurityContext;
type Output<I> = TlsStream<I>;
async fn connect<Input: GrpcEndpoint>(
&self,
authority: &Authority,
source: Input,
_info: &ClientHandshakeInfo,
_rt: &GrpcRuntime,
_token: private::Internal,
) -> Result<HandshakeOutput<TlsStream<Input>, ClientTlsSecurityContext>, String> {
let server_name = ServerName::try_from(authority.host())
.map_err(|e| format!("invalid authority: {}", e))?
.to_owned();
let input_io = AsyncIoAdapter::new(source);
let tls_stream = self
.connector
.connect(server_name, input_io)
.await
.map_err(|e| e.to_string())?;
let (_io, connection) = tls_stream.get_ref();
if let Some(negotiated) = connection.alpn_protocol() {
if negotiated != ALPN_PROTO_STR_H2 {
return Err("Server negotiated unexpected ALPN protocol".into());
}
} else {
return Err("Server did not negotiate ALPN (h2 required)".into());
}
let peer_cert = connection
.peer_certificates()
.and_then(|certs| certs.first())
.map(|c| c.clone().into_owned());
let cs_info = ClientConnectionSecurityInfo::new(
"tls",
SecurityLevel::PrivacyAndIntegrity,
ClientTlsSecurityContext {
verified_peer_cert: peer_cert,
},
Attributes::new(),
);
let ep = TlsStream::new(RustlsStream::Client(tls_stream));
Ok(HandshakeOutput {
endpoint: ep,
security: cs_info,
})
}
fn info(&self) -> &ProtocolInfo {
&TLS_PROTO_INFO
}
fn get_call_credentials(&self, _: private::Internal) -> Option<&Arc<dyn CallCredentials>> {
None
}
}