mod cluster_time;
mod pool;
#[cfg(test)]
mod test;
use std::{
collections::HashSet,
sync::Arc,
time::{Duration, Instant},
};
use lazy_static::lazy_static;
use uuid::Uuid;
use crate::{
bson::{doc, spec::BinarySubtype, Binary, Bson, Document, Timestamp},
cmap::conn::PinnedConnectionHandle,
error::{ErrorKind, Result},
operation::{AbortTransaction, CommitTransaction, Operation},
options::{SessionOptions, TransactionOptions},
sdam::{ServerInfo, TransactionSupportStatus},
selection_criteria::SelectionCriteria,
Client,
};
pub use cluster_time::ClusterTime;
pub(super) use pool::ServerSessionPool;
use super::{options::ServerAddress, AsyncDropToken};
lazy_static! {
pub(crate) static ref SESSIONS_UNSUPPORTED_COMMANDS: HashSet<&'static str> = {
let mut hash_set = HashSet::new();
hash_set.insert("killcursors");
hash_set.insert("parallelcollectionscan");
hash_set
};
}
#[derive(Debug)]
pub struct ClientSession {
cluster_time: Option<ClusterTime>,
server_session: ServerSession,
client: Client,
is_implicit: bool,
options: Option<SessionOptions>,
drop_token: AsyncDropToken,
pub(crate) transaction: Transaction,
pub(crate) snapshot_time: Option<Timestamp>,
pub(crate) operation_time: Option<Timestamp>,
#[cfg(test)]
pub(crate) convenient_transaction_timeout: Option<Duration>,
}
#[derive(Debug)]
pub(crate) struct Transaction {
pub(crate) state: TransactionState,
pub(crate) options: Option<TransactionOptions>,
pub(crate) pinned: Option<TransactionPin>,
pub(crate) recovery_token: Option<Document>,
}
impl Transaction {
pub(crate) fn start(&mut self, options: Option<TransactionOptions>) {
self.state = TransactionState::Starting;
self.options = options;
self.recovery_token = None;
}
pub(crate) fn commit(&mut self, data_committed: bool) {
self.state = TransactionState::Committed { data_committed };
}
pub(crate) fn abort(&mut self) {
self.state = TransactionState::Aborted;
self.options = None;
self.pinned = None;
}
pub(crate) fn reset(&mut self) {
self.state = TransactionState::None;
self.options = None;
self.pinned = None;
self.recovery_token = None;
}
pub(crate) fn pinned_mongos(&self) -> Option<&SelectionCriteria> {
match &self.pinned {
Some(TransactionPin::Mongos(s)) => Some(s),
_ => None,
}
}
pub(crate) fn pinned_connection(&self) -> Option<&PinnedConnectionHandle> {
match &self.pinned {
Some(TransactionPin::Connection(c)) => Some(c),
_ => None,
}
}
fn take(&mut self) -> Self {
Transaction {
state: self.state.clone(),
options: self.options.take(),
pinned: self.pinned.take(),
recovery_token: self.recovery_token.take(),
}
}
}
impl Default for Transaction {
fn default() -> Self {
Self {
state: TransactionState::None,
options: None,
pinned: None,
recovery_token: None,
}
}
}
#[derive(Clone, Debug, PartialEq)]
pub(crate) enum TransactionState {
None,
Starting,
InProgress,
Committed {
data_committed: bool,
},
Aborted,
}
#[derive(Debug)]
pub(crate) enum TransactionPin {
Mongos(SelectionCriteria),
Connection(PinnedConnectionHandle),
}
impl ClientSession {
pub(crate) async fn new(
client: Client,
options: Option<SessionOptions>,
is_implicit: bool,
) -> Self {
let timeout = client.inner.topology.logical_session_timeout();
let server_session = client.inner.session_pool.check_out(timeout).await;
Self {
drop_token: client.register_async_drop(),
client,
server_session,
cluster_time: None,
is_implicit,
options,
transaction: Default::default(),
snapshot_time: None,
operation_time: None,
#[cfg(test)]
convenient_transaction_timeout: None,
}
}
pub fn client(&self) -> Client {
self.client.clone()
}
pub fn id(&self) -> &Document {
&self.server_session.id
}
pub(crate) fn is_implicit(&self) -> bool {
self.is_implicit
}
pub(crate) fn in_transaction(&self) -> bool {
self.transaction.state == TransactionState::Starting
|| self.transaction.state == TransactionState::InProgress
}
pub fn cluster_time(&self) -> Option<&ClusterTime> {
self.cluster_time.as_ref()
}
pub fn options(&self) -> Option<&SessionOptions> {
self.options.as_ref()
}
pub fn advance_cluster_time(&mut self, to: &ClusterTime) {
if self.cluster_time().map(|ct| ct < to).unwrap_or(true) {
self.cluster_time = Some(to.clone());
}
}
pub fn advance_operation_time(&mut self, ts: Timestamp) {
self.operation_time = match self.operation_time {
Some(current_op_time) if current_op_time < ts => Some(ts),
None => Some(ts),
_ => self.operation_time,
}
}
pub fn operation_time(&self) -> Option<Timestamp> {
self.operation_time
}
pub(crate) fn causal_consistency(&self) -> bool {
self.options()
.and_then(|opts| opts.causal_consistency)
.unwrap_or(!self.is_implicit())
}
pub(crate) fn mark_dirty(&mut self) {
self.server_session.dirty = true;
}
pub(crate) fn update_last_use(&mut self) {
self.server_session.last_use = Instant::now();
}
pub(crate) fn txn_number(&self) -> i64 {
self.server_session.txn_number
}
pub(crate) fn increment_txn_number(&mut self) {
self.server_session.txn_number += 1;
}
pub(crate) fn get_and_increment_txn_number(&mut self) -> i64 {
self.increment_txn_number();
self.server_session.txn_number
}
pub(crate) fn pin_mongos(&mut self, address: ServerAddress) {
self.transaction.pinned = Some(TransactionPin::Mongos(SelectionCriteria::Predicate(
Arc::new(move |server_info: &ServerInfo| *server_info.address() == address),
)));
}
pub(crate) fn pin_connection(&mut self, handle: PinnedConnectionHandle) {
self.transaction.pinned = Some(TransactionPin::Connection(handle));
}
pub(crate) fn unpin(&mut self) {
self.transaction.pinned = None;
}
#[cfg(test)]
pub(crate) fn is_dirty(&self) -> bool {
self.server_session.dirty
}
pub async fn start_transaction(
&mut self,
options: impl Into<Option<TransactionOptions>>,
) -> Result<()> {
if self
.options
.as_ref()
.and_then(|o| o.snapshot)
.unwrap_or(false)
{
return Err(ErrorKind::Transaction {
message: "Transactions are not supported in snapshot sessions".into(),
}
.into());
}
match self.transaction.state {
TransactionState::Starting | TransactionState::InProgress => {
return Err(ErrorKind::Transaction {
message: "transaction already in progress".into(),
}
.into());
}
TransactionState::Committed { .. } => {
self.unpin(); }
_ => {}
}
match self.client.transaction_support_status().await? {
TransactionSupportStatus::Supported => {
let mut options = match options.into() {
Some(mut options) => {
if let Some(defaults) = self.default_transaction_options() {
merge_options!(
defaults,
options,
[
read_concern,
write_concern,
selection_criteria,
max_commit_time
]
);
}
Some(options)
}
None => self.default_transaction_options().cloned(),
};
resolve_options!(
self.client,
options,
[read_concern, write_concern, selection_criteria]
);
if let Some(ref options) = options {
if !options
.write_concern
.as_ref()
.map(|wc| wc.is_acknowledged())
.unwrap_or(true)
{
return Err(ErrorKind::Transaction {
message: "transactions do not support unacknowledged write concerns"
.into(),
}
.into());
}
}
self.increment_txn_number();
self.transaction.start(options);
Ok(())
}
_ => Err(ErrorKind::Transaction {
message: "Transactions are not supported by this deployment".into(),
}
.into()),
}
}
pub async fn commit_transaction(&mut self) -> Result<()> {
match &mut self.transaction.state {
TransactionState::None => Err(ErrorKind::Transaction {
message: "no transaction started".into(),
}
.into()),
TransactionState::Aborted => Err(ErrorKind::Transaction {
message: "Cannot call commitTransaction after calling abortTransaction".into(),
}
.into()),
TransactionState::Starting => {
self.transaction.commit(false);
Ok(())
}
TransactionState::InProgress => {
let commit_transaction = CommitTransaction::new(self.transaction.options.clone());
self.transaction.commit(true);
self.client
.clone()
.execute_operation(commit_transaction, self)
.await
}
TransactionState::Committed {
data_committed: true,
} => {
let mut commit_transaction =
CommitTransaction::new(self.transaction.options.clone());
commit_transaction.update_for_retry();
self.client
.clone()
.execute_operation(commit_transaction, self)
.await
}
TransactionState::Committed {
data_committed: false,
} => Ok(()),
}
}
pub async fn abort_transaction(&mut self) -> Result<()> {
match self.transaction.state {
TransactionState::None => Err(ErrorKind::Transaction {
message: "no transaction started".into(),
}
.into()),
TransactionState::Committed { .. } => Err(ErrorKind::Transaction {
message: "Cannot call abortTransaction after calling commitTransaction".into(),
}
.into()),
TransactionState::Aborted => Err(ErrorKind::Transaction {
message: "cannot call abortTransaction twice".into(),
}
.into()),
TransactionState::Starting => {
self.transaction.abort();
Ok(())
}
TransactionState::InProgress => {
let write_concern = self
.transaction
.options
.as_ref()
.and_then(|options| options.write_concern.as_ref())
.cloned();
let abort_transaction =
AbortTransaction::new(write_concern, self.transaction.pinned.take());
self.transaction.abort();
let _result = self
.client
.clone()
.execute_operation(abort_transaction, &mut *self)
.await;
Ok(())
}
}
}
pub async fn with_transaction<R, C, F>(
&mut self,
mut context: C,
mut callback: F,
options: impl Into<Option<TransactionOptions>>,
) -> Result<R>
where
F: for<'a> FnMut(&'a mut ClientSession, &'a mut C) -> BoxFuture<'a, Result<R>>,
{
let options = options.into();
let timeout = Duration::from_secs(120);
#[cfg(test)]
let timeout = self.convenient_transaction_timeout.unwrap_or(timeout);
let start = Instant::now();
use crate::error::{TRANSIENT_TRANSACTION_ERROR, UNKNOWN_TRANSACTION_COMMIT_RESULT};
'transaction: loop {
self.start_transaction(options.clone()).await?;
let ret = match callback(self, &mut context).await {
Ok(v) => v,
Err(e) => {
if matches!(
self.transaction.state,
TransactionState::Starting | TransactionState::InProgress
) {
self.abort_transaction().await?;
}
if e.contains_label(TRANSIENT_TRANSACTION_ERROR) && start.elapsed() < timeout {
continue 'transaction;
}
return Err(e);
}
};
if matches!(
self.transaction.state,
TransactionState::None
| TransactionState::Aborted
| TransactionState::Committed { .. }
) {
return Ok(ret);
}
'commit: loop {
match self.commit_transaction().await {
Ok(()) => return Ok(ret),
Err(e) => {
if e.is_max_time_ms_expired_error() || start.elapsed() >= timeout {
return Err(e);
}
if e.contains_label(UNKNOWN_TRANSACTION_COMMIT_RESULT) {
continue 'commit;
}
if e.contains_label(TRANSIENT_TRANSACTION_ERROR) {
continue 'transaction;
}
return Err(e);
}
}
}
}
}
fn default_transaction_options(&self) -> Option<&TransactionOptions> {
self.options
.as_ref()
.and_then(|options| options.default_transaction_options.as_ref())
}
}
pub type BoxFuture<'a, T> = std::pin::Pin<Box<dyn std::future::Future<Output = T> + Send + 'a>>;
struct DroppedClientSession {
cluster_time: Option<ClusterTime>,
server_session: ServerSession,
client: Client,
is_implicit: bool,
options: Option<SessionOptions>,
transaction: Transaction,
snapshot_time: Option<Timestamp>,
operation_time: Option<Timestamp>,
}
impl From<DroppedClientSession> for ClientSession {
fn from(dropped_session: DroppedClientSession) -> Self {
Self {
cluster_time: dropped_session.cluster_time,
server_session: dropped_session.server_session,
drop_token: dropped_session.client.register_async_drop(),
client: dropped_session.client,
is_implicit: dropped_session.is_implicit,
options: dropped_session.options,
transaction: dropped_session.transaction,
snapshot_time: dropped_session.snapshot_time,
operation_time: dropped_session.operation_time,
#[cfg(test)]
convenient_transaction_timeout: None,
}
}
}
impl Drop for ClientSession {
fn drop(&mut self) {
if self.transaction.state == TransactionState::InProgress {
let dropped_session = DroppedClientSession {
cluster_time: self.cluster_time.clone(),
server_session: self.server_session.clone(),
client: self.client.clone(),
is_implicit: self.is_implicit,
options: self.options.clone(),
transaction: self.transaction.take(),
snapshot_time: self.snapshot_time,
operation_time: self.operation_time,
};
self.drop_token.spawn(async move {
let mut session: ClientSession = dropped_session.into();
let _result = session.abort_transaction().await;
});
} else {
let client = self.client.clone();
let server_session = self.server_session.clone();
self.drop_token.spawn(async move {
client.check_in_server_session(server_session).await;
});
}
}
}
#[derive(Clone, Debug)]
pub(crate) struct ServerSession {
id: Document,
last_use: std::time::Instant,
dirty: bool,
txn_number: i64,
}
impl ServerSession {
fn new() -> Self {
let binary = Bson::Binary(Binary {
subtype: BinarySubtype::Uuid,
bytes: Uuid::new_v4().as_bytes().to_vec(),
});
Self {
id: doc! { "id": binary },
last_use: Instant::now(),
dirty: false,
txn_number: 0,
}
}
fn is_about_to_expire(&self, logical_session_timeout: Option<Duration>) -> bool {
let timeout = match logical_session_timeout {
Some(t) => t,
None => return false,
};
let expiration_date = self.last_use + timeout;
expiration_date < Instant::now() + Duration::from_secs(60)
}
}