Skip to main content

pg_ephemeral/
certificate.rs

1pub struct Bundle {
2    pub ca_cert_pem: String,
3    pub server_cert_pem: String,
4    pub server_key_pem: String,
5}
6
7/// Failure modes of [`write_ca_pem_to_temp`].
8#[derive(Debug, thiserror::Error)]
9pub enum WriteCaPemError {
10    /// The system clock reports a time before the Unix epoch, so the
11    /// unique temp-file suffix cannot be derived.
12    #[error("system clock is before the Unix epoch")]
13    SystemClock(#[source] std::time::SystemTimeError),
14    /// The CA certificate PEM could not be written to the temp file.
15    #[error("failed to write CA certificate PEM to temp file")]
16    Write(#[source] std::io::Error),
17}
18
19/// Persist a CA-certificate PEM to a uniquely-named file under
20/// `std::env::temp_dir()` and return its path.
21///
22/// Leak-on-purpose: the caller becomes the file's owner and nothing
23/// removes it today. Centralised here so the `pg_ephemeral_ca_*.crt`
24/// naming scheme has a single source of truth — both `Container`
25/// startup and `label::Metadata::prepare_config` route through this.
26pub(crate) fn write_ca_pem_to_temp(pem: &[u8]) -> Result<std::path::PathBuf, WriteCaPemError> {
27    let timestamp = std::time::SystemTime::now()
28        .duration_since(std::time::UNIX_EPOCH)
29        .map_err(WriteCaPemError::SystemClock)?
30        .as_nanos();
31    let path = std::env::temp_dir().join(format!("pg_ephemeral_ca_{timestamp}.crt"));
32    std::fs::write(&path, pem).map_err(WriteCaPemError::Write)?;
33    Ok(path)
34}
35
36impl Bundle {
37    pub fn generate(hostname: &str) -> Result<Self, rcgen::Error> {
38        let ca_key = rcgen::KeyPair::generate()?;
39        let mut ca_params = rcgen::CertificateParams::new(vec![])?;
40        ca_params
41            .distinguished_name
42            .push(rcgen::DnType::CommonName, "pg-ephemeral CA");
43        ca_params.is_ca = rcgen::IsCa::Ca(rcgen::BasicConstraints::Unconstrained);
44        ca_params.key_usages = vec![
45            rcgen::KeyUsagePurpose::KeyCertSign,
46            rcgen::KeyUsagePurpose::CrlSign,
47        ];
48
49        let ca_cert = ca_params.self_signed(&ca_key)?;
50        let ca_cert_pem = ca_cert.pem();
51
52        let server_key = rcgen::KeyPair::generate()?;
53        let mut server_params = rcgen::CertificateParams::new(vec![hostname.to_string()])?;
54        server_params
55            .distinguished_name
56            .push(rcgen::DnType::CommonName, "pg-ephemeral server");
57
58        server_params.key_usages = vec![
59            rcgen::KeyUsagePurpose::DigitalSignature,
60            rcgen::KeyUsagePurpose::KeyEncipherment,
61        ];
62
63        server_params.extended_key_usages = vec![rcgen::ExtendedKeyUsagePurpose::ServerAuth];
64
65        let ca_issuer = rcgen::Issuer::from_params(&ca_params, &ca_key);
66        let server_cert = server_params.signed_by(&server_key, &ca_issuer)?;
67        let server_cert_pem = server_cert.pem();
68        let server_key_pem = server_key.serialize_pem();
69
70        Ok(Bundle {
71            ca_cert_pem,
72            server_cert_pem,
73            server_key_pem,
74        })
75    }
76
77    pub fn build(
78        ca_cert_pem: String,
79        server_cert_pem: String,
80        server_key_pem: String,
81        expected_hostname: &str,
82    ) -> Result<Self, ValidationError> {
83        use x509_parser::prelude::*;
84        use x509_parser::verify::verify_signature;
85
86        let (_, pem) = parse_x509_pem(ca_cert_pem.as_bytes()).map_err(|e| {
87            ValidationError::ParseError(format!("Failed to parse CA certificate PEM: {e}"))
88        })?;
89        let ca_cert = pem.parse_x509().map_err(|e| {
90            ValidationError::ParseError(format!("Failed to parse CA certificate: {e}"))
91        })?;
92
93        let ca_basic_constraints = ca_cert
94            .basic_constraints()
95            .map_err(|e| {
96                ValidationError::ValidationError(format!("CA should have basic constraints: {e}"))
97            })?
98            .ok_or_else(|| {
99                ValidationError::ValidationError(
100                    "CA basic constraints should be present".to_string(),
101                )
102            })?;
103
104        if !ca_basic_constraints.value.ca {
105            return Err(ValidationError::ValidationError(
106                "CA certificate should have CA=true".to_string(),
107            ));
108        }
109
110        if ca_cert.subject().to_string() != "CN=pg-ephemeral CA" {
111            return Err(ValidationError::ValidationError(format!(
112                "CA subject should be CN=pg-ephemeral CA, got: {}",
113                ca_cert.subject()
114            )));
115        }
116
117        let (_, pem) = parse_x509_pem(server_cert_pem.as_bytes()).map_err(|e| {
118            ValidationError::ParseError(format!("Failed to parse server certificate PEM: {e}"))
119        })?;
120        let server_cert = pem.parse_x509().map_err(|e| {
121            ValidationError::ParseError(format!("Failed to parse server certificate: {e}"))
122        })?;
123
124        if server_cert.subject().to_string() != "CN=pg-ephemeral server" {
125            return Err(ValidationError::ValidationError(format!(
126                "Server subject should be CN=pg-ephemeral server, got: {}",
127                server_cert.subject()
128            )));
129        }
130
131        // Validate that the DNS name in SAN matches the expected hostname
132        let san_ext = server_cert
133            .subject_alternative_name()
134            .map_err(|e| {
135                ValidationError::ValidationError(format!(
136                    "Failed to read subject alternative name: {e}"
137                ))
138            })?
139            .ok_or_else(|| {
140                ValidationError::ValidationError(
141                    "Server certificate should have subject alternative name extension".to_string(),
142                )
143            })?;
144
145        let dns_names: Vec<&str> = san_ext
146            .value
147            .general_names
148            .iter()
149            .filter_map(|name| {
150                if let x509_parser::extensions::GeneralName::DNSName(dns) = name {
151                    Some(*dns)
152                } else {
153                    None
154                }
155            })
156            .collect();
157
158        if !dns_names.contains(&expected_hostname) {
159            return Err(ValidationError::ValidationError(format!(
160                "Server certificate DNS names {dns_names:?} should contain {expected_hostname}"
161            )));
162        }
163
164        if server_cert.issuer().to_string() != ca_cert.subject().to_string() {
165            return Err(ValidationError::ValidationError(
166                "Server certificate issuer should match CA subject".to_string(),
167            ));
168        }
169
170        verify_signature(
171            ca_cert.public_key(),
172            &server_cert.signature_algorithm,
173            &server_cert.signature_value,
174            server_cert.tbs_certificate.as_ref(),
175        )
176        .map_err(|e| {
177            ValidationError::ValidationError(format!(
178                "Server certificate should be signed by CA: {e}"
179            ))
180        })?;
181
182        let server_key = rcgen::KeyPair::from_pem(&server_key_pem)
183            .map_err(|e| ValidationError::ParseError(format!("Failed to parse server key: {e}")))?;
184
185        let server_key_public = rcgen::PublicKeyData::subject_public_key_info(&server_key);
186        let cert_public_key = server_cert.public_key().raw;
187
188        if server_key_public != cert_public_key {
189            return Err(ValidationError::ValidationError(
190                "Server certificate public key should match server private key".to_string(),
191            ));
192        }
193
194        Ok(Bundle {
195            ca_cert_pem,
196            server_cert_pem,
197            server_key_pem,
198        })
199    }
200}
201
202#[derive(Debug)]
203pub enum ValidationError {
204    ParseError(String),
205    ValidationError(String),
206}
207
208impl std::fmt::Display for ValidationError {
209    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
210        match self {
211            ValidationError::ParseError(msg) => write!(f, "Parse error: {msg}"),
212            ValidationError::ValidationError(msg) => write!(f, "Validation error: {msg}"),
213        }
214    }
215}
216
217impl std::error::Error for ValidationError {}
218
219#[cfg(test)]
220mod test {
221    use super::*;
222
223    #[test]
224    fn test_generate_bundle() {
225        let hostname = "test.example.com";
226        let generated = Bundle::generate(hostname).unwrap();
227
228        Bundle::build(
229            generated.ca_cert_pem,
230            generated.server_cert_pem,
231            generated.server_key_pem,
232            hostname,
233        )
234        .unwrap();
235    }
236}