1use super::{ConnectionError, ConnectionResult, ConnectionString, Driver};
4use std::collections::HashMap;
5use std::time::Duration;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
9pub enum SslMode {
10 Disable,
12 Allow,
14 #[default]
16 Prefer,
17 Require,
19 VerifyCa,
21 VerifyFull,
23}
24
25impl SslMode {
26 pub fn from_str(s: &str) -> Option<Self> {
28 match s.to_lowercase().as_str() {
29 "disable" | "false" | "0" => Some(Self::Disable),
30 "allow" => Some(Self::Allow),
31 "prefer" => Some(Self::Prefer),
32 "require" | "true" | "1" => Some(Self::Require),
33 "verify-ca" | "verify_ca" => Some(Self::VerifyCa),
34 "verify-full" | "verify_full" => Some(Self::VerifyFull),
35 _ => None,
36 }
37 }
38
39 pub fn as_str(&self) -> &'static str {
41 match self {
42 Self::Disable => "disable",
43 Self::Allow => "allow",
44 Self::Prefer => "prefer",
45 Self::Require => "require",
46 Self::VerifyCa => "verify-ca",
47 Self::VerifyFull => "verify-full",
48 }
49 }
50}
51
52#[derive(Debug, Clone, Default)]
54pub struct SslConfig {
55 pub mode: SslMode,
57 pub ca_cert: Option<String>,
59 pub client_cert: Option<String>,
61 pub client_key: Option<String>,
63 pub server_name: Option<String>,
65}
66
67impl SslConfig {
68 pub fn new(mode: SslMode) -> Self {
70 Self {
71 mode,
72 ..Default::default()
73 }
74 }
75
76 pub fn require() -> Self {
78 Self::new(SslMode::Require)
79 }
80
81 pub fn with_ca_cert(mut self, path: impl Into<String>) -> Self {
83 self.ca_cert = Some(path.into());
84 self
85 }
86
87 pub fn with_client_cert(mut self, path: impl Into<String>) -> Self {
89 self.client_cert = Some(path.into());
90 self
91 }
92
93 pub fn with_client_key(mut self, path: impl Into<String>) -> Self {
95 self.client_key = Some(path.into());
96 self
97 }
98}
99
100#[derive(Debug, Clone)]
102pub struct ConnectionOptions {
103 pub connect_timeout: Duration,
105 pub read_timeout: Option<Duration>,
107 pub write_timeout: Option<Duration>,
109 pub ssl: SslConfig,
111 pub application_name: Option<String>,
113 pub schema: Option<String>,
115 pub extra: HashMap<String, String>,
117}
118
119impl Default for ConnectionOptions {
120 fn default() -> Self {
121 Self {
122 connect_timeout: Duration::from_secs(30),
123 read_timeout: None,
124 write_timeout: None,
125 ssl: SslConfig::default(),
126 application_name: None,
127 schema: None,
128 extra: HashMap::new(),
129 }
130 }
131}
132
133impl ConnectionOptions {
134 pub fn new() -> Self {
136 Self::default()
137 }
138
139 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
141 self.connect_timeout = timeout;
142 self
143 }
144
145 pub fn read_timeout(mut self, timeout: Duration) -> Self {
147 self.read_timeout = Some(timeout);
148 self
149 }
150
151 pub fn write_timeout(mut self, timeout: Duration) -> Self {
153 self.write_timeout = Some(timeout);
154 self
155 }
156
157 pub fn ssl_mode(mut self, mode: SslMode) -> Self {
159 self.ssl.mode = mode;
160 self
161 }
162
163 pub fn ssl(mut self, config: SslConfig) -> Self {
165 self.ssl = config;
166 self
167 }
168
169 pub fn application_name(mut self, name: impl Into<String>) -> Self {
171 self.application_name = Some(name.into());
172 self
173 }
174
175 pub fn schema(mut self, schema: impl Into<String>) -> Self {
177 self.schema = Some(schema.into());
178 self
179 }
180
181 pub fn option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
183 self.extra.insert(key.into(), value.into());
184 self
185 }
186
187 pub fn from_params(params: &HashMap<String, String>) -> Self {
189 let mut opts = Self::default();
190
191 if let Some(timeout) = params.get("connect_timeout") {
192 if let Ok(secs) = timeout.parse::<u64>() {
193 opts.connect_timeout = Duration::from_secs(secs);
194 }
195 }
196
197 if let Some(timeout) = params.get("read_timeout") {
198 if let Ok(secs) = timeout.parse::<u64>() {
199 opts.read_timeout = Some(Duration::from_secs(secs));
200 }
201 }
202
203 if let Some(timeout) = params.get("write_timeout") {
204 if let Ok(secs) = timeout.parse::<u64>() {
205 opts.write_timeout = Some(Duration::from_secs(secs));
206 }
207 }
208
209 if let Some(ssl) = params.get("sslmode").or_else(|| params.get("ssl")) {
210 if let Some(mode) = SslMode::from_str(ssl) {
211 opts.ssl.mode = mode;
212 }
213 }
214
215 if let Some(name) = params.get("application_name") {
216 opts.application_name = Some(name.clone());
217 }
218
219 if let Some(schema) = params.get("schema").or_else(|| params.get("search_path")) {
220 opts.schema = Some(schema.clone());
221 }
222
223 for (key, value) in params {
225 if !matches!(
226 key.as_str(),
227 "connect_timeout"
228 | "read_timeout"
229 | "write_timeout"
230 | "sslmode"
231 | "ssl"
232 | "application_name"
233 | "schema"
234 | "search_path"
235 ) {
236 opts.extra.insert(key.clone(), value.clone());
237 }
238 }
239
240 opts
241 }
242}
243
244#[derive(Debug, Clone)]
246pub struct PoolOptions {
247 pub max_connections: u32,
249 pub min_connections: u32,
251 pub acquire_timeout: Duration,
253 pub idle_timeout: Option<Duration>,
255 pub max_lifetime: Option<Duration>,
257 pub test_before_acquire: bool,
259}
260
261impl Default for PoolOptions {
262 fn default() -> Self {
263 Self {
264 max_connections: 10,
265 min_connections: 1,
266 acquire_timeout: Duration::from_secs(30),
267 idle_timeout: Some(Duration::from_secs(600)),
268 max_lifetime: Some(Duration::from_secs(1800)),
269 test_before_acquire: true,
270 }
271 }
272}
273
274impl PoolOptions {
275 pub fn new() -> Self {
277 Self::default()
278 }
279
280 pub fn max_connections(mut self, n: u32) -> Self {
282 self.max_connections = n;
283 self
284 }
285
286 pub fn min_connections(mut self, n: u32) -> Self {
288 self.min_connections = n;
289 self
290 }
291
292 pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
294 self.acquire_timeout = timeout;
295 self
296 }
297
298 pub fn idle_timeout(mut self, timeout: Duration) -> Self {
300 self.idle_timeout = Some(timeout);
301 self
302 }
303
304 pub fn no_idle_timeout(mut self) -> Self {
306 self.idle_timeout = None;
307 self
308 }
309
310 pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
312 self.max_lifetime = Some(lifetime);
313 self
314 }
315
316 pub fn no_max_lifetime(mut self) -> Self {
318 self.max_lifetime = None;
319 self
320 }
321
322 pub fn test_before_acquire(mut self, enabled: bool) -> Self {
324 self.test_before_acquire = enabled;
325 self
326 }
327}
328
329#[derive(Debug, Clone, Default)]
331pub struct PostgresOptions {
332 pub statement_cache_capacity: usize,
334 pub prepared_statements: bool,
336 pub channel_binding: Option<String>,
338 pub target_session_attrs: Option<String>,
340}
341
342impl PostgresOptions {
343 pub fn new() -> Self {
345 Self {
346 statement_cache_capacity: 100,
347 prepared_statements: true,
348 channel_binding: None,
349 target_session_attrs: None,
350 }
351 }
352
353 pub fn statement_cache(mut self, capacity: usize) -> Self {
355 self.statement_cache_capacity = capacity;
356 self
357 }
358
359 pub fn prepared_statements(mut self, enabled: bool) -> Self {
361 self.prepared_statements = enabled;
362 self
363 }
364}
365
366#[derive(Debug, Clone, Default)]
368pub struct MySqlOptions {
369 pub compression: bool,
371 pub charset: Option<String>,
373 pub collation: Option<String>,
375 pub sql_mode: Option<String>,
377 pub timezone: Option<String>,
379}
380
381impl MySqlOptions {
382 pub fn new() -> Self {
384 Self::default()
385 }
386
387 pub fn compression(mut self, enabled: bool) -> Self {
389 self.compression = enabled;
390 self
391 }
392
393 pub fn charset(mut self, charset: impl Into<String>) -> Self {
395 self.charset = Some(charset.into());
396 self
397 }
398
399 pub fn sql_mode(mut self, mode: impl Into<String>) -> Self {
401 self.sql_mode = Some(mode.into());
402 self
403 }
404}
405
406#[derive(Debug, Clone)]
408pub struct SqliteOptions {
409 pub journal_mode: SqliteJournalMode,
411 pub synchronous: SqliteSynchronous,
413 pub foreign_keys: bool,
415 pub busy_timeout: u32,
417 pub cache_size: i32,
419}
420
421impl Default for SqliteOptions {
422 fn default() -> Self {
423 Self {
424 journal_mode: SqliteJournalMode::Wal,
425 synchronous: SqliteSynchronous::Normal,
426 foreign_keys: true,
427 busy_timeout: 5000,
428 cache_size: -2000, }
430 }
431}
432
433impl SqliteOptions {
434 pub fn new() -> Self {
436 Self::default()
437 }
438
439 pub fn journal_mode(mut self, mode: SqliteJournalMode) -> Self {
441 self.journal_mode = mode;
442 self
443 }
444
445 pub fn synchronous(mut self, mode: SqliteSynchronous) -> Self {
447 self.synchronous = mode;
448 self
449 }
450
451 pub fn foreign_keys(mut self, enabled: bool) -> Self {
453 self.foreign_keys = enabled;
454 self
455 }
456
457 pub fn busy_timeout(mut self, ms: u32) -> Self {
459 self.busy_timeout = ms;
460 self
461 }
462
463 pub fn to_pragmas(&self) -> Vec<String> {
465 vec![
466 format!("PRAGMA journal_mode = {};", self.journal_mode.as_str()),
467 format!("PRAGMA synchronous = {};", self.synchronous.as_str()),
468 format!(
469 "PRAGMA foreign_keys = {};",
470 if self.foreign_keys { "ON" } else { "OFF" }
471 ),
472 format!("PRAGMA busy_timeout = {};", self.busy_timeout),
473 format!("PRAGMA cache_size = {};", self.cache_size),
474 ]
475 }
476}
477
478#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
480pub enum SqliteJournalMode {
481 Delete,
483 Truncate,
485 Persist,
487 Memory,
489 #[default]
491 Wal,
492 Off,
494}
495
496impl SqliteJournalMode {
497 pub fn as_str(&self) -> &'static str {
499 match self {
500 Self::Delete => "DELETE",
501 Self::Truncate => "TRUNCATE",
502 Self::Persist => "PERSIST",
503 Self::Memory => "MEMORY",
504 Self::Wal => "WAL",
505 Self::Off => "OFF",
506 }
507 }
508}
509
510#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
512pub enum SqliteSynchronous {
513 Off,
515 #[default]
517 Normal,
518 Full,
520 Extra,
522}
523
524impl SqliteSynchronous {
525 pub fn as_str(&self) -> &'static str {
527 match self {
528 Self::Off => "OFF",
529 Self::Normal => "NORMAL",
530 Self::Full => "FULL",
531 Self::Extra => "EXTRA",
532 }
533 }
534}
535
536#[cfg(test)]
537mod tests {
538 use super::*;
539
540 #[test]
541 fn test_ssl_mode_parse() {
542 assert_eq!(SslMode::from_str("disable"), Some(SslMode::Disable));
543 assert_eq!(SslMode::from_str("require"), Some(SslMode::Require));
544 assert_eq!(SslMode::from_str("verify-full"), Some(SslMode::VerifyFull));
545 assert_eq!(SslMode::from_str("invalid"), None);
546 }
547
548 #[test]
549 fn test_connection_options_builder() {
550 let opts = ConnectionOptions::new()
551 .connect_timeout(Duration::from_secs(10))
552 .ssl_mode(SslMode::Require)
553 .application_name("test-app");
554
555 assert_eq!(opts.connect_timeout, Duration::from_secs(10));
556 assert_eq!(opts.ssl.mode, SslMode::Require);
557 assert_eq!(opts.application_name, Some("test-app".to_string()));
558 }
559
560 #[test]
561 fn test_pool_options_builder() {
562 let opts = PoolOptions::new()
563 .max_connections(20)
564 .min_connections(5)
565 .no_idle_timeout();
566
567 assert_eq!(opts.max_connections, 20);
568 assert_eq!(opts.min_connections, 5);
569 assert_eq!(opts.idle_timeout, None);
570 }
571
572 #[test]
573 fn test_sqlite_options_pragmas() {
574 let opts = SqliteOptions::new()
575 .journal_mode(SqliteJournalMode::Wal)
576 .foreign_keys(true);
577
578 let pragmas = opts.to_pragmas();
579 assert!(pragmas.iter().any(|p| p.contains("journal_mode = WAL")));
580 assert!(pragmas.iter().any(|p| p.contains("foreign_keys = ON")));
581 }
582
583 #[test]
584 fn test_options_from_params() {
585 let mut params = HashMap::new();
586 params.insert("connect_timeout".to_string(), "10".to_string());
587 params.insert("sslmode".to_string(), "require".to_string());
588 params.insert("application_name".to_string(), "myapp".to_string());
589
590 let opts = ConnectionOptions::from_params(¶ms);
591 assert_eq!(opts.connect_timeout, Duration::from_secs(10));
592 assert_eq!(opts.ssl.mode, SslMode::Require);
593 assert_eq!(opts.application_name, Some("myapp".to_string()));
594 }
595}