use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::thread::{self, ThreadId};
use std::time::Instant;
use std::collections::VecDeque;
use crate::error::{ZiporaError, Result};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
#[repr(u8)]
pub enum ConcurrencyLevel {
NoWriteReadOnly = 0,
SingleThreadStrict = 1,
SingleThreadShared = 2,
OneWriteMultiRead = 3,
MultiWriteMultiRead = 4,
}
impl ConcurrencyLevel {
#[inline]
pub const fn allows_concurrent_readers(self) -> bool {
matches!(self, Self::OneWriteMultiRead | Self::MultiWriteMultiRead)
}
#[inline]
pub const fn allows_concurrent_writers(self) -> bool {
matches!(self, Self::MultiWriteMultiRead)
}
#[inline]
pub const fn requires_synchronization(self) -> bool {
!matches!(self, Self::NoWriteReadOnly | Self::SingleThreadStrict)
}
#[inline]
pub const fn uses_lazy_cleanup(self) -> bool {
matches!(
self,
Self::SingleThreadShared | Self::OneWriteMultiRead | Self::MultiWriteMultiRead
)
}
pub const fn max_concurrent_readers(self) -> Option<usize> {
match self {
Self::NoWriteReadOnly => None, Self::SingleThreadStrict => Some(1),
Self::SingleThreadShared => Some(1),
Self::OneWriteMultiRead => None, Self::MultiWriteMultiRead => None, }
}
pub const fn max_concurrent_writers(self) -> Option<usize> {
match self {
Self::NoWriteReadOnly => Some(0), Self::SingleThreadStrict => Some(1),
Self::SingleThreadShared => Some(1),
Self::OneWriteMultiRead => Some(1), Self::MultiWriteMultiRead => None, }
}
}
impl Default for ConcurrencyLevel {
fn default() -> Self {
Self::SingleThreadStrict
}
}
impl std::fmt::Display for ConcurrencyLevel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::NoWriteReadOnly => write!(f, "NoWriteReadOnly"),
Self::SingleThreadStrict => write!(f, "SingleThreadStrict"),
Self::SingleThreadShared => write!(f, "SingleThreadShared"),
Self::OneWriteMultiRead => write!(f, "OneWriteMultiRead"),
Self::MultiWriteMultiRead => write!(f, "MultiWriteMultiRead"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct LazyFreeItem {
pub age: u64,
pub memory_offset: u32,
pub size: u32,
}
impl LazyFreeItem {
pub fn new(age: u64, memory_offset: u32, size: u32) -> Self {
Self {
age,
memory_offset,
size,
}
}
#[inline]
pub fn can_free(&self, min_version: u64) -> bool {
self.age < min_version
}
}
#[derive(Debug)]
pub struct LazyFreeList {
items: VecDeque<LazyFreeItem>,
bulk_threshold: usize,
stats: LazyFreeStats,
}
impl LazyFreeList {
pub const BULK_FREE_NUM: usize = 32;
pub fn new() -> Self {
Self::with_bulk_threshold(Self::BULK_FREE_NUM)
}
pub fn with_bulk_threshold(bulk_threshold: usize) -> Self {
Self {
items: VecDeque::new(),
bulk_threshold,
stats: LazyFreeStats::default(),
}
}
#[inline]
pub fn push(&mut self, item: LazyFreeItem) {
self.items.push_back(item);
self.stats.items_added += 1;
}
#[inline]
pub fn len(&self) -> usize {
self.items.len()
}
#[inline]
pub fn is_empty(&self) -> bool {
self.items.is_empty()
}
pub fn process_safe_items<F>(&mut self, min_version: u64, mut free_fn: F) -> usize
where
F: FnMut(LazyFreeItem),
{
let mut processed = 0;
let start_time = Instant::now();
while let Some(&front) = self.items.front() {
if !front.can_free(min_version) {
break; }
let item = self.items.pop_front().expect("items non-empty by len check");
free_fn(item);
processed += 1;
if processed >= self.bulk_threshold {
break;
}
}
self.stats.items_processed += processed as u64;
self.stats.total_processing_time += start_time.elapsed();
processed
}
pub fn should_bulk_process(&self) -> bool {
self.len() >= 2 * self.bulk_threshold
}
pub fn stats(&self) -> &LazyFreeStats {
&self.stats
}
pub fn clear_stats(&mut self) {
self.stats = LazyFreeStats::default();
}
}
impl Default for LazyFreeList {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Default, Clone)]
pub struct LazyFreeStats {
pub items_added: u64,
pub items_processed: u64,
pub total_processing_time: std::time::Duration,
}
impl LazyFreeStats {
pub fn efficiency(&self) -> f64 {
if self.items_added == 0 {
0.0
} else {
self.items_processed as f64 / self.items_added as f64
}
}
pub fn avg_processing_time(&self) -> std::time::Duration {
if self.items_processed == 0 {
std::time::Duration::ZERO
} else {
self.total_processing_time / self.items_processed as u32
}
}
}
#[derive(Debug)]
pub struct VersionManager {
concurrency_level: ConcurrencyLevel,
current_version: AtomicU64,
min_version: AtomicU64,
active_readers: AtomicU64,
active_writers: AtomicU64,
token_chain_mutex: Mutex<()>,
stats: Mutex<VersionManagerStats>,
}
impl VersionManager {
pub fn new(concurrency_level: ConcurrencyLevel) -> Self {
Self {
concurrency_level,
current_version: AtomicU64::new(1), min_version: AtomicU64::new(1),
active_readers: AtomicU64::new(0),
active_writers: AtomicU64::new(0),
token_chain_mutex: Mutex::new(()),
stats: Mutex::new(VersionManagerStats::default()),
}
}
#[inline]
pub fn concurrency_level(&self) -> ConcurrencyLevel {
self.concurrency_level
}
#[inline]
pub fn current_version(&self) -> u64 {
self.current_version.load(Ordering::Acquire)
}
#[inline]
pub fn min_version(&self) -> u64 {
self.min_version.load(Ordering::Acquire)
}
#[inline]
pub fn active_readers(&self) -> u64 {
self.active_readers.load(Ordering::Relaxed)
}
#[inline]
pub fn active_writers(&self) -> u64 {
self.active_writers.load(Ordering::Relaxed)
}
pub fn acquire_reader_token(&self) -> Result<ReaderToken> {
if self.concurrency_level == ConcurrencyLevel::NoWriteReadOnly {
return Ok(ReaderToken::new_readonly());
}
let start_time = Instant::now();
let (version, min_version) = if self.concurrency_level.requires_synchronization() {
let _lock = self.token_chain_mutex.lock().map_err(|_| {
ZiporaError::system_error("Failed to acquire token chain mutex for reader")
})?;
let current_min = self.min_version.load(Ordering::Acquire);
let version = self.current_version.fetch_add(1, Ordering::AcqRel) + 1;
(version, current_min)
} else {
(1, 1)
};
self.active_readers.fetch_add(1, Ordering::Relaxed);
if let Ok(mut stats) = self.stats.lock() {
stats.reader_tokens_acquired += 1;
stats.total_reader_acquisition_time += start_time.elapsed();
}
Ok(ReaderToken::new(
version,
min_version,
thread::current().id(),
self.concurrency_level,
Arc::new(TokenReleaseCallback {
version_manager: self as *const Self,
token_type: TokenType::Reader,
}),
))
}
pub fn acquire_writer_token(&self) -> Result<WriterToken> {
if self.concurrency_level == ConcurrencyLevel::NoWriteReadOnly {
return Err(ZiporaError::invalid_operation(
"Writers not allowed in NoWriteReadOnly mode",
));
}
let start_time = Instant::now();
if self.concurrency_level == ConcurrencyLevel::OneWriteMultiRead {
let current_writers = self.active_writers.load(Ordering::Acquire);
if current_writers > 0 {
return Err(ZiporaError::resource_busy(
"Another writer is already active in OneWriteMultiRead mode",
));
}
}
let (version, min_version) = if self.concurrency_level.requires_synchronization() {
let _lock = self.token_chain_mutex.lock().map_err(|_| {
ZiporaError::system_error("Failed to acquire token chain mutex for writer")
})?;
let current_min = self.min_version.load(Ordering::Acquire);
let version = self.current_version.fetch_add(1, Ordering::AcqRel) + 1;
(version, current_min)
} else {
(1, 1)
};
self.active_writers.fetch_add(1, Ordering::Relaxed);
if let Ok(mut stats) = self.stats.lock() {
stats.writer_tokens_acquired += 1;
stats.total_writer_acquisition_time += start_time.elapsed();
}
Ok(WriterToken::new(
version,
min_version,
thread::current().id(),
self.concurrency_level,
Arc::new(TokenReleaseCallback {
version_manager: self as *const Self,
token_type: TokenType::Writer,
}),
))
}
fn release_reader_token(&self, token_version: u64) {
self.active_readers.fetch_sub(1, Ordering::Relaxed);
if self.concurrency_level.requires_synchronization() {
self.try_advance_min_version();
}
if let Ok(mut stats) = self.stats.lock() {
stats.reader_tokens_released += 1;
}
}
fn release_writer_token(&self, token_version: u64) {
self.active_writers.fetch_sub(1, Ordering::Relaxed);
if self.concurrency_level.requires_synchronization() {
self.try_advance_min_version();
}
if let Ok(mut stats) = self.stats.lock() {
stats.writer_tokens_released += 1;
}
}
fn try_advance_min_version(&self) {
if self.active_readers.load(Ordering::Relaxed) == 0
&& self.active_writers.load(Ordering::Relaxed) == 0
{
let current = self.current_version.load(Ordering::Acquire);
self.min_version.store(current, Ordering::Release);
}
}
pub fn stats(&self) -> Result<VersionManagerStats> {
self.stats
.lock()
.map(|stats| stats.clone())
.map_err(|_| ZiporaError::system_error("Failed to acquire stats mutex"))
}
pub fn clear_stats(&self) -> Result<()> {
self.stats
.lock()
.map(|mut stats| *stats = VersionManagerStats::default())
.map_err(|_| ZiporaError::system_error("Failed to acquire stats mutex"))
}
pub fn validate_token_version(&self, token_version: u64) -> bool {
let current = self.current_version();
let min = self.min_version();
token_version >= min && token_version <= current
}
}
#[derive(Debug, Default, Clone)]
pub struct VersionManagerStats {
pub reader_tokens_acquired: u64,
pub reader_tokens_released: u64,
pub writer_tokens_acquired: u64,
pub writer_tokens_released: u64,
pub total_reader_acquisition_time: std::time::Duration,
pub total_writer_acquisition_time: std::time::Duration,
}
impl VersionManagerStats {
pub fn avg_reader_acquisition_time(&self) -> std::time::Duration {
if self.reader_tokens_acquired == 0 {
std::time::Duration::ZERO
} else {
self.total_reader_acquisition_time / self.reader_tokens_acquired as u32
}
}
pub fn avg_writer_acquisition_time(&self) -> std::time::Duration {
if self.writer_tokens_acquired == 0 {
std::time::Duration::ZERO
} else {
self.total_writer_acquisition_time / self.writer_tokens_acquired as u32
}
}
pub fn active_readers(&self) -> i64 {
self.reader_tokens_acquired as i64 - self.reader_tokens_released as i64
}
pub fn active_writers(&self) -> i64 {
self.writer_tokens_acquired as i64 - self.writer_tokens_released as i64
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum TokenType {
Reader,
Writer,
}
struct TokenReleaseCallback {
version_manager: *const VersionManager,
token_type: TokenType,
}
impl std::fmt::Debug for TokenReleaseCallback {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TokenReleaseCallback")
.field("version_manager", &(self.version_manager as usize))
.field("token_type", &self.token_type)
.finish()
}
}
impl TokenReleaseCallback {
fn release(&self, token_version: u64) {
unsafe {
let manager = &*self.version_manager;
match self.token_type {
TokenType::Reader => manager.release_reader_token(token_version),
TokenType::Writer => manager.release_writer_token(token_version),
}
}
}
}
unsafe impl Send for TokenReleaseCallback {}
unsafe impl Sync for TokenReleaseCallback {}
#[derive(Debug)]
pub struct ReaderToken {
version: u64,
min_version: u64,
thread_id: ThreadId,
concurrency_level: ConcurrencyLevel,
release_callback: Option<Arc<TokenReleaseCallback>>,
}
impl ReaderToken {
fn new(
version: u64,
min_version: u64,
thread_id: ThreadId,
concurrency_level: ConcurrencyLevel,
release_callback: Arc<TokenReleaseCallback>,
) -> Self {
Self {
version,
min_version,
thread_id,
concurrency_level,
release_callback: Some(release_callback),
}
}
fn new_readonly() -> Self {
Self {
version: 0, min_version: 0,
thread_id: thread::current().id(),
concurrency_level: ConcurrencyLevel::NoWriteReadOnly,
release_callback: None,
}
}
#[inline]
pub fn version(&self) -> u64 {
self.version
}
#[inline]
pub fn min_version(&self) -> u64 {
self.min_version
}
#[inline]
pub fn thread_id(&self) -> ThreadId {
self.thread_id
}
#[inline]
pub fn concurrency_level(&self) -> ConcurrencyLevel {
self.concurrency_level
}
#[inline]
pub fn is_valid(&self) -> bool {
self.concurrency_level == ConcurrencyLevel::NoWriteReadOnly || self.version > 0
}
#[inline]
pub fn is_readonly(&self) -> bool {
self.concurrency_level == ConcurrencyLevel::NoWriteReadOnly
}
}
impl Drop for ReaderToken {
fn drop(&mut self) {
if let Some(callback) = self.release_callback.take() {
callback.release(self.version);
}
}
}
#[derive(Debug)]
pub struct WriterToken {
version: u64,
min_version: u64,
thread_id: ThreadId,
concurrency_level: ConcurrencyLevel,
release_callback: Option<Arc<TokenReleaseCallback>>,
}
impl WriterToken {
fn new(
version: u64,
min_version: u64,
thread_id: ThreadId,
concurrency_level: ConcurrencyLevel,
release_callback: Arc<TokenReleaseCallback>,
) -> Self {
Self {
version,
min_version,
thread_id,
concurrency_level,
release_callback: Some(release_callback),
}
}
#[inline]
pub fn version(&self) -> u64 {
self.version
}
#[inline]
pub fn min_version(&self) -> u64 {
self.min_version
}
#[inline]
pub fn thread_id(&self) -> ThreadId {
self.thread_id
}
#[inline]
pub fn concurrency_level(&self) -> ConcurrencyLevel {
self.concurrency_level
}
#[inline]
pub fn is_valid(&self) -> bool {
self.version > 0
}
#[inline]
pub fn allows_concurrent_writers(&self) -> bool {
self.concurrency_level.allows_concurrent_writers()
}
}
impl Drop for WriterToken {
fn drop(&mut self) {
if let Some(callback) = self.release_callback.take() {
callback.release(self.version);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
use std::time::Duration;
#[test]
fn test_concurrency_level_properties() {
assert!(!ConcurrencyLevel::NoWriteReadOnly.allows_concurrent_writers());
assert!(ConcurrencyLevel::MultiWriteMultiRead.allows_concurrent_writers());
assert!(ConcurrencyLevel::OneWriteMultiRead.allows_concurrent_readers());
assert!(!ConcurrencyLevel::SingleThreadStrict.requires_synchronization());
assert!(ConcurrencyLevel::OneWriteMultiRead.uses_lazy_cleanup());
}
#[test]
fn test_lazy_free_list() {
let mut list = LazyFreeList::new();
assert!(list.is_empty());
list.push(LazyFreeItem::new(1, 100, 64));
list.push(LazyFreeItem::new(2, 200, 128));
list.push(LazyFreeItem::new(3, 300, 256));
assert_eq!(list.len(), 3);
let mut freed_items = Vec::new();
let processed = list.process_safe_items(3, |item| freed_items.push(item));
assert_eq!(processed, 2); assert_eq!(freed_items.len(), 2);
assert_eq!(list.len(), 1);
let stats = list.stats();
assert_eq!(stats.items_added, 3);
assert_eq!(stats.items_processed, 2);
}
#[test]
fn test_version_manager_single_thread() -> Result<()> {
let manager = VersionManager::new(ConcurrencyLevel::SingleThreadStrict);
let reader_token = manager.acquire_reader_token()?;
assert!(reader_token.is_valid());
assert_eq!(reader_token.concurrency_level(), ConcurrencyLevel::SingleThreadStrict);
let writer_token = manager.acquire_writer_token()?;
assert!(writer_token.is_valid());
assert!(!writer_token.allows_concurrent_writers());
Ok(())
}
#[test]
fn test_version_manager_readonly() -> Result<()> {
let manager = VersionManager::new(ConcurrencyLevel::NoWriteReadOnly);
let reader_token = manager.acquire_reader_token()?;
assert!(reader_token.is_valid());
assert!(reader_token.is_readonly());
let result = manager.acquire_writer_token();
assert!(result.is_err());
Ok(())
}
#[test]
fn test_version_manager_one_write_multi_read() -> Result<()> {
let manager = VersionManager::new(ConcurrencyLevel::OneWriteMultiRead);
let reader1 = manager.acquire_reader_token()?;
let reader2 = manager.acquire_reader_token()?;
assert_eq!(manager.active_readers(), 2);
let writer1 = manager.acquire_writer_token()?;
assert_eq!(manager.active_writers(), 1);
let result = manager.acquire_writer_token();
assert!(result.is_err());
drop(writer1);
assert_eq!(manager.active_writers(), 0);
let writer2 = manager.acquire_writer_token()?;
assert!(writer2.is_valid());
Ok(())
}
#[test]
fn test_version_manager_multi_write_multi_read() -> Result<()> {
let manager = VersionManager::new(ConcurrencyLevel::MultiWriteMultiRead);
let _reader1 = manager.acquire_reader_token()?;
let _reader2 = manager.acquire_reader_token()?;
let _writer1 = manager.acquire_writer_token()?;
let _writer2 = manager.acquire_writer_token()?;
assert_eq!(manager.active_readers(), 2);
assert_eq!(manager.active_writers(), 2);
Ok(())
}
#[test]
fn test_token_version_validation() -> Result<()> {
let manager = VersionManager::new(ConcurrencyLevel::OneWriteMultiRead);
let token = manager.acquire_reader_token()?;
assert!(manager.validate_token_version(token.version()));
assert!(!manager.validate_token_version(0));
assert!(!manager.validate_token_version(u64::MAX));
Ok(())
}
#[test]
fn test_concurrent_token_acquisition() -> Result<()> {
let manager = Arc::new(VersionManager::new(ConcurrencyLevel::MultiWriteMultiRead));
let num_threads = 4;
let tokens_per_thread = 10;
let handles: Vec<_> = (0..num_threads)
.map(|_| {
let manager_clone = Arc::clone(&manager);
thread::spawn(move || -> Result<()> {
for _ in 0..tokens_per_thread {
let _reader = manager_clone.acquire_reader_token()?;
let _writer = manager_clone.acquire_writer_token()?;
thread::sleep(Duration::from_millis(1));
}
Ok(())
})
})
.collect();
for handle in handles {
handle.join().unwrap()?;
}
assert_eq!(manager.active_readers(), 0);
assert_eq!(manager.active_writers(), 0);
let stats = manager.stats()?;
assert_eq!(stats.reader_tokens_acquired, num_threads * tokens_per_thread);
assert_eq!(stats.writer_tokens_acquired, num_threads * tokens_per_thread);
Ok(())
}
#[test]
fn test_token_drop_cleanup() -> Result<()> {
let manager = VersionManager::new(ConcurrencyLevel::OneWriteMultiRead);
{
let _reader = manager.acquire_reader_token()?;
let _writer = manager.acquire_writer_token()?;
assert_eq!(manager.active_readers(), 1);
assert_eq!(manager.active_writers(), 1);
}
thread::sleep(Duration::from_millis(10));
assert_eq!(manager.active_readers(), 0);
assert_eq!(manager.active_writers(), 0);
Ok(())
}
}