1#[cfg(feature = "tls")]
7use sqlmodel_core::Error;
8#[cfg(feature = "tls")]
9use sqlmodel_core::error::{ConnectionError, ConnectionErrorKind};
10
11#[cfg(feature = "tls")]
12use crate::config::SslMode;
13
14#[cfg(feature = "tls")]
15use std::sync::Arc;
16
17#[cfg(feature = "tls")]
18fn tls_error(message: impl Into<String>) -> Error {
19 Error::Connection(ConnectionError {
20 kind: ConnectionErrorKind::Ssl,
21 message: message.into(),
22 source: None,
23 })
24}
25
26#[cfg(feature = "tls")]
27pub(crate) fn server_name(host: &str) -> Result<rustls::pki_types::ServerName<'static>, Error> {
28 host.to_string()
29 .try_into()
30 .map_err(|e| tls_error(format!("Invalid server name '{host}': {e}")))
31}
32
33#[cfg(feature = "tls")]
40pub(crate) fn build_client_config(ssl_mode: SslMode) -> Result<rustls::ClientConfig, Error> {
41 let provider = Arc::new(rustls::crypto::ring::default_provider());
42
43 match ssl_mode {
44 SslMode::Disable => Err(tls_error("TLS config requested with SslMode::Disable")),
45 SslMode::Prefer | SslMode::Require => build_no_verify_config(&provider),
46 SslMode::VerifyCa | SslMode::VerifyFull => build_webpki_config(&provider),
47 }
48}
49
50#[cfg(feature = "tls")]
52fn build_no_verify_config(
53 provider: &Arc<rustls::crypto::CryptoProvider>,
54) -> Result<rustls::ClientConfig, Error> {
55 use rustls::client::danger::{HandshakeSignatureValid, ServerCertVerified, ServerCertVerifier};
56 use rustls::pki_types::{CertificateDer, ServerName, UnixTime};
57 use rustls::{DigitallySignedStruct, Error as RustlsError, SignatureScheme};
58
59 #[derive(Debug)]
60 struct NoVerifier;
61
62 impl ServerCertVerifier for NoVerifier {
63 fn verify_server_cert(
64 &self,
65 _end_entity: &CertificateDer<'_>,
66 _intermediates: &[CertificateDer<'_>],
67 _server_name: &ServerName<'_>,
68 _ocsp_response: &[u8],
69 _now: UnixTime,
70 ) -> Result<ServerCertVerified, RustlsError> {
71 Ok(ServerCertVerified::assertion())
72 }
73
74 fn verify_tls12_signature(
75 &self,
76 _message: &[u8],
77 _cert: &CertificateDer<'_>,
78 _dss: &DigitallySignedStruct,
79 ) -> Result<HandshakeSignatureValid, RustlsError> {
80 Ok(HandshakeSignatureValid::assertion())
81 }
82
83 fn verify_tls13_signature(
84 &self,
85 _message: &[u8],
86 _cert: &CertificateDer<'_>,
87 _dss: &DigitallySignedStruct,
88 ) -> Result<HandshakeSignatureValid, RustlsError> {
89 Ok(HandshakeSignatureValid::assertion())
90 }
91
92 fn supported_verify_schemes(&self) -> Vec<SignatureScheme> {
93 vec![
94 SignatureScheme::RSA_PKCS1_SHA256,
95 SignatureScheme::RSA_PKCS1_SHA384,
96 SignatureScheme::RSA_PKCS1_SHA512,
97 SignatureScheme::ECDSA_NISTP256_SHA256,
98 SignatureScheme::ECDSA_NISTP384_SHA384,
99 SignatureScheme::ECDSA_NISTP521_SHA512,
100 SignatureScheme::RSA_PSS_SHA256,
101 SignatureScheme::RSA_PSS_SHA384,
102 SignatureScheme::RSA_PSS_SHA512,
103 SignatureScheme::ED25519,
104 ]
105 }
106 }
107
108 let config = rustls::ClientConfig::builder_with_provider(provider.clone())
109 .with_protocol_versions(&[&rustls::version::TLS12, &rustls::version::TLS13])
110 .map_err(|e| tls_error(format!("Failed to set TLS versions: {e}")))?
111 .dangerous()
112 .with_custom_certificate_verifier(Arc::new(NoVerifier))
113 .with_no_client_auth();
114
115 Ok(config)
116}
117
118#[cfg(feature = "tls")]
120fn build_webpki_config(
121 provider: &Arc<rustls::crypto::CryptoProvider>,
122) -> Result<rustls::ClientConfig, Error> {
123 use rustls::RootCertStore;
124
125 let mut root_store = RootCertStore::empty();
126 root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
127
128 let config = rustls::ClientConfig::builder_with_provider(provider.clone())
129 .with_protocol_versions(&[&rustls::version::TLS12, &rustls::version::TLS13])
130 .map_err(|e| tls_error(format!("Failed to set TLS versions: {e}")))?
131 .with_root_certificates(root_store)
132 .with_no_client_auth();
133
134 Ok(config)
135}