use crate::postgres::stream::PgStream;
use crate::url::Url;
#[cfg_attr(not(feature = "tls"), allow(unused_variables))]
pub(crate) async fn request_if_needed(stream: &mut PgStream, url: &Url) -> crate::Result<()> {
match url.param("sslmode").as_deref() {
Some("disable") | Some("allow") => {
}
#[cfg(feature = "tls")]
Some("prefer") | None => {
if !try_upgrade(stream, url, true, true).await? {
}
}
#[cfg(not(feature = "tls"))]
None => {
}
#[cfg(feature = "tls")]
Some(mode @ "require") | Some(mode @ "verify-ca") | Some(mode @ "verify-full") => {
if !try_upgrade(
stream,
url,
mode == "require",
mode != "verify-full",
)
.await?
{
return Err(tls_err!("server does not support TLS").into());
}
}
#[cfg(not(feature = "tls"))]
Some(mode @ "prefer")
| Some(mode @ "require")
| Some(mode @ "verify-ca")
| Some(mode @ "verify-full") => {
return Err(tls_err!(
"sslmode {:?} unsupported; SQLx was compiled without `tls` feature",
mode
)
.into());
}
Some(mode) => {
return Err(tls_err!("unknown `sslmode` value: {:?}", mode).into());
}
}
Ok(())
}
#[cfg(feature = "tls")]
async fn try_upgrade(
stream: &mut PgStream,
url: &Url,
accept_invalid_certs: bool,
accept_invalid_host_names: bool,
) -> crate::Result<bool> {
use async_native_tls::TlsConnector;
use std::borrow::Cow;
stream.write(crate::postgres::protocol::SslRequest);
stream.flush().await?;
let ind = stream.stream.peek(1).await?[0];
stream.stream.consume(1);
match ind {
b'S' => {
}
b'N' => {
return Ok(false);
}
other => {
return Err(tls_err!("unexpected response from SSLRequest: 0x{:02X}", other).into());
}
}
let mut connector = TlsConnector::new()
.danger_accept_invalid_certs(accept_invalid_certs)
.danger_accept_invalid_hostnames(accept_invalid_host_names);
if !accept_invalid_certs {
if let Some(cert) = read_root_certificate(&url).await? {
connector = connector.add_root_certificate(cert);
}
}
let host = url
.host()
.map(Cow::Borrowed)
.or_else(|| url.param("host"))
.unwrap_or("localhost".into());
stream.stream.upgrade(&host, connector).await?;
Ok(true)
}
#[cfg(feature = "tls")]
async fn read_root_certificate(url: &Url) -> crate::Result<Option<async_native_tls::Certificate>> {
use crate::runtime::fs;
use std::env;
let mut data = None;
if let Some(path) = url
.param("sslrootcert")
.or_else(|| env::var("PGSSLROOTCERT").ok().map(Into::into))
{
data = Some(fs::read(&*path).await?);
} else if cfg!(windows) {
if let Ok(app_data) = env::var("APPDATA") {
let path = format!("{}\\postgresql\\root.crt", app_data);
data = fs::read(path).await.ok();
}
} else {
if let Ok(home) = env::var("HOME") {
let path = format!("{}/.postgresql/root.crt", home);
data = fs::read(path).await.ok();
}
}
data.map(|data| async_native_tls::Certificate::from_pem(&data))
.transpose()
.map_err(Into::into)
}