use std::{
collections::{HashMap, HashSet},
future::Future,
sync::Arc,
time::Duration,
};
use bson::oid::ObjectId;
#[cfg(test)]
use futures_util::stream::{FuturesUnordered, StreamExt};
use futures_util::FutureExt;
use tokio::sync::{
broadcast,
mpsc::{self, UnboundedReceiver, UnboundedSender},
watch::{self, Ref},
};
use crate::{
client::options::{ClientOptions, ServerAddress},
cmap::{conn::ConnectionGeneration, Command, Connection, PoolGeneration},
error::{load_balanced_mode_mismatch, Error, Result},
event::sdam::{
handle_sdam_event,
SdamEvent,
ServerClosedEvent,
ServerDescriptionChangedEvent,
ServerOpeningEvent,
TopologyClosedEvent,
TopologyDescriptionChangedEvent,
TopologyOpeningEvent,
},
runtime::{self, AcknowledgedMessage, HttpClient, WorkerHandle, WorkerHandleListener},
selection_criteria::SelectionCriteria,
ClusterTime,
ServerInfo,
ServerType,
TopologyType,
};
use super::{
srv_polling::SrvPollingMonitor,
Monitor,
Server,
ServerDescription,
SessionSupportStatus,
TopologyDescription,
TransactionSupportStatus,
};
#[derive(Debug)]
pub(crate) struct Topology {
watcher: TopologyWatcher,
updater: TopologyUpdater,
check_requester: TopologyCheckRequester,
_worker_handle: WorkerHandle,
}
impl Topology {
pub(crate) fn new(options: ClientOptions) -> Result<Topology> {
let http_client = HttpClient::default();
let description = TopologyDescription::new(options.clone())?;
let is_load_balanced = options.load_balanced == Some(true);
let update_requester = TopologyCheckRequester::new();
let event_emitter = options.sdam_event_handler.as_ref().map(|handler| {
let (tx, mut rx) = mpsc::unbounded_channel::<AcknowledgedMessage<SdamEvent>>();
let handler = handler.clone();
runtime::execute(async move {
while let Some(event) = rx.recv().await {
let (event, ack) = event.into_parts();
let is_closed = matches!(event, SdamEvent::TopologyClosed(_));
handle_sdam_event(handler.as_ref(), event);
ack.acknowledge(());
if is_closed {
break;
};
}
});
SdamEventEmitter { sender: tx }
});
let (updater, update_receiver) = TopologyUpdater::channel();
let (worker_handle, handle_listener) = WorkerHandleListener::channel();
let servers = description
.server_addresses()
.map(|address| {
(
address.clone(),
Server::new(
address.clone(),
options.clone(),
http_client.clone(),
updater.clone(),
),
)
})
.collect();
let state = TopologyState {
description,
servers,
};
let addresses = state.servers.keys().cloned().collect::<Vec<_>>();
let (watcher, publisher) = TopologyWatcher::channel(state);
#[cfg(test)]
let disable_monitoring_threads = options
.test_options
.as_ref()
.map(|to| to.disable_monitoring_threads)
.unwrap_or(false);
#[cfg(not(test))]
let disable_monitoring_threads = false;
if !is_load_balanced && !disable_monitoring_threads {
for address in addresses {
Monitor::start(
address.clone(),
updater.clone(),
watcher.clone(),
event_emitter.clone(),
update_requester.subscribe(),
options.clone(),
);
}
SrvPollingMonitor::start(updater.clone(), watcher.clone(), options.clone());
}
let worker = TopologyWorker {
id: ObjectId::new(),
update_receiver,
publisher,
options,
http_client,
topology_watcher: watcher.clone(),
topology_updater: updater.clone(),
update_requester: update_requester.clone(),
handle_listener,
event_emitter,
};
if let Some(ref emitter) = worker.event_emitter {
let event = SdamEvent::TopologyOpening(TopologyOpeningEvent {
topology_id: worker.id,
});
let _ = emitter.emit(event);
let new_description = worker.borrow_latest_state().description.clone().into();
let event = TopologyDescriptionChangedEvent {
topology_id: worker.id,
previous_description: TopologyDescription::new_empty().into(),
new_description,
};
let _ = emitter.emit(SdamEvent::TopologyDescriptionChanged(Box::new(event)));
for server_address in worker.options.hosts.iter() {
let event = SdamEvent::ServerOpening(ServerOpeningEvent {
topology_id: worker.id,
address: server_address.clone(),
});
let _ = emitter.emit(event);
}
}
if is_load_balanced {
let mut new_state = worker.borrow_latest_state().clone();
let old_description = new_state.description.clone();
for server_address in new_state.servers.keys() {
let new_desc = ServerDescription {
server_type: ServerType::LoadBalancer,
average_round_trip_time: Some(Duration::from_nanos(0)),
..ServerDescription::new(server_address.clone(), None)
};
new_state
.description
.update(new_desc)
.map_err(Error::internal)?;
}
worker.process_topology_diff(&old_description, &new_state.description);
worker.publisher.publish_new_state(new_state);
}
worker.start();
Ok(Topology {
watcher,
updater,
check_requester: update_requester,
_worker_handle: worker_handle,
})
}
pub(crate) fn watch(&self) -> TopologyWatcher {
let mut watcher = self.watcher.clone();
watcher.receiver.borrow_and_update();
watcher
}
#[cfg(test)]
pub(crate) fn clone_updater(&self) -> TopologyUpdater {
self.updater.clone()
}
pub(crate) fn request_update(&self) {
self.check_requester.request()
}
pub(crate) async fn handle_application_error(
&self,
address: ServerAddress,
error: Error,
phase: HandshakePhase,
) {
self.updater
.handle_application_error(address, error, phase)
.await;
}
pub(crate) fn cluster_time(&self) -> Option<ClusterTime> {
self.watcher
.peek_latest()
.description
.cluster_time()
.cloned()
}
pub(crate) async fn advance_cluster_time(&self, to: ClusterTime) {
self.updater.advance_cluster_time(to).await;
}
pub(crate) fn topology_type(&self) -> TopologyType {
self.watcher.peek_latest().description.topology_type
}
pub(crate) fn session_support_status(&self) -> SessionSupportStatus {
self.watcher
.peek_latest()
.description
.session_support_status()
}
pub(crate) fn transaction_support_status(&self) -> TransactionSupportStatus {
self.watcher
.peek_latest()
.description
.transaction_support_status()
}
pub(crate) fn update_command_with_read_pref<T>(
&self,
server_address: &ServerAddress,
command: &mut Command<T>,
criteria: Option<&SelectionCriteria>,
) {
self.watcher
.peek_latest()
.description
.update_command_with_read_pref(server_address, command, criteria)
}
pub(crate) fn server_selection_timeout_error_message(
&self,
criteria: &SelectionCriteria,
) -> String {
self.watcher
.peek_latest()
.description
.server_selection_timeout_error_message(criteria)
}
#[cfg(test)]
pub(crate) fn server_addresses(&self) -> HashSet<ServerAddress> {
self.watcher.peek_latest().servers.keys().cloned().collect()
}
#[cfg(test)]
pub(crate) fn servers(&self) -> HashMap<ServerAddress, Arc<Server>> {
self.watcher.peek_latest().servers.clone()
}
#[cfg(test)]
pub(crate) fn description(&self) -> TopologyDescription {
self.watcher.peek_latest().description.clone()
}
#[cfg(test)]
pub(crate) async fn sync_workers(&self) {
self.updater.sync_workers().await;
}
}
#[derive(Debug, Clone)]
pub(crate) struct TopologyState {
pub(crate) description: TopologyDescription,
pub(crate) servers: HashMap<ServerAddress, Arc<Server>>,
}
#[derive(Debug)]
pub(crate) enum UpdateMessage {
AdvanceClusterTime(ClusterTime),
ServerUpdate(Box<ServerDescription>),
SyncHosts(HashSet<ServerAddress>),
MonitorError {
address: ServerAddress,
error: Error,
},
ApplicationError {
address: ServerAddress,
error: Error,
phase: HandshakePhase,
},
#[cfg(test)]
SyncWorkers,
}
struct TopologyWorker {
id: ObjectId,
update_receiver: TopologyUpdateReceiver,
handle_listener: WorkerHandleListener,
publisher: TopologyPublisher,
event_emitter: Option<SdamEventEmitter>,
options: ClientOptions,
http_client: HttpClient,
topology_watcher: TopologyWatcher,
topology_updater: TopologyUpdater,
update_requester: TopologyCheckRequester,
}
impl TopologyWorker {
fn start(mut self) {
runtime::execute(async move {
loop {
tokio::select! {
Some(update) = self.update_receiver.recv() => {
let (update, ack) = update.into_parts();
let changed = match update {
UpdateMessage::AdvanceClusterTime(to) => {
self.advance_cluster_time(to);
true
}
UpdateMessage::SyncHosts(hosts) => {
let mut state = self.borrow_latest_state().clone();
self.sync_hosts(&mut state, hosts);
self.publisher.publish_new_state(state);
true
}
UpdateMessage::ServerUpdate(sd) => self.update_server(*sd).await,
UpdateMessage::MonitorError { address, error } => {
self.handle_monitor_error(address, error).await
}
UpdateMessage::ApplicationError {
address,
error,
phase,
} => self.handle_application_error(address, error, phase).await,
#[cfg(test)]
UpdateMessage::SyncWorkers => {
let rxen: FuturesUnordered<_> = self
.borrow_latest_state()
.servers
.values()
.map(|v| v.pool.sync_worker())
.collect();
let _: Vec<_> = rxen.collect().await;
false
}
};
ack.acknowledge(changed);
},
_ = self.handle_listener.wait_for_all_handle_drops() => {
break
}
}
}
drop(self.publisher);
if let Some(emitter) = self.event_emitter {
emitter
.emit(SdamEvent::TopologyClosed(TopologyClosedEvent {
topology_id: self.id,
}))
.await;
}
});
}
fn borrow_latest_state(&self) -> Ref<TopologyState> {
self.topology_watcher.peek_latest()
}
fn advance_cluster_time(&mut self, to: ClusterTime) {
let mut latest_state = self.borrow_latest_state().clone();
latest_state.description.advance_cluster_time(&to);
self.publisher.publish_new_state(latest_state);
}
fn sync_hosts(&self, state: &mut TopologyState, hosts: HashSet<ServerAddress>) {
state.servers.retain(|host, _| hosts.contains(host));
state.description.sync_hosts(&hosts);
for address in hosts {
if state.servers.contains_key(&address) {
continue;
}
let server = Server::new(
address.clone(),
self.options.clone(),
self.http_client.clone(),
self.topology_updater.clone(),
);
state.servers.insert(address.clone(), server);
#[cfg(test)]
let disable_monitoring_threads = self
.options
.test_options
.as_ref()
.map(|to| to.disable_monitoring_threads)
.unwrap_or(false);
#[cfg(not(test))]
let disable_monitoring_threads = false;
if !disable_monitoring_threads {
Monitor::start(
address,
self.topology_updater.clone(),
self.topology_watcher.clone(),
self.event_emitter.clone(),
self.update_requester.subscribe(),
self.options.clone(),
);
}
}
}
async fn update_server(&mut self, mut sd: ServerDescription) -> bool {
let mut latest_state = self.borrow_latest_state().clone();
let old_description = latest_state.description.clone();
if let Some(expected_name) = &self.options.repl_set_name {
let got_name = sd.set_name();
if latest_state.description.topology_type() == TopologyType::Single
&& got_name.as_ref().map(|opt| opt.as_ref()) != Ok(Some(expected_name))
{
let got_display = match got_name {
Ok(Some(s)) => format!("{:?}", s),
Ok(None) => "<none>".to_string(),
Err(s) => format!("<error: {}>", s),
};
sd = ServerDescription::new(
sd.address,
Some(Err(format!(
"Connection string replicaSet name {:?} does not match actual name {}",
expected_name, got_display,
))),
);
}
}
let server_type = sd.server_type;
let server_address = sd.address.clone();
let _ = latest_state.description.update(sd);
let hosts = latest_state
.description
.server_addresses()
.cloned()
.collect();
self.sync_hosts(&mut latest_state, hosts);
let topology_changed =
self.process_topology_diff(&old_description, &latest_state.description);
if topology_changed {
if server_type.is_data_bearing()
|| (server_type != ServerType::Unknown
&& latest_state.description.topology_type() == TopologyType::Single)
{
if let Some(s) = latest_state.servers.get(&server_address) {
s.pool.mark_as_ready().await;
}
}
self.publisher.publish_new_state(latest_state)
}
topology_changed
}
fn process_topology_diff(
&self,
old_description: &TopologyDescription,
new_description: &TopologyDescription,
) -> bool {
let diff = old_description.diff(new_description);
let changed = diff.is_some();
if let Some(ref emitter) = self.event_emitter {
if let Some(diff) = diff {
for (address, (previous_description, new_description)) in diff.changed_servers {
let event = ServerDescriptionChangedEvent {
address: address.clone(),
topology_id: self.id,
previous_description: ServerInfo::new_owned(previous_description.clone()),
new_description: ServerInfo::new_owned(new_description.clone()),
};
let _ = emitter.emit(SdamEvent::ServerDescriptionChanged(Box::new(event)));
}
for address in diff.removed_addresses {
let event = SdamEvent::ServerClosed(ServerClosedEvent {
address: address.clone(),
topology_id: self.id,
});
let _ = emitter.emit(event);
}
for address in diff.added_addresses {
let event = ServerOpeningEvent {
address: address.clone(),
topology_id: self.id,
};
let _ = emitter.emit(SdamEvent::ServerOpening(event));
}
let event = TopologyDescriptionChangedEvent {
topology_id: self.id,
previous_description: old_description.clone().into(),
new_description: new_description.clone().into(),
};
let _ = emitter.emit(SdamEvent::TopologyDescriptionChanged(Box::new(event)));
}
}
changed
}
async fn mark_server_as_unknown(&mut self, address: ServerAddress, error: Error) -> bool {
let description = ServerDescription::new(address, Some(Err(error.to_string())));
self.update_server(description).await
}
pub(crate) async fn handle_application_error(
&mut self,
address: ServerAddress,
error: Error,
handshake: HandshakePhase,
) -> bool {
let server = match self.server(&address) {
Some(s) => s,
None => return false,
};
match &handshake {
HandshakePhase::PreHello { generation } => {
match (generation, server.pool.generation()) {
(PoolGeneration::Normal(hgen), PoolGeneration::Normal(sgen)) => {
if *hgen < sgen {
return false;
}
}
(PoolGeneration::LoadBalanced(_), PoolGeneration::LoadBalanced(_)) => {
return false
}
_ => load_balanced_mode_mismatch!(false),
}
}
HandshakePhase::PostHello { generation }
| HandshakePhase::AfterCompletion { generation, .. } => {
if generation.is_stale(&server.pool.generation()) {
return false;
}
}
}
let is_load_balanced =
self.borrow_latest_state().description.topology_type() == TopologyType::LoadBalanced;
if error.is_state_change_error() {
let updated =
is_load_balanced || self.mark_server_as_unknown(address, error.clone()).await;
if updated && (error.is_shutting_down() || handshake.wire_version().unwrap_or(0) < 8) {
server.pool.clear(error, handshake.service_id()).await;
}
self.update_requester.request();
updated
} else if error.is_non_timeout_network_error()
|| (handshake.is_before_completion()
&& (error.is_auth_error()
|| error.is_network_timeout()
|| error.is_command_error()))
{
let updated = is_load_balanced
|| self
.mark_server_as_unknown(server.address.clone(), error.clone())
.await;
if updated {
server.pool.clear(error, handshake.service_id()).await;
}
updated
} else {
false
}
}
pub(crate) async fn handle_monitor_error(
&mut self,
address: ServerAddress,
error: Error,
) -> bool {
match self.server(&address) {
Some(server) => {
let updated = self.mark_server_as_unknown(address, error.clone()).await;
if updated {
server.pool.clear(error, None).await;
}
updated
}
None => false,
}
}
fn server(&self, address: &ServerAddress) -> Option<Arc<Server>> {
self.borrow_latest_state().servers.get(address).cloned()
}
}
#[derive(Debug, Clone)]
pub(crate) struct TopologyUpdater {
sender: UnboundedSender<AcknowledgedMessage<UpdateMessage, bool>>,
}
impl TopologyUpdater {
pub(crate) fn channel() -> (TopologyUpdater, TopologyUpdateReceiver) {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let updater = TopologyUpdater { sender: tx };
let update_receiver = TopologyUpdateReceiver {
update_receiver: rx,
};
(updater, update_receiver)
}
async fn send_message(&self, update: UpdateMessage) -> bool {
let (message, receiver) = AcknowledgedMessage::package(update);
match self.sender.send(message) {
Ok(_) => receiver.wait_for_acknowledgment().await.unwrap_or(false),
_ => false,
}
}
pub(crate) async fn handle_monitor_error(&self, address: ServerAddress, error: Error) -> bool {
self.send_message(UpdateMessage::MonitorError { address, error })
.await
}
pub(crate) async fn handle_application_error(
&self,
address: ServerAddress,
error: Error,
phase: HandshakePhase,
) -> bool {
self.send_message(UpdateMessage::ApplicationError {
address,
error,
phase,
})
.await
}
pub(crate) async fn update(&self, sd: ServerDescription) -> bool {
self.send_message(UpdateMessage::ServerUpdate(Box::new(sd)))
.await
}
pub(crate) async fn advance_cluster_time(&self, to: ClusterTime) {
self.send_message(UpdateMessage::AdvanceClusterTime(to))
.await;
}
pub(crate) async fn sync_hosts(&self, hosts: HashSet<ServerAddress>) {
self.send_message(UpdateMessage::SyncHosts(hosts)).await;
}
#[cfg(test)]
pub(crate) async fn sync_workers(&self) {
self.send_message(UpdateMessage::SyncWorkers).await;
}
}
pub(crate) struct TopologyUpdateReceiver {
update_receiver: UnboundedReceiver<AcknowledgedMessage<UpdateMessage, bool>>,
}
impl TopologyUpdateReceiver {
pub(crate) async fn recv(&mut self) -> Option<AcknowledgedMessage<UpdateMessage, bool>> {
self.update_receiver.recv().await
}
}
#[derive(Debug, Clone)]
pub(crate) struct TopologyWatcher {
receiver: watch::Receiver<TopologyState>,
}
impl TopologyWatcher {
fn channel(initial_state: TopologyState) -> (TopologyWatcher, TopologyPublisher) {
let (tx, rx) = watch::channel(initial_state);
let watcher = TopologyWatcher { receiver: rx };
let publisher = TopologyPublisher { state_sender: tx };
(watcher, publisher)
}
pub(crate) fn is_alive(&self) -> bool {
self.receiver.has_changed().is_ok()
}
pub(crate) fn server_description(&self, address: &ServerAddress) -> Option<ServerDescription> {
self.receiver
.borrow()
.description
.get_server_description(address)
.cloned()
}
pub(crate) fn observe_latest(&mut self) -> TopologyState {
self.receiver.borrow_and_update().clone()
}
pub(crate) async fn wait_for_update(&mut self, timeout: Duration) -> bool {
let changed = runtime::timeout(timeout, self.receiver.changed())
.await
.is_ok();
self.receiver.borrow_and_update();
changed
}
pub(crate) fn peek_latest(&self) -> Ref<TopologyState> {
self.receiver.borrow()
}
pub(crate) fn topology_type(&self) -> TopologyType {
self.peek_latest().description.topology_type
}
}
struct TopologyPublisher {
state_sender: watch::Sender<TopologyState>,
}
impl TopologyPublisher {
fn publish_new_state(&self, state: TopologyState) {
let _ = self.state_sender.send(state);
}
}
#[derive(Clone, Debug)]
struct TopologyCheckRequester {
sender: broadcast::Sender<()>,
}
impl TopologyCheckRequester {
fn new() -> TopologyCheckRequester {
let (tx, _rx) = broadcast::channel(1);
TopologyCheckRequester { sender: tx }
}
fn request(&self) {
let _ = self.sender.send(());
}
fn subscribe(&self) -> TopologyCheckRequestReceiver {
TopologyCheckRequestReceiver {
receiver: self.sender.subscribe(),
}
}
}
pub(crate) struct TopologyCheckRequestReceiver {
receiver: broadcast::Receiver<()>,
}
impl TopologyCheckRequestReceiver {
pub(crate) async fn wait_for_check_request(&mut self, timeout: Duration) {
let _: std::result::Result<_, _> = runtime::timeout(timeout, self.receiver.recv()).await;
}
pub(crate) fn clear_check_requests(&mut self) {
let _: std::result::Result<_, _> = self.receiver.try_recv();
}
}
#[derive(Clone)]
pub(crate) struct SdamEventEmitter {
sender: UnboundedSender<AcknowledgedMessage<SdamEvent>>,
}
impl SdamEventEmitter {
pub(crate) fn emit(&self, event: impl Into<SdamEvent>) -> impl Future<Output = ()> {
let (msg, ack) = AcknowledgedMessage::package(event.into());
let _ = self.sender.send(msg);
ack.wait_for_acknowledgment().map(|_| ())
}
}
#[derive(Debug, Clone)]
pub(crate) enum HandshakePhase {
PreHello { generation: PoolGeneration },
PostHello { generation: ConnectionGeneration },
AfterCompletion {
generation: ConnectionGeneration,
max_wire_version: i32,
},
}
impl HandshakePhase {
pub(crate) fn after_completion(handshaked_connection: &Connection) -> Self {
Self::AfterCompletion {
generation: handshaked_connection.generation.clone(),
max_wire_version: handshaked_connection
.stream_description()
.ok()
.and_then(|sd| sd.max_wire_version)
.unwrap_or(0),
}
}
pub(crate) fn service_id(&self) -> Option<ObjectId> {
match self {
HandshakePhase::PreHello { .. } => None,
HandshakePhase::PostHello { generation, .. } => generation.service_id(),
HandshakePhase::AfterCompletion { generation, .. } => generation.service_id(),
}
}
fn is_before_completion(&self) -> bool {
!matches!(self, HandshakePhase::AfterCompletion { .. })
}
fn wire_version(&self) -> Option<i32> {
match self {
HandshakePhase::AfterCompletion {
max_wire_version, ..
} => Some(*max_wire_version),
_ => None,
}
}
}