Skip to main content

jwks_cache/
security.rs

1//! Security utilities covering HTTPS enforcement, domain allowlists, and SPKI pinning.
2//!
3//! # Threat Model
4//! These helpers assume upstream TLS validation has already succeeded and focus on defending the
5//! cache pipeline against downgrade attempts (HTTP redirects), host header confusion, and
6//! certificate substitution by validating SPKI fingerprints.
7
8// std
9use std::{
10	collections::HashSet,
11	fmt::{Debug, Formatter, Result as FmtResult},
12};
13// crates.io
14use base64::prelude::*;
15use serde::{Deserialize, Serialize, de::Deserializer};
16use sha2::{Digest, Sha256};
17use url::Url;
18// self
19use crate::_prelude::*;
20
21/// SHA-256 fingerprint of a Subject Public Key Info (SPKI) structure.
22#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
23#[serde(try_from = "String", into = "String")]
24pub struct SpkiFingerprint {
25	bytes: Arc<[u8; 32]>,
26}
27impl SpkiFingerprint {
28	/// Decode a base64(fp) value into a fingerprint.
29	pub fn from_b64(value: &str) -> Result<Self> {
30		let cleaned = value.trim();
31		let decoded = BASE64_STANDARD
32			.decode(cleaned)
33			.or_else(|_| BASE64_URL_SAFE_NO_PAD.decode(cleaned))
34			.map_err(|err| Error::Validation {
35				field: "pinned_spki",
36				reason: format!("Invalid base64 fingerprint: {err}."),
37			})?;
38
39		if decoded.len() != 32 {
40			return Err(Error::Validation {
41				field: "pinned_spki",
42				reason: "Fingerprint must decode to 32 bytes (SHA-256).".into(),
43			});
44		}
45
46		let mut bytes = [0u8; 32];
47
48		bytes.copy_from_slice(&decoded);
49
50		Ok(Self { bytes: Arc::new(bytes) })
51	}
52
53	/// Raw fingerprint bytes.
54	pub fn as_bytes(&self) -> &[u8; 32] {
55		self.bytes.as_ref()
56	}
57}
58impl Debug for SpkiFingerprint {
59	fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
60		write!(f, "SpkiFingerprint({})", BASE64_STANDARD.encode(self.bytes.as_ref()))
61	}
62}
63impl TryFrom<String> for SpkiFingerprint {
64	type Error = Error;
65
66	fn try_from(value: String) -> Result<Self> {
67		Self::from_b64(&value)
68	}
69}
70impl From<SpkiFingerprint> for String {
71	fn from(value: SpkiFingerprint) -> Self {
72		BASE64_STANDARD.encode(value.bytes.as_ref())
73	}
74}
75
76/// Canonicalise a DNS name by trimming whitespace, removing any trailing dot, and lowercasing.
77pub fn canonicalize_dns_name(value: &str) -> Option<String> {
78	let trimmed = value.trim();
79	if trimmed.is_empty() {
80		return None;
81	}
82
83	let without_trailing_dot = trimmed.trim_end_matches('.');
84	if without_trailing_dot.is_empty() {
85		return None;
86	}
87
88	Some(without_trailing_dot.to_ascii_lowercase())
89}
90
91/// Normalise an allowlist by canonicalising entries and removing duplicates/empties.
92pub fn normalize_allowlist(domains: Vec<String>) -> Vec<String> {
93	let mut seen = HashSet::new();
94	let mut normalized = Vec::with_capacity(domains.len());
95
96	for domain in domains {
97		if let Some(canonical) = canonicalize_dns_name(&domain)
98			&& seen.insert(canonical.clone())
99		{
100			normalized.push(canonical);
101		}
102	}
103
104	normalized
105}
106
107/// `serde` helper to normalise allowlist domains during deserialisation.
108pub fn deserialize_allowed_domains<'de, D>(
109	deserializer: D,
110) -> std::result::Result<Vec<String>, D::Error>
111where
112	D: Deserializer<'de>,
113{
114	let raw = Vec::<String>::deserialize(deserializer)?;
115	Ok(normalize_allowlist(raw))
116}
117
118/// Ensure the provided URL uses HTTPS.
119pub fn enforce_https(url: &Url) -> Result<()> {
120	if url.scheme() == "https" {
121		Ok(())
122	} else {
123		Err(Error::Security(format!("Upstream URL {url} must use HTTPS.")))
124	}
125}
126
127#[inline]
128fn matches_allowlist(host: &str, domain: &str) -> bool {
129	if host == domain {
130		return true;
131	}
132
133	host.strip_suffix(domain).and_then(|prefix| prefix.strip_suffix('.')).is_some()
134}
135
136fn is_canonical_allowlist_entry(domain: &str) -> bool {
137	!domain.is_empty()
138		&& !domain.ends_with('.')
139		&& domain.trim().len() == domain.len()
140		&& !domain.chars().any(|c| c.is_ascii_uppercase())
141}
142
143/// Evaluate whether the given hostname is allowed by the provided suffix allowlist.
144///
145/// When the list is empty, all hosts are considered valid.
146pub fn host_is_allowed(host: &str, allowed_domains: &[String]) -> bool {
147	if allowed_domains.is_empty() {
148		return true;
149	}
150
151	let Some(host) = canonicalize_dns_name(host) else {
152		return false;
153	};
154
155	allowed_domains.iter().any(|domain| {
156		if is_canonical_allowlist_entry(domain) {
157			matches_allowlist(&host, domain)
158		} else if let Some(canonical) = canonicalize_dns_name(domain) {
159			matches_allowlist(&host, &canonical)
160		} else {
161			false
162		}
163	})
164}
165
166/// Compute the SHA-256 fingerprint of a DER-encoded SPKI payload.
167pub fn fingerprint_spki(spki_der: &[u8]) -> [u8; 32] {
168	let digest = Sha256::digest(spki_der);
169	let mut bytes = [0u8; 32];
170
171	bytes.copy_from_slice(&digest);
172
173	bytes
174}
175
176/// Validate that at least one configured SPKI fingerprint matches the presented SPKI set.
177///
178/// The iterator should provide DER-encoded SPKI payloads extracted from the TLS peer certificates.
179pub fn verify_spki_pins<'a, I>(present_spki: I, pins: &[SpkiFingerprint]) -> Result<()>
180where
181	I: IntoIterator<Item = &'a [u8]>,
182{
183	if pins.is_empty() {
184		return Ok(());
185	}
186
187	let mut presented_fingerprints = Vec::new();
188
189	for spki in present_spki {
190		let fingerprint = fingerprint_spki(spki);
191		if pins.iter().any(|pin| pin.as_bytes() == &fingerprint) {
192			return Ok(());
193		}
194		if tracing::enabled!(tracing::Level::WARN) {
195			presented_fingerprints.push(BASE64_STANDARD.encode(fingerprint));
196		}
197	}
198
199	if tracing::enabled!(tracing::Level::WARN) {
200		let expected: Vec<String> =
201			pins.iter().map(|pin| BASE64_STANDARD.encode(pin.as_bytes())).collect();
202		tracing::warn!(
203			expected = ?expected,
204			presented = ?presented_fingerprints,
205			"SPKI pin verification failed — no fingerprints matched",
206		);
207	}
208
209	Err(Error::Security(
210		"Presented certificate chain does not match any configured SPKI pins.".into(),
211	))
212}
213
214#[cfg(test)]
215mod tests {
216	use super::*;
217	use url::Url;
218
219	#[test]
220	fn base64_variants_are_accepted() {
221		let bytes = [42u8; 32];
222		let standard = BASE64_STANDARD.encode(bytes);
223		let url_safe = BASE64_URL_SAFE_NO_PAD.encode(bytes);
224
225		for encoded in [standard, url_safe] {
226			let fingerprint = SpkiFingerprint::from_b64(&encoded).expect("valid fingerprint");
227			assert_eq!(fingerprint.as_bytes(), &bytes);
228		}
229	}
230
231	#[test]
232	fn base64_length_error_is_reported() {
233		let err = SpkiFingerprint::from_b64("AQID");
234		assert!(err.is_err());
235	}
236
237	#[test]
238	fn host_allowlist_handles_case_and_trailing_dot() {
239		let domains = normalize_allowlist(vec!["Example.COM.".into()]);
240		assert!(host_is_allowed("api.EXAMPLE.com.", &domains));
241		assert!(host_is_allowed("example.com.", &domains));
242		assert!(!host_is_allowed("other.org", &domains));
243		let empty_allowlist: Vec<String> = Vec::new();
244		assert!(host_is_allowed("anything.example", &empty_allowlist));
245	}
246
247	#[test]
248	fn verify_spki_pins_success_and_failure() {
249		let spki_primary = b"primary";
250		let spki_other = b"other";
251		let pin_value = BASE64_STANDARD.encode(fingerprint_spki(spki_primary));
252		let pins = vec![SpkiFingerprint::from_b64(&pin_value).unwrap()];
253
254		assert!(verify_spki_pins([spki_primary.as_slice()], &pins).is_ok());
255		assert!(verify_spki_pins([spki_other.as_slice()], &pins).is_err());
256	}
257
258	#[test]
259	fn enforce_https_rejects_insecure_scheme() {
260		let http = Url::parse("http://example.com/jwks").unwrap();
261		assert!(enforce_https(&http).is_err());
262	}
263}