prax_query/connection/
options.rs

1//! Connection and pool options.
2
3use super::{ConnectionError, ConnectionResult, ConnectionString, Driver};
4use std::collections::HashMap;
5use std::time::Duration;
6
7/// SSL/TLS mode for connections.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
9pub enum SslMode {
10    /// Disable SSL.
11    Disable,
12    /// Allow SSL but don't require it.
13    Allow,
14    /// Prefer SSL but allow non-SSL.
15    #[default]
16    Prefer,
17    /// Require SSL.
18    Require,
19    /// Require SSL and verify the server certificate.
20    VerifyCa,
21    /// Require SSL and verify the server certificate and hostname.
22    VerifyFull,
23}
24
25impl SslMode {
26    /// Parse from string.
27    pub fn from_str(s: &str) -> Option<Self> {
28        match s.to_lowercase().as_str() {
29            "disable" | "false" | "0" => Some(Self::Disable),
30            "allow" => Some(Self::Allow),
31            "prefer" => Some(Self::Prefer),
32            "require" | "true" | "1" => Some(Self::Require),
33            "verify-ca" | "verify_ca" => Some(Self::VerifyCa),
34            "verify-full" | "verify_full" => Some(Self::VerifyFull),
35            _ => None,
36        }
37    }
38
39    /// Convert to string.
40    pub fn as_str(&self) -> &'static str {
41        match self {
42            Self::Disable => "disable",
43            Self::Allow => "allow",
44            Self::Prefer => "prefer",
45            Self::Require => "require",
46            Self::VerifyCa => "verify-ca",
47            Self::VerifyFull => "verify-full",
48        }
49    }
50}
51
52/// SSL/TLS configuration.
53#[derive(Debug, Clone, Default)]
54pub struct SslConfig {
55    /// SSL mode.
56    pub mode: SslMode,
57    /// Path to CA certificate.
58    pub ca_cert: Option<String>,
59    /// Path to client certificate.
60    pub client_cert: Option<String>,
61    /// Path to client key.
62    pub client_key: Option<String>,
63    /// Server name for SNI.
64    pub server_name: Option<String>,
65}
66
67impl SslConfig {
68    /// Create a new SSL config.
69    pub fn new(mode: SslMode) -> Self {
70        Self {
71            mode,
72            ..Default::default()
73        }
74    }
75
76    /// Require SSL.
77    pub fn require() -> Self {
78        Self::new(SslMode::Require)
79    }
80
81    /// Set CA certificate path.
82    pub fn with_ca_cert(mut self, path: impl Into<String>) -> Self {
83        self.ca_cert = Some(path.into());
84        self
85    }
86
87    /// Set client certificate path.
88    pub fn with_client_cert(mut self, path: impl Into<String>) -> Self {
89        self.client_cert = Some(path.into());
90        self
91    }
92
93    /// Set client key path.
94    pub fn with_client_key(mut self, path: impl Into<String>) -> Self {
95        self.client_key = Some(path.into());
96        self
97    }
98}
99
100/// Common connection options.
101#[derive(Debug, Clone)]
102pub struct ConnectionOptions {
103    /// Connection timeout.
104    pub connect_timeout: Duration,
105    /// Read timeout.
106    pub read_timeout: Option<Duration>,
107    /// Write timeout.
108    pub write_timeout: Option<Duration>,
109    /// SSL configuration.
110    pub ssl: SslConfig,
111    /// Application name.
112    pub application_name: Option<String>,
113    /// Schema/database to use after connecting.
114    pub schema: Option<String>,
115    /// Additional options as key-value pairs.
116    pub extra: HashMap<String, String>,
117}
118
119impl Default for ConnectionOptions {
120    fn default() -> Self {
121        Self {
122            connect_timeout: Duration::from_secs(30),
123            read_timeout: None,
124            write_timeout: None,
125            ssl: SslConfig::default(),
126            application_name: None,
127            schema: None,
128            extra: HashMap::new(),
129        }
130    }
131}
132
133impl ConnectionOptions {
134    /// Create new connection options.
135    pub fn new() -> Self {
136        Self::default()
137    }
138
139    /// Set connection timeout.
140    pub fn connect_timeout(mut self, timeout: Duration) -> Self {
141        self.connect_timeout = timeout;
142        self
143    }
144
145    /// Set read timeout.
146    pub fn read_timeout(mut self, timeout: Duration) -> Self {
147        self.read_timeout = Some(timeout);
148        self
149    }
150
151    /// Set write timeout.
152    pub fn write_timeout(mut self, timeout: Duration) -> Self {
153        self.write_timeout = Some(timeout);
154        self
155    }
156
157    /// Set SSL mode.
158    pub fn ssl_mode(mut self, mode: SslMode) -> Self {
159        self.ssl.mode = mode;
160        self
161    }
162
163    /// Set SSL configuration.
164    pub fn ssl(mut self, config: SslConfig) -> Self {
165        self.ssl = config;
166        self
167    }
168
169    /// Set application name.
170    pub fn application_name(mut self, name: impl Into<String>) -> Self {
171        self.application_name = Some(name.into());
172        self
173    }
174
175    /// Set schema.
176    pub fn schema(mut self, schema: impl Into<String>) -> Self {
177        self.schema = Some(schema.into());
178        self
179    }
180
181    /// Add extra option.
182    pub fn option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
183        self.extra.insert(key.into(), value.into());
184        self
185    }
186
187    /// Parse options from URL query parameters.
188    pub fn from_params(params: &HashMap<String, String>) -> Self {
189        let mut opts = Self::default();
190
191        if let Some(timeout) = params.get("connect_timeout") {
192            if let Ok(secs) = timeout.parse::<u64>() {
193                opts.connect_timeout = Duration::from_secs(secs);
194            }
195        }
196
197        if let Some(timeout) = params.get("read_timeout") {
198            if let Ok(secs) = timeout.parse::<u64>() {
199                opts.read_timeout = Some(Duration::from_secs(secs));
200            }
201        }
202
203        if let Some(timeout) = params.get("write_timeout") {
204            if let Ok(secs) = timeout.parse::<u64>() {
205                opts.write_timeout = Some(Duration::from_secs(secs));
206            }
207        }
208
209        if let Some(ssl) = params.get("sslmode").or_else(|| params.get("ssl")) {
210            if let Some(mode) = SslMode::from_str(ssl) {
211                opts.ssl.mode = mode;
212            }
213        }
214
215        if let Some(name) = params.get("application_name") {
216            opts.application_name = Some(name.clone());
217        }
218
219        if let Some(schema) = params.get("schema").or_else(|| params.get("search_path")) {
220            opts.schema = Some(schema.clone());
221        }
222
223        // Copy remaining params as extra options
224        for (key, value) in params {
225            if !matches!(
226                key.as_str(),
227                "connect_timeout"
228                    | "read_timeout"
229                    | "write_timeout"
230                    | "sslmode"
231                    | "ssl"
232                    | "application_name"
233                    | "schema"
234                    | "search_path"
235            ) {
236                opts.extra.insert(key.clone(), value.clone());
237            }
238        }
239
240        opts
241    }
242}
243
244/// Pool options.
245#[derive(Debug, Clone)]
246pub struct PoolOptions {
247    /// Maximum number of connections.
248    pub max_connections: u32,
249    /// Minimum number of connections to keep idle.
250    pub min_connections: u32,
251    /// Maximum time to wait for a connection.
252    pub acquire_timeout: Duration,
253    /// Maximum idle time before closing a connection.
254    pub idle_timeout: Option<Duration>,
255    /// Maximum lifetime of a connection.
256    pub max_lifetime: Option<Duration>,
257    /// Test connections before returning them.
258    pub test_before_acquire: bool,
259}
260
261impl Default for PoolOptions {
262    fn default() -> Self {
263        Self {
264            max_connections: 10,
265            min_connections: 1,
266            acquire_timeout: Duration::from_secs(30),
267            idle_timeout: Some(Duration::from_secs(600)),
268            max_lifetime: Some(Duration::from_secs(1800)),
269            test_before_acquire: true,
270        }
271    }
272}
273
274impl PoolOptions {
275    /// Create new pool options.
276    pub fn new() -> Self {
277        Self::default()
278    }
279
280    /// Set max connections.
281    pub fn max_connections(mut self, n: u32) -> Self {
282        self.max_connections = n;
283        self
284    }
285
286    /// Set min connections.
287    pub fn min_connections(mut self, n: u32) -> Self {
288        self.min_connections = n;
289        self
290    }
291
292    /// Set acquire timeout.
293    pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
294        self.acquire_timeout = timeout;
295        self
296    }
297
298    /// Set idle timeout.
299    pub fn idle_timeout(mut self, timeout: Duration) -> Self {
300        self.idle_timeout = Some(timeout);
301        self
302    }
303
304    /// Disable idle timeout.
305    pub fn no_idle_timeout(mut self) -> Self {
306        self.idle_timeout = None;
307        self
308    }
309
310    /// Set max lifetime.
311    pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
312        self.max_lifetime = Some(lifetime);
313        self
314    }
315
316    /// Disable max lifetime.
317    pub fn no_max_lifetime(mut self) -> Self {
318        self.max_lifetime = None;
319        self
320    }
321
322    /// Enable/disable test before acquire.
323    pub fn test_before_acquire(mut self, enabled: bool) -> Self {
324        self.test_before_acquire = enabled;
325        self
326    }
327}
328
329/// PostgreSQL-specific options.
330#[derive(Debug, Clone, Default)]
331pub struct PostgresOptions {
332    /// Statement cache capacity.
333    pub statement_cache_capacity: usize,
334    /// Enable prepared statements.
335    pub prepared_statements: bool,
336    /// Channel binding mode.
337    pub channel_binding: Option<String>,
338    /// Target session attributes.
339    pub target_session_attrs: Option<String>,
340}
341
342impl PostgresOptions {
343    /// Create new PostgreSQL options.
344    pub fn new() -> Self {
345        Self {
346            statement_cache_capacity: 100,
347            prepared_statements: true,
348            channel_binding: None,
349            target_session_attrs: None,
350        }
351    }
352
353    /// Set statement cache capacity.
354    pub fn statement_cache(mut self, capacity: usize) -> Self {
355        self.statement_cache_capacity = capacity;
356        self
357    }
358
359    /// Enable/disable prepared statements.
360    pub fn prepared_statements(mut self, enabled: bool) -> Self {
361        self.prepared_statements = enabled;
362        self
363    }
364}
365
366/// MySQL-specific options.
367#[derive(Debug, Clone, Default)]
368pub struct MySqlOptions {
369    /// Enable compression.
370    pub compression: bool,
371    /// Character set.
372    pub charset: Option<String>,
373    /// Collation.
374    pub collation: Option<String>,
375    /// SQL mode.
376    pub sql_mode: Option<String>,
377    /// Timezone.
378    pub timezone: Option<String>,
379}
380
381impl MySqlOptions {
382    /// Create new MySQL options.
383    pub fn new() -> Self {
384        Self::default()
385    }
386
387    /// Enable compression.
388    pub fn compression(mut self, enabled: bool) -> Self {
389        self.compression = enabled;
390        self
391    }
392
393    /// Set character set.
394    pub fn charset(mut self, charset: impl Into<String>) -> Self {
395        self.charset = Some(charset.into());
396        self
397    }
398
399    /// Set SQL mode.
400    pub fn sql_mode(mut self, mode: impl Into<String>) -> Self {
401        self.sql_mode = Some(mode.into());
402        self
403    }
404}
405
406/// SQLite-specific options.
407#[derive(Debug, Clone)]
408pub struct SqliteOptions {
409    /// Journal mode.
410    pub journal_mode: SqliteJournalMode,
411    /// Synchronous mode.
412    pub synchronous: SqliteSynchronous,
413    /// Foreign keys enforcement.
414    pub foreign_keys: bool,
415    /// Busy timeout in milliseconds.
416    pub busy_timeout: u32,
417    /// Cache size in pages (negative for KB).
418    pub cache_size: i32,
419}
420
421impl Default for SqliteOptions {
422    fn default() -> Self {
423        Self {
424            journal_mode: SqliteJournalMode::Wal,
425            synchronous: SqliteSynchronous::Normal,
426            foreign_keys: true,
427            busy_timeout: 5000,
428            cache_size: -2000, // 2MB
429        }
430    }
431}
432
433impl SqliteOptions {
434    /// Create new SQLite options.
435    pub fn new() -> Self {
436        Self::default()
437    }
438
439    /// Set journal mode.
440    pub fn journal_mode(mut self, mode: SqliteJournalMode) -> Self {
441        self.journal_mode = mode;
442        self
443    }
444
445    /// Set synchronous mode.
446    pub fn synchronous(mut self, mode: SqliteSynchronous) -> Self {
447        self.synchronous = mode;
448        self
449    }
450
451    /// Enable/disable foreign keys.
452    pub fn foreign_keys(mut self, enabled: bool) -> Self {
453        self.foreign_keys = enabled;
454        self
455    }
456
457    /// Set busy timeout in milliseconds.
458    pub fn busy_timeout(mut self, ms: u32) -> Self {
459        self.busy_timeout = ms;
460        self
461    }
462
463    /// Generate PRAGMA statements.
464    pub fn to_pragmas(&self) -> Vec<String> {
465        vec![
466            format!("PRAGMA journal_mode = {};", self.journal_mode.as_str()),
467            format!("PRAGMA synchronous = {};", self.synchronous.as_str()),
468            format!("PRAGMA foreign_keys = {};", if self.foreign_keys { "ON" } else { "OFF" }),
469            format!("PRAGMA busy_timeout = {};", self.busy_timeout),
470            format!("PRAGMA cache_size = {};", self.cache_size),
471        ]
472    }
473}
474
475/// SQLite journal mode.
476#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
477pub enum SqliteJournalMode {
478    /// Delete journal after transaction.
479    Delete,
480    /// Truncate journal.
481    Truncate,
482    /// Persist journal.
483    Persist,
484    /// In-memory journal.
485    Memory,
486    /// Write-ahead logging (recommended).
487    #[default]
488    Wal,
489    /// Disable journaling.
490    Off,
491}
492
493impl SqliteJournalMode {
494    /// Get the SQL string.
495    pub fn as_str(&self) -> &'static str {
496        match self {
497            Self::Delete => "DELETE",
498            Self::Truncate => "TRUNCATE",
499            Self::Persist => "PERSIST",
500            Self::Memory => "MEMORY",
501            Self::Wal => "WAL",
502            Self::Off => "OFF",
503        }
504    }
505}
506
507/// SQLite synchronous mode.
508#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
509pub enum SqliteSynchronous {
510    /// No synchronization.
511    Off,
512    /// Normal synchronization.
513    #[default]
514    Normal,
515    /// Full synchronization.
516    Full,
517    /// Extra synchronization.
518    Extra,
519}
520
521impl SqliteSynchronous {
522    /// Get the SQL string.
523    pub fn as_str(&self) -> &'static str {
524        match self {
525            Self::Off => "OFF",
526            Self::Normal => "NORMAL",
527            Self::Full => "FULL",
528            Self::Extra => "EXTRA",
529        }
530    }
531}
532
533#[cfg(test)]
534mod tests {
535    use super::*;
536
537    #[test]
538    fn test_ssl_mode_parse() {
539        assert_eq!(SslMode::from_str("disable"), Some(SslMode::Disable));
540        assert_eq!(SslMode::from_str("require"), Some(SslMode::Require));
541        assert_eq!(SslMode::from_str("verify-full"), Some(SslMode::VerifyFull));
542        assert_eq!(SslMode::from_str("invalid"), None);
543    }
544
545    #[test]
546    fn test_connection_options_builder() {
547        let opts = ConnectionOptions::new()
548            .connect_timeout(Duration::from_secs(10))
549            .ssl_mode(SslMode::Require)
550            .application_name("test-app");
551
552        assert_eq!(opts.connect_timeout, Duration::from_secs(10));
553        assert_eq!(opts.ssl.mode, SslMode::Require);
554        assert_eq!(opts.application_name, Some("test-app".to_string()));
555    }
556
557    #[test]
558    fn test_pool_options_builder() {
559        let opts = PoolOptions::new()
560            .max_connections(20)
561            .min_connections(5)
562            .no_idle_timeout();
563
564        assert_eq!(opts.max_connections, 20);
565        assert_eq!(opts.min_connections, 5);
566        assert_eq!(opts.idle_timeout, None);
567    }
568
569    #[test]
570    fn test_sqlite_options_pragmas() {
571        let opts = SqliteOptions::new()
572            .journal_mode(SqliteJournalMode::Wal)
573            .foreign_keys(true);
574
575        let pragmas = opts.to_pragmas();
576        assert!(pragmas.iter().any(|p| p.contains("journal_mode = WAL")));
577        assert!(pragmas.iter().any(|p| p.contains("foreign_keys = ON")));
578    }
579
580    #[test]
581    fn test_options_from_params() {
582        let mut params = HashMap::new();
583        params.insert("connect_timeout".to_string(), "10".to_string());
584        params.insert("sslmode".to_string(), "require".to_string());
585        params.insert("application_name".to_string(), "myapp".to_string());
586
587        let opts = ConnectionOptions::from_params(&params);
588        assert_eq!(opts.connect_timeout, Duration::from_secs(10));
589        assert_eq!(opts.ssl.mode, SslMode::Require);
590        assert_eq!(opts.application_name, Some("myapp".to_string()));
591    }
592}
593
594