Skip to main content

br_db/types/
sqlite_transaction.rs

1use log::{info, warn};
2use sqlite::ConnectionThreadSafe;
3use std::collections::HashMap;
4use std::sync::atomic::{AtomicU32, Ordering};
5use std::sync::{Arc, Condvar, Mutex, RwLock};
6use std::time::{Duration, Instant};
7
8struct TransactionInfo {
9    conn: Arc<ConnectionThreadSafe>,
10    depth: AtomicU32,
11    created_at: Instant,
12}
13
14pub struct SqliteTransactionManager {
15    connections: RwLock<HashMap<String, Arc<TransactionInfo>>>,
16    active_writer: Mutex<Option<String>>,
17    writer_cond: Condvar,
18    timeout: Duration,
19}
20
21impl SqliteTransactionManager {
22    pub fn new(timeout_secs: u64) -> Self {
23        Self {
24            connections: RwLock::new(HashMap::new()),
25            active_writer: Mutex::new(None),
26            writer_cond: Condvar::new(),
27            timeout: Duration::from_secs(timeout_secs),
28        }
29    }
30
31    pub fn is_in_transaction(&self, key: &str) -> bool {
32        match self.connections.read() {
33            Ok(guard) => guard.contains_key(key),
34            Err(poisoned) => poisoned.into_inner().contains_key(key),
35        }
36    }
37
38    pub fn get_depth(&self, key: &str) -> u32 {
39        let conns = match self.connections.read() {
40            Ok(guard) => guard,
41            Err(poisoned) => poisoned.into_inner(),
42        };
43        conns
44            .get(key)
45            .map(|t| t.depth.load(Ordering::SeqCst))
46            .unwrap_or(0)
47    }
48
49    pub fn increment_depth(&self, key: &str) -> bool {
50        let conns = match self.connections.read() {
51            Ok(guard) => guard,
52            Err(poisoned) => poisoned.into_inner(),
53        };
54
55        if let Some(txn_info) = conns.get(key) {
56            let new_depth = txn_info.depth.fetch_add(1, Ordering::SeqCst) + 1;
57            info!(
58                "SqliteTransactionManager: nested transaction depth={}",
59                new_depth
60            );
61            return true;
62        }
63        false
64    }
65
66    pub fn acquire_write_lock(&self, thread_id: &str, timeout: Duration) -> bool {
67        let start = Instant::now();
68
69        let mut guard = match self.active_writer.lock() {
70            Ok(g) => g,
71            Err(poisoned) => poisoned.into_inner(),
72        };
73
74        loop {
75            match &*guard {
76                None => {
77                    *guard = Some(thread_id.to_string());
78                    return true;
79                }
80                Some(owner) if owner == thread_id => {
81                    return true;
82                }
83                _ => {}
84            }
85
86            let remaining = timeout.saturating_sub(start.elapsed());
87            if remaining.is_zero() {
88                warn!("SqliteTransactionManager: write lock timeout for {thread_id}");
89                return false;
90            }
91
92            let wait_time = Duration::from_millis(100).min(remaining);
93            let result = self.writer_cond.wait_timeout(guard, wait_time);
94            guard = match result {
95                Ok((g, _)) => g,
96                Err(poisoned) => poisoned.into_inner().0,
97            };
98        }
99    }
100
101    pub fn release_write_lock(&self, thread_id: &str) {
102        let mut guard = match self.active_writer.lock() {
103            Ok(g) => g,
104            Err(poisoned) => poisoned.into_inner(),
105        };
106
107        if let Some(owner) = &*guard {
108            if owner == thread_id {
109                *guard = None;
110                self.writer_cond.notify_all();
111            }
112        }
113    }
114
115    pub fn start(&self, key: &str, conn: Arc<ConnectionThreadSafe>) -> bool {
116        let mut conns = match self.connections.write() {
117            Ok(guard) => guard,
118            Err(poisoned) => poisoned.into_inner(),
119        };
120
121        conns.insert(
122            key.to_string(),
123            Arc::new(TransactionInfo {
124                conn,
125                depth: AtomicU32::new(1),
126                created_at: Instant::now(),
127            }),
128        );
129        true
130    }
131
132    pub fn with_conn<F, R>(&self, key: &str, f: F) -> Option<R>
133    where
134        F: FnOnce(&ConnectionThreadSafe) -> R,
135    {
136        let txn_info = {
137            let conns = match self.connections.read() {
138                Ok(guard) => guard,
139                Err(poisoned) => poisoned.into_inner(),
140            };
141            conns.get(key).cloned()
142        };
143
144        txn_info.map(|info| f(&info.conn))
145    }
146
147    pub fn decrement_or_finish(&self, key: &str, thread_id: &str) -> Option<u32> {
148        let txn_info = {
149            let conns = match self.connections.read() {
150                Ok(guard) => guard,
151                Err(poisoned) => poisoned.into_inner(),
152            };
153            conns.get(key).cloned()
154        };
155
156        if let Some(info) = txn_info {
157            let old_depth = info.depth.fetch_sub(1, Ordering::SeqCst);
158            if old_depth > 1 {
159                return Some(old_depth - 1);
160            }
161        }
162
163        let mut conns = match self.connections.write() {
164            Ok(guard) => guard,
165            Err(poisoned) => poisoned.into_inner(),
166        };
167        conns.remove(key);
168        drop(conns);
169        self.release_write_lock(thread_id);
170        Some(0)
171    }
172
173    pub fn remove(&self, key: &str, thread_id: &str) {
174        let mut conns = match self.connections.write() {
175            Ok(guard) => guard,
176            Err(poisoned) => poisoned.into_inner(),
177        };
178        conns.remove(key);
179        drop(conns);
180        self.release_write_lock(thread_id);
181    }
182
183    pub fn cleanup_expired(&self) {
184        let expired: Vec<(String, String)> = {
185            let conns = match self.connections.read() {
186                Ok(guard) => guard,
187                Err(poisoned) => poisoned.into_inner(),
188            };
189            conns
190                .iter()
191                .filter(|(_, txn_info)| txn_info.created_at.elapsed() > self.timeout)
192                .map(|(key, _)| (key.clone(), key.clone()))
193                .collect()
194        };
195
196        if expired.is_empty() {
197            return;
198        }
199
200        for (key, thread_id) in expired {
201            warn!("SqliteTransactionManager: cleaning up expired transaction: {key}");
202            let mut conns = match self.connections.write() {
203                Ok(guard) => guard,
204                Err(poisoned) => poisoned.into_inner(),
205            };
206            conns.remove(&key);
207            drop(conns);
208            self.release_write_lock(&thread_id);
209        }
210    }
211
212    pub fn stats(&self) -> (usize, bool) {
213        let conn_count = match self.connections.read() {
214            Ok(g) => g.len(),
215            Err(poisoned) => poisoned.into_inner().len(),
216        };
217        let has_writer = match self.active_writer.lock() {
218            Ok(g) => g.is_some(),
219            Err(poisoned) => poisoned.into_inner().is_some(),
220        };
221        (conn_count, has_writer)
222    }
223}
224
225lazy_static::lazy_static! {
226    pub static ref SQLITE_TRANSACTION_MANAGER: Arc<SqliteTransactionManager> =
227        Arc::new(SqliteTransactionManager::new(300));
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233    use std::thread;
234
235    fn mem_conn() -> Arc<ConnectionThreadSafe> {
236        Arc::new(sqlite::Connection::open_thread_safe(":memory:").unwrap())
237    }
238
239    #[test]
240    fn test_new_creates_empty_manager() {
241        let tm = SqliteTransactionManager::new(60);
242        let (conn_count, has_writer) = tm.stats();
243        assert_eq!(conn_count, 0);
244        assert!(!has_writer);
245    }
246
247    #[test]
248    fn test_is_in_transaction_false_when_empty() {
249        let tm = SqliteTransactionManager::new(60);
250        assert!(!tm.is_in_transaction("nonexistent"));
251    }
252
253    #[test]
254    fn test_get_depth_zero_when_empty() {
255        let tm = SqliteTransactionManager::new(60);
256        assert_eq!(tm.get_depth("nonexistent"), 0);
257    }
258
259    #[test]
260    fn test_start_and_is_in_transaction() {
261        let tm = SqliteTransactionManager::new(60);
262        let key = "test_key";
263
264        assert!(tm.start(key, mem_conn()));
265        assert!(tm.is_in_transaction(key));
266    }
267
268    #[test]
269    fn test_start_and_get_depth() {
270        let tm = SqliteTransactionManager::new(60);
271        let key = "test_key";
272
273        tm.start(key, mem_conn());
274        assert_eq!(tm.get_depth(key), 1);
275    }
276
277    #[test]
278    fn test_increment_depth() {
279        let tm = SqliteTransactionManager::new(60);
280        let key = "test_key";
281
282        tm.start(key, mem_conn());
283        assert!(tm.increment_depth(key));
284        assert_eq!(tm.get_depth(key), 2);
285    }
286
287    #[test]
288    fn test_increment_depth_returns_false_when_no_transaction() {
289        let tm = SqliteTransactionManager::new(60);
290        assert!(!tm.increment_depth("nonexistent"));
291    }
292
293    #[test]
294    fn test_with_conn_executes_closure() {
295        let tm = SqliteTransactionManager::new(60);
296        let key = "test_key";
297
298        tm.start(key, mem_conn());
299        let result = tm.with_conn(key, |conn| {
300            conn.execute("SELECT 1").unwrap();
301            42
302        });
303        assert_eq!(result, Some(42));
304    }
305
306    #[test]
307    fn test_with_conn_returns_none_when_no_transaction() {
308        let tm = SqliteTransactionManager::new(60);
309        let result = tm.with_conn("nonexistent", |_conn| 42);
310        assert_eq!(result, None);
311    }
312
313    #[test]
314    fn test_decrement_or_finish_from_depth_2() {
315        let tm = SqliteTransactionManager::new(60);
316        let key = "test_key";
317
318        tm.start(key, mem_conn());
319        tm.increment_depth(key);
320        assert_eq!(tm.get_depth(key), 2);
321
322        let result = tm.decrement_or_finish(key, key);
323        assert_eq!(result, Some(1));
324        assert!(tm.is_in_transaction(key));
325    }
326
327    #[test]
328    fn test_decrement_or_finish_from_depth_1() {
329        let tm = SqliteTransactionManager::new(60);
330        let key = "test_key";
331
332        tm.start(key, mem_conn());
333        assert_eq!(tm.get_depth(key), 1);
334
335        let result = tm.decrement_or_finish(key, key);
336        assert_eq!(result, Some(0));
337        assert!(!tm.is_in_transaction(key));
338    }
339
340    #[test]
341    fn test_remove_clears_transaction() {
342        let tm = SqliteTransactionManager::new(60);
343        let key = "test_key";
344
345        tm.start(key, mem_conn());
346        assert!(tm.is_in_transaction(key));
347
348        tm.remove(key, key);
349        assert!(!tm.is_in_transaction(key));
350    }
351
352    #[test]
353    fn test_remove_releases_write_lock() {
354        let tm = SqliteTransactionManager::new(60);
355        let key = "test_key";
356
357        tm.acquire_write_lock(key, Duration::from_secs(1));
358        assert!(tm.stats().1);
359
360        tm.start(key, mem_conn());
361        tm.remove(key, key);
362
363        assert!(!tm.stats().1);
364    }
365
366    #[test]
367    fn test_acquire_write_lock_same_thread_reentrant() {
368        let tm = SqliteTransactionManager::new(60);
369        let thread_id = "test_thread_1";
370
371        assert!(tm.acquire_write_lock(thread_id, Duration::from_secs(1)));
372        assert!(tm.acquire_write_lock(thread_id, Duration::from_secs(1)));
373
374        tm.release_write_lock(thread_id);
375
376        let (_, has_writer) = tm.stats();
377        assert!(!has_writer);
378    }
379
380    #[test]
381    fn test_acquire_write_lock_different_thread_timeout() {
382        let tm = Arc::new(SqliteTransactionManager::new(60));
383        let tm2 = tm.clone();
384
385        assert!(tm.acquire_write_lock("thread_1", Duration::from_secs(1)));
386
387        let handle = thread::spawn(move || {
388            let result = tm2.acquire_write_lock("thread_2", Duration::from_millis(100));
389            assert!(!result);
390        });
391
392        handle.join().unwrap();
393        tm.release_write_lock("thread_1");
394    }
395
396    #[test]
397    fn test_release_write_lock_by_owner() {
398        let tm = SqliteTransactionManager::new(60);
399        assert!(tm.acquire_write_lock("thread_a", Duration::from_secs(1)));
400        assert!(tm.stats().1);
401
402        tm.release_write_lock("thread_a");
403        assert!(!tm.stats().1);
404    }
405
406    #[test]
407    fn test_release_write_lock_wrong_thread_does_nothing() {
408        let tm = SqliteTransactionManager::new(60);
409        assert!(tm.acquire_write_lock("thread_a", Duration::from_secs(1)));
410
411        tm.release_write_lock("thread_b");
412        assert!(tm.stats().1);
413    }
414
415    #[test]
416    fn test_cleanup_expired_no_transactions() {
417        let tm = SqliteTransactionManager::new(60);
418        tm.cleanup_expired();
419    }
420
421    #[test]
422    fn test_cleanup_expired_removes_old_transactions() {
423        let tm = SqliteTransactionManager::new(0);
424        let key = "test_key";
425
426        tm.acquire_write_lock(key, Duration::from_secs(1));
427        tm.start(key, mem_conn());
428        assert!(tm.is_in_transaction(key));
429
430        thread::sleep(Duration::from_millis(10));
431        tm.cleanup_expired();
432
433        assert!(!tm.is_in_transaction(key));
434        assert!(!tm.stats().1);
435    }
436
437    #[test]
438    fn test_stats_reflects_state() {
439        let tm = SqliteTransactionManager::new(60);
440
441        let (conn_count, has_writer) = tm.stats();
442        assert_eq!(conn_count, 0);
443        assert!(!has_writer);
444
445        tm.start("key1", mem_conn());
446        tm.acquire_write_lock("writer", Duration::from_secs(1));
447
448        let (conn_count, has_writer) = tm.stats();
449        assert_eq!(conn_count, 1);
450        assert!(has_writer);
451    }
452}