1use super::{ConnectionError, ConnectionResult, ConnectionString, Driver};
4use std::collections::HashMap;
5use std::time::Duration;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
9pub enum SslMode {
10 Disable,
12 Allow,
14 #[default]
16 Prefer,
17 Require,
19 VerifyCa,
21 VerifyFull,
23}
24
25impl SslMode {
26 pub fn from_str(s: &str) -> Option<Self> {
28 match s.to_lowercase().as_str() {
29 "disable" | "false" | "0" => Some(Self::Disable),
30 "allow" => Some(Self::Allow),
31 "prefer" => Some(Self::Prefer),
32 "require" | "true" | "1" => Some(Self::Require),
33 "verify-ca" | "verify_ca" => Some(Self::VerifyCa),
34 "verify-full" | "verify_full" => Some(Self::VerifyFull),
35 _ => None,
36 }
37 }
38
39 pub fn as_str(&self) -> &'static str {
41 match self {
42 Self::Disable => "disable",
43 Self::Allow => "allow",
44 Self::Prefer => "prefer",
45 Self::Require => "require",
46 Self::VerifyCa => "verify-ca",
47 Self::VerifyFull => "verify-full",
48 }
49 }
50}
51
52#[derive(Debug, Clone, Default)]
54pub struct SslConfig {
55 pub mode: SslMode,
57 pub ca_cert: Option<String>,
59 pub client_cert: Option<String>,
61 pub client_key: Option<String>,
63 pub server_name: Option<String>,
65}
66
67impl SslConfig {
68 pub fn new(mode: SslMode) -> Self {
70 Self {
71 mode,
72 ..Default::default()
73 }
74 }
75
76 pub fn require() -> Self {
78 Self::new(SslMode::Require)
79 }
80
81 pub fn with_ca_cert(mut self, path: impl Into<String>) -> Self {
83 self.ca_cert = Some(path.into());
84 self
85 }
86
87 pub fn with_client_cert(mut self, path: impl Into<String>) -> Self {
89 self.client_cert = Some(path.into());
90 self
91 }
92
93 pub fn with_client_key(mut self, path: impl Into<String>) -> Self {
95 self.client_key = Some(path.into());
96 self
97 }
98}
99
100#[derive(Debug, Clone)]
102pub struct ConnectionOptions {
103 pub connect_timeout: Duration,
105 pub read_timeout: Option<Duration>,
107 pub write_timeout: Option<Duration>,
109 pub ssl: SslConfig,
111 pub application_name: Option<String>,
113 pub schema: Option<String>,
115 pub extra: HashMap<String, String>,
117}
118
119impl Default for ConnectionOptions {
120 fn default() -> Self {
121 Self {
122 connect_timeout: Duration::from_secs(30),
123 read_timeout: None,
124 write_timeout: None,
125 ssl: SslConfig::default(),
126 application_name: None,
127 schema: None,
128 extra: HashMap::new(),
129 }
130 }
131}
132
133impl ConnectionOptions {
134 pub fn new() -> Self {
136 Self::default()
137 }
138
139 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
141 self.connect_timeout = timeout;
142 self
143 }
144
145 pub fn read_timeout(mut self, timeout: Duration) -> Self {
147 self.read_timeout = Some(timeout);
148 self
149 }
150
151 pub fn write_timeout(mut self, timeout: Duration) -> Self {
153 self.write_timeout = Some(timeout);
154 self
155 }
156
157 pub fn ssl_mode(mut self, mode: SslMode) -> Self {
159 self.ssl.mode = mode;
160 self
161 }
162
163 pub fn ssl(mut self, config: SslConfig) -> Self {
165 self.ssl = config;
166 self
167 }
168
169 pub fn application_name(mut self, name: impl Into<String>) -> Self {
171 self.application_name = Some(name.into());
172 self
173 }
174
175 pub fn schema(mut self, schema: impl Into<String>) -> Self {
177 self.schema = Some(schema.into());
178 self
179 }
180
181 pub fn option(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
183 self.extra.insert(key.into(), value.into());
184 self
185 }
186
187 pub fn from_params(params: &HashMap<String, String>) -> Self {
189 let mut opts = Self::default();
190
191 if let Some(timeout) = params.get("connect_timeout") {
192 if let Ok(secs) = timeout.parse::<u64>() {
193 opts.connect_timeout = Duration::from_secs(secs);
194 }
195 }
196
197 if let Some(timeout) = params.get("read_timeout") {
198 if let Ok(secs) = timeout.parse::<u64>() {
199 opts.read_timeout = Some(Duration::from_secs(secs));
200 }
201 }
202
203 if let Some(timeout) = params.get("write_timeout") {
204 if let Ok(secs) = timeout.parse::<u64>() {
205 opts.write_timeout = Some(Duration::from_secs(secs));
206 }
207 }
208
209 if let Some(ssl) = params.get("sslmode").or_else(|| params.get("ssl")) {
210 if let Some(mode) = SslMode::from_str(ssl) {
211 opts.ssl.mode = mode;
212 }
213 }
214
215 if let Some(name) = params.get("application_name") {
216 opts.application_name = Some(name.clone());
217 }
218
219 if let Some(schema) = params.get("schema").or_else(|| params.get("search_path")) {
220 opts.schema = Some(schema.clone());
221 }
222
223 for (key, value) in params {
225 if !matches!(
226 key.as_str(),
227 "connect_timeout"
228 | "read_timeout"
229 | "write_timeout"
230 | "sslmode"
231 | "ssl"
232 | "application_name"
233 | "schema"
234 | "search_path"
235 ) {
236 opts.extra.insert(key.clone(), value.clone());
237 }
238 }
239
240 opts
241 }
242}
243
244#[derive(Debug, Clone)]
246pub struct PoolOptions {
247 pub max_connections: u32,
249 pub min_connections: u32,
251 pub acquire_timeout: Duration,
253 pub idle_timeout: Option<Duration>,
255 pub max_lifetime: Option<Duration>,
257 pub test_before_acquire: bool,
259}
260
261impl Default for PoolOptions {
262 fn default() -> Self {
263 Self {
264 max_connections: 10,
265 min_connections: 1,
266 acquire_timeout: Duration::from_secs(30),
267 idle_timeout: Some(Duration::from_secs(600)),
268 max_lifetime: Some(Duration::from_secs(1800)),
269 test_before_acquire: true,
270 }
271 }
272}
273
274impl PoolOptions {
275 pub fn new() -> Self {
277 Self::default()
278 }
279
280 pub fn max_connections(mut self, n: u32) -> Self {
282 self.max_connections = n;
283 self
284 }
285
286 pub fn min_connections(mut self, n: u32) -> Self {
288 self.min_connections = n;
289 self
290 }
291
292 pub fn acquire_timeout(mut self, timeout: Duration) -> Self {
294 self.acquire_timeout = timeout;
295 self
296 }
297
298 pub fn idle_timeout(mut self, timeout: Duration) -> Self {
300 self.idle_timeout = Some(timeout);
301 self
302 }
303
304 pub fn no_idle_timeout(mut self) -> Self {
306 self.idle_timeout = None;
307 self
308 }
309
310 pub fn max_lifetime(mut self, lifetime: Duration) -> Self {
312 self.max_lifetime = Some(lifetime);
313 self
314 }
315
316 pub fn no_max_lifetime(mut self) -> Self {
318 self.max_lifetime = None;
319 self
320 }
321
322 pub fn test_before_acquire(mut self, enabled: bool) -> Self {
324 self.test_before_acquire = enabled;
325 self
326 }
327}
328
329#[derive(Debug, Clone, Default)]
331pub struct PostgresOptions {
332 pub statement_cache_capacity: usize,
334 pub prepared_statements: bool,
336 pub channel_binding: Option<String>,
338 pub target_session_attrs: Option<String>,
340}
341
342impl PostgresOptions {
343 pub fn new() -> Self {
345 Self {
346 statement_cache_capacity: 100,
347 prepared_statements: true,
348 channel_binding: None,
349 target_session_attrs: None,
350 }
351 }
352
353 pub fn statement_cache(mut self, capacity: usize) -> Self {
355 self.statement_cache_capacity = capacity;
356 self
357 }
358
359 pub fn prepared_statements(mut self, enabled: bool) -> Self {
361 self.prepared_statements = enabled;
362 self
363 }
364}
365
366#[derive(Debug, Clone, Default)]
368pub struct MySqlOptions {
369 pub compression: bool,
371 pub charset: Option<String>,
373 pub collation: Option<String>,
375 pub sql_mode: Option<String>,
377 pub timezone: Option<String>,
379}
380
381impl MySqlOptions {
382 pub fn new() -> Self {
384 Self::default()
385 }
386
387 pub fn compression(mut self, enabled: bool) -> Self {
389 self.compression = enabled;
390 self
391 }
392
393 pub fn charset(mut self, charset: impl Into<String>) -> Self {
395 self.charset = Some(charset.into());
396 self
397 }
398
399 pub fn sql_mode(mut self, mode: impl Into<String>) -> Self {
401 self.sql_mode = Some(mode.into());
402 self
403 }
404}
405
406#[derive(Debug, Clone)]
408pub struct SqliteOptions {
409 pub journal_mode: SqliteJournalMode,
411 pub synchronous: SqliteSynchronous,
413 pub foreign_keys: bool,
415 pub busy_timeout: u32,
417 pub cache_size: i32,
419}
420
421impl Default for SqliteOptions {
422 fn default() -> Self {
423 Self {
424 journal_mode: SqliteJournalMode::Wal,
425 synchronous: SqliteSynchronous::Normal,
426 foreign_keys: true,
427 busy_timeout: 5000,
428 cache_size: -2000, }
430 }
431}
432
433impl SqliteOptions {
434 pub fn new() -> Self {
436 Self::default()
437 }
438
439 pub fn journal_mode(mut self, mode: SqliteJournalMode) -> Self {
441 self.journal_mode = mode;
442 self
443 }
444
445 pub fn synchronous(mut self, mode: SqliteSynchronous) -> Self {
447 self.synchronous = mode;
448 self
449 }
450
451 pub fn foreign_keys(mut self, enabled: bool) -> Self {
453 self.foreign_keys = enabled;
454 self
455 }
456
457 pub fn busy_timeout(mut self, ms: u32) -> Self {
459 self.busy_timeout = ms;
460 self
461 }
462
463 pub fn to_pragmas(&self) -> Vec<String> {
465 vec![
466 format!("PRAGMA journal_mode = {};", self.journal_mode.as_str()),
467 format!("PRAGMA synchronous = {};", self.synchronous.as_str()),
468 format!("PRAGMA foreign_keys = {};", if self.foreign_keys { "ON" } else { "OFF" }),
469 format!("PRAGMA busy_timeout = {};", self.busy_timeout),
470 format!("PRAGMA cache_size = {};", self.cache_size),
471 ]
472 }
473}
474
475#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
477pub enum SqliteJournalMode {
478 Delete,
480 Truncate,
482 Persist,
484 Memory,
486 #[default]
488 Wal,
489 Off,
491}
492
493impl SqliteJournalMode {
494 pub fn as_str(&self) -> &'static str {
496 match self {
497 Self::Delete => "DELETE",
498 Self::Truncate => "TRUNCATE",
499 Self::Persist => "PERSIST",
500 Self::Memory => "MEMORY",
501 Self::Wal => "WAL",
502 Self::Off => "OFF",
503 }
504 }
505}
506
507#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
509pub enum SqliteSynchronous {
510 Off,
512 #[default]
514 Normal,
515 Full,
517 Extra,
519}
520
521impl SqliteSynchronous {
522 pub fn as_str(&self) -> &'static str {
524 match self {
525 Self::Off => "OFF",
526 Self::Normal => "NORMAL",
527 Self::Full => "FULL",
528 Self::Extra => "EXTRA",
529 }
530 }
531}
532
533#[cfg(test)]
534mod tests {
535 use super::*;
536
537 #[test]
538 fn test_ssl_mode_parse() {
539 assert_eq!(SslMode::from_str("disable"), Some(SslMode::Disable));
540 assert_eq!(SslMode::from_str("require"), Some(SslMode::Require));
541 assert_eq!(SslMode::from_str("verify-full"), Some(SslMode::VerifyFull));
542 assert_eq!(SslMode::from_str("invalid"), None);
543 }
544
545 #[test]
546 fn test_connection_options_builder() {
547 let opts = ConnectionOptions::new()
548 .connect_timeout(Duration::from_secs(10))
549 .ssl_mode(SslMode::Require)
550 .application_name("test-app");
551
552 assert_eq!(opts.connect_timeout, Duration::from_secs(10));
553 assert_eq!(opts.ssl.mode, SslMode::Require);
554 assert_eq!(opts.application_name, Some("test-app".to_string()));
555 }
556
557 #[test]
558 fn test_pool_options_builder() {
559 let opts = PoolOptions::new()
560 .max_connections(20)
561 .min_connections(5)
562 .no_idle_timeout();
563
564 assert_eq!(opts.max_connections, 20);
565 assert_eq!(opts.min_connections, 5);
566 assert_eq!(opts.idle_timeout, None);
567 }
568
569 #[test]
570 fn test_sqlite_options_pragmas() {
571 let opts = SqliteOptions::new()
572 .journal_mode(SqliteJournalMode::Wal)
573 .foreign_keys(true);
574
575 let pragmas = opts.to_pragmas();
576 assert!(pragmas.iter().any(|p| p.contains("journal_mode = WAL")));
577 assert!(pragmas.iter().any(|p| p.contains("foreign_keys = ON")));
578 }
579
580 #[test]
581 fn test_options_from_params() {
582 let mut params = HashMap::new();
583 params.insert("connect_timeout".to_string(), "10".to_string());
584 params.insert("sslmode".to_string(), "require".to_string());
585 params.insert("application_name".to_string(), "myapp".to_string());
586
587 let opts = ConnectionOptions::from_params(¶ms);
588 assert_eq!(opts.connect_timeout, Duration::from_secs(10));
589 assert_eq!(opts.ssl.mode, SslMode::Require);
590 assert_eq!(opts.application_name, Some("myapp".to_string()));
591 }
592}
593
594