mongodb/
client.rs

1pub mod action;
2pub mod auth;
3#[cfg(feature = "in-use-encryption")]
4pub(crate) mod csfle;
5mod executor;
6pub mod options;
7pub mod session;
8
9use std::{
10    sync::{
11        atomic::{AtomicBool, Ordering},
12        Mutex as SyncMutex,
13    },
14    time::{Duration, Instant},
15};
16
17#[cfg(feature = "in-use-encryption")]
18pub use self::csfle::client_builder::*;
19use derive_where::derive_where;
20use futures_core::Future;
21use futures_util::FutureExt;
22
23#[cfg(feature = "tracing-unstable")]
24use crate::trace::{
25    command::CommandTracingEventEmitter,
26    server_selection::ServerSelectionTracingEventEmitter,
27    trace_or_log_enabled,
28    TracingOrLogLevel,
29    COMMAND_TRACING_EVENT_TARGET,
30};
31use crate::{
32    bson::doc,
33    concern::{ReadConcern, WriteConcern},
34    db::Database,
35    error::{Error, ErrorKind, Result},
36    event::command::CommandEvent,
37    id_set::IdSet,
38    options::{ClientOptions, DatabaseOptions, ReadPreference, SelectionCriteria, ServerAddress},
39    sdam::{
40        server_selection::{self, attempt_to_select_server},
41        SelectedServer,
42        Topology,
43    },
44    tracking_arc::TrackingArc,
45    BoxFuture,
46    ClientSession,
47};
48
49pub(crate) use executor::{HELLO_COMMAND_NAMES, REDACTED_COMMANDS};
50pub(crate) use session::{ClusterTime, SESSIONS_UNSUPPORTED_COMMANDS};
51
52use session::{ServerSession, ServerSessionPool};
53
54const DEFAULT_SERVER_SELECTION_TIMEOUT: Duration = Duration::from_secs(30);
55
56/// This is the main entry point for the API. A `Client` is used to connect to a MongoDB cluster.
57/// By default, it will monitor the topology of the cluster, keeping track of any changes, such
58/// as servers being added or removed.
59///
60/// `Client` uses [`std::sync::Arc`](https://doc.rust-lang.org/std/sync/struct.Arc.html) internally,
61/// so it can safely be shared across threads or async tasks. For example:
62///
63/// ```rust
64/// # use mongodb::{bson::Document, Client, error::Result};
65/// #
66/// # async fn start_workers() -> Result<()> {
67/// let client = Client::with_uri_str("mongodb://example.com").await?;
68///
69/// for i in 0..5 {
70///     let client_ref = client.clone();
71///
72///     tokio::task::spawn(async move {
73///         let collection = client_ref.database("items").collection::<Document>(&format!("coll{}", i));
74///
75///         // Do something with the collection
76///     });
77/// }
78/// #
79/// # Ok(())
80/// # }
81/// ```
82/// ## Notes on performance
83/// Spawning many asynchronous tasks that use the driver concurrently like this is often the best
84/// way to achieve maximum performance, as the driver is designed to work well in such situations.
85///
86/// Additionally, using a custom Rust type that implements `Serialize` and `Deserialize` as the
87/// generic parameter of [`Collection`](../struct.Collection.html) instead of [`bson::Document`] can
88/// reduce the amount of time the driver and your application spends serializing and deserializing
89/// BSON, which can also lead to increased performance.
90///
91/// ## TCP Keepalive
92/// TCP keepalive is enabled by default with ``tcp_keepalive_time`` set to 120 seconds. The
93/// driver does not set ``tcp_keepalive_intvl``. See the
94/// [MongoDB Diagnostics FAQ keepalive section](https://www.mongodb.com/docs/manual/faq/diagnostics/#does-tcp-keepalive-time-affect-mongodb-deployments)
95/// for instructions on setting these values at the system level.
96///
97/// ## Clean shutdown
98/// Because Rust has no async equivalent of `Drop`, values that require server-side cleanup when
99/// dropped spawn a new async task to perform that cleanup.  This can cause two potential issues:
100///
101/// * Drop tasks pending or in progress when the async runtime shuts down may not complete, causing
102///   server-side resources to not be freed.
103/// * Drop tasks may run at an arbitrary time even after no `Client` values exist, making it hard to
104///   reason about associated resources (e.g. event handlers).
105///
106/// To address these issues, we highly recommend you use [`Client::shutdown`] in the termination
107/// path of your application.  This will ensure that outstanding resources have been cleaned up and
108/// terminate internal worker tasks before returning.  Please note that `shutdown` will wait for
109/// _all_ outstanding resource handles to be dropped, so they must either have been dropped before
110/// calling `shutdown` or in a concurrent task; see the documentation of `shutdown` for more
111/// details.
112#[derive(Debug, Clone)]
113pub struct Client {
114    inner: TrackingArc<ClientInner>,
115}
116
117#[allow(dead_code, unreachable_code, clippy::diverging_sub_expression)]
118const _: fn() = || {
119    fn assert_send<T: Send>(_t: T) {}
120    fn assert_sync<T: Sync>(_t: T) {}
121
122    let _c: super::Client = todo!();
123    assert_send(_c);
124    assert_sync(_c);
125};
126
127#[derive(Debug)]
128struct ClientInner {
129    topology: Topology,
130    options: ClientOptions,
131    session_pool: ServerSessionPool,
132    shutdown: Shutdown,
133    dropped: AtomicBool,
134    end_sessions_token: std::sync::Mutex<AsyncDropToken>,
135    #[cfg(feature = "in-use-encryption")]
136    csfle: tokio::sync::RwLock<Option<csfle::ClientState>>,
137    #[cfg(test)]
138    disable_command_events: AtomicBool,
139}
140
141#[derive(Debug)]
142struct Shutdown {
143    pending_drops: SyncMutex<IdSet<crate::runtime::AsyncJoinHandle<()>>>,
144    executed: AtomicBool,
145}
146
147impl Client {
148    /// Creates a new `Client` connected to the cluster specified by `uri`. `uri` must be a valid
149    /// MongoDB connection string.
150    ///
151    /// See the documentation on
152    /// [`ClientOptions::parse`](options/struct.ClientOptions.html#method.parse) for more details.
153    pub async fn with_uri_str(uri: impl AsRef<str>) -> Result<Self> {
154        let options = ClientOptions::parse(uri.as_ref()).await?;
155
156        Client::with_options(options)
157    }
158
159    /// Creates a new `Client` connected to the cluster specified by `options`.
160    pub fn with_options(options: ClientOptions) -> Result<Self> {
161        options.validate()?;
162
163        // Spawn a cleanup task, similar to register_async_drop
164        let (cleanup_tx, cleanup_rx) = tokio::sync::oneshot::channel::<BoxFuture<'static, ()>>();
165        crate::runtime::spawn(async move {
166            // If the cleanup channel is closed, that task was dropped.
167            if let Ok(cleanup) = cleanup_rx.await {
168                cleanup.await;
169            }
170        });
171        let end_sessions_token = std::sync::Mutex::new(AsyncDropToken {
172            tx: Some(cleanup_tx),
173        });
174
175        let inner = TrackingArc::new(ClientInner {
176            topology: Topology::new(options.clone())?,
177            session_pool: ServerSessionPool::new(),
178            options,
179            shutdown: Shutdown {
180                pending_drops: SyncMutex::new(IdSet::new()),
181                executed: AtomicBool::new(false),
182            },
183            dropped: AtomicBool::new(false),
184            end_sessions_token,
185            #[cfg(feature = "in-use-encryption")]
186            csfle: Default::default(),
187            #[cfg(test)]
188            disable_command_events: AtomicBool::new(false),
189        });
190        Ok(Self { inner })
191    }
192
193    /// Return an `EncryptedClientBuilder` for constructing a `Client` with auto-encryption enabled.
194    ///
195    /// ```no_run
196    /// # use bson::doc;
197    /// # use mongocrypt::ctx::KmsProvider;
198    /// # use mongodb::Client;
199    /// # use mongodb::error::Result;
200    /// # async fn func() -> Result<()> {
201    /// # let client_options = todo!();
202    /// # let key_vault_namespace = todo!();
203    /// # let key_vault_client: Client = todo!();
204    /// # let local_key: bson::Binary = todo!();
205    /// let encrypted_client = Client::encrypted_builder(
206    ///     client_options,
207    ///     key_vault_namespace,
208    ///     [(KmsProvider::Local, doc! { "key": local_key }, None)],
209    /// )?
210    /// .key_vault_client(key_vault_client)
211    /// .build()
212    /// .await?;
213    /// # Ok(())
214    /// # }
215    /// ```
216    #[cfg(feature = "in-use-encryption")]
217    pub fn encrypted_builder(
218        client_options: ClientOptions,
219        key_vault_namespace: crate::Namespace,
220        kms_providers: impl IntoIterator<
221            Item = (
222                mongocrypt::ctx::KmsProvider,
223                bson::Document,
224                Option<options::TlsOptions>,
225            ),
226        >,
227    ) -> Result<EncryptedClientBuilder> {
228        Ok(EncryptedClientBuilder::new(
229            client_options,
230            csfle::options::AutoEncryptionOptions::new(
231                key_vault_namespace,
232                csfle::options::KmsProviders::new(kms_providers)?,
233            ),
234        ))
235    }
236
237    /// Whether commands sent via this client should be auto-encrypted.
238    pub(crate) async fn should_auto_encrypt(&self) -> bool {
239        #[cfg(feature = "in-use-encryption")]
240        {
241            let csfle = self.inner.csfle.read().await;
242            match *csfle {
243                Some(ref csfle) => csfle
244                    .opts()
245                    .bypass_auto_encryption
246                    .map(|b| !b)
247                    .unwrap_or(true),
248                None => false,
249            }
250        }
251        #[cfg(not(feature = "in-use-encryption"))]
252        {
253            false
254        }
255    }
256
257    #[cfg(all(test, feature = "in-use-encryption"))]
258    pub(crate) async fn mongocryptd_spawned(&self) -> bool {
259        self.inner
260            .csfle
261            .read()
262            .await
263            .as_ref()
264            .is_some_and(|cs| cs.exec().mongocryptd_spawned())
265    }
266
267    #[cfg(all(test, feature = "in-use-encryption"))]
268    pub(crate) async fn has_mongocryptd_client(&self) -> bool {
269        self.inner
270            .csfle
271            .read()
272            .await
273            .as_ref()
274            .is_some_and(|cs| cs.exec().has_mongocryptd_client())
275    }
276
277    fn test_command_event_channel(&self) -> Option<&options::TestEventSender> {
278        #[cfg(test)]
279        {
280            self.inner
281                .options
282                .test_options
283                .as_ref()
284                .and_then(|t| t.async_event_listener.as_ref())
285        }
286        #[cfg(not(test))]
287        {
288            None
289        }
290    }
291
292    pub(crate) async fn emit_command_event(&self, generate_event: impl FnOnce() -> CommandEvent) {
293        #[cfg(test)]
294        if self
295            .inner
296            .disable_command_events
297            .load(std::sync::atomic::Ordering::SeqCst)
298        {
299            return;
300        }
301        #[cfg(feature = "tracing-unstable")]
302        let tracing_emitter = if trace_or_log_enabled!(
303            target: COMMAND_TRACING_EVENT_TARGET,
304            TracingOrLogLevel::Debug
305        ) {
306            Some(CommandTracingEventEmitter::new(
307                self.inner.options.tracing_max_document_length_bytes,
308                self.inner.topology.id,
309            ))
310        } else {
311            None
312        };
313        let test_channel = self.test_command_event_channel();
314        let should_send = test_channel.is_some() || self.options().command_event_handler.is_some();
315        #[cfg(feature = "tracing-unstable")]
316        let should_send = should_send || tracing_emitter.is_some();
317        if !should_send {
318            return;
319        }
320
321        let event = generate_event();
322        if let Some(tx) = test_channel {
323            let (msg, ack) = crate::runtime::AcknowledgedMessage::package(event.clone());
324            let _ = tx.send(msg).await;
325            ack.wait_for_acknowledgment().await;
326        }
327        #[cfg(feature = "tracing-unstable")]
328        if let Some(ref tracing_emitter) = tracing_emitter {
329            tracing_emitter.handle(event.clone());
330        }
331        if let Some(handler) = &self.options().command_event_handler {
332            handler.handle(event);
333        }
334    }
335
336    /// Gets the default selection criteria the `Client` uses for operations..
337    pub fn selection_criteria(&self) -> Option<&SelectionCriteria> {
338        self.inner.options.selection_criteria.as_ref()
339    }
340
341    /// Gets the default read concern the `Client` uses for operations.
342    pub fn read_concern(&self) -> Option<&ReadConcern> {
343        self.inner.options.read_concern.as_ref()
344    }
345
346    /// Gets the default write concern the `Client` uses for operations.
347    pub fn write_concern(&self) -> Option<&WriteConcern> {
348        self.inner.options.write_concern.as_ref()
349    }
350
351    /// Gets a handle to a database specified by `name` in the cluster the `Client` is connected to.
352    /// The `Database` options (e.g. read preference and write concern) will default to those of the
353    /// `Client`.
354    ///
355    /// This method does not send or receive anything across the wire to the database, so it can be
356    /// used repeatedly without incurring any costs from I/O.
357    pub fn database(&self, name: &str) -> Database {
358        Database::new(self.clone(), name, None)
359    }
360
361    /// Gets a handle to a database specified by `name` in the cluster the `Client` is connected to.
362    /// Operations done with this `Database` will use the options specified by `options` by default
363    /// and will otherwise default to those of the `Client`.
364    ///
365    /// This method does not send or receive anything across the wire to the database, so it can be
366    /// used repeatedly without incurring any costs from I/O.
367    pub fn database_with_options(&self, name: &str, options: DatabaseOptions) -> Database {
368        Database::new(self.clone(), name, Some(options))
369    }
370
371    /// Gets a handle to the default database specified in the `ClientOptions` or MongoDB connection
372    /// string used to construct this `Client`.
373    ///
374    /// If no default database was specified, `None` will be returned.
375    pub fn default_database(&self) -> Option<Database> {
376        self.inner
377            .options
378            .default_database
379            .as_ref()
380            .map(|db_name| self.database(db_name))
381    }
382
383    pub(crate) fn register_async_drop(&self) -> AsyncDropToken {
384        let (cleanup_tx, cleanup_rx) = tokio::sync::oneshot::channel::<BoxFuture<'static, ()>>();
385        let (id_tx, id_rx) = tokio::sync::oneshot::channel::<crate::id_set::Id>();
386        let weak = self.weak();
387        let handle = crate::runtime::spawn(async move {
388            // Unwrap safety: the id is sent immediately after task creation, with no
389            // await points in between.
390            let id = id_rx.await.unwrap();
391            // If the cleanup channel is closed, that task was dropped.
392            if let Ok(cleanup) = cleanup_rx.await {
393                cleanup.await;
394            }
395            if let Some(client) = weak.upgrade() {
396                client
397                    .inner
398                    .shutdown
399                    .pending_drops
400                    .lock()
401                    .unwrap()
402                    .remove(&id);
403            }
404        });
405        let id = self
406            .inner
407            .shutdown
408            .pending_drops
409            .lock()
410            .unwrap()
411            .insert(handle);
412        let _ = id_tx.send(id);
413        AsyncDropToken {
414            tx: Some(cleanup_tx),
415        }
416    }
417
418    /// Check in a server session to the server session pool. The session will be discarded if it is
419    /// expired or dirty.
420    pub(crate) async fn check_in_server_session(&self, session: ServerSession) {
421        let timeout = self.inner.topology.logical_session_timeout();
422        self.inner.session_pool.check_in(session, timeout).await;
423    }
424
425    #[cfg(test)]
426    pub(crate) async fn clear_session_pool(&self) {
427        self.inner.session_pool.clear().await;
428    }
429
430    #[cfg(test)]
431    pub(crate) async fn is_session_checked_in(&self, id: &bson::Document) -> bool {
432        self.inner.session_pool.contains(id).await
433    }
434
435    #[cfg(test)]
436    pub(crate) fn disable_command_events(&self, disable: bool) {
437        self.inner
438            .disable_command_events
439            .store(disable, std::sync::atomic::Ordering::SeqCst);
440    }
441
442    /// Get the address of the server selected according to the given criteria.
443    /// This method is only used in tests.
444    #[cfg(test)]
445    pub(crate) async fn test_select_server(
446        &self,
447        criteria: Option<&SelectionCriteria>,
448    ) -> Result<ServerAddress> {
449        let server = self
450            .select_server(criteria, "Test select server", None)
451            .await?;
452        Ok(server.address.clone())
453    }
454
455    /// Select a server using the provided criteria. If none is provided, a primary read preference
456    /// will be used instead.
457    async fn select_server(
458        &self,
459        criteria: Option<&SelectionCriteria>,
460        #[allow(unused_variables)] // we only use the operation_name for tracing.
461        operation_name: &str,
462        deprioritized: Option<&ServerAddress>,
463    ) -> Result<SelectedServer> {
464        let criteria =
465            criteria.unwrap_or(&SelectionCriteria::ReadPreference(ReadPreference::Primary));
466
467        let start_time = Instant::now();
468        let timeout = self
469            .inner
470            .options
471            .server_selection_timeout
472            .unwrap_or(DEFAULT_SERVER_SELECTION_TIMEOUT);
473
474        #[cfg(feature = "tracing-unstable")]
475        let event_emitter = ServerSelectionTracingEventEmitter::new(
476            self.inner.topology.id,
477            criteria,
478            operation_name,
479            start_time,
480            timeout,
481        );
482        #[cfg(feature = "tracing-unstable")]
483        event_emitter.emit_started_event(self.inner.topology.watch().observe_latest().description);
484        // We only want to emit this message once per operation at most.
485        #[cfg(feature = "tracing-unstable")]
486        let mut emitted_waiting_message = false;
487
488        let mut watcher = self.inner.topology.watch();
489        loop {
490            let state = watcher.observe_latest();
491
492            let result = server_selection::attempt_to_select_server(
493                criteria,
494                &state.description,
495                &state.servers(),
496                deprioritized,
497            );
498            match result {
499                Err(error) => {
500                    #[cfg(feature = "tracing-unstable")]
501                    event_emitter.emit_failed_event(&state.description, &error);
502
503                    return Err(error);
504                }
505                Ok(result) => {
506                    if let Some(server) = result {
507                        #[cfg(feature = "tracing-unstable")]
508                        event_emitter.emit_succeeded_event(&state.description, &server);
509
510                        return Ok(server);
511                    } else {
512                        #[cfg(feature = "tracing-unstable")]
513                        if !emitted_waiting_message {
514                            event_emitter.emit_waiting_event(&state.description);
515                            emitted_waiting_message = true;
516                        }
517
518                        watcher.request_immediate_check();
519
520                        let change_occurred = start_time.elapsed() < timeout
521                            && watcher
522                                .wait_for_update(timeout - start_time.elapsed())
523                                .await;
524                        if !change_occurred {
525                            let error: Error = ErrorKind::ServerSelection {
526                                message: state
527                                    .description
528                                    .server_selection_timeout_error_message(criteria),
529                            }
530                            .into();
531
532                            #[cfg(feature = "tracing-unstable")]
533                            event_emitter.emit_failed_event(&state.description, &error);
534
535                            return Err(error);
536                        }
537                    }
538                }
539            }
540        }
541    }
542
543    #[cfg(all(test, feature = "dns-resolver"))]
544    pub(crate) fn get_hosts(&self) -> Vec<String> {
545        let watcher = self.inner.topology.watch();
546        let state = watcher.peek_latest();
547
548        state
549            .servers()
550            .keys()
551            .map(|stream_address| format!("{}", stream_address))
552            .collect()
553    }
554
555    #[cfg(test)]
556    pub(crate) async fn sync_workers(&self) {
557        self.inner.topology.sync_workers().await;
558    }
559
560    #[cfg(test)]
561    pub(crate) fn topology_description(&self) -> crate::sdam::TopologyDescription {
562        self.inner
563            .topology
564            .watch()
565            .peek_latest()
566            .description
567            .clone()
568    }
569
570    #[cfg(test)]
571    pub(crate) fn topology(&self) -> &Topology {
572        &self.inner.topology
573    }
574
575    #[cfg(feature = "in-use-encryption")]
576    pub(crate) async fn primary_description(&self) -> Option<crate::sdam::ServerDescription> {
577        let start_time = Instant::now();
578        let timeout = self
579            .inner
580            .options
581            .server_selection_timeout
582            .unwrap_or(DEFAULT_SERVER_SELECTION_TIMEOUT);
583        let mut watcher = self.inner.topology.watch();
584        loop {
585            let topology = watcher.observe_latest();
586            if let Some(desc) = topology.description.primary() {
587                return Some(desc.clone());
588            }
589            if !watcher
590                .wait_for_update(timeout - start_time.elapsed())
591                .await
592            {
593                return None;
594            }
595        }
596    }
597
598    pub(crate) fn weak(&self) -> WeakClient {
599        WeakClient {
600            inner: TrackingArc::downgrade(&self.inner),
601        }
602    }
603
604    #[cfg(feature = "in-use-encryption")]
605    pub(crate) async fn auto_encryption_opts(
606        &self,
607    ) -> Option<tokio::sync::RwLockReadGuard<'_, csfle::options::AutoEncryptionOptions>> {
608        tokio::sync::RwLockReadGuard::try_map(self.inner.csfle.read().await, |csfle| {
609            csfle.as_ref().map(|cs| cs.opts())
610        })
611        .ok()
612    }
613
614    pub(crate) fn options(&self) -> &ClientOptions {
615        &self.inner.options
616    }
617
618    /// Ends all sessions contained in this client's session pool on the server.
619    pub(crate) async fn end_all_sessions(&self) {
620        // The maximum number of session IDs that should be sent in a single endSessions command.
621        const MAX_END_SESSIONS_BATCH_SIZE: usize = 10_000;
622
623        let mut watcher = self.inner.topology.watch();
624        let selection_criteria =
625            SelectionCriteria::from(ReadPreference::PrimaryPreferred { options: None });
626
627        let session_ids = self.inner.session_pool.get_session_ids().await;
628        for chunk in session_ids.chunks(MAX_END_SESSIONS_BATCH_SIZE) {
629            let state = watcher.observe_latest();
630            let Ok(Some(_)) = attempt_to_select_server(
631                &selection_criteria,
632                &state.description,
633                &state.servers(),
634                None,
635            ) else {
636                // If a suitable server is not available, do not proceed with the operation to avoid
637                // spinning for server_selection_timeout.
638                return;
639            };
640
641            let end_sessions = doc! {
642                "endSessions": chunk,
643            };
644            let _ = self
645                .database("admin")
646                .run_command(end_sessions)
647                .selection_criteria(selection_criteria.clone())
648                .await;
649        }
650    }
651}
652
653#[derive(Clone, Debug)]
654pub(crate) struct WeakClient {
655    inner: crate::tracking_arc::Weak<ClientInner>,
656}
657
658impl WeakClient {
659    pub(crate) fn upgrade(&self) -> Option<Client> {
660        self.inner.upgrade().map(|inner| Client { inner })
661    }
662}
663
664#[derive_where(Debug)]
665pub(crate) struct AsyncDropToken {
666    #[derive_where(skip)]
667    tx: Option<tokio::sync::oneshot::Sender<BoxFuture<'static, ()>>>,
668}
669
670impl AsyncDropToken {
671    pub(crate) fn spawn(&mut self, fut: impl Future<Output = ()> + Send + 'static) {
672        if let Some(tx) = self.tx.take() {
673            let _ = tx.send(fut.boxed());
674        } else {
675            #[cfg(debug_assertions)]
676            panic!("exhausted AsyncDropToken");
677        }
678    }
679
680    pub(crate) fn take(&mut self) -> Self {
681        Self { tx: self.tx.take() }
682    }
683}
684
685impl Drop for Client {
686    fn drop(&mut self) {
687        if !self.inner.shutdown.executed.load(Ordering::SeqCst)
688            && !self.inner.dropped.load(Ordering::SeqCst)
689            && TrackingArc::strong_count(&self.inner) == 1
690        {
691            // We need an owned copy of the client to move into the spawned future. However, if this
692            // call to drop completes before the spawned future completes, the number of strong
693            // references to the inner client will again be 1 when the cloned client drops, and thus
694            // end_all_sessions will be called continuously until the runtime shuts down. Storing a
695            // flag indicating whether end_all_sessions has already been called breaks
696            // this cycle.
697            self.inner.dropped.store(true, Ordering::SeqCst);
698            let client = self.clone();
699            self.inner
700                .end_sessions_token
701                .lock()
702                .unwrap()
703                .spawn(async move {
704                    client.end_all_sessions().await;
705                });
706        }
707    }
708}