use std::io::{Read, Write};
use std::net::TcpStream;
use std::sync::{Arc, OnceLock};
use crate::proto;
use crate::DriverError;
static TLS_CONFIG: OnceLock<Arc<rustls::ClientConfig>> = OnceLock::new();
fn init_tls_config() -> Arc<rustls::ClientConfig> {
let mut root_store = rustls::RootCertStore::empty();
root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());
Arc::new(
rustls::ClientConfig::builder()
.with_root_certificates(root_store)
.with_no_client_auth(),
)
}
pub struct TlsUpgradeResult {
pub stream: rustls::StreamOwned<rustls::ClientConnection, TcpStream>,
pub server_cert_hash: Option<[u8; 32]>,
}
pub fn try_upgrade(
mut tcp: TcpStream,
host: &str,
required: bool,
) -> Result<TlsUpgradeResult, DriverError> {
let mut buf = Vec::with_capacity(8);
proto::write_ssl_request(&mut buf);
tcp.write_all(&buf).map_err(DriverError::Io)?;
tcp.flush().map_err(DriverError::Io)?;
let mut response = [0u8; 1];
tcp.read_exact(&mut response).map_err(DriverError::Io)?;
match response[0] {
b'S' => {
let server_name =
rustls::pki_types::ServerName::try_from(host.to_owned()).map_err(|e| {
DriverError::Protocol(format!("invalid TLS server name '{host}': {e}"))
})?;
let tls_conn = rustls::ClientConnection::new(
TLS_CONFIG.get_or_init(init_tls_config).clone(),
server_name,
)
.map_err(|e| DriverError::Io(std::io::Error::other(e)))?;
let stream = rustls::StreamOwned::new(tls_conn, tcp);
let server_cert_hash = stream
.conn
.peer_certificates()
.and_then(|certs| certs.first())
.map(|cert| {
use sha2::{Digest, Sha256};
let mut hasher = Sha256::new();
hasher.update(cert.as_ref());
let hash: [u8; 32] = hasher.finalize().into();
hash
});
Ok(TlsUpgradeResult {
stream,
server_cert_hash,
})
}
b'N' => {
if required {
Err(DriverError::Protocol(
"server does not support TLS (sslmode=require)".into(),
))
} else {
Err(DriverError::Protocol(
"server declined TLS (sslmode=prefer, falling back)".into(),
))
}
}
other => Err(DriverError::Protocol(format!(
"unexpected SSL response byte: 0x{other:02x}"
))),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn tls_sync_config_cached() {
let c1 = TLS_CONFIG.get_or_init(init_tls_config).clone();
let c2 = TLS_CONFIG.get_or_init(init_tls_config).clone();
assert!(Arc::ptr_eq(&c1, &c2));
}
}