use crate::error::{DbxError, DbxResult};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use std::sync::{Arc, RwLock};
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Checkpoint {
pub id: String,
pub timestamp: i64,
pub description: String,
pub state: HashMap<String, serde_json::Value>,
}
impl Checkpoint {
pub fn new(id: String, description: String) -> Self {
Self {
id,
timestamp: SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs() as i64,
description,
state: HashMap::new(),
}
}
pub fn add_state<T: Serialize>(&mut self, key: String, value: &T) -> DbxResult<()> {
let json_value = serde_json::to_value(value)?;
self.state.insert(key, json_value);
Ok(())
}
pub fn get_state<T: for<'de> Deserialize<'de>>(&self, key: &str) -> DbxResult<T> {
let json_value = self
.state
.get(key)
.ok_or_else(|| DbxError::Serialization(format!("State key '{}' not found", key)))?;
let value = serde_json::from_value(json_value.clone())?;
Ok(value)
}
}
pub struct RollbackManager {
checkpoints: Arc<RwLock<HashMap<String, Checkpoint>>>,
checkpoint_dir: PathBuf,
auto_rollback_enabled: bool,
}
impl RollbackManager {
pub fn new() -> Self {
Self {
checkpoints: Arc::new(RwLock::new(HashMap::new())),
checkpoint_dir: PathBuf::from("target/checkpoints"),
auto_rollback_enabled: false,
}
}
pub fn with_checkpoint_dir(mut self, dir: PathBuf) -> Self {
self.checkpoint_dir = dir;
self
}
pub fn with_auto_rollback(mut self, enabled: bool) -> Self {
self.auto_rollback_enabled = enabled;
self
}
pub fn create_checkpoint(&self, id: String, description: String) -> DbxResult<Checkpoint> {
let checkpoint = Checkpoint::new(id.clone(), description);
self.checkpoints
.write()
.unwrap()
.insert(id.clone(), checkpoint.clone());
self.save_checkpoint(&checkpoint)?;
Ok(checkpoint)
}
fn save_checkpoint(&self, checkpoint: &Checkpoint) -> DbxResult<()> {
fs::create_dir_all(&self.checkpoint_dir)?;
let file_path = self.checkpoint_dir.join(format!("{}.json", checkpoint.id));
let json = serde_json::to_string_pretty(checkpoint)?;
fs::write(file_path, json)?;
Ok(())
}
fn load_checkpoint(&self, id: &str) -> DbxResult<Checkpoint> {
let file_path = self.checkpoint_dir.join(format!("{}.json", id));
if !file_path.exists() {
return Err(DbxError::Serialization(format!(
"Checkpoint '{}' not found",
id
)));
}
let json = fs::read_to_string(file_path)?;
let checkpoint: Checkpoint = serde_json::from_str(&json)?;
Ok(checkpoint)
}
pub fn rollback_to_checkpoint(&self, id: &str) -> DbxResult<Checkpoint> {
let checkpoint = self.load_checkpoint(id)?;
self.checkpoints
.write()
.unwrap()
.insert(id.to_string(), checkpoint.clone());
Ok(checkpoint)
}
pub fn get_checkpoint(&self, id: &str) -> Option<Checkpoint> {
self.checkpoints.read().unwrap().get(id).cloned()
}
pub fn list_checkpoints(&self) -> Vec<Checkpoint> {
self.checkpoints.read().unwrap().values().cloned().collect()
}
pub fn delete_checkpoint(&self, id: &str) -> DbxResult<()> {
self.checkpoints.write().unwrap().remove(id);
let file_path = self.checkpoint_dir.join(format!("{}.json", id));
if file_path.exists() {
fs::remove_file(file_path)?;
}
Ok(())
}
pub fn trigger_auto_rollback(&self, reason: &str) -> DbxResult<()> {
if !self.auto_rollback_enabled {
return Ok(());
}
let checkpoints = self.list_checkpoints();
if let Some(latest) = checkpoints.iter().max_by_key(|c| c.timestamp) {
eprintln!("Auto-rollback triggered: {}", reason);
eprintln!("Rolling back to checkpoint: {}", latest.id);
self.rollback_to_checkpoint(&latest.id)?;
}
Ok(())
}
}
impl Default for RollbackManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_checkpoint_creation() {
let manager =
RollbackManager::new().with_checkpoint_dir(PathBuf::from("target/test_checkpoints"));
let checkpoint = manager
.create_checkpoint("test_cp_1".to_string(), "Test checkpoint".to_string())
.unwrap();
assert_eq!(checkpoint.id, "test_cp_1");
assert_eq!(checkpoint.description, "Test checkpoint");
assert!(checkpoint.timestamp > 0);
let loaded = manager.get_checkpoint("test_cp_1");
assert!(loaded.is_some());
let _ = manager.delete_checkpoint("test_cp_1");
}
#[test]
fn test_rollback_to_checkpoint() {
let manager =
RollbackManager::new().with_checkpoint_dir(PathBuf::from("target/test_checkpoints"));
let mut checkpoint = manager
.create_checkpoint("test_cp_2".to_string(), "Rollback test".to_string())
.unwrap();
checkpoint
.add_state("key1".to_string(), &"value1".to_string())
.unwrap();
checkpoint.add_state("key2".to_string(), &42).unwrap();
manager
.checkpoints
.write()
.unwrap()
.insert("test_cp_2".to_string(), checkpoint.clone());
manager.save_checkpoint(&checkpoint).unwrap();
manager.checkpoints.write().unwrap().clear();
let restored = manager.rollback_to_checkpoint("test_cp_2").unwrap();
assert_eq!(restored.id, "test_cp_2");
let value1: String = restored.get_state("key1").unwrap();
let value2: i32 = restored.get_state("key2").unwrap();
assert_eq!(value1, "value1");
assert_eq!(value2, 42);
let _ = manager.delete_checkpoint("test_cp_2");
}
#[test]
fn test_auto_rollback_on_regression() {
let manager = RollbackManager::new()
.with_checkpoint_dir(PathBuf::from("target/test_checkpoints"))
.with_auto_rollback(true);
manager
.create_checkpoint("test_cp_3".to_string(), "Auto-rollback test".to_string())
.unwrap();
manager
.trigger_auto_rollback("Performance regression detected")
.unwrap();
let checkpoint = manager.get_checkpoint("test_cp_3");
assert!(checkpoint.is_some());
let _ = manager.delete_checkpoint("test_cp_3");
}
}