1mod types;
7pub use types::*;
8
9use std::time::Duration;
10
11use mssql_auth::Credentials;
12#[cfg(feature = "tls")]
13use mssql_tls::TlsConfig;
14use tds_protocol::version::TdsVersion;
15
16#[derive(Debug, Clone)]
22#[non_exhaustive]
23pub struct Config {
24 pub host: String,
26
27 pub port: u16,
29
30 pub database: Option<String>,
32
33 pub credentials: Credentials,
35
36 #[cfg(feature = "tls")]
38 pub tls: TlsConfig,
39
40 pub application_name: String,
42
43 pub connect_timeout: Duration,
45
46 pub command_timeout: Duration,
48
49 pub packet_size: u16,
51
52 pub strict_mode: bool,
54
55 pub trust_server_certificate: bool,
57
58 pub instance: Option<String>,
60
61 pub mars: bool,
63
64 pub encrypt: bool,
68
69 pub no_tls: bool,
88
89 pub redirect: RedirectConfig,
91
92 pub retry: RetryPolicy,
94
95 pub timeouts: TimeoutConfig,
97
98 pub tds_version: TdsVersion,
111}
112
113impl Default for Config {
114 fn default() -> Self {
115 let timeouts = TimeoutConfig::default();
116 Self {
117 host: "localhost".to_string(),
118 port: 1433,
119 database: None,
120 credentials: Credentials::sql_server("", ""),
121 #[cfg(feature = "tls")]
122 tls: TlsConfig::default(),
123 application_name: "mssql-client".to_string(),
124 connect_timeout: timeouts.connect_timeout,
125 command_timeout: timeouts.command_timeout,
126 packet_size: 4096,
127 strict_mode: false,
128 trust_server_certificate: false,
129 instance: None,
130 mars: false,
131 encrypt: true, no_tls: false, redirect: RedirectConfig::default(),
134 retry: RetryPolicy::default(),
135 timeouts,
136 tds_version: TdsVersion::V7_4, }
138 }
139}
140
141impl Config {
142 #[must_use]
144 pub fn new() -> Self {
145 Self::default()
146 }
147
148 pub fn from_connection_string(conn_str: &str) -> Result<Self, crate::error::Error> {
155 let mut config = Self::default();
156
157 for part in conn_str.split(';') {
158 let part = part.trim();
159 if part.is_empty() {
160 continue;
161 }
162
163 let (key, value) = part
164 .split_once('=')
165 .ok_or_else(|| crate::error::Error::Config(format!("invalid key-value: {part}")))?;
166
167 let key = key.trim().to_lowercase();
168 let value = value.trim();
169
170 match key.as_str() {
171 "server" | "data source" | "host" => {
172 if let Some((host, port_or_instance)) = value.split_once(',') {
174 config.host = host.to_string();
175 config.port = port_or_instance.parse().map_err(|_| {
176 crate::error::Error::Config(format!("invalid port: {port_or_instance}"))
177 })?;
178 } else if let Some((host, instance)) = value.split_once('\\') {
179 config.host = host.to_string();
180 config.instance = Some(instance.to_string());
181 } else {
182 config.host = value.to_string();
183 }
184 }
185 "port" => {
186 config.port = value.parse().map_err(|_| {
187 crate::error::Error::Config(format!("invalid port: {value}"))
188 })?;
189 }
190 "database" | "initial catalog" => {
191 config.database = Some(value.to_string());
192 }
193 "user id" | "uid" | "user" => {
194 if let Credentials::SqlServer { password, .. } = &config.credentials {
196 config.credentials =
197 Credentials::sql_server(value.to_string(), password.clone());
198 }
199 }
200 "password" | "pwd" => {
201 if let Credentials::SqlServer { username, .. } = &config.credentials {
203 config.credentials =
204 Credentials::sql_server(username.clone(), value.to_string());
205 }
206 }
207 "application name" | "app" => {
208 config.application_name = value.to_string();
209 }
210 "connect timeout" | "connection timeout" => {
211 let secs: u64 = value.parse().map_err(|_| {
212 crate::error::Error::Config(format!("invalid timeout: {value}"))
213 })?;
214 config.connect_timeout = Duration::from_secs(secs);
215 }
216 "command timeout" => {
217 let secs: u64 = value.parse().map_err(|_| {
218 crate::error::Error::Config(format!("invalid timeout: {value}"))
219 })?;
220 config.command_timeout = Duration::from_secs(secs);
221 }
222 "trustservercertificate" | "trust server certificate" => {
223 config.trust_server_certificate = value.eq_ignore_ascii_case("true")
224 || value.eq_ignore_ascii_case("yes")
225 || value == "1";
226 }
227 "encrypt" => {
228 if value.eq_ignore_ascii_case("strict") {
230 config.strict_mode = true;
231 config.encrypt = true;
232 config.no_tls = false;
233 } else if value.eq_ignore_ascii_case("no_tls") {
234 config.no_tls = true;
237 config.encrypt = false;
238 } else if value.eq_ignore_ascii_case("true")
239 || value.eq_ignore_ascii_case("yes")
240 || value == "1"
241 {
242 config.encrypt = true;
243 config.no_tls = false;
244 } else if value.eq_ignore_ascii_case("false")
245 || value.eq_ignore_ascii_case("no")
246 || value == "0"
247 {
248 config.encrypt = false;
249 config.no_tls = false;
250 }
251 }
252 "multipleactiveresultsets" | "mars" => {
253 config.mars = value.eq_ignore_ascii_case("true")
254 || value.eq_ignore_ascii_case("yes")
255 || value == "1";
256 }
257 "packet size" => {
258 config.packet_size = value.parse().map_err(|_| {
259 crate::error::Error::Config(format!("invalid packet size: {value}"))
260 })?;
261 }
262 "tdsversion" | "tds version" | "protocolversion" | "protocol version" => {
263 config.tds_version = TdsVersion::parse(value).ok_or_else(|| {
266 crate::error::Error::Config(format!(
267 "invalid TDS version: {value}. Supported values: 7.3, 7.3A, 7.3B, 7.4, 8.0"
268 ))
269 })?;
270 if config.tds_version.is_tds_8() {
272 config.strict_mode = true;
273 }
274 }
275 "integrated security" | "trusted_connection" => {
276 if value.eq_ignore_ascii_case("true")
277 || value.eq_ignore_ascii_case("yes")
278 || value.eq_ignore_ascii_case("sspi")
279 || value == "1"
280 {
281 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
282 {
283 config.credentials = Credentials::Integrated;
284 }
285 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
286 {
287 return Err(crate::error::Error::Config(
288 "Integrated Security requires the 'integrated-auth' (Linux/macOS) \
289 or 'sspi-auth' (Windows) feature to be enabled"
290 .into(),
291 ));
292 }
293 }
294 }
295 _ => {
296 tracing::debug!(
298 key = key,
299 value = value,
300 "ignoring unknown connection string option"
301 );
302 }
303 }
304 }
305
306 Ok(config)
307 }
308
309 #[must_use]
311 pub fn host(mut self, host: impl Into<String>) -> Self {
312 self.host = host.into();
313 self
314 }
315
316 #[must_use]
318 pub fn port(mut self, port: u16) -> Self {
319 self.port = port;
320 self
321 }
322
323 #[must_use]
325 pub fn database(mut self, database: impl Into<String>) -> Self {
326 self.database = Some(database.into());
327 self
328 }
329
330 #[must_use]
332 pub fn credentials(mut self, credentials: Credentials) -> Self {
333 self.credentials = credentials;
334 self
335 }
336
337 #[must_use]
339 pub fn application_name(mut self, name: impl Into<String>) -> Self {
340 self.application_name = name.into();
341 self
342 }
343
344 #[must_use]
346 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
347 self.connect_timeout = timeout;
348 self
349 }
350
351 #[must_use]
353 pub fn trust_server_certificate(mut self, trust: bool) -> Self {
354 self.trust_server_certificate = trust;
355 #[cfg(feature = "tls")]
356 {
357 self.tls = self.tls.trust_server_certificate(trust);
358 }
359 self
360 }
361
362 #[must_use]
364 pub fn strict_mode(mut self, enabled: bool) -> Self {
365 self.strict_mode = enabled;
366 #[cfg(feature = "tls")]
367 {
368 self.tls = self.tls.strict_mode(enabled);
369 }
370 if enabled {
371 self.tds_version = TdsVersion::V8_0;
372 }
373 self
374 }
375
376 #[must_use]
400 pub fn tds_version(mut self, version: TdsVersion) -> Self {
401 self.tds_version = version;
402 if version.is_tds_8() {
404 self.strict_mode = true;
405 #[cfg(feature = "tls")]
406 {
407 self.tls = self.tls.strict_mode(true);
408 }
409 }
410 self
411 }
412
413 #[must_use]
421 pub fn encrypt(mut self, enabled: bool) -> Self {
422 self.encrypt = enabled;
423 self
424 }
425
426 #[must_use]
461 pub fn no_tls(mut self, enabled: bool) -> Self {
462 self.no_tls = enabled;
463 if enabled {
464 self.encrypt = false;
465 }
466 self
467 }
468
469 #[must_use]
471 pub fn with_host(mut self, host: &str) -> Self {
472 self.host = host.to_string();
473 self
474 }
475
476 #[must_use]
478 pub fn with_port(mut self, port: u16) -> Self {
479 self.port = port;
480 self
481 }
482
483 #[must_use]
485 pub fn redirect(mut self, redirect: RedirectConfig) -> Self {
486 self.redirect = redirect;
487 self
488 }
489
490 #[must_use]
492 pub fn max_redirects(mut self, max: u8) -> Self {
493 self.redirect.max_redirects = max;
494 self
495 }
496
497 #[must_use]
499 pub fn retry(mut self, retry: RetryPolicy) -> Self {
500 self.retry = retry;
501 self
502 }
503
504 #[must_use]
506 pub fn max_retries(mut self, max: u32) -> Self {
507 self.retry.max_retries = max;
508 self
509 }
510
511 #[must_use]
513 pub fn timeouts(mut self, timeouts: TimeoutConfig) -> Self {
514 self.connect_timeout = timeouts.connect_timeout;
516 self.command_timeout = timeouts.command_timeout;
517 self.timeouts = timeouts;
518 self
519 }
520}
521
522#[cfg(test)]
523#[allow(clippy::unwrap_used)]
524mod tests {
525 use super::*;
526
527 #[test]
528 fn test_connection_string_parsing() {
529 let config = Config::from_connection_string(
530 "Server=localhost;Database=test;User Id=sa;Password=secret;",
531 )
532 .unwrap();
533
534 assert_eq!(config.host, "localhost");
535 assert_eq!(config.database, Some("test".to_string()));
536 }
537
538 #[test]
539 fn test_connection_string_with_port() {
540 let config =
541 Config::from_connection_string("Server=localhost,1434;Database=test;").unwrap();
542
543 assert_eq!(config.host, "localhost");
544 assert_eq!(config.port, 1434);
545 }
546
547 #[test]
548 fn test_connection_string_with_instance() {
549 let config =
550 Config::from_connection_string("Server=localhost\\SQLEXPRESS;Database=test;").unwrap();
551
552 assert_eq!(config.host, "localhost");
553 assert_eq!(config.instance, Some("SQLEXPRESS".to_string()));
554 }
555
556 #[test]
557 fn test_redirect_config_defaults() {
558 let config = RedirectConfig::default();
559 assert_eq!(config.max_redirects, 2);
560 assert!(config.follow_redirects);
561 }
562
563 #[test]
564 fn test_redirect_config_builder() {
565 let config = RedirectConfig::new()
566 .max_redirects(5)
567 .follow_redirects(false);
568 assert_eq!(config.max_redirects, 5);
569 assert!(!config.follow_redirects);
570 }
571
572 #[test]
573 fn test_redirect_config_no_follow() {
574 let config = RedirectConfig::no_follow();
575 assert_eq!(config.max_redirects, 0);
576 assert!(!config.follow_redirects);
577 }
578
579 #[test]
580 fn test_config_redirect_builder() {
581 let config = Config::new().max_redirects(3);
582 assert_eq!(config.redirect.max_redirects, 3);
583
584 let config2 = Config::new().redirect(RedirectConfig::no_follow());
585 assert!(!config2.redirect.follow_redirects);
586 }
587
588 #[test]
589 fn test_retry_policy_defaults() {
590 let policy = RetryPolicy::default();
591 assert_eq!(policy.max_retries, 3);
592 assert_eq!(policy.initial_backoff, Duration::from_millis(100));
593 assert_eq!(policy.max_backoff, Duration::from_secs(30));
594 assert!((policy.backoff_multiplier - 2.0).abs() < f64::EPSILON);
595 assert!(policy.jitter);
596 }
597
598 #[test]
599 fn test_retry_policy_builder() {
600 let policy = RetryPolicy::new()
601 .max_retries(5)
602 .initial_backoff(Duration::from_millis(200))
603 .max_backoff(Duration::from_secs(60))
604 .backoff_multiplier(3.0)
605 .jitter(false);
606
607 assert_eq!(policy.max_retries, 5);
608 assert_eq!(policy.initial_backoff, Duration::from_millis(200));
609 assert_eq!(policy.max_backoff, Duration::from_secs(60));
610 assert!((policy.backoff_multiplier - 3.0).abs() < f64::EPSILON);
611 assert!(!policy.jitter);
612 }
613
614 #[test]
615 fn test_retry_policy_no_retry() {
616 let policy = RetryPolicy::no_retry();
617 assert_eq!(policy.max_retries, 0);
618 assert!(!policy.should_retry(0));
619 }
620
621 #[test]
622 fn test_retry_policy_should_retry() {
623 let policy = RetryPolicy::new().max_retries(3);
624 assert!(policy.should_retry(0));
625 assert!(policy.should_retry(1));
626 assert!(policy.should_retry(2));
627 assert!(!policy.should_retry(3));
628 assert!(!policy.should_retry(4));
629 }
630
631 #[test]
632 fn test_retry_policy_backoff_calculation() {
633 let policy = RetryPolicy::new()
634 .initial_backoff(Duration::from_millis(100))
635 .backoff_multiplier(2.0)
636 .max_backoff(Duration::from_secs(10))
637 .jitter(false);
638
639 assert_eq!(policy.backoff_for_attempt(0), Duration::ZERO);
640 assert_eq!(policy.backoff_for_attempt(1), Duration::from_millis(100));
641 assert_eq!(policy.backoff_for_attempt(2), Duration::from_millis(200));
642 assert_eq!(policy.backoff_for_attempt(3), Duration::from_millis(400));
643 }
644
645 #[test]
646 fn test_retry_policy_backoff_capped() {
647 let policy = RetryPolicy::new()
648 .initial_backoff(Duration::from_secs(1))
649 .backoff_multiplier(10.0)
650 .max_backoff(Duration::from_secs(5))
651 .jitter(false);
652
653 assert_eq!(policy.backoff_for_attempt(3), Duration::from_secs(5));
655 }
656
657 #[test]
658 fn test_config_retry_builder() {
659 let config = Config::new().max_retries(5);
660 assert_eq!(config.retry.max_retries, 5);
661
662 let config2 = Config::new().retry(RetryPolicy::no_retry());
663 assert_eq!(config2.retry.max_retries, 0);
664 }
665
666 #[test]
667 fn test_timeout_config_defaults() {
668 let config = TimeoutConfig::default();
669 assert_eq!(config.connect_timeout, Duration::from_secs(15));
670 assert_eq!(config.tls_timeout, Duration::from_secs(10));
671 assert_eq!(config.login_timeout, Duration::from_secs(30));
672 assert_eq!(config.command_timeout, Duration::from_secs(30));
673 assert_eq!(config.idle_timeout, Duration::from_secs(300));
674 assert_eq!(config.keepalive_interval, Some(Duration::from_secs(30)));
675 }
676
677 #[test]
678 fn test_timeout_config_builder() {
679 let config = TimeoutConfig::new()
680 .connect_timeout(Duration::from_secs(5))
681 .tls_timeout(Duration::from_secs(3))
682 .login_timeout(Duration::from_secs(10))
683 .command_timeout(Duration::from_secs(60))
684 .idle_timeout(Duration::from_secs(600))
685 .keepalive_interval(Some(Duration::from_secs(60)));
686
687 assert_eq!(config.connect_timeout, Duration::from_secs(5));
688 assert_eq!(config.tls_timeout, Duration::from_secs(3));
689 assert_eq!(config.login_timeout, Duration::from_secs(10));
690 assert_eq!(config.command_timeout, Duration::from_secs(60));
691 assert_eq!(config.idle_timeout, Duration::from_secs(600));
692 assert_eq!(config.keepalive_interval, Some(Duration::from_secs(60)));
693 }
694
695 #[test]
696 fn test_timeout_config_no_keepalive() {
697 let config = TimeoutConfig::new().no_keepalive();
698 assert_eq!(config.keepalive_interval, None);
699 }
700
701 #[test]
702 fn test_timeout_config_total_connect() {
703 let config = TimeoutConfig::new()
704 .connect_timeout(Duration::from_secs(5))
705 .tls_timeout(Duration::from_secs(3))
706 .login_timeout(Duration::from_secs(10));
707
708 assert_eq!(config.total_connect_timeout(), Duration::from_secs(18));
710 }
711
712 #[test]
713 fn test_config_timeouts_builder() {
714 let timeouts = TimeoutConfig::new()
715 .connect_timeout(Duration::from_secs(5))
716 .command_timeout(Duration::from_secs(60));
717
718 let config = Config::new().timeouts(timeouts);
719 assert_eq!(config.timeouts.connect_timeout, Duration::from_secs(5));
720 assert_eq!(config.timeouts.command_timeout, Duration::from_secs(60));
721 assert_eq!(config.connect_timeout, Duration::from_secs(5));
723 assert_eq!(config.command_timeout, Duration::from_secs(60));
724 }
725
726 #[test]
727 fn test_tds_version_default() {
728 let config = Config::default();
729 assert_eq!(config.tds_version, TdsVersion::V7_4);
730 assert!(!config.strict_mode);
731 }
732
733 #[test]
734 fn test_tds_version_builder() {
735 let config = Config::new().tds_version(TdsVersion::V7_3A);
736 assert_eq!(config.tds_version, TdsVersion::V7_3A);
737 assert!(!config.strict_mode);
738
739 let config = Config::new().tds_version(TdsVersion::V7_3B);
740 assert_eq!(config.tds_version, TdsVersion::V7_3B);
741 assert!(!config.strict_mode);
742
743 let config = Config::new().tds_version(TdsVersion::V8_0);
745 assert_eq!(config.tds_version, TdsVersion::V8_0);
746 assert!(config.strict_mode);
747 }
748
749 #[test]
750 fn test_strict_mode_sets_tds_8() {
751 let config = Config::new().strict_mode(true);
752 assert!(config.strict_mode);
753 assert_eq!(config.tds_version, TdsVersion::V8_0);
754 }
755
756 #[test]
757 fn test_connection_string_tds_version() {
758 let config = Config::from_connection_string("Server=localhost;TDSVersion=7.3;").unwrap();
760 assert_eq!(config.tds_version, TdsVersion::V7_3A);
761
762 let config = Config::from_connection_string("Server=localhost;TDSVersion=7.3A;").unwrap();
764 assert_eq!(config.tds_version, TdsVersion::V7_3A);
765
766 let config = Config::from_connection_string("Server=localhost;TDSVersion=7.3B;").unwrap();
768 assert_eq!(config.tds_version, TdsVersion::V7_3B);
769
770 let config = Config::from_connection_string("Server=localhost;TDSVersion=7.4;").unwrap();
772 assert_eq!(config.tds_version, TdsVersion::V7_4);
773
774 let config = Config::from_connection_string("Server=localhost;TDSVersion=8.0;").unwrap();
776 assert_eq!(config.tds_version, TdsVersion::V8_0);
777 assert!(config.strict_mode);
778
779 let config =
781 Config::from_connection_string("Server=localhost;ProtocolVersion=7.3;").unwrap();
782 assert_eq!(config.tds_version, TdsVersion::V7_3A);
783 }
784
785 #[test]
786 fn test_connection_string_invalid_tds_version() {
787 let result = Config::from_connection_string("Server=localhost;TDSVersion=invalid;");
788 assert!(result.is_err());
789
790 let result = Config::from_connection_string("Server=localhost;TDSVersion=9.0;");
791 assert!(result.is_err());
792 }
793
794 #[test]
795 fn test_connection_string_no_tls() {
796 let config = Config::from_connection_string("Server=legacy;Encrypt=no_tls;").unwrap();
798 assert!(config.no_tls);
799 assert!(!config.encrypt);
800 assert!(!config.strict_mode);
801
802 let config = Config::from_connection_string("Server=legacy;Encrypt=no_tls;").unwrap();
804 assert!(config.no_tls);
805
806 let config = Config::from_connection_string("Server=localhost;Encrypt=true;").unwrap();
808 assert!(!config.no_tls);
809 assert!(config.encrypt);
810
811 let config = Config::from_connection_string("Server=localhost;Encrypt=strict;").unwrap();
813 assert!(!config.no_tls);
814 assert!(config.encrypt);
815 assert!(config.strict_mode);
816 }
817
818 #[test]
819 fn test_no_tls_builder() {
820 let config = Config::new().no_tls(true);
822 assert!(config.no_tls);
823 assert!(!config.encrypt);
824
825 let config = Config::new().no_tls(true).no_tls(false);
827 assert!(!config.no_tls);
828 }
829
830 #[test]
831 #[cfg(any(feature = "integrated-auth", feature = "sspi-auth"))]
832 fn test_connection_string_integrated_security() {
833 let config =
835 Config::from_connection_string("Server=localhost;Integrated Security=true;").unwrap();
836 assert_eq!(
837 config.credentials.method_name(),
838 "Integrated Authentication"
839 );
840
841 let config =
843 Config::from_connection_string("Server=localhost;Integrated Security=yes;").unwrap();
844 assert_eq!(
845 config.credentials.method_name(),
846 "Integrated Authentication"
847 );
848
849 let config =
851 Config::from_connection_string("Server=localhost;Integrated Security=sspi;").unwrap();
852 assert_eq!(
853 config.credentials.method_name(),
854 "Integrated Authentication"
855 );
856
857 let config =
859 Config::from_connection_string("Server=localhost;Integrated Security=1;").unwrap();
860 assert_eq!(
861 config.credentials.method_name(),
862 "Integrated Authentication"
863 );
864
865 let config =
867 Config::from_connection_string("Server=localhost;Trusted_Connection=true;").unwrap();
868 assert_eq!(
869 config.credentials.method_name(),
870 "Integrated Authentication"
871 );
872 }
873
874 #[test]
875 #[cfg(not(any(feature = "integrated-auth", feature = "sspi-auth")))]
876 fn test_connection_string_integrated_security_without_feature() {
877 let result = Config::from_connection_string("Server=localhost;Integrated Security=true;");
879 assert!(result.is_err());
880 let err = result.unwrap_err().to_string();
881 assert!(err.contains("integrated-auth"));
882 }
883}