1use crate::execution::{ExecutionContext, ExecutionState, NodeExecutionResult};
33use crate::NodeId;
34use chrono::{DateTime, Utc};
35use serde::{Deserialize, Serialize};
36use std::collections::{HashMap, VecDeque};
37use uuid::Uuid;
38
39#[cfg(feature = "openapi")]
40use utoipa::ToSchema;
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44#[cfg_attr(feature = "openapi", derive(ToSchema))]
45pub struct ExecutionSnapshot {
46 #[cfg_attr(feature = "openapi", schema(value_type = String))]
48 pub id: Uuid,
49
50 pub created_at: DateTime<Utc>,
52
53 pub label: Option<String>,
55
56 #[cfg_attr(feature = "openapi", schema(value_type = Option<String>))]
58 pub trigger_node: Option<NodeId>,
59
60 pub state: ExecutionState,
62
63 pub variables: HashMap<String, serde_json::Value>,
65
66 #[cfg_attr(feature = "openapi", schema(value_type = HashMap<String, NodeExecutionResult>))]
68 pub node_results: HashMap<NodeId, NodeExecutionResult>,
69
70 #[serde(default)]
72 pub metadata: SnapshotMetadata,
73}
74
75#[derive(Debug, Clone, Default, Serialize, Deserialize)]
77#[cfg_attr(feature = "openapi", derive(ToSchema))]
78pub struct SnapshotMetadata {
79 pub reason: Option<String>,
81
82 pub created_by: Option<String>,
84
85 pub is_auto: bool,
87
88 #[serde(default)]
90 pub custom: HashMap<String, serde_json::Value>,
91}
92
93impl ExecutionSnapshot {
94 pub fn from_context(ctx: &ExecutionContext) -> Self {
96 Self {
97 id: Uuid::new_v4(),
98 created_at: Utc::now(),
99 label: None,
100 trigger_node: None,
101 state: ctx.state.clone(),
102 variables: ctx.variables.clone(),
103 node_results: ctx.node_results.clone(),
104 metadata: SnapshotMetadata::default(),
105 }
106 }
107
108 pub fn with_label(mut self, label: impl Into<String>) -> Self {
110 self.label = Some(label.into());
111 self
112 }
113
114 pub fn with_trigger_node(mut self, node_id: NodeId) -> Self {
116 self.trigger_node = Some(node_id);
117 self
118 }
119
120 pub fn with_reason(mut self, reason: impl Into<String>) -> Self {
122 self.metadata.reason = Some(reason.into());
123 self
124 }
125
126 pub fn as_auto(mut self) -> Self {
128 self.metadata.is_auto = true;
129 self
130 }
131
132 pub fn apply_to(&self, ctx: &mut ExecutionContext) {
134 ctx.state = self.state.clone();
135 ctx.variables = self.variables.clone();
136 ctx.node_results = self.node_results.clone();
137 }
138
139 pub fn completed_node_count(&self) -> usize {
141 self.node_results.len()
142 }
143
144 pub fn variable_count(&self) -> usize {
146 self.variables.len()
147 }
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
152#[cfg_attr(feature = "openapi", derive(ToSchema))]
153pub struct RollbackResult {
154 pub success: bool,
156
157 #[cfg_attr(feature = "openapi", schema(value_type = Option<String>))]
159 pub applied_snapshot_id: Option<Uuid>,
160
161 pub nodes_removed: usize,
163
164 pub variables_changed: usize,
166
167 pub error: Option<String>,
169
170 pub rolled_back_at: DateTime<Utc>,
172}
173
174impl RollbackResult {
175 pub fn success(snapshot_id: Uuid, nodes_removed: usize, variables_changed: usize) -> Self {
177 Self {
178 success: true,
179 applied_snapshot_id: Some(snapshot_id),
180 nodes_removed,
181 variables_changed,
182 error: None,
183 rolled_back_at: Utc::now(),
184 }
185 }
186
187 pub fn failure(error: impl Into<String>) -> Self {
189 Self {
190 success: false,
191 applied_snapshot_id: None,
192 nodes_removed: 0,
193 variables_changed: 0,
194 error: Some(error.into()),
195 rolled_back_at: Utc::now(),
196 }
197 }
198}
199
200#[derive(Debug, Clone)]
202pub struct RollbackManager {
203 snapshots: VecDeque<ExecutionSnapshot>,
205
206 max_snapshots: usize,
208
209 auto_snapshot: bool,
211
212 auto_snapshot_interval: usize,
214
215 node_counter: usize,
217}
218
219impl RollbackManager {
220 pub fn new(max_snapshots: usize) -> Self {
225 Self {
226 snapshots: VecDeque::new(),
227 max_snapshots,
228 auto_snapshot: false,
229 auto_snapshot_interval: 5,
230 node_counter: 0,
231 }
232 }
233
234 pub fn with_auto_snapshot(mut self, interval: usize) -> Self {
236 self.auto_snapshot = true;
237 self.auto_snapshot_interval = interval;
238 self
239 }
240
241 pub fn push_snapshot(&mut self, snapshot: ExecutionSnapshot) {
245 self.snapshots.push_front(snapshot);
247
248 while self.snapshots.len() > self.max_snapshots {
250 self.snapshots.pop_back();
251 }
252 }
253
254 pub fn create_snapshot(&mut self, ctx: &ExecutionContext) -> Uuid {
256 let snapshot = ExecutionSnapshot::from_context(ctx);
257 let id = snapshot.id;
258 self.push_snapshot(snapshot);
259 id
260 }
261
262 pub fn create_labeled_snapshot(
264 &mut self,
265 ctx: &ExecutionContext,
266 label: impl Into<String>,
267 ) -> Uuid {
268 let snapshot = ExecutionSnapshot::from_context(ctx).with_label(label);
269 let id = snapshot.id;
270 self.push_snapshot(snapshot);
271 id
272 }
273
274 pub fn on_node_execute(&mut self, ctx: &ExecutionContext, node_id: NodeId) -> Option<Uuid> {
278 self.node_counter += 1;
279
280 if self.auto_snapshot
281 && self
282 .node_counter
283 .is_multiple_of(self.auto_snapshot_interval)
284 {
285 let snapshot = ExecutionSnapshot::from_context(ctx)
286 .with_trigger_node(node_id)
287 .as_auto();
288 let id = snapshot.id;
289 self.push_snapshot(snapshot);
290 Some(id)
291 } else {
292 None
293 }
294 }
295
296 pub fn rollback(&mut self, ctx: &mut ExecutionContext) -> RollbackResult {
298 if let Some(snapshot) = self.snapshots.front() {
299 let nodes_before = ctx.node_results.len();
300 let vars_before = ctx.variables.clone();
301
302 snapshot.apply_to(ctx);
303
304 let nodes_removed = nodes_before.saturating_sub(ctx.node_results.len());
305 let variables_changed = vars_before
306 .iter()
307 .filter(|(k, v)| ctx.variables.get(*k) != Some(*v))
308 .count();
309
310 RollbackResult::success(snapshot.id, nodes_removed, variables_changed)
311 } else {
312 RollbackResult::failure("No snapshots available")
313 }
314 }
315
316 pub fn rollback_to(&mut self, ctx: &mut ExecutionContext, snapshot_id: Uuid) -> RollbackResult {
318 if let Some(snapshot) = self.snapshots.iter().find(|s| s.id == snapshot_id) {
319 let nodes_before = ctx.node_results.len();
320 let vars_before = ctx.variables.clone();
321
322 snapshot.clone().apply_to(ctx);
323
324 let nodes_removed = nodes_before.saturating_sub(ctx.node_results.len());
325 let variables_changed = vars_before
326 .iter()
327 .filter(|(k, v)| ctx.variables.get(*k) != Some(*v))
328 .count();
329
330 RollbackResult::success(snapshot_id, nodes_removed, variables_changed)
331 } else {
332 RollbackResult::failure(format!("Snapshot {} not found", snapshot_id))
333 }
334 }
335
336 pub fn rollback_n(&mut self, ctx: &mut ExecutionContext, steps: usize) -> RollbackResult {
345 if steps == 0 {
346 return RollbackResult::failure("Cannot rollback 0 steps");
347 }
348
349 if steps > self.snapshots.len() {
350 return RollbackResult::failure(format!(
351 "Cannot rollback {} steps, only {} snapshots available",
352 steps,
353 self.snapshots.len()
354 ));
355 }
356
357 if let Some(snapshot) = self.snapshots.get(steps - 1) {
360 let nodes_before = ctx.node_results.len();
361 let vars_before = ctx.variables.clone();
362
363 snapshot.clone().apply_to(ctx);
364
365 let nodes_removed = nodes_before.saturating_sub(ctx.node_results.len());
366 let variables_changed = vars_before
367 .iter()
368 .filter(|(k, v)| ctx.variables.get(*k) != Some(*v))
369 .count();
370
371 RollbackResult::success(snapshot.id, nodes_removed, variables_changed)
372 } else {
373 RollbackResult::failure("Snapshot not found")
374 }
375 }
376
377 pub fn latest_snapshot(&self) -> Option<&ExecutionSnapshot> {
379 self.snapshots.front()
380 }
381
382 pub fn get_snapshot(&self, id: Uuid) -> Option<&ExecutionSnapshot> {
384 self.snapshots.iter().find(|s| s.id == id)
385 }
386
387 pub fn list_snapshots(&self) -> Vec<&ExecutionSnapshot> {
389 self.snapshots.iter().collect()
390 }
391
392 pub fn snapshot_count(&self) -> usize {
394 self.snapshots.len()
395 }
396
397 pub fn clear(&mut self) {
399 self.snapshots.clear();
400 self.node_counter = 0;
401 }
402
403 pub fn prune_before(&mut self, timestamp: DateTime<Utc>) -> usize {
405 let before = self.snapshots.len();
406 self.snapshots.retain(|s| s.created_at >= timestamp);
407 before - self.snapshots.len()
408 }
409
410 pub fn summary(&self) -> RollbackSummary {
412 RollbackSummary {
413 total_snapshots: self.snapshots.len(),
414 max_snapshots: self.max_snapshots,
415 auto_snapshot_enabled: self.auto_snapshot,
416 auto_snapshot_interval: self.auto_snapshot_interval,
417 oldest_snapshot: self.snapshots.back().map(|s| s.created_at),
418 newest_snapshot: self.snapshots.front().map(|s| s.created_at),
419 nodes_processed: self.node_counter,
420 }
421 }
422}
423
424impl Default for RollbackManager {
425 fn default() -> Self {
426 Self::new(10)
427 }
428}
429
430#[derive(Debug, Clone, Serialize, Deserialize)]
432#[cfg_attr(feature = "openapi", derive(ToSchema))]
433pub struct RollbackSummary {
434 pub total_snapshots: usize,
436
437 pub max_snapshots: usize,
439
440 pub auto_snapshot_enabled: bool,
442
443 pub auto_snapshot_interval: usize,
445
446 pub oldest_snapshot: Option<DateTime<Utc>>,
448
449 pub newest_snapshot: Option<DateTime<Utc>>,
451
452 pub nodes_processed: usize,
454}
455
456#[cfg(test)]
457mod tests {
458 use super::*;
459
460 fn create_test_context() -> ExecutionContext {
461 let mut ctx = ExecutionContext::new(Uuid::new_v4());
462 ctx.set_variable("key1".to_string(), serde_json::json!("value1"));
463 ctx.set_variable("key2".to_string(), serde_json::json!(42));
464 ctx
465 }
466
467 #[test]
468 fn test_snapshot_creation() {
469 let ctx = create_test_context();
470 let snapshot = ExecutionSnapshot::from_context(&ctx);
471
472 assert!(snapshot.label.is_none());
473 assert!(snapshot.trigger_node.is_none());
474 assert_eq!(snapshot.variables.len(), 2);
475 assert_eq!(snapshot.state, ExecutionState::Running);
476 }
477
478 #[test]
479 fn test_snapshot_with_label() {
480 let ctx = create_test_context();
481 let snapshot = ExecutionSnapshot::from_context(&ctx).with_label("Before LLM call");
482
483 assert_eq!(snapshot.label, Some("Before LLM call".to_string()));
484 }
485
486 #[test]
487 fn test_snapshot_with_trigger_node() {
488 let ctx = create_test_context();
489 let node_id = Uuid::new_v4();
490 let snapshot = ExecutionSnapshot::from_context(&ctx).with_trigger_node(node_id);
491
492 assert_eq!(snapshot.trigger_node, Some(node_id));
493 }
494
495 #[test]
496 fn test_snapshot_apply_to() {
497 let mut ctx = create_test_context();
498 let snapshot = ExecutionSnapshot::from_context(&ctx);
499
500 ctx.set_variable("key3".to_string(), serde_json::json!("new_value"));
502 ctx.state = ExecutionState::Failed("test error".to_string());
503
504 snapshot.apply_to(&mut ctx);
506
507 assert_eq!(ctx.variables.len(), 2);
508 assert!(!ctx.variables.contains_key("key3"));
509 assert_eq!(ctx.state, ExecutionState::Running);
510 }
511
512 #[test]
513 fn test_rollback_manager_push() {
514 let mut manager = RollbackManager::new(3);
515
516 for i in 0..5 {
517 let mut ctx = create_test_context();
518 ctx.set_variable(format!("iter_{}", i), serde_json::json!(i));
519 manager.create_snapshot(&ctx);
520 }
521
522 assert_eq!(manager.snapshot_count(), 3);
524 }
525
526 #[test]
527 fn test_rollback_latest() {
528 let mut manager = RollbackManager::new(10);
529 let mut ctx = create_test_context();
530
531 manager.create_snapshot(&ctx);
533
534 ctx.set_variable("new_key".to_string(), serde_json::json!("new_value"));
536 ctx.state = ExecutionState::Failed("error".to_string());
537
538 let result = manager.rollback(&mut ctx);
540
541 assert!(result.success);
542 assert!(!ctx.variables.contains_key("new_key"));
543 assert_eq!(ctx.state, ExecutionState::Running);
544 }
545
546 #[test]
547 fn test_rollback_to_specific() {
548 let mut manager = RollbackManager::new(10);
549 let mut ctx = create_test_context();
550
551 let first_id = manager.create_snapshot(&ctx);
553
554 ctx.set_variable("modified".to_string(), serde_json::json!(true));
556 let _second_id = manager.create_snapshot(&ctx);
557
558 ctx.set_variable("more_changes".to_string(), serde_json::json!("value"));
560
561 let result = manager.rollback_to(&mut ctx, first_id);
563
564 assert!(result.success);
565 assert!(!ctx.variables.contains_key("modified"));
566 assert!(!ctx.variables.contains_key("more_changes"));
567 }
568
569 #[test]
570 fn test_rollback_n_steps() {
571 let mut manager = RollbackManager::new(10);
572 let mut ctx = create_test_context();
573
574 for i in 0..5 {
577 ctx.set_variable(format!("step_{}", i), serde_json::json!(i));
578 manager.create_snapshot(&ctx);
579 }
580
581 ctx.set_variable("step_5".to_string(), serde_json::json!(5));
583
584 let result = manager.rollback_n(&mut ctx, 3);
587
588 assert!(result.success);
589 assert!(ctx.variables.contains_key("step_0"));
591 assert!(ctx.variables.contains_key("step_1"));
592 assert!(ctx.variables.contains_key("step_2"));
593 assert!(!ctx.variables.contains_key("step_3"));
594 assert!(!ctx.variables.contains_key("step_4"));
595 assert!(!ctx.variables.contains_key("step_5"));
596 }
597
598 #[test]
599 fn test_rollback_no_snapshots() {
600 let mut manager = RollbackManager::new(10);
601 let mut ctx = create_test_context();
602
603 let result = manager.rollback(&mut ctx);
604
605 assert!(!result.success);
606 assert!(result.error.is_some());
607 }
608
609 #[test]
610 fn test_auto_snapshot() {
611 let mut manager = RollbackManager::new(10).with_auto_snapshot(2);
612 let ctx = create_test_context();
613
614 let result = manager.on_node_execute(&ctx, Uuid::new_v4());
616 assert!(result.is_none());
617
618 let result = manager.on_node_execute(&ctx, Uuid::new_v4());
620 assert!(result.is_some());
621
622 let result = manager.on_node_execute(&ctx, Uuid::new_v4());
624 assert!(result.is_none());
625
626 let result = manager.on_node_execute(&ctx, Uuid::new_v4());
628 assert!(result.is_some());
629
630 assert_eq!(manager.snapshot_count(), 2);
631 }
632
633 #[test]
634 fn test_prune_before() {
635 let mut manager = RollbackManager::new(10);
636 let ctx = create_test_context();
637
638 manager.create_snapshot(&ctx);
640 std::thread::sleep(std::time::Duration::from_millis(10));
641 let cutoff = Utc::now();
642 std::thread::sleep(std::time::Duration::from_millis(10));
643 manager.create_snapshot(&ctx);
644 manager.create_snapshot(&ctx);
645
646 let pruned = manager.prune_before(cutoff);
647 assert_eq!(pruned, 1);
648 assert_eq!(manager.snapshot_count(), 2);
649 }
650
651 #[test]
652 fn test_rollback_summary() {
653 let mut manager = RollbackManager::new(5).with_auto_snapshot(3);
654 let ctx = create_test_context();
655
656 manager.create_snapshot(&ctx);
657 manager.create_snapshot(&ctx);
658
659 let summary = manager.summary();
660 assert_eq!(summary.total_snapshots, 2);
661 assert_eq!(summary.max_snapshots, 5);
662 assert!(summary.auto_snapshot_enabled);
663 assert_eq!(summary.auto_snapshot_interval, 3);
664 assert!(summary.oldest_snapshot.is_some());
665 assert!(summary.newest_snapshot.is_some());
666 }
667
668 #[test]
669 fn test_clear() {
670 let mut manager = RollbackManager::new(10).with_auto_snapshot(2);
671 let ctx = create_test_context();
672
673 manager.create_snapshot(&ctx);
674 manager.create_snapshot(&ctx);
675 for _ in 0..5 {
676 manager.on_node_execute(&ctx, Uuid::new_v4());
677 }
678
679 manager.clear();
680
681 assert_eq!(manager.snapshot_count(), 0);
682 assert_eq!(manager.summary().nodes_processed, 0);
683 }
684
685 #[test]
686 fn test_rollback_result_success() {
687 let result = RollbackResult::success(Uuid::new_v4(), 3, 2);
688 assert!(result.success);
689 assert!(result.applied_snapshot_id.is_some());
690 assert_eq!(result.nodes_removed, 3);
691 assert_eq!(result.variables_changed, 2);
692 assert!(result.error.is_none());
693 }
694
695 #[test]
696 fn test_rollback_result_failure() {
697 let result = RollbackResult::failure("No snapshots available");
698 assert!(!result.success);
699 assert!(result.applied_snapshot_id.is_none());
700 assert_eq!(result.error, Some("No snapshots available".to_string()));
701 }
702}