Skip to main content

sozu_command_lib/
certificate.rs

1use std::{fmt, str::FromStr};
2
3use hex::{FromHex, FromHexError};
4use serde::de::{self, Visitor};
5use sha2::{Digest, Sha256};
6use x509_parser::{
7    certificate::X509Certificate,
8    extensions::{GeneralName, ParsedExtension},
9    oid_registry::{OID_X509_COMMON_NAME, OID_X509_EXT_SUBJECT_ALT_NAME},
10    parse_x509_certificate,
11    pem::{Pem, parse_x509_pem},
12};
13
14use crate::{
15    config::{Config, ConfigError},
16    proto::command::{CertificateAndKey, TlsVersion},
17};
18
19/// Byte length of a SHA-256 digest. Every fingerprint Sōzu *computes*
20/// (as opposed to one it *parses* from operator hex, which may be any
21/// length) is a SHA-256 hash and therefore exactly this many bytes.
22///
23/// Intentionally NOT `#[cfg(debug_assertions)]`-gated: `debug_assert!`
24/// compiles its arguments in every profile (it gates execution, not
25/// compilation), so a cfg-gated const referenced inside one would fail
26/// the release build with E0425. Ungated, it is dead code in release and
27/// dropped by the optimizer.
28#[allow(dead_code)]
29const SHA256_FINGERPRINT_LEN: usize = 32;
30
31// -----------------------------------------------------------------------------
32// CertificateError
33
34#[derive(thiserror::Error, Debug)]
35pub enum CertificateError {
36    #[error("Could not parse PEM certificate from bytes: {0}")]
37    ParsePEMCertificate(String),
38    #[error("Could not parse X509 certificate from bytes: {0}")]
39    ParseX509Certificate(String),
40    #[error("failed to parse tls version '{0}'")]
41    InvalidTlsVersion(String),
42    #[error("failed to parse fingerprint, {0}")]
43    InvalidFingerprint(FromHexError),
44    #[error("could not load file on path {path}: {error}")]
45    LoadFile { path: String, error: ConfigError },
46    #[error("Failed at decoding the hex encoded certificate: {0}")]
47    DecodeError(FromHexError),
48}
49
50// -----------------------------------------------------------------------------
51// parse
52
53/// parse a pem file encoded as binary and convert it into the right structure
54/// (a.k.a [`Pem`])
55pub fn parse_pem(certificate: &[u8]) -> Result<Pem, CertificateError> {
56    let (_, pem) = parse_x509_pem(certificate)
57        .map_err(|err| CertificateError::ParsePEMCertificate(err.to_string()))?;
58
59    Ok(pem)
60}
61
62/// parse x509 certificate from PEM bytes
63pub fn parse_x509(pem_bytes: &[u8]) -> Result<X509Certificate<'_>, CertificateError> {
64    parse_x509_certificate(pem_bytes)
65        .map_err(|nom_e| CertificateError::ParseX509Certificate(nom_e.to_string()))
66        .map(|t| t.1)
67}
68
69// -----------------------------------------------------------------------------
70// get_cn_and_san_attributes
71
72/// Retrieve the certificate's authoritative DNS identities for routing.
73///
74/// Per RFC 6125 §6.4.4: when the SubjectAlternativeName extension contains
75/// at least one `dNSName` entry, the SAN entries are the sole authoritative
76/// identities and the Common Name is ignored. The CN is only honoured as a
77/// fallback when the certificate omits the SAN extension entirely or
78/// declares it without a `dNSName` (e.g. SAN with only `iPAddress` /
79/// `rfc822Name` / `directoryName` entries — uncommon, but legal).
80///
81/// Aligns Sōzu's coalescing trust boundary with browser implementations
82/// (Firefox / Chrome both stopped honouring CN for hostname verification
83/// circa 2017) so a cert with `CN=tenant-b.example` and `SAN=tenant-a.example`
84/// cannot smuggle `tenant-b.example` into the routing authority list.
85pub fn get_cn_and_san_attributes(x509: &X509Certificate) -> Vec<String> {
86    let mut names: Vec<String> = Vec::new();
87    let mut san_dns_seen = false;
88
89    for extension in x509.extensions() {
90        if extension.oid == OID_X509_EXT_SUBJECT_ALT_NAME {
91            if let ParsedExtension::SubjectAlternativeName(san) = extension.parsed_extension() {
92                for name in &san.general_names {
93                    if let GeneralName::DNSName(name) = name {
94                        san_dns_seen = true;
95                        names.push(name.to_string());
96                    }
97                }
98            }
99        }
100    }
101
102    // POST: a dNSName SAN entry was observed iff at least one name has been
103    // collected from the SAN branch. `san_dns_seen` and a non-empty `names`
104    // must agree before the CN fallback runs — otherwise the RFC 6125
105    // §6.4.4 trust boundary (SAN dNSName is authoritative when present) is
106    // broken and the CN could smuggle an extra identity below.
107    debug_assert_eq!(
108        san_dns_seen,
109        !names.is_empty(),
110        "SAN dNSName presence must match the collected-names state before CN fallback"
111    );
112
113    if !san_dns_seen {
114        for name in x509.subject().iter_by_oid(&OID_X509_COMMON_NAME) {
115            names.push(
116                name.as_str()
117                    .map(String::from)
118                    .unwrap_or_else(|_| String::from_utf8_lossy(name.as_slice()).to_string()),
119            );
120        }
121    }
122    let before_dedup = names.len();
123    names.dedup();
124    // POST: dedup only removes *consecutive* equal entries, so it can never
125    // grow the list; the deduped length is an invariant upper bound.
126    debug_assert!(
127        names.len() <= before_dedup,
128        "dedup must not grow the identity list"
129    );
130    names
131}
132
133#[cfg(test)]
134mod tests {
135    use super::*;
136
137    /// RFC 6125 §6.4.4: when SAN contains at least one dNSName, the CN is
138    /// ignored. A cert with `CN=tenant-b.example` and SAN `tenant-a.example`
139    /// is authoritative for `tenant-a.example` only.
140    #[test]
141    fn san_dns_present_excludes_cn() {
142        let pem = parse_pem(include_str!("../../lib/assets/cn-ne-san-cert.pem").as_bytes())
143            .expect("parse PEM");
144        let x509 = parse_x509(&pem.contents).expect("parse x509");
145        let names = get_cn_and_san_attributes(&x509);
146        assert_eq!(names, vec![String::from("tenant-a.example")]);
147    }
148
149    /// Fallback: SAN extension absent (no dNSName entries) ⇒ CN is honoured.
150    /// `lib/assets/certificate.pem` (CN=lolcatho.st, no SAN extension) is the
151    /// canonical fixture for this branch.
152    #[test]
153    fn cn_used_when_san_absent() {
154        let pem = parse_pem(include_str!("../../lib/assets/certificate.pem").as_bytes())
155            .expect("parse PEM");
156        let x509 = parse_x509(&pem.contents).expect("parse x509");
157        let names = get_cn_and_san_attributes(&x509);
158        assert_eq!(names, vec![String::from("lolcatho.st")]);
159    }
160
161    /// SAN dNSName present and CN ∈ SAN ⇒ the resulting list is the SAN
162    /// dNSName set verbatim (dedup removes the duplicate CN entry from the
163    /// pre-fix code path; the post-fix code never inserts the CN at all,
164    /// so the same list is observed but via a tighter path).
165    #[test]
166    fn san_dns_present_cn_is_san_member() {
167        let pem = parse_pem(include_str!("../../lib/assets/multi-sni-cert.pem").as_bytes())
168            .expect("parse PEM");
169        let x509 = parse_x509(&pem.contents).expect("parse x509");
170        let names = get_cn_and_san_attributes(&x509);
171        assert!(names.contains(&String::from("foo.example.com")));
172        assert!(names.contains(&String::from("bar.example.com")));
173        assert!(names.contains(&String::from("baz.example.com")));
174        assert!(names.contains(&String::from("localhost")));
175        assert_eq!(names.len(), 4);
176    }
177}
178
179// -----------------------------------------------------------------------------
180// TlsVersion
181
182impl FromStr for TlsVersion {
183    type Err = CertificateError;
184
185    fn from_str(s: &str) -> Result<Self, Self::Err> {
186        match s {
187            "SSL_V2" => Ok(TlsVersion::SslV2),
188            "SSL_V3" => Ok(TlsVersion::SslV3),
189            "TLSv1" => Ok(TlsVersion::TlsV10),
190            "TLS_V11" => Ok(TlsVersion::TlsV11),
191            "TLS_V12" => Ok(TlsVersion::TlsV12),
192            "TLS_V13" => Ok(TlsVersion::TlsV13),
193            _ => Err(CertificateError::InvalidTlsVersion(s.to_string())),
194        }
195    }
196}
197
198// -----------------------------------------------------------------------------
199// Fingerprint
200
201//FIXME: make fixed size depending on hash algorithm
202/// A TLS certificates, encoded in bytes
203#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
204pub struct Fingerprint(pub Vec<u8>);
205
206impl FromStr for Fingerprint {
207    type Err = CertificateError;
208
209    fn from_str(s: &str) -> Result<Self, Self::Err> {
210        hex::decode(s)
211            .map_err(CertificateError::InvalidFingerprint)
212            .map(Fingerprint)
213    }
214}
215
216impl fmt::Debug for Fingerprint {
217    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
218        write!(f, "CertificateFingerprint({})", hex::encode(&self.0))
219    }
220}
221
222impl fmt::Display for Fingerprint {
223    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
224        write!(f, "{}", hex::encode(&self.0))
225    }
226}
227
228impl serde::Serialize for Fingerprint {
229    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
230    where
231        S: serde::Serializer,
232    {
233        serializer.serialize_str(&hex::encode(&self.0))
234    }
235}
236
237struct FingerprintVisitor;
238
239impl Visitor<'_> for FingerprintVisitor {
240    type Value = Fingerprint;
241
242    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
243        formatter.write_str("the certificate fingerprint must be in hexadecimal format")
244    }
245
246    fn visit_str<E>(self, value: &str) -> Result<Fingerprint, E>
247    where
248        E: de::Error,
249    {
250        FromHex::from_hex(value)
251            .map_err(|e| E::custom(format!("could not deserialize hex: {e:?}")))
252            .map(Fingerprint)
253    }
254}
255
256impl<'de> serde::Deserialize<'de> for Fingerprint {
257    fn deserialize<D>(deserializer: D) -> Result<Fingerprint, D::Error>
258    where
259        D: serde::de::Deserializer<'de>,
260    {
261        deserializer.deserialize_str(FingerprintVisitor {})
262    }
263}
264
265/// Compute fingerprint from decoded pem as binary value
266pub fn calculate_fingerprint_from_der(certificate: &[u8]) -> Vec<u8> {
267    let fingerprint: Vec<u8> = Sha256::digest(certificate).iter().cloned().collect();
268    // POST: a SHA-256 digest is unconditionally 32 bytes. Anything else
269    // means the digest collection went wrong and downstream fingerprint
270    // comparison / map keying would silently mismatch.
271    debug_assert_eq!(
272        fingerprint.len(),
273        SHA256_FINGERPRINT_LEN,
274        "SHA-256 fingerprint must be exactly 32 bytes"
275    );
276    fingerprint
277}
278
279/// Compute fingerprint from a certificate that is encoded in pem format
280pub fn calculate_fingerprint(certificate: &[u8]) -> Result<Vec<u8>, CertificateError> {
281    let parsed_certificate = parse_pem(certificate)?;
282    let fingerprint = calculate_fingerprint_from_der(&parsed_certificate.contents);
283    // POST: the result is a SHA-256 digest and recomputing it over the same
284    // DER bytes yields the same value — the fingerprint is a pure function of
285    // the parsed certificate contents, never of the surrounding PEM framing.
286    debug_assert_eq!(
287        fingerprint.len(),
288        SHA256_FINGERPRINT_LEN,
289        "PEM fingerprint must be a 32-byte SHA-256 digest"
290    );
291    debug_assert_eq!(
292        fingerprint,
293        calculate_fingerprint_from_der(&parsed_certificate.contents),
294        "fingerprint must be a deterministic function of the DER contents"
295    );
296    Ok(fingerprint)
297}
298
299pub fn split_certificate_chain(mut chain: String) -> Vec<String> {
300    let mut v = Vec::new();
301
302    let end = "-----END CERTIFICATE-----";
303    // PRE: the loop consumes exactly one END marker per iteration, so the
304    // final chain length must equal the number of markers present on entry.
305    // The leaf certificate is, by PEM convention, the first block — it lands
306    // at index 0 because we drain from the front. (`matches().count()` is
307    // overlap-free; the END marker cannot overlap itself.)
308    let expected_certs = chain.matches(end).count();
309    loop {
310        if let Some(sz) = chain.find(end) {
311            let cert: String = chain.drain(..sz + end.len()).collect();
312            // INV: every emitted block carries exactly the END marker that
313            // terminated it — the drain range includes `end.len()` bytes past
314            // the match, so a non-terminated trailing block is impossible.
315            debug_assert!(
316                cert.contains(end),
317                "each split block must contain its END CERTIFICATE marker"
318            );
319            v.push(cert.trim().to_string());
320            continue;
321        }
322
323        break;
324    }
325
326    // POST: one block per END marker, leaf at index 0.
327    debug_assert_eq!(
328        v.len(),
329        expected_certs,
330        "split must yield exactly one certificate per END marker"
331    );
332    v
333}
334
335pub fn get_fingerprint_from_certificate_path(
336    certificate_path: &str,
337) -> Result<Fingerprint, CertificateError> {
338    let bytes =
339        Config::load_file_bytes(certificate_path).map_err(|e| CertificateError::LoadFile {
340            path: certificate_path.to_string(),
341            error: e,
342        })?;
343
344    let parsed_bytes = calculate_fingerprint(&bytes)?;
345
346    // POST: a computed fingerprint is always a 32-byte SHA-256 digest. (This
347    // is distinct from a *parsed* fingerprint — see `decode_fingerprint` /
348    // `Fingerprint::from_str` — which may be any operator-supplied length.)
349    debug_assert_eq!(
350        parsed_bytes.len(),
351        SHA256_FINGERPRINT_LEN,
352        "fingerprint loaded from a certificate path must be 32 bytes"
353    );
354    Ok(Fingerprint(parsed_bytes))
355}
356
357pub fn decode_fingerprint(fingerprint: &str) -> Result<Fingerprint, CertificateError> {
358    let bytes = hex::decode(fingerprint).map_err(CertificateError::DecodeError)?;
359    Ok(Fingerprint(bytes))
360}
361
362pub fn load_full_certificate(
363    certificate_path: &str,
364    certificate_chain_path: &str,
365    key_path: &str,
366    versions: Vec<TlsVersion>,
367    names: Vec<String>,
368) -> Result<CertificateAndKey, CertificateError> {
369    let certificate =
370        Config::load_file(certificate_path).map_err(|e| CertificateError::LoadFile {
371            path: certificate_path.to_string(),
372            error: e,
373        })?;
374
375    let certificate_chain = Config::load_file(certificate_chain_path)
376        .map(split_certificate_chain)
377        .map_err(|e| CertificateError::LoadFile {
378            path: certificate_chain_path.to_string(),
379            error: e,
380        })?;
381
382    let key = Config::load_file(key_path).map_err(|e| CertificateError::LoadFile {
383        path: key_path.to_string(),
384        error: e,
385    })?;
386
387    let versions_len = versions.len();
388    let names_len = names.len();
389    let versions: Vec<i32> = versions.iter().map(|v| *v as i32).collect();
390
391    // POST: the i32-encoded TLS-version list is a 1:1 map of the input — no
392    // version is dropped or duplicated by the `as i32` projection.
393    debug_assert_eq!(
394        versions.len(),
395        versions_len,
396        "version encoding must preserve the input cardinality"
397    );
398
399    let built = CertificateAndKey {
400        certificate,
401        certificate_chain,
402        key,
403        versions,
404        names,
405    };
406
407    // POST: the routing-name list is carried through verbatim — the builder
408    // does not synthesize or drop names (overriding names are derived later
409    // via `apply_overriding_names`, never here).
410    debug_assert_eq!(
411        built.names.len(),
412        names_len,
413        "names must be carried through the builder unchanged"
414    );
415    Ok(built)
416}
417
418impl CertificateAndKey {
419    pub fn fingerprint(&self) -> Result<Fingerprint, CertificateError> {
420        let pem = parse_pem(self.certificate.as_bytes())?;
421        let fingerprint = Fingerprint(Sha256::digest(&pem.contents).iter().cloned().collect());
422        // POST: the certificate fingerprint is a 32-byte SHA-256 digest of the
423        // DER contents and agrees with the free-function recompute over the
424        // same bytes — the two fingerprint paths must never diverge or a cert
425        // would key into two different map slots.
426        debug_assert_eq!(
427            fingerprint.0.len(),
428            SHA256_FINGERPRINT_LEN,
429            "CertificateAndKey fingerprint must be 32 bytes"
430        );
431        debug_assert_eq!(
432            fingerprint.0,
433            calculate_fingerprint_from_der(&pem.contents),
434            "method and free-function fingerprints must agree on the same DER"
435        );
436        Ok(fingerprint)
437    }
438
439    pub fn get_overriding_names(&self) -> Result<Vec<String>, CertificateError> {
440        if self.names.is_empty() {
441            let pem = parse_pem(self.certificate.as_bytes())?;
442            let x509 = parse_x509(&pem.contents)?;
443
444            let overriding_names = get_cn_and_san_attributes(&x509);
445
446            Ok(overriding_names.into_iter().collect())
447        } else {
448            let names = self.names.to_owned();
449            // POST: when explicit names are set, they are returned verbatim —
450            // the cert is NOT consulted, so the operator's intent is the sole
451            // authority. (`to_owned` preserves both length and order.)
452            debug_assert_eq!(
453                names, self.names,
454                "explicit names must be returned unchanged when present"
455            );
456            Ok(names)
457        }
458    }
459
460    pub fn apply_overriding_names(&mut self) -> Result<(), CertificateError> {
461        let resolved = self.get_overriding_names()?;
462        self.names = resolved.clone();
463        // POST: after applying, the stored names are exactly the resolved set
464        // and re-resolving is idempotent — a second `apply_overriding_names`
465        // would now hit the "explicit names present" branch and be a no-op.
466        debug_assert_eq!(
467            self.names, resolved,
468            "applied names must equal the resolved set"
469        );
470        Ok(())
471    }
472}