use crate::engine::WorkflowDefinition;
use crate::error::{Result, WorkflowError};
use chrono::{DateTime, Utc};
use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::sync::Arc;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RollbackPoint {
pub id: String,
pub workflow_id: String,
pub definition: WorkflowDefinition,
pub created_at: DateTime<Utc>,
pub description: Option<String>,
pub tags: Vec<String>,
}
pub struct RollbackManager {
rollback_points: Arc<DashMap<String, RollbackPoint>>,
max_rollback_points: usize,
}
impl RollbackManager {
pub fn new() -> Self {
Self {
rollback_points: Arc::new(DashMap::new()),
max_rollback_points: 100,
}
}
pub fn with_max_points(max_points: usize) -> Self {
Self {
rollback_points: Arc::new(DashMap::new()),
max_rollback_points: max_points,
}
}
pub fn create_rollback_point(
&self,
workflow_id: String,
definition: WorkflowDefinition,
) -> Result<String> {
let id = uuid::Uuid::new_v4().to_string();
let rollback_point = RollbackPoint {
id: id.clone(),
workflow_id: workflow_id.clone(),
definition,
created_at: Utc::now(),
description: None,
tags: Vec::new(),
};
let workflow_points: Vec<String> = self
.rollback_points
.iter()
.filter(|entry| entry.value().workflow_id == workflow_id)
.map(|entry| entry.key().clone())
.collect();
if workflow_points.len() >= self.max_rollback_points {
if let Some(oldest) = workflow_points.first() {
self.rollback_points.remove(oldest);
}
}
self.rollback_points.insert(id.clone(), rollback_point);
Ok(id)
}
pub fn rollback(&self, rollback_id: &str) -> Result<WorkflowDefinition> {
let rollback_point = self
.rollback_points
.get(rollback_id)
.ok_or_else(|| WorkflowError::not_found(rollback_id))?;
Ok(rollback_point.definition.clone())
}
pub fn get_rollback_point(&self, rollback_id: &str) -> Option<RollbackPoint> {
self.rollback_points
.get(rollback_id)
.map(|entry| entry.clone())
}
pub fn list_rollback_points(&self, workflow_id: &str) -> Vec<RollbackPoint> {
let mut points: Vec<RollbackPoint> = self
.rollback_points
.iter()
.filter(|entry| entry.value().workflow_id == workflow_id)
.map(|entry| entry.value().clone())
.collect();
points.sort_by_key(|x| std::cmp::Reverse(x.created_at));
points
}
pub fn delete_rollback_point(&self, rollback_id: &str) -> Option<RollbackPoint> {
self.rollback_points
.remove(rollback_id)
.map(|(_, point)| point)
}
pub fn delete_workflow_rollback_points(&self, workflow_id: &str) -> usize {
let points_to_delete: Vec<String> = self
.rollback_points
.iter()
.filter(|entry| entry.value().workflow_id == workflow_id)
.map(|entry| entry.key().clone())
.collect();
let count = points_to_delete.len();
for id in points_to_delete {
self.rollback_points.remove(&id);
}
count
}
pub fn get_latest_rollback_point(&self, workflow_id: &str) -> Option<RollbackPoint> {
self.list_rollback_points(workflow_id).into_iter().next()
}
pub fn clear_all(&self) {
self.rollback_points.clear();
}
pub fn count(&self) -> usize {
self.rollback_points.len()
}
pub fn update_description(&self, rollback_id: &str, description: String) -> Result<()> {
let mut point = self
.rollback_points
.get_mut(rollback_id)
.ok_or_else(|| WorkflowError::not_found(rollback_id))?;
point.description = Some(description);
Ok(())
}
pub fn add_tag(&self, rollback_id: &str, tag: String) -> Result<()> {
let mut point = self
.rollback_points
.get_mut(rollback_id)
.ok_or_else(|| WorkflowError::not_found(rollback_id))?;
if !point.tags.contains(&tag) {
point.tags.push(tag);
}
Ok(())
}
pub fn search_by_tag(&self, tag: &str) -> Vec<RollbackPoint> {
self.rollback_points
.iter()
.filter(|entry| entry.value().tags.contains(&tag.to_string()))
.map(|entry| entry.value().clone())
.collect()
}
}
impl Default for RollbackManager {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dag::WorkflowDag;
#[test]
fn test_rollback_manager_creation() {
let manager = RollbackManager::new();
assert_eq!(manager.count(), 0);
}
#[test]
fn test_create_rollback_point() {
let manager = RollbackManager::new();
let definition = WorkflowDefinition {
id: "test".to_string(),
name: "Test".to_string(),
description: None,
version: "1.0.0".to_string(),
dag: WorkflowDag::new(),
};
let rollback_id = manager
.create_rollback_point("test-workflow".to_string(), definition)
.expect("Failed to create rollback point");
assert!(manager.get_rollback_point(&rollback_id).is_some());
}
#[test]
fn test_rollback() {
let manager = RollbackManager::new();
let definition = WorkflowDefinition {
id: "test".to_string(),
name: "Test".to_string(),
description: None,
version: "1.0.0".to_string(),
dag: WorkflowDag::new(),
};
let rollback_id = manager
.create_rollback_point("test-workflow".to_string(), definition)
.expect("Failed to create");
let restored = manager.rollback(&rollback_id).expect("Failed to rollback");
assert_eq!(restored.id, "test");
assert_eq!(restored.version, "1.0.0");
}
#[test]
fn test_list_rollback_points() {
let manager = RollbackManager::new();
for i in 0..3 {
let definition = WorkflowDefinition {
id: "test".to_string(),
name: format!("Test {}", i),
description: None,
version: format!("1.0.{}", i),
dag: WorkflowDag::new(),
};
manager
.create_rollback_point("test-workflow".to_string(), definition)
.expect("Failed to create");
}
let points = manager.list_rollback_points("test-workflow");
assert_eq!(points.len(), 3);
}
#[test]
fn test_delete_rollback_point() {
let manager = RollbackManager::new();
let definition = WorkflowDefinition {
id: "test".to_string(),
name: "Test".to_string(),
description: None,
version: "1.0.0".to_string(),
dag: WorkflowDag::new(),
};
let rollback_id = manager
.create_rollback_point("test".to_string(), definition)
.expect("Failed to create");
assert!(manager.delete_rollback_point(&rollback_id).is_some());
assert!(manager.get_rollback_point(&rollback_id).is_none());
}
#[test]
fn test_max_rollback_points() {
let manager = RollbackManager::with_max_points(3);
for i in 0..5 {
let definition = WorkflowDefinition {
id: "test".to_string(),
name: format!("Test {}", i),
description: None,
version: format!("1.0.{}", i),
dag: WorkflowDag::new(),
};
manager
.create_rollback_point("test".to_string(), definition)
.expect("Failed to create");
}
let points = manager.list_rollback_points("test");
assert!(points.len() <= 3);
}
#[test]
fn test_update_description() {
let manager = RollbackManager::new();
let definition = WorkflowDefinition {
id: "test".to_string(),
name: "Test".to_string(),
description: None,
version: "1.0.0".to_string(),
dag: WorkflowDag::new(),
};
let id = manager
.create_rollback_point("test".to_string(), definition)
.expect("Failed to create");
manager
.update_description(&id, "Test description".to_string())
.expect("Failed to update");
let point = manager.get_rollback_point(&id).expect("Not found");
assert_eq!(point.description, Some("Test description".to_string()));
}
#[test]
fn test_search_by_tag() {
let manager = RollbackManager::new();
let definition = WorkflowDefinition {
id: "test".to_string(),
name: "Test".to_string(),
description: None,
version: "1.0.0".to_string(),
dag: WorkflowDag::new(),
};
let id = manager
.create_rollback_point("test".to_string(), definition)
.expect("Failed to create");
manager
.add_tag(&id, "production".to_string())
.expect("Failed to add tag");
let tagged = manager.search_by_tag("production");
assert_eq!(tagged.len(), 1);
}
}