1use super::auth_types::*;
5use super::builder::PgDriverBuilder;
6use super::connection::PgConnection;
7use super::pool;
8use super::rls::RlsContext;
9use super::types::*;
10
11pub struct PgDriver {
13 #[allow(dead_code)]
14 pub(super) connection: PgConnection,
15 pub(super) rls_context: Option<RlsContext>,
17}
18
19impl PgDriver {
20 pub fn new(connection: PgConnection) -> Self {
22 Self {
23 connection,
24 rls_context: None,
25 }
26 }
27
28 pub fn builder() -> PgDriverBuilder {
41 PgDriverBuilder::new()
42 }
43
44 pub async fn connect(host: &str, port: u16, user: &str, database: &str) -> PgResult<Self> {
53 let connection = PgConnection::connect(host, port, user, database).await?;
54 Ok(Self::new(connection))
55 }
56
57 pub async fn connect_with_password(
60 host: &str,
61 port: u16,
62 user: &str,
63 database: &str,
64 password: &str,
65 ) -> PgResult<Self> {
66 let connection =
67 PgConnection::connect_with_password(host, port, user, database, Some(password)).await?;
68 Ok(Self::new(connection))
69 }
70
71 pub async fn connect_with_options(
73 host: &str,
74 port: u16,
75 user: &str,
76 database: &str,
77 password: Option<&str>,
78 options: ConnectOptions,
79 ) -> PgResult<Self> {
80 let connection =
81 PgConnection::connect_with_options(host, port, user, database, password, options)
82 .await?;
83 Ok(Self::new(connection))
84 }
85
86 pub async fn connect_logical_replication(
91 host: &str,
92 port: u16,
93 user: &str,
94 database: &str,
95 password: Option<&str>,
96 ) -> PgResult<Self> {
97 let options = ConnectOptions::default().with_logical_replication();
98 Self::connect_with_options(host, port, user, database, password, options).await
99 }
100
101 pub async fn connect_logical_replication_with_options(
103 host: &str,
104 port: u16,
105 user: &str,
106 database: &str,
107 password: Option<&str>,
108 options: ConnectOptions,
109 ) -> PgResult<Self> {
110 Self::connect_with_options(
111 host,
112 port,
113 user,
114 database,
115 password,
116 options.with_logical_replication(),
117 )
118 .await
119 }
120
121 pub async fn connect_env() -> PgResult<Self> {
132 let url = std::env::var("DATABASE_URL").map_err(|_| {
133 PgError::Connection("DATABASE_URL environment variable not set".to_string())
134 })?;
135 Self::connect_url(&url).await
136 }
137
138 pub async fn connect_url(url: &str) -> PgResult<Self> {
151 let (host, port, user, database, password) = Self::parse_database_url(url)?;
152
153 let mut pool_cfg = pool::PoolConfig::new(&host, port, &user, &database);
155 if let Some(pw) = &password {
156 pool_cfg = pool_cfg.password(pw);
157 }
158 if let Some(query) = url.split('?').nth(1) {
159 pool::apply_url_query_params(&mut pool_cfg, query, &host)?;
160 }
161
162 let mut opts = ConnectOptions {
163 tls_mode: pool_cfg.tls_mode,
164 gss_enc_mode: pool_cfg.gss_enc_mode,
165 tls_ca_cert_pem: pool_cfg.tls_ca_cert_pem,
166 mtls: pool_cfg.mtls,
167 gss_token_provider: pool_cfg.gss_token_provider,
168 gss_token_provider_ex: pool_cfg.gss_token_provider_ex,
169 auth: pool_cfg.auth_settings,
170 startup_params: Vec::new(),
171 };
172
173 if let Some(query) = url.split('?').nth(1) {
175 for pair in query.split('&') {
176 let mut kv = pair.splitn(2, '=');
177 let key = kv.next().unwrap_or_default().trim();
178 let value = kv.next().unwrap_or_default().trim();
179 if key.eq_ignore_ascii_case("replication") {
180 let replication_mode = if value.eq_ignore_ascii_case("database") {
181 "database"
182 } else if value.eq_ignore_ascii_case("true")
183 || value.eq_ignore_ascii_case("on")
184 || value == "1"
185 {
186 "database"
189 } else {
190 return Err(PgError::Connection(format!(
191 "Invalid replication startup mode '{}': expected database|true|on|1",
192 value
193 )));
194 };
195 opts = opts.with_startup_param("replication", replication_mode);
196 }
197 }
198 }
199
200 Self::connect_with_options(&host, port, &user, &database, password.as_deref(), opts).await
201 }
202
203 pub(crate) fn parse_database_url(
210 url: &str,
211 ) -> PgResult<(String, u16, String, String, Option<String>)> {
212 let after_scheme = url.split("://").nth(1).ok_or_else(|| {
214 PgError::Connection("Invalid DATABASE_URL: missing scheme".to_string())
215 })?;
216
217 let (auth_part, host_db_part) = if let Some(at_pos) = after_scheme.rfind('@') {
219 (Some(&after_scheme[..at_pos]), &after_scheme[at_pos + 1..])
220 } else {
221 (None, after_scheme)
222 };
223
224 let (user, password) = if let Some(auth) = auth_part {
226 let parts: Vec<&str> = auth.splitn(2, ':').collect();
227 if parts.len() == 2 {
228 (
230 Self::percent_decode(parts[0]),
231 Some(Self::percent_decode(parts[1])),
232 )
233 } else {
234 (Self::percent_decode(parts[0]), None)
235 }
236 } else {
237 return Err(PgError::Connection(
238 "Invalid DATABASE_URL: missing user".to_string(),
239 ));
240 };
241
242 let (host_port, database) = if let Some(slash_pos) = host_db_part.find('/') {
244 let raw_db = &host_db_part[slash_pos + 1..];
245 let db = raw_db.split('?').next().unwrap_or(raw_db).to_string();
247 (&host_db_part[..slash_pos], db)
248 } else {
249 return Err(PgError::Connection(
250 "Invalid DATABASE_URL: missing database name".to_string(),
251 ));
252 };
253
254 let (host, port) = if let Some(colon_pos) = host_port.rfind(':') {
256 let port_str = &host_port[colon_pos + 1..];
257 let port = port_str
258 .parse::<u16>()
259 .map_err(|_| PgError::Connection(format!("Invalid port: {}", port_str)))?;
260 (host_port[..colon_pos].to_string(), port)
261 } else {
262 (host_port.to_string(), 5432) };
264
265 Ok((host, port, user, database, password))
266 }
267
268 pub(crate) fn percent_decode(s: &str) -> String {
271 let mut result = String::with_capacity(s.len());
272 let mut chars = s.chars().peekable();
273
274 while let Some(c) = chars.next() {
275 if c == '%' {
276 let hex: String = chars.by_ref().take(2).collect();
278 if hex.len() == 2
279 && let Ok(byte) = u8::from_str_radix(&hex, 16)
280 {
281 result.push(byte as char);
282 continue;
283 }
284 result.push('%');
286 result.push_str(&hex);
287 } else if c == '+' {
288 result.push('+');
291 } else {
292 result.push(c);
293 }
294 }
295
296 result
297 }
298
299 pub async fn connect_with_timeout(
310 host: &str,
311 port: u16,
312 user: &str,
313 database: &str,
314 password: &str,
315 timeout: std::time::Duration,
316 ) -> PgResult<Self> {
317 tokio::time::timeout(
318 timeout,
319 Self::connect_with_password(host, port, user, database, password),
320 )
321 .await
322 .map_err(|_| PgError::Timeout(format!("connection after {:?}", timeout)))?
323 }
324 pub fn clear_cache(&mut self) {
328 self.connection.clear_prepared_statement_state();
329 }
330
331 pub fn cache_stats(&self) -> (usize, usize) {
334 (
335 self.connection.stmt_cache.len(),
336 self.connection.stmt_cache.cap().get(),
337 )
338 }
339}