use crate::{AdapterError, SymbolTable};
use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard, TryLockError};
use std::time::{Duration, Instant};
#[derive(Debug, Clone, Default)]
pub struct LockStats {
pub read_locks: usize,
pub write_locks: usize,
pub read_contentions: usize,
pub write_contentions: usize,
pub read_wait_ms: u128,
pub write_wait_ms: u128,
pub transactions_started: usize,
pub transactions_committed: usize,
pub transactions_rolled_back: usize,
}
impl LockStats {
pub fn new() -> Self {
Self::default()
}
pub fn avg_read_wait_ms(&self) -> f64 {
if self.read_locks == 0 {
0.0
} else {
self.read_wait_ms as f64 / self.read_locks as f64
}
}
pub fn avg_write_wait_ms(&self) -> f64 {
if self.write_locks == 0 {
0.0
} else {
self.write_wait_ms as f64 / self.write_locks as f64
}
}
pub fn read_contention_rate(&self) -> f64 {
let total = self.read_locks + self.read_contentions;
if total == 0 {
0.0
} else {
self.read_contentions as f64 / total as f64
}
}
pub fn write_contention_rate(&self) -> f64 {
let total = self.write_locks + self.write_contentions;
if total == 0 {
0.0
} else {
self.write_contentions as f64 / total as f64
}
}
pub fn commit_rate(&self) -> f64 {
if self.transactions_started == 0 {
0.0
} else {
self.transactions_committed as f64 / self.transactions_started as f64
}
}
}
pub struct LockedSymbolTable {
table: RwLock<SymbolTable>,
stats: RwLock<LockStats>,
}
impl LockedSymbolTable {
pub fn new() -> Self {
Self {
table: RwLock::new(SymbolTable::new()),
stats: RwLock::new(LockStats::new()),
}
}
pub fn from_table(table: SymbolTable) -> Self {
Self {
table: RwLock::new(table),
stats: RwLock::new(LockStats::new()),
}
}
pub fn read(&self) -> RwLockReadGuard<'_, SymbolTable> {
let start = Instant::now();
let guard = self.table.read().expect("lock should not be poisoned");
let elapsed = start.elapsed().as_millis();
if let Ok(mut stats) = self.stats.write() {
stats.read_locks += 1;
stats.read_wait_ms += elapsed;
}
guard
}
pub fn try_read(&self) -> Option<RwLockReadGuard<'_, SymbolTable>> {
match self.table.try_read() {
Ok(guard) => {
if let Ok(mut stats) = self.stats.write() {
stats.read_locks += 1;
}
Some(guard)
}
Err(TryLockError::WouldBlock) => {
if let Ok(mut stats) = self.stats.write() {
stats.read_contentions += 1;
}
None
}
Err(TryLockError::Poisoned(_)) => None,
}
}
pub fn write(&self) -> RwLockWriteGuard<'_, SymbolTable> {
let start = Instant::now();
let guard = self.table.write().expect("lock should not be poisoned");
let elapsed = start.elapsed().as_millis();
if let Ok(mut stats) = self.stats.write() {
stats.write_locks += 1;
stats.write_wait_ms += elapsed;
}
guard
}
pub fn try_write(&self) -> Option<RwLockWriteGuard<'_, SymbolTable>> {
match self.table.try_write() {
Ok(guard) => {
if let Ok(mut stats) = self.stats.write() {
stats.write_locks += 1;
}
Some(guard)
}
Err(TryLockError::WouldBlock) => {
if let Ok(mut stats) = self.stats.write() {
stats.write_contentions += 1;
}
None
}
Err(TryLockError::Poisoned(_)) => None,
}
}
pub fn stats(&self) -> LockStats {
self.stats
.read()
.expect("lock should not be poisoned")
.clone()
}
pub fn reset_stats(&self) {
*self.stats.write().expect("lock should not be poisoned") = LockStats::new();
}
pub fn begin_transaction(&self) -> Transaction<'_> {
if let Ok(mut stats) = self.stats.write() {
stats.transactions_started += 1;
}
Transaction::new(self)
}
}
impl Default for LockedSymbolTable {
fn default() -> Self {
Self::new()
}
}
pub struct Transaction<'a> {
locked_table: &'a LockedSymbolTable,
snapshot: Option<SymbolTable>,
committed: bool,
}
impl<'a> Transaction<'a> {
fn new(locked_table: &'a LockedSymbolTable) -> Self {
let snapshot = locked_table.read().clone();
Self {
locked_table,
snapshot: Some(snapshot),
committed: false,
}
}
pub fn execute<F, R>(&mut self, f: F) -> Result<R, AdapterError>
where
F: FnOnce(&mut SymbolTable) -> Result<R, AdapterError>,
{
let mut guard = self.locked_table.write();
f(&mut guard)
}
pub fn commit(mut self) {
self.committed = true;
if let Ok(mut stats) = self.locked_table.stats.write() {
stats.transactions_committed += 1;
}
self.snapshot = None;
}
pub fn rollback(mut self) {
if let Some(snapshot) = self.snapshot.take() {
*self.locked_table.write() = snapshot;
}
if let Ok(mut stats) = self.locked_table.stats.write() {
stats.transactions_rolled_back += 1;
}
}
}
impl<'a> Drop for Transaction<'a> {
fn drop(&mut self) {
if !self.committed {
if let Some(snapshot) = self.snapshot.take() {
if let Ok(mut guard) = self.locked_table.table.write() {
*guard = snapshot;
}
if let Ok(mut stats) = self.locked_table.stats.write() {
stats.transactions_rolled_back += 1;
}
}
}
}
}
pub trait LockWithTimeout {
fn read_timeout(&self, timeout: Duration) -> Option<RwLockReadGuard<'_, SymbolTable>>;
fn write_timeout(&self, timeout: Duration) -> Option<RwLockWriteGuard<'_, SymbolTable>>;
}
impl LockWithTimeout for LockedSymbolTable {
fn read_timeout(&self, timeout: Duration) -> Option<RwLockReadGuard<'_, SymbolTable>> {
let start = Instant::now();
loop {
if let Some(guard) = self.try_read() {
return Some(guard);
}
if start.elapsed() >= timeout {
if let Ok(mut stats) = self.stats.write() {
stats.read_contentions += 1;
}
return None;
}
std::thread::sleep(Duration::from_millis(1));
}
}
fn write_timeout(&self, timeout: Duration) -> Option<RwLockWriteGuard<'_, SymbolTable>> {
let start = Instant::now();
loop {
if let Some(guard) = self.try_write() {
return Some(guard);
}
if start.elapsed() >= timeout {
if let Ok(mut stats) = self.stats.write() {
stats.write_contentions += 1;
}
return None;
}
std::thread::sleep(Duration::from_millis(1));
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::DomainInfo;
use std::sync::Arc;
use std::thread;
#[test]
fn test_basic_read_write() {
let table = LockedSymbolTable::new();
{
let mut guard = table.write();
guard
.add_domain(DomainInfo::new("User", 100))
.expect("unwrap");
}
{
let guard = table.read();
assert_eq!(guard.domains.len(), 1);
assert!(guard.get_domain("User").is_some());
}
}
#[test]
fn test_multiple_readers() {
let table = Arc::new(LockedSymbolTable::new());
{
let mut guard = table.write();
guard
.add_domain(DomainInfo::new("User", 100))
.expect("unwrap");
}
let mut handles = vec![];
for _ in 0..5 {
let table_clone = Arc::clone(&table);
handles.push(thread::spawn(move || {
let guard = table_clone.read();
assert_eq!(guard.domains.len(), 1);
}));
}
for handle in handles {
handle.join().expect("unwrap");
}
}
#[test]
fn test_try_read_write() {
let table = LockedSymbolTable::new();
{
let guard = table.try_read();
assert!(guard.is_some());
}
{
let guard = table.try_write();
assert!(guard.is_some());
}
}
#[test]
fn test_try_write_contention() {
let table = Arc::new(LockedSymbolTable::new());
let _read_guard = table.read();
let table_clone = Arc::clone(&table);
let handle = thread::spawn(move || {
let guard = table_clone.try_write();
assert!(guard.is_none());
});
handle.join().expect("unwrap");
let stats = table.stats();
assert!(stats.write_contentions > 0);
}
#[test]
fn test_transaction_commit() {
let table = LockedSymbolTable::new();
{
let mut txn = table.begin_transaction();
txn.execute(|t| {
t.add_domain(DomainInfo::new("User", 100)).expect("unwrap");
t.add_domain(DomainInfo::new("Post", 1000)).expect("unwrap");
Ok(())
})
.expect("unwrap");
txn.commit();
}
let guard = table.read();
assert_eq!(guard.domains.len(), 2);
let stats = table.stats();
assert_eq!(stats.transactions_committed, 1);
}
#[test]
fn test_transaction_rollback() {
let table = LockedSymbolTable::new();
{
let mut guard = table.write();
guard
.add_domain(DomainInfo::new("User", 100))
.expect("unwrap");
}
{
let mut txn = table.begin_transaction();
txn.execute(|t| {
t.add_domain(DomainInfo::new("Post", 1000)).expect("unwrap");
Ok(())
})
.expect("unwrap");
txn.rollback();
}
let guard = table.read();
assert_eq!(guard.domains.len(), 1);
assert!(guard.get_domain("Post").is_none());
let stats = table.stats();
assert_eq!(stats.transactions_rolled_back, 1);
}
#[test]
fn test_transaction_auto_rollback() {
let table = LockedSymbolTable::new();
{
let mut txn = table.begin_transaction();
txn.execute(|t| {
t.add_domain(DomainInfo::new("User", 100)).expect("unwrap");
Ok(())
})
.expect("unwrap");
}
let guard = table.read();
assert_eq!(guard.domains.len(), 0);
let stats = table.stats();
assert_eq!(stats.transactions_rolled_back, 1);
}
#[test]
fn test_lock_stats() {
let table = LockedSymbolTable::new();
for _ in 0..3 {
let _guard = table.read();
}
for _ in 0..2 {
let _guard = table.write();
}
let stats = table.stats();
assert_eq!(stats.read_locks, 3);
assert_eq!(stats.write_locks, 2);
}
#[test]
fn test_reset_stats() {
let table = LockedSymbolTable::new();
let _guard = table.read();
assert_eq!(table.stats().read_locks, 1);
table.reset_stats();
assert_eq!(table.stats().read_locks, 0);
}
#[test]
fn test_timeout_success() {
let table = LockedSymbolTable::new();
let guard = table.read_timeout(Duration::from_millis(100));
assert!(guard.is_some());
}
#[test]
fn test_timeout_failure() {
let table = Arc::new(LockedSymbolTable::new());
let _write_guard = table.write();
let table_clone = Arc::clone(&table);
let handle = thread::spawn(move || {
let guard = table_clone.write_timeout(Duration::from_millis(50));
assert!(guard.is_none());
});
handle.join().expect("unwrap");
}
#[test]
fn test_concurrent_read_write() {
let table = Arc::new(LockedSymbolTable::new());
{
let mut guard = table.write();
guard
.add_domain(DomainInfo::new("User", 100))
.expect("unwrap");
}
let mut handles = vec![];
for _ in 0..3 {
let table_clone = Arc::clone(&table);
handles.push(thread::spawn(move || {
for _ in 0..10 {
let guard = table_clone.read();
assert!(!guard.domains.is_empty());
thread::sleep(Duration::from_millis(1));
}
}));
}
for i in 0..2 {
let table_clone = Arc::clone(&table);
handles.push(thread::spawn(move || {
for j in 0..5 {
let mut guard = table_clone.write();
let domain_name = format!("Domain_{}_{}", i, j);
guard
.add_domain(DomainInfo::new(&domain_name, 100))
.expect("unwrap");
thread::sleep(Duration::from_millis(2));
}
}));
}
for handle in handles {
handle.join().expect("unwrap");
}
let guard = table.read();
assert!(guard.domains.len() >= 11);
let stats = table.stats();
assert!(stats.read_locks > 0);
assert!(stats.write_locks > 0);
}
#[test]
fn test_stats_calculations() {
let mut stats = LockStats::new();
stats.read_locks = 10;
stats.write_locks = 5;
stats.read_wait_ms = 100;
stats.write_wait_ms = 200;
stats.read_contentions = 2;
stats.write_contentions = 3;
stats.transactions_started = 10;
stats.transactions_committed = 8;
assert_eq!(stats.avg_read_wait_ms(), 10.0);
assert_eq!(stats.avg_write_wait_ms(), 40.0);
assert!((stats.read_contention_rate() - 0.1667).abs() < 0.001);
assert_eq!(stats.write_contention_rate(), 0.375);
assert_eq!(stats.commit_rate(), 0.8);
}
#[test]
fn test_transaction_error_handling() {
let table = LockedSymbolTable::new();
let result: Result<(), AdapterError> = {
let mut txn = table.begin_transaction();
txn.execute(|t| {
t.add_domain(DomainInfo::new("User", 100)).expect("unwrap");
Err(AdapterError::DuplicateDomain("User".to_string()))
})
};
assert!(result.is_err());
let guard = table.read();
assert_eq!(guard.domains.len(), 0);
}
#[test]
fn test_from_table() {
let mut original = SymbolTable::new();
original
.add_domain(DomainInfo::new("User", 100))
.expect("unwrap");
let locked = LockedSymbolTable::from_table(original);
let guard = locked.read();
assert_eq!(guard.domains.len(), 1);
assert!(guard.get_domain("User").is_some());
}
}