use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use uuid::Uuid;
use crate::domain::version::SemanticVersion;
#[derive(Debug, thiserror::Error)]
pub enum StoreError {
#[error("Checkpoint not found: {0}")]
NotFound(String),
#[error("Failed to save checkpoint: {0}")]
SaveFailed(String),
#[error("Failed to load checkpoint: {0}")]
LoadFailed(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Serialization error: {0}")]
SerializationError(String),
}
#[async_trait]
pub trait CheckpointStore: Send + Sync {
async fn save(&self, checkpoint: &Checkpoint) -> Result<(), StoreError>;
async fn load(&self) -> Result<Option<Checkpoint>, StoreError>;
async fn load_by_id(&self, id: &str) -> Result<Option<Checkpoint>, StoreError>;
async fn list(&self) -> Result<Vec<CheckpointInfo>, StoreError>;
async fn delete(&self, id: &str) -> Result<(), StoreError>;
async fn clear(&self) -> Result<(), StoreError>;
fn name(&self) -> &str;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Checkpoint {
pub id: String,
pub workflow_id: String,
pub step_index: usize,
pub completed_steps: Vec<String>,
pub state: serde_json::Value,
pub target_version: Option<SemanticVersion>,
pub timestamp: chrono::DateTime<chrono::Utc>,
pub interrupted_by_error: bool,
pub error_message: Option<String>,
}
impl Checkpoint {
#[must_use]
pub fn new(
workflow_id: String,
step_index: usize,
completed_steps: Vec<String>,
state: serde_json::Value,
) -> Self {
Self {
id: format!("chk_{}", Uuid::new_v4().simple()),
workflow_id,
step_index,
completed_steps,
state,
target_version: None,
timestamp: chrono::Utc::now(),
interrupted_by_error: false,
error_message: None,
}
}
#[must_use]
pub fn id(&self) -> &str {
&self.id
}
#[must_use]
pub const fn can_resume(&self) -> bool {
!self.interrupted_by_error || self.completed_steps.is_empty()
}
#[must_use]
pub fn with_error(mut self, message: String) -> Self {
self.interrupted_by_error = true;
self.error_message = Some(message);
self
}
#[must_use]
pub fn with_version(mut self, version: SemanticVersion) -> Self {
self.target_version = Some(version);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CheckpointInfo {
pub id: String,
pub workflow_id: String,
pub step_index: usize,
pub completed_count: usize,
pub timestamp: chrono::DateTime<chrono::Utc>,
pub has_error: bool,
}
impl From<Checkpoint> for CheckpointInfo {
fn from(checkpoint: Checkpoint) -> Self {
Self {
id: checkpoint.id,
workflow_id: checkpoint.workflow_id,
step_index: checkpoint.step_index,
completed_count: checkpoint.completed_steps.len(),
timestamp: checkpoint.timestamp,
has_error: checkpoint.interrupted_by_error,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WorkflowState {
pub crate_versions: HashMap<String, SemanticVersion>,
pub published_crates: Vec<String>,
pub failed_crates: Vec<String>,
pub commit_hash: Option<String>,
pub tag: Option<String>,
pub extra_data: HashMap<String, serde_json::Value>,
}
impl WorkflowState {
#[must_use]
pub fn new() -> Self {
Self {
crate_versions: HashMap::new(),
published_crates: Vec::new(),
failed_crates: Vec::new(),
commit_hash: None,
tag: None,
extra_data: HashMap::new(),
}
}
pub fn add_crate_version(&mut self, name: String, version: SemanticVersion) {
self.crate_versions.insert(name, version);
}
pub fn mark_published(&mut self, name: String) {
self.published_crates.push(name);
}
pub fn mark_failed(&mut self, name: String) {
self.failed_crates.push(name);
}
#[must_use]
pub fn with_commit_hash(mut self, hash: String) -> Self {
self.commit_hash = Some(hash);
self
}
#[must_use]
pub fn with_tag(mut self, tag: String) -> Self {
self.tag = Some(tag);
self
}
#[must_use]
pub fn with_extra(mut self, key: String, value: serde_json::Value) -> Self {
self.extra_data.insert(key, value);
self
}
}
impl Default for WorkflowState {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Default)]
pub struct MemoryCheckpointStore {
checkpoints: std::sync::Arc<std::sync::Mutex<Vec<Checkpoint>>>,
}
#[async_trait]
impl CheckpointStore for MemoryCheckpointStore {
async fn save(&self, checkpoint: &Checkpoint) -> Result<(), StoreError> {
self.checkpoints
.lock()
.map_err(|e| StoreError::SaveFailed(format!("Failed to acquire lock: {e}")))?
.push(checkpoint.clone());
Ok(())
}
async fn load(&self) -> Result<Option<Checkpoint>, StoreError> {
let checkpoints = self
.checkpoints
.lock()
.map_err(|e| StoreError::LoadFailed(format!("Failed to acquire lock: {e}")))?;
Ok(checkpoints.last().cloned())
}
async fn load_by_id(&self, id: &str) -> Result<Option<Checkpoint>, StoreError> {
let checkpoints = self
.checkpoints
.lock()
.map_err(|e| StoreError::LoadFailed(format!("Failed to acquire lock: {e}")))?;
Ok(checkpoints.iter().find(|c| c.id == id).cloned())
}
async fn list(&self) -> Result<Vec<CheckpointInfo>, StoreError> {
let checkpoints = self
.checkpoints
.lock()
.map_err(|e| StoreError::LoadFailed(format!("Failed to acquire lock: {e}")))?;
Ok(checkpoints
.iter()
.cloned()
.map(CheckpointInfo::from)
.collect())
}
async fn delete(&self, id: &str) -> Result<(), StoreError> {
self.checkpoints
.lock()
.map_err(|e| StoreError::SaveFailed(format!("Failed to acquire lock: {e}")))?
.retain(|c| c.id != id);
Ok(())
}
async fn clear(&self) -> Result<(), StoreError> {
self.checkpoints
.lock()
.map_err(|e| StoreError::SaveFailed(format!("Failed to acquire lock: {e}")))?
.clear();
Ok(())
}
fn name(&self) -> &'static str {
"memory"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_memory_checkpoint_store() {
let store = MemoryCheckpointStore::default();
let checkpoint = Checkpoint::new(
"test-workflow".to_string(),
0,
vec![],
serde_json::Value::Object(serde_json::Map::default()),
);
store.save(&checkpoint).await.unwrap();
let loaded = store.load().await.unwrap();
assert!(loaded.is_some());
assert_eq!(loaded.unwrap().workflow_id, "test-workflow");
}
#[tokio::test]
async fn test_checkpoint_list() {
let store = MemoryCheckpointStore::default();
let checkpoint1 = Checkpoint::new(
"workflow-1".to_string(),
0,
vec![],
serde_json::Value::Object(serde_json::Map::default()),
);
let checkpoint2 = Checkpoint::new(
"workflow-2".to_string(),
1,
vec!["step1".to_string()],
serde_json::Value::Object(serde_json::Map::default()),
);
store.save(&checkpoint1).await.unwrap();
store.save(&checkpoint2).await.unwrap();
let list = store.list().await.unwrap();
assert_eq!(list.len(), 2);
}
#[tokio::test]
async fn test_workflow_state() {
let mut state = WorkflowState::new();
state.add_crate_version(
"crate1".to_string(),
SemanticVersion::parse("1.0.0").unwrap(),
);
state.mark_published("crate1".to_string());
assert_eq!(state.crate_versions.len(), 1);
assert_eq!(state.published_crates.len(), 1);
assert_eq!(state.failed_crates.len(), 0);
}
}