pub(super) mod server;
use std::{
collections::{HashMap, HashSet},
sync::{
atomic::{AtomicBool, Ordering},
Arc,
Weak,
},
};
use tokio::sync::{RwLock, RwLockWriteGuard};
use self::server::Server;
use super::{
description::topology::{server_selection::SelectedServer, TransactionSupportStatus},
message_manager::TopologyMessageSubscriber,
ServerInfo,
SessionSupportStatus,
TopologyDescription,
};
use crate::{
bson::oid::ObjectId,
client::ClusterTime,
cmap::{conn::ConnectionGeneration, Command, Connection, PoolGeneration},
error::{load_balanced_mode_mismatch, Error, Result},
event::sdam::{
ServerClosedEvent,
ServerDescriptionChangedEvent,
ServerOpeningEvent,
TopologyClosedEvent,
TopologyDescriptionChangedEvent,
TopologyOpeningEvent,
},
options::{ClientOptions, SelectionCriteria, ServerAddress},
runtime::HttpClient,
sdam::{
description::{
server::{ServerDescription, ServerType},
topology::{server_selection, TopologyType},
},
srv_polling::SrvPollingMonitor,
TopologyMessageManager,
},
RUNTIME,
};
#[derive(Clone, Debug)]
pub(crate) struct Topology {
state: Arc<RwLock<TopologyState>>,
common: Common,
}
#[derive(Clone, Debug)]
pub(crate) struct WeakTopology {
state: Weak<RwLock<TopologyState>>,
common: Common,
}
#[derive(Clone, Debug)]
struct Common {
is_alive: Arc<AtomicBool>,
message_manager: TopologyMessageManager,
options: ClientOptions,
id: ObjectId,
}
#[derive(Debug)]
struct TopologyState {
http_client: HttpClient,
description: TopologyDescription,
servers: HashMap<ServerAddress, Arc<Server>>,
#[cfg(test)]
mocked: bool,
}
impl Topology {
#[cfg(test)]
pub(super) fn new_mocked(options: ClientOptions) -> Self {
let description = TopologyDescription::new(options.clone()).unwrap();
let id = ObjectId::new();
if let Some(ref handler) = options.sdam_event_handler {
let event = TopologyOpeningEvent { topology_id: id };
handler.handle_topology_opening_event(event);
}
let common = Common {
is_alive: Arc::new(AtomicBool::new(true)),
message_manager: TopologyMessageManager::new(),
options: options.clone(),
id,
};
let http_client = HttpClient::default();
let state = TopologyState {
description,
servers: Default::default(),
http_client: http_client.clone(),
mocked: true,
};
let topology = Self {
state: Arc::new(RwLock::new(state)),
common,
};
let mut topology_state = RUNTIME.block_in_place(topology.state.write());
for address in &options.hosts {
topology_state.servers.insert(
address.clone(),
Server::create(
address.clone(),
&options,
topology.downgrade(),
http_client.clone(),
)
.0,
);
}
if let Some(ref handler) = options.sdam_event_handler {
let event = TopologyDescriptionChangedEvent {
topology_id: id,
previous_description: TopologyDescription::new_empty().into(),
new_description: topology_state.description.clone().into(),
};
handler.handle_topology_description_changed_event(event);
for server_address in &options.hosts {
let event = ServerOpeningEvent {
topology_id: id,
address: server_address.clone(),
};
handler.handle_server_opening_event(event);
}
}
drop(topology_state);
topology
}
pub(crate) fn new(mut options: ClientOptions) -> Result<Self> {
let description = TopologyDescription::new(options.clone())?;
let is_load_balanced = description.topology_type() == TopologyType::LoadBalanced;
let id = ObjectId::new();
if let Some(ref handler) = options.sdam_event_handler {
let event = TopologyOpeningEvent { topology_id: id };
handler.handle_topology_opening_event(event);
}
let hosts: Vec<_> = options.hosts.drain(..).collect();
let common = Common {
is_alive: Arc::new(AtomicBool::new(true)),
message_manager: TopologyMessageManager::new(),
options: options.clone(),
id,
};
let http_client = HttpClient::default();
#[cfg(test)]
let topology_state = TopologyState {
description,
servers: Default::default(),
http_client,
mocked: false,
};
#[cfg(not(test))]
let topology_state = TopologyState {
description,
servers: Default::default(),
http_client,
};
let state = Arc::new(RwLock::new(topology_state));
let topology = Topology { state, common };
let mut topology_state = RUNTIME.block_in_place(topology.state.write());
for address in hosts {
topology_state.add_new_server(address, options.clone(), &topology.downgrade());
}
if !is_load_balanced {
SrvPollingMonitor::start(topology.downgrade());
}
if let Some(ref handler) = options.sdam_event_handler {
let event = TopologyDescriptionChangedEvent {
topology_id: id,
previous_description: TopologyDescription::new_empty().into(),
new_description: topology_state.description.clone().into(),
};
handler.handle_topology_description_changed_event(event);
for server_address in &options.hosts {
let event = ServerOpeningEvent {
topology_id: id,
address: server_address.clone(),
};
handler.handle_server_opening_event(event);
}
}
SrvPollingMonitor::start(topology.downgrade());
drop(topology_state);
Ok(topology)
}
pub(crate) fn close(&self) {
self.common.is_alive.store(false, Ordering::SeqCst);
if let Some(ref handler) = self.common.options.sdam_event_handler {
let event = TopologyClosedEvent {
topology_id: self.common.id,
};
handler.handle_topology_closed_event(event);
}
}
#[cfg(test)]
pub(crate) async fn servers(&self) -> HashSet<ServerAddress> {
self.state.read().await.servers.keys().cloned().collect()
}
#[cfg(test)]
pub(crate) async fn description(&self) -> TopologyDescription {
self.state.read().await.description.clone()
}
pub(super) fn downgrade(&self) -> WeakTopology {
WeakTopology {
state: Arc::downgrade(&self.state),
common: self.common.clone(),
}
}
pub(crate) async fn attempt_to_select_server(
&self,
criteria: &SelectionCriteria,
) -> Result<Option<SelectedServer>> {
let topology_state = self.state.read().await;
server_selection::attempt_to_select_server(
criteria,
&topology_state.description,
&topology_state.servers,
)
}
pub(crate) async fn server_selection_timeout_error_message(
&self,
criteria: &SelectionCriteria,
) -> String {
self.state
.read()
.await
.description
.server_selection_timeout_error_message(criteria)
}
pub(crate) fn request_topology_check(&self) {
self.common.message_manager.request_topology_check();
}
pub(crate) fn subscribe_to_topology_check_requests(&self) -> TopologyMessageSubscriber {
self.common
.message_manager
.subscribe_to_topology_check_requests()
}
pub(crate) fn subscribe_to_topology_changes(&self) -> TopologyMessageSubscriber {
self.common.message_manager.subscribe_to_topology_changes()
}
pub(crate) fn notify_topology_changed(&self) {
self.common.message_manager.notify_topology_changed();
}
pub(crate) async fn handle_application_error(
&self,
error: Error,
handshake: HandshakePhase,
server: &Server,
) -> bool {
let state_lock = self.state.write().await;
match &handshake {
HandshakePhase::PreHello { generation } => {
match (generation, server.pool.generation()) {
(PoolGeneration::Normal(hgen), PoolGeneration::Normal(sgen)) => {
if *hgen < sgen {
return false;
}
}
(PoolGeneration::LoadBalanced(_), PoolGeneration::LoadBalanced(_)) => {
return false
}
_ => load_balanced_mode_mismatch!(false),
}
}
HandshakePhase::PostHello { generation }
| HandshakePhase::AfterCompletion { generation, .. } => {
if generation.is_stale(&server.pool.generation()) {
return false;
}
}
}
let is_load_balanced = state_lock.description.topology_type() == TopologyType::LoadBalanced;
if error.is_state_change_error() {
let updated = is_load_balanced
|| self
.mark_server_as_unknown(error.to_string(), server, state_lock)
.await;
if updated && (error.is_shutting_down() || handshake.wire_version().unwrap_or(0) < 8) {
server.pool.clear(error, handshake.service_id()).await;
}
self.request_topology_check();
updated
} else if error.is_non_timeout_network_error()
|| (handshake.is_before_completion()
&& (error.is_auth_error()
|| error.is_network_timeout()
|| error.is_command_error()))
{
let updated = is_load_balanced
|| self
.mark_server_as_unknown(error.to_string(), server, state_lock)
.await;
if updated {
server.pool.clear(error, handshake.service_id()).await;
}
updated
} else {
false
}
}
pub(crate) async fn handle_monitor_error(&self, error: Error, server: &Server) -> bool {
let state_lock = self.state.write().await;
let updated = self
.mark_server_as_unknown(error.to_string(), server, state_lock)
.await;
if updated {
server.pool.clear(error, None).await;
}
updated
}
async fn mark_server_as_unknown(
&self,
error: String,
server: &Server,
state_lock: RwLockWriteGuard<'_, TopologyState>,
) -> bool {
let description = ServerDescription::new(server.address.clone(), Some(Err(error)));
self.update_and_notify(server, description, state_lock)
.await
}
async fn update_and_notify(
&self,
server: &Server,
server_description: ServerDescription,
mut state_lock: RwLockWriteGuard<'_, TopologyState>,
) -> bool {
let server_type = server_description.server_type;
match state_lock.update(server_description, &self.common.options, self.downgrade()) {
Ok(true) => {
if server_type.is_data_bearing()
|| (server_type != ServerType::Unknown
&& state_lock.description.topology_type() == TopologyType::Single)
{
server.pool.mark_as_ready().await;
}
true
}
_ => false,
}
}
pub(crate) async fn update(
&self,
server: &Server,
server_description: ServerDescription,
) -> bool {
self.update_and_notify(server, server_description, self.state.write().await)
.await
}
pub(crate) async fn update_hosts(
&self,
hosts: HashSet<ServerAddress>,
options: &ClientOptions,
) -> bool {
self.state
.write()
.await
.update_hosts(&hosts, options, self.downgrade());
true
}
pub(crate) async fn advance_cluster_time(&self, cluster_time: &ClusterTime) {
self.state
.write()
.await
.description
.advance_cluster_time(cluster_time);
}
pub(crate) async fn cluster_time(&self) -> Option<ClusterTime> {
self.state
.read()
.await
.description
.cluster_time()
.map(Clone::clone)
}
pub(crate) async fn update_command_with_read_pref<T>(
&self,
server_address: &ServerAddress,
command: &mut Command<T>,
criteria: Option<&SelectionCriteria>,
) {
self.state
.read()
.await
.update_command_with_read_pref(server_address, command, criteria)
}
pub(crate) async fn session_support_status(&self) -> SessionSupportStatus {
self.state.read().await.description.session_support_status()
}
pub(crate) async fn transaction_support_status(&self) -> TransactionSupportStatus {
self.state
.read()
.await
.description
.transaction_support_status()
}
pub(crate) async fn topology_type(&self) -> TopologyType {
self.state.read().await.description.topology_type()
}
pub(crate) async fn get_server_description(
&self,
address: &ServerAddress,
) -> Option<ServerDescription> {
self.state
.read()
.await
.description
.get_server_description(address)
.cloned()
}
#[cfg(test)]
pub(crate) async fn get_servers(&self) -> HashMap<ServerAddress, Weak<Server>> {
self.state
.read()
.await
.servers
.iter()
.map(|(addr, server)| (addr.clone(), Arc::downgrade(server)))
.collect()
}
}
impl WeakTopology {
pub(crate) fn upgrade(&self) -> Option<Topology> {
Some(Topology {
state: self.state.upgrade()?,
common: self.common.clone(),
})
}
pub(crate) fn is_alive(&self) -> bool {
self.common.is_alive.load(Ordering::SeqCst)
}
pub(crate) fn client_options(&self) -> &ClientOptions {
&self.common.options
}
}
impl TopologyState {
fn add_new_server(
&mut self,
address: ServerAddress,
options: ClientOptions,
topology: &WeakTopology,
) {
if self.servers.contains_key(&address) {
return;
}
let (server, monitor) = Server::create(
address.clone(),
&options,
topology.clone(),
self.http_client.clone(),
);
self.servers.insert(address, server);
#[cfg(test)]
if self.mocked {
return;
}
monitor.start();
}
pub(crate) fn update_command_with_read_pref<T>(
&self,
server_address: &ServerAddress,
command: &mut Command<T>,
criteria: Option<&SelectionCriteria>,
) {
let server_type = self
.description
.get_server_description(server_address)
.map(|desc| desc.server_type)
.unwrap_or(ServerType::Unknown);
self.description
.update_command_with_read_pref(server_type, command, criteria)
}
fn update(
&mut self,
server: ServerDescription,
options: &ClientOptions,
topology: WeakTopology,
) -> std::result::Result<bool, String> {
let old_description = self.description.clone();
self.description.update(server)?;
let hosts: HashSet<_> = self.description.server_addresses().cloned().collect();
self.sync_hosts(&hosts, options, &topology);
let diff = old_description.diff(&self.description);
let topology_changed = diff.is_some();
if let Some(ref handler) = options.sdam_event_handler {
if let Some(diff) = diff {
for (address, (previous_description, new_description)) in diff.changed_servers {
let event = ServerDescriptionChangedEvent {
address: address.clone(),
topology_id: topology.common.id,
previous_description: ServerInfo::new_owned(previous_description.clone()),
new_description: ServerInfo::new_owned(new_description.clone()),
};
handler.handle_server_description_changed_event(event);
}
for address in diff.removed_addresses {
let event = ServerClosedEvent {
address: address.clone(),
topology_id: topology.common.id,
};
handler.handle_server_closed_event(event);
}
for address in diff.added_addresses {
let event = ServerOpeningEvent {
address: address.clone(),
topology_id: topology.common.id,
};
handler.handle_server_opening_event(event);
}
let event = TopologyDescriptionChangedEvent {
topology_id: topology.common.id,
previous_description: old_description.clone().into(),
new_description: self.description.clone().into(),
};
handler.handle_topology_description_changed_event(event);
}
}
Ok(topology_changed)
}
fn update_hosts(
&mut self,
hosts: &HashSet<ServerAddress>,
options: &ClientOptions,
topology: WeakTopology,
) {
self.description.sync_hosts(hosts);
self.sync_hosts(hosts, options, &topology);
}
fn sync_hosts(
&mut self,
hosts: &HashSet<ServerAddress>,
options: &ClientOptions,
topology: &WeakTopology,
) {
for address in hosts.iter() {
self.add_new_server(address.clone(), options.clone(), topology);
}
self.servers.retain(|host, _| hosts.contains(host));
}
}
#[derive(Debug, Clone)]
pub(crate) enum HandshakePhase {
PreHello { generation: PoolGeneration },
PostHello { generation: ConnectionGeneration },
AfterCompletion {
generation: ConnectionGeneration,
max_wire_version: i32,
},
}
impl HandshakePhase {
pub(crate) fn after_completion(handshaked_connection: &Connection) -> Self {
Self::AfterCompletion {
generation: handshaked_connection.generation.clone(),
max_wire_version: handshaked_connection
.stream_description()
.ok()
.and_then(|sd| sd.max_wire_version)
.unwrap_or(0),
}
}
pub(crate) fn service_id(&self) -> Option<ObjectId> {
match self {
HandshakePhase::PreHello { .. } => None,
HandshakePhase::PostHello { generation, .. } => generation.service_id(),
HandshakePhase::AfterCompletion { generation, .. } => generation.service_id(),
}
}
fn is_before_completion(&self) -> bool {
!matches!(self, HandshakePhase::AfterCompletion { .. })
}
fn wire_version(&self) -> Option<i32> {
match self {
HandshakePhase::AfterCompletion {
max_wire_version, ..
} => Some(*max_wire_version),
_ => None,
}
}
}