Skip to main content

oxigdal_workflow/versioning/
rollback.rs

1//! Workflow rollback utilities.
2
3use crate::engine::WorkflowDefinition;
4use crate::error::{Result, WorkflowError};
5use chrono::{DateTime, Utc};
6use dashmap::DashMap;
7use serde::{Deserialize, Serialize};
8use std::sync::Arc;
9
10/// Rollback point for a workflow.
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct RollbackPoint {
13    /// Rollback point ID.
14    pub id: String,
15    /// Workflow ID.
16    pub workflow_id: String,
17    /// Workflow definition at this point.
18    pub definition: WorkflowDefinition,
19    /// Creation timestamp.
20    pub created_at: DateTime<Utc>,
21    /// Description.
22    pub description: Option<String>,
23    /// Tags.
24    pub tags: Vec<String>,
25}
26
27/// Rollback manager for workflow versions.
28pub struct RollbackManager {
29    rollback_points: Arc<DashMap<String, RollbackPoint>>,
30    max_rollback_points: usize,
31}
32
33impl RollbackManager {
34    /// Create a new rollback manager.
35    pub fn new() -> Self {
36        Self {
37            rollback_points: Arc::new(DashMap::new()),
38            max_rollback_points: 100,
39        }
40    }
41
42    /// Create a new rollback manager with custom limits.
43    pub fn with_max_points(max_points: usize) -> Self {
44        Self {
45            rollback_points: Arc::new(DashMap::new()),
46            max_rollback_points: max_points,
47        }
48    }
49
50    /// Create a rollback point.
51    pub fn create_rollback_point(
52        &self,
53        workflow_id: String,
54        definition: WorkflowDefinition,
55    ) -> Result<String> {
56        let id = uuid::Uuid::new_v4().to_string();
57
58        let rollback_point = RollbackPoint {
59            id: id.clone(),
60            workflow_id: workflow_id.clone(),
61            definition,
62            created_at: Utc::now(),
63            description: None,
64            tags: Vec::new(),
65        };
66
67        // Check if we've exceeded max rollback points for this workflow
68        let workflow_points: Vec<String> = self
69            .rollback_points
70            .iter()
71            .filter(|entry| entry.value().workflow_id == workflow_id)
72            .map(|entry| entry.key().clone())
73            .collect();
74
75        if workflow_points.len() >= self.max_rollback_points {
76            // Remove the oldest rollback point
77            if let Some(oldest) = workflow_points.first() {
78                self.rollback_points.remove(oldest);
79            }
80        }
81
82        self.rollback_points.insert(id.clone(), rollback_point);
83
84        Ok(id)
85    }
86
87    /// Rollback to a specific rollback point.
88    pub fn rollback(&self, rollback_id: &str) -> Result<WorkflowDefinition> {
89        let rollback_point = self
90            .rollback_points
91            .get(rollback_id)
92            .ok_or_else(|| WorkflowError::not_found(rollback_id))?;
93
94        Ok(rollback_point.definition.clone())
95    }
96
97    /// Get a rollback point.
98    pub fn get_rollback_point(&self, rollback_id: &str) -> Option<RollbackPoint> {
99        self.rollback_points
100            .get(rollback_id)
101            .map(|entry| entry.clone())
102    }
103
104    /// List all rollback points for a workflow.
105    pub fn list_rollback_points(&self, workflow_id: &str) -> Vec<RollbackPoint> {
106        let mut points: Vec<RollbackPoint> = self
107            .rollback_points
108            .iter()
109            .filter(|entry| entry.value().workflow_id == workflow_id)
110            .map(|entry| entry.value().clone())
111            .collect();
112
113        points.sort_by_key(|x| std::cmp::Reverse(x.created_at));
114
115        points
116    }
117
118    /// Delete a rollback point.
119    pub fn delete_rollback_point(&self, rollback_id: &str) -> Option<RollbackPoint> {
120        self.rollback_points
121            .remove(rollback_id)
122            .map(|(_, point)| point)
123    }
124
125    /// Delete all rollback points for a workflow.
126    pub fn delete_workflow_rollback_points(&self, workflow_id: &str) -> usize {
127        let points_to_delete: Vec<String> = self
128            .rollback_points
129            .iter()
130            .filter(|entry| entry.value().workflow_id == workflow_id)
131            .map(|entry| entry.key().clone())
132            .collect();
133
134        let count = points_to_delete.len();
135
136        for id in points_to_delete {
137            self.rollback_points.remove(&id);
138        }
139
140        count
141    }
142
143    /// Get the latest rollback point for a workflow.
144    pub fn get_latest_rollback_point(&self, workflow_id: &str) -> Option<RollbackPoint> {
145        self.list_rollback_points(workflow_id).into_iter().next()
146    }
147
148    /// Clear all rollback points.
149    pub fn clear_all(&self) {
150        self.rollback_points.clear();
151    }
152
153    /// Get total count of rollback points.
154    pub fn count(&self) -> usize {
155        self.rollback_points.len()
156    }
157
158    /// Update rollback point description.
159    pub fn update_description(&self, rollback_id: &str, description: String) -> Result<()> {
160        let mut point = self
161            .rollback_points
162            .get_mut(rollback_id)
163            .ok_or_else(|| WorkflowError::not_found(rollback_id))?;
164
165        point.description = Some(description);
166
167        Ok(())
168    }
169
170    /// Add tag to rollback point.
171    pub fn add_tag(&self, rollback_id: &str, tag: String) -> Result<()> {
172        let mut point = self
173            .rollback_points
174            .get_mut(rollback_id)
175            .ok_or_else(|| WorkflowError::not_found(rollback_id))?;
176
177        if !point.tags.contains(&tag) {
178            point.tags.push(tag);
179        }
180
181        Ok(())
182    }
183
184    /// Search rollback points by tag.
185    pub fn search_by_tag(&self, tag: &str) -> Vec<RollbackPoint> {
186        self.rollback_points
187            .iter()
188            .filter(|entry| entry.value().tags.contains(&tag.to_string()))
189            .map(|entry| entry.value().clone())
190            .collect()
191    }
192}
193
194impl Default for RollbackManager {
195    fn default() -> Self {
196        Self::new()
197    }
198}
199
200#[cfg(test)]
201mod tests {
202    use super::*;
203    use crate::dag::WorkflowDag;
204
205    #[test]
206    fn test_rollback_manager_creation() {
207        let manager = RollbackManager::new();
208        assert_eq!(manager.count(), 0);
209    }
210
211    #[test]
212    fn test_create_rollback_point() {
213        let manager = RollbackManager::new();
214
215        let definition = WorkflowDefinition {
216            id: "test".to_string(),
217            name: "Test".to_string(),
218            description: None,
219            version: "1.0.0".to_string(),
220            dag: WorkflowDag::new(),
221        };
222
223        let rollback_id = manager
224            .create_rollback_point("test-workflow".to_string(), definition)
225            .expect("Failed to create rollback point");
226
227        assert!(manager.get_rollback_point(&rollback_id).is_some());
228    }
229
230    #[test]
231    fn test_rollback() {
232        let manager = RollbackManager::new();
233
234        let definition = WorkflowDefinition {
235            id: "test".to_string(),
236            name: "Test".to_string(),
237            description: None,
238            version: "1.0.0".to_string(),
239            dag: WorkflowDag::new(),
240        };
241
242        let rollback_id = manager
243            .create_rollback_point("test-workflow".to_string(), definition)
244            .expect("Failed to create");
245
246        let restored = manager.rollback(&rollback_id).expect("Failed to rollback");
247
248        assert_eq!(restored.id, "test");
249        assert_eq!(restored.version, "1.0.0");
250    }
251
252    #[test]
253    fn test_list_rollback_points() {
254        let manager = RollbackManager::new();
255
256        for i in 0..3 {
257            let definition = WorkflowDefinition {
258                id: "test".to_string(),
259                name: format!("Test {}", i),
260                description: None,
261                version: format!("1.0.{}", i),
262                dag: WorkflowDag::new(),
263            };
264
265            manager
266                .create_rollback_point("test-workflow".to_string(), definition)
267                .expect("Failed to create");
268        }
269
270        let points = manager.list_rollback_points("test-workflow");
271        assert_eq!(points.len(), 3);
272    }
273
274    #[test]
275    fn test_delete_rollback_point() {
276        let manager = RollbackManager::new();
277
278        let definition = WorkflowDefinition {
279            id: "test".to_string(),
280            name: "Test".to_string(),
281            description: None,
282            version: "1.0.0".to_string(),
283            dag: WorkflowDag::new(),
284        };
285
286        let rollback_id = manager
287            .create_rollback_point("test".to_string(), definition)
288            .expect("Failed to create");
289
290        assert!(manager.delete_rollback_point(&rollback_id).is_some());
291        assert!(manager.get_rollback_point(&rollback_id).is_none());
292    }
293
294    #[test]
295    fn test_max_rollback_points() {
296        let manager = RollbackManager::with_max_points(3);
297
298        for i in 0..5 {
299            let definition = WorkflowDefinition {
300                id: "test".to_string(),
301                name: format!("Test {}", i),
302                description: None,
303                version: format!("1.0.{}", i),
304                dag: WorkflowDag::new(),
305            };
306
307            manager
308                .create_rollback_point("test".to_string(), definition)
309                .expect("Failed to create");
310        }
311
312        let points = manager.list_rollback_points("test");
313        // Should only keep the last 3
314        assert!(points.len() <= 3);
315    }
316
317    #[test]
318    fn test_update_description() {
319        let manager = RollbackManager::new();
320
321        let definition = WorkflowDefinition {
322            id: "test".to_string(),
323            name: "Test".to_string(),
324            description: None,
325            version: "1.0.0".to_string(),
326            dag: WorkflowDag::new(),
327        };
328
329        let id = manager
330            .create_rollback_point("test".to_string(), definition)
331            .expect("Failed to create");
332
333        manager
334            .update_description(&id, "Test description".to_string())
335            .expect("Failed to update");
336
337        let point = manager.get_rollback_point(&id).expect("Not found");
338        assert_eq!(point.description, Some("Test description".to_string()));
339    }
340
341    #[test]
342    fn test_search_by_tag() {
343        let manager = RollbackManager::new();
344
345        let definition = WorkflowDefinition {
346            id: "test".to_string(),
347            name: "Test".to_string(),
348            description: None,
349            version: "1.0.0".to_string(),
350            dag: WorkflowDag::new(),
351        };
352
353        let id = manager
354            .create_rollback_point("test".to_string(), definition)
355            .expect("Failed to create");
356
357        manager
358            .add_tag(&id, "production".to_string())
359            .expect("Failed to add tag");
360
361        let tagged = manager.search_by_tag("production");
362        assert_eq!(tagged.len(), 1);
363    }
364}