use async_trait::async_trait;
use serde::{Deserialize, Serialize};
#[derive(Debug, thiserror::Error)]
pub enum OperationError {
#[error("Operation failed: {0}")]
Failed(String),
#[error("Rollback failed: {0}")]
RollbackFailed(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Serialization error: {0}")]
SerializationError(String),
#[error("State not found for rollback")]
StateNotFound,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OperationResult {
pub success: bool,
pub operation_name: String,
pub duration_ms: u64,
pub message: Option<String>,
pub state: Option<serde_json::Value>,
pub data: serde_json::Value,
}
impl OperationResult {
#[must_use]
pub fn success(name: String, duration_ms: u64) -> Self {
Self {
success: true,
operation_name: name,
duration_ms,
message: None,
state: None,
data: serde_json::Value::Object(serde_json::Map::default()),
}
}
#[must_use]
pub fn failed(name: String, message: String) -> Self {
Self {
success: false,
operation_name: name,
duration_ms: 0,
message: Some(message),
state: None,
data: serde_json::Value::Object(serde_json::Map::default()),
}
}
#[must_use]
pub fn with_state(mut self, state: serde_json::Value) -> Self {
self.state = Some(state);
self
}
#[must_use]
pub fn with_data(mut self, data: serde_json::Value) -> Self {
self.data = data;
self
}
#[must_use]
pub fn with_message(mut self, message: String) -> Self {
self.message = Some(message);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OperationState {
pub operation_name: String,
pub data: serde_json::Value,
pub captured_at: chrono::DateTime<chrono::Utc>,
}
impl OperationState {
#[must_use]
pub fn new(operation_name: String, data: serde_json::Value) -> Self {
Self {
operation_name,
data,
captured_at: chrono::Utc::now(),
}
}
}
#[async_trait]
pub trait ReversibleOperation: Send + Sync {
fn name(&self) -> &str;
async fn execute(&self) -> Result<OperationResult, OperationError>;
async fn rollback(&self, state: &OperationState) -> Result<(), OperationError>;
async fn capture_state(&self) -> Result<OperationState, OperationError>;
fn is_idempotent(&self) -> bool {
false
}
}
pub struct OperationChain {
operations: Vec<Box<dyn ReversibleOperation>>,
states: Vec<Option<OperationState>>,
completed: usize,
}
impl OperationChain {
#[must_use]
pub fn new() -> Self {
Self {
operations: Vec::new(),
states: Vec::new(),
completed: 0,
}
}
pub fn add(&mut self, operation: Box<dyn ReversibleOperation>) {
self.operations.push(operation);
self.states.push(None);
}
pub async fn execute_all(&mut self) -> Result<Vec<OperationResult>, OperationError> {
let mut results = Vec::new();
for (i, op) in self.operations.iter().enumerate() {
let state = match op.capture_state().await {
Ok(s) => Some(s),
Err(e) => {
self.rollback_to(i).await?;
return Err(e);
}
};
self.states[i] = state;
let result = op.execute().await?;
if !result.success {
self.rollback_to(i).await?;
return Err(OperationError::Failed(result.message.unwrap_or_default()));
}
self.completed = i + 1;
results.push(result);
}
Ok(results)
}
pub async fn rollback_all(&mut self) -> Result<(), OperationError> {
self.rollback_to(self.completed).await
}
async fn rollback_to(&mut self, index: usize) -> Result<(), OperationError> {
for i in (0..index).rev() {
if let Some(op) = self.operations.get(i)
&& let Some(state) = &self.states[i]
{
op.rollback(state).await?;
}
}
self.completed = 0;
self.states = vec![None; self.operations.len()];
Ok(())
}
#[must_use]
pub fn len(&self) -> usize {
self.operations.len()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.operations.is_empty()
}
}
impl Default for OperationChain {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
struct TestOperation {
name: String,
should_fail: bool,
}
#[async_trait]
impl ReversibleOperation for TestOperation {
fn name(&self) -> &str {
&self.name
}
async fn execute(&self) -> Result<OperationResult, OperationError> {
if self.should_fail {
Ok(OperationResult::failed(
self.name.clone(),
"Test failure".to_string(),
))
} else {
Ok(OperationResult::success(self.name.clone(), 10))
}
}
async fn rollback(&self, _state: &OperationState) -> Result<(), OperationError> {
Ok(())
}
async fn capture_state(&self) -> Result<OperationState, OperationError> {
Ok(OperationState::new(
self.name.clone(),
serde_json::Value::Null,
))
}
}
#[tokio::test]
async fn test_operation_chain_success() {
let mut chain = OperationChain::new();
chain.add(Box::new(TestOperation {
name: "op1".to_string(),
should_fail: false,
}));
chain.add(Box::new(TestOperation {
name: "op2".to_string(),
should_fail: false,
}));
let results = chain.execute_all().await.unwrap();
assert_eq!(results.len(), 2);
assert!(results[0].success);
assert!(results[1].success);
}
#[tokio::test]
async fn test_operation_chain_failure() {
let mut chain = OperationChain::new();
chain.add(Box::new(TestOperation {
name: "op1".to_string(),
should_fail: false,
}));
chain.add(Box::new(TestOperation {
name: "op2".to_string(),
should_fail: true,
}));
let result = chain.execute_all().await;
assert!(result.is_err());
}
}