1use std::time::Duration;
4
5use mssql_auth::Credentials;
6use mssql_tls::TlsConfig;
7use tds_protocol::version::TdsVersion;
8
9#[derive(Debug, Clone)]
14pub struct RedirectConfig {
15 pub max_redirects: u8,
17 pub follow_redirects: bool,
19}
20
21impl Default for RedirectConfig {
22 fn default() -> Self {
23 Self {
24 max_redirects: 2,
25 follow_redirects: true,
26 }
27 }
28}
29
30impl RedirectConfig {
31 #[must_use]
33 pub fn new() -> Self {
34 Self::default()
35 }
36
37 #[must_use]
39 pub fn max_redirects(mut self, max: u8) -> Self {
40 self.max_redirects = max;
41 self
42 }
43
44 #[must_use]
46 pub fn follow_redirects(mut self, follow: bool) -> Self {
47 self.follow_redirects = follow;
48 self
49 }
50
51 #[must_use]
56 pub fn no_follow() -> Self {
57 Self {
58 max_redirects: 0,
59 follow_redirects: false,
60 }
61 }
62}
63
64#[derive(Debug, Clone)]
69pub struct TimeoutConfig {
70 pub connect_timeout: Duration,
72 pub tls_timeout: Duration,
74 pub login_timeout: Duration,
76 pub command_timeout: Duration,
78 pub idle_timeout: Duration,
80 pub keepalive_interval: Option<Duration>,
82}
83
84impl Default for TimeoutConfig {
85 fn default() -> Self {
86 Self {
87 connect_timeout: Duration::from_secs(15),
88 tls_timeout: Duration::from_secs(10),
89 login_timeout: Duration::from_secs(30),
90 command_timeout: Duration::from_secs(30),
91 idle_timeout: Duration::from_secs(300),
92 keepalive_interval: Some(Duration::from_secs(30)),
93 }
94 }
95}
96
97impl TimeoutConfig {
98 #[must_use]
100 pub fn new() -> Self {
101 Self::default()
102 }
103
104 #[must_use]
106 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
107 self.connect_timeout = timeout;
108 self
109 }
110
111 #[must_use]
113 pub fn tls_timeout(mut self, timeout: Duration) -> Self {
114 self.tls_timeout = timeout;
115 self
116 }
117
118 #[must_use]
120 pub fn login_timeout(mut self, timeout: Duration) -> Self {
121 self.login_timeout = timeout;
122 self
123 }
124
125 #[must_use]
127 pub fn command_timeout(mut self, timeout: Duration) -> Self {
128 self.command_timeout = timeout;
129 self
130 }
131
132 #[must_use]
134 pub fn idle_timeout(mut self, timeout: Duration) -> Self {
135 self.idle_timeout = timeout;
136 self
137 }
138
139 #[must_use]
141 pub fn keepalive_interval(mut self, interval: Option<Duration>) -> Self {
142 self.keepalive_interval = interval;
143 self
144 }
145
146 #[must_use]
148 pub fn no_keepalive(mut self) -> Self {
149 self.keepalive_interval = None;
150 self
151 }
152
153 #[must_use]
155 pub fn total_connect_timeout(&self) -> Duration {
156 self.connect_timeout + self.tls_timeout + self.login_timeout
157 }
158}
159
160#[derive(Debug, Clone)]
165pub struct RetryPolicy {
166 pub max_retries: u32,
168 pub initial_backoff: Duration,
170 pub max_backoff: Duration,
172 pub backoff_multiplier: f64,
174 pub jitter: bool,
176}
177
178impl Default for RetryPolicy {
179 fn default() -> Self {
180 Self {
181 max_retries: 3,
182 initial_backoff: Duration::from_millis(100),
183 max_backoff: Duration::from_secs(30),
184 backoff_multiplier: 2.0,
185 jitter: true,
186 }
187 }
188}
189
190impl RetryPolicy {
191 #[must_use]
193 pub fn new() -> Self {
194 Self::default()
195 }
196
197 #[must_use]
199 pub fn max_retries(mut self, max: u32) -> Self {
200 self.max_retries = max;
201 self
202 }
203
204 #[must_use]
206 pub fn initial_backoff(mut self, backoff: Duration) -> Self {
207 self.initial_backoff = backoff;
208 self
209 }
210
211 #[must_use]
213 pub fn max_backoff(mut self, backoff: Duration) -> Self {
214 self.max_backoff = backoff;
215 self
216 }
217
218 #[must_use]
220 pub fn backoff_multiplier(mut self, multiplier: f64) -> Self {
221 self.backoff_multiplier = multiplier;
222 self
223 }
224
225 #[must_use]
227 pub fn jitter(mut self, enabled: bool) -> Self {
228 self.jitter = enabled;
229 self
230 }
231
232 #[must_use]
234 pub fn no_retry() -> Self {
235 Self {
236 max_retries: 0,
237 ..Self::default()
238 }
239 }
240
241 #[must_use]
245 pub fn backoff_for_attempt(&self, attempt: u32) -> Duration {
246 if attempt == 0 {
247 return Duration::ZERO;
248 }
249
250 let base = self.initial_backoff.as_millis() as f64
251 * self
252 .backoff_multiplier
253 .powi(attempt.saturating_sub(1) as i32);
254 let capped = base.min(self.max_backoff.as_millis() as f64);
255
256 if self.jitter {
257 Duration::from_millis(capped as u64)
260 } else {
261 Duration::from_millis(capped as u64)
262 }
263 }
264
265 #[must_use]
267 pub fn should_retry(&self, attempt: u32) -> bool {
268 attempt < self.max_retries
269 }
270}
271
272#[derive(Debug, Clone)]
278#[non_exhaustive]
279pub struct Config {
280 pub host: String,
282
283 pub port: u16,
285
286 pub database: Option<String>,
288
289 pub credentials: Credentials,
291
292 pub tls: TlsConfig,
294
295 pub application_name: String,
297
298 pub connect_timeout: Duration,
300
301 pub command_timeout: Duration,
303
304 pub packet_size: u16,
306
307 pub strict_mode: bool,
309
310 pub trust_server_certificate: bool,
312
313 pub instance: Option<String>,
315
316 pub mars: bool,
318
319 pub encrypt: bool,
323
324 pub redirect: RedirectConfig,
326
327 pub retry: RetryPolicy,
329
330 pub timeouts: TimeoutConfig,
332
333 pub tds_version: TdsVersion,
346}
347
348impl Default for Config {
349 fn default() -> Self {
350 let timeouts = TimeoutConfig::default();
351 Self {
352 host: "localhost".to_string(),
353 port: 1433,
354 database: None,
355 credentials: Credentials::sql_server("", ""),
356 tls: TlsConfig::default(),
357 application_name: "mssql-client".to_string(),
358 connect_timeout: timeouts.connect_timeout,
359 command_timeout: timeouts.command_timeout,
360 packet_size: 4096,
361 strict_mode: false,
362 trust_server_certificate: false,
363 instance: None,
364 mars: false,
365 encrypt: true, redirect: RedirectConfig::default(),
367 retry: RetryPolicy::default(),
368 timeouts,
369 tds_version: TdsVersion::V7_4, }
371 }
372}
373
374impl Config {
375 #[must_use]
377 pub fn new() -> Self {
378 Self::default()
379 }
380
381 pub fn from_connection_string(conn_str: &str) -> Result<Self, crate::error::Error> {
388 let mut config = Self::default();
389
390 for part in conn_str.split(';') {
391 let part = part.trim();
392 if part.is_empty() {
393 continue;
394 }
395
396 let (key, value) = part
397 .split_once('=')
398 .ok_or_else(|| crate::error::Error::Config(format!("invalid key-value: {part}")))?;
399
400 let key = key.trim().to_lowercase();
401 let value = value.trim();
402
403 match key.as_str() {
404 "server" | "data source" | "host" => {
405 if let Some((host, port_or_instance)) = value.split_once(',') {
407 config.host = host.to_string();
408 config.port = port_or_instance.parse().map_err(|_| {
409 crate::error::Error::Config(format!("invalid port: {port_or_instance}"))
410 })?;
411 } else if let Some((host, instance)) = value.split_once('\\') {
412 config.host = host.to_string();
413 config.instance = Some(instance.to_string());
414 } else {
415 config.host = value.to_string();
416 }
417 }
418 "port" => {
419 config.port = value.parse().map_err(|_| {
420 crate::error::Error::Config(format!("invalid port: {value}"))
421 })?;
422 }
423 "database" | "initial catalog" => {
424 config.database = Some(value.to_string());
425 }
426 "user id" | "uid" | "user" => {
427 if let Credentials::SqlServer { password, .. } = &config.credentials {
429 config.credentials =
430 Credentials::sql_server(value.to_string(), password.clone());
431 }
432 }
433 "password" | "pwd" => {
434 if let Credentials::SqlServer { username, .. } = &config.credentials {
436 config.credentials =
437 Credentials::sql_server(username.clone(), value.to_string());
438 }
439 }
440 "application name" | "app" => {
441 config.application_name = value.to_string();
442 }
443 "connect timeout" | "connection timeout" => {
444 let secs: u64 = value.parse().map_err(|_| {
445 crate::error::Error::Config(format!("invalid timeout: {value}"))
446 })?;
447 config.connect_timeout = Duration::from_secs(secs);
448 }
449 "command timeout" => {
450 let secs: u64 = value.parse().map_err(|_| {
451 crate::error::Error::Config(format!("invalid timeout: {value}"))
452 })?;
453 config.command_timeout = Duration::from_secs(secs);
454 }
455 "trustservercertificate" | "trust server certificate" => {
456 config.trust_server_certificate = value.eq_ignore_ascii_case("true")
457 || value.eq_ignore_ascii_case("yes")
458 || value == "1";
459 }
460 "encrypt" => {
461 if value.eq_ignore_ascii_case("strict") {
463 config.strict_mode = true;
464 config.encrypt = true;
465 } else if value.eq_ignore_ascii_case("true")
466 || value.eq_ignore_ascii_case("yes")
467 || value == "1"
468 {
469 config.encrypt = true;
470 } else if value.eq_ignore_ascii_case("false")
471 || value.eq_ignore_ascii_case("no")
472 || value == "0"
473 {
474 config.encrypt = false;
475 }
476 }
477 "multipleactiveresultsets" | "mars" => {
478 config.mars = value.eq_ignore_ascii_case("true")
479 || value.eq_ignore_ascii_case("yes")
480 || value == "1";
481 }
482 "packet size" => {
483 config.packet_size = value.parse().map_err(|_| {
484 crate::error::Error::Config(format!("invalid packet size: {value}"))
485 })?;
486 }
487 "tdsversion" | "tds version" | "protocolversion" | "protocol version" => {
488 config.tds_version = TdsVersion::parse(value).ok_or_else(|| {
491 crate::error::Error::Config(format!(
492 "invalid TDS version: {value}. Supported values: 7.3, 7.3A, 7.3B, 7.4, 8.0"
493 ))
494 })?;
495 if config.tds_version.is_tds_8() {
497 config.strict_mode = true;
498 }
499 }
500 _ => {
501 tracing::debug!(
503 key = key,
504 value = value,
505 "ignoring unknown connection string option"
506 );
507 }
508 }
509 }
510
511 Ok(config)
512 }
513
514 #[must_use]
516 pub fn host(mut self, host: impl Into<String>) -> Self {
517 self.host = host.into();
518 self
519 }
520
521 #[must_use]
523 pub fn port(mut self, port: u16) -> Self {
524 self.port = port;
525 self
526 }
527
528 #[must_use]
530 pub fn database(mut self, database: impl Into<String>) -> Self {
531 self.database = Some(database.into());
532 self
533 }
534
535 #[must_use]
537 pub fn credentials(mut self, credentials: Credentials) -> Self {
538 self.credentials = credentials;
539 self
540 }
541
542 #[must_use]
544 pub fn application_name(mut self, name: impl Into<String>) -> Self {
545 self.application_name = name.into();
546 self
547 }
548
549 #[must_use]
551 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
552 self.connect_timeout = timeout;
553 self
554 }
555
556 #[must_use]
558 pub fn trust_server_certificate(mut self, trust: bool) -> Self {
559 self.trust_server_certificate = trust;
560 self.tls = self.tls.trust_server_certificate(trust);
561 self
562 }
563
564 #[must_use]
566 pub fn strict_mode(mut self, enabled: bool) -> Self {
567 self.strict_mode = enabled;
568 self.tls = self.tls.strict_mode(enabled);
569 if enabled {
570 self.tds_version = TdsVersion::V8_0;
571 }
572 self
573 }
574
575 #[must_use]
599 pub fn tds_version(mut self, version: TdsVersion) -> Self {
600 self.tds_version = version;
601 if version.is_tds_8() {
603 self.strict_mode = true;
604 self.tls = self.tls.strict_mode(true);
605 }
606 self
607 }
608
609 #[must_use]
617 pub fn encrypt(mut self, enabled: bool) -> Self {
618 self.encrypt = enabled;
619 self
620 }
621
622 #[must_use]
624 pub fn with_host(mut self, host: &str) -> Self {
625 self.host = host.to_string();
626 self
627 }
628
629 #[must_use]
631 pub fn with_port(mut self, port: u16) -> Self {
632 self.port = port;
633 self
634 }
635
636 #[must_use]
638 pub fn redirect(mut self, redirect: RedirectConfig) -> Self {
639 self.redirect = redirect;
640 self
641 }
642
643 #[must_use]
645 pub fn max_redirects(mut self, max: u8) -> Self {
646 self.redirect.max_redirects = max;
647 self
648 }
649
650 #[must_use]
652 pub fn retry(mut self, retry: RetryPolicy) -> Self {
653 self.retry = retry;
654 self
655 }
656
657 #[must_use]
659 pub fn max_retries(mut self, max: u32) -> Self {
660 self.retry.max_retries = max;
661 self
662 }
663
664 #[must_use]
666 pub fn timeouts(mut self, timeouts: TimeoutConfig) -> Self {
667 self.connect_timeout = timeouts.connect_timeout;
669 self.command_timeout = timeouts.command_timeout;
670 self.timeouts = timeouts;
671 self
672 }
673}
674
675#[cfg(test)]
676#[allow(clippy::unwrap_used)]
677mod tests {
678 use super::*;
679
680 #[test]
681 fn test_connection_string_parsing() {
682 let config = Config::from_connection_string(
683 "Server=localhost;Database=test;User Id=sa;Password=secret;",
684 )
685 .unwrap();
686
687 assert_eq!(config.host, "localhost");
688 assert_eq!(config.database, Some("test".to_string()));
689 }
690
691 #[test]
692 fn test_connection_string_with_port() {
693 let config =
694 Config::from_connection_string("Server=localhost,1434;Database=test;").unwrap();
695
696 assert_eq!(config.host, "localhost");
697 assert_eq!(config.port, 1434);
698 }
699
700 #[test]
701 fn test_connection_string_with_instance() {
702 let config =
703 Config::from_connection_string("Server=localhost\\SQLEXPRESS;Database=test;").unwrap();
704
705 assert_eq!(config.host, "localhost");
706 assert_eq!(config.instance, Some("SQLEXPRESS".to_string()));
707 }
708
709 #[test]
710 fn test_redirect_config_defaults() {
711 let config = RedirectConfig::default();
712 assert_eq!(config.max_redirects, 2);
713 assert!(config.follow_redirects);
714 }
715
716 #[test]
717 fn test_redirect_config_builder() {
718 let config = RedirectConfig::new()
719 .max_redirects(5)
720 .follow_redirects(false);
721 assert_eq!(config.max_redirects, 5);
722 assert!(!config.follow_redirects);
723 }
724
725 #[test]
726 fn test_redirect_config_no_follow() {
727 let config = RedirectConfig::no_follow();
728 assert_eq!(config.max_redirects, 0);
729 assert!(!config.follow_redirects);
730 }
731
732 #[test]
733 fn test_config_redirect_builder() {
734 let config = Config::new().max_redirects(3);
735 assert_eq!(config.redirect.max_redirects, 3);
736
737 let config2 = Config::new().redirect(RedirectConfig::no_follow());
738 assert!(!config2.redirect.follow_redirects);
739 }
740
741 #[test]
742 fn test_retry_policy_defaults() {
743 let policy = RetryPolicy::default();
744 assert_eq!(policy.max_retries, 3);
745 assert_eq!(policy.initial_backoff, Duration::from_millis(100));
746 assert_eq!(policy.max_backoff, Duration::from_secs(30));
747 assert!((policy.backoff_multiplier - 2.0).abs() < f64::EPSILON);
748 assert!(policy.jitter);
749 }
750
751 #[test]
752 fn test_retry_policy_builder() {
753 let policy = RetryPolicy::new()
754 .max_retries(5)
755 .initial_backoff(Duration::from_millis(200))
756 .max_backoff(Duration::from_secs(60))
757 .backoff_multiplier(3.0)
758 .jitter(false);
759
760 assert_eq!(policy.max_retries, 5);
761 assert_eq!(policy.initial_backoff, Duration::from_millis(200));
762 assert_eq!(policy.max_backoff, Duration::from_secs(60));
763 assert!((policy.backoff_multiplier - 3.0).abs() < f64::EPSILON);
764 assert!(!policy.jitter);
765 }
766
767 #[test]
768 fn test_retry_policy_no_retry() {
769 let policy = RetryPolicy::no_retry();
770 assert_eq!(policy.max_retries, 0);
771 assert!(!policy.should_retry(0));
772 }
773
774 #[test]
775 fn test_retry_policy_should_retry() {
776 let policy = RetryPolicy::new().max_retries(3);
777 assert!(policy.should_retry(0));
778 assert!(policy.should_retry(1));
779 assert!(policy.should_retry(2));
780 assert!(!policy.should_retry(3));
781 assert!(!policy.should_retry(4));
782 }
783
784 #[test]
785 fn test_retry_policy_backoff_calculation() {
786 let policy = RetryPolicy::new()
787 .initial_backoff(Duration::from_millis(100))
788 .backoff_multiplier(2.0)
789 .max_backoff(Duration::from_secs(10))
790 .jitter(false);
791
792 assert_eq!(policy.backoff_for_attempt(0), Duration::ZERO);
793 assert_eq!(policy.backoff_for_attempt(1), Duration::from_millis(100));
794 assert_eq!(policy.backoff_for_attempt(2), Duration::from_millis(200));
795 assert_eq!(policy.backoff_for_attempt(3), Duration::from_millis(400));
796 }
797
798 #[test]
799 fn test_retry_policy_backoff_capped() {
800 let policy = RetryPolicy::new()
801 .initial_backoff(Duration::from_secs(1))
802 .backoff_multiplier(10.0)
803 .max_backoff(Duration::from_secs(5))
804 .jitter(false);
805
806 assert_eq!(policy.backoff_for_attempt(3), Duration::from_secs(5));
808 }
809
810 #[test]
811 fn test_config_retry_builder() {
812 let config = Config::new().max_retries(5);
813 assert_eq!(config.retry.max_retries, 5);
814
815 let config2 = Config::new().retry(RetryPolicy::no_retry());
816 assert_eq!(config2.retry.max_retries, 0);
817 }
818
819 #[test]
820 fn test_timeout_config_defaults() {
821 let config = TimeoutConfig::default();
822 assert_eq!(config.connect_timeout, Duration::from_secs(15));
823 assert_eq!(config.tls_timeout, Duration::from_secs(10));
824 assert_eq!(config.login_timeout, Duration::from_secs(30));
825 assert_eq!(config.command_timeout, Duration::from_secs(30));
826 assert_eq!(config.idle_timeout, Duration::from_secs(300));
827 assert_eq!(config.keepalive_interval, Some(Duration::from_secs(30)));
828 }
829
830 #[test]
831 fn test_timeout_config_builder() {
832 let config = TimeoutConfig::new()
833 .connect_timeout(Duration::from_secs(5))
834 .tls_timeout(Duration::from_secs(3))
835 .login_timeout(Duration::from_secs(10))
836 .command_timeout(Duration::from_secs(60))
837 .idle_timeout(Duration::from_secs(600))
838 .keepalive_interval(Some(Duration::from_secs(60)));
839
840 assert_eq!(config.connect_timeout, Duration::from_secs(5));
841 assert_eq!(config.tls_timeout, Duration::from_secs(3));
842 assert_eq!(config.login_timeout, Duration::from_secs(10));
843 assert_eq!(config.command_timeout, Duration::from_secs(60));
844 assert_eq!(config.idle_timeout, Duration::from_secs(600));
845 assert_eq!(config.keepalive_interval, Some(Duration::from_secs(60)));
846 }
847
848 #[test]
849 fn test_timeout_config_no_keepalive() {
850 let config = TimeoutConfig::new().no_keepalive();
851 assert_eq!(config.keepalive_interval, None);
852 }
853
854 #[test]
855 fn test_timeout_config_total_connect() {
856 let config = TimeoutConfig::new()
857 .connect_timeout(Duration::from_secs(5))
858 .tls_timeout(Duration::from_secs(3))
859 .login_timeout(Duration::from_secs(10));
860
861 assert_eq!(config.total_connect_timeout(), Duration::from_secs(18));
863 }
864
865 #[test]
866 fn test_config_timeouts_builder() {
867 let timeouts = TimeoutConfig::new()
868 .connect_timeout(Duration::from_secs(5))
869 .command_timeout(Duration::from_secs(60));
870
871 let config = Config::new().timeouts(timeouts);
872 assert_eq!(config.timeouts.connect_timeout, Duration::from_secs(5));
873 assert_eq!(config.timeouts.command_timeout, Duration::from_secs(60));
874 assert_eq!(config.connect_timeout, Duration::from_secs(5));
876 assert_eq!(config.command_timeout, Duration::from_secs(60));
877 }
878
879 #[test]
880 fn test_tds_version_default() {
881 let config = Config::default();
882 assert_eq!(config.tds_version, TdsVersion::V7_4);
883 assert!(!config.strict_mode);
884 }
885
886 #[test]
887 fn test_tds_version_builder() {
888 let config = Config::new().tds_version(TdsVersion::V7_3A);
889 assert_eq!(config.tds_version, TdsVersion::V7_3A);
890 assert!(!config.strict_mode);
891
892 let config = Config::new().tds_version(TdsVersion::V7_3B);
893 assert_eq!(config.tds_version, TdsVersion::V7_3B);
894 assert!(!config.strict_mode);
895
896 let config = Config::new().tds_version(TdsVersion::V8_0);
898 assert_eq!(config.tds_version, TdsVersion::V8_0);
899 assert!(config.strict_mode);
900 }
901
902 #[test]
903 fn test_strict_mode_sets_tds_8() {
904 let config = Config::new().strict_mode(true);
905 assert!(config.strict_mode);
906 assert_eq!(config.tds_version, TdsVersion::V8_0);
907 }
908
909 #[test]
910 fn test_connection_string_tds_version() {
911 let config = Config::from_connection_string("Server=localhost;TDSVersion=7.3;").unwrap();
913 assert_eq!(config.tds_version, TdsVersion::V7_3A);
914
915 let config = Config::from_connection_string("Server=localhost;TDSVersion=7.3A;").unwrap();
917 assert_eq!(config.tds_version, TdsVersion::V7_3A);
918
919 let config = Config::from_connection_string("Server=localhost;TDSVersion=7.3B;").unwrap();
921 assert_eq!(config.tds_version, TdsVersion::V7_3B);
922
923 let config = Config::from_connection_string("Server=localhost;TDSVersion=7.4;").unwrap();
925 assert_eq!(config.tds_version, TdsVersion::V7_4);
926
927 let config = Config::from_connection_string("Server=localhost;TDSVersion=8.0;").unwrap();
929 assert_eq!(config.tds_version, TdsVersion::V8_0);
930 assert!(config.strict_mode);
931
932 let config =
934 Config::from_connection_string("Server=localhost;ProtocolVersion=7.3;").unwrap();
935 assert_eq!(config.tds_version, TdsVersion::V7_3A);
936 }
937
938 #[test]
939 fn test_connection_string_invalid_tds_version() {
940 let result = Config::from_connection_string("Server=localhost;TDSVersion=invalid;");
941 assert!(result.is_err());
942
943 let result = Config::from_connection_string("Server=localhost;TDSVersion=9.0;");
944 assert!(result.is_err());
945 }
946}