pub(crate) mod server_selection;
#[cfg(test)]
pub(crate) mod test;
use std::{
collections::{HashMap, HashSet},
time::Duration,
};
use serde::Deserialize;
use crate::{
bson::oid::ObjectId,
client::ClusterTime,
cmap::Command,
options::{ClientOptions, ServerAddress},
sdam::description::server::{ServerDescription, ServerType},
selection_criteria::{ReadPreference, SelectionCriteria},
};
const DEFAULT_HEARTBEAT_FREQUENCY: Duration = Duration::from_secs(10);
#[derive(Debug, Clone, Copy, Eq, PartialEq, Deserialize)]
#[non_exhaustive]
pub enum TopologyType {
Single,
ReplicaSetNoPrimary,
ReplicaSetWithPrimary,
Sharded,
LoadBalanced,
Unknown,
}
#[cfg(test)]
impl TopologyType {
fn as_str(&self) -> &'static str {
match self {
Self::Single => "Single",
Self::ReplicaSetNoPrimary => "ReplicaSetNoPrimary",
Self::ReplicaSetWithPrimary => "ReplicaSetWithPrimary",
Self::Sharded => "Sharded",
Self::LoadBalanced => "LoadBalanced",
Self::Unknown => "Unknown",
}
}
}
impl Default for TopologyType {
fn default() -> Self {
TopologyType::Unknown
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub(crate) struct TopologyDescription {
pub(crate) single_seed: bool,
pub(crate) topology_type: TopologyType,
pub(crate) set_name: Option<String>,
pub(crate) max_set_version: Option<i32>,
pub(crate) max_election_id: Option<ObjectId>,
pub(crate) compatibility_error: Option<String>,
pub(crate) session_support_status: SessionSupportStatus,
pub(crate) transaction_support_status: TransactionSupportStatus,
pub(crate) cluster_time: Option<ClusterTime>,
pub(crate) local_threshold: Option<Duration>,
pub(crate) heartbeat_freq: Option<Duration>,
pub(crate) servers: HashMap<ServerAddress, ServerDescription>,
}
impl PartialEq for TopologyDescription {
fn eq(&self, other: &Self) -> bool {
self.compatibility_error == other.compatibility_error
&& self.servers == other.servers
&& self.topology_type == other.topology_type
}
}
impl TopologyDescription {
pub(crate) fn new(options: ClientOptions) -> crate::error::Result<Self> {
verify_max_staleness(
options
.selection_criteria
.as_ref()
.and_then(|criteria| criteria.max_staleness()),
)?;
let topology_type = if let Some(true) = options.direct_connection {
TopologyType::Single
} else if options.repl_set_name.is_some() {
TopologyType::ReplicaSetNoPrimary
} else if options.load_balanced.unwrap_or(false) {
TopologyType::LoadBalanced
} else {
TopologyType::Unknown
};
let servers: HashMap<_, _> = options
.hosts
.into_iter()
.map(|address| {
let description = if topology_type == TopologyType::LoadBalanced {
ServerDescription::new_load_balancer(address.clone())
} else {
ServerDescription::new(address.clone(), None)
};
(address, description)
})
.collect();
let session_support_status = if topology_type == TopologyType::LoadBalanced {
SessionSupportStatus::Supported {
logical_session_timeout: None,
}
} else {
SessionSupportStatus::Undetermined
};
Ok(Self {
single_seed: servers.len() == 1,
topology_type,
set_name: options.repl_set_name,
max_set_version: None,
max_election_id: None,
compatibility_error: None,
session_support_status,
transaction_support_status: TransactionSupportStatus::Undetermined,
cluster_time: None,
local_threshold: options.local_threshold,
heartbeat_freq: options.heartbeat_freq,
servers,
})
}
pub(crate) fn new_empty() -> Self {
Self {
single_seed: false,
topology_type: TopologyType::Unknown,
set_name: None,
max_set_version: None,
max_election_id: None,
compatibility_error: None,
session_support_status: SessionSupportStatus::Undetermined,
transaction_support_status: TransactionSupportStatus::Undetermined,
cluster_time: None,
local_threshold: None,
heartbeat_freq: None,
servers: HashMap::new(),
}
}
pub(crate) fn topology_type(&self) -> TopologyType {
self.topology_type
}
pub(crate) fn server_addresses(&self) -> impl Iterator<Item = &ServerAddress> {
self.servers.keys()
}
pub(crate) fn cluster_time(&self) -> Option<&ClusterTime> {
self.cluster_time.as_ref()
}
pub(crate) fn get_server_description(
&self,
address: &ServerAddress,
) -> Option<&ServerDescription> {
self.servers.get(address)
}
pub(crate) fn update_command_with_read_pref<T>(
&self,
server_type: ServerType,
command: &mut Command<T>,
criteria: Option<&SelectionCriteria>,
) {
match (self.topology_type, server_type) {
(TopologyType::Sharded, ServerType::Mongos)
| (TopologyType::Single, ServerType::Mongos)
| (TopologyType::LoadBalanced, _) => {
self.update_command_read_pref_for_mongos(command, criteria)
}
(TopologyType::Single, ServerType::Standalone) => {}
(TopologyType::Single, _) => {
let specified_read_pref = criteria
.and_then(SelectionCriteria::as_read_pref)
.map(Clone::clone);
let resolved_read_pref = match specified_read_pref {
Some(ReadPreference::Primary) | None => ReadPreference::PrimaryPreferred {
options: Default::default(),
},
Some(other) => other,
};
command.set_read_preference(resolved_read_pref)
}
_ => {
let read_pref = match criteria {
Some(SelectionCriteria::ReadPreference(rp)) => rp.clone(),
Some(SelectionCriteria::Predicate(_)) => ReadPreference::PrimaryPreferred {
options: Default::default(),
},
None => ReadPreference::Primary,
};
command.set_read_preference(read_pref)
}
}
}
fn update_command_read_pref_for_mongos<T>(
&self,
command: &mut Command<T>,
criteria: Option<&SelectionCriteria>,
) {
let read_preference = match criteria {
Some(SelectionCriteria::ReadPreference(rp)) => rp,
_ => return,
};
match read_preference {
ReadPreference::Secondary { .. }
| ReadPreference::PrimaryPreferred { .. }
| ReadPreference::Nearest { .. } => {
command.set_read_preference(read_preference.clone())
}
ReadPreference::SecondaryPreferred { ref options }
if options.max_staleness.is_some() || options.tag_sets.is_some() =>
{
command.set_read_preference(read_preference.clone())
}
_ => {}
}
}
fn heartbeat_frequency(&self) -> Duration {
self.heartbeat_freq.unwrap_or(DEFAULT_HEARTBEAT_FREQUENCY)
}
fn check_compatibility(&mut self) {
self.compatibility_error = None;
for server in self.servers.values() {
let error_message = server.compatibility_error_message();
if error_message.is_some() {
self.compatibility_error = error_message;
return;
}
}
}
pub(crate) fn compatibility_error(&self) -> Option<&String> {
self.compatibility_error.as_ref()
}
fn update_round_trip_time(&self, server_description: &mut ServerDescription) {
if let Some(old_rtt) = self
.servers
.get(&server_description.address)
.and_then(|server_desc| server_desc.average_round_trip_time)
{
if let Some(new_rtt) = server_description.average_round_trip_time {
server_description.average_round_trip_time =
Some((new_rtt / 5) + (old_rtt * 4 / 5));
}
}
}
fn update_session_support_status(&mut self, server_description: &ServerDescription) {
if !server_description.server_type.is_data_bearing() {
if let TopologyType::Single = self.topology_type {
self.session_support_status = SessionSupportStatus::Unsupported {
logical_session_timeout: None,
};
}
return;
}
match server_description.logical_session_timeout().ok().flatten() {
Some(timeout) => match self.session_support_status {
SessionSupportStatus::Supported {
logical_session_timeout: topology_timeout,
} => {
self.session_support_status = SessionSupportStatus::Supported {
logical_session_timeout: std::cmp::min(Some(timeout), topology_timeout),
};
}
SessionSupportStatus::Undetermined => {
self.session_support_status = SessionSupportStatus::Supported {
logical_session_timeout: Some(timeout),
}
}
SessionSupportStatus::Unsupported { .. } => {
let min_timeout = self
.servers
.values()
.filter(|s| s.server_type.is_data_bearing())
.map(|s| s.logical_session_timeout().ok().flatten())
.min()
.flatten();
match min_timeout {
Some(timeout) => {
self.session_support_status = SessionSupportStatus::Supported {
logical_session_timeout: Some(timeout),
}
}
None => {
self.session_support_status = SessionSupportStatus::Unsupported {
logical_session_timeout: None,
}
}
}
}
},
None if server_description.server_type.is_data_bearing()
|| self.topology_type == TopologyType::Single =>
{
self.session_support_status = SessionSupportStatus::Unsupported {
logical_session_timeout: None,
}
}
None => {}
}
}
fn update_transaction_support_status(&mut self, server_description: &ServerDescription) {
if !matches!(
self.session_support_status,
SessionSupportStatus::Supported { .. }
) {
self.transaction_support_status = TransactionSupportStatus::Unsupported;
}
if let Ok(Some(max_wire_version)) = server_description.max_wire_version() {
self.transaction_support_status = if max_wire_version < 7
|| (max_wire_version < 8 && self.topology_type == TopologyType::Sharded)
{
TransactionSupportStatus::Unsupported
} else {
TransactionSupportStatus::Supported
}
}
}
pub(crate) fn advance_cluster_time(&mut self, cluster_time: &ClusterTime) {
if self.cluster_time.as_ref() >= Some(cluster_time) {
return;
}
self.cluster_time = Some(cluster_time.clone());
}
pub(crate) fn diff<'a>(
&'a self,
other: &'a TopologyDescription,
) -> Option<TopologyDescriptionDiff> {
if self == other {
return None;
}
let addresses: HashSet<&ServerAddress> = self.server_addresses().collect();
let other_addresses: HashSet<&ServerAddress> = other.server_addresses().collect();
let changed_servers = self
.servers
.iter()
.filter_map(|(address, description)| match other.servers.get(address) {
Some(other_description) if description != other_description => {
Some((address, (description, other_description)))
}
_ => None,
});
Some(TopologyDescriptionDiff {
removed_addresses: addresses.difference(&other_addresses).cloned().collect(),
added_addresses: other_addresses.difference(&addresses).cloned().collect(),
changed_servers: changed_servers.collect(),
})
}
pub(crate) fn sync_hosts(&mut self, hosts: &HashSet<ServerAddress>) {
self.add_new_servers_from_addresses(hosts.iter());
self.servers.retain(|host, _| hosts.contains(host));
}
pub(crate) fn session_support_status(&self) -> SessionSupportStatus {
self.session_support_status
}
pub(crate) fn transaction_support_status(&self) -> TransactionSupportStatus {
self.transaction_support_status
}
pub(crate) fn update(
&mut self,
mut server_description: ServerDescription,
) -> Result<(), String> {
if !self.servers.contains_key(&server_description.address) {
return Ok(());
}
self.update_round_trip_time(&mut server_description);
self.servers.insert(
server_description.address.clone(),
server_description.clone(),
);
self.update_session_support_status(&server_description);
self.update_transaction_support_status(&server_description);
if let Some(ref cluster_time) = server_description.cluster_time().ok().flatten() {
self.advance_cluster_time(cluster_time);
}
match self.topology_type {
TopologyType::Single | TopologyType::LoadBalanced => {}
TopologyType::Unknown => self.update_unknown_topology(server_description)?,
TopologyType::Sharded => self.update_sharded_topology(server_description),
TopologyType::ReplicaSetNoPrimary => {
self.update_replica_set_no_primary_topology(server_description)?
}
TopologyType::ReplicaSetWithPrimary => {
self.update_replica_set_with_primary_topology(server_description)?;
}
}
self.check_compatibility();
Ok(())
}
fn update_unknown_topology(
&mut self,
server_description: ServerDescription,
) -> Result<(), String> {
match server_description.server_type {
ServerType::Unknown | ServerType::RsGhost => {}
ServerType::Standalone => {
self.update_unknown_with_standalone_server(server_description)
}
ServerType::Mongos => self.topology_type = TopologyType::Sharded,
ServerType::RsPrimary => {
self.topology_type = TopologyType::ReplicaSetWithPrimary;
self.update_rs_from_primary_server(server_description)?;
}
ServerType::RsSecondary | ServerType::RsArbiter | ServerType::RsOther => {
self.topology_type = TopologyType::ReplicaSetNoPrimary;
self.update_rs_without_primary_server(server_description)?;
}
ServerType::LoadBalancer => {
return Err("cannot transition to a load balancer".to_string())
}
}
Ok(())
}
fn update_sharded_topology(&mut self, server_description: ServerDescription) {
match server_description.server_type {
ServerType::Unknown | ServerType::Mongos => {}
_ => {
self.servers.remove(&server_description.address);
}
}
}
fn update_replica_set_no_primary_topology(
&mut self,
server_description: ServerDescription,
) -> Result<(), String> {
match server_description.server_type {
ServerType::Unknown | ServerType::RsGhost => {}
ServerType::Standalone | ServerType::Mongos => {
self.servers.remove(&server_description.address);
}
ServerType::RsPrimary => {
self.topology_type = TopologyType::ReplicaSetWithPrimary;
self.update_rs_from_primary_server(server_description)?
}
ServerType::RsSecondary | ServerType::RsArbiter | ServerType::RsOther => {
self.update_rs_without_primary_server(server_description)?;
}
ServerType::LoadBalancer => {
return Err("cannot transition to a load balancer".to_string())
}
}
Ok(())
}
fn update_replica_set_with_primary_topology(
&mut self,
server_description: ServerDescription,
) -> Result<(), String> {
match server_description.server_type {
ServerType::Unknown | ServerType::RsGhost => {
self.record_primary_state();
}
ServerType::Standalone | ServerType::Mongos => {
self.servers.remove(&server_description.address);
self.record_primary_state();
}
ServerType::RsPrimary => self.update_rs_from_primary_server(server_description)?,
ServerType::RsSecondary | ServerType::RsArbiter | ServerType::RsOther => {
self.update_rs_with_primary_from_member(server_description)?;
}
ServerType::LoadBalancer => {
return Err("cannot transition to a load balancer".to_string())
}
}
Ok(())
}
fn update_unknown_with_standalone_server(&mut self, server_description: ServerDescription) {
if self.single_seed {
self.topology_type = TopologyType::Single;
} else {
self.servers.remove(&server_description.address);
}
}
fn update_rs_without_primary_server(
&mut self,
server_description: ServerDescription,
) -> Result<(), String> {
if self.set_name.is_none() {
self.set_name = server_description.set_name()?;
} else if self.set_name != server_description.set_name()? {
self.servers.remove(&server_description.address);
return Ok(());
}
self.add_new_servers(server_description.known_hosts()?)?;
if server_description.invalid_me()? {
self.servers.remove(&server_description.address);
}
Ok(())
}
fn update_rs_with_primary_from_member(
&mut self,
server_description: ServerDescription,
) -> Result<(), String> {
if self.set_name != server_description.set_name()? {
self.servers.remove(&server_description.address);
self.record_primary_state();
return Ok(());
}
if server_description.invalid_me()? {
self.servers.remove(&server_description.address);
self.record_primary_state();
return Ok(());
}
Ok(())
}
fn update_rs_from_primary_server(
&mut self,
server_description: ServerDescription,
) -> Result<(), String> {
if self.set_name.is_none() {
self.set_name = server_description.set_name()?;
} else if self.set_name != server_description.set_name()? {
self.servers.remove(&server_description.address);
self.record_primary_state();
return Ok(());
}
if let Some(server_set_version) = server_description.set_version()? {
if let Some(server_election_id) = server_description.election_id()? {
if let Some(topology_max_set_version) = self.max_set_version {
if let Some(ref topology_max_election_id) = self.max_election_id {
if topology_max_set_version > server_set_version
|| (topology_max_set_version == server_set_version
&& *topology_max_election_id > server_election_id)
{
self.servers.insert(
server_description.address.clone(),
ServerDescription::new(server_description.address, None),
);
self.record_primary_state();
return Ok(());
}
}
}
self.max_election_id = Some(server_election_id);
}
}
if let Some(server_set_version) = server_description.set_version()? {
if self
.max_set_version
.as_ref()
.map(|topology_max_set_version| server_set_version > *topology_max_set_version)
.unwrap_or(true)
{
self.max_set_version = Some(server_set_version);
}
}
let addresses: Vec<_> = self.servers.keys().cloned().collect();
for address in addresses.clone() {
if address == server_description.address {
continue;
}
if let ServerType::RsPrimary = self.servers.get(&address).unwrap().server_type {
self.servers
.insert(address.clone(), ServerDescription::new(address, None));
}
}
self.add_new_servers(server_description.known_hosts()?)?;
let known_hosts: HashSet<_> = server_description.known_hosts()?.collect();
for address in addresses {
if !known_hosts.contains(&address.to_string()) {
self.servers.remove(&address);
}
}
self.record_primary_state();
Ok(())
}
fn record_primary_state(&mut self) {
self.topology_type = if self
.servers
.values()
.any(|server| server.server_type == ServerType::RsPrimary)
{
TopologyType::ReplicaSetWithPrimary
} else {
TopologyType::ReplicaSetNoPrimary
};
}
fn add_new_servers<'a>(
&mut self,
servers: impl Iterator<Item = &'a String>,
) -> Result<(), String> {
let servers: Result<Vec<_>, String> = servers
.map(|server| ServerAddress::parse(server).map_err(|e| e.to_string()))
.collect();
self.add_new_servers_from_addresses(servers?.iter());
Ok(())
}
fn add_new_servers_from_addresses<'a>(
&mut self,
servers: impl Iterator<Item = &'a ServerAddress>,
) {
for server in servers {
if !self.servers.contains_key(server) {
self.servers
.insert(server.clone(), ServerDescription::new(server.clone(), None));
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) enum SessionSupportStatus {
Undetermined,
Unsupported {
logical_session_timeout: Option<Duration>,
},
Supported {
logical_session_timeout: Option<Duration>,
},
}
impl Default for SessionSupportStatus {
fn default() -> Self {
Self::Undetermined
}
}
impl SessionSupportStatus {
#[cfg(test)]
fn logical_session_timeout(&self) -> Option<Duration> {
match self {
Self::Undetermined => None,
Self::Unsupported {
logical_session_timeout,
} => *logical_session_timeout,
Self::Supported {
logical_session_timeout,
} => *logical_session_timeout,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub(crate) enum TransactionSupportStatus {
Undetermined,
Unsupported,
Supported,
}
impl Default for TransactionSupportStatus {
fn default() -> Self {
Self::Undetermined
}
}
#[derive(Debug)]
pub(crate) struct TopologyDescriptionDiff<'a> {
pub(crate) removed_addresses: HashSet<&'a ServerAddress>,
pub(crate) added_addresses: HashSet<&'a ServerAddress>,
pub(crate) changed_servers:
HashMap<&'a ServerAddress, (&'a ServerDescription, &'a ServerDescription)>,
}
fn verify_max_staleness(max_staleness: Option<Duration>) -> crate::error::Result<()> {
verify_max_staleness_inner(max_staleness)
.map_err(|s| crate::error::ErrorKind::InvalidArgument { message: s }.into())
}
fn verify_max_staleness_inner(max_staleness: Option<Duration>) -> std::result::Result<(), String> {
if max_staleness
.map(|staleness| staleness > Duration::from_secs(0) && staleness < Duration::from_secs(90))
.unwrap_or(false)
{
return Err("max staleness cannot be both positive and below 90 seconds".into());
}
Ok(())
}