use super::connection::TlsConfig;
use std::sync::Arc;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum ScramChannelBindingMode {
Disable,
#[default]
Prefer,
Require,
}
impl ScramChannelBindingMode {
pub fn parse(value: &str) -> Option<Self> {
match value.trim().to_ascii_lowercase().as_str() {
"disable" | "off" | "false" | "no" => Some(Self::Disable),
"prefer" | "on" | "true" | "yes" => Some(Self::Prefer),
"require" | "required" => Some(Self::Require),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EnterpriseAuthMechanism {
KerberosV5,
GssApi,
Sspi,
}
pub type GssTokenProvider = fn(EnterpriseAuthMechanism, Option<&[u8]>) -> Result<Vec<u8>, String>;
#[derive(Debug, Clone, Copy)]
pub struct GssTokenRequest<'a> {
pub session_id: u64,
pub mechanism: EnterpriseAuthMechanism,
pub server_token: Option<&'a [u8]>,
}
pub type GssTokenProviderEx =
Arc<dyn for<'a> Fn(GssTokenRequest<'a>) -> Result<Vec<u8>, String> + Send + Sync>;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct AuthSettings {
pub allow_cleartext_password: bool,
pub allow_md5_password: bool,
pub allow_scram_sha_256: bool,
pub allow_kerberos_v5: bool,
pub allow_gssapi: bool,
pub allow_sspi: bool,
pub channel_binding: ScramChannelBindingMode,
}
impl Default for AuthSettings {
fn default() -> Self {
Self {
allow_cleartext_password: true,
allow_md5_password: true,
allow_scram_sha_256: true,
allow_kerberos_v5: false,
allow_gssapi: false,
allow_sspi: false,
channel_binding: ScramChannelBindingMode::Prefer,
}
}
}
impl AuthSettings {
pub fn scram_only() -> Self {
Self {
allow_cleartext_password: false,
allow_md5_password: false,
allow_scram_sha_256: true,
allow_kerberos_v5: false,
allow_gssapi: false,
allow_sspi: false,
channel_binding: ScramChannelBindingMode::Prefer,
}
}
pub fn gssapi_only() -> Self {
Self {
allow_cleartext_password: false,
allow_md5_password: false,
allow_scram_sha_256: false,
allow_kerberos_v5: true,
allow_gssapi: true,
allow_sspi: true,
channel_binding: ScramChannelBindingMode::Prefer,
}
}
pub(crate) fn has_any_password_method(self) -> bool {
self.allow_cleartext_password || self.allow_md5_password || self.allow_scram_sha_256
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum TlsMode {
#[default]
Disable,
Prefer,
Require,
}
impl TlsMode {
pub fn parse_sslmode(value: &str) -> Option<Self> {
match value.trim().to_ascii_lowercase().as_str() {
"disable" => Some(Self::Disable),
"allow" | "prefer" => Some(Self::Prefer),
"require" | "verify-ca" | "verify-full" => Some(Self::Require),
_ => None,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum GssEncMode {
#[default]
Disable,
Prefer,
Require,
}
impl GssEncMode {
pub fn parse_gssencmode(value: &str) -> Option<Self> {
match value.trim().to_ascii_lowercase().as_str() {
"disable" => Some(Self::Disable),
"prefer" => Some(Self::Prefer),
"require" => Some(Self::Require),
_ => None,
}
}
}
#[derive(Clone, Default)]
pub struct ConnectOptions {
pub tls_mode: TlsMode,
pub gss_enc_mode: GssEncMode,
pub tls_ca_cert_pem: Option<Vec<u8>>,
pub mtls: Option<TlsConfig>,
pub gss_token_provider: Option<GssTokenProvider>,
pub gss_token_provider_ex: Option<GssTokenProviderEx>,
pub auth: AuthSettings,
pub startup_params: Vec<(String, String)>,
}
impl std::fmt::Debug for ConnectOptions {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConnectOptions")
.field("tls_mode", &self.tls_mode)
.field("gss_enc_mode", &self.gss_enc_mode)
.field(
"tls_ca_cert_pem",
&self.tls_ca_cert_pem.as_ref().map(std::vec::Vec::len),
)
.field("mtls", &self.mtls.as_ref().map(|_| "<configured>"))
.field(
"gss_token_provider",
&self.gss_token_provider.as_ref().map(|_| "<configured>"),
)
.field(
"gss_token_provider_ex",
&self.gss_token_provider_ex.as_ref().map(|_| "<configured>"),
)
.field("auth", &self.auth)
.field("startup_params_count", &self.startup_params.len())
.finish()
}
}
impl ConnectOptions {
pub fn with_startup_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
let key = key.into();
let value = value.into();
self.startup_params
.retain(|(existing, _)| !existing.eq_ignore_ascii_case(&key));
self.startup_params.push((key, value));
self
}
pub fn with_logical_replication(mut self) -> Self {
self.startup_params
.retain(|(k, _)| !k.eq_ignore_ascii_case("replication"));
self.startup_params
.push(("replication".to_string(), "database".to_string()));
self
}
}