use std::{
collections::BTreeMap,
fmt,
sync::Arc,
time::{Duration, Instant},
};
use anyhow::{anyhow, Context, Error};
use arc_swap::ArcSwapOption;
use chrono::{DateTime, FixedOffset, TimeDelta, Utc};
use rasn_ocsp::CertStatus;
use rustls::{
pki_types::CertificateDer,
server::{ClientHello, ResolvesServerCert},
sign::CertifiedKey,
};
use sha1::{Digest, Sha1};
use tokio::sync::mpsc;
use tokio_util::{sync::CancellationToken, task::TaskTracker};
use tracing::warn;
use x509_parser::prelude::*;
use super::{client::Client, OcspValidity};
type Storage = BTreeMap<Fingerprint, Cert>;
#[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
struct Fingerprint([u8; 20]);
impl From<&CertificateDer<'_>> for Fingerprint {
fn from(v: &CertificateDer) -> Self {
let digest = Sha1::digest(v.as_ref());
Self(digest.into())
}
}
#[derive(Clone)]
struct Cert {
ckey: Arc<CertifiedKey>,
cert_validity: DateTime<FixedOffset>,
ocsp_validity: Option<OcspValidity>,
}
pub struct Stapler {
tx: mpsc::Sender<(Fingerprint, Arc<CertifiedKey>)>,
storage: Arc<ArcSwapOption<Storage>>,
inner: Arc<dyn ResolvesServerCert>,
tracker: TaskTracker,
token: CancellationToken,
}
impl Stapler {
pub fn new_with_client(inner: Arc<dyn ResolvesServerCert>, client: Client) -> Self {
let (tx, rx) = mpsc::channel(1024);
let storage = Arc::new(ArcSwapOption::empty());
let tracker = TaskTracker::new();
let token = CancellationToken::new();
let mut actor = StaplerActor {
client,
storage: BTreeMap::new(),
rx,
published: storage.clone(),
};
let actor_token = token.clone();
tracker.spawn(async move {
actor.run(actor_token).await;
});
Self {
tx,
storage,
inner,
tracker,
token,
}
}
pub fn new(inner: Arc<dyn ResolvesServerCert>) -> Self {
Self::new_with_client(inner, Client::new())
}
pub async fn stop(&self) {
self.token.cancel();
self.tracker.close();
self.tracker.wait().await;
}
}
impl fmt::Debug for Stapler {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "OcspStapler")
}
}
impl ResolvesServerCert for Stapler {
fn resolve(&self, client_hello: ClientHello) -> Option<Arc<CertifiedKey>> {
let ckey = self.inner.resolve(client_hello)?;
if ckey.cert.len() < 2 {
return Some(ckey);
}
let fp = Fingerprint::from(&ckey.cert[0]);
if let Some(map) = self.storage.load_full() {
if let Some(v) = map.get(&fp) {
if v.ocsp_validity.is_some() {
return Some(v.ckey.clone());
}
return Some(ckey);
}
}
let _ = self.tx.try_send((fp, ckey.clone()));
Some(ckey)
}
}
struct StaplerActor {
client: Client,
storage: Storage,
rx: mpsc::Receiver<(Fingerprint, Arc<CertifiedKey>)>,
published: Arc<ArcSwapOption<Storage>>,
}
impl StaplerActor {
async fn refresh(&mut self) {
if self.storage.is_empty() {
return;
}
let now: DateTime<FixedOffset> = Utc::now().into();
self.storage.retain(|_, v| v.cert_validity > now);
let start = Instant::now();
for v in self.storage.values_mut() {
if let Some(x) = &v.ocsp_validity {
if !x.time_to_update(now) {
continue;
}
if x.next_update - now < TimeDelta::hours(1) {
v.ocsp_validity = None
}
}
let cert = v.ckey.cert[0].as_ref();
let issuer = v.ckey.cert[1].as_ref();
let resp = match self.client.query(cert, issuer).await {
Err(e) => {
warn!("OCSP-Stapler: unable to perform OCSP request: {e:#}");
continue;
}
Ok(v) => v,
};
if let CertStatus::Revoked(x) = resp.cert_status {
warn!("OCSP-Stapler: certificate was revoked: {x:?}");
}
let mut ckey = v.ckey.as_ref().clone();
ckey.ocsp = Some(resp.raw);
v.ckey = Arc::new(ckey);
v.ocsp_validity = Some(resp.ocsp_validity);
}
let new = Arc::new(self.storage.clone());
self.published.store(Some(new));
warn!(
"OCSP-Stapler: certificates refreshed in {}ms",
start.elapsed().as_millis()
);
}
async fn process_certificate(
&mut self,
fp: Fingerprint,
ckey: Arc<CertifiedKey>,
) -> Result<(), Error> {
if self.storage.contains_key(&fp) {
return Ok(());
}
let cert = X509Certificate::from_der(ckey.end_entity_cert().unwrap())
.context("unable to parse certificate as X.509")?
.1;
let cert_validity = DateTime::from_timestamp(cert.validity.not_after.timestamp(), 0)
.ok_or_else(|| anyhow!("unable to parse NotAfter"))?
.into();
let cert = Cert {
ckey: ckey.clone(),
cert_validity,
ocsp_validity: None,
};
self.storage.insert(fp, cert);
self.refresh().await;
Ok(())
}
async fn run(&mut self, token: CancellationToken) {
let mut interval = tokio::time::interval(Duration::from_secs(60));
loop {
tokio::select! {
biased;
() = token.cancelled() => {
warn!("OCSP-Stapler: exiting");
return;
}
_ = interval.tick() => {
self.refresh().await;
},
msg = self.rx.recv() => {
if let Some((fp, ckey)) = msg {
if let Err(e) = self.process_certificate(fp, ckey).await {
warn!("OCSP-Stapler: unable to process certificate: {e:#}");
}
}
}
}
}
}
}