prax_query/connection/
options.rs

1//! Connection and pool options.
2
3use std::collections::HashMap;
4use std::time::Duration;
5
6/// SSL/TLS mode for connections.
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
8pub enum SslMode {
9    /// Disable SSL.
10    Disable,
11    /// Allow SSL but don't require it.
12    Allow,
13    /// Prefer SSL but allow non-SSL.
14    #[default]
15    Prefer,
16    /// Require SSL.
17    Require,
18    /// Require SSL and verify the server certificate.
19    VerifyCa,
20    /// Require SSL and verify the server certificate and hostname.
21    VerifyFull,
22}
23
24impl SslMode {
25    /// Parse from string.
26    pub fn parse(s: &str) -> Option<Self> {
27        match s.to_lowercase().as_str() {
28            "disable" | "false" | "0" => Some(Self::Disable),
29            "allow" => Some(Self::Allow),
30            "prefer" => Some(Self::Prefer),
31            "require" | "true" | "1" => Some(Self::Require),
32            "verify-ca" | "verify_ca" => Some(Self::VerifyCa),
33            "verify-full" | "verify_full" => Some(Self::VerifyFull),
34            _ => None,
35        }
36    }
37
38    /// Convert to string.
39    pub fn as_str(&self) -> &'static str {
40        match self {
41            Self::Disable => "disable",
42            Self::Allow => "allow",
43            Self::Prefer => "prefer",
44            Self::Require => "require",
45            Self::VerifyCa => "verify-ca",
46            Self::VerifyFull => "verify-full",
47        }
48    }
49}
50
51/// SSL/TLS configuration.
52#[derive(Debug, Clone, Default)]
53pub struct SslConfig {
54    /// SSL mode.
55    pub mode: SslMode,
56    /// Path to CA certificate.
57    pub ca_cert: Option<String>,
58    /// Path to client certificate.
59    pub client_cert: Option<String>,
60    /// Path to client key.
61    pub client_key: Option<String>,
62    /// Server name for SNI.
63    pub server_name: Option<String>,
64}
65
66impl SslConfig {
67    /// Create a new SSL config.
68    pub fn new(mode: SslMode) -> Self {
69        Self {
70            mode,
71            ..Default::default()
72        }
73    }
74
75    /// Require SSL.
76    pub fn require() -> Self {
77        Self::new(SslMode::Require)
78    }
79
80    /// Set CA certificate path.
81    pub fn with_ca_cert(mut self, path: impl Into<String>) -> Self {
82        self.ca_cert = Some(path.into());
83        self
84    }
85
86    /// Set client certificate path.
87    pub fn with_client_cert(mut self, path: impl Into<String>) -> Self {
88        self.client_cert = Some(path.into());
89        self
90    }
91
92    /// Set client key path.
93    pub fn with_client_key(mut self, path: impl Into<String>) -> Self {
94        self.client_key = Some(path.into());
95        self
96    }
97}
98
99/// Common connection options.
100#[derive(Debug, Clone)]
101pub struct ConnectionOptions {
102    /// Connection timeout.
103    pub connect_timeout: Duration,
104    /// Read timeout.
105    pub read_timeout: Option<Duration>,
106    /// Write timeout.
107    pub write_timeout: Option<Duration>,
108    /// SSL configuration.
109    pub ssl: SslConfig,
110    /// Application name.
111    pub application_name: Option<String>,
112    /// Schema/database to use after connecting.
113    pub schema: Option<String>,
114    /// Additional options as key-value pairs.
115    pub extra: HashMap<String, String>,
116}
117
118impl Default for ConnectionOptions {
119    fn default() -> Self {
120        Self {
121            connect_timeout: Duration::from_secs(30),
122            read_timeout: None,
123            write_timeout: None,
124            ssl: SslConfig::default(),
125            application_name: None,
126            schema: None,
127            extra: HashMap::new(),
128        }
129    }
130}
131
132impl ConnectionOptions {
133    /// Create new connection options.
134    pub fn new() -> Self {
135        Self::default()
136    }
137
138    /// Set connection timeout.
139    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
140        self.connect_timeout = timeout;
141        self
142    }
143
144    /// Set read timeout.
145    pub fn read_timeout(mut self, timeout: Duration) -> Self {
146        self.read_timeout = Some(timeout);
147        self
148    }
149
150    /// Set write timeout.
151    pub fn write_timeout(mut self, timeout: Duration) -> Self {
152        self.write_timeout = Some(timeout);
153        self
154    }
155
156    /// Set SSL mode.
157    pub fn ssl_mode(mut self, mode: SslMode) -> Self {
158        self.ssl.mode = mode;
159        self
160    }
161
162    /// Set SSL configuration.
163    pub fn ssl(mut self, config: SslConfig) -> Self {
164        self.ssl = config;
165        self
166    }
167
168    /// Set application name.
169    pub fn application_name(mut self, name: impl Into<String>) -> Self {
170        self.application_name = Some(name.into());
171        self
172    }
173
174    /// Set schema.
175    pub fn schema(mut self, schema: impl Into<String>) -> Self {
176        self.schema = Some(schema.into());
177        self
178    }
179
180    /// Add extra option.
181    pub fn option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
182        self.extra.insert(key.into(), value.into());
183        self
184    }
185
186    /// Parse options from URL query parameters.
187    pub fn from_params(params: &HashMap<String, String>) -> Self {
188        let mut opts = Self::default();
189
190        if let Some(timeout) = params.get("connect_timeout") {
191            if let Ok(secs) = timeout.parse::<u64>() {
192                opts.connect_timeout = Duration::from_secs(secs);
193            }
194        }
195
196        if let Some(timeout) = params.get("read_timeout") {
197            if let Ok(secs) = timeout.parse::<u64>() {
198                opts.read_timeout = Some(Duration::from_secs(secs));
199            }
200        }
201
202        if let Some(timeout) = params.get("write_timeout") {
203            if let Ok(secs) = timeout.parse::<u64>() {
204                opts.write_timeout = Some(Duration::from_secs(secs));
205            }
206        }
207
208        if let Some(ssl) = params.get("sslmode").or_else(|| params.get("ssl")) {
209            if let Some(mode) = SslMode::parse(ssl) {
210                opts.ssl.mode = mode;
211            }
212        }
213
214        if let Some(name) = params.get("application_name") {
215            opts.application_name = Some(name.clone());
216        }
217
218        if let Some(schema) = params.get("schema").or_else(|| params.get("search_path")) {
219            opts.schema = Some(schema.clone());
220        }
221
222        // Copy remaining params as extra options
223        for (key, value) in params {
224            if !matches!(
225                key.as_str(),
226                "connect_timeout"
227                    | "read_timeout"
228                    | "write_timeout"
229                    | "sslmode"
230                    | "ssl"
231                    | "application_name"
232                    | "schema"
233                    | "search_path"
234            ) {
235                opts.extra.insert(key.clone(), value.clone());
236            }
237        }
238
239        opts
240    }
241}
242
243/// Pool options.
244#[derive(Debug, Clone)]
245pub struct PoolOptions {
246    /// Maximum number of connections.
247    pub max_connections: u32,
248    /// Minimum number of connections to keep idle.
249    pub min_connections: u32,
250    /// Maximum time to wait for a connection.
251    pub acquire_timeout: Duration,
252    /// Maximum idle time before closing a connection.
253    pub idle_timeout: Option<Duration>,
254    /// Maximum lifetime of a connection.
255    pub max_lifetime: Option<Duration>,
256    /// Test connections before returning them.
257    pub test_before_acquire: bool,
258}
259
260impl Default for PoolOptions {
261    fn default() -> Self {
262        Self {
263            max_connections: 10,
264            min_connections: 1,
265            acquire_timeout: Duration::from_secs(30),
266            idle_timeout: Some(Duration::from_secs(600)),
267            max_lifetime: Some(Duration::from_secs(1800)),
268            test_before_acquire: true,
269        }
270    }
271}
272
273impl PoolOptions {
274    /// Create new pool options.
275    pub fn new() -> Self {
276        Self::default()
277    }
278
279    /// Set max connections.
280    pub fn max_connections(mut self, n: u32) -> Self {
281        self.max_connections = n;
282        self
283    }
284
285    /// Set min connections.
286    pub fn min_connections(mut self, n: u32) -> Self {
287        self.min_connections = n;
288        self
289    }
290
291    /// Set acquire timeout.
292    pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
293        self.acquire_timeout = timeout;
294        self
295    }
296
297    /// Set idle timeout.
298    pub fn idle_timeout(mut self, timeout: Duration) -> Self {
299        self.idle_timeout = Some(timeout);
300        self
301    }
302
303    /// Disable idle timeout.
304    pub fn no_idle_timeout(mut self) -> Self {
305        self.idle_timeout = None;
306        self
307    }
308
309    /// Set max lifetime.
310    pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
311        self.max_lifetime = Some(lifetime);
312        self
313    }
314
315    /// Disable max lifetime.
316    pub fn no_max_lifetime(mut self) -> Self {
317        self.max_lifetime = None;
318        self
319    }
320
321    /// Enable/disable test before acquire.
322    pub fn test_before_acquire(mut self, enabled: bool) -> Self {
323        self.test_before_acquire = enabled;
324        self
325    }
326}
327
328/// PostgreSQL-specific options.
329#[derive(Debug, Clone, Default)]
330pub struct PostgresOptions {
331    /// Statement cache capacity.
332    pub statement_cache_capacity: usize,
333    /// Enable prepared statements.
334    pub prepared_statements: bool,
335    /// Channel binding mode.
336    pub channel_binding: Option<String>,
337    /// Target session attributes.
338    pub target_session_attrs: Option<String>,
339}
340
341impl PostgresOptions {
342    /// Create new PostgreSQL options.
343    pub fn new() -> Self {
344        Self {
345            statement_cache_capacity: 100,
346            prepared_statements: true,
347            channel_binding: None,
348            target_session_attrs: None,
349        }
350    }
351
352    /// Set statement cache capacity.
353    pub fn statement_cache(mut self, capacity: usize) -> Self {
354        self.statement_cache_capacity = capacity;
355        self
356    }
357
358    /// Enable/disable prepared statements.
359    pub fn prepared_statements(mut self, enabled: bool) -> Self {
360        self.prepared_statements = enabled;
361        self
362    }
363}
364
365/// MySQL-specific options.
366#[derive(Debug, Clone, Default)]
367pub struct MySqlOptions {
368    /// Enable compression.
369    pub compression: bool,
370    /// Character set.
371    pub charset: Option<String>,
372    /// Collation.
373    pub collation: Option<String>,
374    /// SQL mode.
375    pub sql_mode: Option<String>,
376    /// Timezone.
377    pub timezone: Option<String>,
378}
379
380impl MySqlOptions {
381    /// Create new MySQL options.
382    pub fn new() -> Self {
383        Self::default()
384    }
385
386    /// Enable compression.
387    pub fn compression(mut self, enabled: bool) -> Self {
388        self.compression = enabled;
389        self
390    }
391
392    /// Set character set.
393    pub fn charset(mut self, charset: impl Into<String>) -> Self {
394        self.charset = Some(charset.into());
395        self
396    }
397
398    /// Set SQL mode.
399    pub fn sql_mode(mut self, mode: impl Into<String>) -> Self {
400        self.sql_mode = Some(mode.into());
401        self
402    }
403}
404
405/// SQLite-specific options.
406#[derive(Debug, Clone)]
407pub struct SqliteOptions {
408    /// Journal mode.
409    pub journal_mode: SqliteJournalMode,
410    /// Synchronous mode.
411    pub synchronous: SqliteSynchronous,
412    /// Foreign keys enforcement.
413    pub foreign_keys: bool,
414    /// Busy timeout in milliseconds.
415    pub busy_timeout: u32,
416    /// Cache size in pages (negative for KB).
417    pub cache_size: i32,
418}
419
420impl Default for SqliteOptions {
421    fn default() -> Self {
422        Self {
423            journal_mode: SqliteJournalMode::Wal,
424            synchronous: SqliteSynchronous::Normal,
425            foreign_keys: true,
426            busy_timeout: 5000,
427            cache_size: -2000, // 2MB
428        }
429    }
430}
431
432impl SqliteOptions {
433    /// Create new SQLite options.
434    pub fn new() -> Self {
435        Self::default()
436    }
437
438    /// Set journal mode.
439    pub fn journal_mode(mut self, mode: SqliteJournalMode) -> Self {
440        self.journal_mode = mode;
441        self
442    }
443
444    /// Set synchronous mode.
445    pub fn synchronous(mut self, mode: SqliteSynchronous) -> Self {
446        self.synchronous = mode;
447        self
448    }
449
450    /// Enable/disable foreign keys.
451    pub fn foreign_keys(mut self, enabled: bool) -> Self {
452        self.foreign_keys = enabled;
453        self
454    }
455
456    /// Set busy timeout in milliseconds.
457    pub fn busy_timeout(mut self, ms: u32) -> Self {
458        self.busy_timeout = ms;
459        self
460    }
461
462    /// Generate PRAGMA statements.
463    pub fn to_pragmas(&self) -> Vec<String> {
464        vec![
465            format!("PRAGMA journal_mode = {};", self.journal_mode.as_str()),
466            format!("PRAGMA synchronous = {};", self.synchronous.as_str()),
467            format!(
468                "PRAGMA foreign_keys = {};",
469                if self.foreign_keys { "ON" } else { "OFF" }
470            ),
471            format!("PRAGMA busy_timeout = {};", self.busy_timeout),
472            format!("PRAGMA cache_size = {};", self.cache_size),
473        ]
474    }
475}
476
477/// SQLite journal mode.
478#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
479pub enum SqliteJournalMode {
480    /// Delete journal after transaction.
481    Delete,
482    /// Truncate journal.
483    Truncate,
484    /// Persist journal.
485    Persist,
486    /// In-memory journal.
487    Memory,
488    /// Write-ahead logging (recommended).
489    #[default]
490    Wal,
491    /// Disable journaling.
492    Off,
493}
494
495impl SqliteJournalMode {
496    /// Get the SQL string.
497    pub fn as_str(&self) -> &'static str {
498        match self {
499            Self::Delete => "DELETE",
500            Self::Truncate => "TRUNCATE",
501            Self::Persist => "PERSIST",
502            Self::Memory => "MEMORY",
503            Self::Wal => "WAL",
504            Self::Off => "OFF",
505        }
506    }
507}
508
509/// SQLite synchronous mode.
510#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
511pub enum SqliteSynchronous {
512    /// No synchronization.
513    Off,
514    /// Normal synchronization.
515    #[default]
516    Normal,
517    /// Full synchronization.
518    Full,
519    /// Extra synchronization.
520    Extra,
521}
522
523impl SqliteSynchronous {
524    /// Get the SQL string.
525    pub fn as_str(&self) -> &'static str {
526        match self {
527            Self::Off => "OFF",
528            Self::Normal => "NORMAL",
529            Self::Full => "FULL",
530            Self::Extra => "EXTRA",
531        }
532    }
533}
534
535#[cfg(test)]
536mod tests {
537    use super::*;
538
539    #[test]
540    fn test_ssl_mode_parse() {
541        assert_eq!(SslMode::parse("disable"), Some(SslMode::Disable));
542        assert_eq!(SslMode::parse("require"), Some(SslMode::Require));
543        assert_eq!(SslMode::parse("verify-full"), Some(SslMode::VerifyFull));
544        assert_eq!(SslMode::parse("invalid"), None);
545    }
546
547    #[test]
548    fn test_connection_options_builder() {
549        let opts = ConnectionOptions::new()
550            .connect_timeout(Duration::from_secs(10))
551            .ssl_mode(SslMode::Require)
552            .application_name("test-app");
553
554        assert_eq!(opts.connect_timeout, Duration::from_secs(10));
555        assert_eq!(opts.ssl.mode, SslMode::Require);
556        assert_eq!(opts.application_name, Some("test-app".to_string()));
557    }
558
559    #[test]
560    fn test_pool_options_builder() {
561        let opts = PoolOptions::new()
562            .max_connections(20)
563            .min_connections(5)
564            .no_idle_timeout();
565
566        assert_eq!(opts.max_connections, 20);
567        assert_eq!(opts.min_connections, 5);
568        assert_eq!(opts.idle_timeout, None);
569    }
570
571    #[test]
572    fn test_sqlite_options_pragmas() {
573        let opts = SqliteOptions::new()
574            .journal_mode(SqliteJournalMode::Wal)
575            .foreign_keys(true);
576
577        let pragmas = opts.to_pragmas();
578        assert!(pragmas.iter().any(|p| p.contains("journal_mode = WAL")));
579        assert!(pragmas.iter().any(|p| p.contains("foreign_keys = ON")));
580    }
581
582    #[test]
583    fn test_options_from_params() {
584        let mut params = HashMap::new();
585        params.insert("connect_timeout".to_string(), "10".to_string());
586        params.insert("sslmode".to_string(), "require".to_string());
587        params.insert("application_name".to_string(), "myapp".to_string());
588
589        let opts = ConnectionOptions::from_params(&params);
590        assert_eq!(opts.connect_timeout, Duration::from_secs(10));
591        assert_eq!(opts.ssl.mode, SslMode::Require);
592        assert_eq!(opts.application_name, Some("myapp".to_string()));
593    }
594}