oxify_model/
rollback.rs

1//! Execution rollback mechanism for workflows
2//!
3//! This module provides the ability to create snapshots of execution state
4//! and rollback to previous states when needed.
5//!
6//! # Features
7//!
8//! - Create point-in-time snapshots of execution state
9//! - Store execution history with automatic pruning
10//! - Rollback to any previous snapshot
11//! - Support for selective rollback (specific variables, node results)
12//!
13//! # Example
14//!
15//! ```
16//! use oxify_model::rollback::{ExecutionSnapshot, RollbackManager};
17//! use oxify_model::execution::ExecutionContext;
18//! use uuid::Uuid;
19//!
20//! let mut manager = RollbackManager::new(10); // Keep 10 snapshots max
21//! let mut context = ExecutionContext::new(Uuid::new_v4());
22//!
23//! // Create a snapshot before risky operation
24//! let snapshot = ExecutionSnapshot::from_context(&context);
25//! manager.push_snapshot(snapshot);
26//!
27//! // If operation fails, rollback
28//! let result = manager.rollback(&mut context);
29//! assert!(result.success);
30//! ```
31
32use crate::execution::{ExecutionContext, ExecutionState, NodeExecutionResult};
33use crate::NodeId;
34use chrono::{DateTime, Utc};
35use serde::{Deserialize, Serialize};
36use std::collections::{HashMap, VecDeque};
37use uuid::Uuid;
38
39#[cfg(feature = "openapi")]
40use utoipa::ToSchema;
41
42/// A point-in-time snapshot of execution state
43#[derive(Debug, Clone, Serialize, Deserialize)]
44#[cfg_attr(feature = "openapi", derive(ToSchema))]
45pub struct ExecutionSnapshot {
46    /// Unique snapshot identifier
47    #[cfg_attr(feature = "openapi", schema(value_type = String))]
48    pub id: Uuid,
49
50    /// Timestamp when snapshot was created
51    pub created_at: DateTime<Utc>,
52
53    /// Label/description for this snapshot
54    pub label: Option<String>,
55
56    /// Node that triggered this snapshot (if any)
57    #[cfg_attr(feature = "openapi", schema(value_type = Option<String>))]
58    pub trigger_node: Option<NodeId>,
59
60    /// Execution state at snapshot time
61    pub state: ExecutionState,
62
63    /// Variables at snapshot time
64    pub variables: HashMap<String, serde_json::Value>,
65
66    /// Node results at snapshot time
67    #[cfg_attr(feature = "openapi", schema(value_type = HashMap<String, NodeExecutionResult>))]
68    pub node_results: HashMap<NodeId, NodeExecutionResult>,
69
70    /// Metadata for the snapshot
71    #[serde(default)]
72    pub metadata: SnapshotMetadata,
73}
74
75/// Metadata about a snapshot
76#[derive(Debug, Clone, Default, Serialize, Deserialize)]
77#[cfg_attr(feature = "openapi", derive(ToSchema))]
78pub struct SnapshotMetadata {
79    /// Reason for creating the snapshot
80    pub reason: Option<String>,
81
82    /// User who created the snapshot (if applicable)
83    pub created_by: Option<String>,
84
85    /// Whether this is an automatic checkpoint
86    pub is_auto: bool,
87
88    /// Custom metadata
89    #[serde(default)]
90    pub custom: HashMap<String, serde_json::Value>,
91}
92
93impl ExecutionSnapshot {
94    /// Create a new snapshot from an execution context
95    pub fn from_context(ctx: &ExecutionContext) -> Self {
96        Self {
97            id: Uuid::new_v4(),
98            created_at: Utc::now(),
99            label: None,
100            trigger_node: None,
101            state: ctx.state.clone(),
102            variables: ctx.variables.clone(),
103            node_results: ctx.node_results.clone(),
104            metadata: SnapshotMetadata::default(),
105        }
106    }
107
108    /// Create a snapshot with a label
109    pub fn with_label(mut self, label: impl Into<String>) -> Self {
110        self.label = Some(label.into());
111        self
112    }
113
114    /// Create a snapshot with a trigger node
115    pub fn with_trigger_node(mut self, node_id: NodeId) -> Self {
116        self.trigger_node = Some(node_id);
117        self
118    }
119
120    /// Create a snapshot with a reason
121    pub fn with_reason(mut self, reason: impl Into<String>) -> Self {
122        self.metadata.reason = Some(reason.into());
123        self
124    }
125
126    /// Mark as automatic snapshot
127    pub fn as_auto(mut self) -> Self {
128        self.metadata.is_auto = true;
129        self
130    }
131
132    /// Apply this snapshot to an execution context
133    pub fn apply_to(&self, ctx: &mut ExecutionContext) {
134        ctx.state = self.state.clone();
135        ctx.variables = self.variables.clone();
136        ctx.node_results = self.node_results.clone();
137    }
138
139    /// Get the number of completed nodes in this snapshot
140    pub fn completed_node_count(&self) -> usize {
141        self.node_results.len()
142    }
143
144    /// Get the variable count in this snapshot
145    pub fn variable_count(&self) -> usize {
146        self.variables.len()
147    }
148}
149
150/// Result of a rollback operation
151#[derive(Debug, Clone, Serialize, Deserialize)]
152#[cfg_attr(feature = "openapi", derive(ToSchema))]
153pub struct RollbackResult {
154    /// Whether rollback was successful
155    pub success: bool,
156
157    /// The snapshot that was applied
158    #[cfg_attr(feature = "openapi", schema(value_type = Option<String>))]
159    pub applied_snapshot_id: Option<Uuid>,
160
161    /// Number of node results removed
162    pub nodes_removed: usize,
163
164    /// Number of variables changed
165    pub variables_changed: usize,
166
167    /// Error message if failed
168    pub error: Option<String>,
169
170    /// Timestamp of the rollback
171    pub rolled_back_at: DateTime<Utc>,
172}
173
174impl RollbackResult {
175    /// Create a successful rollback result
176    pub fn success(snapshot_id: Uuid, nodes_removed: usize, variables_changed: usize) -> Self {
177        Self {
178            success: true,
179            applied_snapshot_id: Some(snapshot_id),
180            nodes_removed,
181            variables_changed,
182            error: None,
183            rolled_back_at: Utc::now(),
184        }
185    }
186
187    /// Create a failed rollback result
188    pub fn failure(error: impl Into<String>) -> Self {
189        Self {
190            success: false,
191            applied_snapshot_id: None,
192            nodes_removed: 0,
193            variables_changed: 0,
194            error: Some(error.into()),
195            rolled_back_at: Utc::now(),
196        }
197    }
198}
199
200/// Manages execution snapshots and rollback operations
201#[derive(Debug, Clone)]
202pub struct RollbackManager {
203    /// Snapshot history (newest first)
204    snapshots: VecDeque<ExecutionSnapshot>,
205
206    /// Maximum number of snapshots to keep
207    max_snapshots: usize,
208
209    /// Whether automatic snapshots are enabled
210    auto_snapshot: bool,
211
212    /// Auto-snapshot interval (every N nodes)
213    auto_snapshot_interval: usize,
214
215    /// Counter for auto-snapshots
216    node_counter: usize,
217}
218
219impl RollbackManager {
220    /// Create a new rollback manager
221    ///
222    /// # Arguments
223    /// * `max_snapshots` - Maximum number of snapshots to keep
224    pub fn new(max_snapshots: usize) -> Self {
225        Self {
226            snapshots: VecDeque::new(),
227            max_snapshots,
228            auto_snapshot: false,
229            auto_snapshot_interval: 5,
230            node_counter: 0,
231        }
232    }
233
234    /// Enable automatic snapshots every N nodes
235    pub fn with_auto_snapshot(mut self, interval: usize) -> Self {
236        self.auto_snapshot = true;
237        self.auto_snapshot_interval = interval;
238        self
239    }
240
241    /// Push a new snapshot
242    ///
243    /// If max_snapshots is reached, the oldest snapshot is removed.
244    pub fn push_snapshot(&mut self, snapshot: ExecutionSnapshot) {
245        // Add to front (newest first)
246        self.snapshots.push_front(snapshot);
247
248        // Prune if over limit
249        while self.snapshots.len() > self.max_snapshots {
250            self.snapshots.pop_back();
251        }
252    }
253
254    /// Create and push a snapshot from an execution context
255    pub fn create_snapshot(&mut self, ctx: &ExecutionContext) -> Uuid {
256        let snapshot = ExecutionSnapshot::from_context(ctx);
257        let id = snapshot.id;
258        self.push_snapshot(snapshot);
259        id
260    }
261
262    /// Create a labeled snapshot
263    pub fn create_labeled_snapshot(
264        &mut self,
265        ctx: &ExecutionContext,
266        label: impl Into<String>,
267    ) -> Uuid {
268        let snapshot = ExecutionSnapshot::from_context(ctx).with_label(label);
269        let id = snapshot.id;
270        self.push_snapshot(snapshot);
271        id
272    }
273
274    /// Called when a node is about to execute
275    ///
276    /// Creates automatic snapshots if enabled.
277    pub fn on_node_execute(&mut self, ctx: &ExecutionContext, node_id: NodeId) -> Option<Uuid> {
278        self.node_counter += 1;
279
280        if self.auto_snapshot
281            && self
282                .node_counter
283                .is_multiple_of(self.auto_snapshot_interval)
284        {
285            let snapshot = ExecutionSnapshot::from_context(ctx)
286                .with_trigger_node(node_id)
287                .as_auto();
288            let id = snapshot.id;
289            self.push_snapshot(snapshot);
290            Some(id)
291        } else {
292            None
293        }
294    }
295
296    /// Rollback to the most recent snapshot
297    pub fn rollback(&mut self, ctx: &mut ExecutionContext) -> RollbackResult {
298        if let Some(snapshot) = self.snapshots.front() {
299            let nodes_before = ctx.node_results.len();
300            let vars_before = ctx.variables.clone();
301
302            snapshot.apply_to(ctx);
303
304            let nodes_removed = nodes_before.saturating_sub(ctx.node_results.len());
305            let variables_changed = vars_before
306                .iter()
307                .filter(|(k, v)| ctx.variables.get(*k) != Some(*v))
308                .count();
309
310            RollbackResult::success(snapshot.id, nodes_removed, variables_changed)
311        } else {
312            RollbackResult::failure("No snapshots available")
313        }
314    }
315
316    /// Rollback to a specific snapshot by ID
317    pub fn rollback_to(&mut self, ctx: &mut ExecutionContext, snapshot_id: Uuid) -> RollbackResult {
318        if let Some(snapshot) = self.snapshots.iter().find(|s| s.id == snapshot_id) {
319            let nodes_before = ctx.node_results.len();
320            let vars_before = ctx.variables.clone();
321
322            snapshot.clone().apply_to(ctx);
323
324            let nodes_removed = nodes_before.saturating_sub(ctx.node_results.len());
325            let variables_changed = vars_before
326                .iter()
327                .filter(|(k, v)| ctx.variables.get(*k) != Some(*v))
328                .count();
329
330            RollbackResult::success(snapshot_id, nodes_removed, variables_changed)
331        } else {
332            RollbackResult::failure(format!("Snapshot {} not found", snapshot_id))
333        }
334    }
335
336    /// Rollback N steps in history
337    ///
338    /// - `rollback_n(1)` returns to the most recent snapshot (index 0)
339    /// - `rollback_n(2)` returns to 2 snapshots ago (index 1)
340    /// - `rollback_n(3)` returns to 3 snapshots ago (index 2)
341    ///
342    /// Note: "Steps" refers to how many states back in history to go,
343    /// where each snapshot is one state.
344    pub fn rollback_n(&mut self, ctx: &mut ExecutionContext, steps: usize) -> RollbackResult {
345        if steps == 0 {
346            return RollbackResult::failure("Cannot rollback 0 steps");
347        }
348
349        if steps > self.snapshots.len() {
350            return RollbackResult::failure(format!(
351                "Cannot rollback {} steps, only {} snapshots available",
352                steps,
353                self.snapshots.len()
354            ));
355        }
356
357        // Get the snapshot N steps back (1-indexed for user convenience)
358        // rollback_n(1) -> index 0, rollback_n(2) -> index 1, etc.
359        if let Some(snapshot) = self.snapshots.get(steps - 1) {
360            let nodes_before = ctx.node_results.len();
361            let vars_before = ctx.variables.clone();
362
363            snapshot.clone().apply_to(ctx);
364
365            let nodes_removed = nodes_before.saturating_sub(ctx.node_results.len());
366            let variables_changed = vars_before
367                .iter()
368                .filter(|(k, v)| ctx.variables.get(*k) != Some(*v))
369                .count();
370
371            RollbackResult::success(snapshot.id, nodes_removed, variables_changed)
372        } else {
373            RollbackResult::failure("Snapshot not found")
374        }
375    }
376
377    /// Get the most recent snapshot
378    pub fn latest_snapshot(&self) -> Option<&ExecutionSnapshot> {
379        self.snapshots.front()
380    }
381
382    /// Get a snapshot by ID
383    pub fn get_snapshot(&self, id: Uuid) -> Option<&ExecutionSnapshot> {
384        self.snapshots.iter().find(|s| s.id == id)
385    }
386
387    /// List all snapshots (newest first)
388    pub fn list_snapshots(&self) -> Vec<&ExecutionSnapshot> {
389        self.snapshots.iter().collect()
390    }
391
392    /// Get the number of available snapshots
393    pub fn snapshot_count(&self) -> usize {
394        self.snapshots.len()
395    }
396
397    /// Clear all snapshots
398    pub fn clear(&mut self) {
399        self.snapshots.clear();
400        self.node_counter = 0;
401    }
402
403    /// Remove snapshots older than a timestamp
404    pub fn prune_before(&mut self, timestamp: DateTime<Utc>) -> usize {
405        let before = self.snapshots.len();
406        self.snapshots.retain(|s| s.created_at >= timestamp);
407        before - self.snapshots.len()
408    }
409
410    /// Get summary of snapshot history
411    pub fn summary(&self) -> RollbackSummary {
412        RollbackSummary {
413            total_snapshots: self.snapshots.len(),
414            max_snapshots: self.max_snapshots,
415            auto_snapshot_enabled: self.auto_snapshot,
416            auto_snapshot_interval: self.auto_snapshot_interval,
417            oldest_snapshot: self.snapshots.back().map(|s| s.created_at),
418            newest_snapshot: self.snapshots.front().map(|s| s.created_at),
419            nodes_processed: self.node_counter,
420        }
421    }
422}
423
424impl Default for RollbackManager {
425    fn default() -> Self {
426        Self::new(10)
427    }
428}
429
430/// Summary of rollback manager state
431#[derive(Debug, Clone, Serialize, Deserialize)]
432#[cfg_attr(feature = "openapi", derive(ToSchema))]
433pub struct RollbackSummary {
434    /// Total snapshots stored
435    pub total_snapshots: usize,
436
437    /// Maximum snapshots allowed
438    pub max_snapshots: usize,
439
440    /// Whether auto-snapshot is enabled
441    pub auto_snapshot_enabled: bool,
442
443    /// Auto-snapshot interval
444    pub auto_snapshot_interval: usize,
445
446    /// Timestamp of oldest snapshot
447    pub oldest_snapshot: Option<DateTime<Utc>>,
448
449    /// Timestamp of newest snapshot
450    pub newest_snapshot: Option<DateTime<Utc>>,
451
452    /// Total nodes processed
453    pub nodes_processed: usize,
454}
455
456#[cfg(test)]
457mod tests {
458    use super::*;
459
460    fn create_test_context() -> ExecutionContext {
461        let mut ctx = ExecutionContext::new(Uuid::new_v4());
462        ctx.set_variable("key1".to_string(), serde_json::json!("value1"));
463        ctx.set_variable("key2".to_string(), serde_json::json!(42));
464        ctx
465    }
466
467    #[test]
468    fn test_snapshot_creation() {
469        let ctx = create_test_context();
470        let snapshot = ExecutionSnapshot::from_context(&ctx);
471
472        assert!(snapshot.label.is_none());
473        assert!(snapshot.trigger_node.is_none());
474        assert_eq!(snapshot.variables.len(), 2);
475        assert_eq!(snapshot.state, ExecutionState::Running);
476    }
477
478    #[test]
479    fn test_snapshot_with_label() {
480        let ctx = create_test_context();
481        let snapshot = ExecutionSnapshot::from_context(&ctx).with_label("Before LLM call");
482
483        assert_eq!(snapshot.label, Some("Before LLM call".to_string()));
484    }
485
486    #[test]
487    fn test_snapshot_with_trigger_node() {
488        let ctx = create_test_context();
489        let node_id = Uuid::new_v4();
490        let snapshot = ExecutionSnapshot::from_context(&ctx).with_trigger_node(node_id);
491
492        assert_eq!(snapshot.trigger_node, Some(node_id));
493    }
494
495    #[test]
496    fn test_snapshot_apply_to() {
497        let mut ctx = create_test_context();
498        let snapshot = ExecutionSnapshot::from_context(&ctx);
499
500        // Modify context
501        ctx.set_variable("key3".to_string(), serde_json::json!("new_value"));
502        ctx.state = ExecutionState::Failed("test error".to_string());
503
504        // Apply snapshot
505        snapshot.apply_to(&mut ctx);
506
507        assert_eq!(ctx.variables.len(), 2);
508        assert!(!ctx.variables.contains_key("key3"));
509        assert_eq!(ctx.state, ExecutionState::Running);
510    }
511
512    #[test]
513    fn test_rollback_manager_push() {
514        let mut manager = RollbackManager::new(3);
515
516        for i in 0..5 {
517            let mut ctx = create_test_context();
518            ctx.set_variable(format!("iter_{}", i), serde_json::json!(i));
519            manager.create_snapshot(&ctx);
520        }
521
522        // Should only keep 3 snapshots
523        assert_eq!(manager.snapshot_count(), 3);
524    }
525
526    #[test]
527    fn test_rollback_latest() {
528        let mut manager = RollbackManager::new(10);
529        let mut ctx = create_test_context();
530
531        // Create snapshot
532        manager.create_snapshot(&ctx);
533
534        // Modify context
535        ctx.set_variable("new_key".to_string(), serde_json::json!("new_value"));
536        ctx.state = ExecutionState::Failed("error".to_string());
537
538        // Rollback
539        let result = manager.rollback(&mut ctx);
540
541        assert!(result.success);
542        assert!(!ctx.variables.contains_key("new_key"));
543        assert_eq!(ctx.state, ExecutionState::Running);
544    }
545
546    #[test]
547    fn test_rollback_to_specific() {
548        let mut manager = RollbackManager::new(10);
549        let mut ctx = create_test_context();
550
551        // Create first snapshot
552        let first_id = manager.create_snapshot(&ctx);
553
554        // Modify and create second snapshot
555        ctx.set_variable("modified".to_string(), serde_json::json!(true));
556        let _second_id = manager.create_snapshot(&ctx);
557
558        // Modify again
559        ctx.set_variable("more_changes".to_string(), serde_json::json!("value"));
560
561        // Rollback to first snapshot
562        let result = manager.rollback_to(&mut ctx, first_id);
563
564        assert!(result.success);
565        assert!(!ctx.variables.contains_key("modified"));
566        assert!(!ctx.variables.contains_key("more_changes"));
567    }
568
569    #[test]
570    fn test_rollback_n_steps() {
571        let mut manager = RollbackManager::new(10);
572        let mut ctx = create_test_context();
573
574        // Create snapshots (each snapshot contains variables up to that point)
575        // snap0: step_0, snap1: step_0,1, snap2: step_0,1,2, snap3: step_0,1,2,3, snap4: step_0,1,2,3,4
576        for i in 0..5 {
577            ctx.set_variable(format!("step_{}", i), serde_json::json!(i));
578            manager.create_snapshot(&ctx);
579        }
580
581        // Add one more variable that isn't in any snapshot
582        ctx.set_variable("step_5".to_string(), serde_json::json!(5));
583
584        // Rollback 3 steps goes to the 3rd most recent snapshot (snap2)
585        // snap4(idx0) -> snap3(idx1) -> snap2(idx2)
586        let result = manager.rollback_n(&mut ctx, 3);
587
588        assert!(result.success);
589        // snap2 has step_0, step_1, step_2 but NOT step_3, step_4, step_5
590        assert!(ctx.variables.contains_key("step_0"));
591        assert!(ctx.variables.contains_key("step_1"));
592        assert!(ctx.variables.contains_key("step_2"));
593        assert!(!ctx.variables.contains_key("step_3"));
594        assert!(!ctx.variables.contains_key("step_4"));
595        assert!(!ctx.variables.contains_key("step_5"));
596    }
597
598    #[test]
599    fn test_rollback_no_snapshots() {
600        let mut manager = RollbackManager::new(10);
601        let mut ctx = create_test_context();
602
603        let result = manager.rollback(&mut ctx);
604
605        assert!(!result.success);
606        assert!(result.error.is_some());
607    }
608
609    #[test]
610    fn test_auto_snapshot() {
611        let mut manager = RollbackManager::new(10).with_auto_snapshot(2);
612        let ctx = create_test_context();
613
614        // First node - no snapshot
615        let result = manager.on_node_execute(&ctx, Uuid::new_v4());
616        assert!(result.is_none());
617
618        // Second node - snapshot
619        let result = manager.on_node_execute(&ctx, Uuid::new_v4());
620        assert!(result.is_some());
621
622        // Third node - no snapshot
623        let result = manager.on_node_execute(&ctx, Uuid::new_v4());
624        assert!(result.is_none());
625
626        // Fourth node - snapshot
627        let result = manager.on_node_execute(&ctx, Uuid::new_v4());
628        assert!(result.is_some());
629
630        assert_eq!(manager.snapshot_count(), 2);
631    }
632
633    #[test]
634    fn test_prune_before() {
635        let mut manager = RollbackManager::new(10);
636        let ctx = create_test_context();
637
638        // Create some snapshots
639        manager.create_snapshot(&ctx);
640        std::thread::sleep(std::time::Duration::from_millis(10));
641        let cutoff = Utc::now();
642        std::thread::sleep(std::time::Duration::from_millis(10));
643        manager.create_snapshot(&ctx);
644        manager.create_snapshot(&ctx);
645
646        let pruned = manager.prune_before(cutoff);
647        assert_eq!(pruned, 1);
648        assert_eq!(manager.snapshot_count(), 2);
649    }
650
651    #[test]
652    fn test_rollback_summary() {
653        let mut manager = RollbackManager::new(5).with_auto_snapshot(3);
654        let ctx = create_test_context();
655
656        manager.create_snapshot(&ctx);
657        manager.create_snapshot(&ctx);
658
659        let summary = manager.summary();
660        assert_eq!(summary.total_snapshots, 2);
661        assert_eq!(summary.max_snapshots, 5);
662        assert!(summary.auto_snapshot_enabled);
663        assert_eq!(summary.auto_snapshot_interval, 3);
664        assert!(summary.oldest_snapshot.is_some());
665        assert!(summary.newest_snapshot.is_some());
666    }
667
668    #[test]
669    fn test_clear() {
670        let mut manager = RollbackManager::new(10).with_auto_snapshot(2);
671        let ctx = create_test_context();
672
673        manager.create_snapshot(&ctx);
674        manager.create_snapshot(&ctx);
675        for _ in 0..5 {
676            manager.on_node_execute(&ctx, Uuid::new_v4());
677        }
678
679        manager.clear();
680
681        assert_eq!(manager.snapshot_count(), 0);
682        assert_eq!(manager.summary().nodes_processed, 0);
683    }
684
685    #[test]
686    fn test_rollback_result_success() {
687        let result = RollbackResult::success(Uuid::new_v4(), 3, 2);
688        assert!(result.success);
689        assert!(result.applied_snapshot_id.is_some());
690        assert_eq!(result.nodes_removed, 3);
691        assert_eq!(result.variables_changed, 2);
692        assert!(result.error.is_none());
693    }
694
695    #[test]
696    fn test_rollback_result_failure() {
697        let result = RollbackResult::failure("No snapshots available");
698        assert!(!result.success);
699        assert!(result.applied_snapshot_id.is_none());
700        assert_eq!(result.error, Some("No snapshots available".to_string()));
701    }
702}