use std::time::Duration;
#[cfg(feature = "dns-resolver")]
use crate::error::ErrorKind;
use crate::{client::options::ResolverConfig, error::Result, options::ServerAddress};
#[derive(Debug)]
pub(crate) struct ResolvedConfig {
pub(crate) hosts: Vec<ServerAddress>,
pub(crate) min_ttl: Duration,
pub(crate) auth_source: Option<String>,
pub(crate) replica_set: Option<String>,
pub(crate) load_balanced: Option<bool>,
}
#[derive(Debug, Clone)]
pub(crate) struct LookupHosts {
pub(crate) hosts: Vec<ServerAddress>,
pub(crate) min_ttl: Duration,
}
impl LookupHosts {
#[cfg(feature = "dns-resolver")]
pub(crate) fn validate(mut self, original_hostname: &str, dm: DomainMismatch) -> Result<Self> {
let original_hostname_parts: Vec<_> = original_hostname.split('.').collect();
let original_domain_name = if original_hostname_parts.len() >= 3 {
&original_hostname_parts[1..]
} else {
&original_hostname_parts[..]
};
let mut ok_hosts = vec![];
for addr in self.hosts.drain(..) {
let host = addr.host();
let hostname_parts: Vec<_> = host.split('.').collect();
if hostname_parts[1..].ends_with(original_domain_name) {
ok_hosts.push(addr);
} else {
let message = format!(
"SRV lookup for {} returned result {}, which does not match domain name {}",
original_hostname,
host,
original_domain_name.join(".")
);
match dm {
DomainMismatch::Error => return Err(ErrorKind::DnsResolve { message }.into()),
DomainMismatch::Skip => {
#[cfg(feature = "tracing-unstable")]
{
use crate::trace::SERVER_SELECTION_TRACING_EVENT_TARGET;
if crate::trace::trace_or_log_enabled!(
target: SERVER_SELECTION_TRACING_EVENT_TARGET,
crate::trace::TracingOrLogLevel::Warn
) {
tracing::warn!(
target: SERVER_SELECTION_TRACING_EVENT_TARGET,
message,
);
}
}
continue;
}
}
}
}
self.hosts = ok_hosts;
if self.hosts.is_empty() {
return Err(ErrorKind::DnsResolve {
message: format!("SRV lookup for {} returned no records", original_hostname),
}
.into());
}
Ok(self)
}
}
#[derive(Debug, Clone, PartialEq)]
pub(crate) struct OriginalSrvInfo {
pub(crate) hostname: String,
pub(crate) min_ttl: Duration,
}
pub(crate) enum DomainMismatch {
#[allow(dead_code)]
Error,
Skip,
}
#[cfg(feature = "dns-resolver")]
pub(crate) struct SrvResolver {
resolver: crate::runtime::AsyncResolver,
srv_service_name: Option<String>,
}
#[cfg(feature = "dns-resolver")]
impl SrvResolver {
pub(crate) async fn new(
config: Option<ResolverConfig>,
srv_service_name: Option<String>,
) -> Result<Self> {
let resolver = crate::runtime::AsyncResolver::new(config.map(|c| c.inner)).await?;
Ok(Self {
resolver,
srv_service_name,
})
}
pub(crate) async fn resolve_client_options(
&mut self,
hostname: &str,
) -> Result<ResolvedConfig> {
let lookup_result = self.get_srv_hosts(hostname, DomainMismatch::Error).await?;
let mut config = ResolvedConfig {
hosts: lookup_result.hosts,
min_ttl: lookup_result.min_ttl,
auth_source: None,
replica_set: None,
load_balanced: None,
};
self.get_txt_options(hostname, &mut config).await?;
Ok(config)
}
async fn get_srv_hosts_unvalidated(&self, lookup_hostname: &str) -> Result<LookupHosts> {
use hickory_proto::rr::RData;
let srv_lookup = self.resolver.srv_lookup(lookup_hostname).await?;
let mut hosts = vec![];
let mut min_ttl = u32::MAX;
for record in srv_lookup.as_lookup().record_iter() {
let srv = match record.data() {
Some(RData::SRV(s)) => s,
_ => continue,
};
let mut host = srv.target().to_utf8();
if host.ends_with('.') {
host.pop();
}
let port = Some(srv.port());
hosts.push(ServerAddress::Tcp { host, port });
min_ttl = std::cmp::min(min_ttl, record.ttl());
}
Ok(LookupHosts {
hosts,
min_ttl: Duration::from_secs(min_ttl.into()),
})
}
pub(crate) async fn get_srv_hosts(
&self,
original_hostname: &str,
dm: DomainMismatch,
) -> Result<LookupHosts> {
let lookup_hostname = format!(
"_{}._tcp.{}",
self.srv_service_name.as_deref().unwrap_or("mongodb"),
original_hostname
);
self.get_srv_hosts_unvalidated(&lookup_hostname)
.await?
.validate(original_hostname, dm)
}
async fn get_txt_options(
&self,
original_hostname: &str,
config: &mut ResolvedConfig,
) -> Result<()> {
let txt_records_response = match self.resolver.txt_lookup(original_hostname).await? {
Some(response) => response,
None => return Ok(()),
};
let mut txt_records = txt_records_response.iter();
let txt_record = match txt_records.next() {
Some(record) => record,
None => return Ok(()),
};
if txt_records.next().is_some() {
return Err(ErrorKind::DnsResolve {
message: format!(
"TXT lookup for {} returned more than one record, but more than one are not \
allowed with 'mongodb+srv'",
original_hostname,
),
}
.into());
}
let txt_data: Vec<_> = txt_record
.txt_data()
.iter()
.map(|bytes| String::from_utf8_lossy(bytes.as_ref()).into_owned())
.collect();
let txt_string = txt_data.join("");
for option_pair in txt_string.split('&') {
let parts: Vec<_> = option_pair.split('=').collect();
if parts.len() != 2 {
return Err(ErrorKind::DnsResolve {
message: format!(
"TXT record string '{}' is not a value `key=value` option pair",
option_pair
),
}
.into());
}
match &parts[0].to_lowercase()[..] {
"authsource" => {
config.auth_source = Some(parts[1].to_string());
}
"replicaset" => {
config.replica_set = Some(parts[1].into());
}
"loadbalanced" => {
let val = match parts[1] {
"true" => true,
"false" => false,
_ => {
return Err(ErrorKind::DnsResolve {
message: format!(
"TXT record option 'loadbalanced={}' was returned, only \
'true' and 'false' are allowed values.",
parts[1]
),
}
.into())
}
};
config.load_balanced = Some(val);
}
other => {
return Err(ErrorKind::DnsResolve {
message: format!(
"TXT record option '{}' was returned, but only 'authSource', \
'replicaSet', and 'loadBalanced' are allowed",
other
),
}
.into())
}
};
}
Ok(())
}
}
#[cfg(not(feature = "dns-resolver"))]
pub(crate) struct SrvResolver {}
#[cfg(not(feature = "dns-resolver"))]
impl SrvResolver {
pub(crate) async fn new(
_config: Option<ResolverConfig>,
_srv_service_name: Option<String>,
) -> Result<Self> {
Ok(Self {})
}
pub(crate) async fn resolve_client_options(
&mut self,
_hostname: &str,
) -> Result<ResolvedConfig> {
Err(crate::error::Error::invalid_argument(
"mongodb+srv connection strings cannot be used when the 'dns-resolver' feature is \
disabled",
))
}
pub(crate) async fn get_srv_hosts(
&self,
_original_hostname: &str,
_dm: DomainMismatch,
) -> Result<LookupHosts> {
return Err(crate::error::Error::invalid_argument(
"mongodb+srv connection strings cannot be used when the 'dns-resolver' feature is \
disabled",
));
}
}