1use std::collections::HashMap;
6use std::sync::{RwLock, RwLockReadGuard, RwLockWriteGuard};
7
8pub type TxnId = u64;
10
11pub type Lsn = u64;
13
14pub type Timestamp = u64;
16
17#[derive(Debug, Clone)]
19pub struct Savepoint {
20 pub name: String,
22 pub txn_id: TxnId,
24 pub lsn: Lsn,
26 pub created_at: Timestamp,
28 pub lock_count: usize,
30 pub write_set_index: usize,
32 pub depth: usize,
34}
35
36impl Savepoint {
37 pub fn new(
39 name: String,
40 txn_id: TxnId,
41 lsn: Lsn,
42 lock_count: usize,
43 write_set_index: usize,
44 depth: usize,
45 ) -> Self {
46 use std::time::{SystemTime, UNIX_EPOCH};
47
48 Self {
49 name,
50 txn_id,
51 lsn,
52 created_at: SystemTime::now()
53 .duration_since(UNIX_EPOCH)
54 .unwrap_or_default()
55 .as_micros() as Timestamp,
56 lock_count,
57 write_set_index,
58 depth,
59 }
60 }
61}
62
63#[derive(Debug)]
65pub struct TxnSavepoints {
66 txn_id: TxnId,
68 savepoints: HashMap<String, Savepoint>,
70 stack: Vec<String>,
72}
73
74impl TxnSavepoints {
75 pub fn new(txn_id: TxnId) -> Self {
77 Self {
78 txn_id,
79 savepoints: HashMap::new(),
80 stack: Vec::new(),
81 }
82 }
83
84 pub fn create(
86 &mut self,
87 name: String,
88 lsn: Lsn,
89 lock_count: usize,
90 write_set_index: usize,
91 ) -> &Savepoint {
92 let depth = self.stack.len();
93 let savepoint = Savepoint::new(
94 name.clone(),
95 self.txn_id,
96 lsn,
97 lock_count,
98 write_set_index,
99 depth,
100 );
101
102 self.savepoints.insert(name.clone(), savepoint);
103 self.stack.push(name.clone());
104
105 self.savepoints.get(&name).unwrap()
106 }
107
108 pub fn get(&self, name: &str) -> Option<&Savepoint> {
110 self.savepoints.get(name)
111 }
112
113 pub fn release(&mut self, name: &str) -> Option<Savepoint> {
115 if let Some(pos) = self.stack.iter().position(|n| n == name) {
117 let to_remove: Vec<String> = self.stack.drain(pos..).collect();
119 let removed = self.savepoints.remove(name);
120
121 for nested_name in to_remove.iter().skip(1) {
122 self.savepoints.remove(nested_name);
123 }
124
125 removed
126 } else {
127 None
128 }
129 }
130
131 pub fn rollback_to(&mut self, name: &str) -> Option<(Savepoint, Vec<String>)> {
133 if let Some(pos) = self.stack.iter().position(|n| n == name) {
135 let savepoint = self.savepoints.get(name)?.clone();
137
138 let to_release: Vec<String> = self.stack.drain(pos + 1..).collect();
140
141 for nested_name in &to_release {
143 self.savepoints.remove(nested_name);
144 }
145
146 Some((savepoint, to_release))
147 } else {
148 None
149 }
150 }
151
152 pub fn exists(&self, name: &str) -> bool {
154 self.savepoints.contains_key(name)
155 }
156
157 pub fn depth(&self) -> usize {
159 self.stack.len()
160 }
161
162 pub fn names(&self) -> &[String] {
164 &self.stack
165 }
166
167 pub fn clear(&mut self) {
169 self.savepoints.clear();
170 self.stack.clear();
171 }
172}
173
174pub struct SavepointManager {
176 txn_savepoints: RwLock<HashMap<TxnId, TxnSavepoints>>,
178}
179
180impl SavepointManager {
181 pub fn new() -> Self {
183 Self {
184 txn_savepoints: RwLock::new(HashMap::new()),
185 }
186 }
187
188 fn txn_savepoints_write(
189 &self,
190 ) -> Result<RwLockWriteGuard<'_, HashMap<TxnId, TxnSavepoints>>, SavepointError> {
191 self.txn_savepoints
192 .write()
193 .map_err(|_| SavepointError::LockPoisoned("savepoint registry"))
194 }
195
196 fn txn_savepoints_read(&self) -> RwLockReadGuard<'_, HashMap<TxnId, TxnSavepoints>> {
197 self.txn_savepoints
198 .read()
199 .unwrap_or_else(|poisoned| poisoned.into_inner())
200 }
201
202 pub fn create_savepoint(
204 &self,
205 txn_id: TxnId,
206 name: String,
207 lsn: Lsn,
208 lock_count: usize,
209 write_set_index: usize,
210 ) -> Result<Savepoint, SavepointError> {
211 let mut txn_map = self.txn_savepoints_write()?;
212 let txn_sp = txn_map
213 .entry(txn_id)
214 .or_insert_with(|| TxnSavepoints::new(txn_id));
215
216 if txn_sp.exists(&name) {
218 return Err(SavepointError::DuplicateName(name));
219 }
220
221 Ok(txn_sp
222 .create(name, lsn, lock_count, write_set_index)
223 .clone())
224 }
225
226 pub fn get_savepoint(&self, txn_id: TxnId, name: &str) -> Option<Savepoint> {
228 let txn_map = self.txn_savepoints_read();
229 txn_map.get(&txn_id).and_then(|sp| sp.get(name).cloned())
230 }
231
232 pub fn release_savepoint(
234 &self,
235 txn_id: TxnId,
236 name: &str,
237 ) -> Result<Savepoint, SavepointError> {
238 let mut txn_map = self.txn_savepoints_write()?;
239
240 let txn_sp = txn_map
241 .get_mut(&txn_id)
242 .ok_or(SavepointError::TxnNotFound(txn_id))?;
243
244 txn_sp
245 .release(name)
246 .ok_or_else(|| SavepointError::NotFound(name.to_string()))
247 }
248
249 pub fn rollback_to_savepoint(
251 &self,
252 txn_id: TxnId,
253 name: &str,
254 ) -> Result<(Savepoint, Vec<String>), SavepointError> {
255 let mut txn_map = self.txn_savepoints_write()?;
256
257 let txn_sp = txn_map
258 .get_mut(&txn_id)
259 .ok_or(SavepointError::TxnNotFound(txn_id))?;
260
261 txn_sp
262 .rollback_to(name)
263 .ok_or_else(|| SavepointError::NotFound(name.to_string()))
264 }
265
266 pub fn savepoint_exists(&self, txn_id: TxnId, name: &str) -> bool {
268 let txn_map = self.txn_savepoints_read();
269 txn_map
270 .get(&txn_id)
271 .map(|sp| sp.exists(name))
272 .unwrap_or(false)
273 }
274
275 pub fn savepoint_depth(&self, txn_id: TxnId) -> usize {
277 let txn_map = self.txn_savepoints_read();
278 txn_map.get(&txn_id).map(|sp| sp.depth()).unwrap_or(0)
279 }
280
281 pub fn get_savepoint_names(&self, txn_id: TxnId) -> Vec<String> {
283 let txn_map = self.txn_savepoints_read();
284 txn_map
285 .get(&txn_id)
286 .map(|sp| sp.names().to_vec())
287 .unwrap_or_default()
288 }
289
290 pub fn cleanup_transaction(&self, txn_id: TxnId) {
292 let mut txn_map = self
293 .txn_savepoints
294 .write()
295 .unwrap_or_else(|poisoned| poisoned.into_inner());
296 txn_map.remove(&txn_id);
297 }
298
299 pub fn stats(&self) -> SavepointStats {
301 let txn_map = self.txn_savepoints_read();
302 SavepointStats {
303 active_transactions: txn_map.len(),
304 total_savepoints: txn_map.values().map(|sp| sp.depth()).sum(),
305 }
306 }
307}
308
309impl Default for SavepointManager {
310 fn default() -> Self {
311 Self::new()
312 }
313}
314
315#[derive(Debug, Clone, PartialEq, Eq)]
317pub enum SavepointError {
318 NotFound(String),
320 DuplicateName(String),
322 TxnNotFound(TxnId),
324 LockPoisoned(&'static str),
326 StackCorrupted,
328}
329
330impl std::fmt::Display for SavepointError {
331 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
332 match self {
333 SavepointError::NotFound(name) => write!(f, "Savepoint '{}' not found", name),
334 SavepointError::DuplicateName(name) => {
335 write!(f, "Savepoint '{}' already exists", name)
336 }
337 SavepointError::TxnNotFound(id) => write!(f, "Transaction {} not found", id),
338 SavepointError::LockPoisoned(name) => write!(f, "Lock poisoned: {}", name),
339 SavepointError::StackCorrupted => write!(f, "Savepoint stack corrupted"),
340 }
341 }
342}
343
344impl std::error::Error for SavepointError {}
345
346#[derive(Debug, Clone, Default)]
348pub struct SavepointStats {
349 pub active_transactions: usize,
351 pub total_savepoints: usize,
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358
359 #[test]
360 fn test_savepoint_create() {
361 let sp = Savepoint::new("sp1".to_string(), 1, 100, 5, 10, 0);
362 assert_eq!(sp.name, "sp1");
363 assert_eq!(sp.txn_id, 1);
364 assert_eq!(sp.lsn, 100);
365 assert_eq!(sp.lock_count, 5);
366 assert_eq!(sp.write_set_index, 10);
367 assert_eq!(sp.depth, 0);
368 }
369
370 #[test]
371 fn test_txn_savepoints() {
372 let mut sp = TxnSavepoints::new(1);
373
374 sp.create("sp1".to_string(), 100, 1, 0);
376 sp.create("sp2".to_string(), 200, 2, 5);
377 sp.create("sp3".to_string(), 300, 3, 10);
378
379 assert_eq!(sp.depth(), 3);
380 assert!(sp.exists("sp1"));
381 assert!(sp.exists("sp2"));
382 assert!(sp.exists("sp3"));
383
384 let sp1 = sp.get("sp1").unwrap();
386 assert_eq!(sp1.lsn, 100);
387 assert_eq!(sp1.depth, 0);
388
389 let sp3 = sp.get("sp3").unwrap();
390 assert_eq!(sp3.depth, 2);
391 }
392
393 #[test]
394 fn test_savepoint_release() {
395 let mut sp = TxnSavepoints::new(1);
396
397 sp.create("sp1".to_string(), 100, 1, 0);
398 sp.create("sp2".to_string(), 200, 2, 5);
399 sp.create("sp3".to_string(), 300, 3, 10);
400
401 let released = sp.release("sp2").unwrap();
403 assert_eq!(released.name, "sp2");
404
405 assert_eq!(sp.depth(), 1);
406 assert!(sp.exists("sp1"));
407 assert!(!sp.exists("sp2"));
408 assert!(!sp.exists("sp3"));
409 }
410
411 #[test]
412 fn test_savepoint_rollback() {
413 let mut sp = TxnSavepoints::new(1);
414
415 sp.create("sp1".to_string(), 100, 1, 0);
416 sp.create("sp2".to_string(), 200, 2, 5);
417 sp.create("sp3".to_string(), 300, 3, 10);
418
419 let (savepoint, released) = sp.rollback_to("sp2").unwrap();
421 assert_eq!(savepoint.name, "sp2");
422 assert_eq!(savepoint.lsn, 200);
423 assert_eq!(released, vec!["sp3".to_string()]);
424
425 assert!(sp.exists("sp1"));
427 assert!(sp.exists("sp2"));
428 assert!(!sp.exists("sp3"));
429 assert_eq!(sp.depth(), 2);
430 }
431
432 #[test]
433 fn test_savepoint_manager() {
434 let manager = SavepointManager::new();
435
436 let sp1 = manager
438 .create_savepoint(1, "sp1".to_string(), 100, 1, 0)
439 .unwrap();
440 assert_eq!(sp1.name, "sp1");
441
442 let sp2 = manager
443 .create_savepoint(1, "sp2".to_string(), 200, 2, 0)
444 .unwrap();
445 assert_eq!(sp2.name, "sp2");
446
447 let result = manager.create_savepoint(1, "sp1".to_string(), 300, 3, 0);
449 assert!(matches!(result, Err(SavepointError::DuplicateName(_))));
450
451 let sp1_tx2 = manager
453 .create_savepoint(2, "sp1".to_string(), 400, 4, 0)
454 .unwrap();
455 assert_eq!(sp1_tx2.txn_id, 2);
456
457 assert!(manager.savepoint_exists(1, "sp1"));
459 assert!(manager.savepoint_exists(1, "sp2"));
460 assert!(manager.savepoint_exists(2, "sp1"));
461 assert!(!manager.savepoint_exists(1, "sp3"));
462 }
463
464 #[test]
465 fn test_manager_rollback() {
466 let manager = SavepointManager::new();
467
468 manager
469 .create_savepoint(1, "sp1".to_string(), 100, 1, 0)
470 .unwrap();
471 manager
472 .create_savepoint(1, "sp2".to_string(), 200, 2, 0)
473 .unwrap();
474 manager
475 .create_savepoint(1, "sp3".to_string(), 300, 3, 0)
476 .unwrap();
477
478 let (sp, released) = manager.rollback_to_savepoint(1, "sp2").unwrap();
480 assert_eq!(sp.lsn, 200);
481 assert_eq!(released, vec!["sp3".to_string()]);
482
483 assert!(manager.savepoint_exists(1, "sp2"));
485 assert!(!manager.savepoint_exists(1, "sp3"));
486 }
487
488 #[test]
489 fn test_manager_cleanup() {
490 let manager = SavepointManager::new();
491
492 manager
493 .create_savepoint(1, "sp1".to_string(), 100, 1, 0)
494 .unwrap();
495 manager
496 .create_savepoint(1, "sp2".to_string(), 200, 2, 0)
497 .unwrap();
498
499 manager.cleanup_transaction(1);
501
502 assert!(!manager.savepoint_exists(1, "sp1"));
504 assert!(!manager.savepoint_exists(1, "sp2"));
505 assert_eq!(manager.savepoint_depth(1), 0);
506 }
507
508 #[test]
509 fn test_get_savepoint_names() {
510 let manager = SavepointManager::new();
511
512 manager
513 .create_savepoint(1, "first".to_string(), 100, 1, 0)
514 .unwrap();
515 manager
516 .create_savepoint(1, "second".to_string(), 200, 2, 0)
517 .unwrap();
518 manager
519 .create_savepoint(1, "third".to_string(), 300, 3, 0)
520 .unwrap();
521
522 let names = manager.get_savepoint_names(1);
523 assert_eq!(names, vec!["first", "second", "third"]);
524 }
525
526 #[test]
527 fn test_create_savepoint_returns_structured_error_when_registry_lock_is_poisoned() {
528 let manager = std::sync::Arc::new(SavepointManager::new());
529 let poison_target = std::sync::Arc::clone(&manager);
530 let _ = std::thread::spawn(move || {
531 let _guard = poison_target
532 .txn_savepoints
533 .write()
534 .expect("savepoint registry lock should be acquired");
535 panic!("poison savepoint registry");
536 })
537 .join();
538
539 let result = manager.create_savepoint(1, "sp1".to_string(), 100, 0, 0);
540 assert!(matches!(
541 result,
542 Err(SavepointError::LockPoisoned("savepoint registry"))
543 ));
544 }
545}