1use sha2::{Sha256, Digest};
2use std::collections::HashSet;
3use std::fmt;
4use x509_parser::der_parser::oid::Oid;
5use x509_parser::pem::parse_x509_pem;
6use x509_parser::prelude::*;
7
8use super::svg::validate_svg_tiny_ps;
9
10const BIMI_EKU_OID: &[u64] = &[1, 3, 6, 1, 5, 5, 7, 3, 31];
13
14const LOGOTYPE_OID: &[u64] = &[1, 3, 6, 1, 5, 5, 7, 1, 12];
16
17#[derive(Debug, Clone, PartialEq, Eq)]
19pub enum VmcError {
20 PemParse(String),
22 NoCertificates,
24 MultipleVmcs,
26 OutOfOrder,
28 DuplicateCert,
30 MissingBimiEku,
32 SanMismatch { expected: String },
34 Expired,
36 NotYetValid,
38 MissingLogoType,
40 LogoTypeExtractFailed(String),
42 SvgValidation(String),
44 LogoHashMismatch,
46 ChainValidation(String),
48 X509Parse(String),
50}
51
52impl fmt::Display for VmcError {
53 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
54 match self {
55 VmcError::PemParse(e) => write!(f, "PEM parse error: {}", e),
56 VmcError::NoCertificates => write!(f, "no certificates in PEM data"),
57 VmcError::MultipleVmcs => write!(f, "multiple VMC certificates in chain"),
58 VmcError::OutOfOrder => write!(f, "certificate chain out of order"),
59 VmcError::DuplicateCert => write!(f, "duplicate certificate in chain"),
60 VmcError::MissingBimiEku => write!(f, "missing BIMI EKU OID 1.3.6.1.5.5.7.3.31"),
61 VmcError::SanMismatch { expected } => {
62 write!(f, "SAN does not match {}", expected)
63 }
64 VmcError::Expired => write!(f, "certificate expired"),
65 VmcError::NotYetValid => write!(f, "certificate not yet valid"),
66 VmcError::MissingLogoType => write!(f, "LogoType extension not found"),
67 VmcError::LogoTypeExtractFailed(e) => {
68 write!(f, "LogoType SVG extraction failed: {}", e)
69 }
70 VmcError::SvgValidation(e) => write!(f, "SVG validation failed: {}", e),
71 VmcError::LogoHashMismatch => {
72 write!(f, "logo hash mismatch: DNS-fetched != VMC-embedded")
73 }
74 VmcError::ChainValidation(e) => write!(f, "chain validation: {}", e),
75 VmcError::X509Parse(e) => write!(f, "X.509 parse error: {}", e),
76 }
77 }
78}
79
80#[derive(Debug)]
82pub struct VmcValidationResult {
83 pub embedded_svg: String,
85}
86
87fn parse_pem_chain(pem_data: &[u8]) -> Result<Vec<Vec<u8>>, VmcError> {
90 let mut der_certs: Vec<Vec<u8>> = Vec::new();
91 let mut remaining = pem_data;
92
93 loop {
94 match parse_x509_pem(remaining) {
95 Ok((rest, pem)) => {
96 if pem.label != "CERTIFICATE" {
97 remaining = rest;
98 continue;
99 }
100 der_certs.push(pem.contents);
101 if rest.is_empty() {
102 break;
103 }
104 remaining = rest;
105 }
106 Err(_) => break,
107 }
108 }
109
110 if der_certs.is_empty() {
111 return Err(VmcError::NoCertificates);
112 }
113
114 let mut seen = HashSet::new();
116 for cert_der in &der_certs {
117 if !seen.insert(cert_der.clone()) {
118 return Err(VmcError::DuplicateCert);
119 }
120 }
121
122 Ok(der_certs)
123}
124
125pub fn validate_vmc(
134 pem_data: &[u8],
135 selector: &str,
136 domain: &str,
137 dns_logo_svg: Option<&str>,
138) -> Result<VmcValidationResult, VmcError> {
139 let der_certs = parse_pem_chain(pem_data)?;
140
141 let mut parsed_certs: Vec<X509Certificate<'_>> = Vec::new();
143
144 for der in &der_certs {
147 let (_, cert) = X509Certificate::from_der(der)
148 .map_err(|e| VmcError::X509Parse(format!("{}", e)))?;
149 parsed_certs.push(cert);
150 }
151
152 if parsed_certs.is_empty() {
153 return Err(VmcError::NoCertificates);
154 }
155
156 let vmc_count = parsed_certs.iter().filter(|c| !c.tbs_certificate.is_ca()).count();
159 if vmc_count > 1 {
160 return Err(VmcError::MultipleVmcs);
161 }
162
163 let vmc = &parsed_certs[0];
165 if vmc.tbs_certificate.is_ca() && parsed_certs.len() > 1 {
166 return Err(VmcError::OutOfOrder);
167 }
168
169 for i in 0..parsed_certs.len().saturating_sub(1) {
171 let child = &parsed_certs[i];
172 let parent = &parsed_certs[i + 1];
173 if child.issuer() != parent.subject() {
174 return Err(VmcError::OutOfOrder);
175 }
176 }
177
178 let validity = vmc.validity();
180 if !validity.is_valid() {
181 if validity.not_after.timestamp() < chrono_now() {
182 return Err(VmcError::Expired);
183 }
184 return Err(VmcError::NotYetValid);
185 }
186
187 check_bimi_eku(vmc)?;
189
190 let expected_san = format!("{}._bimi.{}", selector, domain);
192 check_san_match(vmc, &expected_san)?;
193
194 let embedded_svg = extract_logotype_svg(vmc)?;
196
197 validate_svg_tiny_ps(&embedded_svg)
199 .map_err(|e| VmcError::SvgValidation(format!("{}", e)))?;
200
201 if let Some(dns_svg) = dns_logo_svg {
203 let dns_hash = sha256_hash(dns_svg.as_bytes());
204 let vmc_hash = sha256_hash(embedded_svg.as_bytes());
205 if dns_hash != vmc_hash {
206 return Err(VmcError::LogoHashMismatch);
207 }
208 }
209
210 validate_chain_signatures(&parsed_certs)?;
212
213 Ok(VmcValidationResult { embedded_svg })
214}
215
216fn check_bimi_eku(cert: &X509Certificate<'_>) -> Result<(), VmcError> {
218 let eku = cert
219 .tbs_certificate
220 .extended_key_usage()
221 .map_err(|e| VmcError::X509Parse(format!("EKU: {}", e)))?;
222
223 match eku {
224 Some(ext) => {
225 let bimi_oid = Oid::from(BIMI_EKU_OID)
226 .map_err(|_| VmcError::MissingBimiEku)?;
227 if ext.value.other.iter().any(|o| o == &bimi_oid) {
228 Ok(())
229 } else {
230 Err(VmcError::MissingBimiEku)
231 }
232 }
233 None => Err(VmcError::MissingBimiEku),
234 }
235}
236
237fn check_san_match(cert: &X509Certificate<'_>, expected: &str) -> Result<(), VmcError> {
239 let san = cert
240 .tbs_certificate
241 .subject_alternative_name()
242 .map_err(|e| VmcError::X509Parse(format!("SAN: {}", e)))?;
243
244 match san {
245 Some(ext) => {
246 for name in &ext.value.general_names {
247 if let GeneralName::DNSName(dns) = name {
248 if dns.eq_ignore_ascii_case(expected) {
249 return Ok(());
250 }
251 }
252 }
253 Err(VmcError::SanMismatch {
254 expected: expected.to_string(),
255 })
256 }
257 None => Err(VmcError::SanMismatch {
258 expected: expected.to_string(),
259 }),
260 }
261}
262
263fn extract_logotype_svg(cert: &X509Certificate<'_>) -> Result<String, VmcError> {
271 let logotype_oid = Oid::from(LOGOTYPE_OID)
272 .map_err(|_| VmcError::MissingLogoType)?;
273
274 let ext = cert
275 .tbs_certificate
276 .get_extension_unique(&logotype_oid)
277 .map_err(|e| VmcError::X509Parse(format!("LogoType: {}", e)))?
278 .ok_or(VmcError::MissingLogoType)?;
279
280 let raw = ext.value;
283 let raw_str = String::from_utf8_lossy(raw);
284
285 let marker = "data:image/svg+xml;base64,";
287 if let Some(start) = raw_str.find(marker) {
288 let b64_start = start + marker.len();
289 let b64_data: String = raw_str[b64_start..]
291 .chars()
292 .take_while(|c| c.is_ascii_alphanumeric() || *c == '+' || *c == '/' || *c == '=')
293 .collect();
294
295 if b64_data.is_empty() {
296 return Err(VmcError::LogoTypeExtractFailed(
297 "empty base64 data after marker".into(),
298 ));
299 }
300
301 use base64::Engine;
302 let svg_bytes = base64::engine::general_purpose::STANDARD
303 .decode(&b64_data)
304 .map_err(|e| VmcError::LogoTypeExtractFailed(format!("base64 decode: {}", e)))?;
305
306 let svg = String::from_utf8(svg_bytes)
307 .map_err(|e| VmcError::LogoTypeExtractFailed(format!("UTF-8: {}", e)))?;
308
309 return Ok(svg);
310 }
311
312 let marker_bytes = marker.as_bytes();
314 if let Some(pos) = raw.windows(marker_bytes.len()).position(|w| w == marker_bytes) {
315 let b64_start = pos + marker_bytes.len();
316 let b64_data: Vec<u8> = raw[b64_start..]
317 .iter()
318 .copied()
319 .take_while(|b| b.is_ascii_alphanumeric() || *b == b'+' || *b == b'/' || *b == b'=')
320 .collect();
321
322 if b64_data.is_empty() {
323 return Err(VmcError::LogoTypeExtractFailed(
324 "empty base64 data".into(),
325 ));
326 }
327
328 use base64::Engine;
329 let svg_bytes = base64::engine::general_purpose::STANDARD
330 .decode(&b64_data)
331 .map_err(|e| VmcError::LogoTypeExtractFailed(format!("base64 decode: {}", e)))?;
332
333 let svg = String::from_utf8(svg_bytes)
334 .map_err(|e| VmcError::LogoTypeExtractFailed(format!("UTF-8: {}", e)))?;
335
336 return Ok(svg);
337 }
338
339 Err(VmcError::MissingLogoType)
340}
341
342fn sha256_hash(data: &[u8]) -> Vec<u8> {
344 let mut hasher = Sha256::new();
345 hasher.update(data);
346 hasher.finalize().to_vec()
347}
348
349fn chrono_now() -> i64 {
351 std::time::SystemTime::now()
352 .duration_since(std::time::UNIX_EPOCH)
353 .map(|d| d.as_secs() as i64)
354 .unwrap_or(0)
355}
356
357fn validate_chain_signatures(certs: &[X509Certificate<'_>]) -> Result<(), VmcError> {
360 for i in 0..certs.len().saturating_sub(1) {
361 let child = &certs[i];
362 let parent = &certs[i + 1];
363 child
364 .verify_signature(Some(&parent.tbs_certificate.subject_pki))
365 .map_err(|e| {
366 VmcError::ChainValidation(format!(
367 "cert {} not signed by cert {}: {}",
368 i,
369 i + 1,
370 e
371 ))
372 })?;
373 }
374 Ok(())
375}
376
377#[cfg(test)]
378mod tests {
379 use super::*;
380 use rcgen::{
381 CertificateParams, CustomExtension, DnType, ExtendedKeyUsagePurpose,
382 IsCa, BasicConstraints, SanType, KeyPair,
383 };
384
385 const TEST_SVG: &str = r#"<svg xmlns="http://www.w3.org/2000/svg" version="1.2" baseProfile="tiny-ps" viewBox="0 0 100 100"><title>Test</title><rect width="100" height="100" fill="red"/></svg>"#;
387
388 fn build_logotype_extension(svg: &str) -> Vec<u8> {
391 use base64::Engine;
392 let b64 = base64::engine::general_purpose::STANDARD.encode(svg.as_bytes());
393 let data_uri = format!("data:image/svg+xml;base64,{}", b64);
394 data_uri.into_bytes()
398 }
399
400 fn bimi_eku() -> ExtendedKeyUsagePurpose {
402 ExtendedKeyUsagePurpose::Other(BIMI_EKU_OID.to_vec())
403 }
404
405 fn make_vmc_cert(
407 selector: &str,
408 domain: &str,
409 svg: &str,
410 expired: bool,
411 not_yet_valid: bool,
412 include_eku: bool,
413 include_san: bool,
414 include_logotype: bool,
415 ) -> (String, KeyPair) {
416 let mut params = CertificateParams::new(Vec::<String>::new())
417 .expect("CertificateParams");
418
419 params
420 .distinguished_name
421 .push(DnType::CommonName, format!("{}._bimi.{}", selector, domain));
422
423 if expired {
425 params.not_before = rcgen::date_time_ymd(2020, 1, 1);
426 params.not_after = rcgen::date_time_ymd(2021, 1, 1);
427 } else if not_yet_valid {
428 params.not_before = rcgen::date_time_ymd(2030, 1, 1);
429 params.not_after = rcgen::date_time_ymd(2031, 1, 1);
430 } else {
431 params.not_before = rcgen::date_time_ymd(2024, 1, 1);
432 params.not_after = rcgen::date_time_ymd(2030, 12, 31);
433 }
434
435 if include_eku {
437 params.extended_key_usages.push(bimi_eku());
438 }
439
440 if include_san {
442 let san_name = format!("{}._bimi.{}", selector, domain);
443 params.subject_alt_names.push(SanType::DnsName(san_name.try_into().expect("dns name")));
444 }
445
446 if include_logotype {
448 let logotype_oid = LOGOTYPE_OID.to_vec();
449 let ext_value = build_logotype_extension(svg);
450 let ext = CustomExtension::from_oid_content(&logotype_oid, ext_value);
451 params.custom_extensions.push(ext);
452 }
453
454 params.is_ca = IsCa::NoCa;
455
456 let key_pair = KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256)
457 .expect("key pair");
458 let cert = params.self_signed(&key_pair).expect("self-signed cert");
459 (cert.pem(), key_pair)
460 }
461
462 fn make_ca_cert(cn: &str) -> (CertificateParams, KeyPair) {
464 let mut params = CertificateParams::new(Vec::<String>::new())
465 .expect("CertificateParams");
466 params.distinguished_name.push(DnType::CommonName, cn);
467 params.is_ca = IsCa::Ca(BasicConstraints::Unconstrained);
468 params.not_before = rcgen::date_time_ymd(2024, 1, 1);
469 params.not_after = rcgen::date_time_ymd(2030, 12, 31);
470
471 let key_pair = KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256)
472 .expect("key pair");
473 (params, key_pair)
474 }
475
476 #[test]
479 fn valid_vmc() {
480 let (pem, _kp) =
481 make_vmc_cert("default", "example.com", TEST_SVG, false, false, true, true, true);
482 let result = validate_vmc(pem.as_bytes(), "default", "example.com", None);
483 assert!(result.is_ok(), "expected Ok, got {:?}", result);
484 assert_eq!(result.unwrap().embedded_svg, TEST_SVG);
485 }
486
487 #[test]
490 fn missing_bimi_eku() {
491 let (pem, _kp) =
492 make_vmc_cert("default", "example.com", TEST_SVG, false, false, false, true, true);
493 let result = validate_vmc(pem.as_bytes(), "default", "example.com", None);
494 assert_eq!(result.unwrap_err(), VmcError::MissingBimiEku);
495 }
496
497 #[test]
500 fn san_match() {
501 let (pem, _kp) =
502 make_vmc_cert("brand", "example.com", TEST_SVG, false, false, true, true, true);
503 let result = validate_vmc(pem.as_bytes(), "brand", "example.com", None);
504 assert!(result.is_ok());
505 }
506
507 #[test]
510 fn san_mismatch() {
511 let (pem, _kp) =
512 make_vmc_cert("default", "example.com", TEST_SVG, false, false, true, true, true);
513 let result = validate_vmc(pem.as_bytes(), "default", "other.com", None);
514 assert!(matches!(result.unwrap_err(), VmcError::SanMismatch { .. }));
515 }
516
517 #[test]
520 fn expired_cert() {
521 let (pem, _kp) =
522 make_vmc_cert("default", "example.com", TEST_SVG, true, false, true, true, true);
523 let result = validate_vmc(pem.as_bytes(), "default", "example.com", None);
524 assert_eq!(result.unwrap_err(), VmcError::Expired);
525 }
526
527 #[test]
530 fn not_yet_valid_cert() {
531 let (pem, _kp) =
532 make_vmc_cert("default", "example.com", TEST_SVG, false, true, true, true, true);
533 let result = validate_vmc(pem.as_bytes(), "default", "example.com", None);
534 assert_eq!(result.unwrap_err(), VmcError::NotYetValid);
535 }
536
537 #[test]
540 fn extract_logotype_svg_test() {
541 let (pem, _kp) =
542 make_vmc_cert("default", "example.com", TEST_SVG, false, false, true, true, true);
543 let result = validate_vmc(pem.as_bytes(), "default", "example.com", None);
544 assert!(result.is_ok());
545 assert_eq!(result.unwrap().embedded_svg, TEST_SVG);
546 }
547
548 #[test]
551 fn logo_hash_match() {
552 let (pem, _kp) =
553 make_vmc_cert("default", "example.com", TEST_SVG, false, false, true, true, true);
554 let result = validate_vmc(pem.as_bytes(), "default", "example.com", Some(TEST_SVG));
556 assert!(result.is_ok());
557 }
558
559 #[test]
562 fn logo_hash_mismatch() {
563 let (pem, _kp) =
564 make_vmc_cert("default", "example.com", TEST_SVG, false, false, true, true, true);
565 let different_svg = r#"<svg xmlns="http://www.w3.org/2000/svg" version="1.2" baseProfile="tiny-ps" viewBox="0 0 100 100"><title>Diff</title><rect width="100" height="100" fill="blue"/></svg>"#;
567 let result =
568 validate_vmc(pem.as_bytes(), "default", "example.com", Some(different_svg));
569 assert_eq!(result.unwrap_err(), VmcError::LogoHashMismatch);
570 }
571
572 #[test]
575 fn valid_pem_chain() {
576 let (ca_params, ca_kp) = make_ca_cert("Test CA");
578 let ca_cert = ca_params.self_signed(&ca_kp).expect("CA cert");
579
580 let mut vmc_params = CertificateParams::new(Vec::<String>::new())
581 .expect("CertificateParams");
582 vmc_params.distinguished_name.push(DnType::CommonName, "default._bimi.example.com");
583 vmc_params.is_ca = IsCa::NoCa;
584 vmc_params.not_before = rcgen::date_time_ymd(2024, 1, 1);
585 vmc_params.not_after = rcgen::date_time_ymd(2030, 12, 31);
586 vmc_params.extended_key_usages.push(bimi_eku());
587 vmc_params.subject_alt_names.push(
588 SanType::DnsName("default._bimi.example.com".try_into().expect("dns"))
589 );
590 let logotype_ext = CustomExtension::from_oid_content(
591 &LOGOTYPE_OID.to_vec(),
592 build_logotype_extension(TEST_SVG),
593 );
594 vmc_params.custom_extensions.push(logotype_ext);
595
596 let vmc_kp = KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256)
597 .expect("vmc key pair");
598 let vmc_cert = vmc_params.signed_by(&vmc_kp, &ca_cert, &ca_kp).expect("VMC signed");
599
600 let chain_pem = format!("{}{}", vmc_cert.pem(), ca_cert.pem());
601 let result = validate_vmc(chain_pem.as_bytes(), "default", "example.com", None);
602 assert!(result.is_ok(), "expected Ok, got {:?}", result);
603 }
604
605 #[test]
608 fn out_of_order_chain() {
609 let (ca_params, ca_kp) = make_ca_cert("Test CA");
611 let ca_cert = ca_params.self_signed(&ca_kp).expect("CA cert");
612
613 let mut vmc_params = CertificateParams::new(Vec::<String>::new())
614 .expect("CertificateParams");
615 vmc_params.distinguished_name.push(DnType::CommonName, "default._bimi.example.com");
616 vmc_params.is_ca = IsCa::NoCa;
617 vmc_params.not_before = rcgen::date_time_ymd(2024, 1, 1);
618 vmc_params.not_after = rcgen::date_time_ymd(2030, 12, 31);
619 vmc_params.extended_key_usages.push(bimi_eku());
620 vmc_params.subject_alt_names.push(
621 SanType::DnsName("default._bimi.example.com".try_into().expect("dns"))
622 );
623 let logotype_ext = CustomExtension::from_oid_content(
624 &LOGOTYPE_OID.to_vec(),
625 build_logotype_extension(TEST_SVG),
626 );
627 vmc_params.custom_extensions.push(logotype_ext);
628
629 let vmc_kp = KeyPair::generate_for(&rcgen::PKCS_ECDSA_P256_SHA256)
630 .expect("vmc key pair");
631 let vmc_cert = vmc_params.signed_by(&vmc_kp, &ca_cert, &ca_kp).expect("VMC signed");
632
633 let chain_pem = format!("{}{}", ca_cert.pem(), vmc_cert.pem());
635 let result = validate_vmc(chain_pem.as_bytes(), "default", "example.com", None);
636 assert!(result.is_err());
637 let err = result.unwrap_err();
639 assert!(
640 matches!(err, VmcError::OutOfOrder | VmcError::MissingBimiEku),
641 "expected OutOfOrder or MissingBimiEku, got {:?}",
642 err
643 );
644 }
645
646 #[test]
649 fn multiple_vmcs_in_chain() {
650 let (pem1, _kp1) =
651 make_vmc_cert("default", "example.com", TEST_SVG, false, false, true, true, true);
652 let (pem2, _kp2) =
653 make_vmc_cert("default", "other.com", TEST_SVG, false, false, true, true, true);
654
655 let chain_pem = format!("{}{}", pem1, pem2);
656 let result = validate_vmc(chain_pem.as_bytes(), "default", "example.com", None);
657 assert_eq!(result.unwrap_err(), VmcError::MultipleVmcs);
658 }
659
660 #[test]
663 fn duplicate_cert_in_chain() {
664 let (pem, _kp) =
665 make_vmc_cert("default", "example.com", TEST_SVG, false, false, true, true, true);
666 let chain_pem = format!("{}{}", pem, pem);
667 let result = validate_vmc(chain_pem.as_bytes(), "default", "example.com", None);
668 assert_eq!(result.unwrap_err(), VmcError::DuplicateCert);
669 }
670
671 #[test]
674 fn no_certificates() {
675 let result = validate_vmc(b"not a PEM", "default", "example.com", None);
676 assert_eq!(result.unwrap_err(), VmcError::NoCertificates);
677 }
678
679 #[test]
682 fn missing_logotype() {
683 let (pem, _kp) =
684 make_vmc_cert("default", "example.com", TEST_SVG, false, false, true, true, false);
685 let result = validate_vmc(pem.as_bytes(), "default", "example.com", None);
686 assert_eq!(result.unwrap_err(), VmcError::MissingLogoType);
687 }
688
689 #[test]
692 fn missing_san() {
693 let (pem, _kp) =
694 make_vmc_cert("default", "example.com", TEST_SVG, false, false, true, false, true);
695 let result = validate_vmc(pem.as_bytes(), "default", "example.com", None);
696 assert!(matches!(result.unwrap_err(), VmcError::SanMismatch { .. }));
697 }
698}