use crate::{Result, WorkItem};
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fmt::Debug;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum StageStatus {
Pending,
InProgress,
Complete,
Failed,
Paused,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StageState {
pub status: StageStatus,
pub started_at: Option<DateTime<Utc>>,
pub completed_at: Option<DateTime<Utc>>,
pub retry_count: u32,
pub error: Option<String>,
pub subtasks: Vec<SubTask>,
}
impl StageState {
pub fn new() -> Self {
Self {
status: StageStatus::Pending,
started_at: None,
completed_at: None,
retry_count: 0,
error: None,
subtasks: Vec::new(),
}
}
pub fn mark_in_progress(&mut self) {
self.status = StageStatus::InProgress;
if self.started_at.is_none() {
self.started_at = Some(Utc::now());
}
}
pub fn mark_complete(&mut self) {
self.status = StageStatus::Complete;
self.completed_at = Some(Utc::now());
}
pub fn mark_failed(&mut self, error: String) {
self.status = StageStatus::Failed;
self.completed_at = Some(Utc::now());
self.error = Some(error);
}
pub fn mark_paused(&mut self) {
self.status = StageStatus::Paused;
}
pub fn increment_retry(&mut self) {
self.retry_count += 1;
}
}
impl Default for StageState {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct SubTask {
pub id: String,
pub status: StageStatus,
pub retry_count: u32,
pub error: Option<String>,
pub metadata: HashMap<String, String>,
}
impl SubTask {
pub fn new(id: String) -> Self {
Self {
id,
status: StageStatus::Pending,
retry_count: 0,
error: None,
metadata: HashMap::new(),
}
}
pub fn mark_complete(&mut self) {
self.status = StageStatus::Complete;
}
pub fn mark_failed(&mut self, error: String) {
self.status = StageStatus::Failed;
self.error = Some(error);
}
pub fn increment_retry(&mut self) {
self.retry_count += 1;
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReviewData {
pub work_item_id: String,
pub stage_name: String,
pub prompt: String,
pub context: HashMap<String, serde_json::Value>,
pub requested_at: DateTime<Utc>,
pub completed_at: Option<DateTime<Utc>>,
pub decision: Option<String>,
pub comments: Option<String>,
}
impl ReviewData {
pub fn new(work_item_id: String, stage_name: String, prompt: String) -> Self {
Self {
work_item_id,
stage_name,
prompt,
context: HashMap::new(),
requested_at: Utc::now(),
completed_at: None,
decision: None,
comments: None,
}
}
pub fn with_context(mut self, key: String, value: serde_json::Value) -> Self {
self.context.insert(key, value);
self
}
pub fn complete(&mut self, decision: String, comments: Option<String>) {
self.decision = Some(decision);
self.comments = comments;
self.completed_at = Some(Utc::now());
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum StageOutcome {
Complete,
NeedsReview,
Retry,
Failed,
FanOut(Vec<SubTask>),
}
#[derive(Debug, Clone)]
pub struct StageContext {
pub stage_name: String,
pub stage_state: StageState,
pub metadata: HashMap<String, serde_json::Value>,
pub subtask_name: Option<String>,
}
impl StageContext {
pub fn new(stage_name: String) -> Self {
Self {
stage_name,
stage_state: StageState::new(),
metadata: HashMap::new(),
subtask_name: None,
}
}
pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
self.metadata.insert(key, value);
self
}
pub fn with_subtask(mut self, subtask_name: impl Into<String>) -> Self {
self.subtask_name = Some(subtask_name.into());
self
}
pub fn get_metadata(&self, key: &str) -> Option<&serde_json::Value> {
self.metadata.get(key)
}
}
#[async_trait]
pub trait Stage: Debug + Send + Sync {
async fn execute(
&self,
item: &dyn WorkItem,
context: &mut StageContext,
) -> Result<StageOutcome>;
fn name(&self) -> &str;
async fn before_execute(&self, _item: &dyn WorkItem, _context: &StageContext) -> Result<()> {
Ok(())
}
async fn after_execute(
&self,
_item: &dyn WorkItem,
_context: &StageContext,
_outcome: &StageOutcome,
) -> Result<()> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_stage_status_equality() {
assert_eq!(StageStatus::Pending, StageStatus::Pending);
assert_ne!(StageStatus::Pending, StageStatus::InProgress);
}
#[test]
fn test_stage_status_serialize() {
let status = StageStatus::InProgress;
let json = serde_json::to_string(&status).unwrap();
assert!(json.contains("InProgress"));
}
#[test]
fn test_stage_status_deserialize() {
let json = r#""Complete""#;
let status: StageStatus = serde_json::from_str(json).unwrap();
assert_eq!(status, StageStatus::Complete);
}
#[test]
fn test_stage_state_new() {
let state = StageState::new();
assert_eq!(state.status, StageStatus::Pending);
assert_eq!(state.retry_count, 0);
assert!(state.started_at.is_none());
assert!(state.completed_at.is_none());
assert!(state.error.is_none());
assert!(state.subtasks.is_empty());
}
#[test]
fn test_stage_state_default() {
let state = StageState::default();
assert_eq!(state.status, StageStatus::Pending);
}
#[test]
fn test_stage_state_mark_in_progress() {
let mut state = StageState::new();
state.mark_in_progress();
assert_eq!(state.status, StageStatus::InProgress);
assert!(state.started_at.is_some());
}
#[test]
fn test_stage_state_mark_complete() {
let mut state = StageState::new();
state.mark_complete();
assert_eq!(state.status, StageStatus::Complete);
assert!(state.completed_at.is_some());
}
#[test]
fn test_stage_state_mark_failed() {
let mut state = StageState::new();
state.mark_failed("test error".to_string());
assert_eq!(state.status, StageStatus::Failed);
assert_eq!(state.error, Some("test error".to_string()));
assert!(state.completed_at.is_some());
}
#[test]
fn test_stage_state_mark_paused() {
let mut state = StageState::new();
state.mark_paused();
assert_eq!(state.status, StageStatus::Paused);
}
#[test]
fn test_stage_state_increment_retry() {
let mut state = StageState::new();
assert_eq!(state.retry_count, 0);
state.increment_retry();
assert_eq!(state.retry_count, 1);
state.increment_retry();
assert_eq!(state.retry_count, 2);
}
#[test]
fn test_stage_state_serialize() {
let state = StageState::new();
let json = serde_json::to_string(&state).unwrap();
assert!(json.contains("Pending"));
}
#[test]
fn test_stage_state_deserialize() {
let json = r#"{"status":"Complete","started_at":null,"completed_at":null,"retry_count":0,"error":null,"subtasks":[]}"#;
let state: StageState = serde_json::from_str(json).unwrap();
assert_eq!(state.status, StageStatus::Complete);
}
#[test]
fn test_subtask_new() {
let subtask = SubTask::new("sub-1".to_string());
assert_eq!(subtask.id, "sub-1");
assert_eq!(subtask.status, StageStatus::Pending);
assert_eq!(subtask.retry_count, 0);
assert!(subtask.error.is_none());
assert!(subtask.metadata.is_empty());
}
#[test]
fn test_subtask_mark_complete() {
let mut subtask = SubTask::new("sub-2".to_string());
subtask.mark_complete();
assert_eq!(subtask.status, StageStatus::Complete);
}
#[test]
fn test_subtask_mark_failed() {
let mut subtask = SubTask::new("sub-3".to_string());
subtask.mark_failed("subtask error".to_string());
assert_eq!(subtask.status, StageStatus::Failed);
assert_eq!(subtask.error, Some("subtask error".to_string()));
}
#[test]
fn test_subtask_increment_retry() {
let mut subtask = SubTask::new("sub-4".to_string());
assert_eq!(subtask.retry_count, 0);
subtask.increment_retry();
assert_eq!(subtask.retry_count, 1);
}
#[test]
fn test_subtask_metadata() {
let mut subtask = SubTask::new("sub-5".to_string());
subtask
.metadata
.insert("key".to_string(), "value".to_string());
assert_eq!(subtask.metadata.get("key"), Some(&"value".to_string()));
}
#[test]
fn test_subtask_serialize() {
let subtask = SubTask::new("sub-6".to_string());
let json = serde_json::to_string(&subtask).unwrap();
assert!(json.contains("sub-6"));
}
#[test]
fn test_review_data_new() {
let review = ReviewData::new(
"item-1".to_string(),
"review-stage".to_string(),
"Please review".to_string(),
);
assert_eq!(review.work_item_id, "item-1");
assert_eq!(review.stage_name, "review-stage");
assert_eq!(review.prompt, "Please review");
assert!(review.context.is_empty());
assert!(review.decision.is_none());
assert!(review.comments.is_none());
}
#[test]
fn test_review_data_with_context() {
let review = ReviewData::new(
"item-2".to_string(),
"review-stage".to_string(),
"Review this".to_string(),
)
.with_context("key".to_string(), serde_json::json!("value"));
assert_eq!(review.context.get("key"), Some(&serde_json::json!("value")));
}
#[test]
fn test_review_data_complete() {
let mut review = ReviewData::new(
"item-3".to_string(),
"review-stage".to_string(),
"Review".to_string(),
);
review.complete("approved".to_string(), Some("looks good".to_string()));
assert_eq!(review.decision, Some("approved".to_string()));
assert_eq!(review.comments, Some("looks good".to_string()));
assert!(review.completed_at.is_some());
}
#[test]
fn test_review_data_complete_without_comments() {
let mut review = ReviewData::new(
"item-4".to_string(),
"review-stage".to_string(),
"Review".to_string(),
);
review.complete("rejected".to_string(), None);
assert_eq!(review.decision, Some("rejected".to_string()));
assert!(review.comments.is_none());
assert!(review.completed_at.is_some());
}
#[test]
fn test_review_data_serialize() {
let review = ReviewData::new(
"item-5".to_string(),
"stage".to_string(),
"prompt".to_string(),
);
let json = serde_json::to_string(&review).unwrap();
assert!(json.contains("item-5"));
}
#[test]
fn test_stage_outcome_equality() {
assert_eq!(StageOutcome::Complete, StageOutcome::Complete);
assert_ne!(StageOutcome::Complete, StageOutcome::Failed);
}
#[test]
fn test_stage_outcome_variants() {
let outcomes = [
StageOutcome::Complete,
StageOutcome::NeedsReview,
StageOutcome::Retry,
StageOutcome::Failed,
];
assert_eq!(outcomes.len(), 4);
}
#[test]
fn test_stage_state_preserves_started_at() {
let mut state = StageState::new();
state.mark_in_progress();
let first_start = state.started_at;
state.mark_in_progress();
assert_eq!(state.started_at, first_start);
}
#[test]
fn test_stage_state_with_subtasks() {
let mut state = StageState::new();
let subtask1 = SubTask::new("sub-1".to_string());
let subtask2 = SubTask::new("sub-2".to_string());
state.subtasks.push(subtask1);
state.subtasks.push(subtask2);
assert_eq!(state.subtasks.len(), 2);
}
#[test]
fn test_review_data_multiple_contexts() {
let review = ReviewData::new(
"item-6".to_string(),
"stage".to_string(),
"prompt".to_string(),
)
.with_context("key1".to_string(), serde_json::json!("value1"))
.with_context("key2".to_string(), serde_json::json!(42));
assert_eq!(review.context.len(), 2);
assert_eq!(
review.context.get("key1"),
Some(&serde_json::json!("value1"))
);
assert_eq!(review.context.get("key2"), Some(&serde_json::json!(42)));
}
#[test]
fn test_stage_context_new() {
let context = StageContext::new("test-stage".to_string());
assert_eq!(context.stage_name, "test-stage");
assert_eq!(context.stage_state.status, StageStatus::Pending);
assert!(context.metadata.is_empty());
}
#[test]
fn test_stage_context_with_metadata() {
let context = StageContext::new("test-stage".to_string())
.with_metadata("key".to_string(), serde_json::json!("value"));
assert_eq!(
context.get_metadata("key"),
Some(&serde_json::json!("value"))
);
}
#[test]
fn test_stage_context_get_metadata_missing() {
let context = StageContext::new("test-stage".to_string());
assert_eq!(context.get_metadata("missing"), None);
}
#[test]
fn test_stage_context_multiple_metadata() {
let context = StageContext::new("test-stage".to_string())
.with_metadata("key1".to_string(), serde_json::json!("value1"))
.with_metadata("key2".to_string(), serde_json::json!(42));
assert_eq!(context.metadata.len(), 2);
assert_eq!(
context.get_metadata("key1"),
Some(&serde_json::json!("value1"))
);
assert_eq!(context.get_metadata("key2"), Some(&serde_json::json!(42)));
}
#[derive(Debug)]
struct TestStage {
name: String,
}
#[async_trait]
impl Stage for TestStage {
async fn execute(
&self,
_item: &dyn WorkItem,
_context: &mut StageContext,
) -> Result<StageOutcome> {
Ok(StageOutcome::Complete)
}
fn name(&self) -> &str {
&self.name
}
}
#[tokio::test]
async fn test_stage_execute() {
use crate::WorkItem;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct TestItem {
id: String,
}
impl WorkItem for TestItem {
fn id(&self) -> &str {
&self.id
}
}
let stage = TestStage {
name: "test".to_string(),
};
let item = TestItem {
id: "test-1".to_string(),
};
let mut context = StageContext::new("test".to_string());
let outcome = stage.execute(&item, &mut context).await.unwrap();
assert_eq!(outcome, StageOutcome::Complete);
}
#[tokio::test]
async fn test_stage_name() {
let stage = TestStage {
name: "my-stage".to_string(),
};
assert_eq!(stage.name(), "my-stage");
}
#[tokio::test]
async fn test_stage_before_execute() {
use crate::WorkItem;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct TestItem {
id: String,
}
impl WorkItem for TestItem {
fn id(&self) -> &str {
&self.id
}
}
let stage = TestStage {
name: "test".to_string(),
};
let item = TestItem {
id: "test-1".to_string(),
};
let context = StageContext::new("test".to_string());
let result = stage.before_execute(&item, &context).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_stage_after_execute() {
use crate::WorkItem;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct TestItem {
id: String,
}
impl WorkItem for TestItem {
fn id(&self) -> &str {
&self.id
}
}
let stage = TestStage {
name: "test".to_string(),
};
let item = TestItem {
id: "test-1".to_string(),
};
let context = StageContext::new("test".to_string());
let outcome = StageOutcome::Complete;
let result = stage.after_execute(&item, &context, &outcome).await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_stage_trait_object() {
use crate::WorkItem;
#[derive(Debug, Clone, Serialize, Deserialize)]
struct TestItem {
id: String,
}
impl WorkItem for TestItem {
fn id(&self) -> &str {
&self.id
}
}
let stage: Box<dyn Stage> = Box::new(TestStage {
name: "boxed".to_string(),
});
let item = TestItem {
id: "test-1".to_string(),
};
let mut context = StageContext::new("boxed".to_string());
assert_eq!(stage.name(), "boxed");
let outcome = stage.execute(&item, &mut context).await.unwrap();
assert_eq!(outcome, StageOutcome::Complete);
}
#[test]
fn test_stage_context_clone() {
let context = StageContext::new("test".to_string())
.with_metadata("key".to_string(), serde_json::json!("value"));
let cloned = context.clone();
assert_eq!(cloned.stage_name, "test");
assert_eq!(
cloned.get_metadata("key"),
Some(&serde_json::json!("value"))
);
}
}