br-db 1.8.71

This is an orm database mysql mssql sqlite
Documentation
use log::{info, warn};
use mysql::PooledConn;
use std::collections::HashMap;
use std::sync::{Arc, Condvar, Mutex, RwLock};
use std::time::{Duration, Instant};

#[derive(Debug)]
struct TransactionInfo {
    conn: PooledConn,
    depth: i32,
    created_at: Instant,
}

pub struct TransactionManager {
    connections: RwLock<HashMap<String, TransactionInfo>>,
    table_locks: RwLock<HashMap<String, String>>,
    table_cond: Condvar,
    table_mutex: Mutex<()>,
    timeout: Duration,
}

impl TransactionManager {
    pub fn new(timeout_secs: u64) -> Self {
        Self {
            connections: RwLock::new(HashMap::new()),
            table_locks: RwLock::new(HashMap::new()),
            table_cond: Condvar::new(),
            table_mutex: Mutex::new(()),
            timeout: Duration::from_secs(timeout_secs),
        }
    }

    pub fn is_in_transaction(&self, key: &str) -> bool {
        match self.connections.read() {
            Ok(guard) => guard.contains_key(key),
            Err(poisoned) => poisoned.into_inner().contains_key(key),
        }
    }

    pub fn get_depth(&self, key: &str) -> i32 {
        match self.connections.read() {
            Ok(guard) => guard.get(key).map(|t| t.depth).unwrap_or(0),
            Err(poisoned) => poisoned.into_inner().get(key).map(|t| t.depth).unwrap_or(0),
        }
    }

    pub fn increment_depth(&self, key: &str) -> bool {
        let mut conns = match self.connections.write() {
            Ok(guard) => guard,
            Err(poisoned) => poisoned.into_inner(),
        };

        if let Some(txn_info) = conns.get_mut(key) {
            txn_info.depth += 1;
            info!(
                "TransactionManager: nested transaction depth={}",
                txn_info.depth
            );
            return true;
        }
        false
    }

    pub fn start(&self, key: &str, conn: PooledConn) -> bool {
        let mut conns = match self.connections.write() {
            Ok(guard) => guard,
            Err(poisoned) => poisoned.into_inner(),
        };

        conns.insert(
            key.to_string(),
            TransactionInfo {
                conn,
                depth: 1,
                created_at: Instant::now(),
            },
        );
        true
    }

    pub fn with_conn<F, R>(&self, key: &str, f: F) -> Option<R>
    where
        F: FnOnce(&mut PooledConn) -> R,
    {
        let mut conns = match self.connections.write() {
            Ok(guard) => guard,
            Err(poisoned) => poisoned.into_inner(),
        };

        conns.get_mut(key).map(|txn_info| f(&mut txn_info.conn))
    }

    pub fn acquire_table_lock(&self, table: &str, thread_id: &str, timeout: Duration) -> bool {
        let start = Instant::now();

        loop {
            {
                let mut locks = match self.table_locks.write() {
                    Ok(guard) => guard,
                    Err(poisoned) => poisoned.into_inner(),
                };

                match locks.get(table) {
                    None => {
                        locks.insert(table.to_string(), thread_id.to_string());
                        return true;
                    }
                    Some(owner) if owner == thread_id => {
                        return true;
                    }
                    _ => {}
                }
            }

            if start.elapsed() > timeout {
                warn!("TransactionManager: table lock timeout for {table}");
                return false;
            }

            let guard = match self.table_mutex.lock() {
                Ok(g) => g,
                Err(poisoned) => poisoned.into_inner(),
            };

            let remaining = timeout.saturating_sub(start.elapsed());
            if remaining.is_zero() {
                return false;
            }

            let wait_time = Duration::from_millis(50).min(remaining);
            let _ = self.table_cond.wait_timeout(guard, wait_time);
        }
    }

    pub fn release_table_lock(&self, table: &str, thread_id: &str) {
        let mut locks = match self.table_locks.write() {
            Ok(guard) => guard,
            Err(poisoned) => poisoned.into_inner(),
        };

        if let Some(owner) = locks.get(table) {
            if owner == thread_id {
                locks.remove(table);
                self.table_cond.notify_all();
            }
        }
    }

    pub fn release_all_table_locks(&self, thread_id: &str) {
        let mut locks = match self.table_locks.write() {
            Ok(guard) => guard,
            Err(poisoned) => poisoned.into_inner(),
        };

        let tables_to_remove: Vec<String> = locks
            .iter()
            .filter(|(_, owner)| *owner == thread_id)
            .map(|(table, _)| table.clone())
            .collect();

        for table in tables_to_remove {
            locks.remove(&table);
        }

        self.table_cond.notify_all();
    }

    pub fn decrement_or_finish(&self, key: &str, thread_id: &str) -> Option<i32> {
        let mut conns = match self.connections.write() {
            Ok(guard) => guard,
            Err(poisoned) => poisoned.into_inner(),
        };

        if let Some(txn_info) = conns.get_mut(key) {
            if txn_info.depth > 1 {
                txn_info.depth -= 1;
                return Some(txn_info.depth);
            }
        }

        conns.remove(key);
        drop(conns);
        self.release_all_table_locks(thread_id);
        Some(0)
    }

    pub fn remove(&self, key: &str, thread_id: &str) {
        let mut conns = match self.connections.write() {
            Ok(guard) => guard,
            Err(poisoned) => poisoned.into_inner(),
        };
        conns.remove(key);
        drop(conns);
        self.release_all_table_locks(thread_id);
    }

