aimds_response/
rollback.rs

1//! Rollback manager for safe mitigation reversal
2
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::RwLock;
6use serde::{Deserialize, Serialize};
7use crate::{MitigationAction, Result, ResponseError};
8
9/// Manages rollback of mitigation actions
10pub struct RollbackManager {
11    /// Stack of reversible actions
12    action_stack: Arc<RwLock<Vec<RollbackEntry>>>,
13
14    /// Rollback history
15    history: Arc<RwLock<Vec<RollbackRecord>>>,
16
17    /// Maximum stack size
18    max_stack_size: usize,
19}
20
21impl RollbackManager {
22    /// Create new rollback manager
23    pub fn new() -> Self {
24        Self {
25            action_stack: Arc::new(RwLock::new(Vec::new())),
26            history: Arc::new(RwLock::new(Vec::new())),
27            max_stack_size: 1000,
28        }
29    }
30
31    /// Create with custom max stack size
32    pub fn with_max_size(max_size: usize) -> Self {
33        Self {
34            action_stack: Arc::new(RwLock::new(Vec::new())),
35            history: Arc::new(RwLock::new(Vec::new())),
36            max_stack_size: max_size,
37        }
38    }
39
40    /// Push action onto rollback stack
41    pub async fn push_action(&self, action: MitigationAction, action_id: String) -> Result<()> {
42        let mut stack = self.action_stack.write().await;
43
44        // Check stack size limit
45        if stack.len() >= self.max_stack_size {
46            // Remove oldest entry
47            stack.remove(0);
48        }
49
50        let entry = RollbackEntry {
51            action,
52            action_id,
53            timestamp: chrono::Utc::now(),
54            context: HashMap::new(),
55        };
56
57        stack.push(entry);
58        Ok(())
59    }
60
61    /// Rollback the last action
62    pub async fn rollback_last(&self) -> Result<()> {
63        let mut stack = self.action_stack.write().await;
64
65        if let Some(entry) = stack.pop() {
66            let result = self.execute_rollback(&entry).await;
67
68            // Record rollback attempt
69            let mut history = self.history.write().await;
70            history.push(RollbackRecord {
71                action_id: entry.action_id.clone(),
72                success: result.is_ok(),
73                timestamp: chrono::Utc::now(),
74                error: result.as_ref().err().map(|e| e.to_string()),
75            });
76
77            result
78        } else {
79            Err(ResponseError::RollbackFailed("No actions to rollback".to_string()))
80        }
81    }
82
83    /// Rollback specific action by ID
84    pub async fn rollback_action(&self, action_id: &str) -> Result<()> {
85        let mut stack = self.action_stack.write().await;
86
87        // Find and remove action from stack
88        if let Some(pos) = stack.iter().position(|e| e.action_id == action_id) {
89            let entry = stack.remove(pos);
90            let result = self.execute_rollback(&entry).await;
91
92            // Record rollback attempt
93            let mut history = self.history.write().await;
94            history.push(RollbackRecord {
95                action_id: entry.action_id.clone(),
96                success: result.is_ok(),
97                timestamp: chrono::Utc::now(),
98                error: result.as_ref().err().map(|e| e.to_string()),
99            });
100
101            result
102        } else {
103            Err(ResponseError::RollbackFailed(
104                format!("Action {} not found", action_id)
105            ))
106        }
107    }
108
109    /// Rollback all actions
110    pub async fn rollback_all(&self) -> Result<Vec<String>> {
111        let mut stack = self.action_stack.write().await;
112        let mut rolled_back = Vec::new();
113        let mut errors = Vec::new();
114
115        while let Some(entry) = stack.pop() {
116            match self.execute_rollback(&entry).await {
117                Ok(_) => {
118                    rolled_back.push(entry.action_id.clone());
119                }
120                Err(e) => {
121                    errors.push(format!("Failed to rollback {}: {}", entry.action_id, e));
122                }
123            }
124
125            // Record rollback attempt
126            let mut history = self.history.write().await;
127            history.push(RollbackRecord {
128                action_id: entry.action_id.clone(),
129                success: errors.is_empty(),
130                timestamp: chrono::Utc::now(),
131                error: errors.last().cloned(),
132            });
133        }
134
135        if errors.is_empty() {
136            Ok(rolled_back)
137        } else {
138            Err(ResponseError::RollbackFailed(errors.join("; ")))
139        }
140    }
141
142    /// Get rollback history
143    pub async fn history(&self) -> Vec<RollbackRecord> {
144        self.history.read().await.clone()
145    }
146
147    /// Get current stack size
148    pub async fn stack_size(&self) -> usize {
149        self.action_stack.read().await.len()
150    }
151
152    /// Clear rollback stack (use with caution)
153    pub async fn clear_stack(&self) {
154        let mut stack = self.action_stack.write().await;
155        stack.clear();
156    }
157
158    /// Execute rollback for entry
159    async fn execute_rollback(&self, entry: &RollbackEntry) -> Result<()> {
160        tracing::info!("Rolling back action: {}", entry.action_id);
161
162        match entry.action.rollback(&entry.action_id) {
163            Ok(_) => {
164                metrics::counter!("rollback.success").increment(1);
165                Ok(())
166            }
167            Err(e) => {
168                metrics::counter!("rollback.failure").increment(1);
169                Err(e)
170            }
171        }
172    }
173}
174
175impl Default for RollbackManager {
176    fn default() -> Self {
177        Self::new()
178    }
179}
180
181/// Entry in rollback stack
182#[derive(Debug, Clone, Serialize, Deserialize)]
183struct RollbackEntry {
184    action: MitigationAction,
185    action_id: String,
186    timestamp: chrono::DateTime<chrono::Utc>,
187    context: HashMap<String, String>,
188}
189
190/// Record of rollback attempt
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct RollbackRecord {
193    pub action_id: String,
194    pub success: bool,
195    pub timestamp: chrono::DateTime<chrono::Utc>,
196    pub error: Option<String>,
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202    use crate::MitigationAction;
203    use std::time::Duration;
204
205    #[tokio::test]
206    async fn test_rollback_manager_creation() {
207        let manager = RollbackManager::new();
208        assert_eq!(manager.stack_size().await, 0);
209    }
210
211    #[tokio::test]
212    async fn test_push_action() {
213        let manager = RollbackManager::new();
214
215        let action = MitigationAction::BlockRequest {
216            reason: "Test".to_string(),
217        };
218
219        manager.push_action(action, "action-1".to_string()).await.unwrap();
220        assert_eq!(manager.stack_size().await, 1);
221    }
222
223    #[tokio::test]
224    async fn test_rollback_last() {
225        let manager = RollbackManager::new();
226
227        let action = MitigationAction::RateLimitUser {
228            duration: Duration::from_secs(60),
229        };
230
231        manager.push_action(action, "action-1".to_string()).await.unwrap();
232        assert_eq!(manager.stack_size().await, 1);
233
234        let result = manager.rollback_last().await;
235        assert!(result.is_ok());
236        assert_eq!(manager.stack_size().await, 0);
237    }
238
239    #[tokio::test]
240    async fn test_rollback_specific_action() {
241        let manager = RollbackManager::new();
242
243        let action1 = MitigationAction::BlockRequest {
244            reason: "Test 1".to_string(),
245        };
246        let action2 = MitigationAction::BlockRequest {
247            reason: "Test 2".to_string(),
248        };
249
250        manager.push_action(action1, "action-1".to_string()).await.unwrap();
251        manager.push_action(action2, "action-2".to_string()).await.unwrap();
252
253        assert_eq!(manager.stack_size().await, 2);
254
255        manager.rollback_action("action-1").await.unwrap();
256        assert_eq!(manager.stack_size().await, 1);
257    }
258
259    #[tokio::test]
260    async fn test_rollback_all() {
261        let manager = RollbackManager::new();
262
263        for i in 0..5 {
264            let action = MitigationAction::BlockRequest {
265                reason: format!("Test {}", i),
266            };
267            manager.push_action(action, format!("action-{}", i)).await.unwrap();
268        }
269
270        assert_eq!(manager.stack_size().await, 5);
271
272        let result = manager.rollback_all().await;
273        assert!(result.is_ok());
274        assert_eq!(manager.stack_size().await, 0);
275    }
276
277    #[tokio::test]
278    async fn test_max_stack_size() {
279        let manager = RollbackManager::with_max_size(3);
280
281        for i in 0..5 {
282            let action = MitigationAction::BlockRequest {
283                reason: format!("Test {}", i),
284            };
285            manager.push_action(action, format!("action-{}", i)).await.unwrap();
286        }
287
288        // Should only keep last 3
289        assert_eq!(manager.stack_size().await, 3);
290    }
291
292    #[tokio::test]
293    async fn test_rollback_history() {
294        let manager = RollbackManager::new();
295
296        let action = MitigationAction::BlockRequest {
297            reason: "Test".to_string(),
298        };
299
300        manager.push_action(action, "action-1".to_string()).await.unwrap();
301        manager.rollback_last().await.unwrap();
302
303        let history = manager.history().await;
304        assert_eq!(history.len(), 1);
305        assert!(history[0].success);
306    }
307}