1use log::{info, warn};
2use mysql::PooledConn;
3use std::collections::HashMap;
4use std::sync::{Arc, Condvar, Mutex, RwLock};
5use std::time::{Duration, Instant};
6
7#[derive(Debug)]
8struct TransactionInfo {
9 conn: PooledConn,
10 depth: i32,
11 created_at: Instant,
12}
13
14pub struct TransactionManager {
15 connections: RwLock<HashMap<String, TransactionInfo>>,
16 table_locks: RwLock<HashMap<String, String>>,
17 table_cond: Condvar,
18 table_mutex: Mutex<()>,
19 timeout: Duration,
20}
21
22impl TransactionManager {
23 pub fn new(timeout_secs: u64) -> Self {
24 Self {
25 connections: RwLock::new(HashMap::new()),
26 table_locks: RwLock::new(HashMap::new()),
27 table_cond: Condvar::new(),
28 table_mutex: Mutex::new(()),
29 timeout: Duration::from_secs(timeout_secs),
30 }
31 }
32
33 pub fn is_in_transaction(&self, key: &str) -> bool {
34 match self.connections.read() {
35 Ok(guard) => guard.contains_key(key),
36 Err(poisoned) => poisoned.into_inner().contains_key(key),
37 }
38 }
39
40 pub fn get_depth(&self, key: &str) -> i32 {
41 match self.connections.read() {
42 Ok(guard) => guard.get(key).map(|t| t.depth).unwrap_or(0),
43 Err(poisoned) => poisoned.into_inner().get(key).map(|t| t.depth).unwrap_or(0),
44 }
45 }
46
47 pub fn increment_depth(&self, key: &str) -> bool {
48 let mut conns = match self.connections.write() {
49 Ok(guard) => guard,
50 Err(poisoned) => poisoned.into_inner(),
51 };
52
53 if let Some(txn_info) = conns.get_mut(key) {
54 txn_info.depth += 1;
55 info!(
56 "TransactionManager: nested transaction depth={}",
57 txn_info.depth
58 );
59 return true;
60 }
61 false
62 }
63
64 pub fn start(&self, key: &str, conn: PooledConn) -> bool {
65 let mut conns = match self.connections.write() {
66 Ok(guard) => guard,
67 Err(poisoned) => poisoned.into_inner(),
68 };
69
70 conns.insert(
71 key.to_string(),
72 TransactionInfo {
73 conn,
74 depth: 1,
75 created_at: Instant::now(),
76 },
77 );
78 true
79 }
80
81 pub fn with_conn<F, R>(&self, key: &str, f: F) -> Option<R>
82 where
83 F: FnOnce(&mut PooledConn) -> R,
84 {
85 let mut conns = match self.connections.write() {
86 Ok(guard) => guard,
87 Err(poisoned) => poisoned.into_inner(),
88 };
89
90 conns.get_mut(key).map(|txn_info| f(&mut txn_info.conn))
91 }
92
93 pub fn acquire_table_lock(&self, table: &str, thread_id: &str, timeout: Duration) -> bool {
94 let start = Instant::now();
95
96 loop {
97 {
98 let mut locks = match self.table_locks.write() {
99 Ok(guard) => guard,
100 Err(poisoned) => poisoned.into_inner(),
101 };
102
103 match locks.get(table) {
104 None => {
105 locks.insert(table.to_string(), thread_id.to_string());
106 return true;
107 }
108 Some(owner) if owner == thread_id => {
109 return true;
110 }
111 _ => {}
112 }
113 }
114
115 if start.elapsed() > timeout {
116 warn!("TransactionManager: table lock timeout for {table}");
117 return false;
118 }
119
120 let guard = match self.table_mutex.lock() {
121 Ok(g) => g,
122 Err(poisoned) => poisoned.into_inner(),
123 };
124
125 let remaining = timeout.saturating_sub(start.elapsed());
126 if remaining.is_zero() {
127 return false;
128 }
129
130 let wait_time = Duration::from_millis(50).min(remaining);
131 let _ = self.table_cond.wait_timeout(guard, wait_time);
132 }
133 }
134
135 pub fn release_table_lock(&self, table: &str, thread_id: &str) {
136 let mut locks = match self.table_locks.write() {
137 Ok(guard) => guard,
138 Err(poisoned) => poisoned.into_inner(),
139 };
140
141 if let Some(owner) = locks.get(table) {
142 if owner == thread_id {
143 locks.remove(table);
144 self.table_cond.notify_all();
145 }
146 }
147 }
148
149 pub fn release_all_table_locks(&self, thread_id: &str) {
150 let mut locks = match self.table_locks.write() {
151 Ok(guard) => guard,
152 Err(poisoned) => poisoned.into_inner(),
153 };
154
155 let tables_to_remove: Vec<String> = locks
156 .iter()
157 .filter(|(_, owner)| *owner == thread_id)
158 .map(|(table, _)| table.clone())
159 .collect();
160
161 for table in tables_to_remove {
162 locks.remove(&table);
163 }
164
165 self.table_cond.notify_all();
166 }
167
168 pub fn decrement_or_finish(&self, key: &str, thread_id: &str) -> Option<i32> {
169 let mut conns = match self.connections.write() {
170 Ok(guard) => guard,
171 Err(poisoned) => poisoned.into_inner(),
172 };
173
174 if let Some(txn_info) = conns.get_mut(key) {
175 if txn_info.depth > 1 {
176 txn_info.depth -= 1;
177 return Some(txn_info.depth);
178 }
179 }
180
181 conns.remove(key);
182 drop(conns);
183 self.release_all_table_locks(thread_id);
184 Some(0)
185 }
186
187 pub fn remove(&self, key: &str, thread_id: &str) {
188 let mut conns = match self.connections.write() {
189 Ok(guard) => guard,
190 Err(poisoned) => poisoned.into_inner(),
191 };
192 conns.remove(key);
193 drop(conns);
194 self.release_all_table_locks(thread_id);
195 }
196
197 pub fn cleanup_expired(&self) {
198 let expired: Vec<String> = {
199 let conns = match self.connections.read() {
200 Ok(guard) => guard,
201 Err(poisoned) => poisoned.into_inner(),
202 };
203 conns
204 .iter()
205 .filter(|(_, txn_info)| txn_info.created_at.elapsed() > self.timeout)
206 .map(|(key, _)| key.clone())
207 .collect()
208 };
209
210 if expired.is_empty() {
211 return;
212 }
213
214 let mut conns = match self.connections.write() {
215 Ok(guard) => guard,
216 Err(poisoned) => poisoned.into_inner(),
217 };
218
219 for key in expired {
220 warn!("TransactionManager: cleaning up expired transaction: {key}");
221 conns.remove(&key);
222 }
223 }
224
225 pub fn stats(&self) -> (usize, usize) {
226 let conn_count = match self.connections.read() {
227 Ok(g) => g.len(),
228 Err(poisoned) => poisoned.into_inner().len(),
229 };
230 let lock_count = match self.table_locks.read() {
231 Ok(g) => g.len(),
232 Err(poisoned) => poisoned.into_inner().len(),
233 };
234 (conn_count, lock_count)
235 }
236}
237
238lazy_static::lazy_static! {
239 pub static ref TRANSACTION_MANAGER: Arc<TransactionManager> =
240 Arc::new(TransactionManager::new(300));
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246 use std::thread;
247
248 #[test]
249 fn test_new_creates_empty_manager() {
250 let tm = TransactionManager::new(60);
251 let (conn_count, lock_count) = tm.stats();
252 assert_eq!(conn_count, 0);
253 assert_eq!(lock_count, 0);
254 }
255
256 #[test]
257 fn test_is_in_transaction_false_when_empty() {
258 let tm = TransactionManager::new(60);
259 assert!(!tm.is_in_transaction("nonexistent"));
260 }
261
262 #[test]
263 fn test_get_depth_zero_when_empty() {
264 let tm = TransactionManager::new(60);
265 assert_eq!(tm.get_depth("nonexistent"), 0);
266 }
267
268 #[test]
269 fn test_increment_depth_returns_false_when_no_transaction() {
270 let tm = TransactionManager::new(60);
271 assert!(!tm.increment_depth("nonexistent"));
272 }
273
274 #[test]
275 fn test_acquire_table_lock_new_table() {
276 let tm = TransactionManager::new(60);
277 assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
278 let (_, lock_count) = tm.stats();
279 assert_eq!(lock_count, 1);
280 }
281
282 #[test]
283 fn test_acquire_table_lock_same_thread_reentrant() {
284 let tm = TransactionManager::new(60);
285 let thread_id = "test_thread_1";
286
287 assert!(tm.acquire_table_lock("users", thread_id, Duration::from_secs(1)));
288 assert!(tm.acquire_table_lock("users", thread_id, Duration::from_secs(1)));
289
290 tm.release_table_lock("users", thread_id);
291 }
292
293 #[test]
294 fn test_acquire_table_lock_different_thread_timeout() {
295 let tm = Arc::new(TransactionManager::new(60));
296 let tm2 = tm.clone();
297
298 assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
299
300 let handle = thread::spawn(move || {
301 let result = tm2.acquire_table_lock("users", "thread_b", Duration::from_millis(100));
302 assert!(!result);
303 });
304
305 handle.join().unwrap();
306 tm.release_table_lock("users", "thread_a");
307 }
308
309 #[test]
310 fn test_release_table_lock_by_owner() {
311 let tm = TransactionManager::new(60);
312 assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
313 assert_eq!(tm.stats().1, 1);
314
315 tm.release_table_lock("users", "thread_a");
316 assert_eq!(tm.stats().1, 0);
317 }
318
319 #[test]
320 fn test_release_table_lock_wrong_thread_does_nothing() {
321 let tm = TransactionManager::new(60);
322 assert!(tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1)));
323
324 tm.release_table_lock("users", "thread_b");
325 assert_eq!(tm.stats().1, 1);
326 }
327
328 #[test]
329 fn test_release_all_table_locks() {
330 let tm = TransactionManager::new(60);
331 let thread_id = "test_thread_2";
332
333 tm.acquire_table_lock("table1", thread_id, Duration::from_secs(1));
334 tm.acquire_table_lock("table2", thread_id, Duration::from_secs(1));
335 tm.acquire_table_lock("table3", thread_id, Duration::from_secs(1));
336
337 let (_, lock_count) = tm.stats();
338 assert_eq!(lock_count, 3);
339
340 tm.release_all_table_locks(thread_id);
341
342 let (_, lock_count) = tm.stats();
343 assert_eq!(lock_count, 0);
344 }
345
346 #[test]
347 fn test_release_all_table_locks_only_removes_own() {
348 let tm = Arc::new(TransactionManager::new(60));
349 let tm2 = tm.clone();
350
351 tm.acquire_table_lock("table_a", "thread_a", Duration::from_secs(1));
352
353 let handle = thread::spawn(move || {
354 tm2.acquire_table_lock("table_b", "thread_b", Duration::from_secs(1));
355 });
356 handle.join().unwrap();
357
358 assert_eq!(tm.stats().1, 2);
359
360 tm.release_all_table_locks("thread_a");
361
362 assert_eq!(tm.stats().1, 1);
363 }
364
365 #[test]
366 fn test_decrement_or_finish_no_transaction() {
367 let tm = TransactionManager::new(60);
368 tm.acquire_table_lock("users", "thread_a", Duration::from_secs(1));
369
370 let result = tm.decrement_or_finish("nonexistent", "thread_a");
371 assert_eq!(result, Some(0));
372
373 assert_eq!(tm.stats().1, 0);
374 }
375
376 #[test]
377 fn test_remove_no_transaction() {
378 let tm = TransactionManager::new(60);
379 tm.remove("nonexistent", "thread_a");
380 }
381
382 #[test]
383 fn test_cleanup_expired_no_transactions() {
384 let tm = TransactionManager::new(60);
385 tm.cleanup_expired();
386 }
387
388 #[test]
389 fn test_stats_reflects_locks() {
390 let tm = TransactionManager::new(60);
391 tm.acquire_table_lock("t1", "thread_a", Duration::from_secs(1));
392 tm.acquire_table_lock("t2", "thread_a", Duration::from_secs(1));
393 tm.acquire_table_lock("t3", "thread_b", Duration::from_secs(1));
394
395 let (conn_count, lock_count) = tm.stats();
396 assert_eq!(conn_count, 0);
397 assert_eq!(lock_count, 3);
398 }
399}