Skip to main content

rustls_sni_resolver/
lib.rs

1//! A minimal `ResolvesServerCert` implementation backed by
2//! `{ by_sni: HashMap<String, Arc<E>>, default: Option<Arc<E>> }`,
3//! with the whole struct designed to live behind an `Arc<ArcSwap<_>>`
4//! so a config reload is one atomic pointer swap.
5//!
6//! `E` is generic over a [`EntryKey`] trait, so callers can attach
7//! their own per-cert state (expiry timestamps, OCSP staple handles,
8//! ACME order IDs) without a fork.
9
10use std::collections::HashMap;
11use std::sync::Arc;
12
13use arc_swap::ArcSwap;
14
15/// A trait satisfied by anything that carries a rustls
16/// `Arc<CertifiedKey>` (cert chain + signing key + optional OCSP
17/// staple). Used by [`CertStore::lookup`] to extract the handshake
18/// material from a caller-defined entry type.
19pub trait EntryKey {
20	fn key(&self) -> Arc<rustls::sign::CertifiedKey>;
21}
22
23/// Per-listener cert pool: zero-or-more SNI-keyed entries plus an
24/// optional sni-less default. The default fires when a `ClientHello`
25/// has no SNI extension or when the SNI does not match any
26/// [`Self::by_sni`] key. A listener has at most one default.
27///
28/// Keys in [`Self::by_sni`] are stored ASCII-lowercase per RFC 6066
29/// § 3 (`server_name` is already ASCII-lowercased by rustls), so
30/// resolver-side lookups are byte-for-byte without an
31/// `eq_ignore_ascii_case` shim.
32#[derive(Debug)]
33pub struct CertStore<E: EntryKey> {
34	pub by_sni: HashMap<String, Arc<E>>,
35	pub default: Option<Arc<E>>,
36}
37
38impl<E: EntryKey> CertStore<E> {
39	#[must_use]
40	pub fn new() -> Self {
41		Self { by_sni: HashMap::new(), default: None }
42	}
43
44	/// Resolve a `ClientHello`'s SNI against the store. The hot-path
45	/// resolver delegates to this so unit tests can exercise the
46	/// lookup without constructing a `rustls::ClientHello` (which is
47	/// not user-constructible). `sni` is expected to already be
48	/// ASCII-lowercased by rustls per RFC 6066 § 3.
49	#[must_use]
50	pub fn lookup(&self, sni: Option<&str>) -> Option<Arc<rustls::sign::CertifiedKey>> {
51		if let Some(name) = sni
52			&& let Some(entry) = self.by_sni.get(name)
53		{
54			return Some(entry.key());
55		}
56		self.default.as_ref().map(|d| d.key())
57	}
58}
59
60impl<E: EntryKey> Default for CertStore<E> {
61	fn default() -> Self {
62		Self::new()
63	}
64}
65
66/// `rustls::server::ResolvesServerCert` implementation backed by an
67/// `ArcSwap<CertStore<E>>`. Reads the current store on every
68/// handshake — a populator-driven swap is observed by the next
69/// `ClientHello`, never mid-connection (TLS does not permit that).
70///
71/// We do **not** delegate to rustls's built-in
72/// `rustls::server::ResolvesServerCertUsingSni` because it returns
73/// `None` (handshake failure) on unmatched SNI with no built-in
74/// fallback hook; this resolver uses [`CertStore::default`] as the
75/// explicit no-SNI fallback.
76#[derive(Debug)]
77pub struct Resolver<E: EntryKey> {
78	store: Arc<ArcSwap<CertStore<E>>>,
79}
80
81impl<E: EntryKey> Resolver<E> {
82	#[must_use]
83	pub fn new(store: Arc<ArcSwap<CertStore<E>>>) -> Self {
84		Self { store }
85	}
86}
87
88impl<E: EntryKey + std::fmt::Debug + Send + Sync + 'static> rustls::server::ResolvesServerCert
89	for Resolver<E>
90{
91	fn resolve(
92		&self,
93		hello: rustls::server::ClientHello<'_>,
94	) -> Option<Arc<rustls::sign::CertifiedKey>> {
95		// `server_name()` is already ASCII-lowercased by rustls per
96		// RFC 6066 § 3, so a direct map lookup suffices.
97		self.store.load().lookup(hello.server_name())
98	}
99}
100
101#[cfg(test)]
102mod tests {
103	use super::*;
104	use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
105	use rustls::sign::CertifiedKey;
106
107	#[derive(Debug)]
108	struct TestEntry {
109		key: Arc<CertifiedKey>,
110	}
111
112	impl EntryKey for TestEntry {
113		fn key(&self) -> Arc<CertifiedKey> {
114			Arc::clone(&self.key)
115		}
116	}
117
118	fn install_crypto() {
119		let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
120	}
121
122	fn make_entry(host: &str) -> Arc<TestEntry> {
123		install_crypto();
124		let issued =
125			rcgen::generate_simple_self_signed(vec![host.to_owned()]).expect("self-signed cert");
126		let cert_der = CertificateDer::from(issued.cert.der().to_vec());
127		let key_der = PrivatePkcs8KeyDer::from(issued.signing_key.serialize_der());
128		let signing = rustls::crypto::CryptoProvider::get_default()
129			.expect("crypto provider")
130			.key_provider
131			.load_private_key(rustls::pki_types::PrivateKeyDer::Pkcs8(key_der))
132			.expect("load_private_key");
133		let key = Arc::new(CertifiedKey::new(vec![cert_der], signing));
134		Arc::new(TestEntry { key })
135	}
136
137	#[test]
138	fn lookup_hit_returns_keyed_entry() {
139		let entry = make_entry("api.example.com");
140		let mut store: CertStore<TestEntry> = CertStore::new();
141		store.by_sni.insert("api.example.com".to_owned(), Arc::clone(&entry));
142		let got = store.lookup(Some("api.example.com")).expect("hit");
143		assert!(Arc::ptr_eq(&got, &entry.key));
144	}
145
146	#[test]
147	fn lookup_miss_falls_back_to_default() {
148		let api = make_entry("api.example.com");
149		let default = make_entry("default.example.com");
150		let mut store: CertStore<TestEntry> = CertStore::new();
151		store.by_sni.insert("api.example.com".to_owned(), api);
152		store.default = Some(Arc::clone(&default));
153		let got = store.lookup(Some("unknown.example.com")).expect("default fires");
154		assert!(Arc::ptr_eq(&got, &default.key));
155	}
156
157	#[test]
158	fn lookup_miss_with_no_default_returns_none() {
159		let api = make_entry("api.example.com");
160		let mut store: CertStore<TestEntry> = CertStore::new();
161		store.by_sni.insert("api.example.com".to_owned(), api);
162		assert!(store.lookup(Some("unknown.example.com")).is_none());
163		assert!(store.lookup(None).is_none());
164	}
165
166	#[test]
167	fn lookup_no_sni_uses_default() {
168		let default = make_entry("default.example.com");
169		let mut store: CertStore<TestEntry> = CertStore::new();
170		store.default = Some(Arc::clone(&default));
171		let got = store.lookup(None).expect("default fires");
172		assert!(Arc::ptr_eq(&got, &default.key));
173	}
174
175	#[test]
176	fn arcswap_store_visible_to_subsequent_lookup() {
177		let api = make_entry("api.example.com");
178		let mut initial: CertStore<TestEntry> = CertStore::new();
179		initial.by_sni.insert("api.example.com".to_owned(), Arc::clone(&api));
180		let arcswap = Arc::new(ArcSwap::from_pointee(initial));
181
182		assert!(Arc::ptr_eq(&arcswap.load().lookup(Some("api.example.com")).expect("hit"), &api.key));
183
184		let admin = make_entry("admin.example.com");
185		let mut fresh: CertStore<TestEntry> = CertStore::new();
186		fresh.by_sni.insert("admin.example.com".to_owned(), Arc::clone(&admin));
187		arcswap.store(Arc::new(fresh));
188
189		assert!(arcswap.load().lookup(Some("api.example.com")).is_none());
190		assert!(Arc::ptr_eq(
191			&arcswap.load().lookup(Some("admin.example.com")).expect("hit fresh"),
192			&admin.key
193		));
194	}
195
196	#[test]
197	fn resolver_constructible_from_arcswap() {
198		// Resolver::resolve takes a `rustls::ClientHello`, which has
199		// no public constructor; downstream e2e tests exercise the
200		// live SNI path. Here we cover construction and trait wiring.
201		let store: Arc<ArcSwap<CertStore<TestEntry>>> =
202			Arc::new(ArcSwap::from_pointee(CertStore::new()));
203		let _resolver: Arc<dyn rustls::server::ResolvesServerCert> = Arc::new(Resolver::new(store));
204	}
205}