Skip to main content

mongodb/
client.rs

1pub mod action;
2pub mod auth;
3#[cfg(feature = "in-use-encryption")]
4pub(crate) mod csfle;
5mod executor;
6pub mod options;
7pub mod session;
8
9use std::{
10    collections::HashSet,
11    sync::atomic::{AtomicBool, Ordering},
12    time::{Duration, Instant},
13};
14
15#[cfg(feature = "in-use-encryption")]
16pub use self::csfle::client_builder::*;
17use derive_where::derive_where;
18use futures_core::Future;
19use futures_util::FutureExt;
20use tokio::sync::Mutex;
21use tokio_util::task::TaskTracker;
22
23#[cfg(feature = "tracing-unstable")]
24use crate::trace::{
25    command::CommandTracingEventEmitter,
26    server_selection::ServerSelectionTracingEventEmitter,
27    trace_or_log_enabled,
28    TracingOrLogLevel,
29    COMMAND_TRACING_EVENT_TARGET,
30};
31use crate::{
32    bson::doc,
33    concern::{ReadConcern, WriteConcern},
34    db::Database,
35    error::{Error, ErrorKind, Result},
36    event::command::CommandEvent,
37    operation::OverrideCriteriaFn,
38    options::{
39        ClientOptions,
40        DatabaseOptions,
41        DriverInfo,
42        ReadPreference,
43        SelectionCriteria,
44        ServerAddress,
45    },
46    sdam::{
47        server_selection::{self, attempt_to_select_server},
48        SelectedServer,
49        Topology,
50    },
51    tracking_arc::TrackingArc,
52    BoxFuture,
53    ClientSession,
54    TopologyType,
55};
56
57pub(crate) use executor::{retry::Retry, HELLO_COMMAND_NAMES, REDACTED_COMMANDS};
58pub(crate) use session::{ClusterTime, SESSIONS_UNSUPPORTED_COMMANDS};
59
60use session::{ServerSession, ServerSessionPool};
61
62const DEFAULT_SERVER_SELECTION_TIMEOUT: Duration = Duration::from_secs(30);
63
64/// This is the main entry point for the API. A `Client` is used to connect to a MongoDB cluster.
65/// By default, it will monitor the topology of the cluster, keeping track of any changes, such
66/// as servers being added or removed.
67///
68/// `Client` uses [`std::sync::Arc`](https://doc.rust-lang.org/std/sync/struct.Arc.html) internally,
69/// so it can safely be shared across threads or async tasks. For example:
70///
71/// ```rust
72/// # use mongodb::{bson::Document, Client, error::Result};
73/// #
74/// # async fn start_workers() -> Result<()> {
75/// let client = Client::with_uri_str("mongodb://example.com").await?;
76///
77/// for i in 0..5 {
78///     let client_ref = client.clone();
79///
80///     tokio::task::spawn(async move {
81///         let collection = client_ref.database("items").collection::<Document>(&format!("coll{}", i));
82///
83///         // Do something with the collection
84///     });
85/// }
86/// #
87/// # Ok(())
88/// # }
89/// ```
90/// ## Notes on performance
91/// Spawning many asynchronous tasks that use the driver concurrently like this is often the best
92/// way to achieve maximum performance, as the driver is designed to work well in such situations.
93///
94/// Additionally, using a custom Rust type that implements `Serialize` and `Deserialize` as the
95/// generic parameter of [`Collection`](../struct.Collection.html) instead of
96/// [`Document`](crate::bson::Document) can reduce the amount of time the driver and your
97/// application spends serializing and deserializing BSON, which can also lead to increased
98/// performance.
99///
100/// ## TCP Keepalive
101/// TCP keepalive is enabled by default with ``tcp_keepalive_time`` set to 120 seconds. The
102/// driver does not set ``tcp_keepalive_intvl``. See the
103/// [MongoDB Diagnostics FAQ keepalive section](https://www.mongodb.com/docs/manual/faq/diagnostics/#does-tcp-keepalive-time-affect-mongodb-deployments)
104/// for instructions on setting these values at the system level.
105///
106/// ## Clean shutdown
107/// Because Rust has no async equivalent of `Drop`, values that require server-side cleanup when
108/// dropped spawn a new async task to perform that cleanup.  This can cause two potential issues:
109///
110/// * Drop tasks pending or in progress when the async runtime shuts down may not complete, causing
111///   server-side resources to not be freed.
112/// * Drop tasks may run at an arbitrary time even after no `Client` values exist, making it hard to
113///   reason about associated resources (e.g. event handlers).
114///
115/// To address these issues, we highly recommend you use [`Client::shutdown`] in the termination
116/// path of your application.  This will ensure that outstanding resources have been cleaned up and
117/// terminate internal worker tasks before returning.  Please note that `shutdown` will wait for
118/// _all_ outstanding resource handles to be dropped, so they must either have been dropped before
119/// calling `shutdown` or in a concurrent task; see the documentation of `shutdown` for more
120/// details.
121///
122/// ## Overload Retry Behavior
123/// All operations executed by a `Client` may retry if the selected server is overloaded. For
124/// details on server load-shedding, see the documentation for
125/// [Intelligent Workload Management](https://www.mongodb.com/docs/atlas/intelligent-workload-management/)
126/// and [Overload Errors](https://www.mongodb.com/docs/atlas/overload-errors).
127///
128/// The following options can be configured to customize this behavior:
129/// - Set [`ClientOptions::retry_reads`] to false to disable retrying all reads. Note that this will
130///   also disable non-overload retries.
131/// - Set [`ClientOptions::retry_writes`] to false to disable retrying all writes. Note that this
132///   will also disable non-overload retries.
133/// - Set [`ClientOptions::max_adaptive_retries`] to adjust the number of retries to perform when
134///   overload errors are encountered.
135/// - Set [`ClientOptions::enable_overload_retargeting`] to deprioritize servers on which an
136///   overload error has occurred.
137#[derive(Debug, Clone)]
138pub struct Client {
139    inner: TrackingArc<ClientInner>,
140}
141
142#[allow(dead_code, unreachable_code, clippy::diverging_sub_expression)]
143const _: fn() = || {
144    fn assert_send<T: Send>(_t: T) {}
145    fn assert_sync<T: Sync>(_t: T) {}
146
147    let _c: super::Client = todo!();
148    assert_send(_c);
149    assert_sync(_c);
150};
151
152#[derive(Debug)]
153struct ClientInner {
154    topology: Topology,
155    options: ClientOptions,
156    session_pool: ServerSessionPool,
157    shutdown: ClientShutdown,
158    dropped: AtomicBool,
159    end_sessions_token: std::sync::Mutex<AsyncDropToken>,
160    token_bucket: Option<Mutex<u16>>,
161    #[cfg(feature = "in-use-encryption")]
162    csfle: tokio::sync::RwLock<Option<csfle::ClientState>>,
163    #[cfg(feature = "opentelemetry")]
164    tracer: opentelemetry::global::BoxedTracer,
165    #[cfg(test)]
166    disable_command_events: AtomicBool,
167}
168
169#[derive(Debug)]
170struct ClientShutdown {
171    pending_drops: TaskTracker,
172    executed: AtomicBool,
173}
174
175impl Client {
176    /// Creates a new `Client` connected to the cluster specified by `uri`. `uri` must be a valid
177    /// MongoDB connection string.
178    ///
179    /// See the documentation on
180    /// [`ClientOptions::parse`](options/struct.ClientOptions.html#method.parse) for more details.
181    pub async fn with_uri_str(uri: impl AsRef<str>) -> Result<Self> {
182        let options = ClientOptions::parse(uri.as_ref()).await?;
183
184        Client::with_options(options)
185    }
186
187    /// Creates a new `Client` connected to the cluster specified by `options`.
188    pub fn with_options(options: ClientOptions) -> Result<Self> {
189        options.validate()?;
190
191        // Spawn a cleanup task, similar to register_async_drop
192        let (cleanup_tx, cleanup_rx) = tokio::sync::oneshot::channel::<BoxFuture<'static, ()>>();
193        crate::runtime::spawn(async move {
194            // If the cleanup channel is closed, that task was dropped.
195            if let Ok(cleanup) = cleanup_rx.await {
196                cleanup.await;
197            }
198        });
199        let end_sessions_token = std::sync::Mutex::new(AsyncDropToken {
200            tx: Some(cleanup_tx),
201        });
202
203        #[cfg(feature = "opentelemetry")]
204        let tracer = options.tracer();
205
206        let token_bucket = options
207            .adaptive_retries
208            .unwrap_or(false)
209            .then(|| Mutex::new(MAX_BUCKET_CAPACITY));
210
211        let inner = TrackingArc::new(ClientInner {
212            topology: Topology::new(options.clone())?,
213            session_pool: ServerSessionPool::new(),
214            options,
215            shutdown: ClientShutdown {
216                pending_drops: TaskTracker::new(),
217                executed: AtomicBool::new(false),
218            },
219            dropped: AtomicBool::new(false),
220            end_sessions_token,
221            token_bucket,
222            #[cfg(feature = "in-use-encryption")]
223            csfle: Default::default(),
224            #[cfg(feature = "opentelemetry")]
225            tracer,
226            #[cfg(test)]
227            disable_command_events: AtomicBool::new(false),
228        });
229        Ok(Self { inner })
230    }
231
232    /// Return an `EncryptedClientBuilder` for constructing a `Client` with auto-encryption enabled.
233    ///
234    /// ```no_run
235    /// # use mongocrypt::ctx::KmsProvider;
236    /// # use mongodb::{Client, bson::{self, doc}, error::Result};
237    /// # async fn func() -> Result<()> {
238    /// # let client_options = todo!();
239    /// # let key_vault_namespace = todo!();
240    /// # let key_vault_client: Client = todo!();
241    /// # let local_key: bson::Binary = todo!();
242    /// let encrypted_client = Client::encrypted_builder(
243    ///     client_options,
244    ///     key_vault_namespace,
245    ///     [(KmsProvider::local(), doc! { "key": local_key }, None)],
246    /// )?
247    /// .key_vault_client(key_vault_client)
248    /// .build()
249    /// .await?;
250    /// # Ok(())
251    /// # }
252    /// ```
253    #[cfg(feature = "in-use-encryption")]
254    pub fn encrypted_builder(
255        client_options: ClientOptions,
256        key_vault_namespace: crate::Namespace,
257        kms_providers: impl IntoIterator<
258            Item = (
259                mongocrypt::ctx::KmsProvider,
260                crate::bson::Document,
261                Option<options::TlsOptions>,
262            ),
263        >,
264    ) -> Result<EncryptedClientBuilder> {
265        Ok(EncryptedClientBuilder::new(
266            client_options,
267            csfle::options::AutoEncryptionOptions::new(
268                key_vault_namespace,
269                csfle::options::KmsProviders::new(kms_providers)?,
270            ),
271        ))
272    }
273
274    /// Whether commands sent via this client should be auto-encrypted.
275    pub(crate) async fn should_auto_encrypt(&self) -> bool {
276        #[cfg(feature = "in-use-encryption")]
277        {
278            let csfle = self.inner.csfle.read().await;
279            match *csfle {
280                Some(ref csfle) => csfle
281                    .opts()
282                    .bypass_auto_encryption
283                    .map(|b| !b)
284                    .unwrap_or(true),
285                None => false,
286            }
287        }
288        #[cfg(not(feature = "in-use-encryption"))]
289        {
290            false
291        }
292    }
293
294    #[cfg(all(test, feature = "in-use-encryption"))]
295    pub(crate) async fn mongocryptd_spawned(&self) -> bool {
296        self.inner
297            .csfle
298            .read()
299            .await
300            .as_ref()
301            .is_some_and(|cs| cs.exec().mongocryptd_spawned())
302    }
303
304    #[cfg(all(test, feature = "in-use-encryption"))]
305    pub(crate) async fn has_mongocryptd_client(&self) -> bool {
306        self.inner
307            .csfle
308            .read()
309            .await
310            .as_ref()
311            .is_some_and(|cs| cs.exec().has_mongocryptd_client())
312    }
313
314    fn test_command_event_channel(&self) -> Option<&options::TestEventSender> {
315        #[cfg(test)]
316        {
317            self.inner
318                .options
319                .test_options
320                .as_ref()
321                .and_then(|t| t.async_event_listener.as_ref())
322        }
323        #[cfg(not(test))]
324        {
325            None
326        }
327    }
328
329    pub(crate) async fn emit_command_event(&self, generate_event: impl FnOnce() -> CommandEvent) {
330        #[cfg(test)]
331        if self
332            .inner
333            .disable_command_events
334            .load(std::sync::atomic::Ordering::SeqCst)
335        {
336            return;
337        }
338        #[cfg(feature = "tracing-unstable")]
339        let tracing_emitter = if trace_or_log_enabled!(
340            target: COMMAND_TRACING_EVENT_TARGET,
341            TracingOrLogLevel::Debug
342        ) {
343            Some(CommandTracingEventEmitter::new(
344                self.inner.options.tracing_max_document_length_bytes,
345                self.inner.topology.id,
346            ))
347        } else {
348            None
349        };
350        let test_channel = self.test_command_event_channel();
351        let should_send = test_channel.is_some() || self.options().command_event_handler.is_some();
352        #[cfg(feature = "tracing-unstable")]
353        let should_send = should_send || tracing_emitter.is_some();
354        if !should_send {
355            return;
356        }
357
358        let event = generate_event();
359        if let Some(tx) = test_channel {
360            let (msg, ack) = crate::runtime::AcknowledgedMessage::package(event.clone());
361            let _ = tx.send(msg).await;
362            ack.wait_for_acknowledgment().await;
363        }
364        #[cfg(feature = "tracing-unstable")]
365        if let Some(ref tracing_emitter) = tracing_emitter {
366            tracing_emitter.handle(event.clone());
367        }
368        if let Some(handler) = &self.options().command_event_handler {
369            handler.handle(event);
370        }
371    }
372
373    /// Gets the default selection criteria the `Client` uses for operations..
374    pub fn selection_criteria(&self) -> Option<&SelectionCriteria> {
375        self.inner.options.selection_criteria.as_ref()
376    }
377
378    /// Gets the default read concern the `Client` uses for operations.
379    pub fn read_concern(&self) -> Option<&ReadConcern> {
380        self.inner.options.read_concern.as_ref()
381    }
382
383    /// Gets the default write concern the `Client` uses for operations.
384    pub fn write_concern(&self) -> Option<&WriteConcern> {
385        self.inner.options.write_concern.as_ref()
386    }
387
388    /// Gets a handle to a database specified by `name` in the cluster the `Client` is connected to.
389    /// The `Database` options (e.g. read preference and write concern) will default to those of the
390    /// `Client`.
391    ///
392    /// This method does not send or receive anything across the wire to the database, so it can be
393    /// used repeatedly without incurring any costs from I/O.
394    pub fn database(&self, name: &str) -> Database {
395        Database::new(self.clone(), name, None)
396    }
397
398    /// Gets a handle to a database specified by `name` in the cluster the `Client` is connected to.
399    /// Operations done with this `Database` will use the options specified by `options` by default
400    /// and will otherwise default to those of the `Client`.
401    ///
402    /// This method does not send or receive anything across the wire to the database, so it can be
403    /// used repeatedly without incurring any costs from I/O.
404    pub fn database_with_options(&self, name: &str, options: DatabaseOptions) -> Database {
405        Database::new(self.clone(), name, Some(options))
406    }
407
408    /// Gets a handle to the default database specified in the `ClientOptions` or MongoDB connection
409    /// string used to construct this `Client`.
410    ///
411    /// If no default database was specified, `None` will be returned.
412    pub fn default_database(&self) -> Option<Database> {
413        self.inner
414            .options
415            .default_database
416            .as_ref()
417            .map(|db_name| self.database(db_name))
418    }
419
420    /// Append new information to the metadata of the handshake with the server.
421    pub fn append_metadata(&self, driver_info: DriverInfo) -> Result<()> {
422        self.inner
423            .topology
424            .metadata
425            .write()
426            .unwrap()
427            .append(driver_info)
428    }
429
430    pub(crate) fn register_async_drop(&self) -> AsyncDropToken {
431        let (cleanup_tx, cleanup_rx) = tokio::sync::oneshot::channel::<BoxFuture<'static, ()>>();
432        crate::runtime::spawn(self.inner.shutdown.pending_drops.track_future(async move {
433            // If the cleanup channel is closed, that task was dropped.
434            if let Ok(cleanup) = cleanup_rx.await {
435                cleanup.await;
436            }
437        }));
438        AsyncDropToken {
439            tx: Some(cleanup_tx),
440        }
441    }
442
443    /// Check in a server session to the server session pool. The session will be discarded if it is
444    /// expired or dirty.
445    pub(crate) async fn check_in_server_session(&self, session: ServerSession) {
446        let timeout = self.inner.topology.watcher().logical_session_timeout();
447        self.inner.session_pool.check_in(session, timeout).await;
448    }
449
450    #[cfg(test)]
451    pub(crate) async fn clear_session_pool(&self) {
452        self.inner.session_pool.clear().await;
453    }
454
455    #[cfg(test)]
456    pub(crate) async fn is_session_checked_in(&self, id: &crate::bson::Document) -> bool {
457        self.inner.session_pool.contains(id).await
458    }
459
460    #[cfg(test)]
461    pub(crate) fn disable_command_events(&self, disable: bool) {
462        self.inner
463            .disable_command_events
464            .store(disable, std::sync::atomic::Ordering::SeqCst);
465    }
466
467    /// Get the address of the server selected according to the given criteria.
468    /// This method is only used in tests.
469    #[cfg(test)]
470    pub(crate) async fn test_select_server(
471        &self,
472        criteria: Option<&SelectionCriteria>,
473    ) -> Result<ServerAddress> {
474        let (server, _) = self
475            .select_server(criteria, None, OpSelectionInfo::new("Test select server"))
476            .await?;
477        Ok(server.address.clone())
478    }
479
480    /// Select a server using the provided criteria. If none is provided, a primary read preference
481    /// will be used instead.
482    async fn select_server(
483        &self,
484        criteria: Option<&SelectionCriteria>,
485        deprioritized: Option<&HashSet<ServerAddress>>,
486        op_info: OpSelectionInfo<'_>,
487    ) -> Result<(SelectedServer, SelectionCriteria)> {
488        let criteria =
489            criteria.unwrap_or(&SelectionCriteria::ReadPreference(ReadPreference::Primary));
490
491        let start_time = Instant::now();
492        let timeout = self
493            .inner
494            .options
495            .server_selection_timeout
496            .unwrap_or(DEFAULT_SERVER_SELECTION_TIMEOUT);
497
498        #[cfg(feature = "tracing-unstable")]
499        let event_emitter = ServerSelectionTracingEventEmitter::new(
500            self.inner.topology.id,
501            criteria,
502            op_info.name,
503            start_time,
504            timeout,
505            self.options().tracing_max_document_length_bytes,
506        );
507        #[cfg(feature = "tracing-unstable")]
508        event_emitter.emit_started_event(self.inner.topology.latest().description.clone());
509        // We only want to emit this message once per operation at most.
510        #[cfg(feature = "tracing-unstable")]
511        let mut emitted_waiting_message = false;
512
513        let mut watcher = self.inner.topology.watcher().clone();
514        loop {
515            let state = watcher.observe_latest();
516            let override_slot;
517            let effective_criteria =
518                if let Some(oc) = (op_info.override_criteria)(criteria, &state.description) {
519                    override_slot = oc;
520                    &override_slot
521                } else {
522                    criteria
523                };
524            let result = server_selection::attempt_to_select_server(
525                effective_criteria,
526                &state.description,
527                &state.servers(),
528                deprioritized,
529            );
530            match result {
531                Err(error) => {
532                    #[cfg(feature = "tracing-unstable")]
533                    event_emitter.emit_failed_event(&state.description, &error);
534
535                    return Err(error);
536                }
537                Ok(result) => {
538                    if let Some(server) = result {
539                        #[cfg(feature = "tracing-unstable")]
540                        event_emitter.emit_succeeded_event(&state.description, &server);
541
542                        return Ok((server, effective_criteria.clone()));
543                    } else {
544                        #[cfg(feature = "tracing-unstable")]
545                        if !emitted_waiting_message {
546                            event_emitter.emit_waiting_event(&state.description);
547                            emitted_waiting_message = true;
548                        }
549
550                        watcher.request_immediate_check();
551
552                        let elapsed = start_time.elapsed();
553                        let change_occurred = elapsed < timeout
554                            && watcher
555                                .wait_for_update(
556                                    timeout.checked_sub(elapsed).unwrap_or(Duration::ZERO),
557                                )
558                                .await;
559                        if !change_occurred {
560                            let error: Error = ErrorKind::ServerSelection {
561                                message: state
562                                    .description
563                                    .server_selection_timeout_error_message(criteria),
564                            }
565                            .into();
566
567                            #[cfg(feature = "tracing-unstable")]
568                            event_emitter.emit_failed_event(&state.description, &error);
569
570                            return Err(error);
571                        }
572                    }
573                }
574            }
575        }
576    }
577
578    #[cfg(all(test, feature = "dns-resolver"))]
579    pub(crate) fn get_hosts(&self) -> Vec<String> {
580        let state = self.inner.topology.latest();
581
582        state
583            .servers()
584            .keys()
585            .map(|stream_address| format!("{stream_address}"))
586            .collect()
587    }
588
589    #[cfg(test)]
590    pub(crate) async fn sync_workers(&self) {
591        self.inner.topology.updater().sync_workers().await;
592    }
593
594    #[cfg(test)]
595    pub(crate) fn topology_description(&self) -> crate::sdam::TopologyDescription {
596        self.inner.topology.latest().description.clone()
597    }
598
599    pub(crate) fn is_sharded(&self) -> bool {
600        self.inner.topology.latest().description.topology_type == TopologyType::Sharded
601    }
602
603    #[cfg(test)]
604    pub(crate) fn topology(&self) -> &Topology {
605        &self.inner.topology
606    }
607
608    #[cfg(feature = "in-use-encryption")]
609    pub(crate) async fn primary_description(&self) -> Option<crate::sdam::ServerDescription> {
610        let start_time = Instant::now();
611        let timeout = self
612            .inner
613            .options
614            .server_selection_timeout
615            .unwrap_or(DEFAULT_SERVER_SELECTION_TIMEOUT);
616        let mut watcher = self.inner.topology.watcher().clone();
617        loop {
618            let topology = watcher.observe_latest();
619            if let Some(desc) = topology.description.primary() {
620                return Some(desc.clone());
621            }
622            let remaining = timeout
623                .checked_sub(start_time.elapsed())
624                .unwrap_or(Duration::ZERO);
625            if !watcher.wait_for_update(remaining).await {
626                return None;
627            }
628        }
629    }
630
631    pub(crate) fn weak(&self) -> WeakClient {
632        WeakClient {
633            inner: TrackingArc::downgrade(&self.inner),
634        }
635    }
636
637    #[cfg(feature = "in-use-encryption")]
638    pub(crate) async fn auto_encryption_opts(
639        &self,
640    ) -> Option<tokio::sync::RwLockReadGuard<'_, csfle::options::AutoEncryptionOptions>> {
641        tokio::sync::RwLockReadGuard::try_map(self.inner.csfle.read().await, |csfle| {
642            csfle.as_ref().map(|cs| cs.opts())
643        })
644        .ok()
645    }
646
647    pub(crate) fn options(&self) -> &ClientOptions {
648        &self.inner.options
649    }
650
651    /// Ends all sessions contained in this client's session pool on the server.
652    pub(crate) async fn end_all_sessions(&self) {
653        // The maximum number of session IDs that should be sent in a single endSessions command.
654        const MAX_END_SESSIONS_BATCH_SIZE: usize = 10_000;
655
656        let mut watcher = self.inner.topology.watcher().clone();
657        let selection_criteria =
658            SelectionCriteria::from(ReadPreference::PrimaryPreferred { options: None });
659
660        let session_ids = self.inner.session_pool.get_session_ids().await;
661        for chunk in session_ids.chunks(MAX_END_SESSIONS_BATCH_SIZE) {
662            let state = watcher.observe_latest();
663            let Ok(Some(_)) = attempt_to_select_server(
664                &selection_criteria,
665                &state.description,
666                &state.servers(),
667                None,
668            ) else {
669                // If a suitable server is not available, do not proceed with the operation to avoid
670                // spinning for server_selection_timeout.
671                return;
672            };
673
674            let end_sessions = doc! {
675                "endSessions": chunk,
676            };
677            let _ = self
678                .database("admin")
679                .run_command(end_sessions)
680                .selection_criteria(selection_criteria.clone())
681                .await;
682        }
683    }
684
685    #[cfg(feature = "opentelemetry")]
686    pub(crate) fn tracer(&self) -> &opentelemetry::global::BoxedTracer {
687        &self.inner.tracer
688    }
689}
690
691#[derive(Clone, Debug)]
692pub(crate) struct WeakClient {
693    inner: crate::tracking_arc::Weak<ClientInner>,
694}
695
696impl WeakClient {
697    pub(crate) fn upgrade(&self) -> Option<Client> {
698        self.inner.upgrade().map(|inner| Client { inner })
699    }
700}
701
702#[derive_where(Debug)]
703pub(crate) struct AsyncDropToken {
704    #[derive_where(skip)]
705    tx: Option<tokio::sync::oneshot::Sender<BoxFuture<'static, ()>>>,
706}
707
708impl AsyncDropToken {
709    pub(crate) fn spawn(&mut self, fut: impl Future<Output = ()> + Send + 'static) {
710        if let Some(tx) = self.tx.take() {
711            let _ = tx.send(fut.boxed());
712        } else {
713            #[cfg(debug_assertions)]
714            panic!("exhausted AsyncDropToken");
715        }
716    }
717}
718
719impl Drop for Client {
720    fn drop(&mut self) {
721        if !self.inner.shutdown.executed.load(Ordering::SeqCst)
722            && !self.inner.dropped.load(Ordering::SeqCst)
723            && TrackingArc::strong_count(&self.inner) == 1
724        {
725            // We need an owned copy of the client to move into the spawned future. However, if this
726            // call to drop completes before the spawned future completes, the number of strong
727            // references to the inner client will again be 1 when the cloned client drops, and thus
728            // end_all_sessions will be called continuously until the runtime shuts down. Storing a
729            // flag indicating whether end_all_sessions has already been called breaks
730            // this cycle.
731            self.inner.dropped.store(true, Ordering::SeqCst);
732            let client = self.clone();
733            self.inner
734                .end_sessions_token
735                .lock()
736                .unwrap()
737                .spawn(async move {
738                    client.end_all_sessions().await;
739                });
740        }
741    }
742}
743
744/// Operation-specific parameters to server selection
745pub(crate) struct OpSelectionInfo<'a> {
746    #[allow(unused)]
747    name: &'a str,
748    override_criteria: OverrideCriteriaFn,
749}
750
751impl<'a, T: crate::operation::Operation> From<&'a T> for OpSelectionInfo<'a> {
752    fn from(op: &'a T) -> Self {
753        Self {
754            name: crate::bson_compat::cstr_to_str(op.name()),
755            override_criteria: op.override_criteria(),
756        }
757    }
758}
759
760impl<'a> OpSelectionInfo<'a> {
761    fn new(name: &'a str) -> Self {
762        Self {
763            name,
764            override_criteria: |_, _| None,
765        }
766    }
767}
768
769/// Retry token bucket functionality. Note that the values used are scaled by a factor of 10 from
770/// those defined in the spec to allow for the use of integers.
771const MAX_BUCKET_CAPACITY: u16 = 10_000;
772const RETRY_TOKEN_RETURN_RATE: u16 = 1;
773const RETRY_TOKEN_CONSUME_RATE: u16 = 10;
774impl Client {
775    pub(crate) async fn consume_from_token_bucket(&self) -> bool {
776        let Some(ref bucket) = self.inner.token_bucket else {
777            return true;
778        };
779        let mut tokens = bucket.lock().await;
780        if *tokens >= RETRY_TOKEN_CONSUME_RATE {
781            *tokens -= RETRY_TOKEN_CONSUME_RATE;
782            true
783        } else {
784            false
785        }
786    }
787
788    pub(crate) async fn deposit_success_in_token_bucket(&self, is_retry: bool) {
789        let Some(ref bucket) = self.inner.token_bucket else {
790            return;
791        };
792        let mut deposit = RETRY_TOKEN_RETURN_RATE;
793        if is_retry {
794            deposit += 10;
795        }
796        let mut tokens = bucket.lock().await;
797        *tokens = std::cmp::min(*tokens + deposit, MAX_BUCKET_CAPACITY);
798    }
799
800    pub(crate) async fn deposit_retry_error_in_token_bucket(&self) {
801        let Some(ref bucket) = self.inner.token_bucket else {
802            return;
803        };
804        let mut tokens = bucket.lock().await;
805        *tokens = std::cmp::min(*tokens + RETRY_TOKEN_RETURN_RATE, MAX_BUCKET_CAPACITY);
806    }
807
808    #[cfg(test)]
809    #[expect(dead_code)]
810    pub(crate) async fn get_num_tokens_in_bucket(&self) -> Option<u16> {
811        let bucket = self.inner.token_bucket.as_ref()?;
812        Some(*bucket.lock().await)
813    }
814
815    #[cfg(test)]
816    #[expect(dead_code)]
817    pub(crate) async fn set_num_tokens_in_bucket(&self, tokens: u16) {
818        let Some(ref bucket) = self.inner.token_bucket else {
819            return;
820        };
821        *bucket.lock().await = tokens;
822    }
823}