Skip to main content

sozu_command_lib/
certificate.rs

1use std::{fmt, str::FromStr};
2
3use hex::{FromHex, FromHexError};
4use serde::de::{self, Visitor};
5use sha2::{Digest, Sha256};
6use x509_parser::{
7    certificate::X509Certificate,
8    extensions::{GeneralName, ParsedExtension},
9    oid_registry::{OID_X509_COMMON_NAME, OID_X509_EXT_SUBJECT_ALT_NAME},
10    parse_x509_certificate,
11    pem::{Pem, parse_x509_pem},
12};
13
14use crate::{
15    config::{Config, ConfigError},
16    proto::command::{CertificateAndKey, TlsVersion},
17};
18
19// -----------------------------------------------------------------------------
20// CertificateError
21
22#[derive(thiserror::Error, Debug)]
23pub enum CertificateError {
24    #[error("Could not parse PEM certificate from bytes: {0}")]
25    ParsePEMCertificate(String),
26    #[error("Could not parse X509 certificate from bytes: {0}")]
27    ParseX509Certificate(String),
28    #[error("failed to parse tls version '{0}'")]
29    InvalidTlsVersion(String),
30    #[error("failed to parse fingerprint, {0}")]
31    InvalidFingerprint(FromHexError),
32    #[error("could not load file on path {path}: {error}")]
33    LoadFile { path: String, error: ConfigError },
34    #[error("Failed at decoding the hex encoded certificate: {0}")]
35    DecodeError(FromHexError),
36}
37
38// -----------------------------------------------------------------------------
39// parse
40
41/// parse a pem file encoded as binary and convert it into the right structure
42/// (a.k.a [`Pem`])
43pub fn parse_pem(certificate: &[u8]) -> Result<Pem, CertificateError> {
44    let (_, pem) = parse_x509_pem(certificate)
45        .map_err(|err| CertificateError::ParsePEMCertificate(err.to_string()))?;
46
47    Ok(pem)
48}
49
50/// parse x509 certificate from PEM bytes
51pub fn parse_x509(pem_bytes: &[u8]) -> Result<X509Certificate<'_>, CertificateError> {
52    parse_x509_certificate(pem_bytes)
53        .map_err(|nom_e| CertificateError::ParseX509Certificate(nom_e.to_string()))
54        .map(|t| t.1)
55}
56
57// -----------------------------------------------------------------------------
58// get_cn_and_san_attributes
59
60/// Retrieve the certificate's authoritative DNS identities for routing.
61///
62/// Per RFC 6125 §6.4.4: when the SubjectAlternativeName extension contains
63/// at least one `dNSName` entry, the SAN entries are the sole authoritative
64/// identities and the Common Name is ignored. The CN is only honoured as a
65/// fallback when the certificate omits the SAN extension entirely or
66/// declares it without a `dNSName` (e.g. SAN with only `iPAddress` /
67/// `rfc822Name` / `directoryName` entries — uncommon, but legal).
68///
69/// Aligns Sōzu's coalescing trust boundary with browser implementations
70/// (Firefox / Chrome both stopped honouring CN for hostname verification
71/// circa 2017) so a cert with `CN=tenant-b.example` and `SAN=tenant-a.example`
72/// cannot smuggle `tenant-b.example` into the routing authority list.
73pub fn get_cn_and_san_attributes(x509: &X509Certificate) -> Vec<String> {
74    let mut names: Vec<String> = Vec::new();
75    let mut san_dns_seen = false;
76
77    for extension in x509.extensions() {
78        if extension.oid == OID_X509_EXT_SUBJECT_ALT_NAME {
79            if let ParsedExtension::SubjectAlternativeName(san) = extension.parsed_extension() {
80                for name in &san.general_names {
81                    if let GeneralName::DNSName(name) = name {
82                        san_dns_seen = true;
83                        names.push(name.to_string());
84                    }
85                }
86            }
87        }
88    }
89
90    if !san_dns_seen {
91        for name in x509.subject().iter_by_oid(&OID_X509_COMMON_NAME) {
92            names.push(
93                name.as_str()
94                    .map(String::from)
95                    .unwrap_or_else(|_| String::from_utf8_lossy(name.as_slice()).to_string()),
96            );
97        }
98    }
99    names.dedup();
100    names
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    /// RFC 6125 §6.4.4: when SAN contains at least one dNSName, the CN is
108    /// ignored. A cert with `CN=tenant-b.example` and SAN `tenant-a.example`
109    /// is authoritative for `tenant-a.example` only.
110    #[test]
111    fn san_dns_present_excludes_cn() {
112        let pem = parse_pem(include_str!("../../lib/assets/cn-ne-san-cert.pem").as_bytes())
113            .expect("parse PEM");
114        let x509 = parse_x509(&pem.contents).expect("parse x509");
115        let names = get_cn_and_san_attributes(&x509);
116        assert_eq!(names, vec![String::from("tenant-a.example")]);
117    }
118
119    /// Fallback: SAN extension absent (no dNSName entries) ⇒ CN is honoured.
120    /// `lib/assets/certificate.pem` (CN=lolcatho.st, no SAN extension) is the
121    /// canonical fixture for this branch.
122    #[test]
123    fn cn_used_when_san_absent() {
124        let pem = parse_pem(include_str!("../../lib/assets/certificate.pem").as_bytes())
125            .expect("parse PEM");
126        let x509 = parse_x509(&pem.contents).expect("parse x509");
127        let names = get_cn_and_san_attributes(&x509);
128        assert_eq!(names, vec![String::from("lolcatho.st")]);
129    }
130
131    /// SAN dNSName present and CN ∈ SAN ⇒ the resulting list is the SAN
132    /// dNSName set verbatim (dedup removes the duplicate CN entry from the
133    /// pre-fix code path; the post-fix code never inserts the CN at all,
134    /// so the same list is observed but via a tighter path).
135    #[test]
136    fn san_dns_present_cn_is_san_member() {
137        let pem = parse_pem(include_str!("../../lib/assets/multi-sni-cert.pem").as_bytes())
138            .expect("parse PEM");
139        let x509 = parse_x509(&pem.contents).expect("parse x509");
140        let names = get_cn_and_san_attributes(&x509);
141        assert!(names.contains(&String::from("foo.example.com")));
142        assert!(names.contains(&String::from("bar.example.com")));
143        assert!(names.contains(&String::from("baz.example.com")));
144        assert!(names.contains(&String::from("localhost")));
145        assert_eq!(names.len(), 4);
146    }
147}
148
149// -----------------------------------------------------------------------------
150// TlsVersion
151
152impl FromStr for TlsVersion {
153    type Err = CertificateError;
154
155    fn from_str(s: &str) -> Result<Self, Self::Err> {
156        match s {
157            "SSL_V2" => Ok(TlsVersion::SslV2),
158            "SSL_V3" => Ok(TlsVersion::SslV3),
159            "TLSv1" => Ok(TlsVersion::TlsV10),
160            "TLS_V11" => Ok(TlsVersion::TlsV11),
161            "TLS_V12" => Ok(TlsVersion::TlsV12),
162            "TLS_V13" => Ok(TlsVersion::TlsV13),
163            _ => Err(CertificateError::InvalidTlsVersion(s.to_string())),
164        }
165    }
166}
167
168// -----------------------------------------------------------------------------
169// Fingerprint
170
171//FIXME: make fixed size depending on hash algorithm
172/// A TLS certificates, encoded in bytes
173#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
174pub struct Fingerprint(pub Vec<u8>);
175
176impl FromStr for Fingerprint {
177    type Err = CertificateError;
178
179    fn from_str(s: &str) -> Result<Self, Self::Err> {
180        hex::decode(s)
181            .map_err(CertificateError::InvalidFingerprint)
182            .map(Fingerprint)
183    }
184}
185
186impl fmt::Debug for Fingerprint {
187    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
188        write!(f, "CertificateFingerprint({})", hex::encode(&self.0))
189    }
190}
191
192impl fmt::Display for Fingerprint {
193    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
194        write!(f, "{}", hex::encode(&self.0))
195    }
196}
197
198impl serde::Serialize for Fingerprint {
199    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
200    where
201        S: serde::Serializer,
202    {
203        serializer.serialize_str(&hex::encode(&self.0))
204    }
205}
206
207struct FingerprintVisitor;
208
209impl Visitor<'_> for FingerprintVisitor {
210    type Value = Fingerprint;
211
212    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
213        formatter.write_str("the certificate fingerprint must be in hexadecimal format")
214    }
215
216    fn visit_str<E>(self, value: &str) -> Result<Fingerprint, E>
217    where
218        E: de::Error,
219    {
220        FromHex::from_hex(value)
221            .map_err(|e| E::custom(format!("could not deserialize hex: {e:?}")))
222            .map(Fingerprint)
223    }
224}
225
226impl<'de> serde::Deserialize<'de> for Fingerprint {
227    fn deserialize<D>(deserializer: D) -> Result<Fingerprint, D::Error>
228    where
229        D: serde::de::Deserializer<'de>,
230    {
231        deserializer.deserialize_str(FingerprintVisitor {})
232    }
233}
234
235/// Compute fingerprint from decoded pem as binary value
236pub fn calculate_fingerprint_from_der(certificate: &[u8]) -> Vec<u8> {
237    Sha256::digest(certificate).iter().cloned().collect()
238}
239
240/// Compute fingerprint from a certificate that is encoded in pem format
241pub fn calculate_fingerprint(certificate: &[u8]) -> Result<Vec<u8>, CertificateError> {
242    let parsed_certificate = parse_pem(certificate)?;
243    let fingerprint = calculate_fingerprint_from_der(&parsed_certificate.contents);
244    Ok(fingerprint)
245}
246
247pub fn split_certificate_chain(mut chain: String) -> Vec<String> {
248    let mut v = Vec::new();
249
250    let end = "-----END CERTIFICATE-----";
251    loop {
252        if let Some(sz) = chain.find(end) {
253            let cert: String = chain.drain(..sz + end.len()).collect();
254            v.push(cert.trim().to_string());
255            continue;
256        }
257
258        break;
259    }
260
261    v
262}
263
264pub fn get_fingerprint_from_certificate_path(
265    certificate_path: &str,
266) -> Result<Fingerprint, CertificateError> {
267    let bytes =
268        Config::load_file_bytes(certificate_path).map_err(|e| CertificateError::LoadFile {
269            path: certificate_path.to_string(),
270            error: e,
271        })?;
272
273    let parsed_bytes = calculate_fingerprint(&bytes)?;
274
275    Ok(Fingerprint(parsed_bytes))
276}
277
278pub fn decode_fingerprint(fingerprint: &str) -> Result<Fingerprint, CertificateError> {
279    let bytes = hex::decode(fingerprint).map_err(CertificateError::DecodeError)?;
280    Ok(Fingerprint(bytes))
281}
282
283pub fn load_full_certificate(
284    certificate_path: &str,
285    certificate_chain_path: &str,
286    key_path: &str,
287    versions: Vec<TlsVersion>,
288    names: Vec<String>,
289) -> Result<CertificateAndKey, CertificateError> {
290    let certificate =
291        Config::load_file(certificate_path).map_err(|e| CertificateError::LoadFile {
292            path: certificate_path.to_string(),
293            error: e,
294        })?;
295
296    let certificate_chain = Config::load_file(certificate_chain_path)
297        .map(split_certificate_chain)
298        .map_err(|e| CertificateError::LoadFile {
299            path: certificate_chain_path.to_string(),
300            error: e,
301        })?;
302
303    let key = Config::load_file(key_path).map_err(|e| CertificateError::LoadFile {
304        path: key_path.to_string(),
305        error: e,
306    })?;
307
308    let versions = versions.iter().map(|v| *v as i32).collect();
309
310    Ok(CertificateAndKey {
311        certificate,
312        certificate_chain,
313        key,
314        versions,
315        names,
316    })
317}
318
319impl CertificateAndKey {
320    pub fn fingerprint(&self) -> Result<Fingerprint, CertificateError> {
321        let pem = parse_pem(self.certificate.as_bytes())?;
322        let fingerprint = Fingerprint(Sha256::digest(pem.contents).iter().cloned().collect());
323        Ok(fingerprint)
324    }
325
326    pub fn get_overriding_names(&self) -> Result<Vec<String>, CertificateError> {
327        if self.names.is_empty() {
328            let pem = parse_pem(self.certificate.as_bytes())?;
329            let x509 = parse_x509(&pem.contents)?;
330
331            let overriding_names = get_cn_and_san_attributes(&x509);
332
333            Ok(overriding_names.into_iter().collect())
334        } else {
335            Ok(self.names.to_owned())
336        }
337    }
338
339    pub fn apply_overriding_names(&mut self) -> Result<(), CertificateError> {
340        self.names = self.get_overriding_names()?;
341        Ok(())
342    }
343}