use core::ops::AddAssign;
use std::borrow::Cow;
use futures::lock::{Mutex, MutexGuard};
use smallvec_wrapper::TinyVec;
use txn_core::future::AsyncCm;
use wmark::{AsyncCloser, AsyncSpawner, AsyncWaterMark};
#[derive(Debug)]
pub(super) struct OracleInner<C> {
next_txn_ts: u64,
last_cleanup_ts: u64,
pub(super) committed_txns: TinyVec<CommittedTxn<C>>,
}
pub(super) enum CreateCommitTimestampResult<C> {
Timestamp(u64),
Conflict(Option<C>),
}
#[derive(Debug)]
pub(super) struct Oracle<C, S>
where
S: AsyncSpawner,
{
pub(super) write_serialize_lock: Mutex<()>,
pub(super) inner: Mutex<OracleInner<C>>,
pub(super) read_mark: AsyncWaterMark<S>,
pub(super) txn_mark: AsyncWaterMark<S>,
closer: AsyncCloser<S>,
}
impl<C, S> Oracle<C, S>
where
C: AsyncCm,
S: AsyncSpawner,
{
pub(super) async fn new_commit_ts(
&self,
done_read: &mut bool,
read_ts: u64,
mut conflict_manager: Option<C>,
) -> CreateCommitTimestampResult<C> {
let mut inner = self.inner.lock().await;
let conflict_manager = conflict_manager.take().unwrap();
for committed_txn in inner.committed_txns.iter() {
if committed_txn.ts <= read_ts {
continue;
}
if let Some(old_conflict_manager) = &committed_txn.conflict_manager {
if conflict_manager.has_conflict(old_conflict_manager).await {
return CreateCommitTimestampResult::Conflict(Some(conflict_manager));
}
}
}
let ts = {
if !*done_read {
self.read_mark.done(read_ts).unwrap();
*done_read = true;
}
self.cleanup_committed_transactions(true, &mut inner);
let ts = inner.next_txn_ts;
inner.next_txn_ts += 1;
self.txn_mark.begin(ts).unwrap();
ts
};
assert!(ts >= inner.last_cleanup_ts);
inner.committed_txns.push(CommittedTxn {
ts,
conflict_manager: Some(conflict_manager),
});
CreateCommitTimestampResult::Timestamp(ts)
}
#[inline]
fn cleanup_committed_transactions(
&self,
detect_conflicts: bool,
inner: &mut MutexGuard<OracleInner<C>>,
) {
if !detect_conflicts {
return;
}
let max_read_ts = self.read_mark.done_until().unwrap();
assert!(max_read_ts >= inner.last_cleanup_ts);
if max_read_ts == inner.last_cleanup_ts {
return;
}
inner.last_cleanup_ts = max_read_ts;
inner.committed_txns.retain(|txn| txn.ts > max_read_ts);
}
}
impl<C, S> Oracle<C, S>
where
S: AsyncSpawner,
{
#[inline]
pub fn new(
read_mark_name: Cow<'static, str>,
txn_mark_name: Cow<'static, str>,
next_txn_ts: u64,
) -> Self {
let closer = AsyncCloser::new(2);
let mut orc = Self {
write_serialize_lock: Mutex::new(()),
inner: Mutex::new(OracleInner {
next_txn_ts,
last_cleanup_ts: 0,
committed_txns: TinyVec::new(),
}),
read_mark: AsyncWaterMark::new(read_mark_name),
txn_mark: AsyncWaterMark::new(txn_mark_name),
closer,
};
orc.read_mark.init(orc.closer.clone());
orc.txn_mark.init(orc.closer.clone());
orc
}
#[inline]
pub(super) async fn read_ts(&self) -> u64 {
let read_ts = {
let inner = self.inner.lock().await;
let read_ts = inner.next_txn_ts - 1;
self.read_mark.begin(read_ts).unwrap();
read_ts
};
if let Err(e) = self.txn_mark.wait_for_mark(read_ts).await {
panic!("{e}");
}
read_ts
}
#[inline]
pub(super) async fn increment_next_ts(&self) {
self.inner.lock().await.next_txn_ts.add_assign(1);
}
#[inline]
pub(super) fn discard_at_or_below(&self) -> u64 {
self.read_mark.done_until().unwrap()
}
#[inline]
pub(super) fn done_read(&self, read_ts: u64) {
self.read_mark.done(read_ts).unwrap();
}
}
impl<C, S> Oracle<C, S>
where
S: AsyncSpawner,
{
#[inline]
pub(super) fn done_commit(&self, cts: u64) {
self.txn_mark.done(cts).unwrap();
}
}
impl<C, S> Oracle<C, S>
where
S: AsyncSpawner,
{
#[inline]
pub(super) async fn stop(&self) {
self.closer.signal_and_wait().await;
}
}
impl<C, S> Drop for Oracle<C, S>
where
S: AsyncSpawner,
{
fn drop(&mut self) {
self.closer.signal_and_wait_detach()
}
}
#[derive(Debug)]
pub(super) struct CommittedTxn<C> {
ts: u64,
conflict_manager: Option<C>,
}