use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use crate::persistent_artrie_core::concurrency::EpochManager;
#[derive(Debug, Clone, Default)]
pub struct MvccStats {
pub transactions_started: u64,
pub transactions_completed: u64,
pub active_transactions: u64,
pub total_reads: u64,
pub cache_hits: u64,
}
#[derive(Debug)]
pub struct MvccStatsTracker {
transactions_started: AtomicU64,
transactions_completed: AtomicU64,
active_transactions: AtomicU64,
total_reads: AtomicU64,
cache_hits: AtomicU64,
}
impl MvccStatsTracker {
pub fn new() -> Self {
Self {
transactions_started: AtomicU64::new(0),
transactions_completed: AtomicU64::new(0),
active_transactions: AtomicU64::new(0),
total_reads: AtomicU64::new(0),
cache_hits: AtomicU64::new(0),
}
}
pub fn record_start(&self) {
self.transactions_started.fetch_add(1, Ordering::Relaxed);
self.active_transactions.fetch_add(1, Ordering::Relaxed);
}
pub fn record_complete(&self) {
self.transactions_completed.fetch_add(1, Ordering::Relaxed);
self.active_transactions.fetch_sub(1, Ordering::Relaxed);
}
pub fn record_read(&self) {
self.total_reads.fetch_add(1, Ordering::Relaxed);
}
pub fn record_cache_hit(&self) {
self.cache_hits.fetch_add(1, Ordering::Relaxed);
}
pub fn stats(&self) -> MvccStats {
MvccStats {
transactions_started: self.transactions_started.load(Ordering::Relaxed),
transactions_completed: self.transactions_completed.load(Ordering::Relaxed),
active_transactions: self.active_transactions.load(Ordering::Relaxed),
total_reads: self.total_reads.load(Ordering::Relaxed),
cache_hits: self.cache_hits.load(Ordering::Relaxed),
}
}
}
impl Default for MvccStatsTracker {
fn default() -> Self {
Self::new()
}
}
pub trait TrieRoot: Send + Sync + 'static {
type Key: Copy;
type Value;
fn is_final(&self) -> bool;
fn find_child(&self, key: Self::Key) -> Option<Arc<Self>>;
fn get_value(&self) -> Option<Self::Value>;
}
#[derive(Debug)]
pub struct ReadTransaction<T: TrieRoot> {
root: Option<Arc<T>>,
version_id: u64,
epoch: u64,
epoch_manager: Arc<EpochManager>,
stats: Option<Arc<MvccStatsTracker>>,
}
impl<T: TrieRoot> ReadTransaction<T> {
pub fn begin(root: Arc<T>, epoch_manager: Arc<EpochManager>) -> Self {
let epoch = epoch_manager.enter_read();
let version_id = epoch_manager.current_epoch();
Self {
root: Some(root),
version_id,
epoch,
epoch_manager,
stats: None,
}
}
pub fn begin_with_stats(
root: Arc<T>,
epoch_manager: Arc<EpochManager>,
stats: Arc<MvccStatsTracker>,
) -> Self {
let epoch = epoch_manager.enter_read();
let version_id = epoch_manager.current_epoch();
stats.record_start();
Self {
root: Some(root),
version_id,
epoch,
epoch_manager,
stats: Some(stats),
}
}
#[inline]
pub fn version_id(&self) -> u64 {
self.version_id
}
#[inline]
pub fn epoch(&self) -> u64 {
self.epoch
}
#[inline]
pub fn root(&self) -> Option<&Arc<T>> {
self.root.as_ref()
}
}
impl<T: TrieRoot<Key = u8>> ReadTransaction<T> {
pub fn contains(&self, term: &[u8]) -> bool {
if let Some(stats) = &self.stats {
stats.record_read();
}
let Some(root) = &self.root else {
return false;
};
let mut current = Arc::clone(root);
for &key in term {
match current.find_child(key) {
Some(child) => current = child,
None => return false,
}
}
current.is_final()
}
pub fn get(&self, term: &[u8]) -> Option<T::Value> {
if let Some(stats) = &self.stats {
stats.record_read();
}
let root = self.root.as_ref()?;
let mut current = Arc::clone(root);
for &key in term {
match current.find_child(key) {
Some(child) => current = child,
None => return None,
}
}
if current.is_final() {
current.get_value()
} else {
None
}
}
}
impl<T: TrieRoot<Key = u32>> ReadTransaction<T> {
pub fn contains_str(&self, term: &str) -> bool {
if let Some(stats) = &self.stats {
stats.record_read();
}
let Some(root) = &self.root else {
return false;
};
let mut current = Arc::clone(root);
for c in term.chars() {
match current.find_child(c as u32) {
Some(child) => current = child,
None => return false,
}
}
current.is_final()
}
pub fn get_str(&self, term: &str) -> Option<T::Value> {
if let Some(stats) = &self.stats {
stats.record_read();
}
let root = self.root.as_ref()?;
let mut current = Arc::clone(root);
for c in term.chars() {
match current.find_child(c as u32) {
Some(child) => current = child,
None => return None,
}
}
if current.is_final() {
current.get_value()
} else {
None
}
}
}
impl<T: TrieRoot> Drop for ReadTransaction<T> {
fn drop(&mut self) {
self.epoch_manager.exit_read();
self.root = None;
if let Some(stats) = &self.stats {
stats.record_complete();
}
}
}
unsafe impl<T: TrieRoot> Send for ReadTransaction<T> {}
#[derive(Debug)]
pub struct EpochGuard {
epoch: u64,
epoch_manager: Arc<EpochManager>,
}
impl EpochGuard {
pub fn new(epoch_manager: Arc<EpochManager>) -> Self {
let epoch = epoch_manager.enter_read();
Self {
epoch,
epoch_manager,
}
}
#[inline]
pub fn epoch(&self) -> u64 {
self.epoch
}
}
impl Drop for EpochGuard {
fn drop(&mut self) {
self.epoch_manager.exit_read();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug)]
struct TestNode {
is_final: bool,
value: Option<u64>,
children: std::collections::HashMap<u8, Arc<TestNode>>,
}
impl TestNode {
fn new() -> Self {
Self {
is_final: false,
value: None,
children: std::collections::HashMap::new(),
}
}
fn with_final(mut self) -> Self {
self.is_final = true;
self
}
fn with_value(mut self, value: u64) -> Self {
self.value = Some(value);
self
}
fn with_child(mut self, key: u8, child: TestNode) -> Self {
self.children.insert(key, Arc::new(child));
self
}
}
impl TrieRoot for TestNode {
type Key = u8;
type Value = u64;
fn is_final(&self) -> bool {
self.is_final
}
fn find_child(&self, key: u8) -> Option<Arc<Self>> {
self.children.get(&key).cloned()
}
fn get_value(&self) -> Option<u64> {
self.value
}
}
#[test]
fn test_read_transaction_basic() {
let epoch_manager = Arc::new(EpochManager::new());
let leaf = TestNode::new().with_final().with_value(42);
let mid = TestNode::new().with_child(b'b', leaf);
let root = Arc::new(TestNode::new().with_child(b'a', mid));
let tx = ReadTransaction::begin(root, epoch_manager);
assert!(tx.contains(b"ab"));
assert!(!tx.contains(b"a"));
assert!(!tx.contains(b"abc"));
assert!(!tx.contains(b""));
assert_eq!(tx.get(b"ab"), Some(42));
assert_eq!(tx.get(b"a"), None);
}
#[test]
fn test_read_transaction_stats() {
let epoch_manager = Arc::new(EpochManager::new());
let stats = Arc::new(MvccStatsTracker::new());
let leaf = TestNode::new().with_final();
let root = Arc::new(TestNode::new().with_child(b'a', leaf));
{
let tx = ReadTransaction::begin_with_stats(
root.clone(),
epoch_manager.clone(),
stats.clone(),
);
tx.contains(b"a");
tx.contains(b"b");
let current_stats = stats.stats();
assert_eq!(current_stats.transactions_started, 1);
assert_eq!(current_stats.active_transactions, 1);
assert_eq!(current_stats.total_reads, 2);
}
let final_stats = stats.stats();
assert_eq!(final_stats.transactions_completed, 1);
assert_eq!(final_stats.active_transactions, 0);
}
#[test]
fn test_epoch_guard() {
let epoch_manager = Arc::new(EpochManager::new());
assert!(!epoch_manager.has_active_readers());
{
let _guard = EpochGuard::new(epoch_manager.clone());
assert!(epoch_manager.has_active_readers());
}
assert!(!epoch_manager.has_active_readers());
}
#[test]
fn test_version_id_and_epoch() {
let epoch_manager = Arc::new(EpochManager::new());
let root = Arc::new(TestNode::new());
epoch_manager.advance();
epoch_manager.advance();
let tx = ReadTransaction::begin(root, epoch_manager.clone());
assert!(tx.version_id() >= 2);
assert!(tx.epoch() >= 2);
}
}