Skip to main content

br_db/types/
pgsql_transaction.rs

1use br_pgsql::connect::Connect;
2use log::{info, warn};
3use std::collections::HashMap;
4use std::sync::{Arc, Condvar, Mutex, RwLock};
5use std::time::{Duration, Instant};
6
7/// 与 MySQL TransactionManager 对齐:全局写锁串行化所有事务操作
8#[derive(Debug)]
9struct TransactionInfo {
10    conn: Connect,
11    depth: i32,
12    created_at: Instant,
13}
14
15pub struct PgsqlTransactionManager {
16    connections: RwLock<HashMap<String, TransactionInfo>>,
17    table_locks: RwLock<HashMap<String, String>>,
18    table_cond: Condvar,
19    table_mutex: Mutex<()>,
20    timeout: Duration,
21}
22
23impl PgsqlTransactionManager {
24    pub fn new(timeout_secs: u64) -> Self {
25        Self {
26            connections: RwLock::new(HashMap::new()),
27            table_locks: RwLock::new(HashMap::new()),
28            table_cond: Condvar::new(),
29            table_mutex: Mutex::new(()),
30            timeout: Duration::from_secs(timeout_secs),
31        }
32    }
33
34    pub fn is_in_transaction(&self, key: &str) -> bool {
35        match self.connections.read() {
36            Ok(guard) => guard.contains_key(key),
37            Err(poisoned) => poisoned.into_inner().contains_key(key),
38        }
39    }
40
41    pub fn get_depth(&self, key: &str) -> i32 {
42        match self.connections.read() {
43            Ok(guard) => guard.get(key).map(|t| t.depth).unwrap_or(0),
44            Err(poisoned) => poisoned.into_inner().get(key).map(|t| t.depth).unwrap_or(0),
45        }
46    }
47
48    pub fn increment_depth(&self, key: &str) -> bool {
49        let mut conns = match self.connections.write() {
50            Ok(guard) => guard,
51            Err(poisoned) => poisoned.into_inner(),
52        };
53
54        if let Some(txn_info) = conns.get_mut(key) {
55            txn_info.depth += 1;
56            info!(
57                "PgsqlTransactionManager: nested transaction depth={}",
58                txn_info.depth
59            );
60            return true;
61        }
62        false
63    }
64
65    pub fn start(&self, key: &str, conn: Connect) -> bool {
66        let mut conns = match self.connections.write() {
67            Ok(guard) => guard,
68            Err(poisoned) => poisoned.into_inner(),
69        };
70
71        conns.insert(
72            key.to_string(),
73            TransactionInfo {
74                conn,
75                depth: 1,
76                created_at: Instant::now(),
77            },
78        );
79        true
80    }
81
82    /// 与 MySQL 对齐:全局写锁串行化,防止多线程事务并发导致死锁
83    pub fn with_conn<F, R>(&self, key: &str, f: F) -> Option<R>
84    where
85        F: FnOnce(&mut Connect) -> R,
86    {
87        let mut conns = match self.connections.write() {
88            Ok(guard) => guard,
89            Err(poisoned) => poisoned.into_inner(),
90        };
91
92        conns.get_mut(key).map(|txn_info| f(&mut txn_info.conn))
93    }
94
95    pub fn acquire_table_lock(&self, table: &str, thread_id: &str, timeout: Duration) -> bool {
96        let start = Instant::now();
97
98        loop {
99            {
100                let mut locks = match self.table_locks.write() {
101                    Ok(guard) => guard,
102                    Err(poisoned) => poisoned.into_inner(),
103                };
104
105                match locks.get(table) {
106                    None => {
107                        locks.insert(table.to_string(), thread_id.to_string());
108                        return true;
109                    }
110                    Some(owner) if owner == thread_id => {
111                        return true;
112                    }
113                    _ => {}
114                }
115            }
116
117            if start.elapsed() > timeout {
118                warn!("PgsqlTransactionManager: table lock timeout for {table}");
119                return false;
120            }
121
122            let guard = match self.table_mutex.lock() {
123                Ok(g) => g,
124                Err(poisoned) => poisoned.into_inner(),
125            };
126
127            let remaining = timeout.saturating_sub(start.elapsed());
128            if remaining.is_zero() {
129                return false;
130            }
131
132            let wait_time = Duration::from_millis(50).min(remaining);
133            let _ = self.table_cond.wait_timeout(guard, wait_time);
134        }
135    }
136
137    pub fn release_table_lock(&self, table: &str, thread_id: &str) {
138        let mut locks = match self.table_locks.write() {
139            Ok(guard) => guard,
140            Err(poisoned) => poisoned.into_inner(),
141        };
142
143        if let Some(owner) = locks.get(table) {
144            if owner == thread_id {
145                locks.remove(table);
146                self.table_cond.notify_all();
147            }
148        }
149    }
150
151    pub fn release_all_table_locks(&self, thread_id: &str) {
152        let mut locks = match self.table_locks.write() {
153            Ok(guard) => guard,
154            Err(poisoned) => poisoned.into_inner(),
155        };
156
157        let tables_to_remove: Vec<String> = locks
158            .iter()
159            .filter(|(_, owner)| *owner == thread_id)
160            .map(|(table, _)| table.clone())
161            .collect();
162
163        for table in tables_to_remove {
164            locks.remove(&table);
165        }
166
167        self.table_cond.notify_all();
168    }
169
170    pub fn decrement_or_finish(&self, key: &str, thread_id: &str) -> Option<i32> {
171        let mut conns = match self.connections.write() {
172            Ok(guard) => guard,
173            Err(poisoned) => poisoned.into_inner(),
174        };
175
176        if let Some(txn_info) = conns.get_mut(key) {
177            if txn_info.depth > 1 {
178                txn_info.depth -= 1;
179                return Some(txn_info.depth);
180            }
181        }
182
183        conns.remove(key);
184        drop(conns);
185        self.release_all_table_locks(thread_id);
186        Some(0)
187    }
188
189    pub fn remove(&self, key: &str, thread_id: &str) -> Option<Connect> {
190        let mut conns = match self.connections.write() {
191            Ok(guard) => guard,
192            Err(poisoned) => poisoned.into_inner(),
193        };
194        let removed = conns.remove(key);
195        drop(conns);
196        self.release_all_table_locks(thread_id);
197        removed.map(|t| t.conn)
198    }
199
200    pub fn cleanup_expired(&self) {
201        let expired: Vec<String> = {
202            let conns = match self.connections.read() {
203                Ok(guard) => guard,
204                Err(poisoned) => poisoned.into_inner(),
205            };
206            conns
207                .iter()
208                .filter(|(_, txn_info)| txn_info.created_at.elapsed() > self.timeout)
209                .map(|(key, _)| key.clone())
210                .collect()
211        };
212
213        if expired.is_empty() {
214            return;
215        }
216
217        let mut conns = match self.connections.write() {
218            Ok(guard) => guard,
219            Err(poisoned) => poisoned.into_inner(),
220        };
221
222        for key in expired {
223            warn!("PgsqlTransactionManager: cleaning up expired transaction: {key}");
224            conns.remove(&key);
225        }
226    }
227
228    pub fn stats(&self) -> (usize, usize) {
229        let conn_count = match self.connections.read() {
230            Ok(g) => g.len(),
231            Err(poisoned) => poisoned.into_inner().len(),
232        };
233        let lock_count = match self.table_locks.read() {
234            Ok(g) => g.len(),
235            Err(poisoned) => poisoned.into_inner().len(),
236        };
237        (conn_count, lock_count)
238    }
239}
240
241lazy_static::lazy_static! {
242    pub static ref PGSQL_TRANSACTION_MANAGER: Arc<PgsqlTransactionManager> =
243        Arc::new(PgsqlTransactionManager::new(300));
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249    use std::thread;
250
251    #[test]
252    fn test_new_creates_empty_manager() {
253        let tm = PgsqlTransactionManager::new(60);
254        let (conn_count, lock_count) = tm.stats();
255        assert_eq!(conn_count, 0);
256        assert_eq!(lock_count, 0);
257    }
258
259    #[test]
260    fn test_is_in_transaction_false_when_empty() {
261        let tm = PgsqlTransactionManager::new(60);
262        assert!(!tm.is_in_transaction("nonexistent"));
263    }
264
265    #[test]
266    fn test_get_depth_zero_when_empty() {
267        let tm = PgsqlTransactionManager::new(60);
268        assert_eq!(tm.get_depth("nonexistent"), 0);
269    }
270
271    #[test]
272    fn test_increment_depth_returns_false_when_no_transaction() {
273        let tm = PgsqlTransactionManager::new(60);
274        assert!(!tm.increment_depth("nonexistent"));
275    }
276
277    #[test]
278    fn test_acquire_table_lock_new_table() {
279        let tm = PgsqlTransactionManager::new(60);
280        assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
281        let (_, lock_count) = tm.stats();
282        assert_eq!(lock_count, 1);
283    }
284
285    #[test]
286    fn test_acquire_table_lock_same_thread_reentrant() {
287        let tm = PgsqlTransactionManager::new(60);
288        let thread_id = "test_thread_1";
289
290        assert!(tm.acquire_table_lock("users", thread_id, Duration::from_secs(1)));
291        assert!(tm.acquire_table_lock("users", thread_id, Duration::from_secs(1)));
292
293        tm.release_table_lock("users", thread_id);
294    }
295
296    #[test]
297    fn test_acquire_table_lock_different_thread_timeout() {
298        let tm = Arc::new(PgsqlTransactionManager::new(60));
299        let tm2 = tm.clone();
300
301        assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
302
303        let handle = thread::spawn(move || {
304            let result = tm2.acquire_table_lock("users", "thread_b", Duration::from_millis(100));
305            assert!(!result);
306        });
307
308        handle.join().unwrap();
309        tm.release_table_lock("users", "thread_a");
310    }
311
312    #[test]
313    fn test_release_table_lock_by_owner() {
314        let tm = PgsqlTransactionManager::new(60);
315        assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
316        assert_eq!(tm.stats().1, 1);
317
318        tm.release_table_lock("users", "thread_a");
319        assert_eq!(tm.stats().1, 0);
320    }
321
322    #[test]
323    fn test_release_table_lock_wrong_thread_does_nothing() {
324        let tm = PgsqlTransactionManager::new(60);
325        assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
326
327        tm.release_table_lock("users", "thread_b");
328        assert_eq!(tm.stats().1, 1);
329    }
330
331    #[test]
332    fn test_release_all_table_locks() {
333        let tm = PgsqlTransactionManager::new(60);
334        let thread_id = "test_thread_2";
335
336        tm.acquire_table_lock("table1", thread_id, Duration::from_secs(1));
337        tm.acquire_table_lock("table2", thread_id, Duration::from_secs(1));
338        tm.acquire_table_lock("table3", thread_id, Duration::from_secs(1));
339
340        let (_, lock_count) = tm.stats();
341        assert_eq!(lock_count, 3);
342
343        tm.release_all_table_locks(thread_id);
344
345        let (_, lock_count) = tm.stats();
346        assert_eq!(lock_count, 0);
347    }
348
349    #[test]
350    fn test_release_all_table_locks_only_removes_own() {
351        let tm = Arc::new(PgsqlTransactionManager::new(60));
352        let tm2 = tm.clone();
353
354        tm.acquire_table_lock("table_a", "thread_a", Duration::from_secs(1));
355
356        let handle = thread::spawn(move || {
357            tm2.acquire_table_lock("table_b", "thread_b", Duration::from_secs(1));
358        });
359        handle.join().unwrap();
360
361        assert_eq!(tm.stats().1, 2);
362
363        tm.release_all_table_locks("thread_a");
364
365        assert_eq!(tm.stats().1, 1);
366    }
367
368    #[test]
369    fn test_decrement_or_finish_no_transaction() {
370        let tm = PgsqlTransactionManager::new(60);
371        tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1));
372
373        let result = tm.decrement_or_finish("nonexistent", "thread_a");
374        assert_eq!(result, Some(0));
375
376        assert_eq!(tm.stats().1, 0);
377    }
378
379    #[test]
380    fn test_remove_no_transaction() {
381        let tm = PgsqlTransactionManager::new(60);
382        tm.remove("nonexistent", "thread_a");
383    }
384
385    #[test]
386    fn test_cleanup_expired_no_transactions() {
387        let tm = PgsqlTransactionManager::new(60);
388        tm.cleanup_expired();
389    }
390
391    #[test]
392    fn test_stats_reflects_locks() {
393        let tm = PgsqlTransactionManager::new(60);
394        tm.acquire_table_lock("t1", "thread_a", Duration::from_secs(1));
395        tm.acquire_table_lock("t2", "thread_a", Duration::from_secs(1));
396        tm.acquire_table_lock("t3", "thread_b", Duration::from_secs(1));
397
398        let (conn_count, lock_count) = tm.stats();
399        assert_eq!(conn_count, 0);
400        assert_eq!(lock_count, 3);
401    }
402}