qail_pg/driver/
builder.rs1use super::auth_types::{
4 AuthSettings, ConnectOptions, GssEncMode, GssTokenProvider, GssTokenProviderEx,
5 ScramChannelBindingMode, TlsMode,
6};
7use super::core::PgDriver;
8use super::types::{PgError, PgResult};
9use crate::driver::connection::TlsConfig;
10
11#[derive(Default)]
28pub struct PgDriverBuilder {
29 host: Option<String>,
30 port: Option<u16>,
31 user: Option<String>,
32 database: Option<String>,
33 password: Option<String>,
34 timeout: Option<std::time::Duration>,
35 pub(crate) connect_options: ConnectOptions,
36}
37
38impl PgDriverBuilder {
39 pub fn new() -> Self {
41 Self::default()
42 }
43
44 pub fn host(mut self, host: impl Into<String>) -> Self {
46 self.host = Some(host.into());
47 self
48 }
49
50 pub fn port(mut self, port: u16) -> Self {
52 self.port = Some(port);
53 self
54 }
55
56 pub fn user(mut self, user: impl Into<String>) -> Self {
58 self.user = Some(user.into());
59 self
60 }
61
62 pub fn database(mut self, database: impl Into<String>) -> Self {
64 self.database = Some(database.into());
65 self
66 }
67
68 pub fn password(mut self, password: impl Into<String>) -> Self {
70 self.password = Some(password.into());
71 self
72 }
73
74 pub fn timeout(mut self, timeout: std::time::Duration) -> Self {
76 self.timeout = Some(timeout);
77 self
78 }
79
80 pub fn tls_mode(mut self, mode: TlsMode) -> Self {
82 self.connect_options.tls_mode = mode;
83 self
84 }
85
86 pub fn gss_enc_mode(mut self, mode: GssEncMode) -> Self {
88 self.connect_options.gss_enc_mode = mode;
89 self
90 }
91
92 pub fn tls_ca_cert_pem(mut self, ca_pem: Vec<u8>) -> Self {
94 self.connect_options.tls_ca_cert_pem = Some(ca_pem);
95 self
96 }
97
98 pub fn mtls(mut self, config: TlsConfig) -> Self {
100 self.connect_options.mtls = Some(config);
101 self.connect_options.tls_mode = TlsMode::Require;
102 self
103 }
104
105 pub fn auth_settings(mut self, settings: AuthSettings) -> Self {
107 self.connect_options.auth = settings;
108 self
109 }
110
111 pub fn channel_binding_mode(mut self, mode: ScramChannelBindingMode) -> Self {
113 self.connect_options.auth.channel_binding = mode;
114 self
115 }
116
117 pub fn gss_token_provider(mut self, provider: GssTokenProvider) -> Self {
119 self.connect_options.gss_token_provider = Some(provider);
120 self
121 }
122
123 pub fn gss_token_provider_ex(mut self, provider: GssTokenProviderEx) -> Self {
125 self.connect_options.gss_token_provider_ex = Some(provider);
126 self
127 }
128
129 pub fn startup_param(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
133 let key = key.into();
134 let value = value.into();
135 self.connect_options
136 .startup_params
137 .retain(|(existing, _)| !existing.eq_ignore_ascii_case(&key));
138 self.connect_options.startup_params.push((key, value));
139 self
140 }
141
142 pub fn logical_replication(mut self) -> Self {
147 self.connect_options
148 .startup_params
149 .retain(|(k, _)| !k.eq_ignore_ascii_case("replication"));
150 self.connect_options
151 .startup_params
152 .push(("replication".to_string(), "database".to_string()));
153 self
154 }
155
156 pub async fn connect(self) -> PgResult<PgDriver> {
158 let host = self.host.unwrap_or_else(|| "127.0.0.1".to_string());
159 let port = self.port.unwrap_or(5432);
160 let user = self
161 .user
162 .ok_or_else(|| PgError::Connection("User is required".to_string()))?;
163 let database = self
164 .database
165 .ok_or_else(|| PgError::Connection("Database is required".to_string()))?;
166
167 let password = self.password;
168 let options = self.connect_options;
169
170 if let Some(timeout) = self.timeout {
171 let options = options.clone();
172 tokio::time::timeout(
173 timeout,
174 PgDriver::connect_with_options(
175 &host,
176 port,
177 &user,
178 &database,
179 password.as_deref(),
180 options,
181 ),
182 )
183 .await
184 .map_err(|_| PgError::Timeout(format!("connection after {:?}", timeout)))?
185 } else {
186 PgDriver::connect_with_options(
187 &host,
188 port,
189 &user,
190 &database,
191 password.as_deref(),
192 options,
193 )
194 .await
195 }
196 }
197}