use std::{
sync::Arc,
time::{Duration, Instant},
};
use bson::doc;
use tokio::sync::watch;
use super::{
description::server::{ServerDescription, TopologyVersion},
topology::{SdamEventEmitter, TopologyCheckRequestReceiver},
TopologyUpdater,
TopologyWatcher,
};
use crate::{
cmap::{establish::ConnectionEstablisher, Connection},
error::{Error, Result},
event::sdam::{
SdamEvent,
ServerHeartbeatFailedEvent,
ServerHeartbeatStartedEvent,
ServerHeartbeatSucceededEvent,
},
hello::{hello_command, run_hello, AwaitableHelloOptions, HelloReply},
options::{ClientOptions, ServerAddress},
runtime::{self, stream::DEFAULT_CONNECT_TIMEOUT, WorkerHandle, WorkerHandleListener},
};
pub(crate) const DEFAULT_HEARTBEAT_FREQUENCY: Duration = Duration::from_secs(10);
pub(crate) const MIN_HEARTBEAT_FREQUENCY: Duration = Duration::from_millis(500);
pub(crate) struct Monitor {
address: ServerAddress,
connection: Option<Connection>,
connection_establisher: ConnectionEstablisher,
topology_updater: TopologyUpdater,
topology_watcher: TopologyWatcher,
sdam_event_emitter: Option<SdamEventEmitter>,
client_options: ClientOptions,
topology_version: Option<TopologyVersion>,
rtt_monitor_handle: RttMonitorHandle,
request_receiver: MonitorRequestReceiver,
}
impl Monitor {
pub(crate) fn start(
address: ServerAddress,
topology_updater: TopologyUpdater,
topology_watcher: TopologyWatcher,
sdam_event_emitter: Option<SdamEventEmitter>,
manager_receiver: MonitorRequestReceiver,
client_options: ClientOptions,
connection_establisher: ConnectionEstablisher,
) {
let (rtt_monitor, rtt_monitor_handle) = RttMonitor::new(
address.clone(),
topology_watcher.clone(),
connection_establisher.clone(),
client_options.clone(),
);
let monitor = Self {
address,
client_options,
connection_establisher,
topology_updater,
topology_watcher,
sdam_event_emitter,
rtt_monitor_handle,
request_receiver: manager_receiver,
connection: None,
topology_version: None,
};
runtime::execute(monitor.execute());
runtime::execute(rtt_monitor.execute());
}
async fn execute(mut self) {
let heartbeat_frequency = self.heartbeat_frequency();
while self.is_alive() {
let check_succeeded = self.check_server().await;
if self.topology_version.is_none() || !check_succeeded {
self.request_receiver
.wait_for_check_request(
self.client_options.min_heartbeat_frequency(),
heartbeat_frequency,
)
.await;
}
}
}
fn is_alive(&self) -> bool {
self.request_receiver.is_alive()
}
async fn check_server(&mut self) -> bool {
let check_result = match self.perform_hello().await {
HelloResult::Err(e) => {
let previous_description = self.topology_watcher.server_description(&self.address);
if e.is_network_error()
&& previous_description
.map(|sd| sd.is_available())
.unwrap_or(false)
{
self.handle_error(e).await;
self.perform_hello().await
} else {
HelloResult::Err(e)
}
}
other => other,
};
match check_result {
HelloResult::Ok(reply) => {
let avg_rtt = self.rtt_monitor_handle.average_rtt();
debug_assert!(avg_rtt.is_some());
let avg_rtt = avg_rtt.unwrap_or(Duration::MAX);
let server_description =
ServerDescription::new_from_hello_reply(self.address.clone(), reply, avg_rtt);
self.topology_updater.update(server_description).await;
true
}
HelloResult::Err(e) => {
self.handle_error(e).await;
false
}
HelloResult::Cancelled { .. } => false,
}
}
async fn perform_hello(&mut self) -> HelloResult {
self.emit_event(|| {
SdamEvent::ServerHeartbeatStarted(ServerHeartbeatStartedEvent {
server_address: self.address.clone(),
awaited: self.topology_version.is_some(),
})
});
let heartbeat_frequency = self.heartbeat_frequency();
let timeout = if self.connect_timeout().is_zero() {
Duration::MAX
} else if self.topology_version.is_some() {
heartbeat_frequency
.checked_add(self.connect_timeout())
.unwrap_or(Duration::MAX)
} else {
self.connect_timeout()
};
let execute_hello = async {
match self.connection {
Some(ref mut conn) => {
if conn.is_streaming() {
conn.receive_message()
.await
.and_then(|r| r.into_hello_reply())
} else {
let opts = self.topology_version.map(|tv| AwaitableHelloOptions {
topology_version: tv,
max_await_time: heartbeat_frequency,
});
let command = hello_command(
self.client_options.server_api.as_ref(),
self.client_options.load_balanced,
Some(conn.stream_description()?.hello_ok),
opts,
);
run_hello(conn, command).await
}
}
None => {
let start = Instant::now();
let res = self
.connection_establisher
.establish_monitoring_connection(self.address.clone())
.await;
match res {
Ok((conn, hello_reply)) => {
self.rtt_monitor_handle.add_sample(start.elapsed());
self.connection = Some(conn);
Ok(hello_reply)
}
Err(e) => Err(e),
}
}
}
};
let start = Instant::now();
let result = tokio::select! {
result = execute_hello => match result {
Ok(reply) => HelloResult::Ok(reply),
Err(e) => HelloResult::Err(e)
},
r = self.request_receiver.wait_for_cancellation() => {
let reason_error = match r {
CancellationReason::Error(e) => e,
CancellationReason::ServerClosed => Error::internal("server closed")
};
HelloResult::Cancelled { reason: reason_error }
}
_ = runtime::delay_for(timeout) => {
HelloResult::Err(Error::network_timeout())
}
};
let duration = start.elapsed();
match result {
HelloResult::Ok(ref r) => {
self.emit_event(|| {
let mut reply = r
.raw_command_response
.to_document()
.unwrap_or_else(|e| doc! { "deserialization error": e.to_string() });
reply.remove("speculativeAuthenticate");
SdamEvent::ServerHeartbeatSucceeded(ServerHeartbeatSucceededEvent {
duration,
reply,
server_address: self.address.clone(),
awaited: self.topology_version.is_some(),
})
});
self.topology_version = r.command_response.topology_version;
}
HelloResult::Err(ref e) | HelloResult::Cancelled { reason: ref e } => {
self.connection = None;
self.rtt_monitor_handle.reset_average_rtt();
self.emit_event(|| {
SdamEvent::ServerHeartbeatFailed(ServerHeartbeatFailedEvent {
duration,
failure: e.clone(),
server_address: self.address.clone(),
awaited: self.topology_version.is_some(),
})
});
self.topology_version.take();
}
}
result
}
async fn handle_error(&mut self, error: Error) -> bool {
self.topology_updater
.handle_monitor_error(self.address.clone(), error)
.await
}
fn emit_event<F>(&self, event: F)
where
F: FnOnce() -> SdamEvent,
{
if let Some(ref emitter) = self.sdam_event_emitter {
let _ = emitter.emit(event());
}
}
fn connect_timeout(&self) -> Duration {
self.client_options
.connect_timeout
.unwrap_or(DEFAULT_CONNECT_TIMEOUT)
}
fn heartbeat_frequency(&self) -> Duration {
self.client_options
.heartbeat_freq
.unwrap_or(DEFAULT_HEARTBEAT_FREQUENCY)
}
}
struct RttMonitor {
sender: Arc<watch::Sender<RttInfo>>,
connection: Option<Connection>,
topology: TopologyWatcher,
address: ServerAddress,
client_options: ClientOptions,
connection_establisher: ConnectionEstablisher,
}
#[derive(Debug, Clone, Copy, Default)]
pub(crate) struct RttInfo {
pub(crate) average: Option<Duration>,
}
impl RttInfo {
pub(crate) fn add_sample(&mut self, sample: Duration) {
match self.average {
Some(old_rtt) => {
self.average = Some((sample / 5) + (old_rtt * 4 / 5))
}
None => self.average = Some(sample),
}
}
}
impl RttMonitor {
fn new(
address: ServerAddress,
topology: TopologyWatcher,
connection_establisher: ConnectionEstablisher,
client_options: ClientOptions,
) -> (Self, RttMonitorHandle) {
let (sender, rtt_receiver) = watch::channel(RttInfo { average: None });
let sender = Arc::new(sender);
let monitor = Self {
address,
connection: None,
topology,
client_options,
connection_establisher,
sender: sender.clone(),
};
let handle = RttMonitorHandle {
reset_sender: sender,
rtt_receiver,
};
(monitor, handle)
}
async fn execute(mut self) {
while self.topology.is_alive() && !self.sender.is_closed() {
let timeout = self
.client_options
.connect_timeout
.unwrap_or(DEFAULT_CONNECT_TIMEOUT);
let perform_check = async {
match self.connection {
Some(ref mut conn) => {
let command = hello_command(
self.client_options.server_api.as_ref(),
self.client_options.load_balanced,
Some(conn.stream_description()?.hello_ok),
None,
);
conn.send_command(command, None).await?;
}
None => {
let connection = self
.connection_establisher
.establish_monitoring_connection(self.address.clone())
.await?
.0;
self.connection = Some(connection);
}
};
Result::Ok(())
};
let start = Instant::now();
let check_succeded = tokio::select! {
r = perform_check => r.is_ok(),
_ = runtime::delay_for(timeout) => {
false
}
};
if check_succeded {
self.sender
.send_modify(|rtt_info| rtt_info.add_sample(start.elapsed()));
} else {
self.connection = None;
}
runtime::delay_for(
self.client_options
.heartbeat_freq
.unwrap_or(DEFAULT_HEARTBEAT_FREQUENCY),
)
.await;
}
}
}
struct RttMonitorHandle {
rtt_receiver: watch::Receiver<RttInfo>,
reset_sender: Arc<watch::Sender<RttInfo>>,
}
impl RttMonitorHandle {
fn average_rtt(&self) -> Option<Duration> {
self.rtt_receiver.borrow().average
}
fn reset_average_rtt(&mut self) {
let _ = self.reset_sender.send(RttInfo::default());
}
fn add_sample(&mut self, sample: Duration) {
self.reset_sender.send_modify(|rtt_info| {
rtt_info.add_sample(sample);
});
}
}
#[allow(clippy::large_enum_variant)] #[derive(Debug, Clone)]
enum HelloResult {
Ok(HelloReply),
Err(Error),
Cancelled { reason: Error },
}
#[derive(Debug, Clone)]
pub(crate) struct MonitorManager {
handle: WorkerHandle,
cancellation_sender: Arc<watch::Sender<CancellationReason>>,
check_requester: Arc<watch::Sender<()>>,
}
impl MonitorManager {
pub(crate) fn new(monitor_handle: WorkerHandle) -> Self {
let (tx, _) = watch::channel(CancellationReason::ServerClosed);
let check_requester = Arc::new(watch::channel(()).0);
MonitorManager {
handle: monitor_handle,
cancellation_sender: Arc::new(tx),
check_requester,
}
}
pub(crate) async fn close_monitor(self) {
drop(self.handle);
let _ = self
.cancellation_sender
.send(CancellationReason::ServerClosed);
self.cancellation_sender.closed().await;
}
pub(crate) fn cancel_in_progress_check(&mut self, reason: impl Into<CancellationReason>) {
let _ = self.cancellation_sender.send(reason.into());
}
pub(crate) fn request_immediate_check(&mut self) {
let _ = self.check_requester.send(());
}
}
pub(crate) struct MonitorRequestReceiver {
handle_listener: WorkerHandleListener,
cancellation_receiver: watch::Receiver<CancellationReason>,
individual_check_request_receiver: watch::Receiver<()>,
topology_check_request_receiver: TopologyCheckRequestReceiver,
}
#[derive(Debug, Clone)]
pub(crate) enum CancellationReason {
Error(Error),
ServerClosed,
}
impl From<Error> for CancellationReason {
fn from(e: Error) -> Self {
Self::Error(e)
}
}
impl MonitorRequestReceiver {
pub(crate) fn new(
manager: &MonitorManager,
topology_check_request_receiver: TopologyCheckRequestReceiver,
handle_listener: WorkerHandleListener,
) -> Self {
Self {
handle_listener,
cancellation_receiver: manager.cancellation_sender.subscribe(),
individual_check_request_receiver: manager.check_requester.subscribe(),
topology_check_request_receiver,
}
}
async fn wait_for_cancellation(&mut self) -> CancellationReason {
let err = if self.cancellation_receiver.changed().await.is_ok() {
self.cancellation_receiver.borrow().clone()
} else {
CancellationReason::ServerClosed
};
self.individual_check_request_receiver.borrow_and_update();
err
}
async fn wait_for_check_request(&mut self, delay: Duration, timeout: Duration) {
let _ = runtime::timeout(timeout, async {
let wait_for_check_request = async {
runtime::delay_for(delay).await;
self.topology_check_request_receiver
.wait_for_check_request()
.await;
};
tokio::pin!(wait_for_check_request);
loop {
tokio::select! {
_ = self.individual_check_request_receiver.changed() => {
break;
}
_ = &mut wait_for_check_request => {
break;
}
_ = self.handle_listener.wait_for_all_handle_drops() => {
break;
}
}
}
})
.await;
self.cancellation_receiver.borrow_and_update();
}
fn is_alive(&self) -> bool {
self.handle_listener.is_alive()
}
}