use std::{
collections::BTreeMap,
fmt::{self, Display},
sync::Arc,
time::{Duration, Instant},
};
use anyhow::{anyhow, Context, Error};
use arc_swap::ArcSwapOption;
use chrono::{DateTime, FixedOffset, Utc};
use itertools::Itertools;
use prometheus::{
register_histogram_vec_with_registry, register_int_counter_vec_with_registry,
register_int_gauge_vec_with_registry, HistogramVec, IntCounterVec, IntGaugeVec, Registry,
};
use rasn_ocsp::{CertStatus, OcspResponseStatus};
use rustls::{
pki_types::CertificateDer,
server::{ClientHello, ResolvesServerCert},
sign::CertifiedKey,
};
use sha1::{Digest, Sha1};
use tokio::sync::mpsc;
use tokio_util::{sync::CancellationToken, task::TaskTracker};
use tracing::{info, warn};
use x509_parser::prelude::*;
use super::{client::Client, Validity, LEEWAY};
type Storage = BTreeMap<Fingerprint, Cert>;
#[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
struct Fingerprint([u8; 20]);
impl From<&CertificateDer<'_>> for Fingerprint {
fn from(v: &CertificateDer) -> Self {
let digest = Sha1::digest(v.as_ref());
Self(digest.into())
}
}
#[derive(PartialEq, Eq)]
enum RefreshResult {
StillValid,
Refreshed,
}
#[derive(Clone)]
struct Cert {
ckey: Arc<CertifiedKey>,
subject: String,
status: CertStatus,
cert_validity: Validity,
ocsp_validity: Option<Validity>,
}
impl Display for Cert {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.subject)
}
}
#[derive(Clone)]
struct Metrics {
resolves: IntCounterVec,
ocsp_requests: IntCounterVec,
ocsp_requests_duration: HistogramVec,
certificate_count: IntGaugeVec,
}
impl Metrics {
fn new(registry: &Registry) -> Self {
Self {
resolves: register_int_counter_vec_with_registry!(
format!("ocsp_resolves_total"),
format!("Counts the number of certificate resolve requests"),
&["stapled"],
registry
)
.unwrap(),
ocsp_requests: register_int_counter_vec_with_registry!(
format!("ocsp_requests_total"),
format!("Counts the number of OCSP requests"),
&["status"],
registry
)
.unwrap(),
ocsp_requests_duration: register_histogram_vec_with_registry!(
format!("ocsp_requests_duration"),
format!("Observes OCSP requests duration"),
&["status"],
registry
)
.unwrap(),
certificate_count: register_int_gauge_vec_with_registry!(
format!("ocsp_certificate_count"),
format!("Current number of certificates in storage"),
&["status"],
registry
)
.unwrap(),
}
}
}
pub struct Stapler {
tx: mpsc::Sender<(Fingerprint, Arc<CertifiedKey>)>,
storage: Arc<ArcSwapOption<Storage>>,
inner: Arc<dyn ResolvesServerCert>,
tracker: TaskTracker,
token: CancellationToken,
metrics: Option<Metrics>,
}
impl Stapler {
pub fn new(inner: Arc<dyn ResolvesServerCert>) -> Self {
Self::new_with_client_and_registry(inner, Client::new(), None)
}
pub fn new_with_registry(inner: Arc<dyn ResolvesServerCert>, registry: &Registry) -> Self {
Self::new_with_client_and_registry(inner, Client::new(), Some(registry))
}
pub fn new_with_client_and_registry(
inner: Arc<dyn ResolvesServerCert>,
client: Client,
registry: Option<&Registry>,
) -> Self {
let (tx, rx) = mpsc::channel(1024);
let storage = Arc::new(ArcSwapOption::empty());
let tracker = TaskTracker::new();
let token = CancellationToken::new();
let metrics = registry.map(Metrics::new);
let mut actor = StaplerActor {
client,
storage: BTreeMap::new(),
rx,
published: storage.clone(),
metrics: metrics.clone(),
};
let actor_token = token.clone();
tracker.spawn(async move {
actor.run(actor_token).await;
});
Self {
tx,
storage,
inner,
tracker,
token,
metrics,
}
}
pub fn preload(&self, ckey: Arc<CertifiedKey>) {
if ckey.cert.len() < 2 {
return;
}
let fp = Fingerprint::from(&ckey.cert[0]);
let _ = self.tx.try_send((fp, ckey));
}
pub fn status(&self, ckey: Arc<CertifiedKey>) -> Option<CertStatus> {
if ckey.cert.len() < 2 {
return None;
}
let fp = Fingerprint::from(&ckey.cert[0]);
Some(self.storage.load_full()?.get(&fp)?.status.clone())
}
pub async fn stop(&self) {
self.token.cancel();
self.tracker.close();
self.tracker.wait().await;
}
fn staple(&self, ckey: Arc<CertifiedKey>) -> (Arc<CertifiedKey>, bool) {
if ckey.cert.len() < 2 {
return (ckey, false);
}
let fp = Fingerprint::from(&ckey.cert[0]);
if let Some(map) = self.storage.load_full() {
if let Some(v) = map.get(&fp) {
if v.ocsp_validity.is_some() {
return (v.ckey.clone(), true);
}
return (ckey, false);
}
}
let _ = self.tx.try_send((fp, ckey.clone()));
(ckey, false)
}
}
impl fmt::Debug for Stapler {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "OcspStapler")
}
}
impl ResolvesServerCert for Stapler {
fn resolve(&self, client_hello: ClientHello) -> Option<Arc<CertifiedKey>> {
let ckey = self.inner.resolve(client_hello)?;
let (ckey, stapled) = self.staple(ckey);
if let Some(v) = &self.metrics {
v.resolves
.with_label_values(&[if stapled { "yes" } else { "no" }])
.inc();
}
Some(ckey)
}
}
async fn refresh_certificate(
client: &Client,
now: DateTime<FixedOffset>,
cert: &mut Cert,
) -> Result<RefreshResult, Error> {
if let Some(x) = &cert.ocsp_validity {
if !x.past_half_validity(now) {
return Ok(RefreshResult::StillValid);
}
}
let end_entity = cert.ckey.cert[0].as_ref();
let issuer = cert.ckey.cert[1].as_ref();
let resp = client
.query(end_entity, issuer)
.await
.context("unable to perform OCSP request")?;
if !resp.ocsp_validity.valid(now) {
return Err(anyhow!("the OCSP response is not valid at current time"));
}
let mut ckey = cert.ckey.as_ref().clone();
ckey.ocsp = Some(resp.raw);
cert.ckey = Arc::new(ckey);
cert.status = resp.cert_status;
cert.ocsp_validity = Some(resp.ocsp_validity);
Ok(RefreshResult::Refreshed)
}
struct StaplerActor {
client: Client,
storage: Storage,
rx: mpsc::Receiver<(Fingerprint, Arc<CertifiedKey>)>,
published: Arc<ArcSwapOption<Storage>>,
metrics: Option<Metrics>,
}
impl StaplerActor {
async fn refresh(&mut self, now: DateTime<FixedOffset>) {
if self.storage.is_empty() {
return;
}
self.storage.retain(|_, v| v.cert_validity.valid(now));
for cert in self.storage.values_mut() {
let start = Instant::now();
let res = refresh_certificate(&self.client, now, cert).await;
if let Some(v) = &self.metrics {
let lbl = &[if res.is_err() { "error" } else { "ok" }];
v.ocsp_requests_duration
.with_label_values(lbl)
.observe(start.elapsed().as_secs_f64());
v.ocsp_requests.with_label_values(lbl).inc()
};
match res {
Ok(v) => {
if v == RefreshResult::Refreshed {
info!(
"OCSP-Stapler: certificate [{cert}] was refreshed ({}) in {}ms",
cert.ocsp_validity.as_ref().unwrap(),
start.elapsed().as_millis()
);
}
}
Err(e) => warn!("OCSP-Stapler: unable to refresh certificate [{cert}]: {e:#}"),
}
if let Some(v) = &cert.ocsp_validity {
if v.not_after - now < LEEWAY {
info!("OCSP-Stapler: certificate [{cert}] OCSP response has expired");
cert.ocsp_validity = None;
}
}
}
let new = Arc::new(self.storage.clone());
self.published.store(Some(new));
if let Some(m) = &self.metrics {
let status = self.storage.values().map(|x| x.status.clone()).counts();
for (k, v) in status {
m.certificate_count
.with_label_values(&[&format!("{k:?}")])
.set(v as i64);
}
}
}
async fn add_certificate(
&mut self,
fp: Fingerprint,
ckey: Arc<CertifiedKey>,
) -> Result<(), Error> {
if self.storage.contains_key(&fp) {
return Ok(());
}
let cert = X509Certificate::from_der(ckey.end_entity_cert().unwrap())
.context("unable to parse certificate as X.509")?
.1;
let cert_validity = Validity::try_from(&cert.validity).context(format!(
"unable to parse certificate [{}] validity",
cert.subject
))?;
if !cert_validity.valid(Utc::now().into()) {
return Err(anyhow!(
"the certificate [{}] is not valid at current time",
cert.subject
));
}
let cert = Cert {
ckey: ckey.clone(),
subject: cert.subject.to_string(),
status: CertStatus::Unknown(()),
cert_validity,
ocsp_validity: None,
};
self.storage.insert(fp, cert);
self.refresh(Utc::now().into()).await;
Ok(())
}
async fn run(&mut self, token: CancellationToken) {
let mut interval = tokio::time::interval(Duration::from_secs(60));
loop {
tokio::select! {
biased;
() = token.cancelled() => {
warn!("OCSP-Stapler: exiting");
return;
}
_ = interval.tick() => {
self.refresh(Utc::now().into()).await;
},
msg = self.rx.recv() => {
if let Some((fp, ckey)) = msg {
if let Err(e) = self.add_certificate(fp, ckey).await {
warn!("OCSP-Stapler: unable to process certificate: {e:#}");
}
}
}
}
}
}
}