1use br_pgsql::connect::Connect;
2use log::{info, warn};
3use std::collections::HashMap;
4use std::sync::{Arc, Condvar, Mutex, RwLock};
5use std::time::{Duration, Instant};
6
7#[derive(Debug)]
8struct TransactionInfo {
9 conn: Arc<Mutex<Connect>>,
10 depth: i32,
11 created_at: Instant,
12}
13
14pub struct PgsqlTransactionManager {
15 connections: RwLock<HashMap<String, TransactionInfo>>,
16 table_locks: RwLock<HashMap<String, String>>,
17 table_cond: Condvar,
18 table_mutex: Mutex<()>,
19 timeout: Duration,
20}
21
22impl PgsqlTransactionManager {
23 pub fn new(timeout_secs: u64) -> Self {
24 Self {
25 connections: RwLock::new(HashMap::new()),
26 table_locks: RwLock::new(HashMap::new()),
27 table_cond: Condvar::new(),
28 table_mutex: Mutex::new(()),
29 timeout: Duration::from_secs(timeout_secs),
30 }
31 }
32
33 pub fn is_in_transaction(&self, key: &str) -> bool {
34 match self.connections.read() {
35 Ok(guard) => guard.contains_key(key),
36 Err(poisoned) => poisoned.into_inner().contains_key(key),
37 }
38 }
39
40 pub fn get_depth(&self, key: &str) -> i32 {
41 match self.connections.read() {
42 Ok(guard) => guard.get(key).map(|t| t.depth).unwrap_or(0),
43 Err(poisoned) => poisoned.into_inner().get(key).map(|t| t.depth).unwrap_or(0),
44 }
45 }
46
47 pub fn increment_depth(&self, key: &str) -> bool {
48 let mut conns = match self.connections.write() {
49 Ok(guard) => guard,
50 Err(poisoned) => poisoned.into_inner(),
51 };
52
53 if let Some(txn_info) = conns.get_mut(key) {
54 txn_info.depth += 1;
55 info!(
56 "PgsqlTransactionManager: nested transaction depth={}",
57 txn_info.depth
58 );
59 return true;
60 }
61 false
62 }
63
64 pub fn start(&self, key: &str, conn: Connect) -> bool {
65 let mut conns = match self.connections.write() {
66 Ok(guard) => guard,
67 Err(poisoned) => poisoned.into_inner(),
68 };
69
70 conns.insert(
71 key.to_string(),
72 TransactionInfo {
73 conn: Arc::new(Mutex::new(conn)),
74 depth: 1,
75 created_at: Instant::now(),
76 },
77 );
78 true
79 }
80
81 pub fn with_conn<F, R>(&self, key: &str, f: F) -> Option<R>
82 where
83 F: FnOnce(&mut Connect) -> R,
84 {
85 let conn = {
86 let conns = match self.connections.read() {
87 Ok(guard) => guard,
88 Err(poisoned) => poisoned.into_inner(),
89 };
90 conns.get(key).map(|txn_info| Arc::clone(&txn_info.conn))
91 };
92
93 conn.map(|arc_conn| {
94 let mut conn_guard = match arc_conn.lock() {
95 Ok(guard) => guard,
96 Err(poisoned) => poisoned.into_inner(),
97 };
98 f(&mut conn_guard)
99 })
100 }
101
102 pub fn acquire_table_lock(&self, table: &str, thread_id: &str, timeout: Duration) -> bool {
103 let start = Instant::now();
104
105 loop {
106 {
107 let mut locks = match self.table_locks.write() {
108 Ok(guard) => guard,
109 Err(poisoned) => poisoned.into_inner(),
110 };
111
112 match locks.get(table) {
113 None => {
114 locks.insert(table.to_string(), thread_id.to_string());
115 return true;
116 }
117 Some(owner) if owner == thread_id => {
118 return true;
119 }
120 _ => {}
121 }
122 }
123
124 if start.elapsed() > timeout {
125 warn!("PgsqlTransactionManager: table lock timeout for {table}");
126 return false;
127 }
128
129 let guard = match self.table_mutex.lock() {
130 Ok(g) => g,
131 Err(poisoned) => poisoned.into_inner(),
132 };
133
134 let remaining = timeout.saturating_sub(start.elapsed());
135 if remaining.is_zero() {
136 return false;
137 }
138
139 let wait_time = Duration::from_millis(50).min(remaining);
140 let _ = self.table_cond.wait_timeout(guard, wait_time);
141 }
142 }
143
144 pub fn release_table_lock(&self, table: &str, thread_id: &str) {
145 let mut locks = match self.table_locks.write() {
146 Ok(guard) => guard,
147 Err(poisoned) => poisoned.into_inner(),
148 };
149
150 if let Some(owner) = locks.get(table) {
151 if owner == thread_id {
152 locks.remove(table);
153 self.table_cond.notify_all();
154 }
155 }
156 }
157
158 pub fn release_all_table_locks(&self, thread_id: &str) {
159 let mut locks = match self.table_locks.write() {
160 Ok(guard) => guard,
161 Err(poisoned) => poisoned.into_inner(),
162 };
163
164 let tables_to_remove: Vec<String> = locks
165 .iter()
166 .filter(|(_, owner)| *owner == thread_id)
167 .map(|(table, _)| table.clone())
168 .collect();
169
170 for table in tables_to_remove {
171 locks.remove(&table);
172 }
173
174 self.table_cond.notify_all();
175 }
176
177 pub fn decrement_or_finish(&self, key: &str, thread_id: &str) -> Option<i32> {
178 let mut conns = match self.connections.write() {
179 Ok(guard) => guard,
180 Err(poisoned) => poisoned.into_inner(),
181 };
182
183 if let Some(txn_info) = conns.get_mut(key) {
184 if txn_info.depth > 1 {
185 txn_info.depth -= 1;
186 return Some(txn_info.depth);
187 }
188 }
189
190 conns.remove(key);
191 drop(conns);
192 self.release_all_table_locks(thread_id);
193 Some(0)
194 }
195
196 pub fn remove(&self, key: &str, thread_id: &str) -> Option<Connect> {
197 let mut conns = match self.connections.write() {
198 Ok(guard) => guard,
199 Err(poisoned) => poisoned.into_inner(),
200 };
201 let removed = conns.remove(key);
202 drop(conns);
203 self.release_all_table_locks(thread_id);
204 removed.and_then(|txn_info| match Arc::try_unwrap(txn_info.conn) {
205 Ok(conn_mutex) => Some(match conn_mutex.into_inner() {
206 Ok(conn) => conn,
207 Err(poisoned) => poisoned.into_inner(),
208 }),
209 Err(_) => {
210 warn!("PgsqlTransactionManager: transaction connection still borrowed when removing {key}");
211 None
212 }
213 })
214 }
215
216 pub fn cleanup_expired(&self) {
217 let expired: Vec<String> = {
218 let conns = match self.connections.read() {
219 Ok(guard) => guard,
220 Err(poisoned) => poisoned.into_inner(),
221 };
222 conns
223 .iter()
224 .filter(|(_, txn_info)| txn_info.created_at.elapsed() > self.timeout)
225 .map(|(key, _)| key.clone())
226 .collect()
227 };
228
229 if expired.is_empty() {
230 return;
231 }
232
233 let mut conns = match self.connections.write() {
234 Ok(guard) => guard,
235 Err(poisoned) => poisoned.into_inner(),
236 };
237
238 for key in expired {
239 warn!("PgsqlTransactionManager: cleaning up expired transaction: {key}");
240 conns.remove(&key);
241 }
242 }
243
244 pub fn stats(&self) -> (usize, usize) {
245 let conn_count = match self.connections.read() {
246 Ok(g) => g.len(),
247 Err(poisoned) => poisoned.into_inner().len(),
248 };
249 let lock_count = match self.table_locks.read() {
250 Ok(g) => g.len(),
251 Err(poisoned) => poisoned.into_inner().len(),
252 };
253 (conn_count, lock_count)
254 }
255}
256
257lazy_static::lazy_static! {
258 pub static ref PGSQL_TRANSACTION_MANAGER: Arc<PgsqlTransactionManager> =
259 Arc::new(PgsqlTransactionManager::new(300));
260}
261
262#[cfg(test)]
263mod tests {
264 use super::*;
265 use std::thread;
266
267 #[test]
268 fn test_new_creates_empty_manager() {
269 let tm = PgsqlTransactionManager::new(60);
270 let (conn_count, lock_count) = tm.stats();
271 assert_eq!(conn_count, 0);
272 assert_eq!(lock_count, 0);
273 }
274
275 #[test]
276 fn test_is_in_transaction_false_when_empty() {
277 let tm = PgsqlTransactionManager::new(60);
278 assert!(!tm.is_in_transaction("nonexistent"));
279 }
280
281 #[test]
282 fn test_get_depth_zero_when_empty() {
283 let tm = PgsqlTransactionManager::new(60);
284 assert_eq!(tm.get_depth("nonexistent"), 0);
285 }
286
287 #[test]
288 fn test_increment_depth_returns_false_when_no_transaction() {
289 let tm = PgsqlTransactionManager::new(60);
290 assert!(!tm.increment_depth("nonexistent"));
291 }
292
293 #[test]
294 fn test_acquire_table_lock_new_table() {
295 let tm = PgsqlTransactionManager::new(60);
296 assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
297 let (_, lock_count) = tm.stats();
298 assert_eq!(lock_count, 1);
299 }
300
301 #[test]
302 fn test_acquire_table_lock_same_thread_reentrant() {
303 let tm = PgsqlTransactionManager::new(60);
304 let thread_id = "test_thread_1";
305
306 assert!(tm.acquire_table_lock("users", thread_id, Duration::from_secs(1)));
307 assert!(tm.acquire_table_lock("users", thread_id, Duration::from_secs(1)));
308
309 tm.release_table_lock("users", thread_id);
310 }
311
312 #[test]
313 fn test_acquire_table_lock_different_thread_timeout() {
314 let tm = Arc::new(PgsqlTransactionManager::new(60));
315 let tm2 = tm.clone();
316
317 assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
318
319 let handle = thread::spawn(move || {
320 let result = tm2.acquire_table_lock("users", "thread_b", Duration::from_millis(100));
321 assert!(!result);
322 });
323
324 handle.join().unwrap();
325 tm.release_table_lock("users", "thread_a");
326 }
327
328 #[test]
329 fn test_release_table_lock_by_owner() {
330 let tm = PgsqlTransactionManager::new(60);
331 assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
332 assert_eq!(tm.stats().1, 1);
333
334 tm.release_table_lock("users", "thread_a");
335 assert_eq!(tm.stats().1, 0);
336 }
337
338 #[test]
339 fn test_release_table_lock_wrong_thread_does_nothing() {
340 let tm = PgsqlTransactionManager::new(60);
341 assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
342
343 tm.release_table_lock("users", "thread_b");
344 assert_eq!(tm.stats().1, 1);
345 }
346
347 #[test]
348 fn test_release_all_table_locks() {
349 let tm = PgsqlTransactionManager::new(60);
350 let thread_id = "test_thread_2";
351
352 tm.acquire_table_lock("table1", thread_id, Duration::from_secs(1));
353 tm.acquire_table_lock("table2", thread_id, Duration::from_secs(1));
354 tm.acquire_table_lock("table3", thread_id, Duration::from_secs(1));
355
356 let (_, lock_count) = tm.stats();
357 assert_eq!(lock_count, 3);
358
359 tm.release_all_table_locks(thread_id);
360
361 let (_, lock_count) = tm.stats();
362 assert_eq!(lock_count, 0);
363 }
364
365 #[test]
366 fn test_release_all_table_locks_only_removes_own() {
367 let tm = Arc::new(PgsqlTransactionManager::new(60));
368 let tm2 = tm.clone();
369
370 tm.acquire_table_lock("table_a", "thread_a", Duration::from_secs(1));
371
372 let handle = thread::spawn(move || {
373 tm2.acquire_table_lock("table_b", "thread_b", Duration::from_secs(1));
374 });
375 handle.join().unwrap();
376
377 assert_eq!(tm.stats().1, 2);
378
379 tm.release_all_table_locks("thread_a");
380
381 assert_eq!(tm.stats().1, 1);
382 }
383
384 #[test]
385 fn test_decrement_or_finish_no_transaction() {
386 let tm = PgsqlTransactionManager::new(60);
387 tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1));
388
389 let result = tm.decrement_or_finish("nonexistent", "thread_a");
390 assert_eq!(result, Some(0));
391
392 assert_eq!(tm.stats().1, 0);
393 }
394
395 #[test]
396 fn test_remove_no_transaction() {
397 let tm = PgsqlTransactionManager::new(60);
398 tm.remove("nonexistent", "thread_a");
399 }
400
401 #[test]
402 fn test_cleanup_expired_no_transactions() {
403 let tm = PgsqlTransactionManager::new(60);
404 tm.cleanup_expired();
405 }
406
407 #[test]
408 fn test_stats_reflects_locks() {
409 let tm = PgsqlTransactionManager::new(60);
410 tm.acquire_table_lock("t1", "thread_a", Duration::from_secs(1));
411 tm.acquire_table_lock("t2", "thread_a", Duration::from_secs(1));
412 tm.acquire_table_lock("t3", "thread_b", Duration::from_secs(1));
413
414 let (conn_count, lock_count) = tm.stats();
415 assert_eq!(conn_count, 0);
416 assert_eq!(lock_count, 3);
417 }
418}