1use std::{
2    collections::BTreeMap,
3    fmt::{self, Display},
4    sync::Arc,
5    time::{Duration, Instant},
6};
7
8use anyhow::{anyhow, Context, Error};
9use arc_swap::ArcSwapOption;
10use chrono::{DateTime, FixedOffset, Utc};
11use rasn_ocsp::CertStatus;
12use rustls::{
13    server::{ClientHello, ResolvesServerCert},
14    sign::CertifiedKey,
15};
16use sha1::{Digest, Sha1};
17use tokio::sync::mpsc;
18use tokio_util::{sync::CancellationToken, task::TaskTracker};
19use tracing::{info, warn};
20use x509_parser::prelude::*;
21
22#[cfg(feature = "prometheus")]
23use prometheus::{
24    register_histogram_vec_with_registry, register_int_counter_vec_with_registry,
25    register_int_gauge_vec_with_registry, HistogramVec, IntCounterVec, IntGaugeVec, Registry,
26};
27
28#[cfg(feature = "prometheus")]
29use itertools::Itertools;
30
31use super::{client::Client, Validity, LEEWAY};
32
33type Storage = BTreeMap<Fingerprint, Cert>;
34
35#[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
37struct Fingerprint([u8; 20]);
38
39impl From<&CertifiedKey> for Fingerprint {
40    fn from(v: &CertifiedKey) -> Self {
41        let digest = Sha1::digest(v.cert[0].as_ref());
42        Self(digest.into())
43    }
44}
45
46#[derive(PartialEq, Eq)]
47enum RefreshResult {
48    StillValid,
49    Refreshed,
50}
51
52#[derive(Clone)]
53struct Cert {
54    ckey: Arc<CertifiedKey>,
55    subject: String,
56    status: CertStatus,
57    cert_validity: Validity,
58    ocsp_validity: Option<Validity>,
59}
60
61impl Display for Cert {
62    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63        write!(f, "{}", self.subject)
64    }
65}
66
67#[cfg(feature = "prometheus")]
68#[derive(Clone)]
69struct Metrics {
70    resolves: IntCounterVec,
71    ocsp_requests: IntCounterVec,
72    ocsp_requests_duration: HistogramVec,
73    certificate_count: IntGaugeVec,
74}
75
76#[cfg(feature = "prometheus")]
77impl Metrics {
78    fn new(registry: &Registry) -> Self {
79        Self {
80            resolves: register_int_counter_vec_with_registry!(
81                format!("ocsp_resolves_total"),
82                format!("Counts the number of certificate resolve requests"),
83                &["stapled"],
84                registry
85            )
86            .unwrap(),
87
88            ocsp_requests: register_int_counter_vec_with_registry!(
89                format!("ocsp_requests_total"),
90                format!("Counts the number of OCSP requests"),
91                &["status"],
92                registry
93            )
94            .unwrap(),
95
96            ocsp_requests_duration: register_histogram_vec_with_registry!(
97                format!("ocsp_requests_duration"),
98                format!("Observes OCSP requests duration"),
99                &["status"],
100                registry
101            )
102            .unwrap(),
103
104            certificate_count: register_int_gauge_vec_with_registry!(
105                format!("ocsp_certificate_count"),
106                format!("Current number of certificates in storage"),
107                &["status"],
108                registry
109            )
110            .unwrap(),
111        }
112    }
113}
114
115pub struct Stapler {
117    tx: mpsc::Sender<(Fingerprint, Arc<CertifiedKey>)>,
118    storage: Arc<ArcSwapOption<Storage>>,
119    inner: Arc<dyn ResolvesServerCert>,
120    tracker: TaskTracker,
121    token: CancellationToken,
122    #[cfg(feature = "prometheus")]
123    metrics: Option<Metrics>,
124}
125
126impl Stapler {
127    #[cfg(feature = "prometheus")]
129    pub fn new(inner: Arc<dyn ResolvesServerCert>) -> Self {
130        Self::new_with_client_and_registry(inner, Client::new(), None)
131    }
132
133    #[cfg(feature = "prometheus")]
135    pub fn new_with_registry(inner: Arc<dyn ResolvesServerCert>, registry: &Registry) -> Self {
136        Self::new_with_client_and_registry(inner, Client::new(), Some(registry))
137    }
138
139    #[cfg(feature = "prometheus")]
141    pub fn new_with_client(inner: Arc<dyn ResolvesServerCert>, client: Client) -> Self {
142        Self::new_with_client_and_registry(inner, client, None)
143    }
144
145    #[cfg(feature = "prometheus")]
147    pub fn new_with_client_and_registry(
148        inner: Arc<dyn ResolvesServerCert>,
149        client: Client,
150        registry: Option<&Registry>,
151    ) -> Self {
152        let (tx, rx) = mpsc::channel(1024);
153        let storage = Arc::new(ArcSwapOption::empty());
154        let tracker = TaskTracker::new();
155        let token = CancellationToken::new();
156        let metrics = registry.map(Metrics::new);
157
158        let mut actor = StaplerActor {
159            client,
160            storage: BTreeMap::new(),
161            rx,
162            published: storage.clone(),
163            metrics: metrics.clone(),
164        };
165
166        let actor_token = token.clone();
168        tracker.spawn(async move {
169            actor.run(actor_token).await;
170        });
171
172        Self {
173            tx,
174            storage,
175            inner,
176            tracker,
177            token,
178            metrics,
179        }
180    }
181
182    #[cfg(not(feature = "prometheus"))]
184    pub fn new(inner: Arc<dyn ResolvesServerCert>) -> Self {
185        Self::new_with_client(inner, Client::new())
186    }
187
188    #[cfg(not(feature = "prometheus"))]
190    pub fn new_with_client(inner: Arc<dyn ResolvesServerCert>, client: Client) -> Self {
191        let (tx, rx) = mpsc::channel(1024);
192        let storage = Arc::new(ArcSwapOption::empty());
193        let tracker = TaskTracker::new();
194        let token = CancellationToken::new();
195
196        let mut actor = StaplerActor {
197            client,
198            storage: BTreeMap::new(),
199            rx,
200            published: storage.clone(),
201        };
202
203        let actor_token = token.clone();
205        tracker.spawn(async move {
206            actor.run(actor_token).await;
207        });
208
209        Self {
210            tx,
211            storage,
212            inner,
213            tracker,
214            token,
215        }
216    }
217
218    pub fn preload(&self, ckey: Arc<CertifiedKey>) {
224        if ckey.cert.len() < 2 {
225            return;
226        }
227
228        let fp = Fingerprint::from(ckey.as_ref());
229        let _ = self.tx.try_send((fp, ckey));
230    }
231
232    pub fn status(&self, ckey: Arc<CertifiedKey>) -> Option<CertStatus> {
235        if ckey.cert.len() < 2 {
236            return None;
237        }
238
239        let fp = Fingerprint::from(ckey.as_ref());
240        Some(self.storage.load_full()?.get(&fp)?.status.clone())
241    }
242
243    pub async fn stop(&self) {
245        self.token.cancel();
246        self.tracker.close();
247        self.tracker.wait().await;
248    }
249
250    fn staple(&self, ckey: Arc<CertifiedKey>) -> (Arc<CertifiedKey>, bool) {
251        if ckey.cert.len() < 2 {
255            return (ckey, false);
256        }
257
258        let fp = Fingerprint::from(ckey.as_ref());
260
261        if let Some(map) = self.storage.load_full() {
263            if let Some(v) = map.get(&fp) {
265                if v.ocsp_validity.is_some() {
268                    return (v.ckey.clone(), true);
269                }
270
271                return (ckey, false);
273            }
274        }
275
276        let _ = self.tx.try_send((fp, ckey.clone()));
279
280        (ckey, false)
282    }
283}
284
285impl fmt::Debug for Stapler {
287    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
288        write!(f, "OcspStapler")
289    }
290}
291
292impl ResolvesServerCert for Stapler {
293    fn resolve(&self, client_hello: ClientHello) -> Option<Arc<CertifiedKey>> {
294        let ckey = self.inner.resolve(client_hello)?;
296
297        #[cfg(feature = "prometheus")]
299        let (ckey, stapled) = self.staple(ckey);
300
301        #[cfg(not(feature = "prometheus"))]
302        let (ckey, _stapled) = self.staple(ckey);
303
304        #[cfg(feature = "prometheus")]
306        if let Some(v) = &self.metrics {
307            v.resolves
308                .with_label_values(&[if stapled { "yes" } else { "no" }])
309                .inc();
310        }
311
312        Some(ckey)
313    }
314}
315
316async fn refresh_certificate(
317    client: &Client,
318    now: DateTime<FixedOffset>,
319    cert: &mut Cert,
320) -> Result<RefreshResult, Error> {
321    if let Some(x) = &cert.ocsp_validity {
323        if !x.past_half_validity(now) {
324            return Ok(RefreshResult::StillValid);
325        }
326    }
327
328    let end_entity = cert.ckey.cert[0].as_ref();
330    let issuer = cert.ckey.cert[1].as_ref();
331
332    let resp = client
334        .query(end_entity, issuer)
335        .await
336        .context("unable to perform OCSP request")?;
337
338    if !resp.ocsp_validity.valid(now) {
339        return Err(anyhow!("the OCSP response is not valid at current time"));
340    }
341
342    let mut ckey = cert.ckey.as_ref().clone();
344    ckey.ocsp = Some(resp.raw);
345
346    cert.ckey = Arc::new(ckey);
348    cert.status = resp.cert_status;
349    cert.ocsp_validity = Some(resp.ocsp_validity);
350
351    Ok(RefreshResult::Refreshed)
352}
353
354struct StaplerActor {
355    client: Client,
356    storage: Storage,
357    rx: mpsc::Receiver<(Fingerprint, Arc<CertifiedKey>)>,
358    published: Arc<ArcSwapOption<Storage>>,
359    #[cfg(feature = "prometheus")]
360    metrics: Option<Metrics>,
361}
362
363impl StaplerActor {
364    async fn refresh(&mut self, now: DateTime<FixedOffset>) {
365        if self.storage.is_empty() {
366            return;
367        }
368
369        self.storage.retain(|_, v| v.cert_validity.valid(now));
371
372        for cert in self.storage.values_mut() {
373            let start = Instant::now();
374            let res = refresh_certificate(&self.client, now, cert).await;
375
376            #[cfg(feature = "prometheus")]
378            if let Some(v) = &self.metrics {
379                let lbl = &[if res.is_err() { "error" } else { "ok" }];
380
381                v.ocsp_requests_duration
382                    .with_label_values(lbl)
383                    .observe(start.elapsed().as_secs_f64());
384
385                v.ocsp_requests.with_label_values(lbl).inc()
386            };
387
388            match res {
389                Ok(v) => {
390                    if v == RefreshResult::Refreshed {
391                        info!(
392                            "OCSP-Stapler: certificate [{cert}] was refreshed ({}) in {}ms",
393                            cert.ocsp_validity.as_ref().unwrap(),
394                            start.elapsed().as_millis()
395                        );
396                    }
397                }
398                Err(e) => warn!("OCSP-Stapler: unable to refresh certificate [{cert}]: {e:#}"),
399            }
400
401            if let Some(v) = &cert.ocsp_validity {
404                if v.not_after - now < LEEWAY {
405                    info!("OCSP-Stapler: certificate [{cert}] OCSP response has expired");
406                    cert.ocsp_validity = None;
407                }
408            }
409        }
410
411        let new = Arc::new(self.storage.clone());
413        self.published.store(Some(new));
414
415        #[cfg(feature = "prometheus")]
417        if let Some(m) = &self.metrics {
418            let status = self.storage.values().map(|x| x.status.clone()).counts();
419
420            for (k, v) in status {
421                m.certificate_count
422                    .with_label_values(&[&format!("{k:?}")])
423                    .set(v as i64);
424            }
425        }
426    }
427
428    fn add_certificate(
429        &mut self,
430        fp: Fingerprint,
431        ckey: Arc<CertifiedKey>,
432        now: DateTime<FixedOffset>,
433    ) -> Result<bool, Error> {
434        if self.storage.contains_key(&fp) {
435            return Ok(false);
436        }
437
438        let cert = X509Certificate::from_der(ckey.end_entity_cert().unwrap())
440            .context("unable to parse certificate as X.509")?
441            .1;
442
443        let cert_validity = Validity::try_from(&cert.validity).context(format!(
444            "unable to parse certificate [{}] validity",
445            cert.subject
446        ))?;
447
448        if !cert_validity.valid(now) {
449            return Err(anyhow!(
450                "the certificate [{}] is not valid at current time",
451                cert.subject
452            ));
453        }
454
455        let cert = Cert {
456            ckey: ckey.clone(),
457            subject: cert.subject.to_string(),
458            status: CertStatus::Unknown(()),
459            cert_validity,
460            ocsp_validity: None,
461        };
462
463        self.storage.insert(fp, cert);
464        Ok(true)
465    }
466
467    async fn run(&mut self, token: CancellationToken) {
468        let mut interval = tokio::time::interval(Duration::from_secs(60));
469
470        loop {
471            tokio::select! {
472                biased;
473
474                () = token.cancelled() => {
475                    warn!("OCSP-Stapler: exiting");
476                    return;
477                }
478
479                _ = interval.tick() => {
480                    self.refresh(Utc::now().into()).await;
481                },
482
483                msg = self.rx.recv() => {
484                    if let Some((fp, ckey)) = msg {
485                        let now = Utc::now().into();
486                        if let Err(e) = self.add_certificate(fp, ckey, now) {
487                            warn!("OCSP-Stapler: unable to process certificate: {e:#}");
488                        } else {
489                            self.refresh(now).await;
490                        }
491                    }
492                }
493            }
494        }
495    }
496}
497
498#[cfg(test)]
499mod test {
500    use super::*;
501    use rustls::crypto::ring;
502
503    #[tokio::test]
504    async fn test_add_certificate() {
505        ring::default_provider()
507            .install_default()
508            .unwrap_or_default();
509
510        let ckey = crate::client::test::test_ckey();
511        let storage = Arc::new(ArcSwapOption::empty());
512        let (_, rx) = mpsc::channel(1024);
513
514        let mut actor = StaplerActor {
515            client: Client::new(),
516            storage: BTreeMap::new(),
517            rx,
518            published: storage.clone(),
519            metrics: None,
520        };
521
522        let fp = Fingerprint::from(&ckey);
523        let ckey = Arc::new(ckey);
524
525        let now = DateTime::parse_from_rfc3339("2024-05-25T00:00:00-00:00").unwrap();
527        assert!(actor
528            .add_certificate(fp.clone(), ckey.clone(), now)
529            .is_err());
530
531        let now = DateTime::parse_from_rfc3339("2024-05-28T00:00:00-00:00").unwrap();
533        assert!(actor
534            .add_certificate(fp.clone(), ckey.clone(), now)
535            .unwrap());
536        assert!(!actor
537            .add_certificate(fp.clone(), ckey.clone(), now)
538            .unwrap());
539    }
540}