Skip to main content

br_db/types/
pgsql_transaction.rs

1use br_pgsql::connect::Connect;
2use log::{info, warn};
3use std::collections::HashMap;
4use std::sync::atomic::{AtomicI32, Ordering};
5use std::sync::{Arc, Condvar, Mutex, RwLock};
6use std::time::{Duration, Instant};
7
8struct TransactionInfo {
9    conn: Mutex<Connect>,
10    depth: AtomicI32,
11    created_at: Instant,
12}
13
14pub struct PgsqlTransactionManager {
15    connections: RwLock<HashMap<String, Arc<TransactionInfo>>>,
16    table_locks: RwLock<HashMap<String, String>>,
17    table_cond: Condvar,
18    table_mutex: Mutex<()>,
19    timeout: Duration,
20}
21
22impl PgsqlTransactionManager {
23    pub fn new(timeout_secs: u64) -> Self {
24        Self {
25            connections: RwLock::new(HashMap::new()),
26            table_locks: RwLock::new(HashMap::new()),
27            table_cond: Condvar::new(),
28            table_mutex: Mutex::new(()),
29            timeout: Duration::from_secs(timeout_secs),
30        }
31    }
32
33    pub fn is_in_transaction(&self, key: &str) -> bool {
34        match self.connections.read() {
35            Ok(guard) => guard.contains_key(key),
36            Err(poisoned) => poisoned.into_inner().contains_key(key),
37        }
38    }
39
40    pub fn get_depth(&self, key: &str) -> i32 {
41        let conns = match self.connections.read() {
42            Ok(guard) => guard,
43            Err(poisoned) => poisoned.into_inner(),
44        };
45        conns
46            .get(key)
47            .map(|t| t.depth.load(Ordering::SeqCst))
48            .unwrap_or(0)
49    }
50
51    pub fn increment_depth(&self, key: &str) -> bool {
52        let conns = match self.connections.read() {
53            Ok(guard) => guard,
54            Err(poisoned) => poisoned.into_inner(),
55        };
56
57        if let Some(txn_info) = conns.get(key) {
58            let new_depth = txn_info.depth.fetch_add(1, Ordering::SeqCst) + 1;
59            info!(
60                "PgsqlTransactionManager: nested transaction depth={}",
61                new_depth
62            );
63            return true;
64        }
65        false
66    }
67
68    pub fn start(&self, key: &str, conn: Connect) -> bool {
69        let mut conns = match self.connections.write() {
70            Ok(guard) => guard,
71            Err(poisoned) => poisoned.into_inner(),
72        };
73
74        conns.insert(
75            key.to_string(),
76            Arc::new(TransactionInfo {
77                conn: Mutex::new(conn),
78                depth: AtomicI32::new(1),
79                created_at: Instant::now(),
80            }),
81        );
82        true
83    }
84
85    pub fn with_conn<F, R>(&self, key: &str, f: F) -> Option<R>
86    where
87        F: FnOnce(&mut Connect) -> R,
88    {
89        let txn_info = {
90            let conns = match self.connections.read() {
91                Ok(guard) => guard,
92                Err(poisoned) => poisoned.into_inner(),
93            };
94            conns.get(key).cloned()
95        };
96
97        txn_info.map(|info| {
98            let mut conn = match info.conn.lock() {
99                Ok(guard) => guard,
100                Err(poisoned) => poisoned.into_inner(),
101            };
102            f(&mut conn)
103        })
104    }
105
106    pub fn acquire_table_lock(&self, table: &str, thread_id: &str, timeout: Duration) -> bool {
107        let start = Instant::now();
108
109        loop {
110            {
111                let mut locks = match self.table_locks.write() {
112                    Ok(guard) => guard,
113                    Err(poisoned) => poisoned.into_inner(),
114                };
115
116                match locks.get(table) {
117                    None => {
118                        locks.insert(table.to_string(), thread_id.to_string());
119                        return true;
120                    }
121                    Some(owner) if owner == thread_id => {
122                        return true;
123                    }
124                    _ => {}
125                }
126            }
127
128            if start.elapsed() > timeout {
129                warn!("PgsqlTransactionManager: table lock timeout for {table}");
130                return false;
131            }
132
133            let guard = match self.table_mutex.lock() {
134                Ok(g) => g,
135                Err(poisoned) => poisoned.into_inner(),
136            };
137
138            let remaining = timeout.saturating_sub(start.elapsed());
139            if remaining.is_zero() {
140                return false;
141            }
142
143            let wait_time = Duration::from_millis(50).min(remaining);
144            let _ = self.table_cond.wait_timeout(guard, wait_time);
145        }
146    }
147
148    pub fn release_table_lock(&self, table: &str, thread_id: &str) {
149        let mut locks = match self.table_locks.write() {
150            Ok(guard) => guard,
151            Err(poisoned) => poisoned.into_inner(),
152        };
153
154        if let Some(owner) = locks.get(table) {
155            if owner == thread_id {
156                locks.remove(table);
157                self.table_cond.notify_all();
158            }
159        }
160    }
161
162    pub fn release_all_table_locks(&self, thread_id: &str) {
163        let mut locks = match self.table_locks.write() {
164            Ok(guard) => guard,
165            Err(poisoned) => poisoned.into_inner(),
166        };
167
168        let tables_to_remove: Vec<String> = locks
169            .iter()
170            .filter(|(_, owner)| *owner == thread_id)
171            .map(|(table, _)| table.clone())
172            .collect();
173
174        for table in tables_to_remove {
175            locks.remove(&table);
176        }
177
178        self.table_cond.notify_all();
179    }
180
181    pub fn decrement_or_finish(&self, key: &str, thread_id: &str) -> Option<i32> {
182        let txn_info = {
183            let conns = match self.connections.read() {
184                Ok(guard) => guard,
185                Err(poisoned) => poisoned.into_inner(),
186            };
187            conns.get(key).cloned()
188        };
189
190        if let Some(info) = txn_info {
191            let old_depth = info.depth.fetch_sub(1, Ordering::SeqCst);
192            if old_depth > 1 {
193                return Some(old_depth - 1);
194            }
195        }
196
197        let mut conns = match self.connections.write() {
198            Ok(guard) => guard,
199            Err(poisoned) => poisoned.into_inner(),
200        };
201        conns.remove(key);
202        drop(conns);
203        self.release_all_table_locks(thread_id);
204        Some(0)
205    }
206
207    pub fn remove(&self, key: &str, thread_id: &str) {
208        let mut conns = match self.connections.write() {
209            Ok(guard) => guard,
210            Err(poisoned) => poisoned.into_inner(),
211        };
212        conns.remove(key);
213        drop(conns);
214        self.release_all_table_locks(thread_id);
215    }
216
217    pub fn cleanup_expired(&self) {
218        let expired: Vec<String> = {
219            let conns = match self.connections.read() {
220                Ok(guard) => guard,
221                Err(poisoned) => poisoned.into_inner(),
222            };
223            conns
224                .iter()
225                .filter(|(_, txn_info)| txn_info.created_at.elapsed() > self.timeout)
226                .map(|(key, _)| key.clone())
227                .collect()
228        };
229
230        if expired.is_empty() {
231            return;
232        }
233
234        let mut conns = match self.connections.write() {
235            Ok(guard) => guard,
236            Err(poisoned) => poisoned.into_inner(),
237        };
238
239        for key in expired {
240            warn!("PgsqlTransactionManager: cleaning up expired transaction: {key}");
241            conns.remove(&key);
242        }
243    }
244
245    pub fn stats(&self) -> (usize, usize) {
246        let conn_count = match self.connections.read() {
247            Ok(g) => g.len(),
248            Err(poisoned) => poisoned.into_inner().len(),
249        };
250        let lock_count = match self.table_locks.read() {
251            Ok(g) => g.len(),
252            Err(poisoned) => poisoned.into_inner().len(),
253        };
254        (conn_count, lock_count)
255    }
256}
257
258lazy_static::lazy_static! {
259    pub static ref PGSQL_TRANSACTION_MANAGER: Arc<PgsqlTransactionManager> =
260        Arc::new(PgsqlTransactionManager::new(300));
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266    use std::thread;
267
268    #[test]
269    fn test_new_creates_empty_manager() {
270        let tm = PgsqlTransactionManager::new(60);
271        let (conn_count, lock_count) = tm.stats();
272        assert_eq!(conn_count, 0);
273        assert_eq!(lock_count, 0);
274    }
275
276    #[test]
277    fn test_is_in_transaction_false_when_empty() {
278        let tm = PgsqlTransactionManager::new(60);
279        assert!(!tm.is_in_transaction("nonexistent"));
280    }
281
282    #[test]
283    fn test_get_depth_zero_when_empty() {
284        let tm = PgsqlTransactionManager::new(60);
285        assert_eq!(tm.get_depth("nonexistent"), 0);
286    }
287
288    #[test]
289    fn test_increment_depth_returns_false_when_no_transaction() {
290        let tm = PgsqlTransactionManager::new(60);
291        assert!(!tm.increment_depth("nonexistent"));
292    }
293
294    #[test]
295    fn test_acquire_table_lock_new_table() {
296        let tm = PgsqlTransactionManager::new(60);
297        assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
298        let (_, lock_count) = tm.stats();
299        assert_eq!(lock_count, 1);
300    }
301
302    #[test]
303    fn test_acquire_table_lock_same_thread_reentrant() {
304        let tm = PgsqlTransactionManager::new(60);
305        let thread_id = "test_thread_1";
306
307        assert!(tm.acquire_table_lock("users", thread_id, Duration::from_secs(1)));
308        assert!(tm.acquire_table_lock("users", thread_id, Duration::from_secs(1)));
309
310        tm.release_table_lock("users", thread_id);
311    }
312
313    #[test]
314    fn test_acquire_table_lock_different_thread_timeout() {
315        let tm = Arc::new(PgsqlTransactionManager::new(60));
316        let tm2 = tm.clone();
317
318        assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
319
320        let handle = thread::spawn(move || {
321            let result = tm2.acquire_table_lock("users", "thread_b", Duration::from_millis(100));
322            assert!(!result);
323        });
324
325        handle.join().unwrap();
326        tm.release_table_lock("users", "thread_a");
327    }
328
329    #[test]
330    fn test_release_table_lock_by_owner() {
331        let tm = PgsqlTransactionManager::new(60);
332        assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
333        assert_eq!(tm.stats().1, 1);
334
335        tm.release_table_lock("users", "thread_a");
336        assert_eq!(tm.stats().1, 0);
337    }
338
339    #[test]
340    fn test_release_table_lock_wrong_thread_does_nothing() {
341        let tm = PgsqlTransactionManager::new(60);
342        assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
343
344        tm.release_table_lock("users", "thread_b");
345        assert_eq!(tm.stats().1, 1);
346    }
347
348    #[test]
349    fn test_release_all_table_locks() {
350        let tm = PgsqlTransactionManager::new(60);
351        let thread_id = "test_thread_2";
352
353        tm.acquire_table_lock("table1", thread_id, Duration::from_secs(1));
354        tm.acquire_table_lock("table2", thread_id, Duration::from_secs(1));
355        tm.acquire_table_lock("table3", thread_id, Duration::from_secs(1));
356
357        let (_, lock_count) = tm.stats();
358        assert_eq!(lock_count, 3);
359
360        tm.release_all_table_locks(thread_id);
361
362        let (_, lock_count) = tm.stats();
363        assert_eq!(lock_count, 0);
364    }
365
366    #[test]
367    fn test_release_all_table_locks_only_removes_own() {
368        let tm = Arc::new(PgsqlTransactionManager::new(60));
369        let tm2 = tm.clone();
370
371        tm.acquire_table_lock("table_a", "thread_a", Duration::from_secs(1));
372
373        let handle = thread::spawn(move || {
374            tm2.acquire_table_lock("table_b", "thread_b", Duration::from_secs(1));
375        });
376        handle.join().unwrap();
377
378        assert_eq!(tm.stats().1, 2);
379
380        tm.release_all_table_locks("thread_a");
381
382        assert_eq!(tm.stats().1, 1);
383    }
384
385    #[test]
386    fn test_decrement_or_finish_no_transaction() {
387        let tm = PgsqlTransactionManager::new(60);
388        tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1));
389
390        let result = tm.decrement_or_finish("nonexistent", "thread_a");
391        assert_eq!(result, Some(0));
392
393        assert_eq!(tm.stats().1, 0);
394    }
395
396    #[test]
397    fn test_remove_no_transaction() {
398        let tm = PgsqlTransactionManager::new(60);
399        tm.remove("nonexistent", "thread_a");
400    }
401
402    #[test]
403    fn test_cleanup_expired_no_transactions() {
404        let tm = PgsqlTransactionManager::new(60);
405        tm.cleanup_expired();
406    }
407
408    #[test]
409    fn test_stats_reflects_locks() {
410        let tm = PgsqlTransactionManager::new(60);
411        tm.acquire_table_lock("t1", "thread_a", Duration::from_secs(1));
412        tm.acquire_table_lock("t2", "thread_a", Duration::from_secs(1));
413        tm.acquire_table_lock("t3", "thread_b", Duration::from_secs(1));
414
415        let (conn_count, lock_count) = tm.stats();
416        assert_eq!(conn_count, 0);
417        assert_eq!(lock_count, 3);
418    }
419}