1#[cfg(test)]
2mod test;
3
4mod bulk_write;
5mod parse;
6mod resolver_config;
7
8use std::{
9 cmp::Ordering,
10 collections::HashSet,
11 convert::TryFrom,
12 fmt::{self, Display, Formatter, Write},
13 hash::{Hash, Hasher},
14 net::Ipv6Addr,
15 path::PathBuf,
16 str::FromStr,
17 time::Duration,
18};
19
20use bson::UuidRepresentation;
21use derive_where::derive_where;
22use macro_magic::export_tokens;
23use once_cell::sync::Lazy;
24use serde::{de::Unexpected, Deserialize, Deserializer, Serialize};
25use serde_with::skip_serializing_none;
26use strsim::jaro_winkler;
27use typed_builder::TypedBuilder;
28
29#[cfg(any(
30 feature = "zstd-compression",
31 feature = "zlib-compression",
32 feature = "snappy-compression"
33))]
34use crate::options::Compressor;
35#[cfg(test)]
36use crate::srv::LookupHosts;
37use crate::{
38 bson::{doc, Bson, Document},
39 client::auth::{AuthMechanism, Credential},
40 concern::{Acknowledgment, ReadConcern, WriteConcern},
41 error::{Error, ErrorKind, Result},
42 event::EventHandler,
43 options::ReadConcernLevel,
44 sdam::{verify_max_staleness, DEFAULT_HEARTBEAT_FREQUENCY, MIN_HEARTBEAT_FREQUENCY},
45 selection_criteria::{ReadPreference, SelectionCriteria, TagSet},
46 serde_util,
47 srv::{OriginalSrvInfo, SrvResolver},
48};
49
50pub use bulk_write::*;
51#[cfg(feature = "dns-resolver")]
52pub use resolver_config::ResolverConfig;
53#[cfg(not(feature = "dns-resolver"))]
54pub(crate) use resolver_config::ResolverConfig;
55
56pub(crate) const DEFAULT_PORT: u16 = 27017;
57
58const URI_OPTIONS: &[&str] = &[
59 "appname",
60 "authmechanism",
61 "authsource",
62 "authmechanismproperties",
63 "compressors",
64 "connecttimeoutms",
65 "directconnection",
66 "heartbeatfrequencyms",
67 "journal",
68 "localthresholdms",
69 "maxidletimems",
70 "maxstalenessseconds",
71 "maxpoolsize",
72 "minpoolsize",
73 "maxconnecting",
74 "readconcernlevel",
75 "readpreference",
76 "readpreferencetags",
77 "replicaset",
78 "retrywrites",
79 "retryreads",
80 "servermonitoringmode",
81 "serverselectiontimeoutms",
82 "sockettimeoutms",
83 "tls",
84 "ssl",
85 "tlsinsecure",
86 "tlsallowinvalidcertificates",
87 "tlscafile",
88 "tlscertificatekeyfile",
89 "uuidRepresentation",
90 "w",
91 "waitqueuetimeoutms",
92 "wtimeoutms",
93 "zlibcompressionlevel",
94 "srvservicename",
95];
96
97static USERINFO_RESERVED_CHARACTERS: Lazy<HashSet<&'static char>> =
101 Lazy::new(|| [':', '/', '?', '#', '[', ']', '@'].iter().collect());
102
103static ILLEGAL_DATABASE_CHARACTERS: Lazy<HashSet<&'static char>> =
104 Lazy::new(|| ['/', '\\', ' ', '"', '$'].iter().collect());
105
106#[derive(Clone, Debug, Eq, Serialize)]
108#[non_exhaustive]
109pub enum ServerAddress {
110 Tcp {
112 host: String,
114
115 port: Option<u16>,
119 },
120 #[cfg(unix)]
122 Unix {
123 path: PathBuf,
125 },
126}
127
128impl<'de> Deserialize<'de> for ServerAddress {
129 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
130 where
131 D: Deserializer<'de>,
132 {
133 #[derive(Deserialize)]
134 #[serde(untagged)]
135 enum ServerAddressHelper {
136 String(String),
137 Object { host: String, port: Option<u16> },
138 }
139
140 let helper = ServerAddressHelper::deserialize(deserializer)?;
141 match helper {
142 ServerAddressHelper::String(string) => {
143 Self::parse(string).map_err(serde::de::Error::custom)
144 }
145 ServerAddressHelper::Object { host, port } => {
146 #[cfg(unix)]
147 if host.ends_with("sock") {
148 return Ok(Self::Unix {
149 path: PathBuf::from(host),
150 });
151 }
152
153 Ok(Self::Tcp { host, port })
154 }
155 }
156 }
157}
158
159impl Default for ServerAddress {
160 fn default() -> Self {
161 Self::Tcp {
162 host: "localhost".into(),
163 port: None,
164 }
165 }
166}
167
168impl PartialEq for ServerAddress {
169 fn eq(&self, other: &Self) -> bool {
170 match (self, other) {
171 (
172 Self::Tcp { host, port },
173 Self::Tcp {
174 host: other_host,
175 port: other_port,
176 },
177 ) => host == other_host && port.unwrap_or(27017) == other_port.unwrap_or(27017),
178 #[cfg(unix)]
179 (Self::Unix { path }, Self::Unix { path: other_path }) => path == other_path,
180 #[cfg(unix)]
181 _ => false,
182 }
183 }
184}
185
186impl Hash for ServerAddress {
187 fn hash<H>(&self, state: &mut H)
188 where
189 H: Hasher,
190 {
191 match self {
192 Self::Tcp { host, port } => {
193 host.hash(state);
194 port.unwrap_or(27017).hash(state);
195 }
196 #[cfg(unix)]
197 Self::Unix { path } => path.hash(state),
198 }
199 }
200}
201
202impl FromStr for ServerAddress {
203 type Err = Error;
204 fn from_str(address: &str) -> Result<Self> {
205 ServerAddress::parse(address)
206 }
207}
208
209impl ServerAddress {
210 pub fn parse(address: impl AsRef<str>) -> Result<Self> {
212 let address = address.as_ref();
213
214 if address.ends_with(".sock") {
215 #[cfg(unix)]
216 {
217 let address = percent_decode(address, "unix domain sockets must be URL-encoded")?;
218 return Ok(Self::Unix {
219 path: PathBuf::from(address),
220 });
221 }
222 #[cfg(not(unix))]
223 return Err(ErrorKind::InvalidArgument {
224 message: "unix domain sockets are not supported on this platform".to_string(),
225 }
226 .into());
227 }
228
229 let (hostname, port) = if let Some(ip_literal) = address.strip_prefix("[") {
230 let Some((hostname, port)) = ip_literal.split_once("]") else {
231 return Err(ErrorKind::InvalidArgument {
232 message: format!(
233 "invalid server address {}: missing closing ']' in IP literal hostname",
234 address
235 ),
236 }
237 .into());
238 };
239
240 if let Err(parse_error) = Ipv6Addr::from_str(hostname) {
241 return Err(ErrorKind::InvalidArgument {
242 message: format!("invalid server address {}: {}", address, parse_error),
243 }
244 .into());
245 }
246
247 let port = if port.is_empty() {
248 None
249 } else if let Some(port) = port.strip_prefix(":") {
250 Some(port)
251 } else {
252 return Err(ErrorKind::InvalidArgument {
253 message: format!(
254 "invalid server address {}: the hostname can only be followed by a port \
255 prefixed with ':', got {}",
256 address, port
257 ),
258 }
259 .into());
260 };
261
262 (hostname, port)
263 } else {
264 match address.split_once(":") {
265 Some((hostname, port)) => (hostname, Some(port)),
266 None => (address, None),
267 }
268 };
269
270 if hostname.is_empty() {
271 return Err(ErrorKind::InvalidArgument {
272 message: format!(
273 "invalid server address {}: the hostname cannot be empty",
274 address
275 ),
276 }
277 .into());
278 }
279
280 let port = if let Some(port) = port {
281 match u16::from_str(port) {
282 Ok(0) | Err(_) => {
283 return Err(ErrorKind::InvalidArgument {
284 message: format!(
285 "invalid server address {}: the port must be an integer between 1 and \
286 65535, got {}",
287 address, port
288 ),
289 }
290 .into())
291 }
292 Ok(port) => Some(port),
293 }
294 } else {
295 None
296 };
297
298 Ok(Self::Tcp {
299 host: hostname.to_lowercase(),
300 port,
301 })
302 }
303
304 #[cfg(feature = "dns-resolver")]
305 pub(crate) fn host(&self) -> std::borrow::Cow<'_, str> {
306 match self {
307 Self::Tcp { host, .. } => std::borrow::Cow::Borrowed(host.as_str()),
308 #[cfg(unix)]
309 Self::Unix { path } => path.to_string_lossy(),
310 }
311 }
312}
313
314impl fmt::Display for ServerAddress {
315 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
316 match self {
317 Self::Tcp { host, port } => {
318 write!(fmt, "{}:{}", host, port.unwrap_or(DEFAULT_PORT))
319 }
320 #[cfg(unix)]
321 Self::Unix { path } => write!(fmt, "{}", path.display()),
322 }
323 }
324}
325
326#[derive(Clone, Debug, PartialEq, Serialize)]
328#[non_exhaustive]
329pub enum ServerApiVersion {
330 #[serde(rename = "1")]
332 V1,
333}
334
335impl FromStr for ServerApiVersion {
336 type Err = Error;
337
338 fn from_str(str: &str) -> Result<Self> {
339 match str {
340 "1" => Ok(Self::V1),
341 _ => Err(ErrorKind::InvalidArgument {
342 message: format!("invalid server api version string: {}", str),
343 }
344 .into()),
345 }
346 }
347}
348
349impl Display for ServerApiVersion {
350 fn fmt(&self, f: &mut Formatter) -> fmt::Result {
351 match self {
352 Self::V1 => write!(f, "1"),
353 }
354 }
355}
356
357impl<'de> Deserialize<'de> for ServerApiVersion {
358 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
359 where
360 D: Deserializer<'de>,
361 {
362 let s = String::deserialize(deserializer)?;
363
364 ServerApiVersion::from_str(&s).map_err(|_| {
365 serde::de::Error::invalid_value(Unexpected::Str(&s), &"a valid version number")
366 })
367 }
368}
369
370#[serde_with::skip_serializing_none]
373#[derive(Clone, Debug, Deserialize, Serialize, PartialEq, TypedBuilder)]
374#[builder(field_defaults(default, setter(into)))]
375#[non_exhaustive]
376pub struct ServerApi {
377 #[serde(rename = "apiVersion")]
379 #[builder(!default)]
380 pub version: ServerApiVersion,
381
382 #[serde(rename = "apiStrict")]
385 pub strict: Option<bool>,
386
387 #[serde(rename = "apiDeprecationErrors")]
391 pub deprecation_errors: Option<bool>,
392}
393
394#[derive(Clone, Deserialize, TypedBuilder)]
396#[builder(field_defaults(default, setter(into)))]
397#[derive_where(Debug, PartialEq)]
398#[serde(rename_all = "camelCase")]
399#[non_exhaustive]
400pub struct ClientOptions {
401 #[builder(default_code = "vec![ServerAddress::Tcp {
407 host: \"localhost\".to_string(),
408 port: Some(27017),
409 }]")]
410 #[serde(default = "default_hosts")]
411 pub hosts: Vec<ServerAddress>,
412
413 pub app_name: Option<String>,
417
418 #[cfg(any(
423 feature = "zstd-compression",
424 feature = "zlib-compression",
425 feature = "snappy-compression"
426 ))]
427 #[serde(skip)]
428 pub compressors: Option<Vec<Compressor>>,
429
430 #[derive_where(skip)]
432 #[builder(setter(strip_option))]
433 #[serde(skip)]
434 pub cmap_event_handler: Option<EventHandler<crate::event::cmap::CmapEvent>>,
435
436 #[derive_where(skip)]
440 #[builder(setter(strip_option))]
441 #[serde(skip)]
442 pub command_event_handler: Option<EventHandler<crate::event::command::CommandEvent>>,
443
444 pub connect_timeout: Option<Duration>,
449
450 pub credential: Option<Credential>,
452
453 pub direct_connection: Option<bool>,
458
459 pub driver_info: Option<DriverInfo>,
462
463 pub heartbeat_freq: Option<Duration>,
467
468 #[builder(setter(skip))]
470 #[serde(rename = "loadbalanced")]
471 pub load_balanced: Option<bool>,
472
473 pub local_threshold: Option<Duration>,
485
486 pub max_idle_time: Option<Duration>,
491
492 pub max_pool_size: Option<u32>,
499
500 pub min_pool_size: Option<u32>,
506
507 pub max_connecting: Option<u32>,
511
512 pub read_concern: Option<ReadConcern>,
515
516 pub repl_set_name: Option<String>,
518
519 pub retry_reads: Option<bool>,
523
524 pub retry_writes: Option<bool>,
528
529 pub server_monitoring_mode: Option<ServerMonitoringMode>,
533
534 #[derive_where(skip)]
536 #[builder(setter(strip_option))]
537 #[serde(skip)]
538 pub sdam_event_handler: Option<EventHandler<crate::event::sdam::SdamEvent>>,
539
540 pub selection_criteria: Option<SelectionCriteria>,
543
544 pub server_api: Option<ServerApi>,
556
557 pub server_selection_timeout: Option<Duration>,
562
563 pub default_database: Option<String>,
567
568 pub srv_service_name: Option<String>,
570
571 #[builder(setter(skip))]
572 #[derive_where(skip(Debug))]
573 pub(crate) socket_timeout: Option<Duration>,
574
575 pub tls: Option<Tls>,
579
580 #[cfg(feature = "tracing-unstable")]
590 pub tracing_max_document_length_bytes: Option<usize>,
591
592 pub write_concern: Option<WriteConcern>,
595
596 pub srv_max_hosts: Option<u32>,
598
599 #[builder(setter(skip))]
601 #[serde(skip)]
602 #[derive_where(skip(Debug))]
603 pub(crate) original_srv_info: Option<OriginalSrvInfo>,
604
605 #[cfg(test)]
606 #[builder(setter(skip))]
607 #[derive_where(skip(Debug))]
608 pub(crate) original_uri: Option<String>,
609
610 #[builder(setter(skip))]
616 #[serde(skip)]
617 #[derive_where(skip(Debug))]
618 #[cfg(feature = "dns-resolver")]
619 pub(crate) resolver_config: Option<ResolverConfig>,
620
621 #[cfg(test)]
623 #[builder(setter(skip))]
624 #[serde(skip)]
625 #[derive_where(skip)]
626 pub(crate) test_options: Option<TestOptions>,
627}
628
629#[cfg(test)]
630#[derive(Debug, Clone, Default)]
631pub(crate) struct TestOptions {
632 pub(crate) min_heartbeat_freq: Option<Duration>,
634
635 pub(crate) disable_monitoring_threads: bool,
637
638 pub(crate) mock_lookup_hosts: Option<Result<LookupHosts>>,
640
641 pub(crate) async_event_listener: Option<TestEventSender>,
643}
644
645pub(crate) type TestEventSender = tokio::sync::mpsc::Sender<
646 crate::runtime::AcknowledgedMessage<crate::event::command::CommandEvent>,
647>;
648
649fn default_hosts() -> Vec<ServerAddress> {
650 vec![ServerAddress::default()]
651}
652
653impl Default for ClientOptions {
654 fn default() -> Self {
655 Self::builder().build()
656 }
657}
658
659#[cfg(test)]
660impl Serialize for ClientOptions {
661 fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
662 where
663 S: serde::Serializer,
664 {
665 #[derive(Serialize)]
666 struct ClientOptionsHelper<'a> {
667 appname: &'a Option<String>,
668
669 #[serde(serialize_with = "serde_util::serialize_duration_option_as_int_millis")]
670 connecttimeoutms: &'a Option<Duration>,
671
672 #[serde(flatten, serialize_with = "Credential::serialize_for_client_options")]
673 credential: &'a Option<Credential>,
674
675 directconnection: &'a Option<bool>,
676
677 #[serde(serialize_with = "serde_util::serialize_duration_option_as_int_millis")]
678 heartbeatfrequencyms: &'a Option<Duration>,
679
680 #[serde(serialize_with = "serde_util::serialize_duration_option_as_int_millis")]
681 localthresholdms: &'a Option<Duration>,
682
683 #[serde(serialize_with = "serde_util::serialize_duration_option_as_int_millis")]
684 maxidletimems: &'a Option<Duration>,
685
686 maxpoolsize: &'a Option<u32>,
687
688 minpoolsize: &'a Option<u32>,
689
690 maxconnecting: &'a Option<u32>,
691
692 #[serde(flatten, serialize_with = "ReadConcern::serialize_for_client_options")]
693 readconcern: &'a Option<ReadConcern>,
694
695 replicaset: &'a Option<String>,
696
697 retryreads: &'a Option<bool>,
698
699 retrywrites: &'a Option<bool>,
700
701 servermonitoringmode: Option<String>,
702
703 #[serde(
704 flatten,
705 serialize_with = "SelectionCriteria::serialize_for_client_options"
706 )]
707 selectioncriteria: &'a Option<SelectionCriteria>,
708
709 #[serde(serialize_with = "serde_util::serialize_duration_option_as_int_millis")]
710 serverselectiontimeoutms: &'a Option<Duration>,
711
712 #[serde(serialize_with = "serde_util::serialize_duration_option_as_int_millis")]
713 sockettimeoutms: &'a Option<Duration>,
714
715 #[serde(flatten, serialize_with = "Tls::serialize_for_client_options")]
716 tls: &'a Option<Tls>,
717
718 #[serde(flatten, serialize_with = "WriteConcern::serialize_for_client_options")]
719 writeconcern: &'a Option<WriteConcern>,
720
721 zlibcompressionlevel: &'a Option<i32>,
722
723 loadbalanced: &'a Option<bool>,
724
725 srvmaxhosts: Option<i32>,
726
727 srvservicename: &'a Option<String>,
728 }
729
730 let client_options = ClientOptionsHelper {
731 appname: &self.app_name,
732 connecttimeoutms: &self.connect_timeout,
733 credential: &self.credential,
734 directconnection: &self.direct_connection,
735 heartbeatfrequencyms: &self.heartbeat_freq,
736 localthresholdms: &self.local_threshold,
737 maxidletimems: &self.max_idle_time,
738 maxpoolsize: &self.max_pool_size,
739 minpoolsize: &self.min_pool_size,
740 maxconnecting: &self.max_connecting,
741 readconcern: &self.read_concern,
742 replicaset: &self.repl_set_name,
743 retryreads: &self.retry_reads,
744 retrywrites: &self.retry_writes,
745 servermonitoringmode: self
746 .server_monitoring_mode
747 .as_ref()
748 .map(|m| format!("{:?}", m).to_lowercase()),
749 selectioncriteria: &self.selection_criteria,
750 serverselectiontimeoutms: &self.server_selection_timeout,
751 sockettimeoutms: &self.socket_timeout,
752 tls: &self.tls,
753 writeconcern: &self.write_concern,
754 loadbalanced: &self.load_balanced,
755 zlibcompressionlevel: &None,
756 srvmaxhosts: self
757 .srv_max_hosts
758 .map(|v| v.try_into())
759 .transpose()
760 .map_err(serde::ser::Error::custom)?,
761 srvservicename: &self.srv_service_name,
762 };
763
764 client_options.serialize(serializer)
765 }
766}
767
768#[derive(Debug, Default, PartialEq)]
772#[non_exhaustive]
773pub struct ConnectionString {
774 pub host_info: HostInfo,
781
782 pub app_name: Option<String>,
786
787 pub tls: Option<Tls>,
791
792 pub heartbeat_frequency: Option<Duration>,
796
797 pub local_threshold: Option<Duration>,
809
810 pub read_concern: Option<ReadConcern>,
813
814 pub replica_set: Option<String>,
816
817 pub write_concern: Option<WriteConcern>,
820
821 pub server_selection_timeout: Option<Duration>,
826
827 pub max_pool_size: Option<u32>,
834
835 pub min_pool_size: Option<u32>,
841
842 pub max_connecting: Option<u32>,
846
847 pub max_idle_time: Option<Duration>,
852
853 #[cfg(any(
854 feature = "zstd-compression",
855 feature = "zlib-compression",
856 feature = "snappy-compression"
857 ))]
858 pub compressors: Option<Vec<Compressor>>,
863
864 pub connect_timeout: Option<Duration>,
869
870 pub retry_reads: Option<bool>,
874
875 pub retry_writes: Option<bool>,
879
880 pub server_monitoring_mode: Option<ServerMonitoringMode>,
884
885 pub direct_connection: Option<bool>,
890
891 pub credential: Option<Credential>,
893
894 pub default_database: Option<String>,
898
899 pub load_balanced: Option<bool>,
901
902 pub socket_timeout: Option<Duration>,
905
906 pub read_preference: Option<ReadPreference>,
908
909 pub uuid_representation: Option<UuidRepresentation>,
914
915 pub srv_max_hosts: Option<u32>,
917
918 pub srv_service_name: Option<String>,
920
921 wait_queue_timeout: Option<Duration>,
922 tls_insecure: Option<bool>,
923
924 #[cfg(test)]
925 original_uri: String,
926}
927
928#[derive(Debug, Default)]
930struct ConnectionStringParts {
931 read_preference_tags: Option<Vec<TagSet>>,
932 max_staleness: Option<Duration>,
933 auth_mechanism: Option<AuthMechanism>,
934 auth_mechanism_properties: Option<Document>,
935 zlib_compression: Option<i32>,
936 auth_source: Option<String>,
937}
938
939#[derive(Debug, PartialEq, Clone)]
941#[non_exhaustive]
942pub enum HostInfo {
943 HostIdentifiers(Vec<ServerAddress>),
945 DnsRecord(String),
947}
948
949impl Default for HostInfo {
950 fn default() -> Self {
951 Self::HostIdentifiers(vec![])
952 }
953}
954
955impl HostInfo {
956 async fn resolve(
957 self,
958 resolver_config: Option<ResolverConfig>,
959 srv_service_name: Option<String>,
960 ) -> Result<ResolvedHostInfo> {
961 Ok(match self {
962 Self::HostIdentifiers(hosts) => ResolvedHostInfo::HostIdentifiers(hosts),
963 Self::DnsRecord(hostname) => {
964 let mut resolver =
965 SrvResolver::new(resolver_config.clone(), srv_service_name).await?;
966 let config = resolver.resolve_client_options(&hostname).await?;
967 ResolvedHostInfo::DnsRecord { hostname, config }
968 }
969 })
970 }
971}
972
973enum ResolvedHostInfo {
974 HostIdentifiers(Vec<ServerAddress>),
975 DnsRecord {
976 hostname: String,
977 config: crate::srv::ResolvedConfig,
978 },
979}
980
981#[derive(Clone, Debug, Deserialize, PartialEq)]
984pub enum Tls {
985 Enabled(TlsOptions),
987
988 Disabled,
990}
991
992impl From<TlsOptions> for Tls {
993 fn from(options: TlsOptions) -> Self {
994 Self::Enabled(options)
995 }
996}
997
998impl From<TlsOptions> for Option<Tls> {
999 fn from(options: TlsOptions) -> Self {
1000 Some(Tls::Enabled(options))
1001 }
1002}
1003
1004impl Tls {
1005 #[cfg(test)]
1006 pub(crate) fn serialize_for_client_options<S>(
1007 tls: &Option<Tls>,
1008 serializer: S,
1009 ) -> std::result::Result<S::Ok, S::Error>
1010 where
1011 S: serde::Serializer,
1012 {
1013 match tls {
1014 Some(Tls::Enabled(tls_options)) => {
1015 TlsOptions::serialize_for_client_options(tls_options, serializer)
1016 }
1017 _ => serializer.serialize_none(),
1018 }
1019 }
1020}
1021
1022#[derive(Clone, Debug, Default, Deserialize, PartialEq, TypedBuilder)]
1024#[builder(field_defaults(default, setter(into)))]
1025#[non_exhaustive]
1026pub struct TlsOptions {
1027 pub allow_invalid_certificates: Option<bool>,
1033
1034 pub ca_file_path: Option<PathBuf>,
1038
1039 pub cert_key_file_path: Option<PathBuf>,
1044
1045 #[cfg(feature = "openssl-tls")]
1050 pub allow_invalid_hostnames: Option<bool>,
1051
1052 #[cfg(feature = "cert-key-password")]
1054 pub tls_certificate_key_file_password: Option<Vec<u8>>,
1055}
1056
1057impl TlsOptions {
1058 #[cfg(test)]
1059 pub(crate) fn serialize_for_client_options<S>(
1060 tls_options: &TlsOptions,
1061 serializer: S,
1062 ) -> std::result::Result<S::Ok, S::Error>
1063 where
1064 S: serde::Serializer,
1065 {
1066 #[derive(Serialize)]
1067 struct TlsOptionsHelper<'a> {
1068 tls: bool,
1069 tlscafile: Option<&'a str>,
1070 tlscertificatekeyfile: Option<&'a str>,
1071 tlsallowinvalidcertificates: Option<bool>,
1072 #[cfg(feature = "cert-key-password")]
1073 tlscertificatekeyfilepassword: Option<&'a str>,
1074 }
1075
1076 let state = TlsOptionsHelper {
1077 tls: true,
1078 tlscafile: tls_options
1079 .ca_file_path
1080 .as_ref()
1081 .map(|s| s.to_str().unwrap()),
1082 tlscertificatekeyfile: tls_options
1083 .cert_key_file_path
1084 .as_ref()
1085 .map(|s| s.to_str().unwrap()),
1086 tlsallowinvalidcertificates: tls_options.allow_invalid_certificates,
1087 #[cfg(feature = "cert-key-password")]
1088 tlscertificatekeyfilepassword: tls_options
1089 .tls_certificate_key_file_password
1090 .as_deref()
1091 .map(|b| std::str::from_utf8(b).unwrap()),
1092 };
1093 state.serialize(serializer)
1094 }
1095}
1096
1097#[derive(Clone, Debug, Deserialize, TypedBuilder, PartialEq)]
1100#[builder(field_defaults(default, setter(into)))]
1101#[non_exhaustive]
1102pub struct DriverInfo {
1103 #[builder(!default)]
1105 pub name: String,
1106
1107 pub version: Option<String>,
1109
1110 pub platform: Option<String>,
1112}
1113
1114impl ClientOptions {
1115 #[cfg(test)]
1118 pub(crate) fn new_srv() -> Self {
1119 Self {
1120 original_srv_info: Some(OriginalSrvInfo {
1121 hostname: "localhost.test.test.build.10gen.cc".into(),
1122 min_ttl: Duration::from_secs(60),
1123 }),
1124 ..Default::default()
1125 }
1126 }
1127
1128 pub(crate) fn tls_options(&self) -> Option<TlsOptions> {
1129 match self.tls {
1130 Some(Tls::Enabled(ref opts)) => Some(opts.clone()),
1131 _ => None,
1132 }
1133 }
1134
1135 pub(crate) fn validate(&self) -> Result<()> {
1137 if let Some(true) = self.direct_connection {
1138 if self.hosts.len() > 1 {
1139 return Err(ErrorKind::InvalidArgument {
1140 message: "cannot specify multiple seeds with directConnection=true".to_string(),
1141 }
1142 .into());
1143 }
1144 }
1145
1146 if let Some(ref write_concern) = self.write_concern {
1147 write_concern.validate()?;
1148 }
1149
1150 if self.load_balanced.unwrap_or(false) {
1151 if self.hosts.len() > 1 {
1152 return Err(ErrorKind::InvalidArgument {
1153 message: "cannot specify multiple seeds with loadBalanced=true".to_string(),
1154 }
1155 .into());
1156 }
1157 if self.repl_set_name.is_some() {
1158 return Err(ErrorKind::InvalidArgument {
1159 message: "cannot specify replicaSet with loadBalanced=true".to_string(),
1160 }
1161 .into());
1162 }
1163 if self.direct_connection == Some(true) {
1164 return Err(ErrorKind::InvalidArgument {
1165 message: "cannot specify directConnection=true with loadBalanced=true"
1166 .to_string(),
1167 }
1168 .into());
1169 }
1170 }
1171
1172 #[cfg(any(
1173 feature = "zstd-compression",
1174 feature = "zlib-compression",
1175 feature = "snappy-compression"
1176 ))]
1177 if let Some(ref compressors) = self.compressors {
1178 for compressor in compressors {
1179 compressor.validate()?;
1180 }
1181 }
1182
1183 if let Some(0) = self.max_pool_size {
1184 return Err(Error::invalid_argument("cannot specify maxPoolSize=0"));
1185 }
1186
1187 if let Some(0) = self.max_connecting {
1188 return Err(Error::invalid_argument("cannot specify maxConnecting=0"));
1189 }
1190
1191 if let Some(SelectionCriteria::ReadPreference(ref rp)) = self.selection_criteria {
1192 if let Some(max_staleness) = rp.max_staleness() {
1193 verify_max_staleness(
1194 max_staleness,
1195 self.heartbeat_freq.unwrap_or(DEFAULT_HEARTBEAT_FREQUENCY),
1196 )?;
1197 }
1198 }
1199
1200 if let Some(heartbeat_frequency) = self.heartbeat_freq {
1201 if heartbeat_frequency < self.min_heartbeat_frequency() {
1202 return Err(ErrorKind::InvalidArgument {
1203 message: format!(
1204 "'heartbeat_freq' must be at least {}ms, but {}ms was given",
1205 self.min_heartbeat_frequency().as_millis(),
1206 heartbeat_frequency.as_millis()
1207 ),
1208 }
1209 .into());
1210 }
1211 }
1212
1213 #[cfg(feature = "tracing-unstable")]
1214 {
1215 let hostnames = if let Some(info) = &self.original_srv_info {
1216 vec![info.hostname.to_ascii_lowercase()]
1217 } else {
1218 self.hosts
1219 .iter()
1220 .filter_map(|addr| match addr {
1221 ServerAddress::Tcp { host, .. } => Some(host.to_ascii_lowercase()),
1222 #[cfg(unix)]
1223 _ => None,
1224 })
1225 .collect()
1226 };
1227 if hostnames.iter().any(|s| s.ends_with(".cosmos.azure.com")) {
1228 tracing::info!("You appear to be connected to a CosmosDB cluster. For more information regarding feature compatibility and support please visit https://www.mongodb.com/supportability/cosmosdb");
1229 }
1230 if hostnames.iter().any(|s| {
1231 s.ends_with(".docdb.amazonaws.com") || s.ends_with(".docdb-elastic.amazonaws.com")
1232 }) {
1233 tracing::info!("You appear to be connected to a DocumentDB cluster. For more information regarding feature compatibility and support please visit https://www.mongodb.com/supportability/documentdb");
1234 }
1235 }
1236
1237 Ok(())
1238 }
1239
1240 #[cfg(test)]
1242 pub(crate) fn merge(&mut self, other: ClientOptions) {
1243 if self.hosts.is_empty() {
1244 self.hosts = other.hosts;
1245 }
1246
1247 #[cfg(any(
1248 feature = "zstd-compression",
1249 feature = "zlib-compression",
1250 feature = "snappy-compression"
1251 ))]
1252 merge_options!(other, self, [compressors]);
1253
1254 merge_options!(
1255 other,
1256 self,
1257 [
1258 app_name,
1259 cmap_event_handler,
1260 command_event_handler,
1261 connect_timeout,
1262 credential,
1263 direct_connection,
1264 driver_info,
1265 heartbeat_freq,
1266 load_balanced,
1267 local_threshold,
1268 max_idle_time,
1269 max_pool_size,
1270 min_pool_size,
1271 read_concern,
1272 repl_set_name,
1273 retry_reads,
1274 retry_writes,
1275 selection_criteria,
1276 server_api,
1277 server_selection_timeout,
1278 socket_timeout,
1279 test_options,
1280 tls,
1281 write_concern,
1282 original_srv_info,
1283 original_uri
1284 ]
1285 );
1286 }
1287
1288 #[cfg(test)]
1289 pub(crate) fn test_options_mut(&mut self) -> &mut TestOptions {
1290 self.test_options.get_or_insert_with(Default::default)
1291 }
1292
1293 pub(crate) fn min_heartbeat_frequency(&self) -> Duration {
1294 #[cfg(test)]
1295 {
1296 self.test_options
1297 .as_ref()
1298 .and_then(|to| to.min_heartbeat_freq)
1299 .unwrap_or(MIN_HEARTBEAT_FREQUENCY)
1300 }
1301
1302 #[cfg(not(test))]
1303 {
1304 MIN_HEARTBEAT_FREQUENCY
1305 }
1306 }
1307
1308 pub(crate) fn resolver_config(&self) -> Option<&ResolverConfig> {
1309 #[cfg(feature = "dns-resolver")]
1310 {
1311 self.resolver_config.as_ref()
1312 }
1313 #[cfg(not(feature = "dns-resolver"))]
1314 {
1315 None
1316 }
1317 }
1318}
1319
1320fn exclusive_split_at(s: &str, i: usize) -> (Option<&str>, Option<&str>) {
1323 let (l, r) = s.split_at(i);
1324
1325 let lout = if !l.is_empty() { Some(l) } else { None };
1326 let rout = if r.len() > 1 { Some(&r[1..]) } else { None };
1327
1328 (lout, rout)
1329}
1330
1331fn percent_decode(s: &str, err_message: &str) -> Result<String> {
1332 match percent_encoding::percent_decode_str(s).decode_utf8() {
1333 Ok(result) => Ok(result.to_string()),
1334 Err(_) => Err(ErrorKind::InvalidArgument {
1335 message: err_message.to_string(),
1336 }
1337 .into()),
1338 }
1339}
1340
1341fn validate_userinfo(s: &str, userinfo_type: &str) -> Result<()> {
1342 if s.chars().any(|c| USERINFO_RESERVED_CHARACTERS.contains(&c)) {
1343 return Err(ErrorKind::InvalidArgument {
1344 message: format!("{} must be URL encoded", userinfo_type),
1345 }
1346 .into());
1347 }
1348
1349 if s.split('%')
1352 .skip(1)
1353 .any(|part| part.len() < 2 || part[0..2].chars().any(|c| !c.is_ascii_hexdigit()))
1354 {
1355 return Err(ErrorKind::InvalidArgument {
1356 message: "username/password cannot contain unescaped %".to_string(),
1357 }
1358 .into());
1359 }
1360
1361 Ok(())
1362}
1363
1364impl TryFrom<&str> for ConnectionString {
1365 type Error = Error;
1366
1367 fn try_from(value: &str) -> Result<Self> {
1368 Self::parse(value)
1369 }
1370}
1371
1372impl TryFrom<&String> for ConnectionString {
1373 type Error = Error;
1374
1375 fn try_from(value: &String) -> Result<Self> {
1376 Self::parse(value)
1377 }
1378}
1379
1380impl TryFrom<String> for ConnectionString {
1381 type Error = Error;
1382
1383 fn try_from(value: String) -> Result<Self> {
1384 Self::parse(value)
1385 }
1386}
1387
1388impl ConnectionString {
1389 pub fn parse(s: impl AsRef<str>) -> Result<Self> {
1392 let s = s.as_ref();
1393 let end_of_scheme = match s.find("://") {
1394 Some(index) => index,
1395 None => {
1396 return Err(ErrorKind::InvalidArgument {
1397 message: "connection string contains no scheme".to_string(),
1398 }
1399 .into())
1400 }
1401 };
1402
1403 let srv = match &s[..end_of_scheme] {
1404 "mongodb" => false,
1405 "mongodb+srv" => true,
1406 _ => {
1407 return Err(ErrorKind::InvalidArgument {
1408 message: format!("invalid connection string scheme: {}", &s[..end_of_scheme]),
1409 }
1410 .into())
1411 }
1412 };
1413 #[cfg(not(feature = "dns-resolver"))]
1414 if srv {
1415 return Err(Error::invalid_argument(
1416 "mongodb+srv connection strings cannot be used when the 'dns-resolver' feature is \
1417 disabled",
1418 ));
1419 }
1420
1421 let after_scheme = &s[end_of_scheme + 3..];
1422
1423 let (pre_slash, post_slash) = match after_scheme.find('/') {
1424 Some(slash_index) => match exclusive_split_at(after_scheme, slash_index) {
1425 (Some(section), o) => (section, o),
1426 (None, _) => {
1427 return Err(ErrorKind::InvalidArgument {
1428 message: "missing hosts".to_string(),
1429 }
1430 .into())
1431 }
1432 },
1433 None => {
1434 if after_scheme.find('?').is_some() {
1435 return Err(ErrorKind::InvalidArgument {
1436 message: "Missing delimiting slash between hosts and options".to_string(),
1437 }
1438 .into());
1439 }
1440 (after_scheme, None)
1441 }
1442 };
1443
1444 let (database, options_section) = match post_slash {
1445 Some(section) => match section.find('?') {
1446 Some(index) => exclusive_split_at(section, index),
1447 None => (post_slash, None),
1448 },
1449 None => (None, None),
1450 };
1451
1452 let db = match database {
1453 Some(db) => {
1454 let decoded = percent_decode(db, "database name must be URL encoded")?;
1455 if decoded
1456 .chars()
1457 .any(|c| ILLEGAL_DATABASE_CHARACTERS.contains(&c))
1458 {
1459 return Err(ErrorKind::InvalidArgument {
1460 message: "illegal character in database name".to_string(),
1461 }
1462 .into());
1463 }
1464 Some(decoded)
1465 }
1466 None => None,
1467 };
1468
1469 let (authentication_requested, cred_section, hosts_section) = match pre_slash.rfind('@') {
1470 Some(index) => {
1471 let (creds, hosts) = exclusive_split_at(pre_slash, index);
1474 match hosts {
1475 Some(hs) => (true, creds, hs),
1476 None => {
1477 return Err(ErrorKind::InvalidArgument {
1478 message: "missing hosts".to_string(),
1479 }
1480 .into())
1481 }
1482 }
1483 }
1484 None => (false, None, pre_slash),
1485 };
1486
1487 let (username, password) = match cred_section {
1488 Some(creds) => match creds.find(':') {
1489 Some(index) => match exclusive_split_at(creds, index) {
1490 (username, None) => (username, Some("")),
1491 (username, password) => (username, password),
1492 },
1493 None => (Some(creds), None), },
1495 None => (None, None),
1496 };
1497
1498 let hosts = hosts_section
1499 .split(',')
1500 .map(ServerAddress::parse)
1501 .collect::<Result<Vec<ServerAddress>>>()?;
1502
1503 let host_info = if !srv {
1504 HostInfo::HostIdentifiers(hosts)
1505 } else {
1506 match &hosts[..] {
1507 [ServerAddress::Tcp { host, port: None }] => HostInfo::DnsRecord(host.clone()),
1508 [ServerAddress::Tcp {
1509 host: _,
1510 port: Some(_),
1511 }] => {
1512 return Err(Error::invalid_argument(
1513 "a port cannot be specified with 'mongodb+srv'",
1514 ));
1515 }
1516 #[cfg(unix)]
1517 [ServerAddress::Unix { .. }] => {
1518 return Err(Error::invalid_argument(
1519 "unix sockets cannot be used with 'mongodb+srv'",
1520 ));
1521 }
1522 _ => {
1523 return Err(Error::invalid_argument(
1524 "exactly one host must be specified with 'mongodb+srv'",
1525 ))
1526 }
1527 }
1528 };
1529
1530 let mut conn_str = ConnectionString {
1531 host_info,
1532 #[cfg(test)]
1533 original_uri: s.into(),
1534 ..Default::default()
1535 };
1536
1537 let mut parts = if let Some(opts) = options_section {
1538 conn_str.parse_options(opts)?
1539 } else {
1540 ConnectionStringParts::default()
1541 };
1542
1543 if conn_str.srv_service_name.is_some() && !srv {
1544 return Err(Error::invalid_argument(
1545 "srvServiceName cannot be specified with a non-SRV URI",
1546 ));
1547 }
1548
1549 if let Some(srv_max_hosts) = conn_str.srv_max_hosts {
1550 if !srv {
1551 return Err(Error::invalid_argument(
1552 "srvMaxHosts cannot be specified with a non-SRV URI",
1553 ));
1554 }
1555 if srv_max_hosts > 0 {
1556 if conn_str.replica_set.is_some() {
1557 return Err(Error::invalid_argument(
1558 "srvMaxHosts and replicaSet cannot both be present",
1559 ));
1560 }
1561 if conn_str.load_balanced == Some(true) {
1562 return Err(Error::invalid_argument(
1563 "srvMaxHosts and loadBalanced=true cannot both be present",
1564 ));
1565 }
1566 }
1567 }
1568
1569 if let Some(u) = username {
1571 let credential = conn_str.credential.get_or_insert_with(Default::default);
1572 validate_userinfo(u, "username")?;
1573 let decoded_u = percent_decode(u, "username must be URL encoded")?;
1574
1575 credential.username = Some(decoded_u);
1576
1577 if let Some(pass) = password {
1578 validate_userinfo(pass, "password")?;
1579 let decoded_p = percent_decode(pass, "password must be URL encoded")?;
1580 credential.password = Some(decoded_p)
1581 }
1582 }
1583
1584 if parts.auth_source.as_deref() == Some("") {
1585 return Err(ErrorKind::InvalidArgument {
1586 message: "empty authSource provided".to_string(),
1587 }
1588 .into());
1589 }
1590
1591 match parts.auth_mechanism {
1592 Some(ref mechanism) => {
1593 let credential = conn_str.credential.get_or_insert_with(Default::default);
1594 credential.source = parts.auth_source;
1595
1596 if let Some(mut doc) = parts.auth_mechanism_properties.take() {
1597 match doc.remove("CANONICALIZE_HOST_NAME") {
1598 Some(Bson::String(s)) => {
1599 let val = match &s.to_lowercase()[..] {
1600 "true" => Bson::Boolean(true),
1601 "false" => Bson::Boolean(false),
1602 _ => Bson::String(s),
1603 };
1604 doc.insert("CANONICALIZE_HOST_NAME", val);
1605 }
1606 Some(val) => {
1607 doc.insert("CANONICALIZE_HOST_NAME", val);
1608 }
1609 None => {}
1610 }
1611
1612 credential.mechanism_properties = Some(doc);
1613 }
1614
1615 credential.mechanism = Some(mechanism.clone());
1616 mechanism.validate_credential(credential)?;
1617 }
1618 None => {
1619 if let Some(ref mut credential) = conn_str.credential {
1620 credential.source = parts.auth_source;
1621 } else if authentication_requested {
1622 return Err(ErrorKind::InvalidArgument {
1623 message: "username and mechanism both not provided, but authentication \
1624 was requested"
1625 .to_string(),
1626 }
1627 .into());
1628 }
1629 }
1630 };
1631
1632 conn_str.default_database = db;
1634
1635 if conn_str.tls.is_none() && conn_str.is_srv() {
1636 conn_str.tls = Some(Tls::Enabled(Default::default()));
1637 }
1638
1639 Ok(conn_str)
1640 }
1641
1642 pub fn wait_queue_timeout(&self) -> Option<Duration> {
1645 self.wait_queue_timeout
1646 }
1647
1648 pub fn tls_insecure(&self) -> Option<bool> {
1651 self.tls_insecure
1652 }
1653
1654 fn is_srv(&self) -> bool {
1655 matches!(self.host_info, HostInfo::DnsRecord(_))
1656 }
1657
1658 fn parse_options(&mut self, options: &str) -> Result<ConnectionStringParts> {
1659 let mut parts = ConnectionStringParts::default();
1660 if options.is_empty() {
1661 return Ok(parts);
1662 }
1663
1664 let mut keys: Vec<&str> = Vec::new();
1665
1666 for option_pair in options.split('&') {
1667 let (key, value) = match option_pair.find('=') {
1668 Some(index) => option_pair.split_at(index),
1669 None => {
1670 return Err(ErrorKind::InvalidArgument {
1671 message: format!(
1672 "connection string options is not a `key=value` pair: {}",
1673 option_pair,
1674 ),
1675 }
1676 .into())
1677 }
1678 };
1679
1680 if key.to_lowercase() != "readpreferencetags" && keys.contains(&key) {
1681 return Err(ErrorKind::InvalidArgument {
1682 message: "repeated options are not allowed in the connection string"
1683 .to_string(),
1684 }
1685 .into());
1686 } else {
1687 keys.push(key);
1688 }
1689
1690 self.parse_option_pair(
1692 &mut parts,
1693 &key.to_lowercase(),
1694 percent_encoding::percent_decode(&value.as_bytes()[1..])
1695 .decode_utf8_lossy()
1696 .as_ref(),
1697 )?;
1698 }
1699
1700 if let Some(tags) = parts.read_preference_tags.take() {
1701 self.read_preference = match self.read_preference.take() {
1702 Some(read_pref) => Some(read_pref.with_tags(tags)?),
1703 None => {
1704 return Err(ErrorKind::InvalidArgument {
1705 message: "cannot set read preference tags without also setting read \
1706 preference mode"
1707 .to_string(),
1708 }
1709 .into())
1710 }
1711 };
1712 }
1713
1714 if let Some(max_staleness) = parts.max_staleness.take() {
1715 self.read_preference = match self.read_preference.take() {
1716 Some(read_pref) => Some(read_pref.with_max_staleness(max_staleness)?),
1717 None => {
1718 return Err(ErrorKind::InvalidArgument {
1719 message: "cannot set max staleness without also setting read preference \
1720 mode"
1721 .to_string(),
1722 }
1723 .into())
1724 }
1725 };
1726 }
1727
1728 if let Some(true) = self.direct_connection {
1729 if self.is_srv() {
1730 return Err(ErrorKind::InvalidArgument {
1731 message: "cannot use SRV-style URI with directConnection=true".to_string(),
1732 }
1733 .into());
1734 }
1735 }
1736
1737 #[cfg(feature = "zlib-compression")]
1738 if let Some(zlib_compression_level) = parts.zlib_compression {
1739 if let Some(compressors) = self.compressors.as_mut() {
1740 for compressor in compressors {
1741 compressor.write_zlib_level(zlib_compression_level)?;
1742 }
1743 }
1744 }
1745 #[cfg(not(feature = "zlib-compression"))]
1746 if parts.zlib_compression.is_some() {
1747 return Err(ErrorKind::InvalidArgument {
1748 message: "zlibCompressionLevel may not be specified without the zlib-compression \
1749 feature flag enabled"
1750 .into(),
1751 }
1752 .into());
1753 }
1754
1755 Ok(parts)
1756 }
1757
1758 fn parse_option_pair(
1759 &mut self,
1760 parts: &mut ConnectionStringParts,
1761 key: &str,
1762 value: &str,
1763 ) -> Result<()> {
1764 macro_rules! get_bool {
1765 ($value:expr, $option:expr) => {
1766 match $value {
1767 "true" => true,
1768 "false" => false,
1769 _ => {
1770 return Err(ErrorKind::InvalidArgument {
1771 message: format!(
1772 "connection string `{}` option must be a boolean",
1773 $option,
1774 ),
1775 }
1776 .into())
1777 }
1778 }
1779 };
1780 }
1781
1782 macro_rules! get_duration {
1783 ($value:expr, $option:expr) => {
1784 match $value.parse::<u64>() {
1785 Ok(i) => i,
1786 _ => {
1787 return Err(ErrorKind::InvalidArgument {
1788 message: format!(
1789 "connection string `{}` option must be a non-negative integer",
1790 $option
1791 ),
1792 }
1793 .into())
1794 }
1795 }
1796 };
1797 }
1798
1799 macro_rules! get_u32 {
1800 ($value:expr, $option:expr) => {
1801 match value.parse::<u32>() {
1802 Ok(u) => u,
1803 Err(_) => {
1804 return Err(ErrorKind::InvalidArgument {
1805 message: format!(
1806 "connection string `{}` argument must be a positive integer",
1807 $option,
1808 ),
1809 }
1810 .into())
1811 }
1812 }
1813 };
1814 }
1815
1816 macro_rules! get_i32 {
1817 ($value:expr, $option:expr) => {
1818 match value.parse::<i32>() {
1819 Ok(u) => u,
1820 Err(_) => {
1821 return Err(ErrorKind::InvalidArgument {
1822 message: format!(
1823 "connection string `{}` argument must be an integer",
1824 $option
1825 ),
1826 }
1827 .into())
1828 }
1829 }
1830 };
1831 }
1832
1833 match key {
1834 "appname" => {
1835 self.app_name = Some(value.into());
1836 }
1837 "authmechanism" => {
1838 parts.auth_mechanism = Some(AuthMechanism::from_str(value)?);
1839 }
1840 "authsource" => parts.auth_source = Some(value.to_string()),
1841 "authmechanismproperties" => {
1842 let mut doc = Document::new();
1843 let err_func = || {
1844 ErrorKind::InvalidArgument {
1845 message: "improperly formatted authMechanismProperties".to_string(),
1846 }
1847 .into()
1848 };
1849
1850 for kvp in value.split(',') {
1851 match kvp.find(':') {
1852 Some(index) => {
1853 let (k, v) = exclusive_split_at(kvp, index);
1854 let key = k.ok_or_else(err_func)?;
1855 match key {
1856 "ALLOWED_HOSTS" => {
1857 return Err(Error::invalid_argument(
1858 "ALLOWED_HOSTS must only be specified through client \
1859 options",
1860 ));
1861 }
1862 "OIDC_CALLBACK" => {
1863 return Err(Error::invalid_argument(
1864 "OIDC_CALLBACK must only be specified through client \
1865 options",
1866 ));
1867 }
1868 "OIDC_HUMAN_CALLBACK" => {
1869 return Err(Error::invalid_argument(
1870 "OIDC_HUMAN_CALLBACK must only be specified through \
1871 client options",
1872 ));
1873 }
1874 _ => {}
1875 }
1876 let value = v.ok_or_else(err_func)?;
1877 doc.insert(key, value);
1878 }
1879 None => return Err(err_func()),
1880 };
1881 }
1882 parts.auth_mechanism_properties = Some(doc);
1883 }
1884 #[cfg(any(
1885 feature = "zstd-compression",
1886 feature = "zlib-compression",
1887 feature = "snappy-compression"
1888 ))]
1889 "compressors" => {
1890 let mut compressors: Option<Vec<Compressor>> = None;
1891 for compressor in value.split(',') {
1892 let compressor = Compressor::from_str(compressor)?;
1893 compressors
1894 .get_or_insert_with(Default::default)
1895 .push(compressor);
1896 }
1897 self.compressors = compressors;
1898 }
1899 k @ "connecttimeoutms" => {
1900 self.connect_timeout = Some(Duration::from_millis(get_duration!(value, k)));
1901 }
1902 k @ "directconnection" => {
1903 self.direct_connection = Some(get_bool!(value, k));
1904 }
1905 k @ "heartbeatfrequencyms" => {
1906 self.heartbeat_frequency = Some(Duration::from_millis(get_duration!(value, k)));
1907 }
1908 k @ "journal" => {
1909 let write_concern = self.write_concern.get_or_insert_with(Default::default);
1910 write_concern.journal = Some(get_bool!(value, k));
1911 }
1912 k @ "loadbalanced" => {
1913 self.load_balanced = Some(get_bool!(value, k));
1914 }
1915 k @ "localthresholdms" => {
1916 self.local_threshold = Some(Duration::from_millis(get_duration!(value, k)))
1917 }
1918 k @ "maxidletimems" => {
1919 self.max_idle_time = Some(Duration::from_millis(get_duration!(value, k)));
1920 }
1921 "maxstalenessseconds" => {
1922 let max_staleness_seconds = value.parse::<i64>().map_err(|e| {
1923 Error::invalid_argument(format!("invalid maxStalenessSeconds value: {}", e))
1924 })?;
1925
1926 let max_staleness = match max_staleness_seconds.cmp(&-1) {
1927 Ordering::Less => {
1928 return Err(Error::invalid_argument(format!(
1929 "maxStalenessSeconds must be -1 or positive, instead got {}",
1930 max_staleness_seconds
1931 )));
1932 }
1933 Ordering::Equal => {
1934 return Ok(());
1936 }
1937 Ordering::Greater => Duration::from_secs(max_staleness_seconds as u64),
1938 };
1939
1940 parts.max_staleness = Some(max_staleness);
1941 }
1942 k @ "maxpoolsize" => {
1943 self.max_pool_size = Some(get_u32!(value, k));
1944 }
1945 k @ "minpoolsize" => {
1946 self.min_pool_size = Some(get_u32!(value, k));
1947 }
1948 k @ "maxconnecting" => {
1949 self.max_connecting = Some(get_u32!(value, k));
1950 }
1951 "readconcernlevel" => {
1952 self.read_concern = Some(ReadConcernLevel::from_str(value).into());
1953 }
1954 "readpreference" => {
1955 self.read_preference = Some(match &value.to_lowercase()[..] {
1956 "primary" => ReadPreference::Primary,
1957 "secondary" => ReadPreference::Secondary {
1958 options: Default::default(),
1959 },
1960 "primarypreferred" => ReadPreference::PrimaryPreferred {
1961 options: Default::default(),
1962 },
1963 "secondarypreferred" => ReadPreference::SecondaryPreferred {
1964 options: Default::default(),
1965 },
1966 "nearest" => ReadPreference::Nearest {
1967 options: Default::default(),
1968 },
1969 other => {
1970 return Err(ErrorKind::InvalidArgument {
1971 message: format!("'{}' is not a valid read preference", other),
1972 }
1973 .into())
1974 }
1975 });
1976 }
1977 "readpreferencetags" => {
1978 let tags: Result<TagSet> = if value.is_empty() {
1979 Ok(TagSet::new())
1980 } else {
1981 value
1982 .split(',')
1983 .map(|tag| {
1984 let mut values = tag.split(':');
1985
1986 match (values.next(), values.next()) {
1987 (Some(key), Some(value)) => {
1988 Ok((key.to_string(), value.to_string()))
1989 }
1990 _ => Err(ErrorKind::InvalidArgument {
1991 message: format!(
1992 "'{}' is not a valid read preference tag (which must be \
1993 of the form 'key:value'",
1994 value,
1995 ),
1996 }
1997 .into()),
1998 }
1999 })
2000 .collect()
2001 };
2002
2003 parts
2004 .read_preference_tags
2005 .get_or_insert_with(Vec::new)
2006 .push(tags?);
2007 }
2008 "replicaset" => {
2009 self.replica_set = Some(value.to_string());
2010 }
2011 k @ "retrywrites" => {
2012 self.retry_writes = Some(get_bool!(value, k));
2013 }
2014 k @ "retryreads" => {
2015 self.retry_reads = Some(get_bool!(value, k));
2016 }
2017 "servermonitoringmode" => {
2018 self.server_monitoring_mode = Some(match value.to_lowercase().as_str() {
2019 "stream" => ServerMonitoringMode::Stream,
2020 "poll" => ServerMonitoringMode::Poll,
2021 "auto" => ServerMonitoringMode::Auto,
2022 other => {
2023 return Err(Error::invalid_argument(format!(
2024 "{:?} is not a valid server monitoring mode",
2025 other
2026 )));
2027 }
2028 });
2029 }
2030 k @ "serverselectiontimeoutms" => {
2031 self.server_selection_timeout = Some(Duration::from_millis(get_duration!(value, k)))
2032 }
2033 k @ "sockettimeoutms" => {
2034 self.socket_timeout = Some(Duration::from_millis(get_duration!(value, k)));
2035 }
2036 k @ "srvmaxhosts" => {
2037 self.srv_max_hosts = Some(get_u32!(value, k));
2038 }
2039 "srvservicename" => {
2040 self.srv_service_name = Some(value.to_string());
2041 }
2042 k @ "tls" | k @ "ssl" => {
2043 let tls = get_bool!(value, k);
2044
2045 match (self.tls.as_ref(), tls) {
2046 (Some(Tls::Disabled), true) | (Some(Tls::Enabled(..)), false) => {
2047 return Err(ErrorKind::InvalidArgument {
2048 message: "All instances of `tls` and `ssl` must have the same
2049 value"
2050 .to_string(),
2051 }
2052 .into());
2053 }
2054 _ => {}
2055 };
2056
2057 if self.tls.is_none() {
2058 let tls = if tls {
2059 Tls::Enabled(Default::default())
2060 } else {
2061 Tls::Disabled
2062 };
2063
2064 self.tls = Some(tls);
2065 }
2066 }
2067 k @ "tlsinsecure" | k @ "tlsallowinvalidcertificates" => {
2068 let val = get_bool!(value, k);
2069
2070 let allow_invalid_certificates = if k == "tlsinsecure" { !val } else { val };
2071
2072 match self.tls {
2073 Some(Tls::Disabled) => {
2074 return Err(ErrorKind::InvalidArgument {
2075 message: "'tlsInsecure' can't be set if tls=false".into(),
2076 }
2077 .into())
2078 }
2079 Some(Tls::Enabled(ref options))
2080 if options.allow_invalid_certificates.is_some()
2081 && options.allow_invalid_certificates
2082 != Some(allow_invalid_certificates) =>
2083 {
2084 return Err(ErrorKind::InvalidArgument {
2085 message: "all instances of 'tlsInsecure' and \
2086 'tlsAllowInvalidCertificates' must be consistent (e.g. \
2087 'tlsInsecure' cannot be true when \
2088 'tlsAllowInvalidCertificates' is false, or vice-versa)"
2089 .into(),
2090 }
2091 .into());
2092 }
2093 Some(Tls::Enabled(ref mut options)) => {
2094 options.allow_invalid_certificates = Some(allow_invalid_certificates);
2095 }
2096 None => {
2097 self.tls = Some(Tls::Enabled(
2098 TlsOptions::builder()
2099 .allow_invalid_certificates(allow_invalid_certificates)
2100 .build(),
2101 ))
2102 }
2103 }
2104 }
2105 "tlscafile" => match self.tls {
2106 Some(Tls::Disabled) => {
2107 return Err(ErrorKind::InvalidArgument {
2108 message: "'tlsCAFile' can't be set if tls=false".into(),
2109 }
2110 .into());
2111 }
2112 Some(Tls::Enabled(ref mut options)) => {
2113 options.ca_file_path = Some(value.into());
2114 }
2115 None => {
2116 self.tls = Some(Tls::Enabled(
2117 TlsOptions::builder()
2118 .ca_file_path(PathBuf::from(value))
2119 .build(),
2120 ))
2121 }
2122 },
2123 "tlscertificatekeyfile" => match self.tls {
2124 Some(Tls::Disabled) => {
2125 return Err(ErrorKind::InvalidArgument {
2126 message: "'tlsCertificateKeyFile' can't be set if tls=false".into(),
2127 }
2128 .into());
2129 }
2130 Some(Tls::Enabled(ref mut options)) => {
2131 options.cert_key_file_path = Some(value.into());
2132 }
2133 None => {
2134 self.tls = Some(Tls::Enabled(
2135 TlsOptions::builder()
2136 .cert_key_file_path(PathBuf::from(value))
2137 .build(),
2138 ))
2139 }
2140 },
2141 #[cfg(feature = "cert-key-password")]
2142 "tlscertificatekeyfilepassword" => match &mut self.tls {
2143 Some(Tls::Disabled) => {
2144 return Err(ErrorKind::InvalidArgument {
2145 message: "'tlsCertificateKeyFilePassword' can't be set if tls=false".into(),
2146 }
2147 .into());
2148 }
2149 Some(Tls::Enabled(options)) => {
2150 options.tls_certificate_key_file_password = Some(value.as_bytes().to_vec());
2151 }
2152 None => {
2153 self.tls = Some(Tls::Enabled(
2154 TlsOptions::builder()
2155 .tls_certificate_key_file_password(value.as_bytes().to_vec())
2156 .build(),
2157 ))
2158 }
2159 },
2160 #[cfg(not(feature = "cert-key-password"))]
2161 "tlscertificatekeyfilepassword" => {
2162 return Err(Error::invalid_argument(
2163 "the cert-key-password feature must be enabled to specify \
2164 tlsCertificateKeyFilePassword in the URI",
2165 ));
2166 }
2167 "uuidrepresentation" => match value.to_lowercase().as_str() {
2168 "csharplegacy" => self.uuid_representation = Some(UuidRepresentation::CSharpLegacy),
2169 "javalegacy" => self.uuid_representation = Some(UuidRepresentation::JavaLegacy),
2170 "pythonlegacy" => self.uuid_representation = Some(UuidRepresentation::PythonLegacy),
2171 _ => {
2172 return Err(ErrorKind::InvalidArgument {
2173 message: format!(
2174 "connection string `uuidRepresentation` option can be one of \
2175 `csharpLegacy`, `javaLegacy`, or `pythonLegacy`. Received invalid \
2176 `{}`",
2177 value
2178 ),
2179 }
2180 .into())
2181 }
2182 },
2183 "w" => {
2184 let write_concern = self.write_concern.get_or_insert_with(Default::default);
2185
2186 match value.parse::<i32>() {
2187 Ok(w) => match u32::try_from(w) {
2188 Ok(uw) => write_concern.w = Some(Acknowledgment::from(uw)),
2189 Err(_) => {
2190 return Err(ErrorKind::InvalidArgument {
2191 message: "connection string `w` option cannot be a negative \
2192 integer"
2193 .to_string(),
2194 }
2195 .into())
2196 }
2197 },
2198 Err(_) => {
2199 write_concern.w = Some(Acknowledgment::from(value.to_string()));
2200 }
2201 };
2202 }
2203 k @ "waitqueuetimeoutms" => {
2204 self.wait_queue_timeout = Some(Duration::from_millis(get_duration!(value, k)));
2205 }
2206 k @ "wtimeoutms" => {
2207 let write_concern = self.write_concern.get_or_insert_with(Default::default);
2208 write_concern.w_timeout = Some(Duration::from_millis(get_duration!(value, k)));
2209 }
2210 k @ "zlibcompressionlevel" => {
2211 let i = get_i32!(value, k);
2212 if i < -1 {
2213 return Err(ErrorKind::InvalidArgument {
2214 message: "'zlibCompressionLevel' cannot be less than -1".to_string(),
2215 }
2216 .into());
2217 }
2218
2219 if i > 9 {
2220 return Err(ErrorKind::InvalidArgument {
2221 message: "'zlibCompressionLevel' cannot be greater than 9".to_string(),
2222 }
2223 .into());
2224 }
2225
2226 parts.zlib_compression = Some(i);
2227 }
2228
2229 other => {
2230 let (jaro_winkler, option) = URI_OPTIONS.iter().fold((0.0, ""), |acc, option| {
2231 let jaro_winkler = jaro_winkler(option, other).abs();
2232 if jaro_winkler > acc.0 {
2233 return (jaro_winkler, option);
2234 }
2235 acc
2236 });
2237 let mut message = format!("{} is an invalid option", other);
2238 if jaro_winkler >= 0.84 {
2239 let _ = write!(
2240 message,
2241 ". An option with a similar name exists: {}",
2242 option
2243 );
2244 }
2245 return Err(ErrorKind::InvalidArgument { message }.into());
2246 }
2247 }
2248
2249 Ok(())
2250 }
2251}
2252
2253impl FromStr for ConnectionString {
2254 type Err = Error;
2255 fn from_str(s: &str) -> Result<Self> {
2256 ConnectionString::parse(s)
2257 }
2258}
2259
2260impl<'de> Deserialize<'de> for ConnectionString {
2261 fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
2262 where
2263 D: Deserializer<'de>,
2264 {
2265 deserializer.deserialize_str(ConnectionStringVisitor)
2266 }
2267}
2268
2269struct ConnectionStringVisitor;
2270
2271impl serde::de::Visitor<'_> for ConnectionStringVisitor {
2272 type Value = ConnectionString;
2273
2274 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
2275 write!(formatter, "a MongoDB connection string")
2276 }
2277
2278 fn visit_str<E>(self, v: &str) -> std::result::Result<Self::Value, E>
2279 where
2280 E: serde::de::Error,
2281 {
2282 ConnectionString::parse(v).map_err(serde::de::Error::custom)
2283 }
2284}
2285
2286#[cfg(test)]
2287mod tests {
2288 use std::time::Duration;
2289
2290 use pretty_assertions::assert_eq;
2291
2292 use super::{ClientOptions, ServerAddress};
2293 use crate::{
2294 concern::{Acknowledgment, ReadConcernLevel, WriteConcern},
2295 selection_criteria::{ReadPreference, ReadPreferenceOptions},
2296 };
2297
2298 macro_rules! tag_set {
2299 ( $($k:expr => $v:expr),* ) => {
2300 #[allow(clippy::let_and_return)]
2301 {
2302 use std::collections::HashMap;
2303
2304 #[allow(unused_mut)]
2305 let mut ts = HashMap::new();
2306 $(
2307 ts.insert($k.to_string(), $v.to_string());
2308 )*
2309
2310 ts
2311 }
2312 }
2313 }
2314
2315 fn host_without_port(hostname: &str) -> ServerAddress {
2316 ServerAddress::Tcp {
2317 host: hostname.to_string(),
2318 port: None,
2319 }
2320 }
2321
2322 #[test]
2323 fn test_parse_address_with_from_str() {
2324 let x = "localhost:27017".parse::<ServerAddress>().unwrap();
2325 match x {
2326 ServerAddress::Tcp { host, port } => {
2327 assert_eq!(host, "localhost");
2328 assert_eq!(port, Some(27017));
2329 }
2330 #[cfg(unix)]
2331 _ => panic!("expected ServerAddress::Tcp"),
2332 }
2333
2334 let x = "localhost".parse::<ServerAddress>().unwrap();
2336 match x {
2337 ServerAddress::Tcp { host, port } => {
2338 assert_eq!(host, "localhost");
2339 assert_eq!(port, None);
2340 }
2341 #[cfg(unix)]
2342 _ => panic!("expected ServerAddress::Tcp"),
2343 }
2344
2345 let x = "localhost:not a number".parse::<ServerAddress>();
2346 assert!(x.is_err());
2347
2348 #[cfg(unix)]
2349 {
2350 let x = "/path/to/socket.sock".parse::<ServerAddress>().unwrap();
2351 match x {
2352 ServerAddress::Unix { path } => {
2353 assert_eq!(path.to_str().unwrap(), "/path/to/socket.sock");
2354 }
2355 _ => panic!("expected ServerAddress::Unix"),
2356 }
2357 }
2358 }
2359
2360 #[tokio::test]
2361 async fn fails_without_scheme() {
2362 assert!(ClientOptions::parse("localhost:27017").await.is_err());
2363 }
2364
2365 #[tokio::test]
2366 async fn fails_with_invalid_scheme() {
2367 assert!(ClientOptions::parse("mangodb://localhost:27017")
2368 .await
2369 .is_err());
2370 }
2371
2372 #[tokio::test]
2373 async fn fails_with_nothing_after_scheme() {
2374 assert!(ClientOptions::parse("mongodb://").await.is_err());
2375 }
2376
2377 #[tokio::test]
2378 async fn fails_with_only_slash_after_scheme() {
2379 assert!(ClientOptions::parse("mongodb:///").await.is_err());
2380 }
2381
2382 #[tokio::test]
2383 async fn fails_with_no_host() {
2384 assert!(ClientOptions::parse("mongodb://:27017").await.is_err());
2385 }
2386
2387 #[tokio::test]
2388 async fn no_port() {
2389 let uri = "mongodb://localhost";
2390
2391 assert_eq!(
2392 ClientOptions::parse(uri).await.unwrap(),
2393 ClientOptions {
2394 hosts: vec![host_without_port("localhost")],
2395 original_uri: Some(uri.into()),
2396 ..Default::default()
2397 }
2398 );
2399 }
2400
2401 #[tokio::test]
2402 async fn no_port_trailing_slash() {
2403 let uri = "mongodb://localhost/";
2404
2405 assert_eq!(
2406 ClientOptions::parse(uri).await.unwrap(),
2407 ClientOptions {
2408 hosts: vec![host_without_port("localhost")],
2409 original_uri: Some(uri.into()),
2410 ..Default::default()
2411 }
2412 );
2413 }
2414
2415 #[tokio::test]
2416 async fn with_port() {
2417 let uri = "mongodb://localhost/";
2418
2419 assert_eq!(
2420 ClientOptions::parse(uri).await.unwrap(),
2421 ClientOptions {
2422 hosts: vec![ServerAddress::Tcp {
2423 host: "localhost".to_string(),
2424 port: Some(27017),
2425 }],
2426 original_uri: Some(uri.into()),
2427 ..Default::default()
2428 }
2429 );
2430 }
2431
2432 #[tokio::test]
2433 async fn with_port_and_trailing_slash() {
2434 let uri = "mongodb://localhost:27017/";
2435
2436 assert_eq!(
2437 ClientOptions::parse(uri).await.unwrap(),
2438 ClientOptions {
2439 hosts: vec![ServerAddress::Tcp {
2440 host: "localhost".to_string(),
2441 port: Some(27017),
2442 }],
2443 original_uri: Some(uri.into()),
2444 ..Default::default()
2445 }
2446 );
2447 }
2448
2449 #[tokio::test]
2450 async fn with_read_concern() {
2451 let uri = "mongodb://localhost:27017/?readConcernLevel=foo";
2452
2453 assert_eq!(
2454 ClientOptions::parse(uri).await.unwrap(),
2455 ClientOptions {
2456 hosts: vec![ServerAddress::Tcp {
2457 host: "localhost".to_string(),
2458 port: Some(27017),
2459 }],
2460 read_concern: Some(ReadConcernLevel::Custom("foo".to_string()).into()),
2461 original_uri: Some(uri.into()),
2462 ..Default::default()
2463 }
2464 );
2465 }
2466
2467 #[tokio::test]
2468 async fn with_w_negative_int() {
2469 assert!(ClientOptions::parse("mongodb://localhost:27017/?w=-1")
2470 .await
2471 .is_err());
2472 }
2473
2474 #[tokio::test]
2475 async fn with_w_non_negative_int() {
2476 let uri = "mongodb://localhost:27017/?w=1";
2477 let write_concern = WriteConcern::builder().w(Acknowledgment::from(1)).build();
2478
2479 assert_eq!(
2480 ClientOptions::parse(uri).await.unwrap(),
2481 ClientOptions {
2482 hosts: vec![ServerAddress::Tcp {
2483 host: "localhost".to_string(),
2484 port: Some(27017),
2485 }],
2486 write_concern: Some(write_concern),
2487 original_uri: Some(uri.into()),
2488 ..Default::default()
2489 }
2490 );
2491 }
2492
2493 #[tokio::test]
2494 async fn with_w_string() {
2495 let uri = "mongodb://localhost:27017/?w=foo";
2496 let write_concern = WriteConcern::builder()
2497 .w(Acknowledgment::from("foo".to_string()))
2498 .build();
2499
2500 assert_eq!(
2501 ClientOptions::parse(uri).await.unwrap(),
2502 ClientOptions {
2503 hosts: vec![ServerAddress::Tcp {
2504 host: "localhost".to_string(),
2505 port: Some(27017),
2506 }],
2507 write_concern: Some(write_concern),
2508 original_uri: Some(uri.into()),
2509 ..Default::default()
2510 }
2511 );
2512 }
2513
2514 #[tokio::test]
2515 async fn with_invalid_j() {
2516 assert!(
2517 ClientOptions::parse("mongodb://localhost:27017/?journal=foo")
2518 .await
2519 .is_err()
2520 );
2521 }
2522
2523 #[tokio::test]
2524 async fn with_j() {
2525 let uri = "mongodb://localhost:27017/?journal=true";
2526 let write_concern = WriteConcern::builder().journal(true).build();
2527
2528 assert_eq!(
2529 ClientOptions::parse(uri).await.unwrap(),
2530 ClientOptions {
2531 hosts: vec![ServerAddress::Tcp {
2532 host: "localhost".to_string(),
2533 port: Some(27017),
2534 }],
2535 write_concern: Some(write_concern),
2536 original_uri: Some(uri.into()),
2537 ..Default::default()
2538 }
2539 );
2540 }
2541
2542 #[tokio::test]
2543 async fn with_wtimeout_non_int() {
2544 assert!(
2545 ClientOptions::parse("mongodb://localhost:27017/?wtimeoutMS=foo")
2546 .await
2547 .is_err()
2548 );
2549 }
2550
2551 #[tokio::test]
2552 async fn with_wtimeout_negative_int() {
2553 assert!(
2554 ClientOptions::parse("mongodb://localhost:27017/?wtimeoutMS=-1")
2555 .await
2556 .is_err()
2557 );
2558 }
2559
2560 #[tokio::test]
2561 async fn with_wtimeout() {
2562 let uri = "mongodb://localhost:27017/?wtimeoutMS=27";
2563 let write_concern = WriteConcern::builder()
2564 .w_timeout(Duration::from_millis(27))
2565 .build();
2566
2567 assert_eq!(
2568 ClientOptions::parse(uri).await.unwrap(),
2569 ClientOptions {
2570 hosts: vec![ServerAddress::Tcp {
2571 host: "localhost".to_string(),
2572 port: Some(27017),
2573 }],
2574 write_concern: Some(write_concern),
2575 original_uri: Some(uri.into()),
2576 ..Default::default()
2577 }
2578 );
2579 }
2580
2581 #[tokio::test]
2582 async fn with_all_write_concern_options() {
2583 let uri = "mongodb://localhost:27017/?w=majority&journal=false&wtimeoutMS=27";
2584 let write_concern = WriteConcern::builder()
2585 .w(Acknowledgment::Majority)
2586 .journal(false)
2587 .w_timeout(Duration::from_millis(27))
2588 .build();
2589
2590 assert_eq!(
2591 ClientOptions::parse(uri).await.unwrap(),
2592 ClientOptions {
2593 hosts: vec![ServerAddress::Tcp {
2594 host: "localhost".to_string(),
2595 port: Some(27017),
2596 }],
2597 write_concern: Some(write_concern),
2598 original_uri: Some(uri.into()),
2599 ..Default::default()
2600 }
2601 );
2602 }
2603
2604 #[tokio::test]
2605 async fn with_mixed_options() {
2606 let uri = "mongodb://localhost,localhost:27018/?w=majority&readConcernLevel=majority&\
2607 journal=false&wtimeoutMS=27&replicaSet=foo&heartbeatFrequencyMS=1000&\
2608 localThresholdMS=4000&readPreference=secondaryPreferred&readpreferencetags=dc:\
2609 ny,rack:1&serverselectiontimeoutms=2000&readpreferencetags=dc:ny&\
2610 readpreferencetags=";
2611 let write_concern = WriteConcern::builder()
2612 .w(Acknowledgment::Majority)
2613 .journal(false)
2614 .w_timeout(Duration::from_millis(27))
2615 .build();
2616
2617 assert_eq!(
2618 ClientOptions::parse(uri).await.unwrap(),
2619 ClientOptions {
2620 hosts: vec![
2621 ServerAddress::Tcp {
2622 host: "localhost".to_string(),
2623 port: None,
2624 },
2625 ServerAddress::Tcp {
2626 host: "localhost".to_string(),
2627 port: Some(27018),
2628 },
2629 ],
2630 selection_criteria: Some(
2631 ReadPreference::SecondaryPreferred {
2632 options: Some(
2633 ReadPreferenceOptions::builder()
2634 .tag_sets(vec![
2635 tag_set! {
2636 "dc" => "ny",
2637 "rack" => "1"
2638 },
2639 tag_set! {
2640 "dc" => "ny"
2641 },
2642 tag_set! {},
2643 ])
2644 .build()
2645 )
2646 }
2647 .into()
2648 ),
2649 read_concern: Some(ReadConcernLevel::Majority.into()),
2650 write_concern: Some(write_concern),
2651 repl_set_name: Some("foo".to_string()),
2652 heartbeat_freq: Some(Duration::from_millis(1000)),
2653 local_threshold: Some(Duration::from_millis(4000)),
2654 server_selection_timeout: Some(Duration::from_millis(2000)),
2655 original_uri: Some(uri.into()),
2656 ..Default::default()
2657 }
2658 );
2659 }
2660}
2661
2662#[derive(Clone, Debug, Default, Deserialize, TypedBuilder)]
2664#[builder(field_defaults(default, setter(into)))]
2665#[serde(rename_all = "camelCase")]
2666#[non_exhaustive]
2667#[export_tokens]
2668pub struct SessionOptions {
2669 pub default_transaction_options: Option<TransactionOptions>,
2677
2678 pub causal_consistency: Option<bool>,
2683
2684 pub snapshot: Option<bool>,
2687}
2688
2689impl SessionOptions {
2690 pub(crate) fn validate(&self) -> Result<()> {
2691 if let (Some(causal_consistency), Some(snapshot)) = (self.causal_consistency, self.snapshot)
2692 {
2693 if causal_consistency && snapshot {
2694 return Err(ErrorKind::InvalidArgument {
2695 message: "snapshot and causal consistency are mutually exclusive".to_string(),
2696 }
2697 .into());
2698 }
2699 }
2700 Ok(())
2701 }
2702}
2703
2704#[skip_serializing_none]
2706#[derive(Debug, Default, Serialize, Deserialize, TypedBuilder, Clone)]
2707#[builder(field_defaults(default, setter(into)))]
2708#[serde(rename_all = "camelCase")]
2709#[non_exhaustive]
2710#[export_tokens]
2711pub struct TransactionOptions {
2712 #[builder(default)]
2714 #[serde(skip_serializing)]
2715 pub read_concern: Option<ReadConcern>,
2716
2717 #[builder(default)]
2719 pub write_concern: Option<WriteConcern>,
2720
2721 #[builder(default)]
2723 #[serde(skip_serializing, rename = "readPreference")]
2724 pub selection_criteria: Option<SelectionCriteria>,
2725
2726 #[builder(default)]
2728 #[serde(
2729 serialize_with = "serde_util::serialize_duration_option_as_int_millis",
2730 deserialize_with = "serde_util::deserialize_duration_option_from_u64_millis",
2731 rename(serialize = "maxTimeMS", deserialize = "maxCommitTimeMS"),
2732 default
2733 )]
2734 pub max_commit_time: Option<Duration>,
2735}
2736
2737#[derive(Debug, Serialize, Deserialize, Clone, PartialEq)]
2739#[non_exhaustive]
2740pub enum ServerMonitoringMode {
2741 Stream,
2744 Poll,
2746 Auto,
2749}