use std::sync::{
Arc,
atomic::{AtomicU64, Ordering},
};
use reifydb_core::{
common::CommitVersion,
encoded::shape::{RowShape, RowShapeField},
key::{EncodableKey, transaction_version::TransactionVersionKey},
};
use reifydb_runtime::sync::mutex::Mutex;
use reifydb_type::{Result, value::r#type::Type};
use crate::single::SingleTransaction;
const BLOCK_SIZE: u64 = 100_000;
pub trait VersionProvider: Send + Sync + Clone {
fn next(&self) -> Result<CommitVersion>;
fn current(&self) -> Result<CommitVersion>;
fn advance_to(&self, _version: CommitVersion) {}
}
#[derive(Debug)]
struct VersionBlock {
last: u64,
current: u64,
}
impl VersionBlock {
fn new(start: u64) -> Self {
Self {
last: start + BLOCK_SIZE,
current: start,
}
}
}
#[derive(Clone)]
pub struct StandardVersionProvider {
single: SingleTransaction,
next_version: Arc<AtomicU64>,
current_block_end: Arc<AtomicU64>,
block_persist_lock: Arc<Mutex<()>>,
shape: RowShape,
}
impl StandardVersionProvider {
pub fn new(single: SingleTransaction) -> Result<Self> {
let shape = RowShape::new(vec![RowShapeField::unconstrained("version", Type::Uint8)]);
let current_version = Self::load_current_version(&shape, &single)?;
let first_block = VersionBlock::new(current_version);
Self::persist_version(&shape, &single, first_block.last)?;
Ok(Self {
single,
next_version: Arc::new(AtomicU64::new(first_block.current)),
current_block_end: Arc::new(AtomicU64::new(first_block.last)),
block_persist_lock: Arc::new(Mutex::new(())),
shape,
})
}
fn load_current_version(shape: &RowShape, single: &SingleTransaction) -> Result<u64> {
let key = TransactionVersionKey {}.encode();
let mut tx = single.begin_query([&key])?;
match tx.get(&key)? {
None => Ok(0),
Some(single) => Ok(shape.get_u64(&single.row, 0)),
}
}
fn persist_version(shape: &RowShape, single: &SingleTransaction, version: u64) -> Result<()> {
let key = TransactionVersionKey {}.encode();
let mut row = shape.allocate();
shape.set_u64(&mut row, 0, version);
let mut tx = single.begin_command([&key])?;
tx.set(&key, row)?;
tx.commit()
}
}
impl VersionProvider for StandardVersionProvider {
fn next(&self) -> Result<CommitVersion> {
let version = self.next_version.fetch_add(1, Ordering::SeqCst) + 1;
let block_end = self.current_block_end.load(Ordering::SeqCst);
if version <= block_end {
return Ok(CommitVersion(version));
}
let _lock = self.block_persist_lock.lock();
let block_end = self.current_block_end.load(Ordering::SeqCst);
if version <= block_end {
return Ok(CommitVersion(version));
}
let new_block_start = (version / BLOCK_SIZE) * BLOCK_SIZE;
let new_block_end = new_block_start + BLOCK_SIZE;
Self::persist_version(&self.shape, &self.single, new_block_end)?;
self.current_block_end.store(new_block_end, Ordering::SeqCst);
Ok(CommitVersion(version))
}
fn current(&self) -> Result<CommitVersion> {
Ok(CommitVersion(self.next_version.load(Ordering::SeqCst)))
}
fn advance_to(&self, version: CommitVersion) {
self.next_version.fetch_max(version.0, Ordering::SeqCst);
self.current_block_end.fetch_max(version.0, Ordering::SeqCst);
}
}
#[cfg(test)]
pub mod tests {
use std::{sync::Arc, thread};
use super::*;
#[test]
fn test_new_version_provider() {
let single = SingleTransaction::testing();
let provider = StandardVersionProvider::new(single).unwrap();
assert_eq!(provider.current().unwrap(), 0);
}
#[test]
fn test_next_version_sequential() {
let single = SingleTransaction::testing();
let provider = StandardVersionProvider::new(single).unwrap();
assert_eq!(provider.next().unwrap(), 1);
assert_eq!(provider.current().unwrap(), 1);
assert_eq!(provider.next().unwrap(), 2);
assert_eq!(provider.current().unwrap(), 2);
assert_eq!(provider.next().unwrap(), 3);
assert_eq!(provider.current().unwrap(), 3);
}
#[test]
fn test_version_persistence() {
let single = SingleTransaction::testing();
{
let provider = StandardVersionProvider::new(single.clone()).unwrap();
assert_eq!(provider.next().unwrap(), 1);
assert_eq!(provider.next().unwrap(), 2);
assert_eq!(provider.next().unwrap(), 3);
}
let provider2 = StandardVersionProvider::new(single.clone()).unwrap();
assert_eq!(provider2.next().unwrap(), BLOCK_SIZE + 1);
assert_eq!(provider2.current().unwrap(), BLOCK_SIZE + 1);
}
#[test]
fn test_block_exhaustion_and_allocation() {
let single = SingleTransaction::testing();
let provider = StandardVersionProvider::new(single).unwrap();
for _ in 0..BLOCK_SIZE {
provider.next().unwrap();
}
assert_eq!(provider.current().unwrap(), BLOCK_SIZE);
assert_eq!(provider.next().unwrap(), BLOCK_SIZE + 1);
assert_eq!(provider.current().unwrap(), BLOCK_SIZE + 1);
assert_eq!(provider.next().unwrap(), BLOCK_SIZE + 2);
assert_eq!(provider.current().unwrap(), BLOCK_SIZE + 2);
}
#[test]
fn test_concurrent_version_allocation() {
let single = SingleTransaction::testing();
let provider = Arc::new(StandardVersionProvider::new(single).unwrap());
let mut handles = vec![];
for _ in 0..10 {
let provider_clone = Arc::clone(&provider);
let handle = thread::spawn(move || {
let mut versions = vec![];
for _ in 0..100 {
versions.push(provider_clone.next().unwrap());
}
versions
});
handles.push(handle);
}
let mut all_versions = vec![];
for handle in handles {
let mut versions = handle.join().unwrap();
all_versions.append(&mut versions);
}
all_versions.sort();
for i in 1..all_versions.len() {
assert_ne!(
all_versions[i - 1],
all_versions[i],
"Duplicate version found: {}",
all_versions[i]
);
}
assert_eq!(all_versions.len(), 1000);
assert_eq!(all_versions[0], 1);
assert_eq!(all_versions[999], 1000);
}
#[test]
fn test_version_block_initialization() {
let block = VersionBlock::new(100);
assert_eq!(block.current, 100);
assert_eq!(block.last, 100 + BLOCK_SIZE);
}
#[test]
fn test_load_existing_version() {
let single = SingleTransaction::testing();
let shape = RowShape::testing(&[Type::Uint8]);
let key = TransactionVersionKey {}.encode();
let mut row = shape.allocate();
shape.set_u64(&mut row, 0, 500u64);
{
let mut tx = single.begin_command([&key]).unwrap();
tx.set(&key, row).unwrap();
tx.commit().unwrap();
}
let provider = StandardVersionProvider::new(single.clone()).unwrap();
assert_eq!(provider.current().unwrap(), 500);
assert_eq!(provider.next().unwrap(), 501);
}
}