1use std::{
2 collections::BTreeMap,
3 fmt::{self, Display},
4 sync::Arc,
5 time::{Duration, Instant},
6};
7
8use anyhow::{anyhow, Context, Error};
9use arc_swap::ArcSwapOption;
10use chrono::{DateTime, FixedOffset, Utc};
11use rasn_ocsp::CertStatus;
12use rustls::{
13 server::{ClientHello, ResolvesServerCert},
14 sign::CertifiedKey,
15};
16use sha1::{Digest, Sha1};
17use tokio::sync::mpsc;
18use tokio_util::{sync::CancellationToken, task::TaskTracker};
19use tracing::{info, warn};
20use x509_parser::prelude::*;
21
22#[cfg(feature = "prometheus")]
23use prometheus::{
24 register_histogram_vec_with_registry, register_int_counter_vec_with_registry,
25 register_int_gauge_vec_with_registry, HistogramVec, IntCounterVec, IntGaugeVec, Registry,
26};
27
28#[cfg(feature = "prometheus")]
29use itertools::Itertools;
30
31use super::{client::Client, Validity, LEEWAY};
32
33type Storage = BTreeMap<Fingerprint, Cert>;
34
35#[derive(Clone, Hash, PartialEq, Eq, PartialOrd, Ord)]
37struct Fingerprint([u8; 20]);
38
39impl From<&CertifiedKey> for Fingerprint {
40 fn from(v: &CertifiedKey) -> Self {
41 let digest = Sha1::digest(v.cert[0].as_ref());
42 Self(digest.into())
43 }
44}
45
46#[derive(PartialEq, Eq)]
47enum RefreshResult {
48 StillValid,
49 Refreshed,
50}
51
52#[derive(Clone)]
53struct Cert {
54 ckey: Arc<CertifiedKey>,
55 subject: String,
56 status: CertStatus,
57 cert_validity: Validity,
58 ocsp_validity: Option<Validity>,
59}
60
61impl Display for Cert {
62 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
63 write!(f, "{}", self.subject)
64 }
65}
66
67#[cfg(feature = "prometheus")]
68#[derive(Clone)]
69struct Metrics {
70 resolves: IntCounterVec,
71 ocsp_requests: IntCounterVec,
72 ocsp_requests_duration: HistogramVec,
73 certificate_count: IntGaugeVec,
74}
75
76#[cfg(feature = "prometheus")]
77impl Metrics {
78 fn new(registry: &Registry) -> Self {
79 Self {
80 resolves: register_int_counter_vec_with_registry!(
81 format!("ocsp_resolves_total"),
82 format!("Counts the number of certificate resolve requests"),
83 &["stapled"],
84 registry
85 )
86 .unwrap(),
87
88 ocsp_requests: register_int_counter_vec_with_registry!(
89 format!("ocsp_requests_total"),
90 format!("Counts the number of OCSP requests"),
91 &["status"],
92 registry
93 )
94 .unwrap(),
95
96 ocsp_requests_duration: register_histogram_vec_with_registry!(
97 format!("ocsp_requests_duration"),
98 format!("Observes OCSP requests duration"),
99 &["status"],
100 registry
101 )
102 .unwrap(),
103
104 certificate_count: register_int_gauge_vec_with_registry!(
105 format!("ocsp_certificate_count"),
106 format!("Current number of certificates in storage"),
107 &["status"],
108 registry
109 )
110 .unwrap(),
111 }
112 }
113}
114
115pub struct Stapler {
117 tx: mpsc::Sender<(Fingerprint, Arc<CertifiedKey>)>,
118 storage: Arc<ArcSwapOption<Storage>>,
119 inner: Arc<dyn ResolvesServerCert>,
120 tracker: TaskTracker,
121 token: CancellationToken,
122 #[cfg(feature = "prometheus")]
123 metrics: Option<Metrics>,
124}
125
126impl Stapler {
127 #[cfg(feature = "prometheus")]
129 pub fn new(inner: Arc<dyn ResolvesServerCert>) -> Self {
130 Self::new_with_client_and_registry(inner, Client::new(), None)
131 }
132
133 #[cfg(feature = "prometheus")]
135 pub fn new_with_registry(inner: Arc<dyn ResolvesServerCert>, registry: &Registry) -> Self {
136 Self::new_with_client_and_registry(inner, Client::new(), Some(registry))
137 }
138
139 #[cfg(feature = "prometheus")]
141 pub fn new_with_client(inner: Arc<dyn ResolvesServerCert>, client: Client) -> Self {
142 Self::new_with_client_and_registry(inner, client, None)
143 }
144
145 #[cfg(feature = "prometheus")]
147 pub fn new_with_client_and_registry(
148 inner: Arc<dyn ResolvesServerCert>,
149 client: Client,
150 registry: Option<&Registry>,
151 ) -> Self {
152 let (tx, rx) = mpsc::channel(1024);
153 let storage = Arc::new(ArcSwapOption::empty());
154 let tracker = TaskTracker::new();
155 let token = CancellationToken::new();
156 let metrics = registry.map(Metrics::new);
157
158 let mut actor = StaplerActor {
159 client,
160 storage: BTreeMap::new(),
161 rx,
162 published: storage.clone(),
163 metrics: metrics.clone(),
164 };
165
166 let actor_token = token.clone();
168 tracker.spawn(async move {
169 actor.run(actor_token).await;
170 });
171
172 Self {
173 tx,
174 storage,
175 inner,
176 tracker,
177 token,
178 metrics,
179 }
180 }
181
182 #[cfg(not(feature = "prometheus"))]
184 pub fn new(inner: Arc<dyn ResolvesServerCert>) -> Self {
185 Self::new_with_client(inner, Client::new())
186 }
187
188 #[cfg(not(feature = "prometheus"))]
190 pub fn new_with_client(inner: Arc<dyn ResolvesServerCert>, client: Client) -> Self {
191 let (tx, rx) = mpsc::channel(1024);
192 let storage = Arc::new(ArcSwapOption::empty());
193 let tracker = TaskTracker::new();
194 let token = CancellationToken::new();
195
196 let mut actor = StaplerActor {
197 client,
198 storage: BTreeMap::new(),
199 rx,
200 published: storage.clone(),
201 };
202
203 let actor_token = token.clone();
205 tracker.spawn(async move {
206 actor.run(actor_token).await;
207 });
208
209 Self {
210 tx,
211 storage,
212 inner,
213 tracker,
214 token,
215 }
216 }
217
218 pub fn preload(&self, ckey: Arc<CertifiedKey>) {
224 if ckey.cert.len() < 2 {
225 return;
226 }
227
228 let fp = Fingerprint::from(ckey.as_ref());
229 let _ = self.tx.try_send((fp, ckey));
230 }
231
232 pub fn status(&self, ckey: Arc<CertifiedKey>) -> Option<CertStatus> {
235 if ckey.cert.len() < 2 {
236 return None;
237 }
238
239 let fp = Fingerprint::from(ckey.as_ref());
240 Some(self.storage.load_full()?.get(&fp)?.status.clone())
241 }
242
243 pub async fn stop(&self) {
245 self.token.cancel();
246 self.tracker.close();
247 self.tracker.wait().await;
248 }
249
250 fn staple(&self, ckey: Arc<CertifiedKey>) -> (Arc<CertifiedKey>, bool) {
251 if ckey.cert.len() < 2 {
255 return (ckey, false);
256 }
257
258 let fp = Fingerprint::from(ckey.as_ref());
260
261 if let Some(map) = self.storage.load_full() {
263 if let Some(v) = map.get(&fp) {
265 if v.ocsp_validity.is_some() {
268 return (v.ckey.clone(), true);
269 }
270
271 return (ckey, false);
273 }
274 }
275
276 let _ = self.tx.try_send((fp, ckey.clone()));
279
280 (ckey, false)
282 }
283}
284
285impl fmt::Debug for Stapler {
287 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
288 write!(f, "OcspStapler")
289 }
290}
291
292impl ResolvesServerCert for Stapler {
293 fn resolve(&self, client_hello: ClientHello) -> Option<Arc<CertifiedKey>> {
294 let ckey = self.inner.resolve(client_hello)?;
296
297 #[cfg(feature = "prometheus")]
299 let (ckey, stapled) = self.staple(ckey);
300
301 #[cfg(not(feature = "prometheus"))]
302 let (ckey, _stapled) = self.staple(ckey);
303
304 #[cfg(feature = "prometheus")]
306 if let Some(v) = &self.metrics {
307 v.resolves
308 .with_label_values(&[if stapled { "yes" } else { "no" }])
309 .inc();
310 }
311
312 Some(ckey)
313 }
314}
315
316async fn refresh_certificate(
317 client: &Client,
318 now: DateTime<FixedOffset>,
319 cert: &mut Cert,
320) -> Result<RefreshResult, Error> {
321 if let Some(x) = &cert.ocsp_validity {
323 if !x.past_half_validity(now) {
324 return Ok(RefreshResult::StillValid);
325 }
326 }
327
328 let end_entity = cert.ckey.cert[0].as_ref();
330 let issuer = cert.ckey.cert[1].as_ref();
331
332 let resp = client
334 .query(end_entity, issuer)
335 .await
336 .context("unable to perform OCSP request")?;
337
338 if !resp.ocsp_validity.valid(now) {
339 return Err(anyhow!("the OCSP response is not valid at current time"));
340 }
341
342 let mut ckey = cert.ckey.as_ref().clone();
344 ckey.ocsp = Some(resp.raw);
345
346 cert.ckey = Arc::new(ckey);
348 cert.status = resp.cert_status;
349 cert.ocsp_validity = Some(resp.ocsp_validity);
350
351 Ok(RefreshResult::Refreshed)
352}
353
354struct StaplerActor {
355 client: Client,
356 storage: Storage,
357 rx: mpsc::Receiver<(Fingerprint, Arc<CertifiedKey>)>,
358 published: Arc<ArcSwapOption<Storage>>,
359 #[cfg(feature = "prometheus")]
360 metrics: Option<Metrics>,
361}
362
363impl StaplerActor {
364 async fn refresh(&mut self, now: DateTime<FixedOffset>) {
365 if self.storage.is_empty() {
366 return;
367 }
368
369 self.storage.retain(|_, v| v.cert_validity.valid(now));
371
372 for cert in self.storage.values_mut() {
373 let start = Instant::now();
374 let res = refresh_certificate(&self.client, now, cert).await;
375
376 #[cfg(feature = "prometheus")]
378 if let Some(v) = &self.metrics {
379 let lbl = &[if res.is_err() { "error" } else { "ok" }];
380
381 v.ocsp_requests_duration
382 .with_label_values(lbl)
383 .observe(start.elapsed().as_secs_f64());
384
385 v.ocsp_requests.with_label_values(lbl).inc()
386 };
387
388 match res {
389 Ok(v) => {
390 if v == RefreshResult::Refreshed {
391 info!(
392 "OCSP-Stapler: certificate [{cert}] was refreshed ({}) in {}ms",
393 cert.ocsp_validity.as_ref().unwrap(),
394 start.elapsed().as_millis()
395 );
396 }
397 }
398 Err(e) => warn!("OCSP-Stapler: unable to refresh certificate [{cert}]: {e:#}"),
399 }
400
401 if let Some(v) = &cert.ocsp_validity {
404 if v.not_after - now < LEEWAY {
405 info!("OCSP-Stapler: certificate [{cert}] OCSP response has expired");
406 cert.ocsp_validity = None;
407 }
408 }
409 }
410
411 let new = Arc::new(self.storage.clone());
413 self.published.store(Some(new));
414
415 #[cfg(feature = "prometheus")]
417 if let Some(m) = &self.metrics {
418 let status = self.storage.values().map(|x| x.status.clone()).counts();
419
420 for (k, v) in status {
421 m.certificate_count
422 .with_label_values(&[&format!("{k:?}")])
423 .set(v as i64);
424 }
425 }
426 }
427
428 fn add_certificate(
429 &mut self,
430 fp: Fingerprint,
431 ckey: Arc<CertifiedKey>,
432 now: DateTime<FixedOffset>,
433 ) -> Result<bool, Error> {
434 if self.storage.contains_key(&fp) {
435 return Ok(false);
436 }
437
438 let cert = X509Certificate::from_der(ckey.end_entity_cert().unwrap())
440 .context("unable to parse certificate as X.509")?
441 .1;
442
443 let cert_validity = Validity::try_from(&cert.validity).context(format!(
444 "unable to parse certificate [{}] validity",
445 cert.subject
446 ))?;
447
448 if !cert_validity.valid(now) {
449 return Err(anyhow!(
450 "the certificate [{}] is not valid at current time",
451 cert.subject
452 ));
453 }
454
455 let cert = Cert {
456 ckey: ckey.clone(),
457 subject: cert.subject.to_string(),
458 status: CertStatus::Unknown(()),
459 cert_validity,
460 ocsp_validity: None,
461 };
462
463 self.storage.insert(fp, cert);
464 Ok(true)
465 }
466
467 async fn run(&mut self, token: CancellationToken) {
468 let mut interval = tokio::time::interval(Duration::from_secs(60));
469
470 loop {
471 tokio::select! {
472 biased;
473
474 () = token.cancelled() => {
475 warn!("OCSP-Stapler: exiting");
476 return;
477 }
478
479 _ = interval.tick() => {
480 self.refresh(Utc::now().into()).await;
481 },
482
483 msg = self.rx.recv() => {
484 if let Some((fp, ckey)) = msg {
485 let now = Utc::now().into();
486 if let Err(e) = self.add_certificate(fp, ckey, now) {
487 warn!("OCSP-Stapler: unable to process certificate: {e:#}");
488 } else {
489 self.refresh(now).await;
490 }
491 }
492 }
493 }
494 }
495 }
496}
497
498#[cfg(test)]
499mod test {
500 use super::*;
501 use rustls::crypto::ring;
502
503 #[tokio::test]
504 async fn test_add_certificate() {
505 ring::default_provider()
507 .install_default()
508 .unwrap_or_default();
509
510 let ckey = crate::client::test::test_ckey();
511 let storage = Arc::new(ArcSwapOption::empty());
512 let (_, rx) = mpsc::channel(1024);
513
514 let mut actor = StaplerActor {
515 client: Client::new(),
516 storage: BTreeMap::new(),
517 rx,
518 published: storage.clone(),
519 metrics: None,
520 };
521
522 let fp = Fingerprint::from(&ckey);
523 let ckey = Arc::new(ckey);
524
525 let now = DateTime::parse_from_rfc3339("2024-05-25T00:00:00-00:00").unwrap();
527 assert!(actor
528 .add_certificate(fp.clone(), ckey.clone(), now)
529 .is_err());
530
531 let now = DateTime::parse_from_rfc3339("2024-05-28T00:00:00-00:00").unwrap();
533 assert!(actor
534 .add_certificate(fp.clone(), ckey.clone(), now)
535 .unwrap());
536 assert!(!actor
537 .add_certificate(fp.clone(), ckey.clone(), now)
538 .unwrap());
539 }
540}