use std::{
fmt,
sync::atomic::{AtomicU32, Ordering},
};
use chacha20poly1305::{
XChaCha20Poly1305, XNonce,
aead::{Aead, AeadCore, KeyInit, OsRng},
};
use prometheus::{IntCounterVec, Registry, register_int_counter_vec_with_registry};
use rustls::server::ProducesTickets;
use zeroize::ZeroizeOnDrop;
const NONCE_LEN: usize = 192 / 8;
#[derive(Debug)]
pub struct Metrics {
processed: IntCounterVec,
}
impl Metrics {
pub fn new(registry: &Registry) -> Self {
Self {
processed: register_int_counter_vec_with_registry!(
format!("tls_tickets"),
format!("Number of TLS tickets that were processed"),
&["action", "result"],
registry
)
.unwrap(),
}
}
}
#[derive(ZeroizeOnDrop)]
pub struct Ticketer {
#[zeroize(skip)]
counter: AtomicU32,
cipher: XChaCha20Poly1305,
}
impl fmt::Debug for Ticketer {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Ticketer")
}
}
impl Default for Ticketer {
fn default() -> Self {
Self::new()
}
}
impl Ticketer {
pub fn new() -> Self {
let key = XChaCha20Poly1305::generate_key(&mut OsRng);
Self {
cipher: XChaCha20Poly1305::new(&key),
counter: AtomicU32::new(0),
}
}
fn nonce(&self) -> XNonce {
let mut nonce = XChaCha20Poly1305::generate_nonce(&mut OsRng);
let count = self.counter.fetch_add(1, Ordering::SeqCst);
nonce[0..4].copy_from_slice(&count.to_le_bytes());
nonce
}
}
impl ProducesTickets for Ticketer {
fn enabled(&self) -> bool {
true
}
fn decrypt(&self, cipher: &[u8]) -> Option<Vec<u8>> {
if cipher.len() <= NONCE_LEN {
return None;
}
#[allow(deprecated)]
let nonce = XNonce::from_slice(&cipher[0..NONCE_LEN]);
self.cipher.decrypt(nonce, &cipher[NONCE_LEN..]).ok()
}
fn encrypt(&self, plain: &[u8]) -> Option<Vec<u8>> {
let nonce = self.nonce();
let ciphertext = self.cipher.encrypt(&nonce, plain).ok()?;
let mut result = Vec::with_capacity(nonce.len() + ciphertext.len());
#[allow(deprecated)]
result.extend_from_slice(nonce.as_slice());
result.extend_from_slice(&ciphertext);
Some(result)
}
fn lifetime(&self) -> u32 {
3600
}
}
#[derive(Debug)]
pub struct WithMetrics<T: ProducesTickets>(pub T, pub Metrics);
impl<T: ProducesTickets> WithMetrics<T> {
fn record(&self, action: &str, res: &Option<Vec<u8>>) {
self.1
.processed
.with_label_values(&[action, if res.is_some() { "ok" } else { "fail" }])
.inc();
}
}
impl<T: ProducesTickets> ProducesTickets for WithMetrics<T> {
fn enabled(&self) -> bool {
self.0.enabled()
}
fn lifetime(&self) -> u32 {
self.0.lifetime()
}
fn encrypt(&self, plain: &[u8]) -> Option<Vec<u8>> {
let res = self.0.encrypt(plain);
self.record("encrypt", &res);
res
}
fn decrypt(&self, cipher: &[u8]) -> Option<Vec<u8>> {
let res = self.0.decrypt(cipher);
self.record("decrypt", &res);
res
}
}
#[cfg(test)]
mod test {
use super::*;
#[test]
fn test_ticketer() {
let t = Ticketer::new();
for i in 0..10 {
#[allow(deprecated)]
let counter = u32::from_le_bytes(t.nonce().as_slice()[0..4].try_into().unwrap());
assert_eq!(counter, i);
}
let msg = b"The quick brown fox jumps over the lazy dog";
let ciphertext = t.encrypt(msg).unwrap();
let plaintext = t.decrypt(&ciphertext).unwrap();
assert_eq!(&msg[..], plaintext);
assert!(t.decrypt(msg).is_none());
}
}