use log::{info, warn};
use sqlite::ConnectionThreadSafe;
use std::collections::HashMap;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::{Arc, Condvar, Mutex, RwLock};
use std::time::{Duration, Instant};
struct TransactionInfo {
conn: Arc<ConnectionThreadSafe>,
depth: AtomicU32,
created_at: Instant,
}
pub struct SqliteTransactionManager {
connections: RwLock<HashMap<String, Arc<TransactionInfo>>>,
active_writer: Mutex<Option<String>>,
writer_cond: Condvar,
timeout: Duration,
}
impl SqliteTransactionManager {
pub fn new(timeout_secs: u64) -> Self {
Self {
connections: RwLock::new(HashMap::new()),
active_writer: Mutex::new(None),
writer_cond: Condvar::new(),
timeout: Duration::from_secs(timeout_secs),
}
}
pub fn is_in_transaction(&self, key: &str) -> bool {
match self.connections.read() {
Ok(guard) => guard.contains_key(key),
Err(poisoned) => poisoned.into_inner().contains_key(key),
}
}
pub fn get_depth(&self, key: &str) -> u32 {
let conns = match self.connections.read() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
conns
.get(key)
.map(|t| t.depth.load(Ordering::SeqCst))
.unwrap_or(0)
}
pub fn increment_depth(&self, key: &str) -> bool {
let conns = match self.connections.read() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
if let Some(txn_info) = conns.get(key) {
let new_depth = txn_info.depth.fetch_add(1, Ordering::SeqCst) + 1;
info!(
"SqliteTransactionManager: nested transaction depth={}",
new_depth
);
return true;
}
false
}
pub fn acquire_write_lock(&self, thread_id: &str, timeout: Duration) -> bool {
let start = Instant::now();
let mut guard = match self.active_writer.lock() {
Ok(g) => g,
Err(poisoned) => poisoned.into_inner(),
};
loop {
match &*guard {
None => {
*guard = Some(thread_id.to_string());
return true;
}
Some(owner) if owner == thread_id => {
return true;
}
_ => {}
}
let remaining = timeout.saturating_sub(start.elapsed());
if remaining.is_zero() {
warn!("SqliteTransactionManager: write lock timeout for {thread_id}");
return false;
}
let wait_time = Duration::from_millis(100).min(remaining);
let result = self.writer_cond.wait_timeout(guard, wait_time);
guard = match result {
Ok((g, _)) => g,
Err(poisoned) => poisoned.into_inner().0,
};
}
}
pub fn release_write_lock(&self, thread_id: &str) {
let mut guard = match self.active_writer.lock() {
Ok(g) => g,
Err(poisoned) => poisoned.into_inner(),
};
if let Some(owner) = &*guard {
if owner == thread_id {
*guard = None;
self.writer_cond.notify_all();
}
}
}
pub fn start(&self, key: &str, conn: Arc<ConnectionThreadSafe>) -> bool {
let mut conns = match self.connections.write() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
conns.insert(
key.to_string(),
Arc::new(TransactionInfo {
conn,
depth: AtomicU32::new(1),
created_at: Instant::now(),
}),
);
true
}
pub fn with_conn<F, R>(&self, key: &str, f: F) -> Option<R>
where
F: FnOnce(&ConnectionThreadSafe) -> R,
{
let txn_info = {
let conns = match self.connections.read() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
conns.get(key).cloned()
};
txn_info.map(|info| f(&info.conn))
}
pub fn decrement_or_finish(&self, key: &str, thread_id: &str) -> Option<u32> {
let txn_info = {
let conns = match self.connections.read() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
conns.get(key).cloned()
};
if let Some(info) = txn_info {
let old_depth = info.depth.fetch_sub(1, Ordering::SeqCst);
if old_depth > 1 {
return Some(old_depth - 1);
}
}
let mut conns = match self.connections.write() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
conns.remove(key);
drop(conns);
self.release_write_lock(thread_id);
Some(0)
}
pub fn remove(&self, key: &str, thread_id: &str) {
let mut conns = match self.connections.write() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
conns.remove(key);
drop(conns);
self.release_write_lock(thread_id);
}
pub fn cleanup_expired(&self) {
let expired: Vec<(String, String)> = {
let conns = match self.connections.read() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
conns
.iter()
.filter(|(_, txn_info)| txn_info.created_at.elapsed() > self.timeout)
.map(|(key, _)| (key.clone(), key.clone()))
.collect()
};
if expired.is_empty() {
return;
}
for (key, thread_id) in expired {
warn!("SqliteTransactionManager: cleaning up expired transaction: {key}");
let mut conns = match self.connections.write() {
Ok(guard) => guard,
Err(poisoned) => poisoned.into_inner(),
};
conns.remove(&key);
drop(conns);
self.release_write_lock(&thread_id);
}
}
pub fn stats(&self) -> (usize, bool) {
let conn_count = match self.connections.read() {
Ok(g) => g.len(),
Err(poisoned) => poisoned.into_inner().len(),
};
let has_writer = match self.active_writer.lock() {
Ok(g) => g.is_some(),
Err(poisoned) => poisoned.into_inner().is_some(),
};
(conn_count, has_writer)
}
}
lazy_static::lazy_static! {
pub static ref SQLITE_TRANSACTION_MANAGER: Arc<SqliteTransactionManager> =
Arc::new(SqliteTransactionManager::new(300));
}
#[cfg(test)]
mod tests {
use super::*;
use std::thread;
fn mem_conn() -> Arc<ConnectionThreadSafe> {
Arc::new(sqlite::Connection::open_thread_safe(":memory:").unwrap())
}
#[test]
fn test_new_creates_empty_manager() {
let tm = SqliteTransactionManager::new(60);
let (conn_count, has_writer) = tm.stats();
assert_eq!(conn_count, 0);
assert!(!has_writer);
}
#[test]
fn test_is_in_transaction_false_when_empty() {
let tm = SqliteTransactionManager::new(60);
assert!(!tm.is_in_transaction("nonexistent"));
}
#[test]
fn test_get_depth_zero_when_empty() {
let tm = SqliteTransactionManager::new(60);
assert_eq!(tm.get_depth("nonexistent"), 0);
}
#[test]
fn test_start_and_is_in_transaction() {
let tm = SqliteTransactionManager::new(60);
let key = "test_key";
assert!(tm.start(key, mem_conn()));
assert!(tm.is_in_transaction(key));
}
#[test]
fn test_start_and_get_depth() {
let tm = SqliteTransactionManager::new(60);
let key = "test_key";
tm.start(key, mem_conn());
assert_eq!(tm.get_depth(key), 1);
}
#[test]
fn test_increment_depth() {
let tm = SqliteTransactionManager::new(60);
let key = "test_key";
tm.start(key, mem_conn());
assert!(tm.increment_depth(key));
assert_eq!(tm.get_depth(key), 2);
}
#[test]
fn test_increment_depth_returns_false_when_no_transaction() {
let tm = SqliteTransactionManager::new(60);
assert!(!tm.increment_depth("nonexistent"));
}
#[test]
fn test_with_conn_executes_closure() {
let tm = SqliteTransactionManager::new(60);
let key = "test_key";
tm.start(key, mem_conn());
let result = tm.with_conn(key, |conn| {
conn.execute("SELECT 1").unwrap();
42
});
assert_eq!(result, Some(42));
}
#[test]
fn test_with_conn_returns_none_when_no_transaction() {
let tm = SqliteTransactionManager::new(60);
let result = tm.with_conn("nonexistent", |_conn| 42);
assert_eq!(result, None);
}
#[test]
fn test_decrement_or_finish_from_depth_2() {
let tm = SqliteTransactionManager::new(60);
let key = "test_key";
tm.start(key, mem_conn());
tm.increment_depth(key);
assert_eq!(tm.get_depth(key), 2);
let result = tm.decrement_or_finish(key, key);
assert_eq!(result, Some(1));
assert!(tm.is_in_transaction(key));
}
#[test]
fn test_decrement_or_finish_from_depth_1() {
let tm = SqliteTransactionManager::new(60);
let key = "test_key";
tm.start(key, mem_conn());
assert_eq!(tm.get_depth(key), 1);
let result = tm.decrement_or_finish(key, key);
assert_eq!(result, Some(0));
assert!(!tm.is_in_transaction(key));
}
#[test]
fn test_remove_clears_transaction() {
let tm = SqliteTransactionManager::new(60);
let key = "test_key";
tm.start(key, mem_conn());
assert!(tm.is_in_transaction(key));
tm.remove(key, key);
assert!(!tm.is_in_transaction(key));
}
#[test]
fn test_remove_releases_write_lock() {
let tm = SqliteTransactionManager::new(60);
let key = "test_key";
tm.acquire_write_lock(key, Duration::from_secs(1));
assert!(tm.stats().1);
tm.start(key, mem_conn());
tm.remove(key, key);
assert!(!tm.stats().1);
}
#[test]
fn test_acquire_write_lock_same_thread_reentrant() {
let tm = SqliteTransactionManager::new(60);
let thread_id = "test_thread_1";
assert!(tm.acquire_write_lock(thread_id, Duration::from_secs(1)));
assert!(tm.acquire_write_lock(thread_id, Duration::from_secs(1)));
tm.release_write_lock(thread_id);
let (_, has_writer) = tm.stats();
assert!(!has_writer);
}
#[test]
fn test_acquire_write_lock_different_thread_timeout() {
let tm = Arc::new(SqliteTransactionManager::new(60));
let tm2 = tm.clone();
assert!(tm.acquire_write_lock("thread_1", Duration::from_secs(1)));
let handle = thread::spawn(move || {
let result = tm2.acquire_write_lock("thread_2", Duration::from_millis(100));
assert!(!result);
});
handle.join().unwrap();
tm.release_write_lock("thread_1");
}
#[test]
fn test_release_write_lock_by_owner() {
let tm = SqliteTransactionManager::new(60);
assert!(tm.acquire_write_lock("thread_a", Duration::from_secs(1)));
assert!(tm.stats().1);
tm.release_write_lock("thread_a");
assert!(!tm.stats().1);
}
#[test]
fn test_release_write_lock_wrong_thread_does_nothing() {
let tm = SqliteTransactionManager::new(60);
assert!(tm.acquire_write_lock("thread_a", Duration::from_secs(1)));
tm.release_write_lock("thread_b");
assert!(tm.stats().1);
}
#[test]
fn test_cleanup_expired_no_transactions() {
let tm = SqliteTransactionManager::new(60);
tm.cleanup_expired();
}
#[test]
fn test_cleanup_expired_removes_old_transactions() {
let tm = SqliteTransactionManager::new(0);
let key = "test_key";
tm.acquire_write_lock(key, Duration::from_secs(1));
tm.start(key, mem_conn());
assert!(tm.is_in_transaction(key));
thread::sleep(Duration::from_millis(10));
tm.cleanup_expired();
assert!(!tm.is_in_transaction(key));
assert!(!tm.stats().1);
}
#[test]
fn test_stats_reflects_state() {
let tm = SqliteTransactionManager::new(60);
let (conn_count, has_writer) = tm.stats();
assert_eq!(conn_count, 0);
assert!(!has_writer);
tm.start("key1", mem_conn());
tm.acquire_write_lock("writer", Duration::from_secs(1));
let (conn_count, has_writer) = tm.stats();
assert_eq!(conn_count, 1);
assert!(has_writer);
}
}