1use std::time::Duration;
4
5use crate::error::{PgError, PgResult};
6
7#[derive(Debug, Clone)]
9pub struct PgConfig {
10 pub url: String,
12 pub host: String,
14 pub port: u16,
16 pub database: String,
18 pub user: String,
20 pub password: Option<String>,
22 pub ssl_mode: SslMode,
24 pub connect_timeout: Duration,
26 pub statement_timeout: Option<Duration>,
28 pub application_name: Option<String>,
30 pub options: Vec<(String, String)>,
32}
33
34#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
36pub enum SslMode {
37 Disable,
39 #[default]
41 Prefer,
42 Require,
44}
45
46impl PgConfig {
47 pub fn from_url(url: impl Into<String>) -> PgResult<Self> {
49 let url = url.into();
50 let parsed = url::Url::parse(&url)
51 .map_err(|e| PgError::config(format!("invalid database URL: {}", e)))?;
52
53 if parsed.scheme() != "postgresql" && parsed.scheme() != "postgres" {
54 return Err(PgError::config(format!(
55 "invalid scheme: expected 'postgresql' or 'postgres', got '{}'",
56 parsed.scheme()
57 )));
58 }
59
60 let host = parsed
61 .host_str()
62 .ok_or_else(|| PgError::config("missing host in URL"))?
63 .to_string();
64
65 let port = parsed.port().unwrap_or(5432);
66
67 let database = parsed.path().trim_start_matches('/').to_string();
68
69 if database.is_empty() {
70 return Err(PgError::config("missing database name in URL"));
71 }
72
73 let user = if parsed.username().is_empty() {
74 "postgres".to_string()
75 } else {
76 parsed.username().to_string()
77 };
78
79 let password = parsed.password().map(String::from);
80
81 let mut ssl_mode = SslMode::Prefer;
83 let mut connect_timeout = Duration::from_secs(30);
84 let mut statement_timeout = None;
85 let mut application_name = None;
86 let mut options = Vec::new();
87
88 for (key, value) in parsed.query_pairs() {
89 let key_str: &str = &key;
90 let value_str: &str = &value;
91 match key_str {
92 "sslmode" => {
93 ssl_mode = match value_str {
94 "disable" => SslMode::Disable,
95 "prefer" => SslMode::Prefer,
96 "require" => SslMode::Require,
97 other => {
98 return Err(PgError::config(format!("invalid sslmode: {}", other)));
99 }
100 };
101 }
102 "connect_timeout" => {
103 let secs: u64 = value_str
104 .parse()
105 .map_err(|_| PgError::config("invalid connect_timeout"))?;
106 connect_timeout = Duration::from_secs(secs);
107 }
108 "statement_timeout" => {
109 let ms: u64 = value_str
110 .parse()
111 .map_err(|_| PgError::config("invalid statement_timeout"))?;
112 statement_timeout = Some(Duration::from_millis(ms));
113 }
114 "application_name" => {
115 application_name = Some(value_str.to_string());
116 }
117 _ => {
118 options.push((key_str.to_string(), value_str.to_string()));
119 }
120 }
121 }
122
123 Ok(Self {
124 url,
125 host,
126 port,
127 database,
128 user,
129 password,
130 ssl_mode,
131 connect_timeout,
132 statement_timeout,
133 application_name,
134 options,
135 })
136 }
137
138 pub fn builder() -> PgConfigBuilder {
140 PgConfigBuilder::new()
141 }
142
143 pub fn to_pg_config(&self) -> tokio_postgres::Config {
145 let mut config = tokio_postgres::Config::new();
146 config.host(&self.host);
147 config.port(self.port);
148 config.dbname(&self.database);
149 config.user(&self.user);
150
151 if let Some(ref password) = self.password {
152 config.password(password);
153 }
154
155 if let Some(ref app_name) = self.application_name {
156 config.application_name(app_name);
157 }
158
159 config.connect_timeout(self.connect_timeout);
160
161 config
162 }
163}
164
165#[derive(Debug, Default)]
167pub struct PgConfigBuilder {
168 url: Option<String>,
169 host: Option<String>,
170 port: Option<u16>,
171 database: Option<String>,
172 user: Option<String>,
173 password: Option<String>,
174 ssl_mode: Option<SslMode>,
175 connect_timeout: Option<Duration>,
176 statement_timeout: Option<Duration>,
177 application_name: Option<String>,
178}
179
180impl PgConfigBuilder {
181 pub fn new() -> Self {
183 Self::default()
184 }
185
186 pub fn url(mut self, url: impl Into<String>) -> Self {
188 self.url = Some(url.into());
189 self
190 }
191
192 pub fn host(mut self, host: impl Into<String>) -> Self {
194 self.host = Some(host.into());
195 self
196 }
197
198 pub fn port(mut self, port: u16) -> Self {
200 self.port = Some(port);
201 self
202 }
203
204 pub fn database(mut self, database: impl Into<String>) -> Self {
206 self.database = Some(database.into());
207 self
208 }
209
210 pub fn user(mut self, user: impl Into<String>) -> Self {
212 self.user = Some(user.into());
213 self
214 }
215
216 pub fn password(mut self, password: impl Into<String>) -> Self {
218 self.password = Some(password.into());
219 self
220 }
221
222 pub fn ssl_mode(mut self, mode: SslMode) -> Self {
224 self.ssl_mode = Some(mode);
225 self
226 }
227
228 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
230 self.connect_timeout = Some(timeout);
231 self
232 }
233
234 pub fn statement_timeout(mut self, timeout: Duration) -> Self {
236 self.statement_timeout = Some(timeout);
237 self
238 }
239
240 pub fn application_name(mut self, name: impl Into<String>) -> Self {
242 self.application_name = Some(name.into());
243 self
244 }
245
246 pub fn build(self) -> PgResult<PgConfig> {
248 if let Some(url) = self.url {
249 let mut config = PgConfig::from_url(url)?;
250
251 if let Some(host) = self.host {
253 config.host = host;
254 }
255 if let Some(port) = self.port {
256 config.port = port;
257 }
258 if let Some(database) = self.database {
259 config.database = database;
260 }
261 if let Some(user) = self.user {
262 config.user = user;
263 }
264 if let Some(password) = self.password {
265 config.password = Some(password);
266 }
267 if let Some(ssl_mode) = self.ssl_mode {
268 config.ssl_mode = ssl_mode;
269 }
270 if let Some(timeout) = self.connect_timeout {
271 config.connect_timeout = timeout;
272 }
273 if let Some(timeout) = self.statement_timeout {
274 config.statement_timeout = Some(timeout);
275 }
276 if let Some(name) = self.application_name {
277 config.application_name = Some(name);
278 }
279
280 Ok(config)
281 } else {
282 let host = self.host.unwrap_or_else(|| "localhost".to_string());
284 let port = self.port.unwrap_or(5432);
285 let database = self
286 .database
287 .ok_or_else(|| PgError::config("database name is required"))?;
288 let user = self.user.unwrap_or_else(|| "postgres".to_string());
289
290 let url = format!(
291 "postgresql://{}{}@{}:{}/{}",
292 user,
293 self.password
294 .as_ref()
295 .map(|p| format!(":{}", p))
296 .unwrap_or_default(),
297 host,
298 port,
299 database
300 );
301
302 Ok(PgConfig {
303 url,
304 host,
305 port,
306 database,
307 user,
308 password: self.password,
309 ssl_mode: self.ssl_mode.unwrap_or_default(),
310 connect_timeout: self.connect_timeout.unwrap_or(Duration::from_secs(30)),
311 statement_timeout: self.statement_timeout,
312 application_name: self.application_name,
313 options: Vec::new(),
314 })
315 }
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322
323 #[test]
324 fn test_config_from_url() {
325 let config = PgConfig::from_url("postgresql://user:pass@localhost:5432/mydb").unwrap();
326 assert_eq!(config.host, "localhost");
327 assert_eq!(config.port, 5432);
328 assert_eq!(config.database, "mydb");
329 assert_eq!(config.user, "user");
330 assert_eq!(config.password, Some("pass".to_string()));
331 }
332
333 #[test]
334 fn test_config_from_url_with_params() {
335 let config =
336 PgConfig::from_url("postgresql://localhost/mydb?sslmode=require&application_name=prax")
337 .unwrap();
338 assert_eq!(config.ssl_mode, SslMode::Require);
339 assert_eq!(config.application_name, Some("prax".to_string()));
340 }
341
342 #[test]
343 fn test_config_builder() {
344 let config = PgConfig::builder()
345 .host("localhost")
346 .port(5432)
347 .database("mydb")
348 .user("postgres")
349 .build()
350 .unwrap();
351
352 assert_eq!(config.host, "localhost");
353 assert_eq!(config.database, "mydb");
354 }
355
356 #[test]
357 fn test_config_invalid_scheme() {
358 let result = PgConfig::from_url("mysql://localhost/db");
359 assert!(result.is_err());
360 }
361}