dbx_core/engine/
rollback.rs1use crate::error::{DbxError, DbxResult};
9use serde::{Deserialize, Serialize};
10use std::collections::HashMap;
11use std::fs;
12use std::path::PathBuf;
13use std::sync::{Arc, RwLock};
14use std::time::{SystemTime, UNIX_EPOCH};
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct Checkpoint {
19 pub id: String,
21
22 pub timestamp: i64,
24
25 pub description: String,
27
28 pub state: HashMap<String, serde_json::Value>,
30}
31
32impl Checkpoint {
33 pub fn new(id: String, description: String) -> Self {
35 Self {
36 id,
37 timestamp: SystemTime::now()
38 .duration_since(UNIX_EPOCH)
39 .unwrap()
40 .as_secs() as i64,
41 description,
42 state: HashMap::new(),
43 }
44 }
45
46 pub fn add_state<T: Serialize>(&mut self, key: String, value: &T) -> DbxResult<()> {
48 let json_value = serde_json::to_value(value)?;
49 self.state.insert(key, json_value);
50 Ok(())
51 }
52
53 pub fn get_state<T: for<'de> Deserialize<'de>>(&self, key: &str) -> DbxResult<T> {
55 let json_value = self
56 .state
57 .get(key)
58 .ok_or_else(|| DbxError::Serialization(format!("State key '{}' not found", key)))?;
59
60 let value = serde_json::from_value(json_value.clone())?;
61 Ok(value)
62 }
63}
64
65pub struct RollbackManager {
67 checkpoints: Arc<RwLock<HashMap<String, Checkpoint>>>,
69
70 checkpoint_dir: PathBuf,
72
73 auto_rollback_enabled: bool,
75}
76
77impl RollbackManager {
78 pub fn new() -> Self {
80 Self {
81 checkpoints: Arc::new(RwLock::new(HashMap::new())),
82 checkpoint_dir: PathBuf::from("target/checkpoints"),
83 auto_rollback_enabled: false,
84 }
85 }
86
87 pub fn with_checkpoint_dir(mut self, dir: PathBuf) -> Self {
89 self.checkpoint_dir = dir;
90 self
91 }
92
93 pub fn with_auto_rollback(mut self, enabled: bool) -> Self {
95 self.auto_rollback_enabled = enabled;
96 self
97 }
98
99 pub fn create_checkpoint(&self, id: String, description: String) -> DbxResult<Checkpoint> {
101 let checkpoint = Checkpoint::new(id.clone(), description);
102
103 self.checkpoints
105 .write()
106 .unwrap()
107 .insert(id.clone(), checkpoint.clone());
108
109 self.save_checkpoint(&checkpoint)?;
111
112 Ok(checkpoint)
113 }
114
115 fn save_checkpoint(&self, checkpoint: &Checkpoint) -> DbxResult<()> {
117 fs::create_dir_all(&self.checkpoint_dir)?;
119
120 let file_path = self.checkpoint_dir.join(format!("{}.json", checkpoint.id));
122
123 let json = serde_json::to_string_pretty(checkpoint)?;
125
126 fs::write(file_path, json)?;
128
129 Ok(())
130 }
131
132 fn load_checkpoint(&self, id: &str) -> DbxResult<Checkpoint> {
134 let file_path = self.checkpoint_dir.join(format!("{}.json", id));
135
136 if !file_path.exists() {
137 return Err(DbxError::Serialization(format!(
138 "Checkpoint '{}' not found",
139 id
140 )));
141 }
142
143 let json = fs::read_to_string(file_path)?;
144 let checkpoint: Checkpoint = serde_json::from_str(&json)?;
145
146 Ok(checkpoint)
147 }
148
149 pub fn rollback_to_checkpoint(&self, id: &str) -> DbxResult<Checkpoint> {
151 let checkpoint = self.load_checkpoint(id)?;
153
154 self.checkpoints
156 .write()
157 .unwrap()
158 .insert(id.to_string(), checkpoint.clone());
159
160 Ok(checkpoint)
161 }
162
163 pub fn get_checkpoint(&self, id: &str) -> Option<Checkpoint> {
165 self.checkpoints.read().unwrap().get(id).cloned()
166 }
167
168 pub fn list_checkpoints(&self) -> Vec<Checkpoint> {
170 self.checkpoints.read().unwrap().values().cloned().collect()
171 }
172
173 pub fn delete_checkpoint(&self, id: &str) -> DbxResult<()> {
175 self.checkpoints.write().unwrap().remove(id);
177
178 let file_path = self.checkpoint_dir.join(format!("{}.json", id));
180 if file_path.exists() {
181 fs::remove_file(file_path)?;
182 }
183
184 Ok(())
185 }
186
187 pub fn trigger_auto_rollback(&self, reason: &str) -> DbxResult<()> {
189 if !self.auto_rollback_enabled {
190 return Ok(());
191 }
192
193 let checkpoints = self.list_checkpoints();
195 if let Some(latest) = checkpoints.iter().max_by_key(|c| c.timestamp) {
196 eprintln!("Auto-rollback triggered: {}", reason);
197 eprintln!("Rolling back to checkpoint: {}", latest.id);
198 self.rollback_to_checkpoint(&latest.id)?;
199 }
200
201 Ok(())
202 }
203}
204
205impl Default for RollbackManager {
206 fn default() -> Self {
207 Self::new()
208 }
209}
210
211#[cfg(test)]
212mod tests {
213 use super::*;
214
215 #[test]
218 fn test_checkpoint_creation() {
219 let manager =
220 RollbackManager::new().with_checkpoint_dir(PathBuf::from("target/test_checkpoints"));
221
222 let checkpoint = manager
223 .create_checkpoint("test_cp_1".to_string(), "Test checkpoint".to_string())
224 .unwrap();
225
226 assert_eq!(checkpoint.id, "test_cp_1");
227 assert_eq!(checkpoint.description, "Test checkpoint");
228 assert!(checkpoint.timestamp > 0);
229
230 let loaded = manager.get_checkpoint("test_cp_1");
232 assert!(loaded.is_some());
233
234 let _ = manager.delete_checkpoint("test_cp_1");
236 }
237
238 #[test]
239 fn test_rollback_to_checkpoint() {
240 let manager =
241 RollbackManager::new().with_checkpoint_dir(PathBuf::from("target/test_checkpoints"));
242
243 let mut checkpoint = manager
245 .create_checkpoint("test_cp_2".to_string(), "Rollback test".to_string())
246 .unwrap();
247
248 checkpoint
250 .add_state("key1".to_string(), &"value1".to_string())
251 .unwrap();
252 checkpoint.add_state("key2".to_string(), &42).unwrap();
253
254 manager
256 .checkpoints
257 .write()
258 .unwrap()
259 .insert("test_cp_2".to_string(), checkpoint.clone());
260 manager.save_checkpoint(&checkpoint).unwrap();
261
262 manager.checkpoints.write().unwrap().clear();
264
265 let restored = manager.rollback_to_checkpoint("test_cp_2").unwrap();
267
268 assert_eq!(restored.id, "test_cp_2");
270 let value1: String = restored.get_state("key1").unwrap();
271 let value2: i32 = restored.get_state("key2").unwrap();
272 assert_eq!(value1, "value1");
273 assert_eq!(value2, 42);
274
275 let _ = manager.delete_checkpoint("test_cp_2");
277 }
278
279 #[test]
280 fn test_auto_rollback_on_regression() {
281 let manager = RollbackManager::new()
282 .with_checkpoint_dir(PathBuf::from("target/test_checkpoints"))
283 .with_auto_rollback(true);
284
285 manager
287 .create_checkpoint("test_cp_3".to_string(), "Auto-rollback test".to_string())
288 .unwrap();
289
290 manager
292 .trigger_auto_rollback("Performance regression detected")
293 .unwrap();
294
295 let checkpoint = manager.get_checkpoint("test_cp_3");
297 assert!(checkpoint.is_some());
298
299 let _ = manager.delete_checkpoint("test_cp_3");
301 }
302}