1use rustls::ClientConfig;
10use tokio_postgres_rustls::MakeRustlsConnect;
11
12pub fn wants_tls(connection_string: &str) -> bool {
14 sslmode(connection_string)
15 .map(|m| {
16 matches!(
17 m.as_str(),
18 "require" | "verify-ca" | "verify-full" | "prefer"
19 )
20 })
21 .unwrap_or(false)
22}
23
24fn sslmode(connection_string: &str) -> Option<String> {
26 let lower = connection_string.to_ascii_lowercase();
29 let idx = lower.find("sslmode=")?;
30 let rest = &lower[idx + "sslmode=".len()..];
31 let end = rest.find([' ', '&', '\'']).unwrap_or(rest.len());
32 Some(rest[..end].trim().to_string())
33}
34
35pub fn ensure_crypto_provider() {
44 let _ = rustls::crypto::ring::default_provider().install_default();
45}
46
47pub fn make_connector() -> anyhow::Result<MakeRustlsConnect> {
52 ensure_crypto_provider();
54
55 let mut roots = rustls::RootCertStore::empty();
56 let result = rustls_native_certs::load_native_certs();
57 if !result.errors.is_empty() {
58 tracing::warn!(
59 "Some native root certificates failed to load: {:?}",
60 result.errors
61 );
62 }
63 for cert in result.certs {
64 let _ = roots.add(cert);
66 }
67 if roots.is_empty() {
68 anyhow::bail!("No native root certificates available for TLS verification");
69 }
70
71 let config = ClientConfig::builder()
72 .with_root_certificates(roots)
73 .with_no_client_auth();
74
75 Ok(MakeRustlsConnect::new(config))
76}
77
78#[cfg(test)]
79mod tests {
80 use super::*;
81
82 #[test]
83 fn test_sslmode_url_form() {
84 assert_eq!(
85 sslmode("postgres://u:p@h/db?sslmode=require"),
86 Some("require".to_string())
87 );
88 }
89
90 #[test]
91 fn test_sslmode_kv_form() {
92 assert_eq!(
93 sslmode("host=localhost sslmode=verify-full dbname=x"),
94 Some("verify-full".to_string())
95 );
96 }
97
98 #[test]
99 fn test_wants_tls() {
100 assert!(wants_tls("postgres://h/db?sslmode=require"));
101 assert!(wants_tls("sslmode=verify-ca"));
102 assert!(wants_tls("sslmode=prefer"));
103 assert!(!wants_tls("postgres://h/db?sslmode=disable"));
104 assert!(!wants_tls("postgres://h/db")); assert!(!wants_tls("host=localhost dbname=x"));
106 }
107}