Skip to main content

qail_pg/driver/
builder.rs

1//! PgDriverBuilder — ergonomic builder pattern for PgDriver connections.
2
3use super::auth_types::{
4    AuthSettings, ConnectOptions, GssEncMode, GssTokenProvider, GssTokenProviderEx,
5    ScramChannelBindingMode, TlsMode,
6};
7use super::core::PgDriver;
8use super::types::{PgError, PgResult};
9use crate::driver::connection::TlsConfig;
10
11// ============================================================================
12// Connection Builder
13// ============================================================================
14
15/// Builder for creating PgDriver connections with named parameters.
16/// # Example
17/// ```ignore
18/// let driver = PgDriver::builder()
19///     .host("localhost")
20///     .port(5432)
21///     .user("admin")
22///     .database("mydb")
23///     .password("secret")
24///     .connect()
25///     .await?;
26/// ```
27#[derive(Default)]
28pub struct PgDriverBuilder {
29    host: Option<String>,
30    port: Option<u16>,
31    user: Option<String>,
32    database: Option<String>,
33    password: Option<String>,
34    timeout: Option<std::time::Duration>,
35    pub(crate) connect_options: ConnectOptions,
36}
37
38impl PgDriverBuilder {
39    /// Create a new builder with default values.
40    pub fn new() -> Self {
41        Self::default()
42    }
43
44    /// Set the host (default: "127.0.0.1").
45    pub fn host(mut self, host: impl Into<String>) -> Self {
46        self.host = Some(host.into());
47        self
48    }
49
50    /// Set the port (default: 5432).
51    pub fn port(mut self, port: u16) -> Self {
52        self.port = Some(port);
53        self
54    }
55
56    /// Set the username (required).
57    pub fn user(mut self, user: impl Into<String>) -> Self {
58        self.user = Some(user.into());
59        self
60    }
61
62    /// Set the database name (required).
63    pub fn database(mut self, database: impl Into<String>) -> Self {
64        self.database = Some(database.into());
65        self
66    }
67
68    /// Set the password (optional, for cleartext/MD5/SCRAM-SHA-256 auth).
69    pub fn password(mut self, password: impl Into<String>) -> Self {
70        self.password = Some(password.into());
71        self
72    }
73
74    /// Set connection timeout (optional).
75    pub fn timeout(mut self, timeout: std::time::Duration) -> Self {
76        self.timeout = Some(timeout);
77        self
78    }
79
80    /// Set TLS policy (`disable`, `prefer`, `require`).
81    pub fn tls_mode(mut self, mode: TlsMode) -> Self {
82        self.connect_options.tls_mode = mode;
83        self
84    }
85
86    /// Set GSSAPI session encryption mode (`disable`, `prefer`, `require`).
87    pub fn gss_enc_mode(mut self, mode: GssEncMode) -> Self {
88        self.connect_options.gss_enc_mode = mode;
89        self
90    }
91
92    /// Set custom CA bundle PEM for TLS validation.
93    pub fn tls_ca_cert_pem(mut self, ca_pem: Vec<u8>) -> Self {
94        self.connect_options.tls_ca_cert_pem = Some(ca_pem);
95        self
96    }
97
98    /// Enable mTLS using client certificate/key config.
99    pub fn mtls(mut self, config: TlsConfig) -> Self {
100        self.connect_options.mtls = Some(config);
101        self.connect_options.tls_mode = TlsMode::Require;
102        self
103    }
104
105    /// Override password-auth policy.
106    pub fn auth_settings(mut self, settings: AuthSettings) -> Self {
107        self.connect_options.auth = settings;
108        self
109    }
110
111    /// Set SCRAM channel-binding mode.
112    pub fn channel_binding_mode(mut self, mode: ScramChannelBindingMode) -> Self {
113        self.connect_options.auth.channel_binding = mode;
114        self
115    }
116
117    /// Set Kerberos/GSS/SSPI token provider callback.
118    pub fn gss_token_provider(mut self, provider: GssTokenProvider) -> Self {
119        self.connect_options.gss_token_provider = Some(provider);
120        self
121    }
122
123    /// Set a stateful Kerberos/GSS/SSPI token provider.
124    pub fn gss_token_provider_ex(mut self, provider: GssTokenProviderEx) -> Self {
125        self.connect_options.gss_token_provider_ex = Some(provider);
126        self
127    }
128
129    /// Add a custom StartupMessage parameter.
130    ///
131    /// Example: `.startup_param("application_name", "qail-replica")`
132    pub fn startup_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
133        let key = key.into();
134        let value = value.into();
135        self.connect_options
136            .startup_params
137            .retain(|(existing, _)| !existing.eq_ignore_ascii_case(&key));
138        self.connect_options.startup_params.push((key, value));
139        self
140    }
141
142    /// Enable logical replication startup mode (`replication=database`).
143    ///
144    /// This is required before issuing commands like `IDENTIFY_SYSTEM` or
145    /// `CREATE_REPLICATION_SLOT` on a replication connection.
146    pub fn logical_replication(mut self) -> Self {
147        self.connect_options
148            .startup_params
149            .retain(|(k, _)| !k.eq_ignore_ascii_case("replication"));
150        self.connect_options
151            .startup_params
152            .push(("replication".to_string(), "database".to_string()));
153        self
154    }
155
156    /// Connect to PostgreSQL using the configured parameters.
157    pub async fn connect(self) -> PgResult<PgDriver> {
158        let host = self.host.unwrap_or_else(|| "127.0.0.1".to_string());
159        let port = self.port.unwrap_or(5432);
160        let user = self
161            .user
162            .ok_or_else(|| PgError::Connection("User is required".to_string()))?;
163        let database = self
164            .database
165            .ok_or_else(|| PgError::Connection("Database is required".to_string()))?;
166
167        let password = self.password;
168        let options = self.connect_options;
169
170        if let Some(timeout) = self.timeout {
171            let options = options.clone();
172            tokio::time::timeout(
173                timeout,
174                PgDriver::connect_with_options(
175                    &host,
176                    port,
177                    &user,
178                    &database,
179                    password.as_deref(),
180                    options,
181                ),
182            )
183            .await
184            .map_err(|_| PgError::Timeout(format!("connection after {:?}", timeout)))?
185        } else {
186            PgDriver::connect_with_options(
187                &host,
188                port,
189                &user,
190                &database,
191                password.as_deref(),
192                options,
193            )
194            .await
195        }
196    }
197}