Skip to main content

br_db/types/
pgsql_transaction.rs

1use br_pgsql::connect::Connect;
2use log::{info, warn};
3use std::collections::HashMap;
4use std::sync::{Arc, Condvar, Mutex, RwLock};
5use std::time::{Duration, Instant};
6
7#[derive(Debug)]
8struct TransactionInfo {
9    conn: Arc<Mutex<Connect>>,
10    depth: i32,
11    created_at: Instant,
12}
13
14pub struct PgsqlTransactionManager {
15    connections: RwLock<HashMap<String, TransactionInfo>>,
16    table_locks: RwLock<HashMap<String, String>>,
17    table_cond: Condvar,
18    table_mutex: Mutex<()>,
19    timeout: Duration,
20}
21
22impl PgsqlTransactionManager {
23    pub fn new(timeout_secs: u64) -> Self {
24        Self {
25            connections: RwLock::new(HashMap::new()),
26            table_locks: RwLock::new(HashMap::new()),
27            table_cond: Condvar::new(),
28            table_mutex: Mutex::new(()),
29            timeout: Duration::from_secs(timeout_secs),
30        }
31    }
32
33    pub fn is_in_transaction(&self, key: &str) -> bool {
34        match self.connections.read() {
35            Ok(guard) => guard.contains_key(key),
36            Err(poisoned) => poisoned.into_inner().contains_key(key),
37        }
38    }
39
40    pub fn get_depth(&self, key: &str) -> i32 {
41        match self.connections.read() {
42            Ok(guard) => guard.get(key).map(|t| t.depth).unwrap_or(0),
43            Err(poisoned) => poisoned.into_inner().get(key).map(|t| t.depth).unwrap_or(0),
44        }
45    }
46
47    pub fn increment_depth(&self, key: &str) -> bool {
48        let mut conns = match self.connections.write() {
49            Ok(guard) => guard,
50            Err(poisoned) => poisoned.into_inner(),
51        };
52
53        if let Some(txn_info) = conns.get_mut(key) {
54            txn_info.depth += 1;
55            info!(
56                "PgsqlTransactionManager: nested transaction depth={}",
57                txn_info.depth
58            );
59            return true;
60        }
61        false
62    }
63
64    pub fn start(&self, key: &str, conn: Connect) -> bool {
65        let mut conns = match self.connections.write() {
66            Ok(guard) => guard,
67            Err(poisoned) => poisoned.into_inner(),
68        };
69
70        conns.insert(
71            key.to_string(),
72            TransactionInfo {
73                conn: Arc::new(Mutex::new(conn)),
74                depth: 1,
75                created_at: Instant::now(),
76            },
77        );
78        true
79    }
80
81    pub fn with_conn<F, R>(&self, key: &str, f: F) -> Option<R>
82    where
83        F: FnOnce(&mut Connect) -> R,
84    {
85        let conn = {
86            let conns = match self.connections.read() {
87                Ok(guard) => guard,
88                Err(poisoned) => poisoned.into_inner(),
89            };
90            conns.get(key).map(|txn_info| Arc::clone(&txn_info.conn))
91        };
92
93        conn.map(|arc_conn| {
94            let mut conn_guard = match arc_conn.lock() {
95                Ok(guard) => guard,
96                Err(poisoned) => poisoned.into_inner(),
97            };
98            f(&mut conn_guard)
99        })
100    }
101
102    pub fn acquire_table_lock(&self, table: &str, thread_id: &str, timeout: Duration) -> bool {
103        let start = Instant::now();
104
105        loop {
106            {
107                let mut locks = match self.table_locks.write() {
108                    Ok(guard) => guard,
109                    Err(poisoned) => poisoned.into_inner(),
110                };
111
112                match locks.get(table) {
113                    None => {
114                        locks.insert(table.to_string(), thread_id.to_string());
115                        return true;
116                    }
117                    Some(owner) if owner == thread_id => {
118                        return true;
119                    }
120                    _ => {}
121                }
122            }
123
124            if start.elapsed() > timeout {
125                warn!("PgsqlTransactionManager: table lock timeout for {table}");
126                return false;
127            }
128
129            let guard = match self.table_mutex.lock() {
130                Ok(g) => g,
131                Err(poisoned) => poisoned.into_inner(),
132            };
133
134            let remaining = timeout.saturating_sub(start.elapsed());
135            if remaining.is_zero() {
136                return false;
137            }
138
139            let wait_time = Duration::from_millis(50).min(remaining);
140            let _ = self.table_cond.wait_timeout(guard, wait_time);
141        }
142    }
143
144    pub fn release_table_lock(&self, table: &str, thread_id: &str) {
145        let mut locks = match self.table_locks.write() {
146            Ok(guard) => guard,
147            Err(poisoned) => poisoned.into_inner(),
148        };
149
150        if let Some(owner) = locks.get(table) {
151            if owner == thread_id {
152                locks.remove(table);
153                self.table_cond.notify_all();
154            }
155        }
156    }
157
158    pub fn release_all_table_locks(&self, thread_id: &str) {
159        let mut locks = match self.table_locks.write() {
160            Ok(guard) => guard,
161            Err(poisoned) => poisoned.into_inner(),
162        };
163
164        let tables_to_remove: Vec<String> = locks
165            .iter()
166            .filter(|(_, owner)| *owner == thread_id)
167            .map(|(table, _)| table.clone())
168            .collect();
169
170        for table in tables_to_remove {
171            locks.remove(&table);
172        }
173
174        self.table_cond.notify_all();
175    }
176
177    pub fn decrement_or_finish(&self, key: &str, thread_id: &str) -> Option<i32> {
178        let mut conns = match self.connections.write() {
179            Ok(guard) => guard,
180            Err(poisoned) => poisoned.into_inner(),
181        };
182
183        if let Some(txn_info) = conns.get_mut(key) {
184            if txn_info.depth > 1 {
185                txn_info.depth -= 1;
186                return Some(txn_info.depth);
187            }
188        }
189
190        conns.remove(key);
191        drop(conns);
192        self.release_all_table_locks(thread_id);
193        Some(0)
194    }
195
196    pub fn remove(&self, key: &str, thread_id: &str) -> Option<Connect> {
197        let mut conns = match self.connections.write() {
198            Ok(guard) => guard,
199            Err(poisoned) => poisoned.into_inner(),
200        };
201        let removed = conns.remove(key);
202        drop(conns);
203        self.release_all_table_locks(thread_id);
204        removed.and_then(|txn_info| match Arc::try_unwrap(txn_info.conn) {
205            Ok(conn_mutex) => Some(match conn_mutex.into_inner() {
206                Ok(conn) => conn,
207                Err(poisoned) => poisoned.into_inner(),
208            }),
209            Err(_) => {
210                warn!("PgsqlTransactionManager: transaction connection still borrowed when removing {key}");
211                None
212            }
213        })
214    }
215
216    pub fn cleanup_expired(&self) {
217        let expired: Vec<String> = {
218            let conns = match self.connections.read() {
219                Ok(guard) => guard,
220                Err(poisoned) => poisoned.into_inner(),
221            };
222            conns
223                .iter()
224                .filter(|(_, txn_info)| txn_info.created_at.elapsed() > self.timeout)
225                .map(|(key, _)| key.clone())
226                .collect()
227        };
228
229        if expired.is_empty() {
230            return;
231        }
232
233        let mut conns = match self.connections.write() {
234            Ok(guard) => guard,
235            Err(poisoned) => poisoned.into_inner(),
236        };
237
238        for key in expired {
239            warn!("PgsqlTransactionManager: cleaning up expired transaction: {key}");
240            conns.remove(&key);
241        }
242    }
243
244    pub fn stats(&self) -> (usize, usize) {
245        let conn_count = match self.connections.read() {
246            Ok(g) => g.len(),
247            Err(poisoned) => poisoned.into_inner().len(),
248        };
249        let lock_count = match self.table_locks.read() {
250            Ok(g) => g.len(),
251            Err(poisoned) => poisoned.into_inner().len(),
252        };
253        (conn_count, lock_count)
254    }
255}
256
257lazy_static::lazy_static! {
258    pub static ref PGSQL_TRANSACTION_MANAGER: Arc<PgsqlTransactionManager> =
259        Arc::new(PgsqlTransactionManager::new(300));
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265    use std::thread;
266
267    #[test]
268    fn test_new_creates_empty_manager() {
269        let tm = PgsqlTransactionManager::new(60);
270        let (conn_count, lock_count) = tm.stats();
271        assert_eq!(conn_count, 0);
272        assert_eq!(lock_count, 0);
273    }
274
275    #[test]
276    fn test_is_in_transaction_false_when_empty() {
277        let tm = PgsqlTransactionManager::new(60);
278        assert!(!tm.is_in_transaction("nonexistent"));
279    }
280
281    #[test]
282    fn test_get_depth_zero_when_empty() {
283        let tm = PgsqlTransactionManager::new(60);
284        assert_eq!(tm.get_depth("nonexistent"), 0);
285    }
286
287    #[test]
288    fn test_increment_depth_returns_false_when_no_transaction() {
289        let tm = PgsqlTransactionManager::new(60);
290        assert!(!tm.increment_depth("nonexistent"));
291    }
292
293    #[test]
294    fn test_acquire_table_lock_new_table() {
295        let tm = PgsqlTransactionManager::new(60);
296        assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
297        let (_, lock_count) = tm.stats();
298        assert_eq!(lock_count, 1);
299    }
300
301    #[test]
302    fn test_acquire_table_lock_same_thread_reentrant() {
303        let tm = PgsqlTransactionManager::new(60);
304        let thread_id = "test_thread_1";
305
306        assert!(tm.acquire_table_lock("users", thread_id, Duration::from_secs(1)));
307        assert!(tm.acquire_table_lock("users", thread_id, Duration::from_secs(1)));
308
309        tm.release_table_lock("users", thread_id);
310    }
311
312    #[test]
313    fn test_acquire_table_lock_different_thread_timeout() {
314        let tm = Arc::new(PgsqlTransactionManager::new(60));
315        let tm2 = tm.clone();
316
317        assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
318
319        let handle = thread::spawn(move || {
320            let result = tm2.acquire_table_lock("users", "thread_b", Duration::from_millis(100));
321            assert!(!result);
322        });
323
324        handle.join().unwrap();
325        tm.release_table_lock("users", "thread_a");
326    }
327
328    #[test]
329    fn test_release_table_lock_by_owner() {
330        let tm = PgsqlTransactionManager::new(60);
331        assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
332        assert_eq!(tm.stats().1, 1);
333
334        tm.release_table_lock("users", "thread_a");
335        assert_eq!(tm.stats().1, 0);
336    }
337
338    #[test]
339    fn test_release_table_lock_wrong_thread_does_nothing() {
340        let tm = PgsqlTransactionManager::new(60);
341        assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
342
343        tm.release_table_lock("users", "thread_b");
344        assert_eq!(tm.stats().1, 1);
345    }
346
347    #[test]
348    fn test_release_all_table_locks() {
349        let tm = PgsqlTransactionManager::new(60);
350        let thread_id = "test_thread_2";
351
352        tm.acquire_table_lock("table1", thread_id, Duration::from_secs(1));
353        tm.acquire_table_lock("table2", thread_id, Duration::from_secs(1));
354        tm.acquire_table_lock("table3", thread_id, Duration::from_secs(1));
355
356        let (_, lock_count) = tm.stats();
357        assert_eq!(lock_count, 3);
358
359        tm.release_all_table_locks(thread_id);
360
361        let (_, lock_count) = tm.stats();
362        assert_eq!(lock_count, 0);
363    }
364
365    #[test]
366    fn test_release_all_table_locks_only_removes_own() {
367        let tm = Arc::new(PgsqlTransactionManager::new(60));
368        let tm2 = tm.clone();
369
370        tm.acquire_table_lock("table_a", "thread_a", Duration::from_secs(1));
371
372        let handle = thread::spawn(move || {
373            tm2.acquire_table_lock("table_b", "thread_b", Duration::from_secs(1));
374        });
375        handle.join().unwrap();
376
377        assert_eq!(tm.stats().1, 2);
378
379        tm.release_all_table_locks("thread_a");
380
381        assert_eq!(tm.stats().1, 1);
382    }
383
384    #[test]
385    fn test_decrement_or_finish_no_transaction() {
386        let tm = PgsqlTransactionManager::new(60);
387        tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1));
388
389        let result = tm.decrement_or_finish("nonexistent", "thread_a");
390        assert_eq!(result, Some(0));
391
392        assert_eq!(tm.stats().1, 0);
393    }
394
395    #[test]
396    fn test_remove_no_transaction() {
397        let tm = PgsqlTransactionManager::new(60);
398        tm.remove("nonexistent", "thread_a");
399    }
400
401    #[test]
402    fn test_cleanup_expired_no_transactions() {
403        let tm = PgsqlTransactionManager::new(60);
404        tm.cleanup_expired();
405    }
406
407    #[test]
408    fn test_stats_reflects_locks() {
409        let tm = PgsqlTransactionManager::new(60);
410        tm.acquire_table_lock("t1", "thread_a", Duration::from_secs(1));
411        tm.acquire_table_lock("t2", "thread_a", Duration::from_secs(1));
412        tm.acquire_table_lock("t3", "thread_b", Duration::from_secs(1));
413
414        let (conn_count, lock_count) = tm.stats();
415        assert_eq!(conn_count, 0);
416        assert_eq!(lock_count, 3);
417    }
418}