1use std::{
10 collections::HashSet,
11 fmt::{Debug, Formatter, Result as FmtResult},
12};
13use base64::prelude::*;
15use serde::{Deserialize, Serialize, de::Deserializer};
16use sha2::{Digest, Sha256};
17use url::Url;
18use crate::_prelude::*;
20
21#[derive(Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
23#[serde(try_from = "String", into = "String")]
24pub struct SpkiFingerprint {
25 bytes: Arc<[u8; 32]>,
26}
27impl SpkiFingerprint {
28 pub fn from_b64(value: &str) -> Result<Self> {
30 let cleaned = value.trim();
31 let decoded = BASE64_STANDARD
32 .decode(cleaned)
33 .or_else(|_| BASE64_URL_SAFE_NO_PAD.decode(cleaned))
34 .map_err(|err| Error::Validation {
35 field: "pinned_spki",
36 reason: format!("Invalid base64 fingerprint: {err}."),
37 })?;
38
39 if decoded.len() != 32 {
40 return Err(Error::Validation {
41 field: "pinned_spki",
42 reason: "Fingerprint must decode to 32 bytes (SHA-256).".into(),
43 });
44 }
45
46 let mut bytes = [0u8; 32];
47
48 bytes.copy_from_slice(&decoded);
49
50 Ok(Self { bytes: Arc::new(bytes) })
51 }
52
53 pub fn as_bytes(&self) -> &[u8; 32] {
55 self.bytes.as_ref()
56 }
57}
58impl Debug for SpkiFingerprint {
59 fn fmt(&self, f: &mut Formatter<'_>) -> FmtResult {
60 write!(f, "SpkiFingerprint({})", BASE64_STANDARD.encode(self.bytes.as_ref()))
61 }
62}
63impl TryFrom<String> for SpkiFingerprint {
64 type Error = Error;
65
66 fn try_from(value: String) -> Result<Self> {
67 Self::from_b64(&value)
68 }
69}
70impl From<SpkiFingerprint> for String {
71 fn from(value: SpkiFingerprint) -> Self {
72 BASE64_STANDARD.encode(value.bytes.as_ref())
73 }
74}
75
76pub fn canonicalize_dns_name(value: &str) -> Option<String> {
78 let trimmed = value.trim();
79 if trimmed.is_empty() {
80 return None;
81 }
82
83 let without_trailing_dot = trimmed.trim_end_matches('.');
84 if without_trailing_dot.is_empty() {
85 return None;
86 }
87
88 Some(without_trailing_dot.to_ascii_lowercase())
89}
90
91pub fn normalize_allowlist(domains: Vec<String>) -> Vec<String> {
93 let mut seen = HashSet::new();
94 let mut normalized = Vec::with_capacity(domains.len());
95
96 for domain in domains {
97 if let Some(canonical) = canonicalize_dns_name(&domain)
98 && seen.insert(canonical.clone())
99 {
100 normalized.push(canonical);
101 }
102 }
103
104 normalized
105}
106
107pub fn deserialize_allowed_domains<'de, D>(
109 deserializer: D,
110) -> std::result::Result<Vec<String>, D::Error>
111where
112 D: Deserializer<'de>,
113{
114 let raw = Vec::<String>::deserialize(deserializer)?;
115 Ok(normalize_allowlist(raw))
116}
117
118pub fn enforce_https(url: &Url) -> Result<()> {
120 if url.scheme() == "https" {
121 Ok(())
122 } else {
123 Err(Error::Security(format!("Upstream URL {url} must use HTTPS.")))
124 }
125}
126
127#[inline]
128fn matches_allowlist(host: &str, domain: &str) -> bool {
129 if host == domain {
130 return true;
131 }
132
133 host.strip_suffix(domain).and_then(|prefix| prefix.strip_suffix('.')).is_some()
134}
135
136fn is_canonical_allowlist_entry(domain: &str) -> bool {
137 !domain.is_empty()
138 && !domain.ends_with('.')
139 && domain.trim().len() == domain.len()
140 && !domain.chars().any(|c| c.is_ascii_uppercase())
141}
142
143pub fn host_is_allowed(host: &str, allowed_domains: &[String]) -> bool {
147 if allowed_domains.is_empty() {
148 return true;
149 }
150
151 let Some(host) = canonicalize_dns_name(host) else {
152 return false;
153 };
154
155 allowed_domains.iter().any(|domain| {
156 if is_canonical_allowlist_entry(domain) {
157 matches_allowlist(&host, domain)
158 } else if let Some(canonical) = canonicalize_dns_name(domain) {
159 matches_allowlist(&host, &canonical)
160 } else {
161 false
162 }
163 })
164}
165
166pub fn fingerprint_spki(spki_der: &[u8]) -> [u8; 32] {
168 let digest = Sha256::digest(spki_der);
169 let mut bytes = [0u8; 32];
170
171 bytes.copy_from_slice(&digest);
172
173 bytes
174}
175
176pub fn verify_spki_pins<'a, I>(present_spki: I, pins: &[SpkiFingerprint]) -> Result<()>
180where
181 I: IntoIterator<Item = &'a [u8]>,
182{
183 if pins.is_empty() {
184 return Ok(());
185 }
186
187 let mut presented_fingerprints = Vec::new();
188
189 for spki in present_spki {
190 let fingerprint = fingerprint_spki(spki);
191 if pins.iter().any(|pin| pin.as_bytes() == &fingerprint) {
192 return Ok(());
193 }
194 if tracing::enabled!(tracing::Level::WARN) {
195 presented_fingerprints.push(BASE64_STANDARD.encode(fingerprint));
196 }
197 }
198
199 if tracing::enabled!(tracing::Level::WARN) {
200 let expected: Vec<String> =
201 pins.iter().map(|pin| BASE64_STANDARD.encode(pin.as_bytes())).collect();
202 tracing::warn!(
203 expected = ?expected,
204 presented = ?presented_fingerprints,
205 "SPKI pin verification failed — no fingerprints matched",
206 );
207 }
208
209 Err(Error::Security(
210 "Presented certificate chain does not match any configured SPKI pins.".into(),
211 ))
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217 use url::Url;
218
219 #[test]
220 fn base64_variants_are_accepted() {
221 let bytes = [42u8; 32];
222 let standard = BASE64_STANDARD.encode(bytes);
223 let url_safe = BASE64_URL_SAFE_NO_PAD.encode(bytes);
224
225 for encoded in [standard, url_safe] {
226 let fingerprint = SpkiFingerprint::from_b64(&encoded).expect("valid fingerprint");
227 assert_eq!(fingerprint.as_bytes(), &bytes);
228 }
229 }
230
231 #[test]
232 fn base64_length_error_is_reported() {
233 let err = SpkiFingerprint::from_b64("AQID");
234 assert!(err.is_err());
235 }
236
237 #[test]
238 fn host_allowlist_handles_case_and_trailing_dot() {
239 let domains = normalize_allowlist(vec!["Example.COM.".into()]);
240 assert!(host_is_allowed("api.EXAMPLE.com.", &domains));
241 assert!(host_is_allowed("example.com.", &domains));
242 assert!(!host_is_allowed("other.org", &domains));
243 let empty_allowlist: Vec<String> = Vec::new();
244 assert!(host_is_allowed("anything.example", &empty_allowlist));
245 }
246
247 #[test]
248 fn verify_spki_pins_success_and_failure() {
249 let spki_primary = b"primary";
250 let spki_other = b"other";
251 let pin_value = BASE64_STANDARD.encode(fingerprint_spki(spki_primary));
252 let pins = vec![SpkiFingerprint::from_b64(&pin_value).unwrap()];
253
254 assert!(verify_spki_pins([spki_primary.as_slice()], &pins).is_ok());
255 assert!(verify_spki_pins([spki_other.as_slice()], &pins).is_err());
256 }
257
258 #[test]
259 fn enforce_https_rejects_insecure_scheme() {
260 let http = Url::parse("http://example.com/jwks").unwrap();
261 assert!(enforce_https(&http).is_err());
262 }
263}