use std::collections::BTreeMap;
use std::sync::{Arc, Mutex};
use super::clock::MvccClock;
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct TxId(pub u64);
impl std::fmt::Display for TxId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "tx{}", self.0)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TxTimestampOrId {
Timestamp(u64),
Id(TxId),
}
#[derive(Clone, Debug, Default)]
pub struct ActiveTxRegistry {
inner: Arc<Mutex<RegistryInner>>,
}
#[derive(Debug, Default)]
struct RegistryInner {
by_id: BTreeMap<TxId, u64>,
by_ts: BTreeMap<u64, usize>,
}
impl ActiveTxRegistry {
pub fn new() -> Self {
Self::default()
}
pub fn register(&self, clock: &MvccClock) -> TxHandle {
let begin_ts = clock.tick();
let id = TxId(begin_ts);
let mut g = self.lock();
g.by_id.insert(id, begin_ts);
*g.by_ts.entry(begin_ts).or_insert(0) += 1;
drop(g);
TxHandle {
id,
begin_ts,
registry: self.clone(),
}
}
pub fn min_active_begin_ts(&self) -> Option<u64> {
self.lock().by_ts.keys().next().copied()
}
pub fn active_count(&self) -> usize {
self.lock().by_id.len()
}
fn unregister(&self, id: TxId, begin_ts: u64) {
let mut g = self.lock();
g.by_id.remove(&id);
if let Some(slot) = g.by_ts.get_mut(&begin_ts) {
*slot = slot.saturating_sub(1);
if *slot == 0 {
g.by_ts.remove(&begin_ts);
}
}
}
fn lock(&self) -> std::sync::MutexGuard<'_, RegistryInner> {
self.inner
.lock()
.unwrap_or_else(|e| panic!("sqlrite: ActiveTxRegistry mutex poisoned: {e}"))
}
}
#[derive(Debug)]
pub struct TxHandle {
id: TxId,
begin_ts: u64,
registry: ActiveTxRegistry,
}
impl TxHandle {
pub fn id(&self) -> TxId {
self.id
}
pub fn begin_ts(&self) -> u64 {
self.begin_ts
}
}
impl Drop for TxHandle {
fn drop(&mut self) {
self.registry.unregister(self.id, self.begin_ts);
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_registry_has_no_minimum() {
let r = ActiveTxRegistry::new();
assert_eq!(r.min_active_begin_ts(), None);
assert_eq!(r.active_count(), 0);
}
#[test]
fn register_advances_clock_and_updates_minimum() {
let clock = MvccClock::new(0);
let r = ActiveTxRegistry::new();
let h1 = r.register(&clock);
assert_eq!(h1.begin_ts(), 1);
assert_eq!(r.min_active_begin_ts(), Some(1));
let h2 = r.register(&clock);
assert_eq!(h2.begin_ts(), 2);
assert_eq!(r.min_active_begin_ts(), Some(1));
drop(h1);
assert_eq!(r.min_active_begin_ts(), Some(2));
drop(h2);
assert_eq!(r.min_active_begin_ts(), None);
}
#[test]
fn handles_carry_distinct_ids_and_unique_timestamps() {
let clock = MvccClock::new(0);
let r = ActiveTxRegistry::new();
let h1 = r.register(&clock);
let h2 = r.register(&clock);
assert_ne!(h1.id(), h2.id());
assert_ne!(h1.begin_ts(), h2.begin_ts());
assert_eq!(r.active_count(), 2);
}
#[test]
fn unregister_in_arbitrary_order_keeps_minimum_correct() {
let clock = MvccClock::new(0);
let r = ActiveTxRegistry::new();
let h1 = r.register(&clock); let h2 = r.register(&clock); let h3 = r.register(&clock); assert_eq!(r.min_active_begin_ts(), Some(1));
drop(h2);
assert_eq!(r.min_active_begin_ts(), Some(1));
drop(h1);
assert_eq!(r.min_active_begin_ts(), Some(3));
drop(h3);
assert_eq!(r.min_active_begin_ts(), None);
}
#[test]
fn registry_is_send_and_sync() {
fn assert_send<T: Send>() {}
fn assert_sync<T: Sync>() {}
assert_send::<ActiveTxRegistry>();
assert_sync::<ActiveTxRegistry>();
assert_send::<TxHandle>();
assert_sync::<TxHandle>();
}
#[test]
fn concurrent_registrations_are_consistent() {
use std::thread;
const THREADS: usize = 8;
const PER_THREAD: usize = 100;
let clock = Arc::new(MvccClock::new(0));
let registry = ActiveTxRegistry::new();
let handles: Vec<_> = (0..THREADS)
.map(|_| {
let c = Arc::clone(&clock);
let r = registry.clone();
thread::spawn(move || {
let mut held: Vec<TxHandle> = Vec::with_capacity(PER_THREAD);
for _ in 0..PER_THREAD {
held.push(r.register(&c));
}
held
})
})
.collect();
let mut all: Vec<TxHandle> = Vec::with_capacity(THREADS * PER_THREAD);
for h in handles {
all.extend(h.join().unwrap());
}
assert_eq!(registry.active_count(), THREADS * PER_THREAD);
let begins: std::collections::BTreeSet<u64> = all.iter().map(|h| h.begin_ts()).collect();
assert_eq!(
begins.len(),
THREADS * PER_THREAD,
"every concurrent registration must allocate a unique begin_ts"
);
drop(all);
assert_eq!(registry.active_count(), 0);
assert_eq!(registry.min_active_begin_ts(), None);
}
#[test]
fn tx_id_displays_with_prefix() {
assert_eq!(format!("{}", TxId(7)), "tx7");
}
#[test]
fn tx_timestamp_or_id_round_trips() {
let a = TxTimestampOrId::Timestamp(42);
let b = TxTimestampOrId::Id(TxId(42));
assert_ne!(a, b);
assert_eq!(a, a);
assert_eq!(b, b);
}
}