    pub fn cleanup_expired(&self) {
        let expired: Vec<String> = {
            let conns = match self.connections.read() {
                Ok(guard) => guard,
                Err(poisoned) => poisoned.into_inner(),
            };
            conns
                .iter()
                .filter(|(_, txn_info)| txn_info.created_at.elapsed() > self.timeout)
                .map(|(key, _)| key.clone())
                .collect()
        };

        if expired.is_empty() {
            return;
        }

        let mut conns = match self.connections.write() {
            Ok(guard) => guard,
            Err(poisoned) => poisoned.into_inner(),
        };

        for key in expired {
            warn!("TransactionManager: cleaning up expired transaction: {key}");
            conns.remove(&key);
        }
    }

    pub fn stats(&self) -> (usize, usize) {
        let conn_count = match self.connections.read() {
            Ok(g) => g.len(),
            Err(poisoned) => poisoned.into_inner().len(),
        };
        let lock_count = match self.table_locks.read() {
            Ok(g) => g.len(),
            Err(poisoned) => poisoned.into_inner().len(),
        };
        (conn_count, lock_count)
    }
}

lazy_static::lazy_static! {
    pub static ref TRANSACTION_MANAGER: Arc<TransactionManager> =
        Arc::new(TransactionManager::new(300));
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::thread;

    #[test]
    fn test_new_creates_empty_manager() {
        let tm = TransactionManager::new(60);
        let (conn_count, lock_count) = tm.stats();
        assert_eq!(conn_count, 0);
        assert_eq!(lock_count, 0);
    }

    #[test]
    fn test_is_in_transaction_false_when_empty() {
        let tm = TransactionManager::new(60);
        assert!(!tm.is_in_transaction("nonexistent"));
    }

    #[test]
    fn test_get_depth_zero_when_empty() {
        let tm = TransactionManager::new(60);
        assert_eq!(tm.get_depth("nonexistent"), 0);
    }

    #[test]
    fn test_increment_depth_returns_false_when_no_transaction() {
        let tm = TransactionManager::new(60);
        assert!(!tm.increment_depth("nonexistent"));
    }

    #[test]
    fn test_acquire_table_lock_new_table() {
        let tm = TransactionManager::new(60);
        assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
        let (_, lock_count) = tm.stats();
        assert_eq!(lock_count, 1);
    }

    #[test]
    fn test_acquire_table_lock_same_thread_reentrant() {
        let tm = TransactionManager::new(60);
        let thread_id = "test_thread_1";

        assert!(tm.acquire_table_lock("users", thread_id, Duration::from_secs(1)));
        assert!(tm.acquire_table_lock("users", thread_id, Duration::from_secs(1)));

        tm.release_table_lock("users", thread_id);
    }

    #[test]
    fn test_acquire_table_lock_different_thread_timeout() {
        let tm = Arc::new(TransactionManager::new(60));
        let tm2 = tm.clone();

        assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));

        let handle = thread::spawn(move || {
            let result = tm2.acquire_table_lock("users", "thread_b", Duration::from_millis(100));
            assert!(!result);
        });

        handle.join().unwrap();
        tm.release_table_lock("users", "thread_a");
    }

    #[test]
    fn test_release_table_lock_by_owner() {
        let tm = TransactionManager::new(60);
        assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
        assert_eq!(tm.stats().1, 1);

        tm.release_table_lock("users", "thread_a");
        assert_eq!(tm.stats().1, 0);
    }

    #[test]
    fn test_release_table_lock_wrong_thread_does_nothing() {
        let tm = TransactionManager::new(60);
        assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));

        tm.release_table_lock("users", "thread_b");
        assert_eq!(tm.stats().1, 1);
    }

    #[test]
    fn test_release_all_table_locks() {
        let tm = TransactionManager::new(60);
        let thread_id = "test_thread_2";

        tm.acquire_table_lock("table1", thread_id, Duration::from_secs(1));
        tm.acquire_table_lock("table2", thread_id, Duration::from_secs(1));
        tm.acquire_table_lock("table3", thread_id, Duration::from_secs(1));

        let (_, lock_count) = tm.stats();
        assert_eq!(lock_count, 3);

        tm.release_all_table_locks(thread_id);

        let (_, lock_count) = tm.stats();
        assert_eq!(lock_count, 0);
    }

    #[test]
    fn test_release_all_table_locks_only_removes_own() {
        let tm = Arc::new(TransactionManager::new(60));
        let tm2 = tm.clone();

        tm.acquire_table_lock("table_a", "thread_a", Duration::from_secs(1));

        let handle = thread::spawn(move || {
            tm2.acquire_table_lock("table_b", "thread_b", Duration::from_secs(1));
        });
        handle.join().unwrap();

        assert_eq!(tm.stats().1, 2);

        tm.release_all_table_locks("thread_a");

        assert_eq!(tm.stats().1, 1);
    }

    #[test]
    fn test_decrement_or_finish_no_transaction() {
        let tm = TransactionManager::new(60);
        tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1));

        let result = tm.decrement_or_finish("nonexistent", "thread_a");
        assert_eq!(result, Some(0));

        assert_eq!(tm.stats().1, 0);
    }

    #[test]
    fn test_remove_no_transaction() {
        let tm = TransactionManager::new(60);
        tm.remove("nonexistent", "thread_a");
    }

    #[test]
    fn test_cleanup_expired_no_transactions() {
        let tm = TransactionManager::new(60);
        tm.cleanup_expired();
    }

    #[test]
    fn test_stats_reflects_locks() {
        let tm = TransactionManager::new(60);
        tm.acquire_table_lock("t1", "thread_a", Duration::from_secs(1));
        tm.acquire_table_lock("t2", "thread_a", Duration::from_secs(1));
        tm.acquire_table_lock("t3", "thread_b", Duration::from_secs(1));

        let (conn_count, lock_count) = tm.stats();
        assert_eq!(conn_count, 0);
        assert_eq!(lock_count, 3);
    }
}