rustls_sni_resolver/
lib.rs1use std::collections::HashMap;
11use std::sync::Arc;
12
13use arc_swap::ArcSwap;
14
15pub trait EntryKey {
20 fn key(&self) -> Arc<rustls::sign::CertifiedKey>;
21}
22
23#[derive(Debug)]
33pub struct CertStore<E: EntryKey> {
34 pub by_sni: HashMap<String, Arc<E>>,
35 pub default: Option<Arc<E>>,
36}
37
38impl<E: EntryKey> CertStore<E> {
39 #[must_use]
40 pub fn new() -> Self {
41 Self { by_sni: HashMap::new(), default: None }
42 }
43
44 #[must_use]
50 pub fn lookup(&self, sni: Option<&str>) -> Option<Arc<rustls::sign::CertifiedKey>> {
51 if let Some(name) = sni
52 && let Some(entry) = self.by_sni.get(name)
53 {
54 return Some(entry.key());
55 }
56 self.default.as_ref().map(|d| d.key())
57 }
58}
59
60impl<E: EntryKey> Default for CertStore<E> {
61 fn default() -> Self {
62 Self::new()
63 }
64}
65
66#[derive(Debug)]
77pub struct Resolver<E: EntryKey> {
78 store: Arc<ArcSwap<CertStore<E>>>,
79}
80
81impl<E: EntryKey> Resolver<E> {
82 #[must_use]
83 pub fn new(store: Arc<ArcSwap<CertStore<E>>>) -> Self {
84 Self { store }
85 }
86}
87
88impl<E: EntryKey + std::fmt::Debug + Send + Sync + 'static> rustls::server::ResolvesServerCert
89 for Resolver<E>
90{
91 fn resolve(
92 &self,
93 hello: rustls::server::ClientHello<'_>,
94 ) -> Option<Arc<rustls::sign::CertifiedKey>> {
95 self.store.load().lookup(hello.server_name())
98 }
99}
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104 use rustls::pki_types::{CertificateDer, PrivatePkcs8KeyDer};
105 use rustls::sign::CertifiedKey;
106
107 #[derive(Debug)]
108 struct TestEntry {
109 key: Arc<CertifiedKey>,
110 }
111
112 impl EntryKey for TestEntry {
113 fn key(&self) -> Arc<CertifiedKey> {
114 Arc::clone(&self.key)
115 }
116 }
117
118 fn install_crypto() {
119 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
120 }
121
122 fn make_entry(host: &str) -> Arc<TestEntry> {
123 install_crypto();
124 let issued =
125 rcgen::generate_simple_self_signed(vec![host.to_owned()]).expect("self-signed cert");
126 let cert_der = CertificateDer::from(issued.cert.der().to_vec());
127 let key_der = PrivatePkcs8KeyDer::from(issued.signing_key.serialize_der());
128 let signing = rustls::crypto::CryptoProvider::get_default()
129 .expect("crypto provider")
130 .key_provider
131 .load_private_key(rustls::pki_types::PrivateKeyDer::Pkcs8(key_der))
132 .expect("load_private_key");
133 let key = Arc::new(CertifiedKey::new(vec![cert_der], signing));
134 Arc::new(TestEntry { key })
135 }
136
137 #[test]
138 fn lookup_hit_returns_keyed_entry() {
139 let entry = make_entry("api.example.com");
140 let mut store: CertStore<TestEntry> = CertStore::new();
141 store.by_sni.insert("api.example.com".to_owned(), Arc::clone(&entry));
142 let got = store.lookup(Some("api.example.com")).expect("hit");
143 assert!(Arc::ptr_eq(&got, &entry.key));
144 }
145
146 #[test]
147 fn lookup_miss_falls_back_to_default() {
148 let api = make_entry("api.example.com");
149 let default = make_entry("default.example.com");
150 let mut store: CertStore<TestEntry> = CertStore::new();
151 store.by_sni.insert("api.example.com".to_owned(), api);
152 store.default = Some(Arc::clone(&default));
153 let got = store.lookup(Some("unknown.example.com")).expect("default fires");
154 assert!(Arc::ptr_eq(&got, &default.key));
155 }
156
157 #[test]
158 fn lookup_miss_with_no_default_returns_none() {
159 let api = make_entry("api.example.com");
160 let mut store: CertStore<TestEntry> = CertStore::new();
161 store.by_sni.insert("api.example.com".to_owned(), api);
162 assert!(store.lookup(Some("unknown.example.com")).is_none());
163 assert!(store.lookup(None).is_none());
164 }
165
166 #[test]
167 fn lookup_no_sni_uses_default() {
168 let default = make_entry("default.example.com");
169 let mut store: CertStore<TestEntry> = CertStore::new();
170 store.default = Some(Arc::clone(&default));
171 let got = store.lookup(None).expect("default fires");
172 assert!(Arc::ptr_eq(&got, &default.key));
173 }
174
175 #[test]
176 fn arcswap_store_visible_to_subsequent_lookup() {
177 let api = make_entry("api.example.com");
178 let mut initial: CertStore<TestEntry> = CertStore::new();
179 initial.by_sni.insert("api.example.com".to_owned(), Arc::clone(&api));
180 let arcswap = Arc::new(ArcSwap::from_pointee(initial));
181
182 assert!(Arc::ptr_eq(&arcswap.load().lookup(Some("api.example.com")).expect("hit"), &api.key));
183
184 let admin = make_entry("admin.example.com");
185 let mut fresh: CertStore<TestEntry> = CertStore::new();
186 fresh.by_sni.insert("admin.example.com".to_owned(), Arc::clone(&admin));
187 arcswap.store(Arc::new(fresh));
188
189 assert!(arcswap.load().lookup(Some("api.example.com")).is_none());
190 assert!(Arc::ptr_eq(
191 &arcswap.load().lookup(Some("admin.example.com")).expect("hit fresh"),
192 &admin.key
193 ));
194 }
195
196 #[test]
197 fn resolver_constructible_from_arcswap() {
198 let store: Arc<ArcSwap<CertStore<TestEntry>>> =
202 Arc::new(ArcSwap::from_pointee(CertStore::new()));
203 let _resolver: Arc<dyn rustls::server::ResolvesServerCert> = Arc::new(Resolver::new(store));
204 }
205}