mongodb/sdam/
monitor.rs

1use std::{
2    sync::{
3        atomic::{AtomicU32, Ordering},
4        Arc,
5    },
6    time::{Duration, Instant},
7};
8
9use crate::bson::doc;
10use tokio::sync::watch;
11
12use super::{
13    description::server::{ServerDescription, TopologyVersion},
14    topology::{SdamEventEmitter, TopologyCheckRequestReceiver},
15    TopologyUpdater,
16    TopologyWatcher,
17};
18use crate::{
19    client::options::ServerMonitoringMode,
20    cmap::{establish::ConnectionEstablisher, Connection},
21    error::{Error, Result},
22    event::sdam::{
23        SdamEvent,
24        ServerHeartbeatFailedEvent,
25        ServerHeartbeatStartedEvent,
26        ServerHeartbeatSucceededEvent,
27    },
28    hello::{hello_command, run_hello, AwaitableHelloOptions, HelloReply},
29    options::{ClientOptions, ServerAddress},
30    runtime::{self, stream::DEFAULT_CONNECT_TIMEOUT, WorkerHandle, WorkerHandleListener},
31};
32
33fn next_monitoring_connection_id() -> u32 {
34    static MONITORING_CONNECTION_ID: AtomicU32 = AtomicU32::new(0);
35
36    MONITORING_CONNECTION_ID.fetch_add(1, Ordering::SeqCst)
37}
38
39pub(crate) const DEFAULT_HEARTBEAT_FREQUENCY: Duration = Duration::from_secs(10);
40pub(crate) const MIN_HEARTBEAT_FREQUENCY: Duration = Duration::from_millis(500);
41
42/// Monitor that performs regular heartbeats to determine server status.
43pub(crate) struct Monitor {
44    address: ServerAddress,
45    connection: Option<Connection>,
46    connection_establisher: ConnectionEstablisher,
47    topology_updater: TopologyUpdater,
48    topology_watcher: TopologyWatcher,
49    sdam_event_emitter: Option<SdamEventEmitter>,
50    client_options: ClientOptions,
51
52    /// Whether this monitor is allowed to use the streaming protocol.
53    allow_streaming: bool,
54
55    /// The most recent topology version returned by the server in a hello response.
56    topology_version: Option<TopologyVersion>,
57
58    /// The RTT monitor; once it's started this is None.
59    pending_rtt_monitor: Option<RttMonitor>,
60
61    /// Handle to the RTT monitor, used to get the latest known round trip time for a given server
62    /// and to reset the RTT when the monitor disconnects from the server.
63    rtt_monitor_handle: RttMonitorHandle,
64
65    /// Handle to the `Server` instance in the `Topology`. This is used to detect when a server has
66    /// been removed from the topology and no longer needs to be monitored and to receive
67    /// cancellation requests.
68    request_receiver: MonitorRequestReceiver,
69}
70
71impl Monitor {
72    pub(crate) fn start(
73        address: ServerAddress,
74        topology_updater: TopologyUpdater,
75        topology_watcher: TopologyWatcher,
76        sdam_event_emitter: Option<SdamEventEmitter>,
77        manager_receiver: MonitorRequestReceiver,
78        client_options: ClientOptions,
79        connection_establisher: ConnectionEstablisher,
80    ) {
81        let (rtt_monitor, rtt_monitor_handle) = RttMonitor::new(
82            address.clone(),
83            topology_watcher.clone(),
84            connection_establisher.clone(),
85            client_options.clone(),
86        );
87        let allow_streaming = match client_options
88            .server_monitoring_mode
89            .clone()
90            .unwrap_or(ServerMonitoringMode::Auto)
91        {
92            ServerMonitoringMode::Stream => true,
93            ServerMonitoringMode::Poll => false,
94            ServerMonitoringMode::Auto => !crate::cmap::is_faas(),
95        };
96        let monitor = Self {
97            address,
98            client_options,
99            connection_establisher,
100            topology_updater,
101            topology_watcher,
102            sdam_event_emitter,
103            pending_rtt_monitor: Some(rtt_monitor),
104            rtt_monitor_handle,
105            request_receiver: manager_receiver,
106            connection: None,
107            allow_streaming,
108            topology_version: None,
109        };
110
111        runtime::spawn(monitor.execute());
112    }
113
114    async fn execute(mut self) {
115        let heartbeat_frequency = self.heartbeat_frequency();
116
117        while self.is_alive() {
118            let check_succeeded = self.check_server().await;
119
120            if self.topology_version.is_some() && self.allow_streaming {
121                if let Some(rtt_monitor) = self.pending_rtt_monitor.take() {
122                    runtime::spawn(rtt_monitor.execute());
123                }
124            }
125
126            // In the streaming protocol, we read from the socket continuously
127            // rather than polling at specific intervals, unless the most recent check
128            // failed.
129            //
130            // We only go to sleep when using the polling protocol (i.e. server never returned a
131            // topologyVersion) or when the most recent check failed.
132            if self.topology_version.is_none() || !check_succeeded || !self.allow_streaming {
133                self.request_receiver
134                    .wait_for_check_request(
135                        self.client_options.min_heartbeat_frequency(),
136                        heartbeat_frequency,
137                    )
138                    .await;
139            }
140        }
141    }
142
143    fn is_alive(&self) -> bool {
144        self.request_receiver.is_alive()
145    }
146
147    /// Checks the the server by running a hello command. If an I/O error occurs, the
148    /// connection will replaced with a new one.
149    ///
150    /// Returns whether the check succeeded or not.
151    async fn check_server(&mut self) -> bool {
152        let check_result = match self.perform_hello().await {
153            HelloResult::Err(e) => {
154                let previous_description = self.topology_watcher.server_description(&self.address);
155                if e.is_network_error()
156                    && previous_description
157                        .map(|sd| sd.is_available())
158                        .unwrap_or(false)
159                {
160                    self.handle_error(e).await;
161                    self.perform_hello().await
162                } else {
163                    HelloResult::Err(e)
164                }
165            }
166            other => other,
167        };
168
169        match check_result {
170            HelloResult::Ok(reply) => {
171                let avg_rtt = self.rtt_monitor_handle.average_rtt();
172
173                // If we have an Ok result, then we at least performed a handshake, which should
174                // mean that the RTT has a value.
175                debug_assert!(avg_rtt.is_some());
176
177                // In the event that we don't have an average RTT value (e.g. due to a bug), just
178                // default to using the maximum possible value.
179                let avg_rtt = avg_rtt.unwrap_or(Duration::MAX);
180
181                let server_description =
182                    ServerDescription::new_from_hello_reply(self.address.clone(), reply, avg_rtt);
183                self.topology_updater.update(server_description).await;
184                true
185            }
186            HelloResult::Err(e) => {
187                self.handle_error(e).await;
188                false
189            }
190            HelloResult::Cancelled { .. } => false,
191        }
192    }
193
194    async fn perform_hello(&mut self) -> HelloResult {
195        let driver_connection_id = self
196            .connection
197            .as_ref()
198            .map(|c| c.id)
199            .unwrap_or(next_monitoring_connection_id());
200
201        self.emit_event(|| {
202            SdamEvent::ServerHeartbeatStarted(ServerHeartbeatStartedEvent {
203                server_address: self.address.clone(),
204                awaited: self.topology_version.is_some() && self.allow_streaming,
205                driver_connection_id,
206                server_connection_id: self.connection.as_ref().and_then(|c| c.server_id),
207            })
208        });
209
210        let heartbeat_frequency = self.heartbeat_frequency();
211        let timeout = if self.connect_timeout().is_zero() {
212            // If connectTimeoutMS = 0, then the socket timeout for monitoring is unlimited.
213            Duration::MAX
214        } else if self.topology_version.is_some() {
215            // For streaming responses, use connectTimeoutMS + heartbeatFrequencyMS for socket
216            // timeout.
217            heartbeat_frequency
218                .checked_add(self.connect_timeout())
219                .unwrap_or(Duration::MAX)
220        } else {
221            // Otherwise, just use connectTimeoutMS.
222            self.connect_timeout()
223        };
224
225        let execute_hello = async {
226            match self.connection {
227                Some(ref mut conn) => {
228                    // If the server indicated there was moreToCome, just read from the socket.
229                    if conn.is_streaming() {
230                        conn.receive_message()
231                            .await
232                            .and_then(|r| r.into_hello_reply())
233                    // Otherwise, send a regular hello command.
234                    } else {
235                        // If the initial handshake returned a topology version, send it back to the
236                        // server to begin streaming responses.
237                        let opts = if self.allow_streaming {
238                            self.topology_version.map(|tv| AwaitableHelloOptions {
239                                topology_version: tv,
240                                max_await_time: heartbeat_frequency,
241                            })
242                        } else {
243                            None
244                        };
245
246                        let command = hello_command(
247                            self.client_options.server_api.as_ref(),
248                            self.client_options.load_balanced,
249                            Some(conn.stream_description()?.hello_ok),
250                            opts,
251                        );
252
253                        run_hello(conn, command, None).await
254                    }
255                }
256                None => {
257                    let start = Instant::now();
258                    let res = self
259                        .connection_establisher
260                        .establish_monitoring_connection(self.address.clone(), driver_connection_id)
261                        .await;
262                    match res {
263                        Ok((conn, hello_reply)) => {
264                            self.rtt_monitor_handle.add_sample(start.elapsed());
265                            self.connection = Some(conn);
266                            Ok(hello_reply)
267                        }
268                        Err(e) => Err(e),
269                    }
270                }
271            }
272        };
273
274        // Execute the hello while also listening for cancellation and keeping track of the timeout.
275        let start = Instant::now();
276        let result = tokio::select! {
277            result = execute_hello => match result {
278                Ok(mut reply) => {
279                    // Do not propagate server reported cluster time for monitoring hello responses.
280                    reply.cluster_time = None;
281                    HelloResult::Ok(reply)
282                },
283                Err(e) => HelloResult::Err(e)
284            },
285            r = self.request_receiver.wait_for_cancellation() => {
286                let reason_error = match r {
287                    CancellationReason::Error(e) => e,
288                    CancellationReason::ServerClosed => Error::internal("server closed")
289                };
290                HelloResult::Cancelled { reason: reason_error }
291            }
292            _ = tokio::time::sleep(timeout) => {
293                HelloResult::Err(Error::network_timeout())
294            }
295        };
296        let duration = start.elapsed();
297
298        let awaited = self.topology_version.is_some() && self.allow_streaming;
299        match result {
300            HelloResult::Ok(ref r) => {
301                if !awaited {
302                    self.rtt_monitor_handle.add_sample(duration);
303                }
304                self.emit_event(|| {
305                    let mut reply =
306                        crate::bson::Document::try_from(r.raw_command_response.as_ref())
307                            .unwrap_or_else(|e| doc! { "deserialization error": e.to_string() });
308                    // if this hello call is part of a handshake, remove speculative authentication
309                    // information before publishing an event
310                    reply.remove("speculativeAuthenticate");
311                    SdamEvent::ServerHeartbeatSucceeded(ServerHeartbeatSucceededEvent {
312                        duration,
313                        reply,
314                        server_address: self.address.clone(),
315                        awaited,
316                        driver_connection_id,
317                        server_connection_id: self.connection.as_ref().and_then(|c| c.server_id),
318                    })
319                });
320
321                // If the response included a topology version, cache it so that we can return it in
322                // the next hello.
323                self.topology_version = r.command_response.topology_version;
324            }
325            HelloResult::Err(ref e) | HelloResult::Cancelled { reason: ref e } => {
326                self.emit_event(|| {
327                    SdamEvent::ServerHeartbeatFailed(ServerHeartbeatFailedEvent {
328                        duration,
329                        failure: e.clone(),
330                        server_address: self.address.clone(),
331                        awaited,
332                        driver_connection_id,
333                        server_connection_id: self.connection.as_ref().and_then(|c| c.server_id),
334                    })
335                });
336
337                // Per the spec, cancelled requests and errors both require the monitoring
338                // connection to be closed.
339                self.connection = None;
340                self.rtt_monitor_handle.reset_average_rtt();
341                self.topology_version.take();
342            }
343        }
344
345        result
346    }
347
348    async fn handle_error(&mut self, error: Error) -> bool {
349        self.topology_updater
350            .handle_monitor_error(self.address.clone(), error)
351            .await
352    }
353
354    fn emit_event<F>(&self, event: F)
355    where
356        F: FnOnce() -> SdamEvent,
357    {
358        if let Some(ref emitter) = self.sdam_event_emitter {
359            // We don't care about ordering or waiting for the event to have been received.
360            #[allow(clippy::let_underscore_future)]
361            let _ = emitter.emit(event());
362        }
363    }
364
365    fn connect_timeout(&self) -> Duration {
366        self.client_options
367            .connect_timeout
368            .unwrap_or(DEFAULT_CONNECT_TIMEOUT)
369    }
370
371    fn heartbeat_frequency(&self) -> Duration {
372        self.client_options
373            .heartbeat_freq
374            .unwrap_or(DEFAULT_HEARTBEAT_FREQUENCY)
375    }
376}
377
378/// The monitor used for tracking the round-trip-time to the server, as described in the SDAM spec.
379/// This monitor uses its own connection to make RTT measurements, and it publishes the averages of
380/// those measurements to a channel.
381struct RttMonitor {
382    sender: Arc<watch::Sender<RttInfo>>,
383    connection: Option<Connection>,
384    topology: TopologyWatcher,
385    address: ServerAddress,
386    client_options: ClientOptions,
387    connection_establisher: ConnectionEstablisher,
388}
389
390#[derive(Debug, Clone, Copy, Default)]
391pub(crate) struct RttInfo {
392    pub(crate) average: Option<Duration>,
393}
394
395impl RttInfo {
396    pub(crate) fn add_sample(&mut self, sample: Duration) {
397        match self.average {
398            Some(old_rtt) => {
399                // Average is 20% most recent sample and 80% prior sample.
400                self.average = Some((sample / 5) + (old_rtt * 4 / 5))
401            }
402            None => self.average = Some(sample),
403        }
404    }
405}
406
407impl RttMonitor {
408    /// Creates a new RTT monitor for the server at the given address, returning a receiver that the
409    /// RTT statistics will be published to. This does not start the monitor.
410    /// [`RttMonitor::execute`] needs to be invoked to start it.
411    fn new(
412        address: ServerAddress,
413        topology: TopologyWatcher,
414        connection_establisher: ConnectionEstablisher,
415        client_options: ClientOptions,
416    ) -> (Self, RttMonitorHandle) {
417        let (sender, rtt_receiver) = watch::channel(RttInfo { average: None });
418        let sender = Arc::new(sender);
419
420        let monitor = Self {
421            address,
422            connection: None,
423            topology,
424            client_options,
425            connection_establisher,
426            sender: sender.clone(),
427        };
428
429        let handle = RttMonitorHandle {
430            reset_sender: sender,
431            rtt_receiver,
432        };
433        (monitor, handle)
434    }
435
436    async fn execute(mut self) {
437        // keep executing until either the topology is closed or server monitor is done (i.e. the
438        // sender is closed)
439        while self.topology.is_alive() && !self.sender.is_closed() {
440            let timeout = self
441                .client_options
442                .connect_timeout
443                .unwrap_or(DEFAULT_CONNECT_TIMEOUT);
444
445            let perform_check = async {
446                match self.connection {
447                    Some(ref mut conn) => {
448                        let command = hello_command(
449                            self.client_options.server_api.as_ref(),
450                            self.client_options.load_balanced,
451                            Some(conn.stream_description()?.hello_ok),
452                            None,
453                        );
454                        conn.send_message(command).await?;
455                    }
456                    None => {
457                        let connection = self
458                            .connection_establisher
459                            .establish_monitoring_connection(
460                                self.address.clone(),
461                                next_monitoring_connection_id(),
462                            )
463                            .await?
464                            .0;
465                        self.connection = Some(connection);
466                    }
467                };
468                Result::Ok(())
469            };
470
471            let start = Instant::now();
472            let check_succeded = tokio::select! {
473                r = perform_check => r.is_ok(),
474                _ = tokio::time::sleep(timeout) => {
475                    false
476                }
477            };
478
479            if check_succeded {
480                self.sender
481                    .send_modify(|rtt_info| rtt_info.add_sample(start.elapsed()));
482            } else {
483                // From the SDAM spec: "Errors encountered when running a hello or legacy hello
484                // command MUST NOT update the topology."
485                self.connection = None;
486
487                // Also from the SDAM spec: "Don't call reset() here. The Monitor thread is
488                // responsible for resetting the average RTT."
489            }
490
491            tokio::time::sleep(
492                self.client_options
493                    .heartbeat_freq
494                    .unwrap_or(DEFAULT_HEARTBEAT_FREQUENCY),
495            )
496            .await;
497        }
498    }
499}
500
501struct RttMonitorHandle {
502    rtt_receiver: watch::Receiver<RttInfo>,
503    reset_sender: Arc<watch::Sender<RttInfo>>,
504}
505
506impl RttMonitorHandle {
507    fn average_rtt(&self) -> Option<Duration> {
508        self.rtt_receiver.borrow().average
509    }
510
511    fn reset_average_rtt(&mut self) {
512        let _ = self.reset_sender.send(RttInfo::default());
513    }
514
515    fn add_sample(&mut self, sample: Duration) {
516        self.reset_sender.send_modify(|rtt_info| {
517            rtt_info.add_sample(sample);
518        });
519    }
520}
521
522#[allow(clippy::large_enum_variant)] // The Ok branch is bigger but more common
523#[derive(Debug, Clone)]
524enum HelloResult {
525    Ok(HelloReply),
526    Err(Error),
527    Cancelled { reason: Error },
528}
529
530/// Struct used to keep a monitor alive, individually request an immediate check, and to cancel
531/// in-progress checks.
532#[derive(Debug, Clone)]
533pub(crate) struct MonitorManager {
534    /// `WorkerHandle` used to keep the monitor alive. When this is dropped, the monitor will exit.
535    handle: WorkerHandle,
536
537    /// Sender used to cancel in-progress monitor checks and, if the reason is TopologyClosed,
538    /// close the monitor.
539    cancellation_sender: Arc<watch::Sender<CancellationReason>>,
540
541    /// Sender used to individually request an immediate check from the monitor associated with
542    /// this manager.
543    check_requester: Arc<watch::Sender<()>>,
544}
545
546impl MonitorManager {
547    pub(crate) fn new(monitor_handle: WorkerHandle) -> Self {
548        // The CancellationReason used as the initial value is just a placeholder. The only receiver
549        // that could have seen it is dropped in this scope, and the monitor's receiver will
550        // never observe it.
551        let (tx, _) = watch::channel(CancellationReason::ServerClosed);
552        let check_requester = Arc::new(watch::channel(()).0);
553
554        MonitorManager {
555            handle: monitor_handle,
556            cancellation_sender: Arc::new(tx),
557            check_requester,
558        }
559    }
560
561    /// Cancel any in progress checks, notify the monitor that it should close, and wait for it to
562    /// do so.
563    pub(crate) async fn close_monitor(self) {
564        drop(self.handle);
565        let _ = self
566            .cancellation_sender
567            .send(CancellationReason::ServerClosed);
568        self.cancellation_sender.closed().await;
569    }
570
571    /// Cancel any in progress check with the provided reason.
572    pub(crate) fn cancel_in_progress_check(&mut self, reason: impl Into<CancellationReason>) {
573        let _ = self.cancellation_sender.send(reason.into());
574    }
575
576    /// Request an immediate topology check by this monitor. If the monitor is currently performing
577    /// a check, this request will be ignored.
578    pub(crate) fn request_immediate_check(&mut self) {
579        let _ = self.check_requester.send(());
580    }
581}
582
583/// Struct used to receive cancellation and immediate check requests from various different places.
584pub(crate) struct MonitorRequestReceiver {
585    /// Handle listener used to determine whether this monitor should continue to execute or not.
586    /// The `MonitorManager` owned by the `TopologyWorker` owns the handle that this listener
587    /// corresponds to.
588    handle_listener: WorkerHandleListener,
589
590    /// Receiver for cancellation requests. These come in when an operation encounters network
591    /// errors or when the topology is closed.
592    cancellation_receiver: watch::Receiver<CancellationReason>,
593
594    /// Receiver used to listen for immediate check requests sent by the `TopologyWorker` that only
595    /// apply to the server associated with the monitor, not for the whole topology.
596    individual_check_request_receiver: watch::Receiver<()>,
597
598    /// Receiver used to listen for immediate check requests that were broadcast to the entire
599    /// topology by operations attempting to select a server.
600    topology_check_request_receiver: TopologyCheckRequestReceiver,
601}
602
603#[derive(Debug, Clone)]
604pub(crate) enum CancellationReason {
605    Error(Error),
606    ServerClosed,
607}
608
609impl From<Error> for CancellationReason {
610    fn from(e: Error) -> Self {
611        Self::Error(e)
612    }
613}
614
615impl MonitorRequestReceiver {
616    pub(crate) fn new(
617        manager: &MonitorManager,
618        topology_check_request_receiver: TopologyCheckRequestReceiver,
619        handle_listener: WorkerHandleListener,
620    ) -> Self {
621        Self {
622            handle_listener,
623            cancellation_receiver: manager.cancellation_sender.subscribe(),
624            individual_check_request_receiver: manager.check_requester.subscribe(),
625            topology_check_request_receiver,
626        }
627    }
628
629    /// Wait for a request to cancel the current in-progress check to come in, returning the reason
630    /// for it. Any check requests that are received during this time will be ignored, as per
631    /// the spec.
632    async fn wait_for_cancellation(&mut self) -> CancellationReason {
633        let err = if self.cancellation_receiver.changed().await.is_ok() {
634            self.cancellation_receiver.borrow().clone()
635        } else {
636            CancellationReason::ServerClosed
637        };
638        // clear out ignored check requests
639        self.individual_check_request_receiver.borrow_and_update();
640        err
641    }
642
643    /// Wait for a request to immediately check the server to be received, guarded by the provided
644    /// timeout. If the server associated with this monitor is removed from the topology, this
645    /// method will return.
646    ///
647    /// The `delay` parameter indicates how long this method should wait before listening to
648    /// requests. The time spent in the delay counts toward the provided timeout.
649    async fn wait_for_check_request(&mut self, delay: Duration, timeout: Duration) {
650        let _ = runtime::timeout(timeout, async {
651            let wait_for_check_request = async {
652                tokio::time::sleep(delay).await;
653                self.topology_check_request_receiver
654                    .wait_for_check_request()
655                    .await;
656            };
657            tokio::pin!(wait_for_check_request);
658
659            tokio::select! {
660                _ = self.individual_check_request_receiver.changed() => (),
661                _ = &mut wait_for_check_request => (),
662                // Don't continue waiting after server has been removed from the topology.
663                _ = self.handle_listener.wait_for_all_handle_drops() => (),
664            }
665        })
666        .await;
667
668        // clear out ignored cancellation requests while we were waiting to begin a check
669        self.cancellation_receiver.borrow_and_update();
670    }
671
672    fn is_alive(&self) -> bool {
673        self.handle_listener.is_alive()
674    }
675}