1use log::{info, warn};
2use sqlite::ConnectionThreadSafe;
3use std::collections::HashMap;
4use std::sync::atomic::{AtomicU32, Ordering};
5use std::sync::{Arc, Condvar, Mutex, RwLock};
6use std::time::{Duration, Instant};
7
8struct TransactionInfo {
9 conn: Arc<ConnectionThreadSafe>,
10 depth: AtomicU32,
11 created_at: Instant,
12}
13
14pub struct SqliteTransactionManager {
15 connections: RwLock<HashMap<String, Arc<TransactionInfo>>>,
16 active_writer: Mutex<Option<String>>,
17 writer_cond: Condvar,
18 timeout: Duration,
19}
20
21impl SqliteTransactionManager {
22 pub fn new(timeout_secs: u64) -> Self {
23 Self {
24 connections: RwLock::new(HashMap::new()),
25 active_writer: Mutex::new(None),
26 writer_cond: Condvar::new(),
27 timeout: Duration::from_secs(timeout_secs),
28 }
29 }
30
31 pub fn is_in_transaction(&self, key: &str) -> bool {
32 match self.connections.read() {
33 Ok(guard) => guard.contains_key(key),
34 Err(poisoned) => poisoned.into_inner().contains_key(key),
35 }
36 }
37
38 pub fn get_depth(&self, key: &str) -> u32 {
39 let conns = match self.connections.read() {
40 Ok(guard) => guard,
41 Err(poisoned) => poisoned.into_inner(),
42 };
43 conns
44 .get(key)
45 .map(|t| t.depth.load(Ordering::SeqCst))
46 .unwrap_or(0)
47 }
48
49 pub fn increment_depth(&self, key: &str) -> bool {
50 let conns = match self.connections.read() {
51 Ok(guard) => guard,
52 Err(poisoned) => poisoned.into_inner(),
53 };
54
55 if let Some(txn_info) = conns.get(key) {
56 let new_depth = txn_info.depth.fetch_add(1, Ordering::SeqCst) + 1;
57 info!(
58 "SqliteTransactionManager: nested transaction depth={}",
59 new_depth
60 );
61 return true;
62 }
63 false
64 }
65
66 pub fn acquire_write_lock(&self, thread_id: &str, timeout: Duration) -> bool {
67 let start = Instant::now();
68
69 let mut guard = match self.active_writer.lock() {
70 Ok(g) => g,
71 Err(poisoned) => poisoned.into_inner(),
72 };
73
74 loop {
75 match &*guard {
76 None => {
77 *guard = Some(thread_id.to_string());
78 return true;
79 }
80 Some(owner) if owner == thread_id => {
81 return true;
82 }
83 _ => {}
84 }
85
86 let remaining = timeout.saturating_sub(start.elapsed());
87 if remaining.is_zero() {
88 warn!("SqliteTransactionManager: write lock timeout for {thread_id}");
89 return false;
90 }
91
92 let wait_time = Duration::from_millis(100).min(remaining);
93 let result = self.writer_cond.wait_timeout(guard, wait_time);
94 guard = match result {
95 Ok((g, _)) => g,
96 Err(poisoned) => poisoned.into_inner().0,
97 };
98 }
99 }
100
101 pub fn release_write_lock(&self, thread_id: &str) {
102 let mut guard = match self.active_writer.lock() {
103 Ok(g) => g,
104 Err(poisoned) => poisoned.into_inner(),
105 };
106
107 if let Some(owner) = &*guard {
108 if owner == thread_id {
109 *guard = None;
110 self.writer_cond.notify_all();
111 }
112 }
113 }
114
115 pub fn start(&self, key: &str, conn: Arc<ConnectionThreadSafe>) -> bool {
116 let mut conns = match self.connections.write() {
117 Ok(guard) => guard,
118 Err(poisoned) => poisoned.into_inner(),
119 };
120
121 conns.insert(
122 key.to_string(),
123 Arc::new(TransactionInfo {
124 conn,
125 depth: AtomicU32::new(1),
126 created_at: Instant::now(),
127 }),
128 );
129 true
130 }
131
132 pub fn with_conn<F, R>(&self, key: &str, f: F) -> Option<R>
133 where
134 F: FnOnce(&ConnectionThreadSafe) -> R,
135 {
136 let txn_info = {
137 let conns = match self.connections.read() {
138 Ok(guard) => guard,
139 Err(poisoned) => poisoned.into_inner(),
140 };
141 conns.get(key).cloned()
142 };
143
144 txn_info.map(|info| f(&info.conn))
145 }
146
147 pub fn decrement_or_finish(&self, key: &str, thread_id: &str) -> Option<u32> {
148 let txn_info = {
149 let conns = match self.connections.read() {
150 Ok(guard) => guard,
151 Err(poisoned) => poisoned.into_inner(),
152 };
153 conns.get(key).cloned()
154 };
155
156 if let Some(info) = txn_info {
157 let old_depth = info.depth.fetch_sub(1, Ordering::SeqCst);
158 if old_depth > 1 {
159 return Some(old_depth - 1);
160 }
161 }
162
163 let mut conns = match self.connections.write() {
164 Ok(guard) => guard,
165 Err(poisoned) => poisoned.into_inner(),
166 };
167 conns.remove(key);
168 drop(conns);
169 self.release_write_lock(thread_id);
170 Some(0)
171 }
172
173 pub fn remove(&self, key: &str, thread_id: &str) {
174 let mut conns = match self.connections.write() {
175 Ok(guard) => guard,
176 Err(poisoned) => poisoned.into_inner(),
177 };
178 conns.remove(key);
179 drop(conns);
180 self.release_write_lock(thread_id);
181 }
182
183 pub fn cleanup_expired(&self) {
184 let expired: Vec<(String, String)> = {
185 let conns = match self.connections.read() {
186 Ok(guard) => guard,
187 Err(poisoned) => poisoned.into_inner(),
188 };
189 conns
190 .iter()
191 .filter(|(_, txn_info)| txn_info.created_at.elapsed() > self.timeout)
192 .map(|(key, _)| (key.clone(), key.clone()))
193 .collect()
194 };
195
196 if expired.is_empty() {
197 return;
198 }
199
200 for (key, thread_id) in expired {
201 warn!("SqliteTransactionManager: cleaning up expired transaction: {key}");
202 let mut conns = match self.connections.write() {
203 Ok(guard) => guard,
204 Err(poisoned) => poisoned.into_inner(),
205 };
206 conns.remove(&key);
207 drop(conns);
208 self.release_write_lock(&thread_id);
209 }
210 }
211
212 pub fn stats(&self) -> (usize, bool) {
213 let conn_count = match self.connections.read() {
214 Ok(g) => g.len(),
215 Err(poisoned) => poisoned.into_inner().len(),
216 };
217 let has_writer = match self.active_writer.lock() {
218 Ok(g) => g.is_some(),
219 Err(poisoned) => poisoned.into_inner().is_some(),
220 };
221 (conn_count, has_writer)
222 }
223}
224
225lazy_static::lazy_static! {
226 pub static ref SQLITE_TRANSACTION_MANAGER: Arc<SqliteTransactionManager> =
227 Arc::new(SqliteTransactionManager::new(300));
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233 use std::thread;
234
235 fn mem_conn() -> Arc<ConnectionThreadSafe> {
236 Arc::new(sqlite::Connection::open_thread_safe(":memory:").unwrap())
237 }
238
239 #[test]
240 fn test_new_creates_empty_manager() {
241 let tm = SqliteTransactionManager::new(60);
242 let (conn_count, has_writer) = tm.stats();
243 assert_eq!(conn_count, 0);
244 assert!(!has_writer);
245 }
246
247 #[test]
248 fn test_is_in_transaction_false_when_empty() {
249 let tm = SqliteTransactionManager::new(60);
250 assert!(!tm.is_in_transaction("nonexistent"));
251 }
252
253 #[test]
254 fn test_get_depth_zero_when_empty() {
255 let tm = SqliteTransactionManager::new(60);
256 assert_eq!(tm.get_depth("nonexistent"), 0);
257 }
258
259 #[test]
260 fn test_start_and_is_in_transaction() {
261 let tm = SqliteTransactionManager::new(60);
262 let key = "test_key";
263
264 assert!(tm.start(key, mem_conn()));
265 assert!(tm.is_in_transaction(key));
266 }
267
268 #[test]
269 fn test_start_and_get_depth() {
270 let tm = SqliteTransactionManager::new(60);
271 let key = "test_key";
272
273 tm.start(key, mem_conn());
274 assert_eq!(tm.get_depth(key), 1);
275 }
276
277 #[test]
278 fn test_increment_depth() {
279 let tm = SqliteTransactionManager::new(60);
280 let key = "test_key";
281
282 tm.start(key, mem_conn());
283 assert!(tm.increment_depth(key));
284 assert_eq!(tm.get_depth(key), 2);
285 }
286
287 #[test]
288 fn test_increment_depth_returns_false_when_no_transaction() {
289 let tm = SqliteTransactionManager::new(60);
290 assert!(!tm.increment_depth("nonexistent"));
291 }
292
293 #[test]
294 fn test_with_conn_executes_closure() {
295 let tm = SqliteTransactionManager::new(60);
296 let key = "test_key";
297
298 tm.start(key, mem_conn());
299 let result = tm.with_conn(key, |conn| {
300 conn.execute("SELECT 1").unwrap();
301 42
302 });
303 assert_eq!(result, Some(42));
304 }
305
306 #[test]
307 fn test_with_conn_returns_none_when_no_transaction() {
308 let tm = SqliteTransactionManager::new(60);
309 let result = tm.with_conn("nonexistent", |_conn| 42);
310 assert_eq!(result, None);
311 }
312
313 #[test]
314 fn test_decrement_or_finish_from_depth_2() {
315 let tm = SqliteTransactionManager::new(60);
316 let key = "test_key";
317
318 tm.start(key, mem_conn());
319 tm.increment_depth(key);
320 assert_eq!(tm.get_depth(key), 2);
321
322 let result = tm.decrement_or_finish(key, key);
323 assert_eq!(result, Some(1));
324 assert!(tm.is_in_transaction(key));
325 }
326
327 #[test]
328 fn test_decrement_or_finish_from_depth_1() {
329 let tm = SqliteTransactionManager::new(60);
330 let key = "test_key";
331
332 tm.start(key, mem_conn());
333 assert_eq!(tm.get_depth(key), 1);
334
335 let result = tm.decrement_or_finish(key, key);
336 assert_eq!(result, Some(0));
337 assert!(!tm.is_in_transaction(key));
338 }
339
340 #[test]
341 fn test_remove_clears_transaction() {
342 let tm = SqliteTransactionManager::new(60);
343 let key = "test_key";
344
345 tm.start(key, mem_conn());
346 assert!(tm.is_in_transaction(key));
347
348 tm.remove(key, key);
349 assert!(!tm.is_in_transaction(key));
350 }
351
352 #[test]
353 fn test_remove_releases_write_lock() {
354 let tm = SqliteTransactionManager::new(60);
355 let key = "test_key";
356
357 tm.acquire_write_lock(key, Duration::from_secs(1));
358 assert!(tm.stats().1);
359
360 tm.start(key, mem_conn());
361 tm.remove(key, key);
362
363 assert!(!tm.stats().1);
364 }
365
366 #[test]
367 fn test_acquire_write_lock_same_thread_reentrant() {
368 let tm = SqliteTransactionManager::new(60);
369 let thread_id = "test_thread_1";
370
371 assert!(tm.acquire_write_lock(thread_id, Duration::from_secs(1)));
372 assert!(tm.acquire_write_lock(thread_id, Duration::from_secs(1)));
373
374 tm.release_write_lock(thread_id);
375
376 let (_, has_writer) = tm.stats();
377 assert!(!has_writer);
378 }
379
380 #[test]
381 fn test_acquire_write_lock_different_thread_timeout() {
382 let tm = Arc::new(SqliteTransactionManager::new(60));
383 let tm2 = tm.clone();
384
385 assert!(tm.acquire_write_lock("thread_1", Duration::from_secs(1)));
386
387 let handle = thread::spawn(move || {
388 let result = tm2.acquire_write_lock("thread_2", Duration::from_millis(100));
389 assert!(!result);
390 });
391
392 handle.join().unwrap();
393 tm.release_write_lock("thread_1");
394 }
395
396 #[test]
397 fn test_release_write_lock_by_owner() {
398 let tm = SqliteTransactionManager::new(60);
399 assert!(tm.acquire_write_lock("thread_a", Duration::from_secs(1)));
400 assert!(tm.stats().1);
401
402 tm.release_write_lock("thread_a");
403 assert!(!tm.stats().1);
404 }
405
406 #[test]
407 fn test_release_write_lock_wrong_thread_does_nothing() {
408 let tm = SqliteTransactionManager::new(60);
409 assert!(tm.acquire_write_lock("thread_a", Duration::from_secs(1)));
410
411 tm.release_write_lock("thread_b");
412 assert!(tm.stats().1);
413 }
414
415 #[test]
416 fn test_cleanup_expired_no_transactions() {
417 let tm = SqliteTransactionManager::new(60);
418 tm.cleanup_expired();
419 }
420
421 #[test]
422 fn test_cleanup_expired_removes_old_transactions() {
423 let tm = SqliteTransactionManager::new(0);
424 let key = "test_key";
425
426 tm.acquire_write_lock(key, Duration::from_secs(1));
427 tm.start(key, mem_conn());
428 assert!(tm.is_in_transaction(key));
429
430 thread::sleep(Duration::from_millis(10));
431 tm.cleanup_expired();
432
433 assert!(!tm.is_in_transaction(key));
434 assert!(!tm.stats().1);
435 }
436
437 #[test]
438 fn test_stats_reflects_state() {
439 let tm = SqliteTransactionManager::new(60);
440
441 let (conn_count, has_writer) = tm.stats();
442 assert_eq!(conn_count, 0);
443 assert!(!has_writer);
444
445 tm.start("key1", mem_conn());
446 tm.acquire_write_lock("writer", Duration::from_secs(1));
447
448 let (conn_count, has_writer) = tm.stats();
449 assert_eq!(conn_count, 1);
450 assert!(has_writer);
451 }
452}