use log::{info, warn};
use mysql::PooledConn;
use std::collections::HashMap;
use std::sync::{Arc, Condvar, Mutex, RwLock};
use std::time::{Duration, Instant};
#[derive(Debug)]
struct TransactionInfo {
conn: PooledConn,
depth: i32,
created_at: Instant,
}
pub struct TransactionManager {
connections: RwLock<HashMap<String, TransactionInfo>>,
table_locks: RwLock<HashMap<String, String>>,
table_cond: Condvar,
table_mutex: Mutex<()>,
timeout: Duration,
}
impl TransactionManager {
pub fn new(timeout_secs: u64) -> Self {
Self {
connections: RwLock::new(HashMap::new()),
table_locks: RwLock::new(HashMap::new()),
table_cond: Condvar::new(),
table_mutex: Mutex::new(()),
timeout: Duration::from_secs(timeout_secs),
}
}
pub fn is_in_transaction(&self, key: &str) -> bool {
match self.connections.read() {
Ok(guard) => guard.contains_key(key),
Err(poisoned) => poisoned.into_inner().contains_key(key),
}
}
pub fn get_depth(&self, key: &str) -> i32 {
match self.connections.read() {
Ok(guard) => guard.get(key).map(|t| t.depth).unwrap_or(0),
Err(poisoned) => poisoned.into_inner().get(key).map(|t| t.depth).unwrap_or(0),
}
}
pub fn increment_depth(&self, key: &str) -> bool {
let mut conns = match self.connections.write() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
if let Some(txn_info) = conns.get_mut(key) {
txn_info.depth += 1;
info!(
"TransactionManager: nested transaction depth={}",
txn_info.depth
);
return true;
}
false
}
pub fn start(&self, key: &str, conn: PooledConn) -> bool {
let mut conns = match self.connections.write() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
conns.insert(
key.to_string(),
TransactionInfo {
conn,
depth: 1,
created_at: Instant::now(),
},
);
true
}
pub fn with_conn<F, R>(&self, key: &str, f: F) -> Option<R>
where
F: FnOnce(&mut PooledConn) -> R,
{
let mut conns = match self.connections.write() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
conns.get_mut(key).map(|txn_info| f(&mut txn_info.conn))
}
pub fn acquire_table_lock(&self, table: &str, thread_id: &str, timeout: Duration) -> bool {
let start = Instant::now();
loop {
{
let mut locks = match self.table_locks.write() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
match locks.get(table) {
None => {
locks.insert(table.to_string(), thread_id.to_string());
return true;
}
Some(owner) if owner == thread_id => {
return true;
}
_ => {}
}
}
if start.elapsed() > timeout {
warn!("TransactionManager: table lock timeout for {table}");
return false;
}
let guard = match self.table_mutex.lock() {
Ok(g) => g,
Err(poisoned) => poisoned.into_inner(),
};
let remaining = timeout.saturating_sub(start.elapsed());
if remaining.is_zero() {
return false;
}
let wait_time = Duration::from_millis(50).min(remaining);
let _ = self.table_cond.wait_timeout(guard, wait_time);
}
}
pub fn release_table_lock(&self, table: &str, thread_id: &str) {
let mut locks = match self.table_locks.write() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
if let Some(owner) = locks.get(table) {
if owner == thread_id {
locks.remove(table);
self.table_cond.notify_all();
}
}
}
pub fn release_all_table_locks(&self, thread_id: &str) {
let mut locks = match self.table_locks.write() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
let tables_to_remove: Vec<String> = locks
.iter()
.filter(|(_, owner)| *owner == thread_id)
.map(|(table, _)| table.clone())
.collect();
for table in tables_to_remove {
locks.remove(&table);
}
self.table_cond.notify_all();
}
pub fn decrement_or_finish(&self, key: &str, thread_id: &str) -> Option<i32> {
let mut conns = match self.connections.write() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
if let Some(txn_info) = conns.get_mut(key) {
if txn_info.depth > 1 {
txn_info.depth -= 1;
return Some(txn_info.depth);
}
}
conns.remove(key);
drop(conns);
self.release_all_table_locks(thread_id);
Some(0)
}
pub fn remove(&self, key: &str, thread_id: &str) {
let mut conns = match self.connections.write() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
conns.remove(key);
drop(conns);
self.release_all_table_locks(thread_id);
}
pub fn cleanup_expired(&self) {
let expired: Vec<String> = {
let conns = match self.connections.read() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
conns
.iter()
.filter(|(_, txn_info)| txn_info.created_at.elapsed() > self.timeout)
.map(|(key, _)| key.clone())
.collect()
};
if expired.is_empty() {
return;
}
let mut conns = match self.connections.write() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
for key in expired {
warn!("TransactionManager: cleaning up expired transaction: {key}");
conns.remove(&key);
}
}
pub fn stats(&self) -> (usize, usize) {
let conn_count = match self.connections.read() {
Ok(g) => g.len(),
Err(poisoned) => poisoned.into_inner().len(),
};
let lock_count = match self.table_locks.read() {
Ok(g) => g.len(),
Err(poisoned) => poisoned.into_inner().len(),
};
(conn_count, lock_count)
}
}
lazy_static::lazy_static! {
pub static ref TRANSACTION_MANAGER: Arc<TransactionManager> =
Arc::new(TransactionManager::new(300));
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
#[test]
fn test_new_creates_empty_manager() {
let tm = TransactionManager::new(60);
let (conn_count, lock_count) = tm.stats();
assert_eq!(conn_count, 0);
assert_eq!(lock_count, 0);
}
#[test]
fn test_is_in_transaction_false_when_empty() {
let tm = TransactionManager::new(60);
assert!(!tm.is_in_transaction("nonexistent"));
}
#[test]
fn test_get_depth_zero_when_empty() {
let tm = TransactionManager::new(60);
assert_eq!(tm.get_depth("nonexistent"), 0);
}
#[test]
fn test_increment_depth_returns_false_when_no_transaction() {
let tm = TransactionManager::new(60);
assert!(!tm.increment_depth("nonexistent"));
}
#[test]
fn test_acquire_table_lock_new_table() {
let tm = TransactionManager::new(60);
assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
let (_, lock_count) = tm.stats();
assert_eq!(lock_count, 1);
}
#[test]
fn test_acquire_table_lock_same_thread_reentrant() {
let tm = TransactionManager::new(60);
let thread_id = "test_thread_1";
assert!(tm.acquire_table_lock("users", thread_id, Duration::from_secs(1)));
assert!(tm.acquire_table_lock("users", thread_id, Duration::from_secs(1)));
tm.release_table_lock("users", thread_id);
}
#[test]
fn test_acquire_table_lock_different_thread_timeout() {
let tm = Arc::new(TransactionManager::new(60));
let tm2 = tm.clone();
assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
let handle = thread::spawn(move || {
let result = tm2.acquire_table_lock("users", "thread_b", Duration::from_millis(100));
assert!(!result);
});
handle.join().unwrap();
tm.release_table_lock("users", "thread_a");
}
#[test]
fn test_release_table_lock_by_owner() {
let tm = TransactionManager::new(60);
assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
assert_eq!(tm.stats().1, 1);
tm.release_table_lock("users", "thread_a");
assert_eq!(tm.stats().1, 0);
}
#[test]
fn test_release_table_lock_wrong_thread_does_nothing() {
let tm = TransactionManager::new(60);
assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
tm.release_table_lock("users", "thread_b");
assert_eq!(tm.stats().1, 1);
}
#[test]
fn test_release_all_table_locks() {
let tm = TransactionManager::new(60);
let thread_id = "test_thread_2";
tm.acquire_table_lock("table1", thread_id, Duration::from_secs(1));
tm.acquire_table_lock("table2", thread_id, Duration::from_secs(1));
tm.acquire_table_lock("table3", thread_id, Duration::from_secs(1));
let (_, lock_count) = tm.stats();
assert_eq!(lock_count, 3);
tm.release_all_table_locks(thread_id);
let (_, lock_count) = tm.stats();
assert_eq!(lock_count, 0);
}
#[test]
fn test_release_all_table_locks_only_removes_own() {
let tm = Arc::new(TransactionManager::new(60));
let tm2 = tm.clone();
tm.acquire_table_lock("table_a", "thread_a", Duration::from_secs(1));
let handle = thread::spawn(move || {
tm2.acquire_table_lock("table_b", "thread_b", Duration::from_secs(1));
});
handle.join().unwrap();
assert_eq!(tm.stats().1, 2);
tm.release_all_table_locks("thread_a");
assert_eq!(tm.stats().1, 1);
}
#[test]
fn test_decrement_or_finish_no_transaction() {
let tm = TransactionManager::new(60);
tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1));
let result = tm.decrement_or_finish("nonexistent", "thread_a");
assert_eq!(result, Some(0));
assert_eq!(tm.stats().1, 0);
}
#[test]
fn test_remove_no_transaction() {
let tm = TransactionManager::new(60);
tm.remove("nonexistent", "thread_a");
}
#[test]
fn test_cleanup_expired_no_transactions() {
let tm = TransactionManager::new(60);
tm.cleanup_expired();
}
#[test]
fn test_stats_reflects_locks() {
let tm = TransactionManager::new(60);
tm.acquire_table_lock("t1", "thread_a", Duration::from_secs(1));
tm.acquire_table_lock("t2", "thread_a", Duration::from_secs(1));
tm.acquire_table_lock("t3", "thread_b", Duration::from_secs(1));
let (conn_count, lock_count) = tm.stats();
assert_eq!(conn_count, 0);
assert_eq!(lock_count, 3);
}
}