use std::sync::{Arc, RwLock};
use derive_where::derive_where;
use ignore_result::Ignore;
use rustls::pki_types::{CertificateDer, CertificateRevocationListDer, PrivateKeyDer};
use rustls::RootCertStore;
use super::TlsClient;
use crate::client::Result;
use crate::Error;
type PemItem = rustls_pemfile::Item;
#[derive(Clone, Debug)]
pub struct TlsCa {
pub(super) roots: RootCertStore,
pub(super) crls: Vec<CertificateRevocationListDer<'static>>,
}
impl TlsCa {
pub fn from_pem(pem: &str) -> Result<Self> {
let mut ca = Self { roots: RootCertStore::empty(), crls: Vec::new() };
for r in rustls_pemfile::read_all(&mut pem.as_bytes()) {
match r {
Ok(PemItem::X509Certificate(cert)) => ca.roots.add(cert).ignore(),
Ok(PemItem::Crl(crl)) => ca.crls.push(crl),
Ok(_) => continue,
Err(err) => return Err(Error::with_other("fail to read ca", err)),
}
}
if ca.roots.is_empty() {
return Err(Error::BadArguments(&"no valid tls trust anchor in pem"));
}
Ok(ca)
}
fn merge(&mut self, ca: TlsCa) {
self.roots.roots.extend(ca.roots.roots);
self.crls.extend(ca.crls);
}
}
#[derive_where(Debug)]
pub struct TlsIdentity {
pub(super) cert: Vec<CertificateDer<'static>>,
#[derive_where(skip)]
pub(super) key: PrivateKeyDer<'static>,
}
impl TlsIdentity {
pub fn from_pem(cert: &str, key: &str) -> Result<Self> {
let r: Result<Vec<_>, _> = rustls_pemfile::certs(&mut cert.as_bytes()).collect();
let cert = match r {
Err(err) => return Err(Error::with_other("fail to read cert", err)),
Ok(cert) => cert,
};
let key = match rustls_pemfile::private_key(&mut key.as_bytes()) {
Err(err) => return Err(Error::with_other("fail to read client private key", err)),
Ok(None) => return Err(Error::BadArguments(&"no client private key")),
Ok(Some(key)) => key,
};
Ok(Self { cert, key })
}
}
impl Clone for TlsIdentity {
fn clone(&self) -> Self {
Self { cert: self.cert.clone(), key: self.key.clone_key() }
}
}
#[derive(Clone, Debug)]
pub struct TlsCerts {
pub(super) ca: TlsCa,
pub(super) identity: Option<TlsIdentity>,
}
impl TlsCerts {
pub fn builder() -> TlsCertsBuilder {
TlsCertsBuilder::new()
}
}
#[derive(Clone, Debug)]
pub struct TlsCertsBuilder {
ca: Option<TlsCa>,
identity: Option<TlsIdentity>,
}
impl TlsCertsBuilder {
fn new() -> Self {
Self { ca: None, identity: None }
}
pub fn with_ca(mut self, ca: TlsCa) -> Self {
self.ca = Some(ca);
self
}
pub fn with_identity(mut self, identity: TlsIdentity) -> Self {
self.identity = Some(identity);
self
}
pub fn build(self) -> Result<TlsCerts> {
let ca = match self.ca {
None => return Err(Error::BadArguments(&"no tls ca")),
Some(ca) => ca,
};
Ok(TlsCerts { ca, identity: self.identity })
}
}
#[derive(Clone, Debug)]
pub struct TlsCertsOptions {
certs: TlsInnerCerts,
}
#[derive(Clone, Debug)]
pub(super) enum TlsInnerCerts {
Static(TlsCerts),
Dynamic(TlsDynamicCerts),
}
impl From<TlsCertsOptions> for TlsInnerCerts {
fn from(options: TlsCertsOptions) -> Self {
options.certs
}
}
impl From<TlsInnerCerts> for TlsCertsOptions {
fn from(certs: TlsInnerCerts) -> Self {
Self { certs }
}
}
impl From<TlsCerts> for TlsCertsOptions {
fn from(certs: TlsCerts) -> Self {
TlsInnerCerts::Static(certs).into()
}
}
impl From<TlsDynamicCerts> for TlsCertsOptions {
fn from(certs: TlsDynamicCerts) -> Self {
TlsInnerCerts::Dynamic(certs).into()
}
}
#[derive(Clone, Debug)]
pub struct TlsDynamicCerts {
certs: Arc<RwLock<(u64, Arc<TlsCerts>)>>,
}
impl TlsDynamicCerts {
pub fn new(certs: TlsCerts) -> Self {
let certs = certs.into();
Self { certs: Arc::new(RwLock::new((1, certs))) }
}
pub fn update(&self, certs: TlsCerts) {
let certs = certs.into();
let mut writer = self.certs.write().unwrap();
writer.0 += 1;
let old = std::mem::replace(&mut writer.1, certs);
drop(writer);
drop(old);
}
pub fn update_ca(&self, ca: TlsCa) {
self.update_partially(|certs| certs.ca = ca.clone())
}
pub fn update_identity(&self, identity: Option<TlsIdentity>) {
self.update_partially(|certs| certs.identity = identity.clone())
}
fn update_versioned(&self, version: u64, certs: TlsCerts) -> bool {
let certs = certs.into();
let mut writer = self.certs.write().unwrap();
if writer.0 != version {
return false;
}
writer.0 += 1;
let old = std::mem::replace(&mut writer.1, certs);
drop(writer);
drop(old);
true
}
fn update_partially(&self, update: impl Fn(&mut TlsCerts)) {
loop {
let (version, certs) = self.get_versioned();
let mut certs = (*certs).clone();
update(&mut certs);
if self.update_versioned(version, certs) {
break;
}
}
}
pub(crate) fn get_versioned(&self) -> (u64, Arc<TlsCerts>) {
self.certs.read().unwrap().clone()
}
pub(crate) fn get_updated(&self, version: u64) -> Option<(u64, Arc<TlsCerts>)> {
let locked = self.certs.read().unwrap();
if version >= locked.0 {
return None;
}
Some(locked.clone())
}
}
#[derive(Clone, Debug)]
pub struct TlsOptions {
ca: Option<TlsCa>,
identity: Option<TlsIdentity>,
certs: Option<TlsCertsOptions>,
hostname_verification: bool,
#[cfg(all(feature = "fips", not(feature = "fips-only")))]
fips: bool,
}
impl Default for TlsOptions {
fn default() -> Self {
Self::new()
}
}
impl TlsOptions {
#[deprecated(since = "0.10.0", note = "use TlsOptions::new instead")]
pub fn no_ca() -> Self {
Self::new()
}
pub fn new() -> Self {
Self {
ca: None,
identity: None,
certs: None,
hostname_verification: true,
#[cfg(all(feature = "fips", not(feature = "fips-only")))]
fips: false,
}
}
pub unsafe fn with_no_hostname_verification(mut self) -> Self {
self.hostname_verification = false;
self
}
#[cfg(feature = "fips")]
#[cfg_attr(docsrs, doc(cfg(any(feature = "fips", feature = "fips-only"))))]
pub fn with_fips(self) -> Self {
self.with_fips_internal()
}
#[cfg(all(feature = "fips", not(feature = "fips-only")))]
fn with_fips_internal(mut self) -> Self {
self.fips = true;
self
}
#[cfg(feature = "fips-only")]
fn with_fips_internal(self) -> Self {
self
}
#[deprecated(since = "0.10.0", note = "use TlsOptions::with_pem_ca instead")]
pub fn with_pem_ca_certs(mut self, certs: &str) -> Result<Self> {
let mut ca = TlsCa::from_pem(certs)?;
ca.crls.clear();
match self.ca.as_mut() {
None => self.ca = Some(ca),
Some(existing_ca) => existing_ca.merge(ca),
};
Ok(self)
}
pub fn with_pem_ca(mut self, ca: &str) -> Result<Self> {
self.ca = Some(TlsCa::from_pem(ca)?);
Ok(self)
}
pub fn with_pem_identity(mut self, cert: &str, key: &str) -> Result<Self> {
self.identity = Some(TlsIdentity::from_pem(cert, key)?);
Ok(self)
}
pub fn with_certs(mut self, certs: impl Into<TlsCertsOptions>) -> Self {
self.certs = Some(certs.into());
self
}
pub(super) fn into_client_options(self) -> Result<TlsClientOptions<TlsInnerCerts>> {
let certs = match self.certs.map(TlsInnerCerts::from) {
None => {
let certs = TlsCertsBuilder { ca: self.ca, identity: self.identity }.build()?;
TlsInnerCerts::Static(certs)
},
Some(certs) => certs,
};
Ok(TlsClientOptions {
certs,
hostname_verification: self.hostname_verification,
#[cfg(all(feature = "fips", not(feature = "fips-only")))]
fips: self.fips,
})
}
pub(crate) fn into_client(self) -> Result<TlsClient> {
let options = self.into_client_options()?;
TlsClient::new(options)
}
}
pub(super) struct TlsClientOptions<Certs> {
pub certs: Certs,
pub hostname_verification: bool,
#[cfg(all(feature = "fips", not(feature = "fips-only")))]
pub fips: bool,
}