prax_postgres/
config.rs

1//! PostgreSQL connection configuration.
2
3use std::time::Duration;
4
5use crate::error::{PgError, PgResult};
6
7/// PostgreSQL connection configuration.
8#[derive(Debug, Clone)]
9pub struct PgConfig {
10    /// Database URL.
11    pub url: String,
12    /// Host (extracted from URL or explicit).
13    pub host: String,
14    /// Port (default: 5432).
15    pub port: u16,
16    /// Database name.
17    pub database: String,
18    /// Username.
19    pub user: String,
20    /// Password.
21    pub password: Option<String>,
22    /// SSL mode.
23    pub ssl_mode: SslMode,
24    /// Connection timeout.
25    pub connect_timeout: Duration,
26    /// Statement timeout.
27    pub statement_timeout: Option<Duration>,
28    /// Application name (shown in pg_stat_activity).
29    pub application_name: Option<String>,
30    /// Additional options.
31    pub options: Vec<(String, String)>,
32}
33
34/// SSL mode for connections.
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
36pub enum SslMode {
37    /// Disable SSL.
38    Disable,
39    /// Prefer SSL but allow non-SSL.
40    #[default]
41    Prefer,
42    /// Require SSL.
43    Require,
44}
45
46impl PgConfig {
47    /// Create a new configuration from a database URL.
48    pub fn from_url(url: impl Into<String>) -> PgResult<Self> {
49        let url = url.into();
50        let parsed = url::Url::parse(&url)
51            .map_err(|e| PgError::config(format!("invalid database URL: {}", e)))?;
52
53        if parsed.scheme() != "postgresql" && parsed.scheme() != "postgres" {
54            return Err(PgError::config(format!(
55                "invalid scheme: expected 'postgresql' or 'postgres', got '{}'",
56                parsed.scheme()
57            )));
58        }
59
60        let host = parsed
61            .host_str()
62            .ok_or_else(|| PgError::config("missing host in URL"))?
63            .to_string();
64
65        let port = parsed.port().unwrap_or(5432);
66
67        let database = parsed.path().trim_start_matches('/').to_string();
68
69        if database.is_empty() {
70            return Err(PgError::config("missing database name in URL"));
71        }
72
73        let user = if parsed.username().is_empty() {
74            "postgres".to_string()
75        } else {
76            parsed.username().to_string()
77        };
78
79        let password = parsed.password().map(String::from);
80
81        // Parse query parameters
82        let mut ssl_mode = SslMode::Prefer;
83        let mut connect_timeout = Duration::from_secs(30);
84        let mut statement_timeout = None;
85        let mut application_name = None;
86        let mut options = Vec::new();
87
88        for (key, value) in parsed.query_pairs() {
89            let key_str: &str = &key;
90            let value_str: &str = &value;
91            match key_str {
92                "sslmode" => {
93                    ssl_mode = match value_str {
94                        "disable" => SslMode::Disable,
95                        "prefer" => SslMode::Prefer,
96                        "require" => SslMode::Require,
97                        other => {
98                            return Err(PgError::config(format!("invalid sslmode: {}", other)));
99                        }
100                    };
101                }
102                "connect_timeout" => {
103                    let secs: u64 = value_str
104                        .parse()
105                        .map_err(|_| PgError::config("invalid connect_timeout"))?;
106                    connect_timeout = Duration::from_secs(secs);
107                }
108                "statement_timeout" => {
109                    let ms: u64 = value_str
110                        .parse()
111                        .map_err(|_| PgError::config("invalid statement_timeout"))?;
112                    statement_timeout = Some(Duration::from_millis(ms));
113                }
114                "application_name" => {
115                    application_name = Some(value_str.to_string());
116                }
117                _ => {
118                    options.push((key_str.to_string(), value_str.to_string()));
119                }
120            }
121        }
122
123        Ok(Self {
124            url,
125            host,
126            port,
127            database,
128            user,
129            password,
130            ssl_mode,
131            connect_timeout,
132            statement_timeout,
133            application_name,
134            options,
135        })
136    }
137
138    /// Create a builder for configuration.
139    pub fn builder() -> PgConfigBuilder {
140        PgConfigBuilder::new()
141    }
142
143    /// Convert to tokio-postgres config.
144    pub fn to_pg_config(&self) -> tokio_postgres::Config {
145        let mut config = tokio_postgres::Config::new();
146        config.host(&self.host);
147        config.port(self.port);
148        config.dbname(&self.database);
149        config.user(&self.user);
150
151        if let Some(ref password) = self.password {
152            config.password(password);
153        }
154
155        if let Some(ref app_name) = self.application_name {
156            config.application_name(app_name);
157        }
158
159        config.connect_timeout(self.connect_timeout);
160
161        config
162    }
163}
164
165/// Builder for PostgreSQL configuration.
166#[derive(Debug, Default)]
167pub struct PgConfigBuilder {
168    url: Option<String>,
169    host: Option<String>,
170    port: Option<u16>,
171    database: Option<String>,
172    user: Option<String>,
173    password: Option<String>,
174    ssl_mode: Option<SslMode>,
175    connect_timeout: Option<Duration>,
176    statement_timeout: Option<Duration>,
177    application_name: Option<String>,
178}
179
180impl PgConfigBuilder {
181    /// Create a new builder.
182    pub fn new() -> Self {
183        Self::default()
184    }
185
186    /// Set the database URL (parses all connection parameters).
187    pub fn url(mut self, url: impl Into<String>) -> Self {
188        self.url = Some(url.into());
189        self
190    }
191
192    /// Set the host.
193    pub fn host(mut self, host: impl Into<String>) -> Self {
194        self.host = Some(host.into());
195        self
196    }
197
198    /// Set the port.
199    pub fn port(mut self, port: u16) -> Self {
200        self.port = Some(port);
201        self
202    }
203
204    /// Set the database name.
205    pub fn database(mut self, database: impl Into<String>) -> Self {
206        self.database = Some(database.into());
207        self
208    }
209
210    /// Set the username.
211    pub fn user(mut self, user: impl Into<String>) -> Self {
212        self.user = Some(user.into());
213        self
214    }
215
216    /// Set the password.
217    pub fn password(mut self, password: impl Into<String>) -> Self {
218        self.password = Some(password.into());
219        self
220    }
221
222    /// Set the SSL mode.
223    pub fn ssl_mode(mut self, mode: SslMode) -> Self {
224        self.ssl_mode = Some(mode);
225        self
226    }
227
228    /// Set the connection timeout.
229    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
230        self.connect_timeout = Some(timeout);
231        self
232    }
233
234    /// Set the statement timeout.
235    pub fn statement_timeout(mut self, timeout: Duration) -> Self {
236        self.statement_timeout = Some(timeout);
237        self
238    }
239
240    /// Set the application name.
241    pub fn application_name(mut self, name: impl Into<String>) -> Self {
242        self.application_name = Some(name.into());
243        self
244    }
245
246    /// Build the configuration.
247    pub fn build(self) -> PgResult<PgConfig> {
248        if let Some(url) = self.url {
249            let mut config = PgConfig::from_url(url)?;
250
251            // Override with explicit values
252            if let Some(host) = self.host {
253                config.host = host;
254            }
255            if let Some(port) = self.port {
256                config.port = port;
257            }
258            if let Some(database) = self.database {
259                config.database = database;
260            }
261            if let Some(user) = self.user {
262                config.user = user;
263            }
264            if let Some(password) = self.password {
265                config.password = Some(password);
266            }
267            if let Some(ssl_mode) = self.ssl_mode {
268                config.ssl_mode = ssl_mode;
269            }
270            if let Some(timeout) = self.connect_timeout {
271                config.connect_timeout = timeout;
272            }
273            if let Some(timeout) = self.statement_timeout {
274                config.statement_timeout = Some(timeout);
275            }
276            if let Some(name) = self.application_name {
277                config.application_name = Some(name);
278            }
279
280            Ok(config)
281        } else {
282            // Build from individual components
283            let host = self.host.unwrap_or_else(|| "localhost".to_string());
284            let port = self.port.unwrap_or(5432);
285            let database = self
286                .database
287                .ok_or_else(|| PgError::config("database name is required"))?;
288            let user = self.user.unwrap_or_else(|| "postgres".to_string());
289
290            let url = format!(
291                "postgresql://{}{}@{}:{}/{}",
292                user,
293                self.password
294                    .as_ref()
295                    .map(|p| format!(":{}", p))
296                    .unwrap_or_default(),
297                host,
298                port,
299                database
300            );
301
302            Ok(PgConfig {
303                url,
304                host,
305                port,
306                database,
307                user,
308                password: self.password,
309                ssl_mode: self.ssl_mode.unwrap_or_default(),
310                connect_timeout: self.connect_timeout.unwrap_or(Duration::from_secs(30)),
311                statement_timeout: self.statement_timeout,
312                application_name: self.application_name,
313                options: Vec::new(),
314            })
315        }
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    #[test]
324    fn test_config_from_url() {
325        let config = PgConfig::from_url("postgresql://user:pass@localhost:5432/mydb").unwrap();
326        assert_eq!(config.host, "localhost");
327        assert_eq!(config.port, 5432);
328        assert_eq!(config.database, "mydb");
329        assert_eq!(config.user, "user");
330        assert_eq!(config.password, Some("pass".to_string()));
331    }
332
333    #[test]
334    fn test_config_from_url_with_params() {
335        let config =
336            PgConfig::from_url("postgresql://localhost/mydb?sslmode=require&application_name=prax")
337                .unwrap();
338        assert_eq!(config.ssl_mode, SslMode::Require);
339        assert_eq!(config.application_name, Some("prax".to_string()));
340    }
341
342    #[test]
343    fn test_config_builder() {
344        let config = PgConfig::builder()
345            .host("localhost")
346            .port(5432)
347            .database("mydb")
348            .user("postgres")
349            .build()
350            .unwrap();
351
352        assert_eq!(config.host, "localhost");
353        assert_eq!(config.database, "mydb");
354    }
355
356    #[test]
357    fn test_config_invalid_scheme() {
358        let result = PgConfig::from_url("mysql://localhost/db");
359        assert!(result.is_err());
360    }
361}