aimds_response/
rollback.rs1use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::RwLock;
6use serde::{Deserialize, Serialize};
7use crate::{MitigationAction, Result, ResponseError};
8
9pub struct RollbackManager {
11 action_stack: Arc<RwLock<Vec<RollbackEntry>>>,
13
14 history: Arc<RwLock<Vec<RollbackRecord>>>,
16
17 max_stack_size: usize,
19}
20
21impl RollbackManager {
22 pub fn new() -> Self {
24 Self {
25 action_stack: Arc::new(RwLock::new(Vec::new())),
26 history: Arc::new(RwLock::new(Vec::new())),
27 max_stack_size: 1000,
28 }
29 }
30
31 pub fn with_max_size(max_size: usize) -> Self {
33 Self {
34 action_stack: Arc::new(RwLock::new(Vec::new())),
35 history: Arc::new(RwLock::new(Vec::new())),
36 max_stack_size: max_size,
37 }
38 }
39
40 pub async fn push_action(&self, action: MitigationAction, action_id: String) -> Result<()> {
42 let mut stack = self.action_stack.write().await;
43
44 if stack.len() >= self.max_stack_size {
46 stack.remove(0);
48 }
49
50 let entry = RollbackEntry {
51 action,
52 action_id,
53 timestamp: chrono::Utc::now(),
54 context: HashMap::new(),
55 };
56
57 stack.push(entry);
58 Ok(())
59 }
60
61 pub async fn rollback_last(&self) -> Result<()> {
63 let mut stack = self.action_stack.write().await;
64
65 if let Some(entry) = stack.pop() {
66 let result = self.execute_rollback(&entry).await;
67
68 let mut history = self.history.write().await;
70 history.push(RollbackRecord {
71 action_id: entry.action_id.clone(),
72 success: result.is_ok(),
73 timestamp: chrono::Utc::now(),
74 error: result.as_ref().err().map(|e| e.to_string()),
75 });
76
77 result
78 } else {
79 Err(ResponseError::RollbackFailed("No actions to rollback".to_string()))
80 }
81 }
82
83 pub async fn rollback_action(&self, action_id: &str) -> Result<()> {
85 let mut stack = self.action_stack.write().await;
86
87 if let Some(pos) = stack.iter().position(|e| e.action_id == action_id) {
89 let entry = stack.remove(pos);
90 let result = self.execute_rollback(&entry).await;
91
92 let mut history = self.history.write().await;
94 history.push(RollbackRecord {
95 action_id: entry.action_id.clone(),
96 success: result.is_ok(),
97 timestamp: chrono::Utc::now(),
98 error: result.as_ref().err().map(|e| e.to_string()),
99 });
100
101 result
102 } else {
103 Err(ResponseError::RollbackFailed(
104 format!("Action {} not found", action_id)
105 ))
106 }
107 }
108
109 pub async fn rollback_all(&self) -> Result<Vec<String>> {
111 let mut stack = self.action_stack.write().await;
112 let mut rolled_back = Vec::new();
113 let mut errors = Vec::new();
114
115 while let Some(entry) = stack.pop() {
116 match self.execute_rollback(&entry).await {
117 Ok(_) => {
118 rolled_back.push(entry.action_id.clone());
119 }
120 Err(e) => {
121 errors.push(format!("Failed to rollback {}: {}", entry.action_id, e));
122 }
123 }
124
125 let mut history = self.history.write().await;
127 history.push(RollbackRecord {
128 action_id: entry.action_id.clone(),
129 success: errors.is_empty(),
130 timestamp: chrono::Utc::now(),
131 error: errors.last().cloned(),
132 });
133 }
134
135 if errors.is_empty() {
136 Ok(rolled_back)
137 } else {
138 Err(ResponseError::RollbackFailed(errors.join("; ")))
139 }
140 }
141
142 pub async fn history(&self) -> Vec<RollbackRecord> {
144 self.history.read().await.clone()
145 }
146
147 pub async fn stack_size(&self) -> usize {
149 self.action_stack.read().await.len()
150 }
151
152 pub async fn clear_stack(&self) {
154 let mut stack = self.action_stack.write().await;
155 stack.clear();
156 }
157
158 async fn execute_rollback(&self, entry: &RollbackEntry) -> Result<()> {
160 tracing::info!("Rolling back action: {}", entry.action_id);
161
162 match entry.action.rollback(&entry.action_id) {
163 Ok(_) => {
164 metrics::counter!("rollback.success").increment(1);
165 Ok(())
166 }
167 Err(e) => {
168 metrics::counter!("rollback.failure").increment(1);
169 Err(e)
170 }
171 }
172 }
173}
174
175impl Default for RollbackManager {
176 fn default() -> Self {
177 Self::new()
178 }
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize)]
183struct RollbackEntry {
184 action: MitigationAction,
185 action_id: String,
186 timestamp: chrono::DateTime<chrono::Utc>,
187 context: HashMap<String, String>,
188}
189
190#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct RollbackRecord {
193 pub action_id: String,
194 pub success: bool,
195 pub timestamp: chrono::DateTime<chrono::Utc>,
196 pub error: Option<String>,
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202 use crate::MitigationAction;
203 use std::time::Duration;
204
205 #[tokio::test]
206 async fn test_rollback_manager_creation() {
207 let manager = RollbackManager::new();
208 assert_eq!(manager.stack_size().await, 0);
209 }
210
211 #[tokio::test]
212 async fn test_push_action() {
213 let manager = RollbackManager::new();
214
215 let action = MitigationAction::BlockRequest {
216 reason: "Test".to_string(),
217 };
218
219 manager.push_action(action, "action-1".to_string()).await.unwrap();
220 assert_eq!(manager.stack_size().await, 1);
221 }
222
223 #[tokio::test]
224 async fn test_rollback_last() {
225 let manager = RollbackManager::new();
226
227 let action = MitigationAction::RateLimitUser {
228 duration: Duration::from_secs(60),
229 };
230
231 manager.push_action(action, "action-1".to_string()).await.unwrap();
232 assert_eq!(manager.stack_size().await, 1);
233
234 let result = manager.rollback_last().await;
235 assert!(result.is_ok());
236 assert_eq!(manager.stack_size().await, 0);
237 }
238
239 #[tokio::test]
240 async fn test_rollback_specific_action() {
241 let manager = RollbackManager::new();
242
243 let action1 = MitigationAction::BlockRequest {
244 reason: "Test 1".to_string(),
245 };
246 let action2 = MitigationAction::BlockRequest {
247 reason: "Test 2".to_string(),
248 };
249
250 manager.push_action(action1, "action-1".to_string()).await.unwrap();
251 manager.push_action(action2, "action-2".to_string()).await.unwrap();
252
253 assert_eq!(manager.stack_size().await, 2);
254
255 manager.rollback_action("action-1").await.unwrap();
256 assert_eq!(manager.stack_size().await, 1);
257 }
258
259 #[tokio::test]
260 async fn test_rollback_all() {
261 let manager = RollbackManager::new();
262
263 for i in 0..5 {
264 let action = MitigationAction::BlockRequest {
265 reason: format!("Test {}", i),
266 };
267 manager.push_action(action, format!("action-{}", i)).await.unwrap();
268 }
269
270 assert_eq!(manager.stack_size().await, 5);
271
272 let result = manager.rollback_all().await;
273 assert!(result.is_ok());
274 assert_eq!(manager.stack_size().await, 0);
275 }
276
277 #[tokio::test]
278 async fn test_max_stack_size() {
279 let manager = RollbackManager::with_max_size(3);
280
281 for i in 0..5 {
282 let action = MitigationAction::BlockRequest {
283 reason: format!("Test {}", i),
284 };
285 manager.push_action(action, format!("action-{}", i)).await.unwrap();
286 }
287
288 assert_eq!(manager.stack_size().await, 3);
290 }
291
292 #[tokio::test]
293 async fn test_rollback_history() {
294 let manager = RollbackManager::new();
295
296 let action = MitigationAction::BlockRequest {
297 reason: "Test".to_string(),
298 };
299
300 manager.push_action(action, "action-1".to_string()).await.unwrap();
301 manager.rollback_last().await.unwrap();
302
303 let history = manager.history().await;
304 assert_eq!(history.len(), 1);
305 assert!(history[0].success);
306 }
307}