use cloacina::{task, CheckpointError, Context, Task, TaskError, TaskNamespace};
use std::sync::{Arc, Mutex};
struct CheckpointableTask {
id: String,
dependencies: Vec<TaskNamespace>,
checkpoint_data: Arc<Mutex<Option<String>>>,
}
impl CheckpointableTask {
fn new(id: &str, dependencies: Vec<&str>) -> Self {
Self {
id: id.to_string(),
dependencies: dependencies
.into_iter()
.map(|s| TaskNamespace::from_string(s).unwrap())
.collect(),
checkpoint_data: Arc::new(Mutex::new(None)),
}
}
fn get_checkpoint_data(&self) -> Option<String> {
self.checkpoint_data.lock().unwrap().clone()
}
}
#[async_trait::async_trait]
impl Task for CheckpointableTask {
async fn execute(
&self,
mut context: Context<serde_json::Value>,
) -> Result<Context<serde_json::Value>, TaskError> {
context
.insert("processed_by", serde_json::json!(self.id))
.map_err(|e| TaskError::ContextError {
task_id: self.id.clone(),
error: e,
})?;
context
.insert(
"timestamp",
serde_json::json!(chrono::Utc::now().to_rfc3339()),
)
.map_err(|e| TaskError::ContextError {
task_id: self.id.clone(),
error: e,
})?;
Ok(context)
}
fn id(&self) -> &str {
&self.id
}
fn dependencies(&self) -> &[TaskNamespace] {
&self.dependencies
}
fn checkpoint(&self, context: &Context<serde_json::Value>) -> Result<(), CheckpointError> {
let checkpoint_json = context.to_json().map_err(|e| CheckpointError::SaveFailed {
task_id: self.id.clone(),
message: format!("Failed to serialize context: {:?}", e),
})?;
*self.checkpoint_data.lock().unwrap() = Some(checkpoint_json);
Ok(())
}
}
#[test]
fn test_default_checkpoint_implementation() {
#[task(id = "simple-task", dependencies = [])]
async fn simple_task(_context: &mut Context<serde_json::Value>) -> Result<(), TaskError> {
Ok(())
}
let task = simple_task_task();
let context = Context::new();
let result = task.checkpoint(&context);
assert!(result.is_ok());
}
#[test]
fn test_custom_checkpoint_save() {
let task = CheckpointableTask::new("checkpoint-task", vec![]);
let mut context = Context::new();
context
.insert("test_data", serde_json::json!("checkpoint_test"))
.unwrap();
context.insert("number", serde_json::json!(42)).unwrap();
let result = task.checkpoint(&context);
assert!(result.is_ok());
let checkpoint_data = task.get_checkpoint_data();
assert!(checkpoint_data.is_some());
let saved_data = checkpoint_data.unwrap();
assert!(saved_data.contains("checkpoint_test"));
assert!(saved_data.contains("42"));
}
#[test]
fn test_checkpoint_restore() {
let task = CheckpointableTask::new("restore-task", vec![]);
let mut original_context = Context::new();
original_context
.insert("original_data", serde_json::json!("test_value"))
.unwrap();
original_context
.insert("count", serde_json::json!(100))
.unwrap();
task.checkpoint(&original_context).unwrap();
let checkpoint_data = task.get_checkpoint_data().unwrap();
let restored_context: Context<serde_json::Value> = Context::from_json(checkpoint_data).unwrap();
assert_eq!(
restored_context.get("original_data").unwrap(),
&serde_json::json!("test_value")
);
assert_eq!(
restored_context.get("count").unwrap(),
&serde_json::json!(100)
);
}
#[test]
fn test_checkpoint_serialization_error() {
struct FailingCheckpointTask;
#[async_trait::async_trait]
impl Task for FailingCheckpointTask {
async fn execute(
&self,
context: Context<serde_json::Value>,
) -> Result<Context<serde_json::Value>, TaskError> {
Ok(context)
}
fn id(&self) -> &str {
"failing-checkpoint"
}
fn dependencies(&self) -> &[TaskNamespace] {
&[]
}
fn checkpoint(&self, _context: &Context<serde_json::Value>) -> Result<(), CheckpointError> {
Err(CheckpointError::SaveFailed {
task_id: self.id().to_string(),
message: "Simulated checkpoint failure".to_string(),
})
}
}
let task = FailingCheckpointTask;
let context = Context::new();
let result = task.checkpoint(&context);
assert!(result.is_err());
match result.unwrap_err() {
CheckpointError::SaveFailed { task_id, message } => {
assert_eq!(task_id, "failing-checkpoint");
assert!(message.contains("Simulated checkpoint failure"));
}
_ => panic!("Expected SaveFailed error"),
}
}
#[test]
fn test_checkpoint_validation() {
let task = CheckpointableTask::new("validation-task", vec![]);
let mut context = Context::new();
context
.insert("valid_field", serde_json::json!("valid_value"))
.unwrap();
let result = task.checkpoint(&context);
assert!(result.is_ok());
let empty_context = Context::new();
let empty_result = task.checkpoint(&empty_context);
assert!(empty_result.is_ok());
}