Skip to main content

br_db/types/
mysql_transaction.rs

1use log::{info, warn};
2use mysql::PooledConn;
3use std::collections::HashMap;
4use std::sync::{Arc, Condvar, Mutex, RwLock};
5use std::time::{Duration, Instant};
6
7#[derive(Debug)]
8struct TransactionInfo {
9    conn: PooledConn,
10    depth: i32,
11    created_at: Instant,
12}
13
14pub struct TransactionManager {
15    connections: RwLock<HashMap<String, TransactionInfo>>,
16    table_locks: RwLock<HashMap<String, String>>,
17    table_cond: Condvar,
18    table_mutex: Mutex<()>,
19    timeout: Duration,
20}
21
22impl TransactionManager {
23    pub fn new(timeout_secs: u64) -> Self {
24        Self {
25            connections: RwLock::new(HashMap::new()),
26            table_locks: RwLock::new(HashMap::new()),
27            table_cond: Condvar::new(),
28            table_mutex: Mutex::new(()),
29            timeout: Duration::from_secs(timeout_secs),
30        }
31    }
32
33    pub fn is_in_transaction(&self, key: &str) -> bool {
34        match self.connections.read() {
35            Ok(guard) => guard.contains_key(key),
36            Err(poisoned) => poisoned.into_inner().contains_key(key),
37        }
38    }
39
40    pub fn get_depth(&self, key: &str) -> i32 {
41        match self.connections.read() {
42            Ok(guard) => guard.get(key).map(|t| t.depth).unwrap_or(0),
43            Err(poisoned) => poisoned.into_inner().get(key).map(|t| t.depth).unwrap_or(0),
44        }
45    }
46
47    pub fn increment_depth(&self, key: &str) -> bool {
48        let mut conns = match self.connections.write() {
49            Ok(guard) => guard,
50            Err(poisoned) => poisoned.into_inner(),
51        };
52
53        if let Some(txn_info) = conns.get_mut(key) {
54            txn_info.depth += 1;
55            info!(
56                "TransactionManager: nested transaction depth={}",
57                txn_info.depth
58            );
59            return true;
60        }
61        false
62    }
63
64    pub fn start(&self, key: &str, conn: PooledConn) -> bool {
65        let mut conns = match self.connections.write() {
66            Ok(guard) => guard,
67            Err(poisoned) => poisoned.into_inner(),
68        };
69
70        conns.insert(
71            key.to_string(),
72            TransactionInfo {
73                conn,
74                depth: 1,
75                created_at: Instant::now(),
76            },
77        );
78        true
79    }
80
81    pub fn with_conn<F, R>(&self, key: &str, f: F) -> Option<R>
82    where
83        F: FnOnce(&mut PooledConn) -> R,
84    {
85        let mut conns = match self.connections.write() {
86            Ok(guard) => guard,
87            Err(poisoned) => poisoned.into_inner(),
88        };
89
90        conns.get_mut(key).map(|txn_info| f(&mut txn_info.conn))
91    }
92
93    pub fn acquire_table_lock(&self, table: &str, thread_id: &str, timeout: Duration) -> bool {
94        let start = Instant::now();
95
96        loop {
97            {
98                let mut locks = match self.table_locks.write() {
99                    Ok(guard) => guard,
100                    Err(poisoned) => poisoned.into_inner(),
101                };
102
103                match locks.get(table) {
104                    None => {
105                        locks.insert(table.to_string(), thread_id.to_string());
106                        return true;
107                    }
108                    Some(owner) if owner == thread_id => {
109                        return true;
110                    }
111                    _ => {}
112                }
113            }
114
115            if start.elapsed() > timeout {
116                warn!("TransactionManager: table lock timeout for {table}");
117                return false;
118            }
119
120            let guard = match self.table_mutex.lock() {
121                Ok(g) => g,
122                Err(poisoned) => poisoned.into_inner(),
123            };
124
125            let remaining = timeout.saturating_sub(start.elapsed());
126            if remaining.is_zero() {
127                return false;
128            }
129
130            let wait_time = Duration::from_millis(50).min(remaining);
131            let _ = self.table_cond.wait_timeout(guard, wait_time);
132        }
133    }
134
135    pub fn release_table_lock(&self, table: &str, thread_id: &str) {
136        let mut locks = match self.table_locks.write() {
137            Ok(guard) => guard,
138            Err(poisoned) => poisoned.into_inner(),
139        };
140
141        if let Some(owner) = locks.get(table) {
142            if owner == thread_id {
143                locks.remove(table);
144                self.table_cond.notify_all();
145            }
146        }
147    }
148
149    pub fn release_all_table_locks(&self, thread_id: &str) {
150        let mut locks = match self.table_locks.write() {
151            Ok(guard) => guard,
152            Err(poisoned) => poisoned.into_inner(),
153        };
154
155        let tables_to_remove: Vec<String> = locks
156            .iter()
157            .filter(|(_, owner)| *owner == thread_id)
158            .map(|(table, _)| table.clone())
159            .collect();
160
161        for table in tables_to_remove {
162            locks.remove(&table);
163        }
164
165        self.table_cond.notify_all();
166    }
167
168    pub fn decrement_or_finish(&self, key: &str, thread_id: &str) -> Option<i32> {
169        let mut conns = match self.connections.write() {
170            Ok(guard) => guard,
171            Err(poisoned) => poisoned.into_inner(),
172        };
173
174        if let Some(txn_info) = conns.get_mut(key) {
175            if txn_info.depth > 1 {
176                txn_info.depth -= 1;
177                return Some(txn_info.depth);
178            }
179        }
180
181        conns.remove(key);
182        drop(conns);
183        self.release_all_table_locks(thread_id);
184        Some(0)
185    }
186
187    pub fn remove(&self, key: &str, thread_id: &str) {
188        let mut conns = match self.connections.write() {
189            Ok(guard) => guard,
190            Err(poisoned) => poisoned.into_inner(),
191        };
192        conns.remove(key);
193        drop(conns);
194        self.release_all_table_locks(thread_id);
195    }
196
197    pub fn cleanup_expired(&self) {
198        let expired: Vec<String> = {
199            let conns = match self.connections.read() {
200                Ok(guard) => guard,
201                Err(poisoned) => poisoned.into_inner(),
202            };
203            conns
204                .iter()
205                .filter(|(_, txn_info)| txn_info.created_at.elapsed() > self.timeout)
206                .map(|(key, _)| key.clone())
207                .collect()
208        };
209
210        if expired.is_empty() {
211            return;
212        }
213
214        let mut conns = match self.connections.write() {
215            Ok(guard) => guard,
216            Err(poisoned) => poisoned.into_inner(),
217        };
218
219        for key in expired {
220            warn!("TransactionManager: cleaning up expired transaction: {key}");
221            conns.remove(&key);
222        }
223    }
224
225    pub fn stats(&self) -> (usize, usize) {
226        let conn_count = match self.connections.read() {
227            Ok(g) => g.len(),
228            Err(poisoned) => poisoned.into_inner().len(),
229        };
230        let lock_count = match self.table_locks.read() {
231            Ok(g) => g.len(),
232            Err(poisoned) => poisoned.into_inner().len(),
233        };
234        (conn_count, lock_count)
235    }
236}
237
238lazy_static::lazy_static! {
239    pub static ref TRANSACTION_MANAGER: Arc<TransactionManager> =
240        Arc::new(TransactionManager::new(300));
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246    use std::thread;
247
248    #[test]
249    fn test_new_creates_empty_manager() {
250        let tm = TransactionManager::new(60);
251        let (conn_count, lock_count) = tm.stats();
252        assert_eq!(conn_count, 0);
253        assert_eq!(lock_count, 0);
254    }
255
256    #[test]
257    fn test_is_in_transaction_false_when_empty() {
258        let tm = TransactionManager::new(60);
259        assert!(!tm.is_in_transaction("nonexistent"));
260    }
261
262    #[test]
263    fn test_get_depth_zero_when_empty() {
264        let tm = TransactionManager::new(60);
265        assert_eq!(tm.get_depth("nonexistent"), 0);
266    }
267
268    #[test]
269    fn test_increment_depth_returns_false_when_no_transaction() {
270        let tm = TransactionManager::new(60);
271        assert!(!tm.increment_depth("nonexistent"));
272    }
273
274    #[test]
275    fn test_acquire_table_lock_new_table() {
276        let tm = TransactionManager::new(60);
277        assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
278        let (_, lock_count) = tm.stats();
279        assert_eq!(lock_count, 1);
280    }
281
282    #[test]
283    fn test_acquire_table_lock_same_thread_reentrant() {
284        let tm = TransactionManager::new(60);
285        let thread_id = "test_thread_1";
286
287        assert!(tm.acquire_table_lock("users", thread_id, Duration::from_secs(1)));
288        assert!(tm.acquire_table_lock("users", thread_id, Duration::from_secs(1)));
289
290        tm.release_table_lock("users", thread_id);
291    }
292
293    #[test]
294    fn test_acquire_table_lock_different_thread_timeout() {
295        let tm = Arc::new(TransactionManager::new(60));
296        let tm2 = tm.clone();
297
298        assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
299
300        let handle = thread::spawn(move || {
301            let result = tm2.acquire_table_lock("users", "thread_b", Duration::from_millis(100));
302            assert!(!result);
303        });
304
305        handle.join().unwrap();
306        tm.release_table_lock("users", "thread_a");
307    }
308
309    #[test]
310    fn test_release_table_lock_by_owner() {
311        let tm = TransactionManager::new(60);
312        assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
313        assert_eq!(tm.stats().1, 1);
314
315        tm.release_table_lock("users", "thread_a");
316        assert_eq!(tm.stats().1, 0);
317    }
318
319    #[test]
320    fn test_release_table_lock_wrong_thread_does_nothing() {
321        let tm = TransactionManager::new(60);
322        assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
323
324        tm.release_table_lock("users", "thread_b");
325        assert_eq!(tm.stats().1, 1);
326    }
327
328    #[test]
329    fn test_release_all_table_locks() {
330        let tm = TransactionManager::new(60);
331        let thread_id = "test_thread_2";
332
333        tm.acquire_table_lock("table1", thread_id, Duration::from_secs(1));
334        tm.acquire_table_lock("table2", thread_id, Duration::from_secs(1));
335        tm.acquire_table_lock("table3", thread_id, Duration::from_secs(1));
336
337        let (_, lock_count) = tm.stats();
338        assert_eq!(lock_count, 3);
339
340        tm.release_all_table_locks(thread_id);
341
342        let (_, lock_count) = tm.stats();
343        assert_eq!(lock_count, 0);
344    }
345
346    #[test]
347    fn test_release_all_table_locks_only_removes_own() {
348        let tm = Arc::new(TransactionManager::new(60));
349        let tm2 = tm.clone();
350
351        tm.acquire_table_lock("table_a", "thread_a", Duration::from_secs(1));
352
353        let handle = thread::spawn(move || {
354            tm2.acquire_table_lock("table_b", "thread_b", Duration::from_secs(1));
355        });
356        handle.join().unwrap();
357
358        assert_eq!(tm.stats().1, 2);
359
360        tm.release_all_table_locks("thread_a");
361
362        assert_eq!(tm.stats().1, 1);
363    }
364
365    #[test]
366    fn test_decrement_or_finish_no_transaction() {
367        let tm = TransactionManager::new(60);
368        tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1));
369
370        let result = tm.decrement_or_finish("nonexistent", "thread_a");
371        assert_eq!(result, Some(0));
372
373        assert_eq!(tm.stats().1, 0);
374    }
375
376    #[test]
377    fn test_remove_no_transaction() {
378        let tm = TransactionManager::new(60);
379        tm.remove("nonexistent", "thread_a");
380    }
381
382    #[test]
383    fn test_cleanup_expired_no_transactions() {
384        let tm = TransactionManager::new(60);
385        tm.cleanup_expired();
386    }
387
388    #[test]
389    fn test_stats_reflects_locks() {
390        let tm = TransactionManager::new(60);
391        tm.acquire_table_lock("t1", "thread_a", Duration::from_secs(1));
392        tm.acquire_table_lock("t2", "thread_a", Duration::from_secs(1));
393        tm.acquire_table_lock("t3", "thread_b", Duration::from_secs(1));
394
395        let (conn_count, lock_count) = tm.stats();
396        assert_eq!(conn_count, 0);
397        assert_eq!(lock_count, 3);
398    }
399}