Skip to main content

pg_ephemeral/
certificate.rs

1pub struct Bundle {
2    pub ca_cert_pem: String,
3    pub server_cert_pem: String,
4    pub server_key_pem: String,
5}
6
7impl Bundle {
8    pub fn generate(hostname: &str) -> Result<Self, rcgen::Error> {
9        let ca_key = rcgen::KeyPair::generate()?;
10        let mut ca_params = rcgen::CertificateParams::new(vec![])?;
11        ca_params
12            .distinguished_name
13            .push(rcgen::DnType::CommonName, "pg-ephemeral CA");
14        ca_params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
15        ca_params.key_usages = vec![
16            rcgen::KeyUsagePurpose::KeyCertSign,
17            rcgen::KeyUsagePurpose::CrlSign,
18        ];
19
20        let ca_cert = ca_params.self_signed(&ca_key)?;
21        let ca_cert_pem = ca_cert.pem();
22
23        let server_key = rcgen::KeyPair::generate()?;
24        let mut server_params = rcgen::CertificateParams::new(vec![hostname.to_string()])?;
25        server_params
26            .distinguished_name
27            .push(rcgen::DnType::CommonName, "pg-ephemeral server");
28
29        server_params.key_usages = vec![
30            rcgen::KeyUsagePurpose::DigitalSignature,
31            rcgen::KeyUsagePurpose::KeyEncipherment,
32        ];
33
34        server_params.extended_key_usages = vec![rcgen::ExtendedKeyUsagePurpose::ServerAuth];
35
36        let ca_issuer = rcgen::Issuer::from_params(&ca_params, &ca_key);
37        let server_cert = server_params.signed_by(&server_key, &ca_issuer)?;
38        let server_cert_pem = server_cert.pem();
39        let server_key_pem = server_key.serialize_pem();
40
41        Ok(Bundle {
42            ca_cert_pem,
43            server_cert_pem,
44            server_key_pem,
45        })
46    }
47
48    pub fn build(
49        ca_cert_pem: String,
50        server_cert_pem: String,
51        server_key_pem: String,
52        expected_hostname: &str,
53    ) -> Result<Self, ValidationError> {
54        use x509_parser::prelude::*;
55        use x509_parser::verify::verify_signature;
56
57        let (_, pem) = parse_x509_pem(ca_cert_pem.as_bytes()).map_err(|e| {
58            ValidationError::ParseError(format!("Failed to parse CA certificate PEM: {e}"))
59        })?;
60        let ca_cert = pem.parse_x509().map_err(|e| {
61            ValidationError::ParseError(format!("Failed to parse CA certificate: {e}"))
62        })?;
63
64        let ca_basic_constraints = ca_cert
65            .basic_constraints()
66            .map_err(|e| {
67                ValidationError::ValidationError(format!("CA should have basic constraints: {e}"))
68            })?
69            .ok_or_else(|| {
70                ValidationError::ValidationError(
71                    "CA basic constraints should be present".to_string(),
72                )
73            })?;
74
75        if !ca_basic_constraints.value.ca {
76            return Err(ValidationError::ValidationError(
77                "CA certificate should have CA=true".to_string(),
78            ));
79        }
80
81        if ca_cert.subject().to_string() != "CN=pg-ephemeral CA" {
82            return Err(ValidationError::ValidationError(format!(
83                "CA subject should be CN=pg-ephemeral CA, got: {}",
84                ca_cert.subject()
85            )));
86        }
87
88        let (_, pem) = parse_x509_pem(server_cert_pem.as_bytes()).map_err(|e| {
89            ValidationError::ParseError(format!("Failed to parse server certificate PEM: {e}"))
90        })?;
91        let server_cert = pem.parse_x509().map_err(|e| {
92            ValidationError::ParseError(format!("Failed to parse server certificate: {e}"))
93        })?;
94
95        if server_cert.subject().to_string() != "CN=pg-ephemeral server" {
96            return Err(ValidationError::ValidationError(format!(
97                "Server subject should be CN=pg-ephemeral server, got: {}",
98                server_cert.subject()
99            )));
100        }
101
102        // Validate that the DNS name in SAN matches the expected hostname
103        let san_ext = server_cert
104            .subject_alternative_name()
105            .map_err(|e| {
106                ValidationError::ValidationError(format!(
107                    "Failed to read subject alternative name: {e}"
108                ))
109            })?
110            .ok_or_else(|| {
111                ValidationError::ValidationError(
112                    "Server certificate should have subject alternative name extension".to_string(),
113                )
114            })?;
115
116        let dns_names: Vec<&str> = san_ext
117            .value
118            .general_names
119            .iter()
120            .filter_map(|name| {
121                if let x509_parser::extensions::GeneralName::DNSName(dns) = name {
122                    Some(*dns)
123                } else {
124                    None
125                }
126            })
127            .collect();
128
129        if !dns_names.contains(&expected_hostname) {
130            return Err(ValidationError::ValidationError(format!(
131                "Server certificate DNS names {dns_names:?} should contain {expected_hostname}"
132            )));
133        }
134
135        if server_cert.issuer().to_string() != ca_cert.subject().to_string() {
136            return Err(ValidationError::ValidationError(
137                "Server certificate issuer should match CA subject".to_string(),
138            ));
139        }
140
141        verify_signature(
142            ca_cert.public_key(),
143            &server_cert.signature_algorithm,
144            &server_cert.signature_value,
145            server_cert.tbs_certificate.as_ref(),
146        )
147        .map_err(|e| {
148            ValidationError::ValidationError(format!(
149                "Server certificate should be signed by CA: {e}"
150            ))
151        })?;
152
153        let server_key = rcgen::KeyPair::from_pem(&server_key_pem)
154            .map_err(|e| ValidationError::ParseError(format!("Failed to parse server key: {e}")))?;
155
156        let server_key_public = rcgen::PublicKeyData::subject_public_key_info(&server_key);
157        let cert_public_key = server_cert.public_key().raw;
158
159        if server_key_public != cert_public_key {
160            return Err(ValidationError::ValidationError(
161                "Server certificate public key should match server private key".to_string(),
162            ));
163        }
164
165        Ok(Bundle {
166            ca_cert_pem,
167            server_cert_pem,
168            server_key_pem,
169        })
170    }
171}
172
173#[derive(Debug)]
174pub enum ValidationError {
175    ParseError(String),
176    ValidationError(String),
177}
178
179impl std::fmt::Display for ValidationError {
180    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
181        match self {
182            ValidationError::ParseError(msg) => write!(f, "Parse error: {msg}"),
183            ValidationError::ValidationError(msg) => write!(f, "Validation error: {msg}"),
184        }
185    }
186}
187
188impl std::error::Error for ValidationError {}
189
190#[cfg(test)]
191mod test {
192    use super::*;
193
194    #[test]
195    fn test_generate_bundle() {
196        let hostname = "test.example.com";
197        let generated = Bundle::generate(hostname).unwrap();
198
199        Bundle::build(
200            generated.ca_cert_pem,
201            generated.server_cert_pem,
202            generated.server_key_pem,
203            hostname,
204        )
205        .unwrap();
206    }
207}