1pub mod action;
2pub mod auth;
3#[cfg(feature = "in-use-encryption")]
4pub(crate) mod csfle;
5mod executor;
6pub mod options;
7pub mod session;
8
9use std::{
10 collections::HashSet,
11 sync::atomic::{AtomicBool, Ordering},
12 time::{Duration, Instant},
13};
14
15#[cfg(feature = "in-use-encryption")]
16pub use self::csfle::client_builder::*;
17use derive_where::derive_where;
18use futures_core::Future;
19use futures_util::FutureExt;
20use tokio::sync::Mutex;
21use tokio_util::task::TaskTracker;
22
23#[cfg(feature = "tracing-unstable")]
24use crate::trace::{
25 command::CommandTracingEventEmitter,
26 server_selection::ServerSelectionTracingEventEmitter,
27 trace_or_log_enabled,
28 TracingOrLogLevel,
29 COMMAND_TRACING_EVENT_TARGET,
30};
31use crate::{
32 bson::doc,
33 concern::{ReadConcern, WriteConcern},
34 db::Database,
35 error::{Error, ErrorKind, Result},
36 event::command::CommandEvent,
37 operation::OverrideCriteriaFn,
38 options::{
39 ClientOptions,
40 DatabaseOptions,
41 DriverInfo,
42 ReadPreference,
43 SelectionCriteria,
44 ServerAddress,
45 },
46 sdam::{
47 server_selection::{self, attempt_to_select_server},
48 SelectedServer,
49 Topology,
50 },
51 tracking_arc::TrackingArc,
52 BoxFuture,
53 ClientSession,
54 TopologyType,
55};
56
57pub(crate) use executor::{retry::Retry, HELLO_COMMAND_NAMES, REDACTED_COMMANDS};
58pub(crate) use session::{ClusterTime, SESSIONS_UNSUPPORTED_COMMANDS};
59
60use session::{ServerSession, ServerSessionPool};
61
62const DEFAULT_SERVER_SELECTION_TIMEOUT: Duration = Duration::from_secs(30);
63
64#[derive(Debug, Clone)]
138pub struct Client {
139 inner: TrackingArc<ClientInner>,
140}
141
142#[allow(dead_code, unreachable_code, clippy::diverging_sub_expression)]
143const _: fn() = || {
144 fn assert_send<T: Send>(_t: T) {}
145 fn assert_sync<T: Sync>(_t: T) {}
146
147 let _c: super::Client = todo!();
148 assert_send(_c);
149 assert_sync(_c);
150};
151
152#[derive(Debug)]
153struct ClientInner {
154 topology: Topology,
155 options: ClientOptions,
156 session_pool: ServerSessionPool,
157 shutdown: ClientShutdown,
158 dropped: AtomicBool,
159 end_sessions_token: std::sync::Mutex<AsyncDropToken>,
160 token_bucket: Option<Mutex<u16>>,
161 #[cfg(feature = "in-use-encryption")]
162 csfle: tokio::sync::RwLock<Option<csfle::ClientState>>,
163 #[cfg(feature = "opentelemetry")]
164 tracer: opentelemetry::global::BoxedTracer,
165 #[cfg(test)]
166 disable_command_events: AtomicBool,
167}
168
169#[derive(Debug)]
170struct ClientShutdown {
171 pending_drops: TaskTracker,
172 executed: AtomicBool,
173}
174
175impl Client {
176 pub async fn with_uri_str(uri: impl AsRef<str>) -> Result<Self> {
182 let options = ClientOptions::parse(uri.as_ref()).await?;
183
184 Client::with_options(options)
185 }
186
187 pub fn with_options(options: ClientOptions) -> Result<Self> {
189 options.validate()?;
190
191 let (cleanup_tx, cleanup_rx) = tokio::sync::oneshot::channel::<BoxFuture<'static, ()>>();
193 crate::runtime::spawn(async move {
194 if let Ok(cleanup) = cleanup_rx.await {
196 cleanup.await;
197 }
198 });
199 let end_sessions_token = std::sync::Mutex::new(AsyncDropToken {
200 tx: Some(cleanup_tx),
201 });
202
203 #[cfg(feature = "opentelemetry")]
204 let tracer = options.tracer();
205
206 let token_bucket = options
207 .adaptive_retries
208 .unwrap_or(false)
209 .then(|| Mutex::new(MAX_BUCKET_CAPACITY));
210
211 let inner = TrackingArc::new(ClientInner {
212 topology: Topology::new(options.clone())?,
213 session_pool: ServerSessionPool::new(),
214 options,
215 shutdown: ClientShutdown {
216 pending_drops: TaskTracker::new(),
217 executed: AtomicBool::new(false),
218 },
219 dropped: AtomicBool::new(false),
220 end_sessions_token,
221 token_bucket,
222 #[cfg(feature = "in-use-encryption")]
223 csfle: Default::default(),
224 #[cfg(feature = "opentelemetry")]
225 tracer,
226 #[cfg(test)]
227 disable_command_events: AtomicBool::new(false),
228 });
229 Ok(Self { inner })
230 }
231
232 #[cfg(feature = "in-use-encryption")]
254 pub fn encrypted_builder(
255 client_options: ClientOptions,
256 key_vault_namespace: crate::Namespace,
257 kms_providers: impl IntoIterator<
258 Item = (
259 mongocrypt::ctx::KmsProvider,
260 crate::bson::Document,
261 Option<options::TlsOptions>,
262 ),
263 >,
264 ) -> Result<EncryptedClientBuilder> {
265 Ok(EncryptedClientBuilder::new(
266 client_options,
267 csfle::options::AutoEncryptionOptions::new(
268 key_vault_namespace,
269 csfle::options::KmsProviders::new(kms_providers)?,
270 ),
271 ))
272 }
273
274 pub(crate) async fn should_auto_encrypt(&self) -> bool {
276 #[cfg(feature = "in-use-encryption")]
277 {
278 let csfle = self.inner.csfle.read().await;
279 match *csfle {
280 Some(ref csfle) => csfle
281 .opts()
282 .bypass_auto_encryption
283 .map(|b| !b)
284 .unwrap_or(true),
285 None => false,
286 }
287 }
288 #[cfg(not(feature = "in-use-encryption"))]
289 {
290 false
291 }
292 }
293
294 #[cfg(all(test, feature = "in-use-encryption"))]
295 pub(crate) async fn mongocryptd_spawned(&self) -> bool {
296 self.inner
297 .csfle
298 .read()
299 .await
300 .as_ref()
301 .is_some_and(|cs| cs.exec().mongocryptd_spawned())
302 }
303
304 #[cfg(all(test, feature = "in-use-encryption"))]
305 pub(crate) async fn has_mongocryptd_client(&self) -> bool {
306 self.inner
307 .csfle
308 .read()
309 .await
310 .as_ref()
311 .is_some_and(|cs| cs.exec().has_mongocryptd_client())
312 }
313
314 fn test_command_event_channel(&self) -> Option<&options::TestEventSender> {
315 #[cfg(test)]
316 {
317 self.inner
318 .options
319 .test_options
320 .as_ref()
321 .and_then(|t| t.async_event_listener.as_ref())
322 }
323 #[cfg(not(test))]
324 {
325 None
326 }
327 }
328
329 pub(crate) async fn emit_command_event(&self, generate_event: impl FnOnce() -> CommandEvent) {
330 #[cfg(test)]
331 if self
332 .inner
333 .disable_command_events
334 .load(std::sync::atomic::Ordering::SeqCst)
335 {
336 return;
337 }
338 #[cfg(feature = "tracing-unstable")]
339 let tracing_emitter = if trace_or_log_enabled!(
340 target: COMMAND_TRACING_EVENT_TARGET,
341 TracingOrLogLevel::Debug
342 ) {
343 Some(CommandTracingEventEmitter::new(
344 self.inner.options.tracing_max_document_length_bytes,
345 self.inner.topology.id,
346 ))
347 } else {
348 None
349 };
350 let test_channel = self.test_command_event_channel();
351 let should_send = test_channel.is_some() || self.options().command_event_handler.is_some();
352 #[cfg(feature = "tracing-unstable")]
353 let should_send = should_send || tracing_emitter.is_some();
354 if !should_send {
355 return;
356 }
357
358 let event = generate_event();
359 if let Some(tx) = test_channel {
360 let (msg, ack) = crate::runtime::AcknowledgedMessage::package(event.clone());
361 let _ = tx.send(msg).await;
362 ack.wait_for_acknowledgment().await;
363 }
364 #[cfg(feature = "tracing-unstable")]
365 if let Some(ref tracing_emitter) = tracing_emitter {
366 tracing_emitter.handle(event.clone());
367 }
368 if let Some(handler) = &self.options().command_event_handler {
369 handler.handle(event);
370 }
371 }
372
373 pub fn selection_criteria(&self) -> Option<&SelectionCriteria> {
375 self.inner.options.selection_criteria.as_ref()
376 }
377
378 pub fn read_concern(&self) -> Option<&ReadConcern> {
380 self.inner.options.read_concern.as_ref()
381 }
382
383 pub fn write_concern(&self) -> Option<&WriteConcern> {
385 self.inner.options.write_concern.as_ref()
386 }
387
388 pub fn database(&self, name: &str) -> Database {
395 Database::new(self.clone(), name, None)
396 }
397
398 pub fn database_with_options(&self, name: &str, options: DatabaseOptions) -> Database {
405 Database::new(self.clone(), name, Some(options))
406 }
407
408 pub fn default_database(&self) -> Option<Database> {
413 self.inner
414 .options
415 .default_database
416 .as_ref()
417 .map(|db_name| self.database(db_name))
418 }
419
420 pub fn append_metadata(&self, driver_info: DriverInfo) -> Result<()> {
422 self.inner
423 .topology
424 .metadata
425 .write()
426 .unwrap()
427 .append(driver_info)
428 }
429
430 pub(crate) fn register_async_drop(&self) -> AsyncDropToken {
431 let (cleanup_tx, cleanup_rx) = tokio::sync::oneshot::channel::<BoxFuture<'static, ()>>();
432 crate::runtime::spawn(self.inner.shutdown.pending_drops.track_future(async move {
433 if let Ok(cleanup) = cleanup_rx.await {
435 cleanup.await;
436 }
437 }));
438 AsyncDropToken {
439 tx: Some(cleanup_tx),
440 }
441 }
442
443 pub(crate) async fn check_in_server_session(&self, session: ServerSession) {
446 let timeout = self.inner.topology.watcher().logical_session_timeout();
447 self.inner.session_pool.check_in(session, timeout).await;
448 }
449
450 #[cfg(test)]
451 pub(crate) async fn clear_session_pool(&self) {
452 self.inner.session_pool.clear().await;
453 }
454
455 #[cfg(test)]
456 pub(crate) async fn is_session_checked_in(&self, id: &crate::bson::Document) -> bool {
457 self.inner.session_pool.contains(id).await
458 }
459
460 #[cfg(test)]
461 pub(crate) fn disable_command_events(&self, disable: bool) {
462 self.inner
463 .disable_command_events
464 .store(disable, std::sync::atomic::Ordering::SeqCst);
465 }
466
467 #[cfg(test)]
470 pub(crate) async fn test_select_server(
471 &self,
472 criteria: Option<&SelectionCriteria>,
473 ) -> Result<ServerAddress> {
474 let (server, _) = self
475 .select_server(criteria, None, OpSelectionInfo::new("Test select server"))
476 .await?;
477 Ok(server.address.clone())
478 }
479
480 async fn select_server(
483 &self,
484 criteria: Option<&SelectionCriteria>,
485 deprioritized: Option<&HashSet<ServerAddress>>,
486 op_info: OpSelectionInfo<'_>,
487 ) -> Result<(SelectedServer, SelectionCriteria)> {
488 let criteria =
489 criteria.unwrap_or(&SelectionCriteria::ReadPreference(ReadPreference::Primary));
490
491 let start_time = Instant::now();
492 let timeout = self
493 .inner
494 .options
495 .server_selection_timeout
496 .unwrap_or(DEFAULT_SERVER_SELECTION_TIMEOUT);
497
498 #[cfg(feature = "tracing-unstable")]
499 let event_emitter = ServerSelectionTracingEventEmitter::new(
500 self.inner.topology.id,
501 criteria,
502 op_info.name,
503 start_time,
504 timeout,
505 self.options().tracing_max_document_length_bytes,
506 );
507 #[cfg(feature = "tracing-unstable")]
508 event_emitter.emit_started_event(self.inner.topology.latest().description.clone());
509 #[cfg(feature = "tracing-unstable")]
511 let mut emitted_waiting_message = false;
512
513 let mut watcher = self.inner.topology.watcher().clone();
514 loop {
515 let state = watcher.observe_latest();
516 let override_slot;
517 let effective_criteria =
518 if let Some(oc) = (op_info.override_criteria)(criteria, &state.description) {
519 override_slot = oc;
520 &override_slot
521 } else {
522 criteria
523 };
524 let result = server_selection::attempt_to_select_server(
525 effective_criteria,
526 &state.description,
527 &state.servers(),
528 deprioritized,
529 );
530 match result {
531 Err(error) => {
532 #[cfg(feature = "tracing-unstable")]
533 event_emitter.emit_failed_event(&state.description, &error);
534
535 return Err(error);
536 }
537 Ok(result) => {
538 if let Some(server) = result {
539 #[cfg(feature = "tracing-unstable")]
540 event_emitter.emit_succeeded_event(&state.description, &server);
541
542 return Ok((server, effective_criteria.clone()));
543 } else {
544 #[cfg(feature = "tracing-unstable")]
545 if !emitted_waiting_message {
546 event_emitter.emit_waiting_event(&state.description);
547 emitted_waiting_message = true;
548 }
549
550 watcher.request_immediate_check();
551
552 let elapsed = start_time.elapsed();
553 let change_occurred = elapsed < timeout
554 && watcher
555 .wait_for_update(
556 timeout.checked_sub(elapsed).unwrap_or(Duration::ZERO),
557 )
558 .await;
559 if !change_occurred {
560 let error: Error = ErrorKind::ServerSelection {
561 message: state
562 .description
563 .server_selection_timeout_error_message(criteria),
564 }
565 .into();
566
567 #[cfg(feature = "tracing-unstable")]
568 event_emitter.emit_failed_event(&state.description, &error);
569
570 return Err(error);
571 }
572 }
573 }
574 }
575 }
576 }
577
578 #[cfg(all(test, feature = "dns-resolver"))]
579 pub(crate) fn get_hosts(&self) -> Vec<String> {
580 let state = self.inner.topology.latest();
581
582 state
583 .servers()
584 .keys()
585 .map(|stream_address| format!("{stream_address}"))
586 .collect()
587 }
588
589 #[cfg(test)]
590 pub(crate) async fn sync_workers(&self) {
591 self.inner.topology.updater().sync_workers().await;
592 }
593
594 #[cfg(test)]
595 pub(crate) fn topology_description(&self) -> crate::sdam::TopologyDescription {
596 self.inner.topology.latest().description.clone()
597 }
598
599 pub(crate) fn is_sharded(&self) -> bool {
600 self.inner.topology.latest().description.topology_type == TopologyType::Sharded
601 }
602
603 #[cfg(test)]
604 pub(crate) fn topology(&self) -> &Topology {
605 &self.inner.topology
606 }
607
608 #[cfg(feature = "in-use-encryption")]
609 pub(crate) async fn primary_description(&self) -> Option<crate::sdam::ServerDescription> {
610 let start_time = Instant::now();
611 let timeout = self
612 .inner
613 .options
614 .server_selection_timeout
615 .unwrap_or(DEFAULT_SERVER_SELECTION_TIMEOUT);
616 let mut watcher = self.inner.topology.watcher().clone();
617 loop {
618 let topology = watcher.observe_latest();
619 if let Some(desc) = topology.description.primary() {
620 return Some(desc.clone());
621 }
622 let remaining = timeout
623 .checked_sub(start_time.elapsed())
624 .unwrap_or(Duration::ZERO);
625 if !watcher.wait_for_update(remaining).await {
626 return None;
627 }
628 }
629 }
630
631 pub(crate) fn weak(&self) -> WeakClient {
632 WeakClient {
633 inner: TrackingArc::downgrade(&self.inner),
634 }
635 }
636
637 #[cfg(feature = "in-use-encryption")]
638 pub(crate) async fn auto_encryption_opts(
639 &self,
640 ) -> Option<tokio::sync::RwLockReadGuard<'_, csfle::options::AutoEncryptionOptions>> {
641 tokio::sync::RwLockReadGuard::try_map(self.inner.csfle.read().await, |csfle| {
642 csfle.as_ref().map(|cs| cs.opts())
643 })
644 .ok()
645 }
646
647 pub(crate) fn options(&self) -> &ClientOptions {
648 &self.inner.options
649 }
650
651 pub(crate) async fn end_all_sessions(&self) {
653 const MAX_END_SESSIONS_BATCH_SIZE: usize = 10_000;
655
656 let mut watcher = self.inner.topology.watcher().clone();
657 let selection_criteria =
658 SelectionCriteria::from(ReadPreference::PrimaryPreferred { options: None });
659
660 let session_ids = self.inner.session_pool.get_session_ids().await;
661 for chunk in session_ids.chunks(MAX_END_SESSIONS_BATCH_SIZE) {
662 let state = watcher.observe_latest();
663 let Ok(Some(_)) = attempt_to_select_server(
664 &selection_criteria,
665 &state.description,
666 &state.servers(),
667 None,
668 ) else {
669 return;
672 };
673
674 let end_sessions = doc! {
675 "endSessions": chunk,
676 };
677 let _ = self
678 .database("admin")
679 .run_command(end_sessions)
680 .selection_criteria(selection_criteria.clone())
681 .await;
682 }
683 }
684
685 #[cfg(feature = "opentelemetry")]
686 pub(crate) fn tracer(&self) -> &opentelemetry::global::BoxedTracer {
687 &self.inner.tracer
688 }
689}
690
691#[derive(Clone, Debug)]
692pub(crate) struct WeakClient {
693 inner: crate::tracking_arc::Weak<ClientInner>,
694}
695
696impl WeakClient {
697 pub(crate) fn upgrade(&self) -> Option<Client> {
698 self.inner.upgrade().map(|inner| Client { inner })
699 }
700}
701
702#[derive_where(Debug)]
703pub(crate) struct AsyncDropToken {
704 #[derive_where(skip)]
705 tx: Option<tokio::sync::oneshot::Sender<BoxFuture<'static, ()>>>,
706}
707
708impl AsyncDropToken {
709 pub(crate) fn spawn(&mut self, fut: impl Future<Output = ()> + Send + 'static) {
710 if let Some(tx) = self.tx.take() {
711 let _ = tx.send(fut.boxed());
712 } else {
713 #[cfg(debug_assertions)]
714 panic!("exhausted AsyncDropToken");
715 }
716 }
717}
718
719impl Drop for Client {
720 fn drop(&mut self) {
721 if !self.inner.shutdown.executed.load(Ordering::SeqCst)
722 && !self.inner.dropped.load(Ordering::SeqCst)
723 && TrackingArc::strong_count(&self.inner) == 1
724 {
725 self.inner.dropped.store(true, Ordering::SeqCst);
732 let client = self.clone();
733 self.inner
734 .end_sessions_token
735 .lock()
736 .unwrap()
737 .spawn(async move {
738 client.end_all_sessions().await;
739 });
740 }
741 }
742}
743
744pub(crate) struct OpSelectionInfo<'a> {
746 #[allow(unused)]
747 name: &'a str,
748 override_criteria: OverrideCriteriaFn,
749}
750
751impl<'a, T: crate::operation::Operation> From<&'a T> for OpSelectionInfo<'a> {
752 fn from(op: &'a T) -> Self {
753 Self {
754 name: crate::bson_compat::cstr_to_str(op.name()),
755 override_criteria: op.override_criteria(),
756 }
757 }
758}
759
760impl<'a> OpSelectionInfo<'a> {
761 fn new(name: &'a str) -> Self {
762 Self {
763 name,
764 override_criteria: |_, _| None,
765 }
766 }
767}
768
769const MAX_BUCKET_CAPACITY: u16 = 10_000;
772const RETRY_TOKEN_RETURN_RATE: u16 = 1;
773const RETRY_TOKEN_CONSUME_RATE: u16 = 10;
774impl Client {
775 pub(crate) async fn consume_from_token_bucket(&self) -> bool {
776 let Some(ref bucket) = self.inner.token_bucket else {
777 return true;
778 };
779 let mut tokens = bucket.lock().await;
780 if *tokens >= RETRY_TOKEN_CONSUME_RATE {
781 *tokens -= RETRY_TOKEN_CONSUME_RATE;
782 true
783 } else {
784 false
785 }
786 }
787
788 pub(crate) async fn deposit_success_in_token_bucket(&self, is_retry: bool) {
789 let Some(ref bucket) = self.inner.token_bucket else {
790 return;
791 };
792 let mut deposit = RETRY_TOKEN_RETURN_RATE;
793 if is_retry {
794 deposit += 10;
795 }
796 let mut tokens = bucket.lock().await;
797 *tokens = std::cmp::min(*tokens + deposit, MAX_BUCKET_CAPACITY);
798 }
799
800 pub(crate) async fn deposit_retry_error_in_token_bucket(&self) {
801 let Some(ref bucket) = self.inner.token_bucket else {
802 return;
803 };
804 let mut tokens = bucket.lock().await;
805 *tokens = std::cmp::min(*tokens + RETRY_TOKEN_RETURN_RATE, MAX_BUCKET_CAPACITY);
806 }
807
808 #[cfg(test)]
809 #[expect(dead_code)]
810 pub(crate) async fn get_num_tokens_in_bucket(&self) -> Option<u16> {
811 let bucket = self.inner.token_bucket.as_ref()?;
812 Some(*bucket.lock().await)
813 }
814
815 #[cfg(test)]
816 #[expect(dead_code)]
817 pub(crate) async fn set_num_tokens_in_bucket(&self, tokens: u16) {
818 let Some(ref bucket) = self.inner.token_bucket else {
819 return;
820 };
821 *bucket.lock().await = tokens;
822 }
823}