Skip to main content

brainwires_agent/
saga.rs

1//! Saga-Style Compensating Transactions
2//!
3//! Based on SagaLLM (arXiv:2503.11951), this module implements the Saga pattern
4//! for multi-step operations. When an operation fails mid-way, compensation
5//! actions are executed in reverse order to undo completed sub-operations.
6//!
7//! # Key Concepts
8//!
9//! - **CompensableOperation**: A trait for operations that can be undone
10//! - **SagaExecutor**: Manages execution and compensation of operation sequences
11//! - **Checkpoint**: Captures state before operations for rollback
12//! - **CompensationReport**: Summary of compensation actions taken
13//!
14//! # Example
15//!
16//! ```ignore
17//! let mut saga = SagaExecutor::new("agent-1", "edit-and-build");
18//!
19//! // Execute operations in sequence
20//! saga.execute_step(Arc::new(FileEditOp { path, content })).await?;
21//! saga.execute_step(Arc::new(GitStageOp { files })).await?;
22//! saga.execute_step(Arc::new(BuildOp { project })).await?;
23//!
24//! // If any step fails, compensate all completed operations
25//! if failed {
26//!     let report = saga.compensate_all().await?;
27//!     // Files restored, staging undone
28//! }
29//! ```
30
31use std::path::PathBuf;
32use std::sync::Arc;
33use std::time::Instant;
34
35use anyhow::Result;
36use async_trait::async_trait;
37use serde::{Deserialize, Serialize};
38use tokio::sync::RwLock;
39
40/// A compensable operation that can be undone
41#[async_trait]
42pub trait CompensableOperation: Send + Sync {
43    /// Execute the forward operation
44    async fn execute(&self) -> Result<OperationResult>;
45
46    /// Compensate (undo) the operation
47    async fn compensate(&self, result: &OperationResult) -> Result<()>;
48
49    /// Get operation description for logging
50    fn description(&self) -> String;
51
52    /// Get the operation type (for categorization)
53    fn operation_type(&self) -> SagaOperationType {
54        SagaOperationType::Generic
55    }
56}
57
58/// Result of an operation, needed for compensation
59#[derive(Debug, Clone, Serialize, Deserialize)]
60pub struct OperationResult {
61    /// Unique identifier for this operation.
62    pub operation_id: String,
63    /// Whether the operation succeeded.
64    pub success: bool,
65    /// State captured before operation (for rollback)
66    #[serde(skip)]
67    pub checkpoint: Option<Checkpoint>,
68    /// Metadata needed for compensation
69    pub compensation_data: serde_json::Value,
70    /// Output from the operation
71    pub output: Option<String>,
72}
73
74impl OperationResult {
75    /// Create a successful operation result.
76    pub fn success(operation_id: impl Into<String>) -> Self {
77        Self {
78            operation_id: operation_id.into(),
79            success: true,
80            checkpoint: None,
81            compensation_data: serde_json::Value::Null,
82            output: None,
83        }
84    }
85
86    /// Create a successful operation result with compensation data.
87    pub fn success_with_data(operation_id: impl Into<String>, data: serde_json::Value) -> Self {
88        Self {
89            operation_id: operation_id.into(),
90            success: true,
91            checkpoint: None,
92            compensation_data: data,
93            output: None,
94        }
95    }
96
97    /// Create a failed operation result.
98    pub fn failure(operation_id: impl Into<String>) -> Self {
99        Self {
100            operation_id: operation_id.into(),
101            success: false,
102            checkpoint: None,
103            compensation_data: serde_json::Value::Null,
104            output: None,
105        }
106    }
107
108    /// Attach a checkpoint for rollback.
109    pub fn with_checkpoint(mut self, checkpoint: Checkpoint) -> Self {
110        self.checkpoint = Some(checkpoint);
111        self
112    }
113
114    /// Attach output text to the result.
115    pub fn with_output(mut self, output: impl Into<String>) -> Self {
116        self.output = Some(output.into());
117        self
118    }
119}
120
121/// Checkpoint for state restoration
122#[derive(Debug, Clone)]
123pub struct Checkpoint {
124    /// Checkpoint identifier.
125    pub id: String,
126    /// When the checkpoint was created.
127    pub timestamp: Instant,
128    /// File states before modification
129    pub file_states: Vec<FileState>,
130    /// Git state before modification
131    pub git_state: Option<GitCheckpoint>,
132}
133
134impl Checkpoint {
135    /// Create a new checkpoint with the given identifier.
136    pub fn new(id: impl Into<String>) -> Self {
137        Self {
138            id: id.into(),
139            timestamp: Instant::now(),
140            file_states: Vec::new(),
141            git_state: None,
142        }
143    }
144
145    /// Add a file state to the checkpoint.
146    pub fn with_file(mut self, path: PathBuf, content: String) -> Self {
147        self.file_states.push(FileState {
148            path,
149            content_hash: Self::hash_content(&content),
150            original_content: Some(content),
151        });
152        self
153    }
154
155    /// Set all file states at once.
156    pub fn with_files(mut self, files: Vec<FileState>) -> Self {
157        self.file_states = files;
158        self
159    }
160
161    /// Attach a git checkpoint.
162    pub fn with_git(mut self, git_state: GitCheckpoint) -> Self {
163        self.git_state = Some(git_state);
164        self
165    }
166
167    fn hash_content(content: &str) -> String {
168        use std::collections::hash_map::DefaultHasher;
169        use std::hash::{Hash, Hasher};
170
171        let mut hasher = DefaultHasher::new();
172        content.hash(&mut hasher);
173        format!("{:x}", hasher.finish())
174    }
175}
176
177/// Snapshot of a file's state for restoration.
178#[derive(Debug, Clone, Serialize, Deserialize)]
179pub struct FileState {
180    /// File path.
181    pub path: PathBuf,
182    /// Hash of the file content.
183    pub content_hash: String,
184    /// Original content for small files (for direct restoration)
185    pub original_content: Option<String>,
186}
187
188impl FileState {
189    /// Create a file state with path and content hash.
190    pub fn new(path: PathBuf, content_hash: String) -> Self {
191        Self {
192            path,
193            content_hash,
194            original_content: None,
195        }
196    }
197
198    /// Create a file state with path and full content (auto-hashed).
199    pub fn with_content(path: PathBuf, content: String) -> Self {
200        use std::collections::hash_map::DefaultHasher;
201        use std::hash::{Hash, Hasher};
202
203        let mut hasher = DefaultHasher::new();
204        content.hash(&mut hasher);
205
206        Self {
207            path,
208            content_hash: format!("{:x}", hasher.finish()),
209            original_content: Some(content),
210        }
211    }
212}
213
214/// Snapshot of git state for restoration.
215#[derive(Debug, Clone, Serialize, Deserialize)]
216pub struct GitCheckpoint {
217    /// HEAD commit hash.
218    pub head_commit: String,
219    /// List of staged files.
220    pub staged_files: Vec<String>,
221    /// Current branch name.
222    pub branch: String,
223}
224
225impl GitCheckpoint {
226    /// Create a new git checkpoint.
227    pub fn new(head_commit: impl Into<String>, branch: impl Into<String>) -> Self {
228        Self {
229            head_commit: head_commit.into(),
230            staged_files: Vec::new(),
231            branch: branch.into(),
232        }
233    }
234
235    /// Set staged files for this checkpoint.
236    pub fn with_staged(mut self, files: Vec<String>) -> Self {
237        self.staged_files = files;
238        self
239    }
240}
241
242/// Types of saga operations for categorization
243#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
244pub enum SagaOperationType {
245    /// File write operation
246    FileWrite,
247    /// File edit operation
248    FileEdit,
249    /// File delete operation
250    FileDelete,
251    /// Git staging operation
252    GitStage,
253    /// Git unstaging operation
254    GitUnstage,
255    /// Git commit operation
256    GitCommit,
257    /// Git branch creation
258    GitBranchCreate,
259    /// Git branch deletion
260    GitBranchDelete,
261    /// Build operation
262    Build,
263    /// Test operation
264    Test,
265    /// Generic operation
266    Generic,
267}
268
269impl SagaOperationType {
270    /// Returns true if this operation type can be compensated
271    pub fn is_compensable(&self) -> bool {
272        match self {
273            SagaOperationType::FileWrite
274            | SagaOperationType::FileEdit
275            | SagaOperationType::FileDelete
276            | SagaOperationType::GitStage
277            | SagaOperationType::GitUnstage
278            | SagaOperationType::GitCommit
279            | SagaOperationType::GitBranchCreate
280            | SagaOperationType::GitBranchDelete => true,
281            // Build and test are idempotent, no compensation needed
282            SagaOperationType::Build | SagaOperationType::Test | SagaOperationType::Generic => {
283                false
284            }
285        }
286    }
287
288    /// Get the compensation description for this operation type
289    pub fn compensation_description(&self) -> &'static str {
290        match self {
291            SagaOperationType::FileWrite => "Delete written file or restore from backup",
292            SagaOperationType::FileEdit => "Restore original file content",
293            SagaOperationType::FileDelete => "Restore deleted file",
294            SagaOperationType::GitStage => "git reset HEAD <files>",
295            SagaOperationType::GitUnstage => "git add <files>",
296            SagaOperationType::GitCommit => "git reset --soft HEAD~1",
297            SagaOperationType::GitBranchCreate => "git branch -d <branch>",
298            SagaOperationType::GitBranchDelete => "Restore branch from reflog",
299            SagaOperationType::Build => "No compensation (idempotent)",
300            SagaOperationType::Test => "No compensation (idempotent)",
301            SagaOperationType::Generic => "Custom compensation",
302        }
303    }
304}
305
306/// Saga executor that manages compensating transactions
307pub struct SagaExecutor {
308    /// Unique saga identifier
309    pub saga_id: String,
310    /// Agent executing this saga
311    pub agent_id: String,
312    /// Description of the saga's purpose
313    pub description: String,
314    /// When the saga started
315    pub started_at: Instant,
316    /// Completed operations in execution order
317    completed_ops: RwLock<Vec<(Arc<dyn CompensableOperation>, OperationResult)>>,
318    /// Compensation callbacks
319    #[allow(clippy::type_complexity)]
320    compensation_hooks: RwLock<Vec<Box<dyn Fn(&str, bool) + Send + Sync>>>,
321    /// Current status
322    status: RwLock<SagaStatus>,
323}
324
325/// Current status of a saga execution.
326#[derive(Debug, Clone, Copy, PartialEq, Eq)]
327pub enum SagaStatus {
328    /// Saga is in progress
329    Running,
330    /// Saga completed successfully
331    Completed,
332    /// Saga failed and compensation is needed
333    Failed,
334    /// Compensation is in progress
335    Compensating,
336    /// Compensation completed
337    Compensated,
338}
339
340impl SagaExecutor {
341    /// Create a new saga executor for the given agent and description.
342    pub fn new(agent_id: impl Into<String>, description: impl Into<String>) -> Self {
343        let agent_id = agent_id.into();
344        let description = description.into();
345        let saga_id = format!(
346            "saga-{}-{}",
347            agent_id,
348            std::time::SystemTime::now()
349                .duration_since(std::time::UNIX_EPOCH)
350                .expect("system clock before UNIX epoch")
351                .as_millis()
352        );
353
354        Self {
355            saga_id,
356            agent_id,
357            description,
358            started_at: Instant::now(),
359            completed_ops: RwLock::new(Vec::new()),
360            compensation_hooks: RwLock::new(Vec::new()),
361            status: RwLock::new(SagaStatus::Running),
362        }
363    }
364
365    /// Get current status
366    pub async fn status(&self) -> SagaStatus {
367        *self.status.read().await
368    }
369
370    /// Get number of completed operations
371    pub async fn operation_count(&self) -> usize {
372        self.completed_ops.read().await.len()
373    }
374
375    /// Execute an operation within the saga
376    pub async fn execute_step(&self, op: Arc<dyn CompensableOperation>) -> Result<OperationResult> {
377        // Check if saga is still running
378        if *self.status.read().await != SagaStatus::Running {
379            anyhow::bail!("Cannot execute step: saga is not running");
380        }
381
382        tracing::debug!(
383            saga_id = %self.saga_id,
384            operation = %op.description(),
385            "Executing saga step"
386        );
387
388        let result = op.execute().await?;
389
390        if result.success {
391            self.completed_ops.write().await.push((op, result.clone()));
392            tracing::debug!(
393                saga_id = %self.saga_id,
394                "Saga step completed successfully"
395            );
396        } else {
397            *self.status.write().await = SagaStatus::Failed;
398            tracing::warn!(
399                saga_id = %self.saga_id,
400                "Saga step failed"
401            );
402        }
403
404        Ok(result)
405    }
406
407    /// Mark the saga as successfully completed
408    pub async fn complete(&self) {
409        *self.status.write().await = SagaStatus::Completed;
410        tracing::info!(
411            saga_id = %self.saga_id,
412            operations = self.completed_ops.read().await.len(),
413            "Saga completed successfully"
414        );
415    }
416
417    /// Mark the saga as failed (triggers compensation need)
418    pub async fn fail(&self) {
419        *self.status.write().await = SagaStatus::Failed;
420        tracing::warn!(
421            saga_id = %self.saga_id,
422            operations = self.completed_ops.read().await.len(),
423            "Saga marked as failed"
424        );
425    }
426
427    /// Compensate all completed operations in reverse order
428    pub async fn compensate_all(&self) -> Result<CompensationReport> {
429        *self.status.write().await = SagaStatus::Compensating;
430
431        let mut report = CompensationReport::new(self.saga_id.clone());
432
433        tracing::info!(
434            saga_id = %self.saga_id,
435            "Starting saga compensation"
436        );
437
438        // Pop operations in reverse order
439        while let Some((op, result)) = self.completed_ops.write().await.pop() {
440            let description = op.description();
441
442            // Skip non-compensable operations
443            if !op.operation_type().is_compensable() {
444                tracing::debug!(
445                    operation = %description,
446                    "Skipping non-compensable operation"
447                );
448                report.add_skipped(&description, "Non-compensable operation type");
449                continue;
450            }
451
452            tracing::debug!(
453                operation = %description,
454                "Compensating operation"
455            );
456
457            match op.compensate(&result).await {
458                Ok(()) => {
459                    report.add_success(&description);
460                    tracing::debug!(
461                        operation = %description,
462                        "Compensation successful"
463                    );
464                }
465                Err(e) => {
466                    let error_msg = e.to_string();
467                    report.add_failure(&description, error_msg.clone());
468                    tracing::error!(
469                        operation = %description,
470                        error = %error_msg,
471                        "Compensation failed"
472                    );
473                    // Continue compensating other operations even if one fails
474                }
475            }
476        }
477
478        *self.status.write().await = SagaStatus::Compensated;
479
480        // Call compensation hooks
481        let hooks = self.compensation_hooks.read().await;
482        let summary = report.summary();
483        let all_successful = report.all_successful();
484        for hook in hooks.iter() {
485            hook(&summary, all_successful);
486        }
487
488        tracing::info!(
489            saga_id = %self.saga_id,
490            summary = %summary,
491            "Saga compensation completed"
492        );
493
494        Ok(report)
495    }
496
497    /// Add a hook called after compensation
498    pub async fn on_compensation<F>(&self, hook: F)
499    where
500        F: Fn(&str, bool) + Send + Sync + 'static,
501    {
502        self.compensation_hooks.write().await.push(Box::new(hook));
503    }
504
505    /// Get descriptions of all completed operations
506    pub async fn get_operation_descriptions(&self) -> Vec<String> {
507        self.completed_ops
508            .read()
509            .await
510            .iter()
511            .map(|(op, _)| op.description())
512            .collect()
513    }
514}
515
516/// Report of compensation actions
517#[derive(Debug, Clone, Serialize, Deserialize)]
518pub struct CompensationReport {
519    /// Saga this report belongs to.
520    pub saga_id: String,
521    /// Status of each compensation action.
522    pub operations: Vec<CompensationStatus>,
523    /// When compensation started (epoch millis).
524    pub started_at: u64,
525    /// When compensation finished (epoch millis).
526    pub completed_at: Option<u64>,
527}
528
529impl CompensationReport {
530    /// Create a new empty compensation report.
531    pub fn new(saga_id: String) -> Self {
532        Self {
533            saga_id,
534            operations: Vec::new(),
535            started_at: std::time::SystemTime::now()
536                .duration_since(std::time::UNIX_EPOCH)
537                .expect("system clock before UNIX epoch")
538                .as_millis() as u64,
539            completed_at: None,
540        }
541    }
542
543    /// Record a successful compensation action.
544    pub fn add_success(&mut self, description: &str) {
545        self.operations.push(CompensationStatus {
546            description: description.to_string(),
547            status: CompensationOutcome::Success,
548            error: None,
549        });
550    }
551
552    /// Record a failed compensation action.
553    pub fn add_failure(&mut self, description: &str, error: String) {
554        self.operations.push(CompensationStatus {
555            description: description.to_string(),
556            status: CompensationOutcome::Failed,
557            error: Some(error),
558        });
559    }
560
561    /// Record a skipped compensation action.
562    pub fn add_skipped(&mut self, description: &str, reason: &str) {
563        self.operations.push(CompensationStatus {
564            description: description.to_string(),
565            status: CompensationOutcome::Skipped,
566            error: Some(reason.to_string()),
567        });
568    }
569
570    /// Returns true if all compensations succeeded or were skipped.
571    pub fn all_successful(&self) -> bool {
572        self.operations.iter().all(|s| {
573            matches!(
574                s.status,
575                CompensationOutcome::Success | CompensationOutcome::Skipped
576            )
577        })
578    }
579
580    /// Generate a human-readable summary of compensation outcomes.
581    pub fn summary(&self) -> String {
582        let successful = self
583            .operations
584            .iter()
585            .filter(|s| s.status == CompensationOutcome::Success)
586            .count();
587        let failed = self
588            .operations
589            .iter()
590            .filter(|s| s.status == CompensationOutcome::Failed)
591            .count();
592        let skipped = self
593            .operations
594            .iter()
595            .filter(|s| s.status == CompensationOutcome::Skipped)
596            .count();
597
598        format!(
599            "{} successful, {} failed, {} skipped (total: {})",
600            successful,
601            failed,
602            skipped,
603            self.operations.len()
604        )
605    }
606
607    /// Mark the compensation report as completed with a timestamp.
608    pub fn mark_completed(&mut self) {
609        self.completed_at = Some(
610            std::time::SystemTime::now()
611                .duration_since(std::time::UNIX_EPOCH)
612                .expect("system clock before UNIX epoch")
613                .as_millis() as u64,
614        );
615    }
616}
617
618/// Status of a single compensation action.
619#[derive(Debug, Clone, Serialize, Deserialize)]
620pub struct CompensationStatus {
621    /// Description of the compensated operation.
622    pub description: String,
623    /// Outcome of the compensation attempt.
624    pub status: CompensationOutcome,
625    /// Error message if compensation failed.
626    pub error: Option<String>,
627}
628
629/// Outcome of a compensation attempt.
630#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
631pub enum CompensationOutcome {
632    /// Compensation succeeded.
633    Success,
634    /// Compensation failed.
635    Failed,
636    /// Compensation was skipped.
637    Skipped,
638}
639
640// =============================================================================
641// Common Compensable Operations
642// =============================================================================
643
644/// File write operation with compensation
645pub struct FileWriteOp {
646    /// Path to the file being written.
647    pub path: PathBuf,
648    /// Content to write.
649    pub content: String,
650    /// Whether this creates a new file or overwrites an existing one.
651    pub is_new_file: bool,
652}
653
654#[cfg(feature = "native")]
655#[async_trait]
656impl CompensableOperation for FileWriteOp {
657    async fn execute(&self) -> Result<OperationResult> {
658        // Capture existing content if file exists
659        let checkpoint = if self.path.exists() {
660            let original_content = tokio::fs::read_to_string(&self.path).await?;
661            Some(
662                Checkpoint::new(format!("file-write-{}", self.path.display()))
663                    .with_file(self.path.clone(), original_content),
664            )
665        } else {
666            None
667        };
668
669        // Write new content
670        tokio::fs::write(&self.path, &self.content).await?;
671
672        let mut result = OperationResult::success_with_data(
673            format!("file-write-{}", self.path.display()),
674            serde_json::json!({
675                "path": self.path.display().to_string(),
676                "is_new_file": self.is_new_file,
677            }),
678        );
679
680        if let Some(cp) = checkpoint {
681            result = result.with_checkpoint(cp);
682        }
683
684        Ok(result)
685    }
686
687    async fn compensate(&self, result: &OperationResult) -> Result<()> {
688        if self.is_new_file {
689            // Delete the new file
690            if self.path.exists() {
691                tokio::fs::remove_file(&self.path).await?;
692            }
693        } else if let Some(checkpoint) = &result.checkpoint {
694            // Restore original content
695            if let Some(file_state) = checkpoint.file_states.first()
696                && let Some(original_content) = &file_state.original_content
697            {
698                tokio::fs::write(&self.path, original_content).await?;
699            }
700        }
701        Ok(())
702    }
703
704    fn description(&self) -> String {
705        format!(
706            "Write file: {} ({})",
707            self.path.display(),
708            if self.is_new_file { "new" } else { "existing" }
709        )
710    }
711
712    fn operation_type(&self) -> SagaOperationType {
713        SagaOperationType::FileWrite
714    }
715}
716
717/// File edit operation with compensation
718pub struct FileEditOp {
719    /// Path to the file being edited.
720    pub path: PathBuf,
721    /// Original file content before edit.
722    pub old_content: String,
723    /// New file content after edit.
724    pub new_content: String,
725}
726
727#[cfg(feature = "native")]
728#[async_trait]
729impl CompensableOperation for FileEditOp {
730    async fn execute(&self) -> Result<OperationResult> {
731        let checkpoint = Checkpoint::new(format!("file-edit-{}", self.path.display()))
732            .with_file(self.path.clone(), self.old_content.clone());
733
734        // Write new content
735        tokio::fs::write(&self.path, &self.new_content).await?;
736
737        Ok(
738            OperationResult::success(format!("file-edit-{}", self.path.display()))
739                .with_checkpoint(checkpoint),
740        )
741    }
742
743    async fn compensate(&self, result: &OperationResult) -> Result<()> {
744        if let Some(checkpoint) = &result.checkpoint
745            && let Some(file_state) = checkpoint.file_states.first()
746            && let Some(original_content) = &file_state.original_content
747        {
748            tokio::fs::write(&self.path, original_content).await?;
749        }
750        Ok(())
751    }
752
753    fn description(&self) -> String {
754        format!("Edit file: {}", self.path.display())
755    }
756
757    fn operation_type(&self) -> SagaOperationType {
758        SagaOperationType::FileEdit
759    }
760}
761
762/// Git stage operation with compensation
763pub struct GitStageOp {
764    /// Files to stage.
765    pub files: Vec<PathBuf>,
766    /// Working directory for the git command.
767    pub working_dir: PathBuf,
768}
769
770#[cfg(feature = "native")]
771#[async_trait]
772impl CompensableOperation for GitStageOp {
773    async fn execute(&self) -> Result<OperationResult> {
774        use tokio::process::Command;
775
776        for file in &self.files {
777            Command::new("git")
778                .args(["add", &file.display().to_string()])
779                .current_dir(&self.working_dir)
780                .output()
781                .await?;
782        }
783
784        Ok(OperationResult::success_with_data(
785            "git-stage",
786            serde_json::json!({
787                "files": self.files.iter().map(|f| f.display().to_string()).collect::<Vec<_>>(),
788            }),
789        ))
790    }
791
792    async fn compensate(&self, _result: &OperationResult) -> Result<()> {
793        use tokio::process::Command;
794
795        for file in &self.files {
796            Command::new("git")
797                .args(["reset", "HEAD", &file.display().to_string()])
798                .current_dir(&self.working_dir)
799                .output()
800                .await?;
801        }
802        Ok(())
803    }
804
805    fn description(&self) -> String {
806        format!("Git stage: {} files", self.files.len())
807    }
808
809    fn operation_type(&self) -> SagaOperationType {
810        SagaOperationType::GitStage
811    }
812}
813
814/// Git commit operation with compensation
815pub struct GitCommitOp {
816    /// Commit message.
817    pub message: String,
818    /// Working directory for the git command.
819    pub working_dir: PathBuf,
820}
821
822#[cfg(feature = "native")]
823#[async_trait]
824impl CompensableOperation for GitCommitOp {
825    async fn execute(&self) -> Result<OperationResult> {
826        use tokio::process::Command;
827
828        // Get current HEAD before commit
829        let head_output = Command::new("git")
830            .args(["rev-parse", "HEAD"])
831            .current_dir(&self.working_dir)
832            .output()
833            .await?;
834        let previous_head = String::from_utf8_lossy(&head_output.stdout)
835            .trim()
836            .to_string();
837
838        // Perform commit
839        let output = Command::new("git")
840            .args(["commit", "-m", &self.message])
841            .current_dir(&self.working_dir)
842            .output()
843            .await?;
844
845        if output.status.success() {
846            Ok(OperationResult::success_with_data(
847                "git-commit",
848                serde_json::json!({
849                    "previous_head": previous_head,
850                    "message": self.message,
851                }),
852            ))
853        } else {
854            anyhow::bail!(
855                "Git commit failed: {}",
856                String::from_utf8_lossy(&output.stderr)
857            )
858        }
859    }
860
861    async fn compensate(&self, _result: &OperationResult) -> Result<()> {
862        use tokio::process::Command;
863
864        // Soft reset to undo the commit but keep changes staged
865        Command::new("git")
866            .args(["reset", "--soft", "HEAD~1"])
867            .current_dir(&self.working_dir)
868            .output()
869            .await?;
870        Ok(())
871    }
872
873    fn description(&self) -> String {
874        format!("Git commit: {}", self.message)
875    }
876
877    fn operation_type(&self) -> SagaOperationType {
878        SagaOperationType::GitCommit
879    }
880}
881
882/// No-op compensable operation (for operations that don't need compensation)
883pub struct NoOpCompensation {
884    /// Description of the operation.
885    pub description: String,
886    /// Type of the operation.
887    pub op_type: SagaOperationType,
888}
889
890#[async_trait]
891impl CompensableOperation for NoOpCompensation {
892    async fn execute(&self) -> Result<OperationResult> {
893        Ok(OperationResult::success(&self.description))
894    }
895
896    async fn compensate(&self, _result: &OperationResult) -> Result<()> {
897        // No compensation needed
898        Ok(())
899    }
900
901    fn description(&self) -> String {
902        self.description.clone()
903    }
904
905    fn operation_type(&self) -> SagaOperationType {
906        self.op_type
907    }
908}
909
910#[cfg(test)]
911mod tests {
912    use super::*;
913    use tempfile::tempdir;
914
915    #[tokio::test]
916    async fn test_saga_executor_basic() {
917        let saga = SagaExecutor::new("test-agent", "test saga");
918
919        assert_eq!(saga.status().await, SagaStatus::Running);
920        assert_eq!(saga.operation_count().await, 0);
921    }
922
923    #[tokio::test]
924    async fn test_saga_execute_and_complete() {
925        let saga = SagaExecutor::new("test-agent", "test saga");
926
927        let op = Arc::new(NoOpCompensation {
928            description: "test op".to_string(),
929            op_type: SagaOperationType::Generic,
930        });
931
932        let result = saga.execute_step(op).await.unwrap();
933        assert!(result.success);
934
935        saga.complete().await;
936        assert_eq!(saga.status().await, SagaStatus::Completed);
937    }
938
939    #[tokio::test]
940    async fn test_saga_compensation() {
941        let saga = SagaExecutor::new("test-agent", "test saga");
942
943        // Execute a compensable operation
944        let op = Arc::new(NoOpCompensation {
945            description: "compensable op".to_string(),
946            op_type: SagaOperationType::FileWrite,
947        });
948
949        saga.execute_step(op).await.unwrap();
950
951        // Trigger compensation
952        saga.fail().await;
953        let report = saga.compensate_all().await.unwrap();
954
955        assert_eq!(saga.status().await, SagaStatus::Compensated);
956        assert_eq!(report.operations.len(), 1);
957    }
958
959    #[tokio::test]
960    async fn test_file_write_op_compensation() {
961        let dir = tempdir().unwrap();
962        let file_path = dir.path().join("test.txt");
963
964        // Create initial file
965        std::fs::write(&file_path, "original content").unwrap();
966
967        // Create file write operation
968        let op = FileWriteOp {
969            path: file_path.clone(),
970            content: "new content".to_string(),
971            is_new_file: false,
972        };
973
974        // Execute
975        let result = op.execute().await.unwrap();
976        assert!(result.success);
977
978        // Verify new content
979        let content = std::fs::read_to_string(&file_path).unwrap();
980        assert_eq!(content, "new content");
981
982        // Compensate
983        op.compensate(&result).await.unwrap();
984
985        // Verify original content restored
986        let content = std::fs::read_to_string(&file_path).unwrap();
987        assert_eq!(content, "original content");
988    }
989
990    #[tokio::test]
991    async fn test_file_write_new_file_compensation() {
992        let dir = tempdir().unwrap();
993        let file_path = dir.path().join("new_file.txt");
994
995        // Create file write operation for new file
996        let op = FileWriteOp {
997            path: file_path.clone(),
998            content: "new content".to_string(),
999            is_new_file: true,
1000        };
1001
1002        // Execute
1003        let result = op.execute().await.unwrap();
1004        assert!(result.success);
1005        assert!(file_path.exists());
1006
1007        // Compensate (should delete the file)
1008        op.compensate(&result).await.unwrap();
1009        assert!(!file_path.exists());
1010    }
1011
1012    #[tokio::test]
1013    async fn test_compensation_report() {
1014        let mut report = CompensationReport::new("test-saga".to_string());
1015
1016        report.add_success("op1");
1017        report.add_failure("op2", "error".to_string());
1018        report.add_skipped("op3", "non-compensable");
1019
1020        assert!(!report.all_successful());
1021        assert!(report.summary().contains("1 successful"));
1022        assert!(report.summary().contains("1 failed"));
1023        assert!(report.summary().contains("1 skipped"));
1024    }
1025
1026    #[tokio::test]
1027    async fn test_operation_type_compensable() {
1028        assert!(SagaOperationType::FileWrite.is_compensable());
1029        assert!(SagaOperationType::FileEdit.is_compensable());
1030        assert!(SagaOperationType::GitStage.is_compensable());
1031        assert!(SagaOperationType::GitCommit.is_compensable());
1032
1033        assert!(!SagaOperationType::Build.is_compensable());
1034        assert!(!SagaOperationType::Test.is_compensable());
1035        assert!(!SagaOperationType::Generic.is_compensable());
1036    }
1037
1038    #[tokio::test]
1039    async fn test_checkpoint_creation() {
1040        let checkpoint = Checkpoint::new("test-cp")
1041            .with_file(PathBuf::from("/test/file.txt"), "content".to_string());
1042
1043        assert_eq!(checkpoint.id, "test-cp");
1044        assert_eq!(checkpoint.file_states.len(), 1);
1045        assert_eq!(
1046            checkpoint.file_states[0].path,
1047            PathBuf::from("/test/file.txt")
1048        );
1049        assert!(checkpoint.file_states[0].original_content.is_some());
1050    }
1051}