prax_query/connection/
options.rs

1//! Connection and pool options.
2
3use super::{ConnectionError, ConnectionResult, ConnectionString, Driver};
4use std::collections::HashMap;
5use std::time::Duration;
6
7/// SSL/TLS mode for connections.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
9pub enum SslMode {
10    /// Disable SSL.
11    Disable,
12    /// Allow SSL but don't require it.
13    Allow,
14    /// Prefer SSL but allow non-SSL.
15    #[default]
16    Prefer,
17    /// Require SSL.
18    Require,
19    /// Require SSL and verify the server certificate.
20    VerifyCa,
21    /// Require SSL and verify the server certificate and hostname.
22    VerifyFull,
23}
24
25impl SslMode {
26    /// Parse from string.
27    pub fn from_str(s: &str) -> Option<Self> {
28        match s.to_lowercase().as_str() {
29            "disable" | "false" | "0" => Some(Self::Disable),
30            "allow" => Some(Self::Allow),
31            "prefer" => Some(Self::Prefer),
32            "require" | "true" | "1" => Some(Self::Require),
33            "verify-ca" | "verify_ca" => Some(Self::VerifyCa),
34            "verify-full" | "verify_full" => Some(Self::VerifyFull),
35            _ => None,
36        }
37    }
38
39    /// Convert to string.
40    pub fn as_str(&self) -> &'static str {
41        match self {
42            Self::Disable => "disable",
43            Self::Allow => "allow",
44            Self::Prefer => "prefer",
45            Self::Require => "require",
46            Self::VerifyCa => "verify-ca",
47            Self::VerifyFull => "verify-full",
48        }
49    }
50}
51
52/// SSL/TLS configuration.
53#[derive(Debug, Clone, Default)]
54pub struct SslConfig {
55    /// SSL mode.
56    pub mode: SslMode,
57    /// Path to CA certificate.
58    pub ca_cert: Option<String>,
59    /// Path to client certificate.
60    pub client_cert: Option<String>,
61    /// Path to client key.
62    pub client_key: Option<String>,
63    /// Server name for SNI.
64    pub server_name: Option<String>,
65}
66
67impl SslConfig {
68    /// Create a new SSL config.
69    pub fn new(mode: SslMode) -> Self {
70        Self {
71            mode,
72            ..Default::default()
73        }
74    }
75
76    /// Require SSL.
77    pub fn require() -> Self {
78        Self::new(SslMode::Require)
79    }
80
81    /// Set CA certificate path.
82    pub fn with_ca_cert(mut self, path: impl Into<String>) -> Self {
83        self.ca_cert = Some(path.into());
84        self
85    }
86
87    /// Set client certificate path.
88    pub fn with_client_cert(mut self, path: impl Into<String>) -> Self {
89        self.client_cert = Some(path.into());
90        self
91    }
92
93    /// Set client key path.
94    pub fn with_client_key(mut self, path: impl Into<String>) -> Self {
95        self.client_key = Some(path.into());
96        self
97    }
98}
99
100/// Common connection options.
101#[derive(Debug, Clone)]
102pub struct ConnectionOptions {
103    /// Connection timeout.
104    pub connect_timeout: Duration,
105    /// Read timeout.
106    pub read_timeout: Option<Duration>,
107    /// Write timeout.
108    pub write_timeout: Option<Duration>,
109    /// SSL configuration.
110    pub ssl: SslConfig,
111    /// Application name.
112    pub application_name: Option<String>,
113    /// Schema/database to use after connecting.
114    pub schema: Option<String>,
115    /// Additional options as key-value pairs.
116    pub extra: HashMap<String, String>,
117}
118
119impl Default for ConnectionOptions {
120    fn default() -> Self {
121        Self {
122            connect_timeout: Duration::from_secs(30),
123            read_timeout: None,
124            write_timeout: None,
125            ssl: SslConfig::default(),
126            application_name: None,
127            schema: None,
128            extra: HashMap::new(),
129        }
130    }
131}
132
133impl ConnectionOptions {
134    /// Create new connection options.
135    pub fn new() -> Self {
136        Self::default()
137    }
138
139    /// Set connection timeout.
140    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
141        self.connect_timeout = timeout;
142        self
143    }
144
145    /// Set read timeout.
146    pub fn read_timeout(mut self, timeout: Duration) -> Self {
147        self.read_timeout = Some(timeout);
148        self
149    }
150
151    /// Set write timeout.
152    pub fn write_timeout(mut self, timeout: Duration) -> Self {
153        self.write_timeout = Some(timeout);
154        self
155    }
156
157    /// Set SSL mode.
158    pub fn ssl_mode(mut self, mode: SslMode) -> Self {
159        self.ssl.mode = mode;
160        self
161    }
162
163    /// Set SSL configuration.
164    pub fn ssl(mut self, config: SslConfig) -> Self {
165        self.ssl = config;
166        self
167    }
168
169    /// Set application name.
170    pub fn application_name(mut self, name: impl Into<String>) -> Self {
171        self.application_name = Some(name.into());
172        self
173    }
174
175    /// Set schema.
176    pub fn schema(mut self, schema: impl Into<String>) -> Self {
177        self.schema = Some(schema.into());
178        self
179    }
180
181    /// Add extra option.
182    pub fn option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
183        self.extra.insert(key.into(), value.into());
184        self
185    }
186
187    /// Parse options from URL query parameters.
188    pub fn from_params(params: &HashMap<String, String>) -> Self {
189        let mut opts = Self::default();
190
191        if let Some(timeout) = params.get("connect_timeout") {
192            if let Ok(secs) = timeout.parse::<u64>() {
193                opts.connect_timeout = Duration::from_secs(secs);
194            }
195        }
196
197        if let Some(timeout) = params.get("read_timeout") {
198            if let Ok(secs) = timeout.parse::<u64>() {
199                opts.read_timeout = Some(Duration::from_secs(secs));
200            }
201        }
202
203        if let Some(timeout) = params.get("write_timeout") {
204            if let Ok(secs) = timeout.parse::<u64>() {
205                opts.write_timeout = Some(Duration::from_secs(secs));
206            }
207        }
208
209        if let Some(ssl) = params.get("sslmode").or_else(|| params.get("ssl")) {
210            if let Some(mode) = SslMode::from_str(ssl) {
211                opts.ssl.mode = mode;
212            }
213        }
214
215        if let Some(name) = params.get("application_name") {
216            opts.application_name = Some(name.clone());
217        }
218
219        if let Some(schema) = params.get("schema").or_else(|| params.get("search_path")) {
220            opts.schema = Some(schema.clone());
221        }
222
223        // Copy remaining params as extra options
224        for (key, value) in params {
225            if !matches!(
226                key.as_str(),
227                "connect_timeout"
228                    | "read_timeout"
229                    | "write_timeout"
230                    | "sslmode"
231                    | "ssl"
232                    | "application_name"
233                    | "schema"
234                    | "search_path"
235            ) {
236                opts.extra.insert(key.clone(), value.clone());
237            }
238        }
239
240        opts
241    }
242}
243
244/// Pool options.
245#[derive(Debug, Clone)]
246pub struct PoolOptions {
247    /// Maximum number of connections.
248    pub max_connections: u32,
249    /// Minimum number of connections to keep idle.
250    pub min_connections: u32,
251    /// Maximum time to wait for a connection.
252    pub acquire_timeout: Duration,
253    /// Maximum idle time before closing a connection.
254    pub idle_timeout: Option<Duration>,
255    /// Maximum lifetime of a connection.
256    pub max_lifetime: Option<Duration>,
257    /// Test connections before returning them.
258    pub test_before_acquire: bool,
259}
260
261impl Default for PoolOptions {
262    fn default() -> Self {
263        Self {
264            max_connections: 10,
265            min_connections: 1,
266            acquire_timeout: Duration::from_secs(30),
267            idle_timeout: Some(Duration::from_secs(600)),
268            max_lifetime: Some(Duration::from_secs(1800)),
269            test_before_acquire: true,
270        }
271    }
272}
273
274impl PoolOptions {
275    /// Create new pool options.
276    pub fn new() -> Self {
277        Self::default()
278    }
279
280    /// Set max connections.
281    pub fn max_connections(mut self, n: u32) -> Self {
282        self.max_connections = n;
283        self
284    }
285
286    /// Set min connections.
287    pub fn min_connections(mut self, n: u32) -> Self {
288        self.min_connections = n;
289        self
290    }
291
292    /// Set acquire timeout.
293    pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
294        self.acquire_timeout = timeout;
295        self
296    }
297
298    /// Set idle timeout.
299    pub fn idle_timeout(mut self, timeout: Duration) -> Self {
300        self.idle_timeout = Some(timeout);
301        self
302    }
303
304    /// Disable idle timeout.
305    pub fn no_idle_timeout(mut self) -> Self {
306        self.idle_timeout = None;
307        self
308    }
309
310    /// Set max lifetime.
311    pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
312        self.max_lifetime = Some(lifetime);
313        self
314    }
315
316    /// Disable max lifetime.
317    pub fn no_max_lifetime(mut self) -> Self {
318        self.max_lifetime = None;
319        self
320    }
321
322    /// Enable/disable test before acquire.
323    pub fn test_before_acquire(mut self, enabled: bool) -> Self {
324        self.test_before_acquire = enabled;
325        self
326    }
327}
328
329/// PostgreSQL-specific options.
330#[derive(Debug, Clone, Default)]
331pub struct PostgresOptions {
332    /// Statement cache capacity.
333    pub statement_cache_capacity: usize,
334    /// Enable prepared statements.
335    pub prepared_statements: bool,
336    /// Channel binding mode.
337    pub channel_binding: Option<String>,
338    /// Target session attributes.
339    pub target_session_attrs: Option<String>,
340}
341
342impl PostgresOptions {
343    /// Create new PostgreSQL options.
344    pub fn new() -> Self {
345        Self {
346            statement_cache_capacity: 100,
347            prepared_statements: true,
348            channel_binding: None,
349            target_session_attrs: None,
350        }
351    }
352
353    /// Set statement cache capacity.
354    pub fn statement_cache(mut self, capacity: usize) -> Self {
355        self.statement_cache_capacity = capacity;
356        self
357    }
358
359    /// Enable/disable prepared statements.
360    pub fn prepared_statements(mut self, enabled: bool) -> Self {
361        self.prepared_statements = enabled;
362        self
363    }
364}
365
366/// MySQL-specific options.
367#[derive(Debug, Clone, Default)]
368pub struct MySqlOptions {
369    /// Enable compression.
370    pub compression: bool,
371    /// Character set.
372    pub charset: Option<String>,
373    /// Collation.
374    pub collation: Option<String>,
375    /// SQL mode.
376    pub sql_mode: Option<String>,
377    /// Timezone.
378    pub timezone: Option<String>,
379}
380
381impl MySqlOptions {
382    /// Create new MySQL options.
383    pub fn new() -> Self {
384        Self::default()
385    }
386
387    /// Enable compression.
388    pub fn compression(mut self, enabled: bool) -> Self {
389        self.compression = enabled;
390        self
391    }
392
393    /// Set character set.
394    pub fn charset(mut self, charset: impl Into<String>) -> Self {
395        self.charset = Some(charset.into());
396        self
397    }
398
399    /// Set SQL mode.
400    pub fn sql_mode(mut self, mode: impl Into<String>) -> Self {
401        self.sql_mode = Some(mode.into());
402        self
403    }
404}
405
406/// SQLite-specific options.
407#[derive(Debug, Clone)]
408pub struct SqliteOptions {
409    /// Journal mode.
410    pub journal_mode: SqliteJournalMode,
411    /// Synchronous mode.
412    pub synchronous: SqliteSynchronous,
413    /// Foreign keys enforcement.
414    pub foreign_keys: bool,
415    /// Busy timeout in milliseconds.
416    pub busy_timeout: u32,
417    /// Cache size in pages (negative for KB).
418    pub cache_size: i32,
419}
420
421impl Default for SqliteOptions {
422    fn default() -> Self {
423        Self {
424            journal_mode: SqliteJournalMode::Wal,
425            synchronous: SqliteSynchronous::Normal,
426            foreign_keys: true,
427            busy_timeout: 5000,
428            cache_size: -2000, // 2MB
429        }
430    }
431}
432
433impl SqliteOptions {
434    /// Create new SQLite options.
435    pub fn new() -> Self {
436        Self::default()
437    }
438
439    /// Set journal mode.
440    pub fn journal_mode(mut self, mode: SqliteJournalMode) -> Self {
441        self.journal_mode = mode;
442        self
443    }
444
445    /// Set synchronous mode.
446    pub fn synchronous(mut self, mode: SqliteSynchronous) -> Self {
447        self.synchronous = mode;
448        self
449    }
450
451    /// Enable/disable foreign keys.
452    pub fn foreign_keys(mut self, enabled: bool) -> Self {
453        self.foreign_keys = enabled;
454        self
455    }
456
457    /// Set busy timeout in milliseconds.
458    pub fn busy_timeout(mut self, ms: u32) -> Self {
459        self.busy_timeout = ms;
460        self
461    }
462
463    /// Generate PRAGMA statements.
464    pub fn to_pragmas(&self) -> Vec<String> {
465        vec![
466            format!("PRAGMA journal_mode = {};", self.journal_mode.as_str()),
467            format!("PRAGMA synchronous = {};", self.synchronous.as_str()),
468            format!(
469                "PRAGMA foreign_keys = {};",
470                if self.foreign_keys { "ON" } else { "OFF" }
471            ),
472            format!("PRAGMA busy_timeout = {};", self.busy_timeout),
473            format!("PRAGMA cache_size = {};", self.cache_size),
474        ]
475    }
476}
477
478/// SQLite journal mode.
479#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
480pub enum SqliteJournalMode {
481    /// Delete journal after transaction.
482    Delete,
483    /// Truncate journal.
484    Truncate,
485    /// Persist journal.
486    Persist,
487    /// In-memory journal.
488    Memory,
489    /// Write-ahead logging (recommended).
490    #[default]
491    Wal,
492    /// Disable journaling.
493    Off,
494}
495
496impl SqliteJournalMode {
497    /// Get the SQL string.
498    pub fn as_str(&self) -> &'static str {
499        match self {
500            Self::Delete => "DELETE",
501            Self::Truncate => "TRUNCATE",
502            Self::Persist => "PERSIST",
503            Self::Memory => "MEMORY",
504            Self::Wal => "WAL",
505            Self::Off => "OFF",
506        }
507    }
508}
509
510/// SQLite synchronous mode.
511#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
512pub enum SqliteSynchronous {
513    /// No synchronization.
514    Off,
515    /// Normal synchronization.
516    #[default]
517    Normal,
518    /// Full synchronization.
519    Full,
520    /// Extra synchronization.
521    Extra,
522}
523
524impl SqliteSynchronous {
525    /// Get the SQL string.
526    pub fn as_str(&self) -> &'static str {
527        match self {
528            Self::Off => "OFF",
529            Self::Normal => "NORMAL",
530            Self::Full => "FULL",
531            Self::Extra => "EXTRA",
532        }
533    }
534}
535
536#[cfg(test)]
537mod tests {
538    use super::*;
539
540    #[test]
541    fn test_ssl_mode_parse() {
542        assert_eq!(SslMode::from_str("disable"), Some(SslMode::Disable));
543        assert_eq!(SslMode::from_str("require"), Some(SslMode::Require));
544        assert_eq!(SslMode::from_str("verify-full"), Some(SslMode::VerifyFull));
545        assert_eq!(SslMode::from_str("invalid"), None);
546    }
547
548    #[test]
549    fn test_connection_options_builder() {
550        let opts = ConnectionOptions::new()
551            .connect_timeout(Duration::from_secs(10))
552            .ssl_mode(SslMode::Require)
553            .application_name("test-app");
554
555        assert_eq!(opts.connect_timeout, Duration::from_secs(10));
556        assert_eq!(opts.ssl.mode, SslMode::Require);
557        assert_eq!(opts.application_name, Some("test-app".to_string()));
558    }
559
560    #[test]
561    fn test_pool_options_builder() {
562        let opts = PoolOptions::new()
563            .max_connections(20)
564            .min_connections(5)
565            .no_idle_timeout();
566
567        assert_eq!(opts.max_connections, 20);
568        assert_eq!(opts.min_connections, 5);
569        assert_eq!(opts.idle_timeout, None);
570    }
571
572    #[test]
573    fn test_sqlite_options_pragmas() {
574        let opts = SqliteOptions::new()
575            .journal_mode(SqliteJournalMode::Wal)
576            .foreign_keys(true);
577
578        let pragmas = opts.to_pragmas();
579        assert!(pragmas.iter().any(|p| p.contains("journal_mode = WAL")));
580        assert!(pragmas.iter().any(|p| p.contains("foreign_keys = ON")));
581    }
582
583    #[test]
584    fn test_options_from_params() {
585        let mut params = HashMap::new();
586        params.insert("connect_timeout".to_string(), "10".to_string());
587        params.insert("sslmode".to_string(), "require".to_string());
588        params.insert("application_name".to_string(), "myapp".to_string());
589
590        let opts = ConnectionOptions::from_params(&params);
591        assert_eq!(opts.connect_timeout, Duration::from_secs(10));
592        assert_eq!(opts.ssl.mode, SslMode::Require);
593        assert_eq!(opts.application_name, Some("myapp".to_string()));
594    }
595}