use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::RwLock;
use serde::{Deserialize, Serialize};
use crate::{MitigationAction, Result, ResponseError};
pub struct RollbackManager {
action_stack: Arc<RwLock<Vec<RollbackEntry>>>,
history: Arc<RwLock<Vec<RollbackRecord>>>,
max_stack_size: usize,
}
impl RollbackManager {
pub fn new() -> Self {
Self {
action_stack: Arc::new(RwLock::new(Vec::new())),
history: Arc::new(RwLock::new(Vec::new())),
max_stack_size: 1000,
}
}
pub fn with_max_size(max_size: usize) -> Self {
Self {
action_stack: Arc::new(RwLock::new(Vec::new())),
history: Arc::new(RwLock::new(Vec::new())),
max_stack_size: max_size,
}
}
pub async fn push_action(&self, action: MitigationAction, action_id: String) -> Result<()> {
let mut stack = self.action_stack.write().await;
if stack.len() >= self.max_stack_size {
stack.remove(0);
}
let entry = RollbackEntry {
action,
action_id,
timestamp: chrono::Utc::now(),
context: HashMap::new(),
};
stack.push(entry);
Ok(())
}
pub async fn rollback_last(&self) -> Result<()> {
let mut stack = self.action_stack.write().await;
if let Some(entry) = stack.pop() {
let result = self.execute_rollback(&entry).await;
let mut history = self.history.write().await;
history.push(RollbackRecord {
action_id: entry.action_id.clone(),
success: result.is_ok(),
timestamp: chrono::Utc::now(),
error: result.as_ref().err().map(|e| e.to_string()),
});
result
} else {
Err(ResponseError::RollbackFailed("No actions to rollback".to_string()))
}
}
pub async fn rollback_action(&self, action_id: &str) -> Result<()> {
let mut stack = self.action_stack.write().await;
if let Some(pos) = stack.iter().position(|e| e.action_id == action_id) {
let entry = stack.remove(pos);
let result = self.execute_rollback(&entry).await;
let mut history = self.history.write().await;
history.push(RollbackRecord {
action_id: entry.action_id.clone(),
success: result.is_ok(),
timestamp: chrono::Utc::now(),
error: result.as_ref().err().map(|e| e.to_string()),
});
result
} else {
Err(ResponseError::RollbackFailed(
format!("Action {} not found", action_id)
))
}
}
pub async fn rollback_all(&self) -> Result<Vec<String>> {
let mut stack = self.action_stack.write().await;
let mut rolled_back = Vec::new();
let mut errors = Vec::new();
while let Some(entry) = stack.pop() {
match self.execute_rollback(&entry).await {
Ok(_) => {
rolled_back.push(entry.action_id.clone());
}
Err(e) => {
errors.push(format!("Failed to rollback {}: {}", entry.action_id, e));
}
}
let mut history = self.history.write().await;
history.push(RollbackRecord {
action_id: entry.action_id.clone(),
success: errors.is_empty(),
timestamp: chrono::Utc::now(),
error: errors.last().cloned(),
});
}
if errors.is_empty() {
Ok(rolled_back)
} else {
Err(ResponseError::RollbackFailed(errors.join("; ")))
}
}
pub async fn history(&self) -> Vec<RollbackRecord> {
self.history.read().await.clone()
}
pub async fn stack_size(&self) -> usize {
self.action_stack.read().await.len()
}
pub async fn clear_stack(&self) {
let mut stack = self.action_stack.write().await;
stack.clear();
}
async fn execute_rollback(&self, entry: &RollbackEntry) -> Result<()> {
tracing::info!("Rolling back action: {}", entry.action_id);
match entry.action.rollback(&entry.action_id) {
Ok(_) => {
metrics::counter!("rollback.success").increment(1);
Ok(())
}
Err(e) => {
metrics::counter!("rollback.failure").increment(1);
Err(e)
}
}
}
}
impl Default for RollbackManager {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
struct RollbackEntry {
action: MitigationAction,
action_id: String,
timestamp: chrono::DateTime<chrono::Utc>,
context: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RollbackRecord {
pub action_id: String,
pub success: bool,
pub timestamp: chrono::DateTime<chrono::Utc>,
pub error: Option<String>,
}
#[cfg(test)]
mod tests {
use super::*;
use crate::MitigationAction;
use std::time::Duration;
#[tokio::test]
async fn test_rollback_manager_creation() {
let manager = RollbackManager::new();
assert_eq!(manager.stack_size().await, 0);
}
#[tokio::test]
async fn test_push_action() {
let manager = RollbackManager::new();
let action = MitigationAction::BlockRequest {
reason: "Test".to_string(),
};
manager.push_action(action, "action-1".to_string()).await.unwrap();
assert_eq!(manager.stack_size().await, 1);
}
#[tokio::test]
async fn test_rollback_last() {
let manager = RollbackManager::new();
let action = MitigationAction::RateLimitUser {
duration: Duration::from_secs(60),
};
manager.push_action(action, "action-1".to_string()).await.unwrap();
assert_eq!(manager.stack_size().await, 1);
let result = manager.rollback_last().await;
assert!(result.is_ok());
assert_eq!(manager.stack_size().await, 0);
}
#[tokio::test]
async fn test_rollback_specific_action() {
let manager = RollbackManager::new();
let action1 = MitigationAction::BlockRequest {
reason: "Test 1".to_string(),
};
let action2 = MitigationAction::BlockRequest {
reason: "Test 2".to_string(),
};
manager.push_action(action1, "action-1".to_string()).await.unwrap();
manager.push_action(action2, "action-2".to_string()).await.unwrap();
assert_eq!(manager.stack_size().await, 2);
manager.rollback_action("action-1").await.unwrap();
assert_eq!(manager.stack_size().await, 1);
}
#[tokio::test]
async fn test_rollback_all() {
let manager = RollbackManager::new();
for i in 0..5 {
let action = MitigationAction::BlockRequest {
reason: format!("Test {}", i),
};
manager.push_action(action, format!("action-{}", i)).await.unwrap();
}
assert_eq!(manager.stack_size().await, 5);
let result = manager.rollback_all().await;
assert!(result.is_ok());
assert_eq!(manager.stack_size().await, 0);
}
#[tokio::test]
async fn test_max_stack_size() {
let manager = RollbackManager::with_max_size(3);
for i in 0..5 {
let action = MitigationAction::BlockRequest {
reason: format!("Test {}", i),
};
manager.push_action(action, format!("action-{}", i)).await.unwrap();
}
assert_eq!(manager.stack_size().await, 3);
}
#[tokio::test]
async fn test_rollback_history() {
let manager = RollbackManager::new();
let action = MitigationAction::BlockRequest {
reason: "Test".to_string(),
};
manager.push_action(action, "action-1".to_string()).await.unwrap();
manager.rollback_last().await.unwrap();
let history = manager.history().await;
assert_eq!(history.len(), 1);
assert!(history[0].success);
}
}