1pub mod action;
2pub mod auth;
3#[cfg(feature = "in-use-encryption")]
4pub(crate) mod csfle;
5mod executor;
6pub mod options;
7pub mod session;
8
9use std::{
10 sync::{
11 atomic::{AtomicBool, Ordering},
12 Mutex as SyncMutex,
13 },
14 time::{Duration, Instant},
15};
16
17#[cfg(feature = "in-use-encryption")]
18pub use self::csfle::client_builder::*;
19use derive_where::derive_where;
20use futures_core::Future;
21use futures_util::FutureExt;
22
23#[cfg(feature = "tracing-unstable")]
24use crate::trace::{
25 command::CommandTracingEventEmitter,
26 server_selection::ServerSelectionTracingEventEmitter,
27 trace_or_log_enabled,
28 TracingOrLogLevel,
29 COMMAND_TRACING_EVENT_TARGET,
30};
31use crate::{
32 bson::doc,
33 concern::{ReadConcern, WriteConcern},
34 db::Database,
35 error::{Error, ErrorKind, Result},
36 event::command::CommandEvent,
37 id_set::IdSet,
38 operation::OverrideCriteriaFn,
39 options::{
40 ClientOptions,
41 DatabaseOptions,
42 DriverInfo,
43 ReadPreference,
44 SelectionCriteria,
45 ServerAddress,
46 },
47 sdam::{
48 server_selection::{self, attempt_to_select_server},
49 SelectedServer,
50 Topology,
51 },
52 tracking_arc::TrackingArc,
53 BoxFuture,
54 ClientSession,
55};
56
57pub(crate) use executor::{HELLO_COMMAND_NAMES, REDACTED_COMMANDS};
58pub(crate) use session::{ClusterTime, SESSIONS_UNSUPPORTED_COMMANDS};
59
60use session::{ServerSession, ServerSessionPool};
61
62const DEFAULT_SERVER_SELECTION_TIMEOUT: Duration = Duration::from_secs(30);
63
64#[derive(Debug, Clone)]
122pub struct Client {
123 inner: TrackingArc<ClientInner>,
124}
125
126#[allow(dead_code, unreachable_code, clippy::diverging_sub_expression)]
127const _: fn() = || {
128 fn assert_send<T: Send>(_t: T) {}
129 fn assert_sync<T: Sync>(_t: T) {}
130
131 let _c: super::Client = todo!();
132 assert_send(_c);
133 assert_sync(_c);
134};
135
136#[derive(Debug)]
137struct ClientInner {
138 topology: Topology,
139 options: ClientOptions,
140 session_pool: ServerSessionPool,
141 shutdown: Shutdown,
142 dropped: AtomicBool,
143 end_sessions_token: std::sync::Mutex<AsyncDropToken>,
144 #[cfg(feature = "in-use-encryption")]
145 csfle: tokio::sync::RwLock<Option<csfle::ClientState>>,
146 #[cfg(feature = "opentelemetry")]
147 tracer: opentelemetry::global::BoxedTracer,
148 #[cfg(test)]
149 disable_command_events: AtomicBool,
150}
151
152#[derive(Debug)]
153struct Shutdown {
154 pending_drops: SyncMutex<IdSet<crate::runtime::AsyncJoinHandle<()>>>,
155 executed: AtomicBool,
156}
157
158impl Client {
159 pub async fn with_uri_str(uri: impl AsRef<str>) -> Result<Self> {
165 let options = ClientOptions::parse(uri.as_ref()).await?;
166
167 Client::with_options(options)
168 }
169
170 pub fn with_options(options: ClientOptions) -> Result<Self> {
172 options.validate()?;
173
174 let (cleanup_tx, cleanup_rx) = tokio::sync::oneshot::channel::<BoxFuture<'static, ()>>();
176 crate::runtime::spawn(async move {
177 if let Ok(cleanup) = cleanup_rx.await {
179 cleanup.await;
180 }
181 });
182 let end_sessions_token = std::sync::Mutex::new(AsyncDropToken {
183 tx: Some(cleanup_tx),
184 });
185
186 #[cfg(feature = "opentelemetry")]
187 let tracer = options.tracer();
188
189 let inner = TrackingArc::new(ClientInner {
190 topology: Topology::new(options.clone())?,
191 session_pool: ServerSessionPool::new(),
192 options,
193 shutdown: Shutdown {
194 pending_drops: SyncMutex::new(IdSet::new()),
195 executed: AtomicBool::new(false),
196 },
197 dropped: AtomicBool::new(false),
198 end_sessions_token,
199 #[cfg(feature = "in-use-encryption")]
200 csfle: Default::default(),
201 #[cfg(feature = "opentelemetry")]
202 tracer,
203 #[cfg(test)]
204 disable_command_events: AtomicBool::new(false),
205 });
206 Ok(Self { inner })
207 }
208
209 #[cfg(feature = "in-use-encryption")]
231 pub fn encrypted_builder(
232 client_options: ClientOptions,
233 key_vault_namespace: crate::Namespace,
234 kms_providers: impl IntoIterator<
235 Item = (
236 mongocrypt::ctx::KmsProvider,
237 crate::bson::Document,
238 Option<options::TlsOptions>,
239 ),
240 >,
241 ) -> Result<EncryptedClientBuilder> {
242 Ok(EncryptedClientBuilder::new(
243 client_options,
244 csfle::options::AutoEncryptionOptions::new(
245 key_vault_namespace,
246 csfle::options::KmsProviders::new(kms_providers)?,
247 ),
248 ))
249 }
250
251 pub(crate) async fn should_auto_encrypt(&self) -> bool {
253 #[cfg(feature = "in-use-encryption")]
254 {
255 let csfle = self.inner.csfle.read().await;
256 match *csfle {
257 Some(ref csfle) => csfle
258 .opts()
259 .bypass_auto_encryption
260 .map(|b| !b)
261 .unwrap_or(true),
262 None => false,
263 }
264 }
265 #[cfg(not(feature = "in-use-encryption"))]
266 {
267 false
268 }
269 }
270
271 #[cfg(all(test, feature = "in-use-encryption"))]
272 pub(crate) async fn mongocryptd_spawned(&self) -> bool {
273 self.inner
274 .csfle
275 .read()
276 .await
277 .as_ref()
278 .is_some_and(|cs| cs.exec().mongocryptd_spawned())
279 }
280
281 #[cfg(all(test, feature = "in-use-encryption"))]
282 pub(crate) async fn has_mongocryptd_client(&self) -> bool {
283 self.inner
284 .csfle
285 .read()
286 .await
287 .as_ref()
288 .is_some_and(|cs| cs.exec().has_mongocryptd_client())
289 }
290
291 fn test_command_event_channel(&self) -> Option<&options::TestEventSender> {
292 #[cfg(test)]
293 {
294 self.inner
295 .options
296 .test_options
297 .as_ref()
298 .and_then(|t| t.async_event_listener.as_ref())
299 }
300 #[cfg(not(test))]
301 {
302 None
303 }
304 }
305
306 pub(crate) async fn emit_command_event(&self, generate_event: impl FnOnce() -> CommandEvent) {
307 #[cfg(test)]
308 if self
309 .inner
310 .disable_command_events
311 .load(std::sync::atomic::Ordering::SeqCst)
312 {
313 return;
314 }
315 #[cfg(feature = "tracing-unstable")]
316 let tracing_emitter = if trace_or_log_enabled!(
317 target: COMMAND_TRACING_EVENT_TARGET,
318 TracingOrLogLevel::Debug
319 ) {
320 Some(CommandTracingEventEmitter::new(
321 self.inner.options.tracing_max_document_length_bytes,
322 self.inner.topology.id,
323 ))
324 } else {
325 None
326 };
327 let test_channel = self.test_command_event_channel();
328 let should_send = test_channel.is_some() || self.options().command_event_handler.is_some();
329 #[cfg(feature = "tracing-unstable")]
330 let should_send = should_send || tracing_emitter.is_some();
331 if !should_send {
332 return;
333 }
334
335 let event = generate_event();
336 if let Some(tx) = test_channel {
337 let (msg, ack) = crate::runtime::AcknowledgedMessage::package(event.clone());
338 let _ = tx.send(msg).await;
339 ack.wait_for_acknowledgment().await;
340 }
341 #[cfg(feature = "tracing-unstable")]
342 if let Some(ref tracing_emitter) = tracing_emitter {
343 tracing_emitter.handle(event.clone());
344 }
345 if let Some(handler) = &self.options().command_event_handler {
346 handler.handle(event);
347 }
348 }
349
350 pub fn selection_criteria(&self) -> Option<&SelectionCriteria> {
352 self.inner.options.selection_criteria.as_ref()
353 }
354
355 pub fn read_concern(&self) -> Option<&ReadConcern> {
357 self.inner.options.read_concern.as_ref()
358 }
359
360 pub fn write_concern(&self) -> Option<&WriteConcern> {
362 self.inner.options.write_concern.as_ref()
363 }
364
365 pub fn database(&self, name: &str) -> Database {
372 Database::new(self.clone(), name, None)
373 }
374
375 pub fn database_with_options(&self, name: &str, options: DatabaseOptions) -> Database {
382 Database::new(self.clone(), name, Some(options))
383 }
384
385 pub fn default_database(&self) -> Option<Database> {
390 self.inner
391 .options
392 .default_database
393 .as_ref()
394 .map(|db_name| self.database(db_name))
395 }
396
397 pub fn append_metadata(&self, driver_info: DriverInfo) -> Result<()> {
399 self.inner
400 .topology
401 .metadata
402 .write()
403 .unwrap()
404 .append(driver_info)
405 }
406
407 pub(crate) fn register_async_drop(&self) -> AsyncDropToken {
408 let (cleanup_tx, cleanup_rx) = tokio::sync::oneshot::channel::<BoxFuture<'static, ()>>();
409 let (id_tx, id_rx) = tokio::sync::oneshot::channel::<crate::id_set::Id>();
410 let weak = self.weak();
411 let handle = crate::runtime::spawn(async move {
412 let id = id_rx.await.unwrap();
415 if let Ok(cleanup) = cleanup_rx.await {
417 cleanup.await;
418 }
419 if let Some(client) = weak.upgrade() {
420 client
421 .inner
422 .shutdown
423 .pending_drops
424 .lock()
425 .unwrap()
426 .remove(&id);
427 }
428 });
429 let id = self
430 .inner
431 .shutdown
432 .pending_drops
433 .lock()
434 .unwrap()
435 .insert(handle);
436 let _ = id_tx.send(id);
437 AsyncDropToken {
438 tx: Some(cleanup_tx),
439 }
440 }
441
442 pub(crate) async fn check_in_server_session(&self, session: ServerSession) {
445 let timeout = self.inner.topology.watcher().logical_session_timeout();
446 self.inner.session_pool.check_in(session, timeout).await;
447 }
448
449 #[cfg(test)]
450 pub(crate) async fn clear_session_pool(&self) {
451 self.inner.session_pool.clear().await;
452 }
453
454 #[cfg(test)]
455 pub(crate) async fn is_session_checked_in(&self, id: &crate::bson::Document) -> bool {
456 self.inner.session_pool.contains(id).await
457 }
458
459 #[cfg(test)]
460 pub(crate) fn disable_command_events(&self, disable: bool) {
461 self.inner
462 .disable_command_events
463 .store(disable, std::sync::atomic::Ordering::SeqCst);
464 }
465
466 #[cfg(test)]
469 pub(crate) async fn test_select_server(
470 &self,
471 criteria: Option<&SelectionCriteria>,
472 ) -> Result<ServerAddress> {
473 let (server, _) = self
474 .select_server(criteria, "Test select server", None, |_, _| None)
475 .await?;
476 Ok(server.address.clone())
477 }
478
479 async fn select_server(
482 &self,
483 criteria: Option<&SelectionCriteria>,
484 #[allow(unused_variables)] operation_name: &str,
486 deprioritized: Option<&ServerAddress>,
487 override_criteria: OverrideCriteriaFn,
488 ) -> Result<(SelectedServer, SelectionCriteria)> {
489 let criteria =
490 criteria.unwrap_or(&SelectionCriteria::ReadPreference(ReadPreference::Primary));
491
492 let start_time = Instant::now();
493 let timeout = self
494 .inner
495 .options
496 .server_selection_timeout
497 .unwrap_or(DEFAULT_SERVER_SELECTION_TIMEOUT);
498
499 #[cfg(feature = "tracing-unstable")]
500 let event_emitter = ServerSelectionTracingEventEmitter::new(
501 self.inner.topology.id,
502 criteria,
503 operation_name,
504 start_time,
505 timeout,
506 self.options().tracing_max_document_length_bytes,
507 );
508 #[cfg(feature = "tracing-unstable")]
509 event_emitter.emit_started_event(self.inner.topology.latest().description.clone());
510 #[cfg(feature = "tracing-unstable")]
512 let mut emitted_waiting_message = false;
513
514 let mut watcher = self.inner.topology.watcher().clone();
515 loop {
516 let state = watcher.observe_latest();
517 let override_slot;
518 let effective_criteria =
519 if let Some(oc) = override_criteria(criteria, &state.description) {
520 override_slot = oc;
521 &override_slot
522 } else {
523 criteria
524 };
525 let result = server_selection::attempt_to_select_server(
526 effective_criteria,
527 &state.description,
528 &state.servers(),
529 deprioritized,
530 );
531 match result {
532 Err(error) => {
533 #[cfg(feature = "tracing-unstable")]
534 event_emitter.emit_failed_event(&state.description, &error);
535
536 return Err(error);
537 }
538 Ok(result) => {
539 if let Some(server) = result {
540 #[cfg(feature = "tracing-unstable")]
541 event_emitter.emit_succeeded_event(&state.description, &server);
542
543 return Ok((server, effective_criteria.clone()));
544 } else {
545 #[cfg(feature = "tracing-unstable")]
546 if !emitted_waiting_message {
547 event_emitter.emit_waiting_event(&state.description);
548 emitted_waiting_message = true;
549 }
550
551 watcher.request_immediate_check();
552
553 let elapsed = start_time.elapsed();
554 let change_occurred = elapsed < timeout
555 && watcher
556 .wait_for_update(
557 timeout.checked_sub(elapsed).unwrap_or(Duration::ZERO),
558 )
559 .await;
560 if !change_occurred {
561 let error: Error = ErrorKind::ServerSelection {
562 message: state
563 .description
564 .server_selection_timeout_error_message(criteria),
565 }
566 .into();
567
568 #[cfg(feature = "tracing-unstable")]
569 event_emitter.emit_failed_event(&state.description, &error);
570
571 return Err(error);
572 }
573 }
574 }
575 }
576 }
577 }
578
579 #[cfg(all(test, feature = "dns-resolver"))]
580 pub(crate) fn get_hosts(&self) -> Vec<String> {
581 let state = self.inner.topology.latest();
582
583 state
584 .servers()
585 .keys()
586 .map(|stream_address| format!("{stream_address}"))
587 .collect()
588 }
589
590 #[cfg(test)]
591 pub(crate) async fn sync_workers(&self) {
592 self.inner.topology.updater().sync_workers().await;
593 }
594
595 #[cfg(test)]
596 pub(crate) fn topology_description(&self) -> crate::sdam::TopologyDescription {
597 self.inner.topology.latest().description.clone()
598 }
599
600 #[cfg(test)]
601 pub(crate) fn topology(&self) -> &Topology {
602 &self.inner.topology
603 }
604
605 #[cfg(feature = "in-use-encryption")]
606 pub(crate) async fn primary_description(&self) -> Option<crate::sdam::ServerDescription> {
607 let start_time = Instant::now();
608 let timeout = self
609 .inner
610 .options
611 .server_selection_timeout
612 .unwrap_or(DEFAULT_SERVER_SELECTION_TIMEOUT);
613 let mut watcher = self.inner.topology.watcher().clone();
614 loop {
615 let topology = watcher.observe_latest();
616 if let Some(desc) = topology.description.primary() {
617 return Some(desc.clone());
618 }
619 let remaining = timeout
620 .checked_sub(start_time.elapsed())
621 .unwrap_or(Duration::ZERO);
622 if !watcher.wait_for_update(remaining).await {
623 return None;
624 }
625 }
626 }
627
628 pub(crate) fn weak(&self) -> WeakClient {
629 WeakClient {
630 inner: TrackingArc::downgrade(&self.inner),
631 }
632 }
633
634 #[cfg(feature = "in-use-encryption")]
635 pub(crate) async fn auto_encryption_opts(
636 &self,
637 ) -> Option<tokio::sync::RwLockReadGuard<'_, csfle::options::AutoEncryptionOptions>> {
638 tokio::sync::RwLockReadGuard::try_map(self.inner.csfle.read().await, |csfle| {
639 csfle.as_ref().map(|cs| cs.opts())
640 })
641 .ok()
642 }
643
644 pub(crate) fn options(&self) -> &ClientOptions {
645 &self.inner.options
646 }
647
648 pub(crate) async fn end_all_sessions(&self) {
650 const MAX_END_SESSIONS_BATCH_SIZE: usize = 10_000;
652
653 let mut watcher = self.inner.topology.watcher().clone();
654 let selection_criteria =
655 SelectionCriteria::from(ReadPreference::PrimaryPreferred { options: None });
656
657 let session_ids = self.inner.session_pool.get_session_ids().await;
658 for chunk in session_ids.chunks(MAX_END_SESSIONS_BATCH_SIZE) {
659 let state = watcher.observe_latest();
660 let Ok(Some(_)) = attempt_to_select_server(
661 &selection_criteria,
662 &state.description,
663 &state.servers(),
664 None,
665 ) else {
666 return;
669 };
670
671 let end_sessions = doc! {
672 "endSessions": chunk,
673 };
674 let _ = self
675 .database("admin")
676 .run_command(end_sessions)
677 .selection_criteria(selection_criteria.clone())
678 .await;
679 }
680 }
681
682 #[cfg(feature = "opentelemetry")]
683 pub(crate) fn tracer(&self) -> &opentelemetry::global::BoxedTracer {
684 &self.inner.tracer
685 }
686}
687
688#[derive(Clone, Debug)]
689pub(crate) struct WeakClient {
690 inner: crate::tracking_arc::Weak<ClientInner>,
691}
692
693impl WeakClient {
694 pub(crate) fn upgrade(&self) -> Option<Client> {
695 self.inner.upgrade().map(|inner| Client { inner })
696 }
697}
698
699#[derive_where(Debug)]
700pub(crate) struct AsyncDropToken {
701 #[derive_where(skip)]
702 tx: Option<tokio::sync::oneshot::Sender<BoxFuture<'static, ()>>>,
703}
704
705impl AsyncDropToken {
706 pub(crate) fn spawn(&mut self, fut: impl Future<Output = ()> + Send + 'static) {
707 if let Some(tx) = self.tx.take() {
708 let _ = tx.send(fut.boxed());
709 } else {
710 #[cfg(debug_assertions)]
711 panic!("exhausted AsyncDropToken");
712 }
713 }
714
715 pub(crate) fn take(&mut self) -> Self {
716 Self { tx: self.tx.take() }
717 }
718}
719
720impl Drop for Client {
721 fn drop(&mut self) {
722 if !self.inner.shutdown.executed.load(Ordering::SeqCst)
723 && !self.inner.dropped.load(Ordering::SeqCst)
724 && TrackingArc::strong_count(&self.inner) == 1
725 {
726 self.inner.dropped.store(true, Ordering::SeqCst);
733 let client = self.clone();
734 self.inner
735 .end_sessions_token
736 .lock()
737 .unwrap()
738 .spawn(async move {
739 client.end_all_sessions().await;
740 });
741 }
742 }
743}