1pub mod action;
2pub mod auth;
3#[cfg(feature = "in-use-encryption")]
4pub(crate) mod csfle;
5mod executor;
6pub mod options;
7pub mod session;
8
9use std::{
10 collections::HashSet,
11 sync::{
12 atomic::{AtomicBool, Ordering},
13 Mutex as SyncMutex,
14 },
15 time::{Duration, Instant},
16};
17
18#[cfg(feature = "in-use-encryption")]
19pub use self::csfle::client_builder::*;
20use derive_where::derive_where;
21use futures_core::Future;
22use futures_util::FutureExt;
23use tokio::sync::Mutex;
24
25#[cfg(feature = "tracing-unstable")]
26use crate::trace::{
27 command::CommandTracingEventEmitter,
28 server_selection::ServerSelectionTracingEventEmitter,
29 trace_or_log_enabled,
30 TracingOrLogLevel,
31 COMMAND_TRACING_EVENT_TARGET,
32};
33use crate::{
34 bson::doc,
35 concern::{ReadConcern, WriteConcern},
36 db::Database,
37 error::{Error, ErrorKind, Result},
38 event::command::CommandEvent,
39 id_set::IdSet,
40 operation::OverrideCriteriaFn,
41 options::{
42 ClientOptions,
43 DatabaseOptions,
44 DriverInfo,
45 ReadPreference,
46 SelectionCriteria,
47 ServerAddress,
48 },
49 sdam::{
50 server_selection::{self, attempt_to_select_server},
51 SelectedServer,
52 Topology,
53 },
54 tracking_arc::TrackingArc,
55 BoxFuture,
56 ClientSession,
57 TopologyType,
58};
59
60pub(crate) use executor::{retry::Retry, HELLO_COMMAND_NAMES, REDACTED_COMMANDS};
61pub(crate) use session::{ClusterTime, SESSIONS_UNSUPPORTED_COMMANDS};
62
63use session::{ServerSession, ServerSessionPool};
64
65const DEFAULT_SERVER_SELECTION_TIMEOUT: Duration = Duration::from_secs(30);
66
67#[derive(Debug, Clone)]
141pub struct Client {
142 inner: TrackingArc<ClientInner>,
143}
144
145#[allow(dead_code, unreachable_code, clippy::diverging_sub_expression)]
146const _: fn() = || {
147 fn assert_send<T: Send>(_t: T) {}
148 fn assert_sync<T: Sync>(_t: T) {}
149
150 let _c: super::Client = todo!();
151 assert_send(_c);
152 assert_sync(_c);
153};
154
155#[derive(Debug)]
156struct ClientInner {
157 topology: Topology,
158 options: ClientOptions,
159 session_pool: ServerSessionPool,
160 shutdown: Shutdown,
161 dropped: AtomicBool,
162 end_sessions_token: std::sync::Mutex<AsyncDropToken>,
163 token_bucket: Option<Mutex<u16>>,
164 #[cfg(feature = "in-use-encryption")]
165 csfle: tokio::sync::RwLock<Option<csfle::ClientState>>,
166 #[cfg(feature = "opentelemetry")]
167 tracer: opentelemetry::global::BoxedTracer,
168 #[cfg(test)]
169 disable_command_events: AtomicBool,
170}
171
172#[derive(Debug)]
173struct Shutdown {
174 pending_drops: SyncMutex<IdSet<crate::runtime::AsyncJoinHandle<()>>>,
175 executed: AtomicBool,
176}
177
178impl Client {
179 pub async fn with_uri_str(uri: impl AsRef<str>) -> Result<Self> {
185 let options = ClientOptions::parse(uri.as_ref()).await?;
186
187 Client::with_options(options)
188 }
189
190 pub fn with_options(options: ClientOptions) -> Result<Self> {
192 options.validate()?;
193
194 let (cleanup_tx, cleanup_rx) = tokio::sync::oneshot::channel::<BoxFuture<'static, ()>>();
196 crate::runtime::spawn(async move {
197 if let Ok(cleanup) = cleanup_rx.await {
199 cleanup.await;
200 }
201 });
202 let end_sessions_token = std::sync::Mutex::new(AsyncDropToken {
203 tx: Some(cleanup_tx),
204 });
205
206 #[cfg(feature = "opentelemetry")]
207 let tracer = options.tracer();
208
209 let token_bucket = options
210 .adaptive_retries
211 .unwrap_or(false)
212 .then(|| Mutex::new(MAX_BUCKET_CAPACITY));
213
214 let inner = TrackingArc::new(ClientInner {
215 topology: Topology::new(options.clone())?,
216 session_pool: ServerSessionPool::new(),
217 options,
218 shutdown: Shutdown {
219 pending_drops: SyncMutex::new(IdSet::new()),
220 executed: AtomicBool::new(false),
221 },
222 dropped: AtomicBool::new(false),
223 end_sessions_token,
224 token_bucket,
225 #[cfg(feature = "in-use-encryption")]
226 csfle: Default::default(),
227 #[cfg(feature = "opentelemetry")]
228 tracer,
229 #[cfg(test)]
230 disable_command_events: AtomicBool::new(false),
231 });
232 Ok(Self { inner })
233 }
234
235 #[cfg(feature = "in-use-encryption")]
257 pub fn encrypted_builder(
258 client_options: ClientOptions,
259 key_vault_namespace: crate::Namespace,
260 kms_providers: impl IntoIterator<
261 Item = (
262 mongocrypt::ctx::KmsProvider,
263 crate::bson::Document,
264 Option<options::TlsOptions>,
265 ),
266 >,
267 ) -> Result<EncryptedClientBuilder> {
268 Ok(EncryptedClientBuilder::new(
269 client_options,
270 csfle::options::AutoEncryptionOptions::new(
271 key_vault_namespace,
272 csfle::options::KmsProviders::new(kms_providers)?,
273 ),
274 ))
275 }
276
277 pub(crate) async fn should_auto_encrypt(&self) -> bool {
279 #[cfg(feature = "in-use-encryption")]
280 {
281 let csfle = self.inner.csfle.read().await;
282 match *csfle {
283 Some(ref csfle) => csfle
284 .opts()
285 .bypass_auto_encryption
286 .map(|b| !b)
287 .unwrap_or(true),
288 None => false,
289 }
290 }
291 #[cfg(not(feature = "in-use-encryption"))]
292 {
293 false
294 }
295 }
296
297 #[cfg(all(test, feature = "in-use-encryption"))]
298 pub(crate) async fn mongocryptd_spawned(&self) -> bool {
299 self.inner
300 .csfle
301 .read()
302 .await
303 .as_ref()
304 .is_some_and(|cs| cs.exec().mongocryptd_spawned())
305 }
306
307 #[cfg(all(test, feature = "in-use-encryption"))]
308 pub(crate) async fn has_mongocryptd_client(&self) -> bool {
309 self.inner
310 .csfle
311 .read()
312 .await
313 .as_ref()
314 .is_some_and(|cs| cs.exec().has_mongocryptd_client())
315 }
316
317 fn test_command_event_channel(&self) -> Option<&options::TestEventSender> {
318 #[cfg(test)]
319 {
320 self.inner
321 .options
322 .test_options
323 .as_ref()
324 .and_then(|t| t.async_event_listener.as_ref())
325 }
326 #[cfg(not(test))]
327 {
328 None
329 }
330 }
331
332 pub(crate) async fn emit_command_event(&self, generate_event: impl FnOnce() -> CommandEvent) {
333 #[cfg(test)]
334 if self
335 .inner
336 .disable_command_events
337 .load(std::sync::atomic::Ordering::SeqCst)
338 {
339 return;
340 }
341 #[cfg(feature = "tracing-unstable")]
342 let tracing_emitter = if trace_or_log_enabled!(
343 target: COMMAND_TRACING_EVENT_TARGET,
344 TracingOrLogLevel::Debug
345 ) {
346 Some(CommandTracingEventEmitter::new(
347 self.inner.options.tracing_max_document_length_bytes,
348 self.inner.topology.id,
349 ))
350 } else {
351 None
352 };
353 let test_channel = self.test_command_event_channel();
354 let should_send = test_channel.is_some() || self.options().command_event_handler.is_some();
355 #[cfg(feature = "tracing-unstable")]
356 let should_send = should_send || tracing_emitter.is_some();
357 if !should_send {
358 return;
359 }
360
361 let event = generate_event();
362 if let Some(tx) = test_channel {
363 let (msg, ack) = crate::runtime::AcknowledgedMessage::package(event.clone());
364 let _ = tx.send(msg).await;
365 ack.wait_for_acknowledgment().await;
366 }
367 #[cfg(feature = "tracing-unstable")]
368 if let Some(ref tracing_emitter) = tracing_emitter {
369 tracing_emitter.handle(event.clone());
370 }
371 if let Some(handler) = &self.options().command_event_handler {
372 handler.handle(event);
373 }
374 }
375
376 pub fn selection_criteria(&self) -> Option<&SelectionCriteria> {
378 self.inner.options.selection_criteria.as_ref()
379 }
380
381 pub fn read_concern(&self) -> Option<&ReadConcern> {
383 self.inner.options.read_concern.as_ref()
384 }
385
386 pub fn write_concern(&self) -> Option<&WriteConcern> {
388 self.inner.options.write_concern.as_ref()
389 }
390
391 pub fn database(&self, name: &str) -> Database {
398 Database::new(self.clone(), name, None)
399 }
400
401 pub fn database_with_options(&self, name: &str, options: DatabaseOptions) -> Database {
408 Database::new(self.clone(), name, Some(options))
409 }
410
411 pub fn default_database(&self) -> Option<Database> {
416 self.inner
417 .options
418 .default_database
419 .as_ref()
420 .map(|db_name| self.database(db_name))
421 }
422
423 pub fn append_metadata(&self, driver_info: DriverInfo) -> Result<()> {
425 self.inner
426 .topology
427 .metadata
428 .write()
429 .unwrap()
430 .append(driver_info)
431 }
432
433 pub(crate) fn register_async_drop(&self) -> AsyncDropToken {
434 let (cleanup_tx, cleanup_rx) = tokio::sync::oneshot::channel::<BoxFuture<'static, ()>>();
435 let (id_tx, id_rx) = tokio::sync::oneshot::channel::<crate::id_set::Id>();
436 let weak = self.weak();
437 let handle = crate::runtime::spawn(async move {
438 let id = id_rx.await.unwrap();
441 if let Ok(cleanup) = cleanup_rx.await {
443 cleanup.await;
444 }
445 if let Some(client) = weak.upgrade() {
446 client
447 .inner
448 .shutdown
449 .pending_drops
450 .lock()
451 .unwrap()
452 .remove(&id);
453 }
454 });
455 let id = self
456 .inner
457 .shutdown
458 .pending_drops
459 .lock()
460 .unwrap()
461 .insert(handle);
462 let _ = id_tx.send(id);
463 AsyncDropToken {
464 tx: Some(cleanup_tx),
465 }
466 }
467
468 pub(crate) async fn check_in_server_session(&self, session: ServerSession) {
471 let timeout = self.inner.topology.watcher().logical_session_timeout();
472 self.inner.session_pool.check_in(session, timeout).await;
473 }
474
475 #[cfg(test)]
476 pub(crate) async fn clear_session_pool(&self) {
477 self.inner.session_pool.clear().await;
478 }
479
480 #[cfg(test)]
481 pub(crate) async fn is_session_checked_in(&self, id: &crate::bson::Document) -> bool {
482 self.inner.session_pool.contains(id).await
483 }
484
485 #[cfg(test)]
486 pub(crate) fn disable_command_events(&self, disable: bool) {
487 self.inner
488 .disable_command_events
489 .store(disable, std::sync::atomic::Ordering::SeqCst);
490 }
491
492 #[cfg(test)]
495 pub(crate) async fn test_select_server(
496 &self,
497 criteria: Option<&SelectionCriteria>,
498 ) -> Result<ServerAddress> {
499 let (server, _) = self
500 .select_server(criteria, None, OpSelectionInfo::new("Test select server"))
501 .await?;
502 Ok(server.address.clone())
503 }
504
505 async fn select_server(
508 &self,
509 criteria: Option<&SelectionCriteria>,
510 deprioritized: Option<&HashSet<ServerAddress>>,
511 op_info: OpSelectionInfo<'_>,
512 ) -> Result<(SelectedServer, SelectionCriteria)> {
513 let criteria =
514 criteria.unwrap_or(&SelectionCriteria::ReadPreference(ReadPreference::Primary));
515
516 let start_time = Instant::now();
517 let timeout = self
518 .inner
519 .options
520 .server_selection_timeout
521 .unwrap_or(DEFAULT_SERVER_SELECTION_TIMEOUT);
522
523 #[cfg(feature = "tracing-unstable")]
524 let event_emitter = ServerSelectionTracingEventEmitter::new(
525 self.inner.topology.id,
526 criteria,
527 op_info.name,
528 start_time,
529 timeout,
530 self.options().tracing_max_document_length_bytes,
531 );
532 #[cfg(feature = "tracing-unstable")]
533 event_emitter.emit_started_event(self.inner.topology.latest().description.clone());
534 #[cfg(feature = "tracing-unstable")]
536 let mut emitted_waiting_message = false;
537
538 let mut watcher = self.inner.topology.watcher().clone();
539 loop {
540 let state = watcher.observe_latest();
541 let override_slot;
542 let effective_criteria =
543 if let Some(oc) = (op_info.override_criteria)(criteria, &state.description) {
544 override_slot = oc;
545 &override_slot
546 } else {
547 criteria
548 };
549 let result = server_selection::attempt_to_select_server(
550 effective_criteria,
551 &state.description,
552 &state.servers(),
553 deprioritized,
554 );
555 match result {
556 Err(error) => {
557 #[cfg(feature = "tracing-unstable")]
558 event_emitter.emit_failed_event(&state.description, &error);
559
560 return Err(error);
561 }
562 Ok(result) => {
563 if let Some(server) = result {
564 #[cfg(feature = "tracing-unstable")]
565 event_emitter.emit_succeeded_event(&state.description, &server);
566
567 return Ok((server, effective_criteria.clone()));
568 } else {
569 #[cfg(feature = "tracing-unstable")]
570 if !emitted_waiting_message {
571 event_emitter.emit_waiting_event(&state.description);
572 emitted_waiting_message = true;
573 }
574
575 watcher.request_immediate_check();
576
577 let elapsed = start_time.elapsed();
578 let change_occurred = elapsed < timeout
579 && watcher
580 .wait_for_update(
581 timeout.checked_sub(elapsed).unwrap_or(Duration::ZERO),
582 )
583 .await;
584 if !change_occurred {
585 let error: Error = ErrorKind::ServerSelection {
586 message: state
587 .description
588 .server_selection_timeout_error_message(criteria),
589 }
590 .into();
591
592 #[cfg(feature = "tracing-unstable")]
593 event_emitter.emit_failed_event(&state.description, &error);
594
595 return Err(error);
596 }
597 }
598 }
599 }
600 }
601 }
602
603 #[cfg(all(test, feature = "dns-resolver"))]
604 pub(crate) fn get_hosts(&self) -> Vec<String> {
605 let state = self.inner.topology.latest();
606
607 state
608 .servers()
609 .keys()
610 .map(|stream_address| format!("{stream_address}"))
611 .collect()
612 }
613
614 #[cfg(test)]
615 pub(crate) async fn sync_workers(&self) {
616 self.inner.topology.updater().sync_workers().await;
617 }
618
619 #[cfg(test)]
620 pub(crate) fn topology_description(&self) -> crate::sdam::TopologyDescription {
621 self.inner.topology.latest().description.clone()
622 }
623
624 pub(crate) fn is_sharded(&self) -> bool {
625 self.inner.topology.latest().description.topology_type == TopologyType::Sharded
626 }
627
628 #[cfg(test)]
629 pub(crate) fn topology(&self) -> &Topology {
630 &self.inner.topology
631 }
632
633 #[cfg(feature = "in-use-encryption")]
634 pub(crate) async fn primary_description(&self) -> Option<crate::sdam::ServerDescription> {
635 let start_time = Instant::now();
636 let timeout = self
637 .inner
638 .options
639 .server_selection_timeout
640 .unwrap_or(DEFAULT_SERVER_SELECTION_TIMEOUT);
641 let mut watcher = self.inner.topology.watcher().clone();
642 loop {
643 let topology = watcher.observe_latest();
644 if let Some(desc) = topology.description.primary() {
645 return Some(desc.clone());
646 }
647 let remaining = timeout
648 .checked_sub(start_time.elapsed())
649 .unwrap_or(Duration::ZERO);
650 if !watcher.wait_for_update(remaining).await {
651 return None;
652 }
653 }
654 }
655
656 pub(crate) fn weak(&self) -> WeakClient {
657 WeakClient {
658 inner: TrackingArc::downgrade(&self.inner),
659 }
660 }
661
662 #[cfg(feature = "in-use-encryption")]
663 pub(crate) async fn auto_encryption_opts(
664 &self,
665 ) -> Option<tokio::sync::RwLockReadGuard<'_, csfle::options::AutoEncryptionOptions>> {
666 tokio::sync::RwLockReadGuard::try_map(self.inner.csfle.read().await, |csfle| {
667 csfle.as_ref().map(|cs| cs.opts())
668 })
669 .ok()
670 }
671
672 pub(crate) fn options(&self) -> &ClientOptions {
673 &self.inner.options
674 }
675
676 pub(crate) async fn end_all_sessions(&self) {
678 const MAX_END_SESSIONS_BATCH_SIZE: usize = 10_000;
680
681 let mut watcher = self.inner.topology.watcher().clone();
682 let selection_criteria =
683 SelectionCriteria::from(ReadPreference::PrimaryPreferred { options: None });
684
685 let session_ids = self.inner.session_pool.get_session_ids().await;
686 for chunk in session_ids.chunks(MAX_END_SESSIONS_BATCH_SIZE) {
687 let state = watcher.observe_latest();
688 let Ok(Some(_)) = attempt_to_select_server(
689 &selection_criteria,
690 &state.description,
691 &state.servers(),
692 None,
693 ) else {
694 return;
697 };
698
699 let end_sessions = doc! {
700 "endSessions": chunk,
701 };
702 let _ = self
703 .database("admin")
704 .run_command(end_sessions)
705 .selection_criteria(selection_criteria.clone())
706 .await;
707 }
708 }
709
710 #[cfg(feature = "opentelemetry")]
711 pub(crate) fn tracer(&self) -> &opentelemetry::global::BoxedTracer {
712 &self.inner.tracer
713 }
714}
715
716#[derive(Clone, Debug)]
717pub(crate) struct WeakClient {
718 inner: crate::tracking_arc::Weak<ClientInner>,
719}
720
721impl WeakClient {
722 pub(crate) fn upgrade(&self) -> Option<Client> {
723 self.inner.upgrade().map(|inner| Client { inner })
724 }
725}
726
727#[derive_where(Debug)]
728pub(crate) struct AsyncDropToken {
729 #[derive_where(skip)]
730 tx: Option<tokio::sync::oneshot::Sender<BoxFuture<'static, ()>>>,
731}
732
733impl AsyncDropToken {
734 pub(crate) fn spawn(&mut self, fut: impl Future<Output = ()> + Send + 'static) {
735 if let Some(tx) = self.tx.take() {
736 let _ = tx.send(fut.boxed());
737 } else {
738 #[cfg(debug_assertions)]
739 panic!("exhausted AsyncDropToken");
740 }
741 }
742}
743
744impl Drop for Client {
745 fn drop(&mut self) {
746 if !self.inner.shutdown.executed.load(Ordering::SeqCst)
747 && !self.inner.dropped.load(Ordering::SeqCst)
748 && TrackingArc::strong_count(&self.inner) == 1
749 {
750 self.inner.dropped.store(true, Ordering::SeqCst);
757 let client = self.clone();
758 self.inner
759 .end_sessions_token
760 .lock()
761 .unwrap()
762 .spawn(async move {
763 client.end_all_sessions().await;
764 });
765 }
766 }
767}
768
769pub(crate) struct OpSelectionInfo<'a> {
771 #[allow(unused)]
772 name: &'a str,
773 override_criteria: OverrideCriteriaFn,
774}
775
776impl<'a, T: crate::operation::Operation> From<&'a T> for OpSelectionInfo<'a> {
777 fn from(op: &'a T) -> Self {
778 Self {
779 name: crate::bson_compat::cstr_to_str(op.name()),
780 override_criteria: op.override_criteria(),
781 }
782 }
783}
784
785impl<'a> OpSelectionInfo<'a> {
786 fn new(name: &'a str) -> Self {
787 Self {
788 name,
789 override_criteria: |_, _| None,
790 }
791 }
792}
793
794const MAX_BUCKET_CAPACITY: u16 = 10_000;
797const RETRY_TOKEN_RETURN_RATE: u16 = 1;
798const RETRY_TOKEN_CONSUME_RATE: u16 = 10;
799impl Client {
800 pub(crate) async fn consume_from_token_bucket(&self) -> bool {
801 let Some(ref bucket) = self.inner.token_bucket else {
802 return true;
803 };
804 let mut tokens = bucket.lock().await;
805 if *tokens >= RETRY_TOKEN_CONSUME_RATE {
806 *tokens -= RETRY_TOKEN_CONSUME_RATE;
807 true
808 } else {
809 false
810 }
811 }
812
813 pub(crate) async fn deposit_success_in_token_bucket(&self, is_retry: bool) {
814 let Some(ref bucket) = self.inner.token_bucket else {
815 return;
816 };
817 let mut deposit = RETRY_TOKEN_RETURN_RATE;
818 if is_retry {
819 deposit += 10;
820 }
821 let mut tokens = bucket.lock().await;
822 *tokens = std::cmp::min(*tokens + deposit, MAX_BUCKET_CAPACITY);
823 }
824
825 pub(crate) async fn deposit_retry_error_in_token_bucket(&self) {
826 let Some(ref bucket) = self.inner.token_bucket else {
827 return;
828 };
829 let mut tokens = bucket.lock().await;
830 *tokens = std::cmp::min(*tokens + RETRY_TOKEN_RETURN_RATE, MAX_BUCKET_CAPACITY);
831 }
832
833 #[cfg(test)]
834 #[expect(dead_code)]
835 pub(crate) async fn get_num_tokens_in_bucket(&self) -> Option<u16> {
836 let bucket = self.inner.token_bucket.as_ref()?;
837 Some(*bucket.lock().await)
838 }
839
840 #[cfg(test)]
841 #[expect(dead_code)]
842 pub(crate) async fn set_num_tokens_in_bucket(&self, tokens: u16) {
843 let Some(ref bucket) = self.inner.token_bucket else {
844 return;
845 };
846 *bucket.lock().await = tokens;
847 }
848}