1use super::auth_types::*;
5use super::builder::PgDriverBuilder;
6use super::connection::PgConnection;
7use super::pool;
8use super::rls::RlsContext;
9use super::types::*;
10
11pub struct PgDriver {
13 pub(super) connection: PgConnection,
14 pub(super) rls_context: Option<RlsContext>,
16}
17
18impl PgDriver {
19 pub fn new(connection: PgConnection) -> Self {
21 Self {
22 connection,
23 rls_context: None,
24 }
25 }
26
27 pub fn builder() -> PgDriverBuilder {
40 PgDriverBuilder::new()
41 }
42
43 pub async fn connect(host: &str, port: u16, user: &str, database: &str) -> PgResult<Self> {
52 let connection = PgConnection::connect(host, port, user, database).await?;
53 Ok(Self::new(connection))
54 }
55
56 pub async fn connect_with_password(
59 host: &str,
60 port: u16,
61 user: &str,
62 database: &str,
63 password: &str,
64 ) -> PgResult<Self> {
65 let connection =
66 PgConnection::connect_with_password(host, port, user, database, Some(password)).await?;
67 Ok(Self::new(connection))
68 }
69
70 pub async fn connect_with_options(
72 host: &str,
73 port: u16,
74 user: &str,
75 database: &str,
76 password: Option<&str>,
77 options: ConnectOptions,
78 ) -> PgResult<Self> {
79 let connection =
80 PgConnection::connect_with_options(host, port, user, database, password, options)
81 .await?;
82 Ok(Self::new(connection))
83 }
84
85 pub async fn connect_logical_replication(
90 host: &str,
91 port: u16,
92 user: &str,
93 database: &str,
94 password: Option<&str>,
95 ) -> PgResult<Self> {
96 let options = ConnectOptions::default().with_logical_replication();
97 Self::connect_with_options(host, port, user, database, password, options).await
98 }
99
100 pub async fn connect_logical_replication_with_options(
102 host: &str,
103 port: u16,
104 user: &str,
105 database: &str,
106 password: Option<&str>,
107 options: ConnectOptions,
108 ) -> PgResult<Self> {
109 Self::connect_with_options(
110 host,
111 port,
112 user,
113 database,
114 password,
115 options.with_logical_replication(),
116 )
117 .await
118 }
119
120 pub async fn connect_env() -> PgResult<Self> {
131 let url = std::env::var("DATABASE_URL").map_err(|_| {
132 PgError::Connection("DATABASE_URL environment variable not set".to_string())
133 })?;
134 Self::connect_url(&url).await
135 }
136
137 pub async fn connect_url(url: &str) -> PgResult<Self> {
150 let (host, port, user, database, password) = Self::parse_database_url(url)?;
151
152 let mut pool_cfg = pool::PoolConfig::new(&host, port, &user, &database);
154 if let Some(pw) = &password {
155 pool_cfg = pool_cfg.password(pw);
156 }
157 if let Some((_, query)) = url.split_once('?') {
158 pool::apply_url_query_params(&mut pool_cfg, query, &host)?;
159 }
160
161 let mut opts = ConnectOptions {
162 tls_mode: pool_cfg.tls_mode,
163 gss_enc_mode: pool_cfg.gss_enc_mode,
164 tls_ca_cert_pem: pool_cfg.tls_ca_cert_pem,
165 mtls: pool_cfg.mtls,
166 gss_token_provider: pool_cfg.gss_token_provider,
167 gss_token_provider_ex: pool_cfg.gss_token_provider_ex,
168 auth: pool_cfg.auth_settings,
169 startup_params: Vec::new(),
170 };
171
172 if let Some((_, query)) = url.split_once('?') {
174 for pair in query.split('&') {
175 let mut kv = pair.splitn(2, '=');
176 let key = kv.next().unwrap_or_default().trim();
177 let value = kv.next().unwrap_or_default().trim();
178 if key.eq_ignore_ascii_case("replication") {
179 let replication_mode = if value.eq_ignore_ascii_case("database") {
180 "database"
181 } else if value.eq_ignore_ascii_case("true")
182 || value.eq_ignore_ascii_case("on")
183 || value == "1"
184 {
185 "database"
188 } else {
189 return Err(PgError::Connection(format!(
190 "Invalid replication startup mode '{}': expected database|true|on|1",
191 value
192 )));
193 };
194 opts = opts.with_startup_param("replication", replication_mode);
195 }
196 }
197 }
198
199 Self::connect_with_options(&host, port, &user, &database, password.as_deref(), opts).await
200 }
201
202 pub(crate) fn parse_database_url(
209 url: &str,
210 ) -> PgResult<(String, u16, String, String, Option<String>)> {
211 let after_scheme = if let Some(rest) = url.strip_prefix("postgres://") {
212 rest
213 } else if let Some(rest) = url.strip_prefix("postgresql://") {
214 rest
215 } else {
216 return Err(PgError::Connection(
217 "Invalid DATABASE_URL: expected postgres:// or postgresql://".to_string(),
218 ));
219 };
220
221 let (auth_part, host_db_part) = if let Some(at_pos) = after_scheme.rfind('@') {
223 (Some(&after_scheme[..at_pos]), &after_scheme[at_pos + 1..])
224 } else {
225 (None, after_scheme)
226 };
227
228 let (user, password) = if let Some(auth) = auth_part {
230 if auth.is_empty() {
231 return Err(PgError::Connection(
232 "Invalid DATABASE_URL: missing user".to_string(),
233 ));
234 }
235 let parts: Vec<&str> = auth.splitn(2, ':').collect();
236 if parts.len() == 2 {
237 let user = Self::percent_decode(parts[0])?;
239 if user.is_empty() {
240 return Err(PgError::Connection(
241 "Invalid DATABASE_URL: missing user".to_string(),
242 ));
243 }
244 (user, Some(Self::percent_decode(parts[1])?))
245 } else {
246 let user = Self::percent_decode(parts[0])?;
247 if user.is_empty() {
248 return Err(PgError::Connection(
249 "Invalid DATABASE_URL: missing user".to_string(),
250 ));
251 }
252 (user, None)
253 }
254 } else {
255 ("postgres".to_string(), None)
256 };
257
258 let (host_port, database) = if let Some(slash_pos) = host_db_part.find('/') {
260 let raw_db = &host_db_part[slash_pos + 1..];
261 let db = Self::percent_decode(raw_db.split('?').next().unwrap_or(raw_db))?;
263 (&host_db_part[..slash_pos], db)
264 } else {
265 return Err(PgError::Connection(
266 "Invalid DATABASE_URL: missing database name".to_string(),
267 ));
268 };
269
270 let (host, port) = if host_port.starts_with('[') {
272 let end = host_port.find(']').ok_or_else(|| {
273 PgError::Connection("Invalid DATABASE_URL: malformed IPv6 host".to_string())
274 })?;
275 let host = &host_port[..=end];
276 if host == "[]" {
277 return Err(PgError::Connection(
278 "Invalid DATABASE_URL: missing host".to_string(),
279 ));
280 }
281 let suffix = &host_port[end + 1..];
282 let port = if suffix.is_empty() {
283 5432
284 } else if let Some(port_str) = suffix.strip_prefix(':') {
285 Self::parse_database_url_port(port_str)?
286 } else {
287 return Err(PgError::Connection(
288 "Invalid DATABASE_URL: malformed IPv6 host".to_string(),
289 ));
290 };
291 (host.to_string(), port)
292 } else if let Some(colon_pos) = host_port.rfind(':') {
293 let port_str = &host_port[colon_pos + 1..];
294 let host = &host_port[..colon_pos];
295 if host.is_empty() {
296 return Err(PgError::Connection(
297 "Invalid DATABASE_URL: missing host".to_string(),
298 ));
299 }
300 let port = Self::parse_database_url_port(port_str)?;
301 (host.to_string(), port)
302 } else {
303 if host_port.is_empty() {
304 return Err(PgError::Connection(
305 "Invalid DATABASE_URL: missing host".to_string(),
306 ));
307 }
308 (host_port.to_string(), 5432) };
310
311 Ok((host, port, user, database, password))
312 }
313
314 fn parse_database_url_port(port_str: &str) -> PgResult<u16> {
315 if port_str.is_empty() {
316 return Err(PgError::Connection(
317 "Invalid DATABASE_URL: missing port after ':'".to_string(),
318 ));
319 }
320 let port = port_str
321 .parse::<u16>()
322 .map_err(|_| PgError::Connection(format!("Invalid port: {}", port_str)))?;
323 if port == 0 {
324 return Err(PgError::Connection(
325 "Invalid port: 0 (expected 1..=65535)".to_string(),
326 ));
327 }
328 Ok(port)
329 }
330
331 pub(crate) fn percent_decode(s: &str) -> PgResult<String> {
334 fn hex_value(byte: u8) -> Option<u8> {
335 match byte {
336 b'0'..=b'9' => Some(byte - b'0'),
337 b'a'..=b'f' => Some(byte - b'a' + 10),
338 b'A'..=b'F' => Some(byte - b'A' + 10),
339 _ => None,
340 }
341 }
342
343 let bytes = s.as_bytes();
344 let mut decoded = Vec::with_capacity(bytes.len());
345 let mut i = 0;
346
347 while i < bytes.len() {
348 if bytes[i] == b'%'
349 && i + 2 < bytes.len()
350 && let (Some(hi), Some(lo)) = (hex_value(bytes[i + 1]), hex_value(bytes[i + 2]))
351 {
352 decoded.push((hi << 4) | lo);
353 i += 3;
354 } else {
355 decoded.push(bytes[i]);
356 i += 1;
357 }
358 }
359
360 String::from_utf8(decoded).map_err(|_| {
361 PgError::Connection(
362 "Invalid DATABASE_URL percent-encoding: decoded value is not UTF-8".to_string(),
363 )
364 })
365 }
366
367 pub async fn connect_with_timeout(
378 host: &str,
379 port: u16,
380 user: &str,
381 database: &str,
382 password: &str,
383 timeout: std::time::Duration,
384 ) -> PgResult<Self> {
385 tokio::time::timeout(
386 timeout,
387 Self::connect_with_password(host, port, user, database, password),
388 )
389 .await
390 .map_err(|_| PgError::Timeout(format!("connection after {:?}", timeout)))?
391 }
392 pub fn clear_cache(&mut self) {
396 self.connection.clear_prepared_statement_state();
397 }
398
399 pub fn cache_stats(&self) -> (usize, usize) {
402 (
403 self.connection.stmt_cache.len(),
404 self.connection.stmt_cache.cap().get(),
405 )
406 }
407}