1use br_pgsql::connect::Connect;
2use log::{info, warn};
3use std::collections::HashMap;
4use std::sync::{Arc, Condvar, Mutex, RwLock};
5use std::time::{Duration, Instant};
6
7#[derive(Debug)]
9struct TransactionInfo {
10 conn: Connect,
11 depth: i32,
12 created_at: Instant,
13}
14
15pub struct PgsqlTransactionManager {
16 connections: RwLock<HashMap<String, TransactionInfo>>,
17 table_locks: RwLock<HashMap<String, String>>,
18 table_cond: Condvar,
19 table_mutex: Mutex<()>,
20 timeout: Duration,
21}
22
23impl PgsqlTransactionManager {
24 pub fn new(timeout_secs: u64) -> Self {
25 Self {
26 connections: RwLock::new(HashMap::new()),
27 table_locks: RwLock::new(HashMap::new()),
28 table_cond: Condvar::new(),
29 table_mutex: Mutex::new(()),
30 timeout: Duration::from_secs(timeout_secs),
31 }
32 }
33
34 pub fn is_in_transaction(&self, key: &str) -> bool {
35 match self.connections.read() {
36 Ok(guard) => guard.contains_key(key),
37 Err(poisoned) => poisoned.into_inner().contains_key(key),
38 }
39 }
40
41 pub fn get_depth(&self, key: &str) -> i32 {
42 match self.connections.read() {
43 Ok(guard) => guard.get(key).map(|t| t.depth).unwrap_or(0),
44 Err(poisoned) => poisoned.into_inner().get(key).map(|t| t.depth).unwrap_or(0),
45 }
46 }
47
48 pub fn increment_depth(&self, key: &str) -> bool {
49 let mut conns = match self.connections.write() {
50 Ok(guard) => guard,
51 Err(poisoned) => poisoned.into_inner(),
52 };
53
54 if let Some(txn_info) = conns.get_mut(key) {
55 txn_info.depth += 1;
56 info!(
57 "PgsqlTransactionManager: nested transaction depth={}",
58 txn_info.depth
59 );
60 return true;
61 }
62 false
63 }
64
65 pub fn start(&self, key: &str, conn: Connect) -> bool {
66 let mut conns = match self.connections.write() {
67 Ok(guard) => guard,
68 Err(poisoned) => poisoned.into_inner(),
69 };
70
71 conns.insert(
72 key.to_string(),
73 TransactionInfo {
74 conn,
75 depth: 1,
76 created_at: Instant::now(),
77 },
78 );
79 true
80 }
81
82 pub fn with_conn<F, R>(&self, key: &str, f: F) -> Option<R>
84 where
85 F: FnOnce(&mut Connect) -> R,
86 {
87 let mut conns = match self.connections.write() {
88 Ok(guard) => guard,
89 Err(poisoned) => poisoned.into_inner(),
90 };
91
92 conns.get_mut(key).map(|txn_info| f(&mut txn_info.conn))
93 }
94
95 pub fn acquire_table_lock(&self, table: &str, thread_id: &str, timeout: Duration) -> bool {
96 let start = Instant::now();
97
98 loop {
99 {
100 let mut locks = match self.table_locks.write() {
101 Ok(guard) => guard,
102 Err(poisoned) => poisoned.into_inner(),
103 };
104
105 match locks.get(table) {
106 None => {
107 locks.insert(table.to_string(), thread_id.to_string());
108 return true;
109 }
110 Some(owner) if owner == thread_id => {
111 return true;
112 }
113 _ => {}
114 }
115 }
116
117 if start.elapsed() > timeout {
118 warn!("PgsqlTransactionManager: table lock timeout for {table}");
119 return false;
120 }
121
122 let guard = match self.table_mutex.lock() {
123 Ok(g) => g,
124 Err(poisoned) => poisoned.into_inner(),
125 };
126
127 let remaining = timeout.saturating_sub(start.elapsed());
128 if remaining.is_zero() {
129 return false;
130 }
131
132 let wait_time = Duration::from_millis(50).min(remaining);
133 let _ = self.table_cond.wait_timeout(guard, wait_time);
134 }
135 }
136
137 pub fn release_table_lock(&self, table: &str, thread_id: &str) {
138 let mut locks = match self.table_locks.write() {
139 Ok(guard) => guard,
140 Err(poisoned) => poisoned.into_inner(),
141 };
142
143 if let Some(owner) = locks.get(table) {
144 if owner == thread_id {
145 locks.remove(table);
146 self.table_cond.notify_all();
147 }
148 }
149 }
150
151 pub fn release_all_table_locks(&self, thread_id: &str) {
152 let mut locks = match self.table_locks.write() {
153 Ok(guard) => guard,
154 Err(poisoned) => poisoned.into_inner(),
155 };
156
157 let tables_to_remove: Vec<String> = locks
158 .iter()
159 .filter(|(_, owner)| *owner == thread_id)
160 .map(|(table, _)| table.clone())
161 .collect();
162
163 for table in tables_to_remove {
164 locks.remove(&table);
165 }
166
167 self.table_cond.notify_all();
168 }
169
170 pub fn decrement_or_finish(&self, key: &str, thread_id: &str) -> Option<i32> {
171 let mut conns = match self.connections.write() {
172 Ok(guard) => guard,
173 Err(poisoned) => poisoned.into_inner(),
174 };
175
176 if let Some(txn_info) = conns.get_mut(key) {
177 if txn_info.depth > 1 {
178 txn_info.depth -= 1;
179 return Some(txn_info.depth);
180 }
181 }
182
183 conns.remove(key);
184 drop(conns);
185 self.release_all_table_locks(thread_id);
186 Some(0)
187 }
188
189 pub fn remove(&self, key: &str, thread_id: &str) -> Option<Connect> {
190 let mut conns = match self.connections.write() {
191 Ok(guard) => guard,
192 Err(poisoned) => poisoned.into_inner(),
193 };
194 let removed = conns.remove(key);
195 drop(conns);
196 self.release_all_table_locks(thread_id);
197 removed.map(|t| t.conn)
198 }
199
200 pub fn cleanup_expired(&self) {
201 let expired: Vec<String> = {
202 let conns = match self.connections.read() {
203 Ok(guard) => guard,
204 Err(poisoned) => poisoned.into_inner(),
205 };
206 conns
207 .iter()
208 .filter(|(_, txn_info)| txn_info.created_at.elapsed() > self.timeout)
209 .map(|(key, _)| key.clone())
210 .collect()
211 };
212
213 if expired.is_empty() {
214 return;
215 }
216
217 let mut conns = match self.connections.write() {
218 Ok(guard) => guard,
219 Err(poisoned) => poisoned.into_inner(),
220 };
221
222 for key in expired {
223 warn!("PgsqlTransactionManager: cleaning up expired transaction: {key}");
224 conns.remove(&key);
225 }
226 }
227
228 pub fn stats(&self) -> (usize, usize) {
229 let conn_count = match self.connections.read() {
230 Ok(g) => g.len(),
231 Err(poisoned) => poisoned.into_inner().len(),
232 };
233 let lock_count = match self.table_locks.read() {
234 Ok(g) => g.len(),
235 Err(poisoned) => poisoned.into_inner().len(),
236 };
237 (conn_count, lock_count)
238 }
239}
240
241lazy_static::lazy_static! {
242 pub static ref PGSQL_TRANSACTION_MANAGER: Arc<PgsqlTransactionManager> =
243 Arc::new(PgsqlTransactionManager::new(300));
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249 use std::thread;
250
251 #[test]
252 fn test_new_creates_empty_manager() {
253 let tm = PgsqlTransactionManager::new(60);
254 let (conn_count, lock_count) = tm.stats();
255 assert_eq!(conn_count, 0);
256 assert_eq!(lock_count, 0);
257 }
258
259 #[test]
260 fn test_is_in_transaction_false_when_empty() {
261 let tm = PgsqlTransactionManager::new(60);
262 assert!(!tm.is_in_transaction("nonexistent"));
263 }
264
265 #[test]
266 fn test_get_depth_zero_when_empty() {
267 let tm = PgsqlTransactionManager::new(60);
268 assert_eq!(tm.get_depth("nonexistent"), 0);
269 }
270
271 #[test]
272 fn test_increment_depth_returns_false_when_no_transaction() {
273 let tm = PgsqlTransactionManager::new(60);
274 assert!(!tm.increment_depth("nonexistent"));
275 }
276
277 #[test]
278 fn test_acquire_table_lock_new_table() {
279 let tm = PgsqlTransactionManager::new(60);
280 assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
281 let (_, lock_count) = tm.stats();
282 assert_eq!(lock_count, 1);
283 }
284
285 #[test]
286 fn test_acquire_table_lock_same_thread_reentrant() {
287 let tm = PgsqlTransactionManager::new(60);
288 let thread_id = "test_thread_1";
289
290 assert!(tm.acquire_table_lock("users", thread_id, Duration::from_secs(1)));
291 assert!(tm.acquire_table_lock("users", thread_id, Duration::from_secs(1)));
292
293 tm.release_table_lock("users", thread_id);
294 }
295
296 #[test]
297 fn test_acquire_table_lock_different_thread_timeout() {
298 let tm = Arc::new(PgsqlTransactionManager::new(60));
299 let tm2 = tm.clone();
300
301 assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
302
303 let handle = thread::spawn(move || {
304 let result = tm2.acquire_table_lock("users", "thread_b", Duration::from_millis(100));
305 assert!(!result);
306 });
307
308 handle.join().unwrap();
309 tm.release_table_lock("users", "thread_a");
310 }
311
312 #[test]
313 fn test_release_table_lock_by_owner() {
314 let tm = PgsqlTransactionManager::new(60);
315 assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
316 assert_eq!(tm.stats().1, 1);
317
318 tm.release_table_lock("users", "thread_a");
319 assert_eq!(tm.stats().1, 0);
320 }
321
322 #[test]
323 fn test_release_table_lock_wrong_thread_does_nothing() {
324 let tm = PgsqlTransactionManager::new(60);
325 assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
326
327 tm.release_table_lock("users", "thread_b");
328 assert_eq!(tm.stats().1, 1);
329 }
330
331 #[test]
332 fn test_release_all_table_locks() {
333 let tm = PgsqlTransactionManager::new(60);
334 let thread_id = "test_thread_2";
335
336 tm.acquire_table_lock("table1", thread_id, Duration::from_secs(1));
337 tm.acquire_table_lock("table2", thread_id, Duration::from_secs(1));
338 tm.acquire_table_lock("table3", thread_id, Duration::from_secs(1));
339
340 let (_, lock_count) = tm.stats();
341 assert_eq!(lock_count, 3);
342
343 tm.release_all_table_locks(thread_id);
344
345 let (_, lock_count) = tm.stats();
346 assert_eq!(lock_count, 0);
347 }
348
349 #[test]
350 fn test_release_all_table_locks_only_removes_own() {
351 let tm = Arc::new(PgsqlTransactionManager::new(60));
352 let tm2 = tm.clone();
353
354 tm.acquire_table_lock("table_a", "thread_a", Duration::from_secs(1));
355
356 let handle = thread::spawn(move || {
357 tm2.acquire_table_lock("table_b", "thread_b", Duration::from_secs(1));
358 });
359 handle.join().unwrap();
360
361 assert_eq!(tm.stats().1, 2);
362
363 tm.release_all_table_locks("thread_a");
364
365 assert_eq!(tm.stats().1, 1);
366 }
367
368 #[test]
369 fn test_decrement_or_finish_no_transaction() {
370 let tm = PgsqlTransactionManager::new(60);
371 tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1));
372
373 let result = tm.decrement_or_finish("nonexistent", "thread_a");
374 assert_eq!(result, Some(0));
375
376 assert_eq!(tm.stats().1, 0);
377 }
378
379 #[test]
380 fn test_remove_no_transaction() {
381 let tm = PgsqlTransactionManager::new(60);
382 tm.remove("nonexistent", "thread_a");
383 }
384
385 #[test]
386 fn test_cleanup_expired_no_transactions() {
387 let tm = PgsqlTransactionManager::new(60);
388 tm.cleanup_expired();
389 }
390
391 #[test]
392 fn test_stats_reflects_locks() {
393 let tm = PgsqlTransactionManager::new(60);
394 tm.acquire_table_lock("t1", "thread_a", Duration::from_secs(1));
395 tm.acquire_table_lock("t2", "thread_a", Duration::from_secs(1));
396 tm.acquire_table_lock("t3", "thread_b", Duration::from_secs(1));
397
398 let (conn_count, lock_count) = tm.stats();
399 assert_eq!(conn_count, 0);
400 assert_eq!(lock_count, 3);
401 }
402}