Skip to main content

reddb_server/storage/transaction/
savepoint.rs

1//! Savepoint Management
2//!
3//! Enables partial rollback within transactions.
4
5use std::collections::HashMap;
6use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
7
8/// Transaction ID type
9pub type TxnId = u64;
10
11/// Log Sequence Number
12pub type Lsn = u64;
13
14/// Timestamp type
15pub type Timestamp = u64;
16
17/// A savepoint within a transaction
18#[derive(Debug, Clone)]
19pub struct Savepoint {
20    /// Savepoint name
21    pub name: String,
22    /// Transaction ID
23    pub txn_id: TxnId,
24    /// LSN when savepoint was created
25    pub lsn: Lsn,
26    /// Timestamp when created
27    pub created_at: Timestamp,
28    /// Lock count at savepoint (for lock release)
29    pub lock_count: usize,
30    /// Write set index at savepoint (for partial rollback)
31    pub write_set_index: usize,
32    /// Nested savepoint depth
33    pub depth: usize,
34}
35
36impl Savepoint {
37    /// Create new savepoint
38    pub fn new(
39        name: String,
40        txn_id: TxnId,
41        lsn: Lsn,
42        lock_count: usize,
43        write_set_index: usize,
44        depth: usize,
45    ) -> Self {
46        use std::time::{SystemTime, UNIX_EPOCH};
47
48        Self {
49            name,
50            txn_id,
51            lsn,
52            created_at: SystemTime::now()
53                .duration_since(UNIX_EPOCH)
54                .unwrap_or_default()
55                .as_micros() as Timestamp,
56            lock_count,
57            write_set_index,
58            depth,
59        }
60    }
61}
62
63/// Savepoint manager for a single transaction
64#[derive(Debug)]
65pub struct TxnSavepoints {
66    /// Transaction ID
67    txn_id: TxnId,
68    /// Savepoints by name
69    savepoints: HashMap<String, Savepoint>,
70    /// Savepoint stack (for nested savepoints)
71    stack: Vec<String>,
72}
73
74impl TxnSavepoints {
75    /// Create new savepoint manager for transaction
76    pub fn new(txn_id: TxnId) -> Self {
77        Self {
78            txn_id,
79            savepoints: HashMap::new(),
80            stack: Vec::new(),
81        }
82    }
83
84    /// Create a savepoint
85    pub fn create(
86        &mut self,
87        name: String,
88        lsn: Lsn,
89        lock_count: usize,
90        write_set_index: usize,
91    ) -> &Savepoint {
92        let depth = self.stack.len();
93        let savepoint = Savepoint::new(
94            name.clone(),
95            self.txn_id,
96            lsn,
97            lock_count,
98            write_set_index,
99            depth,
100        );
101
102        self.savepoints.insert(name.clone(), savepoint);
103        self.stack.push(name.clone());
104
105        self.savepoints.get(&name).unwrap()
106    }
107
108    /// Get a savepoint by name
109    pub fn get(&self, name: &str) -> Option<&Savepoint> {
110        self.savepoints.get(name)
111    }
112
113    /// Release a savepoint (and all nested ones)
114    pub fn release(&mut self, name: &str) -> Option<Savepoint> {
115        // Find position in stack
116        if let Some(pos) = self.stack.iter().position(|n| n == name) {
117            // Remove this and all nested savepoints
118            let to_remove: Vec<String> = self.stack.drain(pos..).collect();
119            let removed = self.savepoints.remove(name);
120
121            for nested_name in to_remove.iter().skip(1) {
122                self.savepoints.remove(nested_name);
123            }
124
125            removed
126        } else {
127            None
128        }
129    }
130
131    /// Rollback to savepoint (returns savepoints to release)
132    pub fn rollback_to(&mut self, name: &str) -> Option<(Savepoint, Vec<String>)> {
133        // Find position in stack
134        if let Some(pos) = self.stack.iter().position(|n| n == name) {
135            // Get savepoint info
136            let savepoint = self.savepoints.get(name)?.clone();
137
138            // Collect nested savepoints to release
139            let to_release: Vec<String> = self.stack.drain(pos + 1..).collect();
140
141            // Remove nested savepoints
142            for nested_name in &to_release {
143                self.savepoints.remove(nested_name);
144            }
145
146            Some((savepoint, to_release))
147        } else {
148            None
149        }
150    }
151
152    /// Check if savepoint exists
153    pub fn exists(&self, name: &str) -> bool {
154        self.savepoints.contains_key(name)
155    }
156
157    /// Get current depth
158    pub fn depth(&self) -> usize {
159        self.stack.len()
160    }
161
162    /// Get all savepoint names in order
163    pub fn names(&self) -> &[String] {
164        &self.stack
165    }
166
167    /// Clear all savepoints
168    pub fn clear(&mut self) {
169        self.savepoints.clear();
170        self.stack.clear();
171    }
172}
173
174/// Savepoint manager for all transactions
175pub struct SavepointManager {
176    /// Per-transaction savepoints
177    txn_savepoints: RwLock<HashMap<TxnId, TxnSavepoints>>,
178}
179
180impl SavepointManager {
181    /// Create new savepoint manager
182    pub fn new() -> Self {
183        Self {
184            txn_savepoints: RwLock::new(HashMap::new()),
185        }
186    }
187
188    fn txn_savepoints_write(
189        &self,
190    ) -> Result<RwLockWriteGuard<'_, HashMap<TxnId, TxnSavepoints>>, SavepointError> {
191        self.txn_savepoints
192            .write()
193            .map_err(|_| SavepointError::LockPoisoned("savepoint registry"))
194    }
195
196    fn txn_savepoints_read(&self) -> RwLockReadGuard<'_, HashMap<TxnId, TxnSavepoints>> {
197        self.txn_savepoints
198            .read()
199            .unwrap_or_else(|poisoned| poisoned.into_inner())
200    }
201
202    /// Create a savepoint for a transaction
203    pub fn create_savepoint(
204        &self,
205        txn_id: TxnId,
206        name: String,
207        lsn: Lsn,
208        lock_count: usize,
209        write_set_index: usize,
210    ) -> Result<Savepoint, SavepointError> {
211        let mut txn_map = self.txn_savepoints_write()?;
212        let txn_sp = txn_map
213            .entry(txn_id)
214            .or_insert_with(|| TxnSavepoints::new(txn_id));
215
216        // Check for duplicate name
217        if txn_sp.exists(&name) {
218            return Err(SavepointError::DuplicateName(name));
219        }
220
221        Ok(txn_sp
222            .create(name, lsn, lock_count, write_set_index)
223            .clone())
224    }
225
226    /// Get a savepoint
227    pub fn get_savepoint(&self, txn_id: TxnId, name: &str) -> Option<Savepoint> {
228        let txn_map = self.txn_savepoints_read();
229        txn_map.get(&txn_id).and_then(|sp| sp.get(name).cloned())
230    }
231
232    /// Release a savepoint
233    pub fn release_savepoint(
234        &self,
235        txn_id: TxnId,
236        name: &str,
237    ) -> Result<Savepoint, SavepointError> {
238        let mut txn_map = self.txn_savepoints_write()?;
239
240        let txn_sp = txn_map
241            .get_mut(&txn_id)
242            .ok_or(SavepointError::TxnNotFound(txn_id))?;
243
244        txn_sp
245            .release(name)
246            .ok_or_else(|| SavepointError::NotFound(name.to_string()))
247    }
248
249    /// Rollback to a savepoint
250    pub fn rollback_to_savepoint(
251        &self,
252        txn_id: TxnId,
253        name: &str,
254    ) -> Result<(Savepoint, Vec<String>), SavepointError> {
255        let mut txn_map = self.txn_savepoints_write()?;
256
257        let txn_sp = txn_map
258            .get_mut(&txn_id)
259            .ok_or(SavepointError::TxnNotFound(txn_id))?;
260
261        txn_sp
262            .rollback_to(name)
263            .ok_or_else(|| SavepointError::NotFound(name.to_string()))
264    }
265
266    /// Check if savepoint exists
267    pub fn savepoint_exists(&self, txn_id: TxnId, name: &str) -> bool {
268        let txn_map = self.txn_savepoints_read();
269        txn_map
270            .get(&txn_id)
271            .map(|sp| sp.exists(name))
272            .unwrap_or(false)
273    }
274
275    /// Get savepoint depth for transaction
276    pub fn savepoint_depth(&self, txn_id: TxnId) -> usize {
277        let txn_map = self.txn_savepoints_read();
278        txn_map.get(&txn_id).map(|sp| sp.depth()).unwrap_or(0)
279    }
280
281    /// Get all savepoint names for transaction
282    pub fn get_savepoint_names(&self, txn_id: TxnId) -> Vec<String> {
283        let txn_map = self.txn_savepoints_read();
284        txn_map
285            .get(&txn_id)
286            .map(|sp| sp.names().to_vec())
287            .unwrap_or_default()
288    }
289
290    /// Clean up savepoints for a transaction
291    pub fn cleanup_transaction(&self, txn_id: TxnId) {
292        let mut txn_map = self
293            .txn_savepoints
294            .write()
295            .unwrap_or_else(|poisoned| poisoned.into_inner());
296        txn_map.remove(&txn_id);
297    }
298
299    /// Get statistics
300    pub fn stats(&self) -> SavepointStats {
301        let txn_map = self.txn_savepoints_read();
302        SavepointStats {
303            active_transactions: txn_map.len(),
304            total_savepoints: txn_map.values().map(|sp| sp.depth()).sum(),
305        }
306    }
307}
308
309impl Default for SavepointManager {
310    fn default() -> Self {
311        Self::new()
312    }
313}
314
315/// Savepoint error types
316#[derive(Debug, Clone, PartialEq, Eq)]
317pub enum SavepointError {
318    /// Savepoint not found
319    NotFound(String),
320    /// Duplicate savepoint name
321    DuplicateName(String),
322    /// Transaction not found
323    TxnNotFound(TxnId),
324    /// Internal lock was poisoned by a panic
325    LockPoisoned(&'static str),
326    /// Savepoint stack corrupted
327    StackCorrupted,
328}
329
330impl std::fmt::Display for SavepointError {
331    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
332        match self {
333            SavepointError::NotFound(name) => write!(f, "Savepoint '{}' not found", name),
334            SavepointError::DuplicateName(name) => {
335                write!(f, "Savepoint '{}' already exists", name)
336            }
337            SavepointError::TxnNotFound(id) => write!(f, "Transaction {} not found", id),
338            SavepointError::LockPoisoned(name) => write!(f, "Lock poisoned: {}", name),
339            SavepointError::StackCorrupted => write!(f, "Savepoint stack corrupted"),
340        }
341    }
342}
343
344impl std::error::Error for SavepointError {}
345
346/// Savepoint statistics
347#[derive(Debug, Clone, Default)]
348pub struct SavepointStats {
349    /// Number of transactions with savepoints
350    pub active_transactions: usize,
351    /// Total savepoints across all transactions
352    pub total_savepoints: usize,
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358
359    #[test]
360    fn test_savepoint_create() {
361        let sp = Savepoint::new("sp1".to_string(), 1, 100, 5, 10, 0);
362        assert_eq!(sp.name, "sp1");
363        assert_eq!(sp.txn_id, 1);
364        assert_eq!(sp.lsn, 100);
365        assert_eq!(sp.lock_count, 5);
366        assert_eq!(sp.write_set_index, 10);
367        assert_eq!(sp.depth, 0);
368    }
369
370    #[test]
371    fn test_txn_savepoints() {
372        let mut sp = TxnSavepoints::new(1);
373
374        // Create savepoints (name, lsn, lock_count, write_set_index)
375        sp.create("sp1".to_string(), 100, 1, 0);
376        sp.create("sp2".to_string(), 200, 2, 5);
377        sp.create("sp3".to_string(), 300, 3, 10);
378
379        assert_eq!(sp.depth(), 3);
380        assert!(sp.exists("sp1"));
381        assert!(sp.exists("sp2"));
382        assert!(sp.exists("sp3"));
383
384        // Get savepoint
385        let sp1 = sp.get("sp1").unwrap();
386        assert_eq!(sp1.lsn, 100);
387        assert_eq!(sp1.depth, 0);
388
389        let sp3 = sp.get("sp3").unwrap();
390        assert_eq!(sp3.depth, 2);
391    }
392
393    #[test]
394    fn test_savepoint_release() {
395        let mut sp = TxnSavepoints::new(1);
396
397        sp.create("sp1".to_string(), 100, 1, 0);
398        sp.create("sp2".to_string(), 200, 2, 5);
399        sp.create("sp3".to_string(), 300, 3, 10);
400
401        // Release sp2 (should also release sp3)
402        let released = sp.release("sp2").unwrap();
403        assert_eq!(released.name, "sp2");
404
405        assert_eq!(sp.depth(), 1);
406        assert!(sp.exists("sp1"));
407        assert!(!sp.exists("sp2"));
408        assert!(!sp.exists("sp3"));
409    }
410
411    #[test]
412    fn test_savepoint_rollback() {
413        let mut sp = TxnSavepoints::new(1);
414
415        sp.create("sp1".to_string(), 100, 1, 0);
416        sp.create("sp2".to_string(), 200, 2, 5);
417        sp.create("sp3".to_string(), 300, 3, 10);
418
419        // Rollback to sp2
420        let (savepoint, released) = sp.rollback_to("sp2").unwrap();
421        assert_eq!(savepoint.name, "sp2");
422        assert_eq!(savepoint.lsn, 200);
423        assert_eq!(released, vec!["sp3".to_string()]);
424
425        // sp2 should still exist
426        assert!(sp.exists("sp1"));
427        assert!(sp.exists("sp2"));
428        assert!(!sp.exists("sp3"));
429        assert_eq!(sp.depth(), 2);
430    }
431
432    #[test]
433    fn test_savepoint_manager() {
434        let manager = SavepointManager::new();
435
436        // Create savepoints for transaction 1
437        let sp1 = manager
438            .create_savepoint(1, "sp1".to_string(), 100, 1, 0)
439            .unwrap();
440        assert_eq!(sp1.name, "sp1");
441
442        let sp2 = manager
443            .create_savepoint(1, "sp2".to_string(), 200, 2, 0)
444            .unwrap();
445        assert_eq!(sp2.name, "sp2");
446
447        // Duplicate should fail
448        let result = manager.create_savepoint(1, "sp1".to_string(), 300, 3, 0);
449        assert!(matches!(result, Err(SavepointError::DuplicateName(_))));
450
451        // Different transaction can have same name
452        let sp1_tx2 = manager
453            .create_savepoint(2, "sp1".to_string(), 400, 4, 0)
454            .unwrap();
455        assert_eq!(sp1_tx2.txn_id, 2);
456
457        // Check existence
458        assert!(manager.savepoint_exists(1, "sp1"));
459        assert!(manager.savepoint_exists(1, "sp2"));
460        assert!(manager.savepoint_exists(2, "sp1"));
461        assert!(!manager.savepoint_exists(1, "sp3"));
462    }
463
464    #[test]
465    fn test_manager_rollback() {
466        let manager = SavepointManager::new();
467
468        manager
469            .create_savepoint(1, "sp1".to_string(), 100, 1, 0)
470            .unwrap();
471        manager
472            .create_savepoint(1, "sp2".to_string(), 200, 2, 0)
473            .unwrap();
474        manager
475            .create_savepoint(1, "sp3".to_string(), 300, 3, 0)
476            .unwrap();
477
478        // Rollback to sp2
479        let (sp, released) = manager.rollback_to_savepoint(1, "sp2").unwrap();
480        assert_eq!(sp.lsn, 200);
481        assert_eq!(released, vec!["sp3".to_string()]);
482
483        // sp2 should still exist, sp3 should not
484        assert!(manager.savepoint_exists(1, "sp2"));
485        assert!(!manager.savepoint_exists(1, "sp3"));
486    }
487
488    #[test]
489    fn test_manager_cleanup() {
490        let manager = SavepointManager::new();
491
492        manager
493            .create_savepoint(1, "sp1".to_string(), 100, 1, 0)
494            .unwrap();
495        manager
496            .create_savepoint(1, "sp2".to_string(), 200, 2, 0)
497            .unwrap();
498
499        // Cleanup transaction
500        manager.cleanup_transaction(1);
501
502        // All savepoints should be gone
503        assert!(!manager.savepoint_exists(1, "sp1"));
504        assert!(!manager.savepoint_exists(1, "sp2"));
505        assert_eq!(manager.savepoint_depth(1), 0);
506    }
507
508    #[test]
509    fn test_get_savepoint_names() {
510        let manager = SavepointManager::new();
511
512        manager
513            .create_savepoint(1, "first".to_string(), 100, 1, 0)
514            .unwrap();
515        manager
516            .create_savepoint(1, "second".to_string(), 200, 2, 0)
517            .unwrap();
518        manager
519            .create_savepoint(1, "third".to_string(), 300, 3, 0)
520            .unwrap();
521
522        let names = manager.get_savepoint_names(1);
523        assert_eq!(names, vec!["first", "second", "third"]);
524    }
525
526    #[test]
527    fn test_create_savepoint_returns_structured_error_when_registry_lock_is_poisoned() {
528        let manager = std::sync::Arc::new(SavepointManager::new());
529        let poison_target = std::sync::Arc::clone(&manager);
530        let _ = std::thread::spawn(move || {
531            let _guard = poison_target
532                .txn_savepoints
533                .write()
534                .expect("savepoint registry lock should be acquired");
535            panic!("poison savepoint registry");
536        })
537        .join();
538
539        let result = manager.create_savepoint(1, "sp1".to_string(), 100, 0, 0);
540        assert!(matches!(
541            result,
542            Err(SavepointError::LockPoisoned("savepoint registry"))
543        ));
544    }
545}