1use std::{fmt, str::FromStr};
2
3use hex::{FromHex, FromHexError};
4use serde::de::{self, Visitor};
5use sha2::{Digest, Sha256};
6use x509_parser::{
7 certificate::X509Certificate,
8 extensions::{GeneralName, ParsedExtension},
9 oid_registry::{OID_X509_COMMON_NAME, OID_X509_EXT_SUBJECT_ALT_NAME},
10 parse_x509_certificate,
11 pem::{Pem, parse_x509_pem},
12};
13
14use crate::{
15 config::{Config, ConfigError},
16 proto::command::{CertificateAndKey, TlsVersion},
17};
18
19#[derive(thiserror::Error, Debug)]
23pub enum CertificateError {
24 #[error("Could not parse PEM certificate from bytes: {0}")]
25 ParsePEMCertificate(String),
26 #[error("Could not parse X509 certificate from bytes: {0}")]
27 ParseX509Certificate(String),
28 #[error("failed to parse tls version '{0}'")]
29 InvalidTlsVersion(String),
30 #[error("failed to parse fingerprint, {0}")]
31 InvalidFingerprint(FromHexError),
32 #[error("could not load file on path {path}: {error}")]
33 LoadFile { path: String, error: ConfigError },
34 #[error("Failed at decoding the hex encoded certificate: {0}")]
35 DecodeError(FromHexError),
36}
37
38pub fn parse_pem(certificate: &[u8]) -> Result<Pem, CertificateError> {
44 let (_, pem) = parse_x509_pem(certificate)
45 .map_err(|err| CertificateError::ParsePEMCertificate(err.to_string()))?;
46
47 Ok(pem)
48}
49
50pub fn parse_x509(pem_bytes: &[u8]) -> Result<X509Certificate<'_>, CertificateError> {
52 parse_x509_certificate(pem_bytes)
53 .map_err(|nom_e| CertificateError::ParseX509Certificate(nom_e.to_string()))
54 .map(|t| t.1)
55}
56
57pub fn get_cn_and_san_attributes(x509: &X509Certificate) -> Vec<String> {
74 let mut names: Vec<String> = Vec::new();
75 let mut san_dns_seen = false;
76
77 for extension in x509.extensions() {
78 if extension.oid == OID_X509_EXT_SUBJECT_ALT_NAME {
79 if let ParsedExtension::SubjectAlternativeName(san) = extension.parsed_extension() {
80 for name in &san.general_names {
81 if let GeneralName::DNSName(name) = name {
82 san_dns_seen = true;
83 names.push(name.to_string());
84 }
85 }
86 }
87 }
88 }
89
90 if !san_dns_seen {
91 for name in x509.subject().iter_by_oid(&OID_X509_COMMON_NAME) {
92 names.push(
93 name.as_str()
94 .map(String::from)
95 .unwrap_or_else(|_| String::from_utf8_lossy(name.as_slice()).to_string()),
96 );
97 }
98 }
99 names.dedup();
100 names
101}
102
103#[cfg(test)]
104mod tests {
105 use super::*;
106
107 #[test]
111 fn san_dns_present_excludes_cn() {
112 let pem = parse_pem(include_str!("../../lib/assets/cn-ne-san-cert.pem").as_bytes())
113 .expect("parse PEM");
114 let x509 = parse_x509(&pem.contents).expect("parse x509");
115 let names = get_cn_and_san_attributes(&x509);
116 assert_eq!(names, vec![String::from("tenant-a.example")]);
117 }
118
119 #[test]
123 fn cn_used_when_san_absent() {
124 let pem = parse_pem(include_str!("../../lib/assets/certificate.pem").as_bytes())
125 .expect("parse PEM");
126 let x509 = parse_x509(&pem.contents).expect("parse x509");
127 let names = get_cn_and_san_attributes(&x509);
128 assert_eq!(names, vec![String::from("lolcatho.st")]);
129 }
130
131 #[test]
136 fn san_dns_present_cn_is_san_member() {
137 let pem = parse_pem(include_str!("../../lib/assets/multi-sni-cert.pem").as_bytes())
138 .expect("parse PEM");
139 let x509 = parse_x509(&pem.contents).expect("parse x509");
140 let names = get_cn_and_san_attributes(&x509);
141 assert!(names.contains(&String::from("foo.example.com")));
142 assert!(names.contains(&String::from("bar.example.com")));
143 assert!(names.contains(&String::from("baz.example.com")));
144 assert!(names.contains(&String::from("localhost")));
145 assert_eq!(names.len(), 4);
146 }
147}
148
149impl FromStr for TlsVersion {
153 type Err = CertificateError;
154
155 fn from_str(s: &str) -> Result<Self, Self::Err> {
156 match s {
157 "SSL_V2" => Ok(TlsVersion::SslV2),
158 "SSL_V3" => Ok(TlsVersion::SslV3),
159 "TLSv1" => Ok(TlsVersion::TlsV10),
160 "TLS_V11" => Ok(TlsVersion::TlsV11),
161 "TLS_V12" => Ok(TlsVersion::TlsV12),
162 "TLS_V13" => Ok(TlsVersion::TlsV13),
163 _ => Err(CertificateError::InvalidTlsVersion(s.to_string())),
164 }
165 }
166}
167
168#[derive(Clone, PartialEq, Eq, Hash, PartialOrd, Ord, Default)]
174pub struct Fingerprint(pub Vec<u8>);
175
176impl FromStr for Fingerprint {
177 type Err = CertificateError;
178
179 fn from_str(s: &str) -> Result<Self, Self::Err> {
180 hex::decode(s)
181 .map_err(CertificateError::InvalidFingerprint)
182 .map(Fingerprint)
183 }
184}
185
186impl fmt::Debug for Fingerprint {
187 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
188 write!(f, "CertificateFingerprint({})", hex::encode(&self.0))
189 }
190}
191
192impl fmt::Display for Fingerprint {
193 fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
194 write!(f, "{}", hex::encode(&self.0))
195 }
196}
197
198impl serde::Serialize for Fingerprint {
199 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
200 where
201 S: serde::Serializer,
202 {
203 serializer.serialize_str(&hex::encode(&self.0))
204 }
205}
206
207struct FingerprintVisitor;
208
209impl Visitor<'_> for FingerprintVisitor {
210 type Value = Fingerprint;
211
212 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
213 formatter.write_str("the certificate fingerprint must be in hexadecimal format")
214 }
215
216 fn visit_str<E>(self, value: &str) -> Result<Fingerprint, E>
217 where
218 E: de::Error,
219 {
220 FromHex::from_hex(value)
221 .map_err(|e| E::custom(format!("could not deserialize hex: {e:?}")))
222 .map(Fingerprint)
223 }
224}
225
226impl<'de> serde::Deserialize<'de> for Fingerprint {
227 fn deserialize<D>(deserializer: D) -> Result<Fingerprint, D::Error>
228 where
229 D: serde::de::Deserializer<'de>,
230 {
231 deserializer.deserialize_str(FingerprintVisitor {})
232 }
233}
234
235pub fn calculate_fingerprint_from_der(certificate: &[u8]) -> Vec<u8> {
237 Sha256::digest(certificate).iter().cloned().collect()
238}
239
240pub fn calculate_fingerprint(certificate: &[u8]) -> Result<Vec<u8>, CertificateError> {
242 let parsed_certificate = parse_pem(certificate)?;
243 let fingerprint = calculate_fingerprint_from_der(&parsed_certificate.contents);
244 Ok(fingerprint)
245}
246
247pub fn split_certificate_chain(mut chain: String) -> Vec<String> {
248 let mut v = Vec::new();
249
250 let end = "-----END CERTIFICATE-----";
251 loop {
252 if let Some(sz) = chain.find(end) {
253 let cert: String = chain.drain(..sz + end.len()).collect();
254 v.push(cert.trim().to_string());
255 continue;
256 }
257
258 break;
259 }
260
261 v
262}
263
264pub fn get_fingerprint_from_certificate_path(
265 certificate_path: &str,
266) -> Result<Fingerprint, CertificateError> {
267 let bytes =
268 Config::load_file_bytes(certificate_path).map_err(|e| CertificateError::LoadFile {
269 path: certificate_path.to_string(),
270 error: e,
271 })?;
272
273 let parsed_bytes = calculate_fingerprint(&bytes)?;
274
275 Ok(Fingerprint(parsed_bytes))
276}
277
278pub fn decode_fingerprint(fingerprint: &str) -> Result<Fingerprint, CertificateError> {
279 let bytes = hex::decode(fingerprint).map_err(CertificateError::DecodeError)?;
280 Ok(Fingerprint(bytes))
281}
282
283pub fn load_full_certificate(
284 certificate_path: &str,
285 certificate_chain_path: &str,
286 key_path: &str,
287 versions: Vec<TlsVersion>,
288 names: Vec<String>,
289) -> Result<CertificateAndKey, CertificateError> {
290 let certificate =
291 Config::load_file(certificate_path).map_err(|e| CertificateError::LoadFile {
292 path: certificate_path.to_string(),
293 error: e,
294 })?;
295
296 let certificate_chain = Config::load_file(certificate_chain_path)
297 .map(split_certificate_chain)
298 .map_err(|e| CertificateError::LoadFile {
299 path: certificate_chain_path.to_string(),
300 error: e,
301 })?;
302
303 let key = Config::load_file(key_path).map_err(|e| CertificateError::LoadFile {
304 path: key_path.to_string(),
305 error: e,
306 })?;
307
308 let versions = versions.iter().map(|v| *v as i32).collect();
309
310 Ok(CertificateAndKey {
311 certificate,
312 certificate_chain,
313 key,
314 versions,
315 names,
316 })
317}
318
319impl CertificateAndKey {
320 pub fn fingerprint(&self) -> Result<Fingerprint, CertificateError> {
321 let pem = parse_pem(self.certificate.as_bytes())?;
322 let fingerprint = Fingerprint(Sha256::digest(pem.contents).iter().cloned().collect());
323 Ok(fingerprint)
324 }
325
326 pub fn get_overriding_names(&self) -> Result<Vec<String>, CertificateError> {
327 if self.names.is_empty() {
328 let pem = parse_pem(self.certificate.as_bytes())?;
329 let x509 = parse_x509(&pem.contents)?;
330
331 let overriding_names = get_cn_and_san_attributes(&x509);
332
333 Ok(overriding_names.into_iter().collect())
334 } else {
335 Ok(self.names.to_owned())
336 }
337 }
338
339 pub fn apply_overriding_names(&mut self) -> Result<(), CertificateError> {
340 self.names = self.get_overriding_names()?;
341 Ok(())
342 }
343}