pub mod dir;
pub mod file;
pub mod storage;
pub use dir::Provider as Dir;
pub use file::Provider as File;
use ic_bn_lib_common::{
traits::{
Healthy, Run,
tls::{ProvidesCertificates, StoresCertificates},
},
types::tls::{CertKey, Pem},
};
use std::sync::{Arc, Mutex};
use anyhow::{Context, Error, anyhow};
use async_trait::async_trait;
use rustls::sign::CertifiedKey;
use tokio_util::sync::CancellationToken;
use tracing::{debug, warn};
use crate::tls::{extract_sans_der, extract_validity_der, pem_convert_to_rustls_single};
pub fn pem_convert_to_certkey(pem: &[u8]) -> Result<CertKey, Error> {
let cert_key = pem_convert_to_rustls_single(pem)
.context("unable to convert certificate chain and/or private key from PEM")?;
let san = extract_sans_der(&cert_key.cert[0]).context("unable to extract SANs")?;
if san.is_empty() {
return Err(anyhow!(
"no supported names found in SubjectAlternativeName extension"
));
}
let (_, not_after) =
extract_validity_der(&cert_key.cert[0]).context("unable to extract validity")?;
Ok(CertKey {
san,
not_after,
cert: Arc::new(cert_key),
})
}
#[derive(Clone, Debug)]
struct AggregatorSnapshot {
pem: Vec<Option<Vec<Pem>>>,
parsed: Vec<Option<Vec<CertKey>>>,
}
impl AggregatorSnapshot {
fn flatten(&self) -> Vec<CertKey> {
self.parsed
.clone()
.into_iter()
.flatten()
.flatten()
.collect()
}
}
impl PartialEq for AggregatorSnapshot {
fn eq(&self, other: &Self) -> bool {
self.pem == other.pem
}
}
impl Eq for AggregatorSnapshot {}
pub struct Aggregator {
providers: Vec<Arc<dyn ProvidesCertificates>>,
storage: Arc<dyn StoresCertificates<Arc<CertifiedKey>>>,
snapshot: Mutex<AggregatorSnapshot>,
}
impl std::fmt::Debug for Aggregator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "CertificateAggregator")
}
}
fn parse_pem(pem: &[Pem]) -> Result<Vec<CertKey>, Error> {
pem.iter().map(|x| pem_convert_to_certkey(x)).collect()
}
impl Aggregator {
pub fn new(
providers: Vec<Arc<dyn ProvidesCertificates>>,
storage: Arc<dyn StoresCertificates<Arc<CertifiedKey>>>,
) -> Self {
let snapshot = AggregatorSnapshot {
pem: vec![None; providers.len()],
parsed: vec![None; providers.len()],
};
Self {
providers,
storage,
snapshot: Mutex::new(snapshot),
}
}
pub fn is_initialized(&self) -> bool {
self.snapshot
.lock()
.unwrap()
.parsed
.iter()
.all(|x| x.is_some())
}
async fn fetch(&self, mut snapshot: AggregatorSnapshot) -> AggregatorSnapshot {
for (i, p) in self.providers.iter().enumerate() {
match p.get_certificates().await {
Ok(pem) => {
match parse_pem(&pem) {
Ok(mut parsed) => {
parsed.sort_by(|a, b| a.san.cmp(&b.san));
snapshot.pem[i] = Some(pem);
snapshot.parsed[i] = Some(parsed);
}
Err(e) => warn!(
"{self:?}: failed to parse certificates from provider {p:?}: {e:#}"
),
}
}
Err(e) => warn!("{self:?}: failed to fetch from provider {p:?}: {e:#}"),
}
}
snapshot
}
#[allow(clippy::significant_drop_tightening)]
async fn refresh(&self) {
let snapshot_old = self.snapshot.lock().unwrap().clone();
let snapshot = self.fetch(snapshot_old.clone()).await;
if snapshot == snapshot_old {
debug!("{self:?}: certs haven't changed, not updating");
return;
}
let certs = snapshot.flatten();
warn!(
"{self:?}: publishing new snapshot with {} certs",
certs.len()
);
debug!("{self:?}: {} certs fetched:", certs.len());
for v in &certs {
debug!("{self:?}: {:?}", v.san);
}
*self.snapshot.lock().unwrap() = snapshot;
if let Err(e) = self.storage.store(certs) {
warn!("{self:?}: error storing certificates: {e:#}");
}
}
}
impl Healthy for Aggregator {
fn healthy(&self) -> bool {
self.is_initialized()
}
}
#[async_trait]
impl Run for Aggregator {
async fn run(&self, _: CancellationToken) -> Result<(), Error> {
self.refresh().await;
Ok(())
}
}
#[cfg(test)]
pub mod test {
use std::sync::atomic::{AtomicUsize, Ordering};
use prometheus::Registry;
use crate::tests::{TEST_CERT_1, TEST_CERT_2, TEST_KEY_1, TEST_KEY_2};
use super::*;
#[derive(Debug)]
struct TestProvider(Pem, AtomicUsize);
#[async_trait]
impl ProvidesCertificates for TestProvider {
async fn get_certificates(&self) -> Result<Vec<Pem>, Error> {
if self.1.load(Ordering::SeqCst) <= 1 {
self.1.fetch_add(1, Ordering::SeqCst);
Ok(vec![self.0.clone()])
} else {
Err(anyhow!("foo"))
}
}
}
#[derive(Debug)]
struct TestProviderBroken;
#[async_trait]
impl ProvidesCertificates for TestProviderBroken {
async fn get_certificates(&self) -> Result<Vec<Pem>, Error> {
Err(anyhow!("I'm dead"))
}
}
#[test]
fn test_pem_convert_to_certkey() -> Result<(), Error> {
let cert = pem_convert_to_certkey([TEST_KEY_1, TEST_CERT_1].concat().as_bytes())?;
assert_eq!(cert.san, vec!["novg"]);
let cert = pem_convert_to_certkey([TEST_KEY_2, TEST_CERT_2].concat().as_bytes())?;
assert_eq!(cert.san, vec!["devenv-igornovg"]);
Ok(())
}
#[tokio::test]
async fn test_aggregator() -> Result<(), Error> {
let prov1 = Arc::new(TestProvider(
Pem([TEST_KEY_1.as_bytes(), TEST_CERT_1.as_bytes()]
.concat()
.to_vec()),
AtomicUsize::new(0),
));
let prov2 = Arc::new(TestProvider(
Pem([TEST_KEY_2.as_bytes(), TEST_CERT_2.as_bytes()]
.concat()
.to_vec()),
AtomicUsize::new(0),
));
let storage = Arc::new(storage::StorageKey::new(
None,
storage::Metrics::new(&Registry::new()),
));
let aggregator = Aggregator::new(vec![prov1.clone(), prov2.clone()], storage.clone());
aggregator.refresh().await;
assert!(aggregator.healthy());
let aggregator = Aggregator::new(vec![prov1, prov2, Arc::new(TestProviderBroken)], storage);
aggregator.refresh().await;
assert!(!aggregator.healthy());
let certs = aggregator.snapshot.lock().unwrap().clone().flatten();
assert_eq!(certs.len(), 2);
assert_eq!(certs[0].san, vec!["novg"]);
assert_eq!(certs[1].san, vec!["devenv-igornovg"]);
aggregator.refresh().await;
let certs = aggregator.snapshot.lock().unwrap().clone().flatten();
assert_eq!(certs.len(), 2);
assert_eq!(certs[0].san, vec!["novg"]);
assert_eq!(certs[1].san, vec!["devenv-igornovg"]);
Ok(())
}
}