1use std::time::Duration;
4
5use mssql_auth::Credentials;
6use mssql_tls::TlsConfig;
7
8#[derive(Debug, Clone)]
13pub struct RedirectConfig {
14 pub max_redirects: u8,
16 pub follow_redirects: bool,
18}
19
20impl Default for RedirectConfig {
21 fn default() -> Self {
22 Self {
23 max_redirects: 2,
24 follow_redirects: true,
25 }
26 }
27}
28
29impl RedirectConfig {
30 #[must_use]
32 pub fn new() -> Self {
33 Self::default()
34 }
35
36 #[must_use]
38 pub fn max_redirects(mut self, max: u8) -> Self {
39 self.max_redirects = max;
40 self
41 }
42
43 #[must_use]
45 pub fn follow_redirects(mut self, follow: bool) -> Self {
46 self.follow_redirects = follow;
47 self
48 }
49
50 #[must_use]
55 pub fn no_follow() -> Self {
56 Self {
57 max_redirects: 0,
58 follow_redirects: false,
59 }
60 }
61}
62
63#[derive(Debug, Clone)]
68pub struct TimeoutConfig {
69 pub connect_timeout: Duration,
71 pub tls_timeout: Duration,
73 pub login_timeout: Duration,
75 pub command_timeout: Duration,
77 pub idle_timeout: Duration,
79 pub keepalive_interval: Option<Duration>,
81}
82
83impl Default for TimeoutConfig {
84 fn default() -> Self {
85 Self {
86 connect_timeout: Duration::from_secs(15),
87 tls_timeout: Duration::from_secs(10),
88 login_timeout: Duration::from_secs(30),
89 command_timeout: Duration::from_secs(30),
90 idle_timeout: Duration::from_secs(300),
91 keepalive_interval: Some(Duration::from_secs(30)),
92 }
93 }
94}
95
96impl TimeoutConfig {
97 #[must_use]
99 pub fn new() -> Self {
100 Self::default()
101 }
102
103 #[must_use]
105 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
106 self.connect_timeout = timeout;
107 self
108 }
109
110 #[must_use]
112 pub fn tls_timeout(mut self, timeout: Duration) -> Self {
113 self.tls_timeout = timeout;
114 self
115 }
116
117 #[must_use]
119 pub fn login_timeout(mut self, timeout: Duration) -> Self {
120 self.login_timeout = timeout;
121 self
122 }
123
124 #[must_use]
126 pub fn command_timeout(mut self, timeout: Duration) -> Self {
127 self.command_timeout = timeout;
128 self
129 }
130
131 #[must_use]
133 pub fn idle_timeout(mut self, timeout: Duration) -> Self {
134 self.idle_timeout = timeout;
135 self
136 }
137
138 #[must_use]
140 pub fn keepalive_interval(mut self, interval: Option<Duration>) -> Self {
141 self.keepalive_interval = interval;
142 self
143 }
144
145 #[must_use]
147 pub fn no_keepalive(mut self) -> Self {
148 self.keepalive_interval = None;
149 self
150 }
151
152 #[must_use]
154 pub fn total_connect_timeout(&self) -> Duration {
155 self.connect_timeout + self.tls_timeout + self.login_timeout
156 }
157}
158
159#[derive(Debug, Clone)]
164pub struct RetryPolicy {
165 pub max_retries: u32,
167 pub initial_backoff: Duration,
169 pub max_backoff: Duration,
171 pub backoff_multiplier: f64,
173 pub jitter: bool,
175}
176
177impl Default for RetryPolicy {
178 fn default() -> Self {
179 Self {
180 max_retries: 3,
181 initial_backoff: Duration::from_millis(100),
182 max_backoff: Duration::from_secs(30),
183 backoff_multiplier: 2.0,
184 jitter: true,
185 }
186 }
187}
188
189impl RetryPolicy {
190 #[must_use]
192 pub fn new() -> Self {
193 Self::default()
194 }
195
196 #[must_use]
198 pub fn max_retries(mut self, max: u32) -> Self {
199 self.max_retries = max;
200 self
201 }
202
203 #[must_use]
205 pub fn initial_backoff(mut self, backoff: Duration) -> Self {
206 self.initial_backoff = backoff;
207 self
208 }
209
210 #[must_use]
212 pub fn max_backoff(mut self, backoff: Duration) -> Self {
213 self.max_backoff = backoff;
214 self
215 }
216
217 #[must_use]
219 pub fn backoff_multiplier(mut self, multiplier: f64) -> Self {
220 self.backoff_multiplier = multiplier;
221 self
222 }
223
224 #[must_use]
226 pub fn jitter(mut self, enabled: bool) -> Self {
227 self.jitter = enabled;
228 self
229 }
230
231 #[must_use]
233 pub fn no_retry() -> Self {
234 Self {
235 max_retries: 0,
236 ..Self::default()
237 }
238 }
239
240 #[must_use]
244 pub fn backoff_for_attempt(&self, attempt: u32) -> Duration {
245 if attempt == 0 {
246 return Duration::ZERO;
247 }
248
249 let base = self.initial_backoff.as_millis() as f64
250 * self
251 .backoff_multiplier
252 .powi(attempt.saturating_sub(1) as i32);
253 let capped = base.min(self.max_backoff.as_millis() as f64);
254
255 if self.jitter {
256 Duration::from_millis(capped as u64)
259 } else {
260 Duration::from_millis(capped as u64)
261 }
262 }
263
264 #[must_use]
266 pub fn should_retry(&self, attempt: u32) -> bool {
267 attempt < self.max_retries
268 }
269}
270
271#[derive(Debug, Clone)]
273pub struct Config {
274 pub host: String,
276
277 pub port: u16,
279
280 pub database: Option<String>,
282
283 pub credentials: Credentials,
285
286 pub tls: TlsConfig,
288
289 pub application_name: String,
291
292 pub connect_timeout: Duration,
294
295 pub command_timeout: Duration,
297
298 pub packet_size: u16,
300
301 pub strict_mode: bool,
303
304 pub trust_server_certificate: bool,
306
307 pub instance: Option<String>,
309
310 pub mars: bool,
312
313 pub encrypt: bool,
317
318 pub redirect: RedirectConfig,
320
321 pub retry: RetryPolicy,
323
324 pub timeouts: TimeoutConfig,
326}
327
328impl Default for Config {
329 fn default() -> Self {
330 let timeouts = TimeoutConfig::default();
331 Self {
332 host: "localhost".to_string(),
333 port: 1433,
334 database: None,
335 credentials: Credentials::sql_server("", ""),
336 tls: TlsConfig::default(),
337 application_name: "mssql-client".to_string(),
338 connect_timeout: timeouts.connect_timeout,
339 command_timeout: timeouts.command_timeout,
340 packet_size: 4096,
341 strict_mode: false,
342 trust_server_certificate: false,
343 instance: None,
344 mars: false,
345 encrypt: true, redirect: RedirectConfig::default(),
347 retry: RetryPolicy::default(),
348 timeouts,
349 }
350 }
351}
352
353impl Config {
354 #[must_use]
356 pub fn new() -> Self {
357 Self::default()
358 }
359
360 pub fn from_connection_string(conn_str: &str) -> Result<Self, crate::error::Error> {
367 let mut config = Self::default();
368
369 for part in conn_str.split(';') {
370 let part = part.trim();
371 if part.is_empty() {
372 continue;
373 }
374
375 let (key, value) = part
376 .split_once('=')
377 .ok_or_else(|| crate::error::Error::Config(format!("invalid key-value: {part}")))?;
378
379 let key = key.trim().to_lowercase();
380 let value = value.trim();
381
382 match key.as_str() {
383 "server" | "data source" | "host" => {
384 if let Some((host, port_or_instance)) = value.split_once(',') {
386 config.host = host.to_string();
387 config.port = port_or_instance.parse().map_err(|_| {
388 crate::error::Error::Config(format!("invalid port: {port_or_instance}"))
389 })?;
390 } else if let Some((host, instance)) = value.split_once('\\') {
391 config.host = host.to_string();
392 config.instance = Some(instance.to_string());
393 } else {
394 config.host = value.to_string();
395 }
396 }
397 "port" => {
398 config.port = value.parse().map_err(|_| {
399 crate::error::Error::Config(format!("invalid port: {value}"))
400 })?;
401 }
402 "database" | "initial catalog" => {
403 config.database = Some(value.to_string());
404 }
405 "user id" | "uid" | "user" => {
406 if let Credentials::SqlServer { password, .. } = &config.credentials {
408 config.credentials =
409 Credentials::sql_server(value.to_string(), password.clone());
410 }
411 }
412 "password" | "pwd" => {
413 if let Credentials::SqlServer { username, .. } = &config.credentials {
415 config.credentials =
416 Credentials::sql_server(username.clone(), value.to_string());
417 }
418 }
419 "application name" | "app" => {
420 config.application_name = value.to_string();
421 }
422 "connect timeout" | "connection timeout" => {
423 let secs: u64 = value.parse().map_err(|_| {
424 crate::error::Error::Config(format!("invalid timeout: {value}"))
425 })?;
426 config.connect_timeout = Duration::from_secs(secs);
427 }
428 "command timeout" => {
429 let secs: u64 = value.parse().map_err(|_| {
430 crate::error::Error::Config(format!("invalid timeout: {value}"))
431 })?;
432 config.command_timeout = Duration::from_secs(secs);
433 }
434 "trustservercertificate" | "trust server certificate" => {
435 config.trust_server_certificate = value.eq_ignore_ascii_case("true")
436 || value.eq_ignore_ascii_case("yes")
437 || value == "1";
438 }
439 "encrypt" => {
440 if value.eq_ignore_ascii_case("strict") {
442 config.strict_mode = true;
443 config.encrypt = true;
444 } else if value.eq_ignore_ascii_case("true")
445 || value.eq_ignore_ascii_case("yes")
446 || value == "1"
447 {
448 config.encrypt = true;
449 } else if value.eq_ignore_ascii_case("false")
450 || value.eq_ignore_ascii_case("no")
451 || value == "0"
452 {
453 config.encrypt = false;
454 }
455 }
456 "multipleactiveresultsets" | "mars" => {
457 config.mars = value.eq_ignore_ascii_case("true")
458 || value.eq_ignore_ascii_case("yes")
459 || value == "1";
460 }
461 "packet size" => {
462 config.packet_size = value.parse().map_err(|_| {
463 crate::error::Error::Config(format!("invalid packet size: {value}"))
464 })?;
465 }
466 _ => {
467 tracing::debug!(
469 key = key,
470 value = value,
471 "ignoring unknown connection string option"
472 );
473 }
474 }
475 }
476
477 Ok(config)
478 }
479
480 #[must_use]
482 pub fn host(mut self, host: impl Into<String>) -> Self {
483 self.host = host.into();
484 self
485 }
486
487 #[must_use]
489 pub fn port(mut self, port: u16) -> Self {
490 self.port = port;
491 self
492 }
493
494 #[must_use]
496 pub fn database(mut self, database: impl Into<String>) -> Self {
497 self.database = Some(database.into());
498 self
499 }
500
501 #[must_use]
503 pub fn credentials(mut self, credentials: Credentials) -> Self {
504 self.credentials = credentials;
505 self
506 }
507
508 #[must_use]
510 pub fn application_name(mut self, name: impl Into<String>) -> Self {
511 self.application_name = name.into();
512 self
513 }
514
515 #[must_use]
517 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
518 self.connect_timeout = timeout;
519 self
520 }
521
522 #[must_use]
524 pub fn trust_server_certificate(mut self, trust: bool) -> Self {
525 self.trust_server_certificate = trust;
526 self.tls = self.tls.trust_server_certificate(trust);
527 self
528 }
529
530 #[must_use]
532 pub fn strict_mode(mut self, enabled: bool) -> Self {
533 self.strict_mode = enabled;
534 self.tls = self.tls.strict_mode(enabled);
535 self
536 }
537
538 #[must_use]
546 pub fn encrypt(mut self, enabled: bool) -> Self {
547 self.encrypt = enabled;
548 self
549 }
550
551 #[must_use]
553 pub fn with_host(mut self, host: &str) -> Self {
554 self.host = host.to_string();
555 self
556 }
557
558 #[must_use]
560 pub fn with_port(mut self, port: u16) -> Self {
561 self.port = port;
562 self
563 }
564
565 #[must_use]
567 pub fn redirect(mut self, redirect: RedirectConfig) -> Self {
568 self.redirect = redirect;
569 self
570 }
571
572 #[must_use]
574 pub fn max_redirects(mut self, max: u8) -> Self {
575 self.redirect.max_redirects = max;
576 self
577 }
578
579 #[must_use]
581 pub fn retry(mut self, retry: RetryPolicy) -> Self {
582 self.retry = retry;
583 self
584 }
585
586 #[must_use]
588 pub fn max_retries(mut self, max: u32) -> Self {
589 self.retry.max_retries = max;
590 self
591 }
592
593 #[must_use]
595 pub fn timeouts(mut self, timeouts: TimeoutConfig) -> Self {
596 self.connect_timeout = timeouts.connect_timeout;
598 self.command_timeout = timeouts.command_timeout;
599 self.timeouts = timeouts;
600 self
601 }
602}
603
604#[cfg(test)]
605#[allow(clippy::unwrap_used)]
606mod tests {
607 use super::*;
608
609 #[test]
610 fn test_connection_string_parsing() {
611 let config = Config::from_connection_string(
612 "Server=localhost;Database=test;User Id=sa;Password=secret;",
613 )
614 .unwrap();
615
616 assert_eq!(config.host, "localhost");
617 assert_eq!(config.database, Some("test".to_string()));
618 }
619
620 #[test]
621 fn test_connection_string_with_port() {
622 let config =
623 Config::from_connection_string("Server=localhost,1434;Database=test;").unwrap();
624
625 assert_eq!(config.host, "localhost");
626 assert_eq!(config.port, 1434);
627 }
628
629 #[test]
630 fn test_connection_string_with_instance() {
631 let config =
632 Config::from_connection_string("Server=localhost\\SQLEXPRESS;Database=test;").unwrap();
633
634 assert_eq!(config.host, "localhost");
635 assert_eq!(config.instance, Some("SQLEXPRESS".to_string()));
636 }
637
638 #[test]
639 fn test_redirect_config_defaults() {
640 let config = RedirectConfig::default();
641 assert_eq!(config.max_redirects, 2);
642 assert!(config.follow_redirects);
643 }
644
645 #[test]
646 fn test_redirect_config_builder() {
647 let config = RedirectConfig::new()
648 .max_redirects(5)
649 .follow_redirects(false);
650 assert_eq!(config.max_redirects, 5);
651 assert!(!config.follow_redirects);
652 }
653
654 #[test]
655 fn test_redirect_config_no_follow() {
656 let config = RedirectConfig::no_follow();
657 assert_eq!(config.max_redirects, 0);
658 assert!(!config.follow_redirects);
659 }
660
661 #[test]
662 fn test_config_redirect_builder() {
663 let config = Config::new().max_redirects(3);
664 assert_eq!(config.redirect.max_redirects, 3);
665
666 let config2 = Config::new().redirect(RedirectConfig::no_follow());
667 assert!(!config2.redirect.follow_redirects);
668 }
669
670 #[test]
671 fn test_retry_policy_defaults() {
672 let policy = RetryPolicy::default();
673 assert_eq!(policy.max_retries, 3);
674 assert_eq!(policy.initial_backoff, Duration::from_millis(100));
675 assert_eq!(policy.max_backoff, Duration::from_secs(30));
676 assert!((policy.backoff_multiplier - 2.0).abs() < f64::EPSILON);
677 assert!(policy.jitter);
678 }
679
680 #[test]
681 fn test_retry_policy_builder() {
682 let policy = RetryPolicy::new()
683 .max_retries(5)
684 .initial_backoff(Duration::from_millis(200))
685 .max_backoff(Duration::from_secs(60))
686 .backoff_multiplier(3.0)
687 .jitter(false);
688
689 assert_eq!(policy.max_retries, 5);
690 assert_eq!(policy.initial_backoff, Duration::from_millis(200));
691 assert_eq!(policy.max_backoff, Duration::from_secs(60));
692 assert!((policy.backoff_multiplier - 3.0).abs() < f64::EPSILON);
693 assert!(!policy.jitter);
694 }
695
696 #[test]
697 fn test_retry_policy_no_retry() {
698 let policy = RetryPolicy::no_retry();
699 assert_eq!(policy.max_retries, 0);
700 assert!(!policy.should_retry(0));
701 }
702
703 #[test]
704 fn test_retry_policy_should_retry() {
705 let policy = RetryPolicy::new().max_retries(3);
706 assert!(policy.should_retry(0));
707 assert!(policy.should_retry(1));
708 assert!(policy.should_retry(2));
709 assert!(!policy.should_retry(3));
710 assert!(!policy.should_retry(4));
711 }
712
713 #[test]
714 fn test_retry_policy_backoff_calculation() {
715 let policy = RetryPolicy::new()
716 .initial_backoff(Duration::from_millis(100))
717 .backoff_multiplier(2.0)
718 .max_backoff(Duration::from_secs(10))
719 .jitter(false);
720
721 assert_eq!(policy.backoff_for_attempt(0), Duration::ZERO);
722 assert_eq!(policy.backoff_for_attempt(1), Duration::from_millis(100));
723 assert_eq!(policy.backoff_for_attempt(2), Duration::from_millis(200));
724 assert_eq!(policy.backoff_for_attempt(3), Duration::from_millis(400));
725 }
726
727 #[test]
728 fn test_retry_policy_backoff_capped() {
729 let policy = RetryPolicy::new()
730 .initial_backoff(Duration::from_secs(1))
731 .backoff_multiplier(10.0)
732 .max_backoff(Duration::from_secs(5))
733 .jitter(false);
734
735 assert_eq!(policy.backoff_for_attempt(3), Duration::from_secs(5));
737 }
738
739 #[test]
740 fn test_config_retry_builder() {
741 let config = Config::new().max_retries(5);
742 assert_eq!(config.retry.max_retries, 5);
743
744 let config2 = Config::new().retry(RetryPolicy::no_retry());
745 assert_eq!(config2.retry.max_retries, 0);
746 }
747
748 #[test]
749 fn test_timeout_config_defaults() {
750 let config = TimeoutConfig::default();
751 assert_eq!(config.connect_timeout, Duration::from_secs(15));
752 assert_eq!(config.tls_timeout, Duration::from_secs(10));
753 assert_eq!(config.login_timeout, Duration::from_secs(30));
754 assert_eq!(config.command_timeout, Duration::from_secs(30));
755 assert_eq!(config.idle_timeout, Duration::from_secs(300));
756 assert_eq!(config.keepalive_interval, Some(Duration::from_secs(30)));
757 }
758
759 #[test]
760 fn test_timeout_config_builder() {
761 let config = TimeoutConfig::new()
762 .connect_timeout(Duration::from_secs(5))
763 .tls_timeout(Duration::from_secs(3))
764 .login_timeout(Duration::from_secs(10))
765 .command_timeout(Duration::from_secs(60))
766 .idle_timeout(Duration::from_secs(600))
767 .keepalive_interval(Some(Duration::from_secs(60)));
768
769 assert_eq!(config.connect_timeout, Duration::from_secs(5));
770 assert_eq!(config.tls_timeout, Duration::from_secs(3));
771 assert_eq!(config.login_timeout, Duration::from_secs(10));
772 assert_eq!(config.command_timeout, Duration::from_secs(60));
773 assert_eq!(config.idle_timeout, Duration::from_secs(600));
774 assert_eq!(config.keepalive_interval, Some(Duration::from_secs(60)));
775 }
776
777 #[test]
778 fn test_timeout_config_no_keepalive() {
779 let config = TimeoutConfig::new().no_keepalive();
780 assert_eq!(config.keepalive_interval, None);
781 }
782
783 #[test]
784 fn test_timeout_config_total_connect() {
785 let config = TimeoutConfig::new()
786 .connect_timeout(Duration::from_secs(5))
787 .tls_timeout(Duration::from_secs(3))
788 .login_timeout(Duration::from_secs(10));
789
790 assert_eq!(config.total_connect_timeout(), Duration::from_secs(18));
792 }
793
794 #[test]
795 fn test_config_timeouts_builder() {
796 let timeouts = TimeoutConfig::new()
797 .connect_timeout(Duration::from_secs(5))
798 .command_timeout(Duration::from_secs(60));
799
800 let config = Config::new().timeouts(timeouts);
801 assert_eq!(config.timeouts.connect_timeout, Duration::from_secs(5));
802 assert_eq!(config.timeouts.command_timeout, Duration::from_secs(60));
803 assert_eq!(config.connect_timeout, Duration::from_secs(5));
805 assert_eq!(config.command_timeout, Duration::from_secs(60));
806 }
807}