1use br_pgsql::connect::Connect;
2use log::{info, warn};
3use std::collections::HashMap;
4use std::sync::atomic::{AtomicI32, Ordering};
5use std::sync::{Arc, Condvar, Mutex, RwLock};
6use std::time::{Duration, Instant};
7
8struct TransactionInfo {
9 conn: Mutex<Connect>,
10 depth: AtomicI32,
11 created_at: Instant,
12}
13
14pub struct PgsqlTransactionManager {
15 connections: RwLock<HashMap<String, Arc<TransactionInfo>>>,
16 table_locks: RwLock<HashMap<String, String>>,
17 table_cond: Condvar,
18 table_mutex: Mutex<()>,
19 timeout: Duration,
20}
21
22impl PgsqlTransactionManager {
23 pub fn new(timeout_secs: u64) -> Self {
24 Self {
25 connections: RwLock::new(HashMap::new()),
26 table_locks: RwLock::new(HashMap::new()),
27 table_cond: Condvar::new(),
28 table_mutex: Mutex::new(()),
29 timeout: Duration::from_secs(timeout_secs),
30 }
31 }
32
33 pub fn is_in_transaction(&self, key: &str) -> bool {
34 match self.connections.read() {
35 Ok(guard) => guard.contains_key(key),
36 Err(poisoned) => poisoned.into_inner().contains_key(key),
37 }
38 }
39
40 pub fn get_depth(&self, key: &str) -> i32 {
41 let conns = match self.connections.read() {
42 Ok(guard) => guard,
43 Err(poisoned) => poisoned.into_inner(),
44 };
45 conns
46 .get(key)
47 .map(|t| t.depth.load(Ordering::SeqCst))
48 .unwrap_or(0)
49 }
50
51 pub fn increment_depth(&self, key: &str) -> bool {
52 let conns = match self.connections.read() {
53 Ok(guard) => guard,
54 Err(poisoned) => poisoned.into_inner(),
55 };
56
57 if let Some(txn_info) = conns.get(key) {
58 let new_depth = txn_info.depth.fetch_add(1, Ordering::SeqCst) + 1;
59 info!(
60 "PgsqlTransactionManager: nested transaction depth={}",
61 new_depth
62 );
63 return true;
64 }
65 false
66 }
67
68 pub fn start(&self, key: &str, conn: Connect) -> bool {
69 let mut conns = match self.connections.write() {
70 Ok(guard) => guard,
71 Err(poisoned) => poisoned.into_inner(),
72 };
73
74 conns.insert(
75 key.to_string(),
76 Arc::new(TransactionInfo {
77 conn: Mutex::new(conn),
78 depth: AtomicI32::new(1),
79 created_at: Instant::now(),
80 }),
81 );
82 true
83 }
84
85 pub fn with_conn<F, R>(&self, key: &str, f: F) -> Option<R>
86 where
87 F: FnOnce(&mut Connect) -> R,
88 {
89 let txn_info = {
90 let conns = match self.connections.read() {
91 Ok(guard) => guard,
92 Err(poisoned) => poisoned.into_inner(),
93 };
94 conns.get(key).cloned()
95 };
96
97 txn_info.map(|info| {
98 let mut conn = match info.conn.lock() {
99 Ok(guard) => guard,
100 Err(poisoned) => poisoned.into_inner(),
101 };
102 f(&mut conn)
103 })
104 }
105
106 pub fn acquire_table_lock(&self, table: &str, thread_id: &str, timeout: Duration) -> bool {
107 let start = Instant::now();
108
109 loop {
110 {
111 let mut locks = match self.table_locks.write() {
112 Ok(guard) => guard,
113 Err(poisoned) => poisoned.into_inner(),
114 };
115
116 match locks.get(table) {
117 None => {
118 locks.insert(table.to_string(), thread_id.to_string());
119 return true;
120 }
121 Some(owner) if owner == thread_id => {
122 return true;
123 }
124 _ => {}
125 }
126 }
127
128 if start.elapsed() > timeout {
129 warn!("PgsqlTransactionManager: table lock timeout for {table}");
130 return false;
131 }
132
133 let guard = match self.table_mutex.lock() {
134 Ok(g) => g,
135 Err(poisoned) => poisoned.into_inner(),
136 };
137
138 let remaining = timeout.saturating_sub(start.elapsed());
139 if remaining.is_zero() {
140 return false;
141 }
142
143 let wait_time = Duration::from_millis(50).min(remaining);
144 let _ = self.table_cond.wait_timeout(guard, wait_time);
145 }
146 }
147
148 pub fn release_table_lock(&self, table: &str, thread_id: &str) {
149 let mut locks = match self.table_locks.write() {
150 Ok(guard) => guard,
151 Err(poisoned) => poisoned.into_inner(),
152 };
153
154 if let Some(owner) = locks.get(table) {
155 if owner == thread_id {
156 locks.remove(table);
157 self.table_cond.notify_all();
158 }
159 }
160 }
161
162 pub fn release_all_table_locks(&self, thread_id: &str) {
163 let mut locks = match self.table_locks.write() {
164 Ok(guard) => guard,
165 Err(poisoned) => poisoned.into_inner(),
166 };
167
168 let tables_to_remove: Vec<String> = locks
169 .iter()
170 .filter(|(_, owner)| *owner == thread_id)
171 .map(|(table, _)| table.clone())
172 .collect();
173
174 for table in tables_to_remove {
175 locks.remove(&table);
176 }
177
178 self.table_cond.notify_all();
179 }
180
181 pub fn decrement_or_finish(&self, key: &str, thread_id: &str) -> Option<i32> {
182 let txn_info = {
183 let conns = match self.connections.read() {
184 Ok(guard) => guard,
185 Err(poisoned) => poisoned.into_inner(),
186 };
187 conns.get(key).cloned()
188 };
189
190 if let Some(info) = txn_info {
191 let old_depth = info.depth.fetch_sub(1, Ordering::SeqCst);
192 if old_depth > 1 {
193 return Some(old_depth - 1);
194 }
195 }
196
197 let mut conns = match self.connections.write() {
198 Ok(guard) => guard,
199 Err(poisoned) => poisoned.into_inner(),
200 };
201 conns.remove(key);
202 drop(conns);
203 self.release_all_table_locks(thread_id);
204 Some(0)
205 }
206
207 pub fn remove(&self, key: &str, thread_id: &str) {
208 let mut conns = match self.connections.write() {
209 Ok(guard) => guard,
210 Err(poisoned) => poisoned.into_inner(),
211 };
212 conns.remove(key);
213 drop(conns);
214 self.release_all_table_locks(thread_id);
215 }
216
217 pub fn cleanup_expired(&self) {
218 let expired: Vec<String> = {
219 let conns = match self.connections.read() {
220 Ok(guard) => guard,
221 Err(poisoned) => poisoned.into_inner(),
222 };
223 conns
224 .iter()
225 .filter(|(_, txn_info)| txn_info.created_at.elapsed() > self.timeout)
226 .map(|(key, _)| key.clone())
227 .collect()
228 };
229
230 if expired.is_empty() {
231 return;
232 }
233
234 let mut conns = match self.connections.write() {
235 Ok(guard) => guard,
236 Err(poisoned) => poisoned.into_inner(),
237 };
238
239 for key in expired {
240 warn!("PgsqlTransactionManager: cleaning up expired transaction: {key}");
241 conns.remove(&key);
242 }
243 }
244
245 pub fn stats(&self) -> (usize, usize) {
246 let conn_count = match self.connections.read() {
247 Ok(g) => g.len(),
248 Err(poisoned) => poisoned.into_inner().len(),
249 };
250 let lock_count = match self.table_locks.read() {
251 Ok(g) => g.len(),
252 Err(poisoned) => poisoned.into_inner().len(),
253 };
254 (conn_count, lock_count)
255 }
256}
257
258lazy_static::lazy_static! {
259 pub static ref PGSQL_TRANSACTION_MANAGER: Arc<PgsqlTransactionManager> =
260 Arc::new(PgsqlTransactionManager::new(300));
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266 use std::thread;
267
268 #[test]
269 fn test_new_creates_empty_manager() {
270 let tm = PgsqlTransactionManager::new(60);
271 let (conn_count, lock_count) = tm.stats();
272 assert_eq!(conn_count, 0);
273 assert_eq!(lock_count, 0);
274 }
275
276 #[test]
277 fn test_is_in_transaction_false_when_empty() {
278 let tm = PgsqlTransactionManager::new(60);
279 assert!(!tm.is_in_transaction("nonexistent"));
280 }
281
282 #[test]
283 fn test_get_depth_zero_when_empty() {
284 let tm = PgsqlTransactionManager::new(60);
285 assert_eq!(tm.get_depth("nonexistent"), 0);
286 }
287
288 #[test]
289 fn test_increment_depth_returns_false_when_no_transaction() {
290 let tm = PgsqlTransactionManager::new(60);
291 assert!(!tm.increment_depth("nonexistent"));
292 }
293
294 #[test]
295 fn test_acquire_table_lock_new_table() {
296 let tm = PgsqlTransactionManager::new(60);
297 assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
298 let (_, lock_count) = tm.stats();
299 assert_eq!(lock_count, 1);
300 }
301
302 #[test]
303 fn test_acquire_table_lock_same_thread_reentrant() {
304 let tm = PgsqlTransactionManager::new(60);
305 let thread_id = "test_thread_1";
306
307 assert!(tm.acquire_table_lock("users", thread_id, Duration::from_secs(1)));
308 assert!(tm.acquire_table_lock("users", thread_id, Duration::from_secs(1)));
309
310 tm.release_table_lock("users", thread_id);
311 }
312
313 #[test]
314 fn test_acquire_table_lock_different_thread_timeout() {
315 let tm = Arc::new(PgsqlTransactionManager::new(60));
316 let tm2 = tm.clone();
317
318 assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
319
320 let handle = thread::spawn(move || {
321 let result = tm2.acquire_table_lock("users", "thread_b", Duration::from_millis(100));
322 assert!(!result);
323 });
324
325 handle.join().unwrap();
326 tm.release_table_lock("users", "thread_a");
327 }
328
329 #[test]
330 fn test_release_table_lock_by_owner() {
331 let tm = PgsqlTransactionManager::new(60);
332 assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
333 assert_eq!(tm.stats().1, 1);
334
335 tm.release_table_lock("users", "thread_a");
336 assert_eq!(tm.stats().1, 0);
337 }
338
339 #[test]
340 fn test_release_table_lock_wrong_thread_does_nothing() {
341 let tm = PgsqlTransactionManager::new(60);
342 assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
343
344 tm.release_table_lock("users", "thread_b");
345 assert_eq!(tm.stats().1, 1);
346 }
347
348 #[test]
349 fn test_release_all_table_locks() {
350 let tm = PgsqlTransactionManager::new(60);
351 let thread_id = "test_thread_2";
352
353 tm.acquire_table_lock("table1", thread_id, Duration::from_secs(1));
354 tm.acquire_table_lock("table2", thread_id, Duration::from_secs(1));
355 tm.acquire_table_lock("table3", thread_id, Duration::from_secs(1));
356
357 let (_, lock_count) = tm.stats();
358 assert_eq!(lock_count, 3);
359
360 tm.release_all_table_locks(thread_id);
361
362 let (_, lock_count) = tm.stats();
363 assert_eq!(lock_count, 0);
364 }
365
366 #[test]
367 fn test_release_all_table_locks_only_removes_own() {
368 let tm = Arc::new(PgsqlTransactionManager::new(60));
369 let tm2 = tm.clone();
370
371 tm.acquire_table_lock("table_a", "thread_a", Duration::from_secs(1));
372
373 let handle = thread::spawn(move || {
374 tm2.acquire_table_lock("table_b", "thread_b", Duration::from_secs(1));
375 });
376 handle.join().unwrap();
377
378 assert_eq!(tm.stats().1, 2);
379
380 tm.release_all_table_locks("thread_a");
381
382 assert_eq!(tm.stats().1, 1);
383 }
384
385 #[test]
386 fn test_decrement_or_finish_no_transaction() {
387 let tm = PgsqlTransactionManager::new(60);
388 tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1));
389
390 let result = tm.decrement_or_finish("nonexistent", "thread_a");
391 assert_eq!(result, Some(0));
392
393 assert_eq!(tm.stats().1, 0);
394 }
395
396 #[test]
397 fn test_remove_no_transaction() {
398 let tm = PgsqlTransactionManager::new(60);
399 tm.remove("nonexistent", "thread_a");
400 }
401
402 #[test]
403 fn test_cleanup_expired_no_transactions() {
404 let tm = PgsqlTransactionManager::new(60);
405 tm.cleanup_expired();
406 }
407
408 #[test]
409 fn test_stats_reflects_locks() {
410 let tm = PgsqlTransactionManager::new(60);
411 tm.acquire_table_lock("t1", "thread_a", Duration::from_secs(1));
412 tm.acquire_table_lock("t2", "thread_a", Duration::from_secs(1));
413 tm.acquire_table_lock("t3", "thread_b", Duration::from_secs(1));
414
415 let (conn_count, lock_count) = tm.stats();
416 assert_eq!(conn_count, 0);
417 assert_eq!(lock_count, 3);
418 }
419}