use std::collections::HashMap;
use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
pub type TxnId = u64;
pub type Lsn = u64;
pub type Timestamp = u64;
#[derive(Debug, Clone)]
pub struct Savepoint {
pub name: String,
pub txn_id: TxnId,
pub lsn: Lsn,
pub created_at: Timestamp,
pub lock_count: usize,
pub write_set_index: usize,
pub depth: usize,
}
impl Savepoint {
pub fn new(
name: String,
txn_id: TxnId,
lsn: Lsn,
lock_count: usize,
write_set_index: usize,
depth: usize,
) -> Self {
use std::time::{SystemTime, UNIX_EPOCH};
Self {
name,
txn_id,
lsn,
created_at: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_micros() as Timestamp,
lock_count,
write_set_index,
depth,
}
}
}
#[derive(Debug)]
pub struct TxnSavepoints {
txn_id: TxnId,
savepoints: HashMap<String, Savepoint>,
stack: Vec<String>,
}
impl TxnSavepoints {
pub fn new(txn_id: TxnId) -> Self {
Self {
txn_id,
savepoints: HashMap::new(),
stack: Vec::new(),
}
}
pub fn create(
&mut self,
name: String,
lsn: Lsn,
lock_count: usize,
write_set_index: usize,
) -> &Savepoint {
let depth = self.stack.len();
let savepoint = Savepoint::new(
name.clone(),
self.txn_id,
lsn,
lock_count,
write_set_index,
depth,
);
self.savepoints.insert(name.clone(), savepoint);
self.stack.push(name.clone());
self.savepoints.get(&name).unwrap()
}
pub fn get(&self, name: &str) -> Option<&Savepoint> {
self.savepoints.get(name)
}
pub fn release(&mut self, name: &str) -> Option<Savepoint> {
if let Some(pos) = self.stack.iter().position(|n| n == name) {
let to_remove: Vec<String> = self.stack.drain(pos..).collect();
let removed = self.savepoints.remove(name);
for nested_name in to_remove.iter().skip(1) {
self.savepoints.remove(nested_name);
}
removed
} else {
None
}
}
pub fn rollback_to(&mut self, name: &str) -> Option<(Savepoint, Vec<String>)> {
if let Some(pos) = self.stack.iter().position(|n| n == name) {
let savepoint = self.savepoints.get(name)?.clone();
let to_release: Vec<String> = self.stack.drain(pos + 1..).collect();
for nested_name in &to_release {
self.savepoints.remove(nested_name);
}
Some((savepoint, to_release))
} else {
None
}
}
pub fn exists(&self, name: &str) -> bool {
self.savepoints.contains_key(name)
}
pub fn depth(&self) -> usize {
self.stack.len()
}
pub fn names(&self) -> &[String] {
&self.stack
}
pub fn clear(&mut self) {
self.savepoints.clear();
self.stack.clear();
}
}
pub struct SavepointManager {
txn_savepoints: RwLock<HashMap<TxnId, TxnSavepoints>>,
}
impl SavepointManager {
pub fn new() -> Self {
Self {
txn_savepoints: RwLock::new(HashMap::new()),
}
}
fn txn_savepoints_write(
&self,
) -> Result<RwLockWriteGuard<'_, HashMap<TxnId, TxnSavepoints>>, SavepointError> {
self.txn_savepoints
.write()
.map_err(|_| SavepointError::LockPoisoned("savepoint registry"))
}
fn txn_savepoints_read(&self) -> RwLockReadGuard<'_, HashMap<TxnId, TxnSavepoints>> {
self.txn_savepoints
.read()
.unwrap_or_else(|poisoned| poisoned.into_inner())
}
pub fn create_savepoint(
&self,
txn_id: TxnId,
name: String,
lsn: Lsn,
lock_count: usize,
write_set_index: usize,
) -> Result<Savepoint, SavepointError> {
let mut txn_map = self.txn_savepoints_write()?;
let txn_sp = txn_map
.entry(txn_id)
.or_insert_with(|| TxnSavepoints::new(txn_id));
if txn_sp.exists(&name) {
return Err(SavepointError::DuplicateName(name));
}
Ok(txn_sp
.create(name, lsn, lock_count, write_set_index)
.clone())
}
pub fn get_savepoint(&self, txn_id: TxnId, name: &str) -> Option<Savepoint> {
let txn_map = self.txn_savepoints_read();
txn_map.get(&txn_id).and_then(|sp| sp.get(name).cloned())
}
pub fn release_savepoint(
&self,
txn_id: TxnId,
name: &str,
) -> Result<Savepoint, SavepointError> {
let mut txn_map = self.txn_savepoints_write()?;
let txn_sp = txn_map
.get_mut(&txn_id)
.ok_or(SavepointError::TxnNotFound(txn_id))?;
txn_sp
.release(name)
.ok_or_else(|| SavepointError::NotFound(name.to_string()))
}
pub fn rollback_to_savepoint(
&self,
txn_id: TxnId,
name: &str,
) -> Result<(Savepoint, Vec<String>), SavepointError> {
let mut txn_map = self.txn_savepoints_write()?;
let txn_sp = txn_map
.get_mut(&txn_id)
.ok_or(SavepointError::TxnNotFound(txn_id))?;
txn_sp
.rollback_to(name)
.ok_or_else(|| SavepointError::NotFound(name.to_string()))
}
pub fn savepoint_exists(&self, txn_id: TxnId, name: &str) -> bool {
let txn_map = self.txn_savepoints_read();
txn_map
.get(&txn_id)
.map(|sp| sp.exists(name))
.unwrap_or(false)
}
pub fn savepoint_depth(&self, txn_id: TxnId) -> usize {
let txn_map = self.txn_savepoints_read();
txn_map.get(&txn_id).map(|sp| sp.depth()).unwrap_or(0)
}
pub fn get_savepoint_names(&self, txn_id: TxnId) -> Vec<String> {
let txn_map = self.txn_savepoints_read();
txn_map
.get(&txn_id)
.map(|sp| sp.names().to_vec())
.unwrap_or_default()
}
pub fn cleanup_transaction(&self, txn_id: TxnId) {
let mut txn_map = self
.txn_savepoints
.write()
.unwrap_or_else(|poisoned| poisoned.into_inner());
txn_map.remove(&txn_id);
}
pub fn stats(&self) -> SavepointStats {
let txn_map = self.txn_savepoints_read();
SavepointStats {
active_transactions: txn_map.len(),
total_savepoints: txn_map.values().map(|sp| sp.depth()).sum(),
}
}
}
impl Default for SavepointManager {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SavepointError {
NotFound(String),
DuplicateName(String),
TxnNotFound(TxnId),
LockPoisoned(&'static str),
StackCorrupted,
}
impl std::fmt::Display for SavepointError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
SavepointError::NotFound(name) => write!(f, "Savepoint '{}' not found", name),
SavepointError::DuplicateName(name) => {
write!(f, "Savepoint '{}' already exists", name)
}
SavepointError::TxnNotFound(id) => write!(f, "Transaction {} not found", id),
SavepointError::LockPoisoned(name) => write!(f, "Lock poisoned: {}", name),
SavepointError::StackCorrupted => write!(f, "Savepoint stack corrupted"),
}
}
}
impl std::error::Error for SavepointError {}
#[derive(Debug, Clone, Default)]
pub struct SavepointStats {
pub active_transactions: usize,
pub total_savepoints: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_savepoint_create() {
let sp = Savepoint::new("sp1".to_string(), 1, 100, 5, 10, 0);
assert_eq!(sp.name, "sp1");
assert_eq!(sp.txn_id, 1);
assert_eq!(sp.lsn, 100);
assert_eq!(sp.lock_count, 5);
assert_eq!(sp.write_set_index, 10);
assert_eq!(sp.depth, 0);
}
#[test]
fn test_txn_savepoints() {
let mut sp = TxnSavepoints::new(1);
sp.create("sp1".to_string(), 100, 1, 0);
sp.create("sp2".to_string(), 200, 2, 5);
sp.create("sp3".to_string(), 300, 3, 10);
assert_eq!(sp.depth(), 3);
assert!(sp.exists("sp1"));
assert!(sp.exists("sp2"));
assert!(sp.exists("sp3"));
let sp1 = sp.get("sp1").unwrap();
assert_eq!(sp1.lsn, 100);
assert_eq!(sp1.depth, 0);
let sp3 = sp.get("sp3").unwrap();
assert_eq!(sp3.depth, 2);
}
#[test]
fn test_savepoint_release() {
let mut sp = TxnSavepoints::new(1);
sp.create("sp1".to_string(), 100, 1, 0);
sp.create("sp2".to_string(), 200, 2, 5);
sp.create("sp3".to_string(), 300, 3, 10);
let released = sp.release("sp2").unwrap();
assert_eq!(released.name, "sp2");
assert_eq!(sp.depth(), 1);
assert!(sp.exists("sp1"));
assert!(!sp.exists("sp2"));
assert!(!sp.exists("sp3"));
}
#[test]
fn test_savepoint_rollback() {
let mut sp = TxnSavepoints::new(1);
sp.create("sp1".to_string(), 100, 1, 0);
sp.create("sp2".to_string(), 200, 2, 5);
sp.create("sp3".to_string(), 300, 3, 10);
let (savepoint, released) = sp.rollback_to("sp2").unwrap();
assert_eq!(savepoint.name, "sp2");
assert_eq!(savepoint.lsn, 200);
assert_eq!(released, vec!["sp3".to_string()]);
assert!(sp.exists("sp1"));
assert!(sp.exists("sp2"));
assert!(!sp.exists("sp3"));
assert_eq!(sp.depth(), 2);
}
#[test]
fn test_savepoint_manager() {
let manager = SavepointManager::new();
let sp1 = manager
.create_savepoint(1, "sp1".to_string(), 100, 1, 0)
.unwrap();
assert_eq!(sp1.name, "sp1");
let sp2 = manager
.create_savepoint(1, "sp2".to_string(), 200, 2, 0)
.unwrap();
assert_eq!(sp2.name, "sp2");
let result = manager.create_savepoint(1, "sp1".to_string(), 300, 3, 0);
assert!(matches!(result, Err(SavepointError::DuplicateName(_))));
let sp1_tx2 = manager
.create_savepoint(2, "sp1".to_string(), 400, 4, 0)
.unwrap();
assert_eq!(sp1_tx2.txn_id, 2);
assert!(manager.savepoint_exists(1, "sp1"));
assert!(manager.savepoint_exists(1, "sp2"));
assert!(manager.savepoint_exists(2, "sp1"));
assert!(!manager.savepoint_exists(1, "sp3"));
}
#[test]
fn test_manager_rollback() {
let manager = SavepointManager::new();
manager
.create_savepoint(1, "sp1".to_string(), 100, 1, 0)
.unwrap();
manager
.create_savepoint(1, "sp2".to_string(), 200, 2, 0)
.unwrap();
manager
.create_savepoint(1, "sp3".to_string(), 300, 3, 0)
.unwrap();
let (sp, released) = manager.rollback_to_savepoint(1, "sp2").unwrap();
assert_eq!(sp.lsn, 200);
assert_eq!(released, vec!["sp3".to_string()]);
assert!(manager.savepoint_exists(1, "sp2"));
assert!(!manager.savepoint_exists(1, "sp3"));
}
#[test]
fn test_manager_cleanup() {
let manager = SavepointManager::new();
manager
.create_savepoint(1, "sp1".to_string(), 100, 1, 0)
.unwrap();
manager
.create_savepoint(1, "sp2".to_string(), 200, 2, 0)
.unwrap();
manager.cleanup_transaction(1);
assert!(!manager.savepoint_exists(1, "sp1"));
assert!(!manager.savepoint_exists(1, "sp2"));
assert_eq!(manager.savepoint_depth(1), 0);
}
#[test]
fn test_get_savepoint_names() {
let manager = SavepointManager::new();
manager
.create_savepoint(1, "first".to_string(), 100, 1, 0)
.unwrap();
manager
.create_savepoint(1, "second".to_string(), 200, 2, 0)
.unwrap();
manager
.create_savepoint(1, "third".to_string(), 300, 3, 0)
.unwrap();
let names = manager.get_savepoint_names(1);
assert_eq!(names, vec!["first", "second", "third"]);
}
#[test]
fn test_create_savepoint_returns_structured_error_when_registry_lock_is_poisoned() {
let manager = std::sync::Arc::new(SavepointManager::new());
let poison_target = std::sync::Arc::clone(&manager);
let _ = std::thread::spawn(move || {
let _guard = poison_target
.txn_savepoints
.write()
.expect("savepoint registry lock should be acquired");
panic!("poison savepoint registry");
})
.join();
let result = manager.create_savepoint(1, "sp1".to_string(), 100, 0, 0);
assert!(matches!(
result,
Err(SavepointError::LockPoisoned("savepoint registry"))
));
}
}