Skip to main content

qail_pg/driver/
auth_types.rs

1//! Authentication and security types: ScramChannelBindingMode, EnterpriseAuthMechanism,
2//! GssTokenProvider, GssTokenRequest, AuthSettings, TlsMode, GssEncMode, ConnectOptions.
3
4use super::connection::TlsConfig;
5use std::sync::Arc;
6
7/// SCRAM channel-binding policy during SASL negotiation.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
9pub enum ScramChannelBindingMode {
10    /// Do not use `SCRAM-SHA-256-PLUS` even when available.
11    Disable,
12    /// Prefer `SCRAM-SHA-256-PLUS`, fallback to plain SCRAM if needed.
13    #[default]
14    Prefer,
15    /// Require `SCRAM-SHA-256-PLUS` and fail otherwise.
16    Require,
17}
18
19impl ScramChannelBindingMode {
20    /// Parse common config string values.
21    pub fn parse(value: &str) -> Option<Self> {
22        match value.trim().to_ascii_lowercase().as_str() {
23            "disable" | "off" | "false" | "no" => Some(Self::Disable),
24            "prefer" | "on" | "true" | "yes" => Some(Self::Prefer),
25            "require" | "required" => Some(Self::Require),
26            _ => None,
27        }
28    }
29}
30
31/// Enterprise authentication mechanisms initiated by PostgreSQL.
32#[derive(Debug, Clone, Copy, PartialEq, Eq)]
33pub enum EnterpriseAuthMechanism {
34    /// Kerberos V5 (`AuthenticationKerberosV5`, auth code `2`).
35    KerberosV5,
36    /// GSSAPI (`AuthenticationGSS`, auth code `7`).
37    GssApi,
38    /// SSPI (`AuthenticationSSPI`, auth code `9`, primarily Windows servers).
39    Sspi,
40}
41
42/// Callback used to generate GSS/SSPI response tokens.
43///
44/// The callback receives:
45/// - negotiated enterprise auth mechanism
46/// - optional server challenge bytes (`None` for initial token)
47///
48/// It must return the client response token bytes to send in `GSSResponse`.
49pub type GssTokenProvider = fn(EnterpriseAuthMechanism, Option<&[u8]>) -> Result<Vec<u8>, String>;
50
51/// Structured token request for stateful Kerberos/GSS/SSPI providers.
52#[derive(Debug, Clone, Copy)]
53pub struct GssTokenRequest<'a> {
54    /// Stable per-handshake identifier so providers can keep per-connection state.
55    pub session_id: u64,
56    /// Negotiated enterprise auth mechanism.
57    pub mechanism: EnterpriseAuthMechanism,
58    /// Server challenge token (`None` for initial token).
59    pub server_token: Option<&'a [u8]>,
60}
61
62/// Stateful callback for Kerberos/GSS/SSPI response generation.
63///
64/// Use this when the underlying auth stack needs per-handshake context between
65/// `AuthenticationGSS` and `AuthenticationGSSContinue` messages.
66pub type GssTokenProviderEx =
67    Arc<dyn for<'a> Fn(GssTokenRequest<'a>) -> Result<Vec<u8>, String> + Send + Sync>;
68
69/// Password-auth mechanism policy.
70///
71/// Defaults allow all PostgreSQL password mechanisms for compatibility.
72#[derive(Debug, Clone, Copy, PartialEq, Eq)]
73pub struct AuthSettings {
74    /// Allow server-requested cleartext password auth.
75    pub allow_cleartext_password: bool,
76    /// Allow server-requested MD5 password auth.
77    pub allow_md5_password: bool,
78    /// Allow server-requested SCRAM auth.
79    pub allow_scram_sha_256: bool,
80    /// Allow server-requested Kerberos V5 auth flow.
81    pub allow_kerberos_v5: bool,
82    /// Allow server-requested GSSAPI auth flow.
83    pub allow_gssapi: bool,
84    /// Allow server-requested SSPI auth flow.
85    pub allow_sspi: bool,
86    /// SCRAM channel-binding requirement.
87    pub channel_binding: ScramChannelBindingMode,
88}
89
90impl Default for AuthSettings {
91    fn default() -> Self {
92        Self {
93            allow_cleartext_password: true,
94            allow_md5_password: true,
95            allow_scram_sha_256: true,
96            allow_kerberos_v5: false,
97            allow_gssapi: false,
98            allow_sspi: false,
99            channel_binding: ScramChannelBindingMode::Prefer,
100        }
101    }
102}
103
104impl AuthSettings {
105    /// Restrictive mode: SCRAM-only password auth.
106    pub fn scram_only() -> Self {
107        Self {
108            allow_cleartext_password: false,
109            allow_md5_password: false,
110            allow_scram_sha_256: true,
111            allow_kerberos_v5: false,
112            allow_gssapi: false,
113            allow_sspi: false,
114            channel_binding: ScramChannelBindingMode::Prefer,
115        }
116    }
117
118    /// Restrictive mode: enterprise Kerberos/GSS only (no password auth).
119    pub fn gssapi_only() -> Self {
120        Self {
121            allow_cleartext_password: false,
122            allow_md5_password: false,
123            allow_scram_sha_256: false,
124            allow_kerberos_v5: true,
125            allow_gssapi: true,
126            allow_sspi: true,
127            channel_binding: ScramChannelBindingMode::Prefer,
128        }
129    }
130
131    pub(crate) fn has_any_password_method(self) -> bool {
132        self.allow_cleartext_password || self.allow_md5_password || self.allow_scram_sha_256
133    }
134}
135
136/// TLS policy for connection establishment.
137#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
138pub enum TlsMode {
139    /// Do not attempt TLS.
140    #[default]
141    Disable,
142    /// Try TLS first; fallback to plaintext only when server has no TLS support.
143    Prefer,
144    /// Require TLS and fail if unavailable.
145    Require,
146}
147
148impl TlsMode {
149    /// Parse libpq-style `sslmode` values.
150    pub fn parse_sslmode(value: &str) -> Option<Self> {
151        match value.trim().to_ascii_lowercase().as_str() {
152            "disable" => Some(Self::Disable),
153            "allow" | "prefer" => Some(Self::Prefer),
154            "require" | "verify-ca" | "verify-full" => Some(Self::Require),
155            _ => None,
156        }
157    }
158}
159
160/// GSSAPI encryption mode for transport-level encryption via Kerberos.
161///
162/// Controls whether the driver attempts GSSAPI session encryption
163/// (GSSENCRequest) before falling back to TLS or plaintext.
164///
165/// See: PostgreSQL protocol §54.2.11 — GSSAPI Session Encryption.
166#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
167pub enum GssEncMode {
168    /// Never attempt GSSAPI encryption.
169    #[default]
170    Disable,
171    /// Try GSSAPI encryption first; fall back to TLS or plaintext.
172    Prefer,
173    /// Require GSSAPI encryption — fail if the server rejects GSSENCRequest.
174    Require,
175}
176
177impl GssEncMode {
178    /// Parse libpq-style `gssencmode` values.
179    pub fn parse_gssencmode(value: &str) -> Option<Self> {
180        match value.trim().to_ascii_lowercase().as_str() {
181            "disable" => Some(Self::Disable),
182            "prefer" => Some(Self::Prefer),
183            "require" => Some(Self::Require),
184            _ => None,
185        }
186    }
187}
188
189/// Advanced connection options for enterprise deployments.
190#[derive(Clone, Default)]
191pub struct ConnectOptions {
192    /// TLS mode for the primary connection.
193    pub tls_mode: TlsMode,
194    /// GSSAPI session encryption mode.
195    pub gss_enc_mode: GssEncMode,
196    /// Optional custom CA bundle (PEM) for TLS server validation.
197    pub tls_ca_cert_pem: Option<Vec<u8>>,
198    /// Optional mTLS client certificate/key config.
199    pub mtls: Option<TlsConfig>,
200    /// Optional callback for Kerberos/GSS/SSPI token generation.
201    pub gss_token_provider: Option<GssTokenProvider>,
202    /// Optional stateful Kerberos/GSS/SSPI token provider.
203    pub gss_token_provider_ex: Option<GssTokenProviderEx>,
204    /// Password-auth policy.
205    pub auth: AuthSettings,
206    /// Additional startup parameters sent in StartupMessage.
207    /// Example: `replication=database` for logical replication mode.
208    pub startup_params: Vec<(String, String)>,
209}
210
211impl std::fmt::Debug for ConnectOptions {
212    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
213        f.debug_struct("ConnectOptions")
214            .field("tls_mode", &self.tls_mode)
215            .field("gss_enc_mode", &self.gss_enc_mode)
216            .field(
217                "tls_ca_cert_pem",
218                &self.tls_ca_cert_pem.as_ref().map(std::vec::Vec::len),
219            )
220            .field("mtls", &self.mtls.as_ref().map(|_| "<configured>"))
221            .field(
222                "gss_token_provider",
223                &self.gss_token_provider.as_ref().map(|_| "<configured>"),
224            )
225            .field(
226                "gss_token_provider_ex",
227                &self.gss_token_provider_ex.as_ref().map(|_| "<configured>"),
228            )
229            .field("auth", &self.auth)
230            .field("startup_params_count", &self.startup_params.len())
231            .finish()
232    }
233}
234
235impl ConnectOptions {
236    /// Add a startup parameter.
237    ///
238    /// Example: `opts.with_startup_param("application_name", "qail-repl")`.
239    pub fn with_startup_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
240        let key = key.into();
241        let value = value.into();
242        self.startup_params
243            .retain(|(existing, _)| !existing.eq_ignore_ascii_case(&key));
244        self.startup_params.push((key, value));
245        self
246    }
247
248    /// Enable logical replication startup mode (`replication=database`).
249    pub fn with_logical_replication(mut self) -> Self {
250        self.startup_params
251            .retain(|(k, _)| !k.eq_ignore_ascii_case("replication"));
252        self.startup_params
253            .push(("replication".to_string(), "database".to_string()));
254        self
255    }
256}