use std::net::SocketAddr;
use std::sync::Arc;
use std::sync::OnceLock;
use std::time::Instant;
use bytes::Bytes;
use parking_lot::Mutex;
use rustls_pki_types::CertificateDer;
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
pub struct ConnId(pub u64);
impl std::fmt::Display for ConnId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:016x}", self.0)
}
}
#[derive(
Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash, Debug, serde::Serialize, serde::Deserialize,
)]
pub enum Transport {
Tcp,
Udp,
}
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
pub enum HttpVersion {
Http1_0,
Http1_1,
Http2,
Http3,
}
#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, serde::Serialize, serde::Deserialize)]
pub enum TlsVersion {
Tls12,
Tls13,
}
#[derive(Clone, Debug, Default)]
pub struct TlsInfo {
pub sni: Option<Arc<str>>,
pub alpn: Option<Arc<[u8]>>,
pub version: Option<TlsVersion>,
pub peer_cert: Option<Arc<PeerCertificate>>,
pub zero_rtt_used: bool,
}
#[derive(Clone, Debug, Default)]
pub struct PeerCertificate {
pub leaf_der: Bytes,
pub subject_cn: Option<Arc<str>>,
pub san_dns: Arc<[Arc<str>]>,
pub fingerprint_sha256: Arc<str>,
pub spki_sha256: Arc<str>,
pub issuer_cn: Option<Arc<str>>,
pub serial: Arc<str>,
}
impl PeerCertificate {
#[must_use]
pub fn from_der(leaf_der: &CertificateDer<'_>) -> Option<Self> {
use sha2::{Digest, Sha256};
use x509_parser::prelude::*;
let bytes = leaf_der.as_ref();
let (_, cert) = X509Certificate::from_der(bytes).ok()?;
let tbs = &cert.tbs_certificate;
let subject_cn =
tbs.subject().iter_common_name().next().and_then(|attr| attr.as_str().ok().map(Arc::from));
let issuer_cn =
tbs.issuer().iter_common_name().next().and_then(|attr| attr.as_str().ok().map(Arc::from));
let mut san_dns: Vec<Arc<str>> = Vec::new();
if let Ok(Some(san_ext)) = tbs.subject_alternative_name() {
for name in &san_ext.value.general_names {
if let GeneralName::DNSName(d) = name {
san_dns.push(Arc::from(*d));
}
}
}
let san_dns: Arc<[Arc<str>]> = san_dns.into();
let mut hasher = Sha256::new();
hasher.update(bytes);
let fingerprint_sha256: Arc<str> = Arc::from(hex_lower(&hasher.finalize()));
let spki_sha256: Arc<str> = {
let spki_der = tbs.subject_pki.raw;
let mut h = Sha256::new();
h.update(spki_der);
Arc::from(hex_lower(&h.finalize()))
};
let serial: Arc<str> = Arc::from(hex_lower(&tbs.serial.to_bytes_be()));
Some(Self {
leaf_der: Bytes::copy_from_slice(bytes),
subject_cn,
san_dns,
fingerprint_sha256,
spki_sha256,
issuer_cn,
serial,
})
}
}
fn hex_lower(bytes: &[u8]) -> String {
use std::fmt::Write as _;
let mut s = String::with_capacity(bytes.len() * 2);
for b in bytes {
let _ = write!(s, "{b:02x}");
}
s
}
#[non_exhaustive]
pub struct ConnContext {
pub id: ConnId,
pub remote: SocketAddr,
pub local: SocketAddr,
pub transport: Transport,
pub entered_at: Instant,
pub tls: Mutex<Option<TlsInfo>>,
pub http_version: OnceLock<HttpVersion>,
pub user: Mutex<http::Extensions>,
}
impl ConnContext {
#[must_use]
pub fn new(
id: ConnId,
remote: SocketAddr,
local: SocketAddr,
transport: Transport,
entered_at: Instant,
) -> Self {
Self {
id,
remote,
local,
transport,
entered_at,
tls: Mutex::new(None),
http_version: OnceLock::new(),
user: Mutex::new(http::Extensions::new()),
}
}
pub fn tls(&self) -> parking_lot::MutexGuard<'_, Option<TlsInfo>> {
self.tls.lock()
}
pub fn with_user<R>(&self, f: impl FnOnce(&mut http::Extensions) -> R) -> R {
let mut guard = self.user.lock();
f(&mut guard)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn conn_id_display_pads_zero_to_sixteen_hex_digits() {
let rendered = format!("{}", ConnId(0));
assert_eq!(rendered, "0000000000000000");
assert_eq!(rendered.len(), 16);
}
#[test]
fn conn_id_display_is_lowercase_hex() {
let rendered = format!("{}", ConnId(0x0bad_f00d_dead_beef));
assert_eq!(rendered, "0badf00ddeadbeef");
assert!(rendered.chars().all(|c| c.is_ascii_digit() || ('a'..='f').contains(&c)));
}
#[test]
fn conn_id_display_zero_pads_small_values() {
let rendered = format!("{}", ConnId(1));
assert_eq!(rendered, "0000000000000001");
}
#[test]
fn conn_id_display_renders_u64_max() {
let rendered = format!("{}", ConnId(u64::MAX));
assert_eq!(rendered, "ffffffffffffffff");
assert_eq!(rendered.len(), 16);
}
#[test]
fn conn_id_serde_round_trip() {
let id = ConnId(0x1234_5678_9abc_def0);
let encoded = serde_json::to_string(&id).expect("serialize");
let decoded: ConnId = serde_json::from_str(&encoded).expect("deserialize");
assert_eq!(decoded, id);
}
#[test]
fn tls_version_variants_are_exhaustive_at_two() {
for v in [TlsVersion::Tls12, TlsVersion::Tls13] {
let matched = match v {
TlsVersion::Tls12 => "1.2",
TlsVersion::Tls13 => "1.3",
};
assert!(!matched.is_empty());
}
}
#[test]
fn tls_version_serde_round_trip_per_variant() {
for v in [TlsVersion::Tls12, TlsVersion::Tls13] {
let encoded = serde_json::to_string(&v).expect("serialize");
let decoded: TlsVersion = serde_json::from_str(&encoded).expect("deserialize");
assert_eq!(decoded, v);
}
}
#[test]
fn transport_serde_round_trip_per_variant() {
for t in [Transport::Tcp, Transport::Udp] {
let encoded = serde_json::to_string(&t).expect("serialize");
let decoded: Transport = serde_json::from_str(&encoded).expect("deserialize");
assert_eq!(decoded, t);
}
}
#[test]
fn http_version_serde_round_trip_per_variant() {
for v in [HttpVersion::Http1_0, HttpVersion::Http1_1, HttpVersion::Http2, HttpVersion::Http3] {
let encoded = serde_json::to_string(&v).expect("serialize");
let decoded: HttpVersion = serde_json::from_str(&encoded).expect("deserialize");
assert_eq!(decoded, v);
}
}
}