use std::io::BufReader;
use std::sync::Arc;
use narwhal_core::{ConnectionParams, Error, Result, SslMode};
use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
use rustls::pki_types::{CertificateDer, PrivateKeyDer, ServerName, UnixTime};
use rustls::{ClientConfig, DigitallySignedStruct, RootCertStore, SignatureScheme};
pub(crate) use tokio_postgres_rustls::MakeRustlsConnect;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum InternalSslMode {
Disable,
Prefer,
Require,
VerifyCa,
Verify,
}
impl InternalSslMode {
pub(crate) fn from_params(params: &ConnectionParams) -> Result<Self> {
let mode = params.ssl_mode;
let mode = if mode == SslMode::Prefer {
if let Some(raw) = params.options.get("sslmode") {
match raw.to_ascii_lowercase().as_str() {
"disable" => SslMode::Disable,
"prefer" => SslMode::Prefer,
"require" => SslMode::Require,
"verify-ca" => SslMode::VerifyCa,
"verify-full" => SslMode::VerifyFull,
other => {
return Err(Error::Config(format!(
"unsupported sslmode value: {other} \
(use disable|prefer|require|verify-ca|verify-full)"
)));
}
}
} else {
SslMode::Prefer
}
} else {
mode
};
Ok(match mode {
SslMode::Disable => Self::Disable,
SslMode::Prefer => Self::Prefer,
SslMode::Require => Self::Require,
SslMode::VerifyCa => Self::VerifyCa,
SslMode::VerifyFull => Self::Verify,
_ => Self::Verify,
})
}
pub(crate) const fn as_str(self) -> &'static str {
match self {
Self::Disable => "disable",
Self::Prefer => "prefer",
Self::Require => "require",
Self::VerifyCa => "verify-ca",
Self::Verify => "verify-full",
}
}
}
impl std::fmt::Display for InternalSslMode {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(self.as_str())
}
}
pub(crate) fn make_tls_connector(
mode: InternalSslMode,
params: &ConnectionParams,
) -> Result<MakeRustlsConnect> {
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let config = match mode {
InternalSslMode::Disable => unreachable!("disable path does not request a TLS connector"),
InternalSslMode::Prefer | InternalSslMode::Verify => verified_client_config(params)?,
InternalSslMode::Require | InternalSslMode::VerifyCa => verify_ca_client_config(params)?,
};
Ok(MakeRustlsConnect::new(config))
}
fn verified_client_config(params: &ConnectionParams) -> Result<ClientConfig> {
let store = build_root_store(params)?;
if let Some(key_pair) = load_client_cert_key(params)? {
ClientConfig::builder()
.with_root_certificates(store)
.with_client_auth_cert(key_pair.certs, key_pair.key)
.map_err(|e| Error::Config(format!("invalid client cert/key pair: {e}")))
} else {
Ok(ClientConfig::builder()
.with_root_certificates(store)
.with_no_client_auth())
}
}
fn verify_ca_client_config(params: &ConnectionParams) -> Result<ClientConfig> {
let store = build_root_store(params)?;
let verifier = Arc::new(VerifyCaNoHostname::new(store));
if let Some(key_pair) = load_client_cert_key(params)? {
ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(verifier)
.with_client_auth_cert(key_pair.certs, key_pair.key)
.map_err(|e| Error::Config(format!("invalid client cert/key pair: {e}")))
} else {
Ok(ClientConfig::builder()
.dangerous()
.with_custom_certificate_verifier(verifier)
.with_no_client_auth())
}
}
fn build_root_store(params: &ConnectionParams) -> Result<RootCertStore> {
let mut store = RootCertStore::empty();
if let Some(path) = ¶ms.ssl_root_cert {
let bytes = std::fs::read(path).map_err(|e| {
Error::Config(format!(
"failed to read ssl_root_cert '{}': {e}",
path.display()
))
})?;
let mut reader = BufReader::new(&bytes[..]);
let certs: Vec<CertificateDer<'_>> = rustls_pemfile::certs(&mut reader)
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| Error::Config(format!("failed to parse ssl_root_cert PEM: {e}")))?;
let (added, _ignored) = store.add_parsable_certificates(certs);
if added == 0 {
return Err(Error::Config(format!(
"no certificates found in ssl_root_cert '{}'",
path.display()
)));
}
} else {
let load = rustls_native_certs::load_native_certs();
if !load.errors.is_empty() {
for err in &load.errors {
tracing::warn!(target: "narwhal::postgres::tls", error = %err, "failed to load a native CA");
}
}
let (added, _ignored) = store.add_parsable_certificates(load.certs);
if added == 0 {
return Err(Error::Config(
"no trusted CA certificates available; install ca-certificates \
or set ssl_root_cert"
.into(),
));
}
}
Ok(store)
}
#[derive(Debug)]
struct VerifyCaNoHostname {
inner: Arc<dyn ServerCertVerifier>,
}
impl VerifyCaNoHostname {
fn new(store: RootCertStore) -> Self {
let built = rustls::client::WebPkiServerVerifier::builder(Arc::new(store))
.build()
.expect("WebPkiServerVerifier construction should not fail with a valid root store");
let inner: Arc<dyn ServerCertVerifier> = built;
Self { inner }
}
}
impl ServerCertVerifier for VerifyCaNoHostname {
fn verify_server_cert(
&self,
end_entity: &CertificateDer<'_>,
intermediates: &[CertificateDer<'_>],
_server_name: &ServerName<'_>,
ocsp_response: &[u8],
now: UnixTime,
) -> std::result::Result<ServerCertVerified, rustls::Error> {
match self.inner.verify_server_cert(
end_entity,
intermediates,
_server_name,
ocsp_response,
now,
) {
Ok(v) => Ok(v),
Err(rustls::Error::InvalidCertificate(e)) => {
if matches!(e, rustls::CertificateError::NotValidForName) {
Ok(ServerCertVerified::assertion())
} else {
Err(rustls::Error::InvalidCertificate(e))
}
}
Err(other) => Err(other),
}
}
fn verify_tls12_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
self.inner.verify_tls12_signature(message, cert, dss)
}
fn verify_tls13_signature(
&self,
message: &[u8],
cert: &CertificateDer<'_>,
dss: &DigitallySignedStruct,
) -> std::result::Result<HandshakeSignatureValid, rustls::Error> {
self.inner.verify_tls13_signature(message, cert, dss)
}
fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
self.inner.supported_verify_schemes()
}
}
#[derive(Debug)]
struct ClientCertKey {
certs: Vec<CertificateDer<'static>>,
key: PrivateKeyDer<'static>,
}
fn load_client_cert_key(params: &ConnectionParams) -> Result<Option<ClientCertKey>> {
match (¶ms.ssl_cert, ¶ms.ssl_key) {
(Some(cert_path), Some(key_path)) => {
let cert_bytes = std::fs::read(cert_path).map_err(|e| {
Error::Config(format!(
"failed to read ssl_cert '{}': {e}",
cert_path.display()
))
})?;
let key_bytes = std::fs::read(key_path).map_err(|e| {
Error::Config(format!(
"failed to read ssl_key '{}': {e}",
key_path.display()
))
})?;
let mut cert_reader = BufReader::new(&cert_bytes[..]);
let certs: Vec<CertificateDer<'_>> = rustls_pemfile::certs(&mut cert_reader)
.collect::<std::result::Result<Vec<_>, _>>()
.map_err(|e| Error::Config(format!("failed to parse ssl_cert PEM: {e}")))?;
let mut key_reader = BufReader::new(&key_bytes[..]);
let key = rustls_pemfile::private_key(&mut key_reader)
.map_err(|e| Error::Config(format!("failed to parse ssl_key PEM: {e}")))?
.ok_or_else(|| Error::Config("no private key found in ssl_key file".into()))?;
Ok(Some(ClientCertKey { certs, key }))
}
(None, None) => Ok(None),
(Some(_), None) => Err(Error::Config(
"ssl_cert is set but ssl_key is missing; both must be provided together".into(),
)),
(None, Some(_)) => Err(Error::Config(
"ssl_key is set but ssl_cert is missing; both must be provided together".into(),
)),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::BTreeMap;
fn params_with_options(options: BTreeMap<String, String>) -> ConnectionParams {
ConnectionParams::with(|p| {
p.options = options;
})
}
fn params_with_ssl_mode(ssl_mode: SslMode) -> ConnectionParams {
ConnectionParams::with(|p| {
p.ssl_mode = ssl_mode;
})
}
#[test]
fn from_params_default_is_prefer() {
let params = ConnectionParams::default();
assert_eq!(
InternalSslMode::from_params(¶ms).unwrap(),
InternalSslMode::Prefer
);
}
#[test]
fn from_params_disable_mode() {
let params = params_with_ssl_mode(SslMode::Disable);
assert_eq!(
InternalSslMode::from_params(¶ms).unwrap(),
InternalSslMode::Disable
);
}
#[test]
fn from_params_require_maps_to_require_chain() {
let params = params_with_ssl_mode(SslMode::Require);
assert_eq!(
InternalSslMode::from_params(¶ms).unwrap(),
InternalSslMode::Require
);
}
#[test]
fn from_params_prefer_maps_to_prefer() {
let params = params_with_ssl_mode(SslMode::Prefer);
assert_eq!(
InternalSslMode::from_params(¶ms).unwrap(),
InternalSslMode::Prefer
);
}
#[test]
fn from_params_verify_ca() {
let params = params_with_ssl_mode(SslMode::VerifyCa);
assert_eq!(
InternalSslMode::from_params(¶ms).unwrap(),
InternalSslMode::VerifyCa
);
}
#[test]
fn from_params_verify_full() {
let params = params_with_ssl_mode(SslMode::VerifyFull);
assert_eq!(
InternalSslMode::from_params(¶ms).unwrap(),
InternalSslMode::Verify
);
}
#[test]
fn from_params_legacy_options_override() {
let mut opts = BTreeMap::new();
opts.insert("sslmode".into(), "disable".into());
let params = params_with_options(opts);
assert_eq!(
InternalSslMode::from_params(¶ms).unwrap(),
InternalSslMode::Disable
);
}
#[test]
fn from_params_explicit_mode_overrides_legacy() {
let mut opts = BTreeMap::new();
opts.insert("sslmode".into(), "disable".into());
let params = ConnectionParams::with(|p| {
p.ssl_mode = SslMode::Require;
p.options = opts;
});
assert_eq!(
InternalSslMode::from_params(¶ms).unwrap(),
InternalSslMode::Require
);
}
#[test]
fn rejects_unknown_legacy_sslmode() {
let mut opts = BTreeMap::new();
opts.insert("sslmode".into(), "bogus".into());
let params = params_with_options(opts);
let err = InternalSslMode::from_params(¶ms).unwrap_err();
assert!(err.to_string().contains("unsupported sslmode"));
}
#[test]
fn client_cert_key_missing_pair_errors() {
let params = ConnectionParams::with(|p| {
p.ssl_cert = Some("/tmp/cert.pem".into());
p.ssl_key = None;
});
let err = load_client_cert_key(¶ms).unwrap_err();
assert!(
err.to_string()
.contains("ssl_cert is set but ssl_key is missing")
);
}
#[test]
fn prefer_uses_chain_verifier() {
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let params = ConnectionParams::with(|p| {
p.ssl_mode = SslMode::Prefer;
});
let mode = InternalSslMode::from_params(¶ms).unwrap();
assert_eq!(mode, InternalSslMode::Prefer);
let result = make_tls_connector(mode, ¶ms);
match result {
Ok(_) => {} Err(e) => {
assert!(
e.to_string().contains("no trusted CA certificates"),
"unexpected error: {e}"
);
}
}
}
#[test]
fn require_uses_chain_verifier_no_hostname() {
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let params = ConnectionParams::with(|p| {
p.ssl_mode = SslMode::Require;
});
let mode = InternalSslMode::from_params(¶ms).unwrap();
assert_eq!(mode, InternalSslMode::Require);
let result = make_tls_connector(mode, ¶ms);
match result {
Ok(_) => {}
Err(e) => {
assert!(
e.to_string().contains("no trusted CA certificates"),
"unexpected error: {e}"
);
}
}
}
#[test]
fn verify_ca_uses_chain_verifier_no_hostname() {
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let params = ConnectionParams::with(|p| {
p.ssl_mode = SslMode::VerifyCa;
p.ssl_root_cert = None;
});
let mode = InternalSslMode::from_params(¶ms).unwrap();
assert_eq!(mode, InternalSslMode::VerifyCa);
}
#[test]
fn verify_ca_not_same_as_verify_full() {
let ca_mode = InternalSslMode::from_params(&ConnectionParams::with(|p| {
p.ssl_mode = SslMode::VerifyCa;
}))
.unwrap();
let full_mode = InternalSslMode::from_params(&ConnectionParams::with(|p| {
p.ssl_mode = SslMode::VerifyFull;
}))
.unwrap();
assert_ne!(ca_mode, full_mode);
}
#[test]
fn chain_verified_mode_sends_client_cert_when_provided() {
use std::io::Write;
let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
let cert_pem = include_str!("../../tests/postgres_fixtures/client.crt");
let key_pem = include_str!("../../tests/postgres_fixtures/client.key");
let dir = tempfile::tempdir().expect("tempdir");
let cert_path = dir.path().join("client.crt");
let key_path = dir.path().join("client.key");
std::fs::File::create(&cert_path)
.and_then(|mut f| f.write_all(cert_pem.as_bytes()))
.expect("write cert");
std::fs::File::create(&key_path)
.and_then(|mut f| f.write_all(key_pem.as_bytes()))
.expect("write key");
let params = ConnectionParams::with(|p| {
p.ssl_mode = SslMode::Require;
p.ssl_cert = Some(cert_path);
p.ssl_key = Some(key_path);
});
let result = make_tls_connector(InternalSslMode::Require, ¶ms);
match result {
Ok(_) => {}
Err(e) => {
assert!(
e.to_string().contains("no trusted CA certificates"),
"unexpected error: {e}"
);
}
}
}
}