use async_trait::async_trait;
use std::sync::atomic::{AtomicU32, Ordering};
use std::sync::Arc;
use treadle::{
MemoryStateStore, Result, Stage, StageContext,
StageOutcome, StageStatus, StateStore, SubTask,
WorkItem, Workflow, WorkflowEvent,
};
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct Document {
id: String,
content: String,
}
impl Document {
fn new(id: impl Into<String>, content: impl Into<String>) -> Self {
Self {
id: id.into(),
content: content.into(),
}
}
}
impl WorkItem for Document {
fn id(&self) -> &str {
&self.id
}
}
#[derive(Debug)]
struct ParseStage;
#[async_trait]
impl Stage for ParseStage {
fn name(&self) -> &str {
"parse"
}
async fn execute(&self, _item: &dyn WorkItem, _ctx: &mut StageContext) -> Result<StageOutcome> {
Ok(StageOutcome::Complete)
}
}
#[derive(Debug)]
struct EnrichStage {
call_count: Arc<AtomicU32>,
}
impl EnrichStage {
fn new() -> Self {
Self {
call_count: Arc::new(AtomicU32::new(0)),
}
}
fn with_counter(counter: Arc<AtomicU32>) -> Self {
Self { call_count: counter }
}
}
#[async_trait]
impl Stage for EnrichStage {
fn name(&self) -> &str {
"enrich"
}
async fn execute(&self, _item: &dyn WorkItem, ctx: &mut StageContext) -> Result<StageOutcome> {
if ctx.subtask_name.is_some() {
self.call_count.fetch_add(1, Ordering::SeqCst);
Ok(StageOutcome::Complete)
} else {
Ok(StageOutcome::FanOut(vec![
SubTask::new("source-1".to_string()),
SubTask::new("source-2".to_string()),
SubTask::new("source-3".to_string()),
]))
}
}
}
#[derive(Debug)]
struct ReviewStage;
#[async_trait]
impl Stage for ReviewStage {
fn name(&self) -> &str {
"review"
}
async fn execute(&self, _item: &dyn WorkItem, _ctx: &mut StageContext) -> Result<StageOutcome> {
Ok(StageOutcome::NeedsReview)
}
}
#[derive(Debug)]
struct ExportStage {
exported: Arc<AtomicU32>,
}
impl ExportStage {
fn new() -> Self {
Self {
exported: Arc::new(AtomicU32::new(0)),
}
}
fn with_counter(counter: Arc<AtomicU32>) -> Self {
Self { exported: counter }
}
}
#[async_trait]
impl Stage for ExportStage {
fn name(&self) -> &str {
"export"
}
async fn execute(&self, _item: &dyn WorkItem, _ctx: &mut StageContext) -> Result<StageOutcome> {
self.exported.fetch_add(1, Ordering::SeqCst);
Ok(StageOutcome::Complete)
}
}
#[tokio::test]
async fn test_full_pipeline_with_memory_store() {
let enrich_counter = Arc::new(AtomicU32::new(0));
let export_counter = Arc::new(AtomicU32::new(0));
let workflow = Workflow::builder()
.stage("parse", ParseStage)
.stage("enrich", EnrichStage::with_counter(enrich_counter.clone()))
.stage("review", ReviewStage)
.stage("export", ExportStage::with_counter(export_counter.clone()))
.dependency("enrich", "parse")
.dependency("review", "enrich")
.dependency("export", "review")
.build()
.expect("workflow should build");
let mut event_receiver = workflow.subscribe();
let mut collected_events = Vec::new();
let mut store = MemoryStateStore::new();
let doc = Document::new("doc-1", "This is a test document with some content.");
workflow.advance(&doc, &mut store).await.expect("advance should succeed");
while let Ok(event) = event_receiver.try_recv() {
collected_events.push(event);
}
let status = workflow.status(doc.id(), &store).await.expect("status should succeed");
assert!(!status.is_complete());
assert!(status.has_pending_reviews());
assert_eq!(status.review_stages(), vec!["review"]);
let parse_state = store.get_stage_state(doc.id(), "parse").await.unwrap().unwrap();
assert!(matches!(parse_state.status, StageStatus::Complete));
let enrich_state = store.get_stage_state(doc.id(), "enrich").await.unwrap().unwrap();
assert!(matches!(enrich_state.status, StageStatus::Complete));
assert_eq!(enrich_counter.load(Ordering::SeqCst), 3);
let review_state = store.get_stage_state(doc.id(), "review").await.unwrap().unwrap();
assert!(matches!(review_state.status, StageStatus::Paused));
let export_state = store.get_stage_state(doc.id(), "export").await.unwrap();
assert!(export_state.is_none());
workflow
.approve_review(doc.id(), "review", &mut store)
.await
.expect("approve should succeed");
workflow.advance(&doc, &mut store).await.expect("advance should succeed");
while let Ok(event) = event_receiver.try_recv() {
collected_events.push(event);
}
assert!(workflow.is_complete(doc.id(), &store).await.unwrap());
assert_eq!(export_counter.load(Ordering::SeqCst), 1);
let final_status = workflow.status(doc.id(), &store).await.unwrap();
assert!(final_status.is_complete());
assert_eq!(final_status.progress_percent(), 100.0);
let stage_completed_events: Vec<_> = collected_events
.iter()
.filter(|e| matches!(e, WorkflowEvent::StageCompleted { .. }))
.collect();
assert_eq!(stage_completed_events.len(), 4);
let workflow_completed = collected_events
.iter()
.any(|e| matches!(e, WorkflowEvent::WorkflowCompleted { .. }));
assert!(workflow_completed);
println!("{}", final_status);
}
#[tokio::test]
#[cfg(feature = "sqlite")]
async fn test_full_pipeline_with_sqlite_store() {
use treadle::SqliteStateStore;
let workflow = Workflow::builder()
.stage("parse", ParseStage)
.stage("enrich", EnrichStage::new())
.stage("review", ReviewStage)
.stage("export", ExportStage::new())
.dependency("enrich", "parse")
.dependency("review", "enrich")
.dependency("export", "review")
.build()
.expect("workflow should build");
let mut store = SqliteStateStore::open_in_memory()
.await
.expect("sqlite should open");
let doc = Document::new("doc-sqlite-1", "SQLite test document");
workflow.advance(&doc, &mut store).await.unwrap();
assert!(workflow.is_blocked(doc.id(), &store).await.unwrap());
workflow.approve_review(doc.id(), "review", &mut store).await.unwrap();
workflow.advance(&doc, &mut store).await.unwrap();
assert!(workflow.is_complete(doc.id(), &store).await.unwrap());
}
#[tokio::test]
async fn test_simple_pipeline() {
let workflow = Workflow::builder()
.stage("parse", ParseStage)
.stage("export", ExportStage::new())
.dependency("export", "parse")
.build()
.unwrap();
let mut store = MemoryStateStore::new();
let doc = Document::new("simple-doc", "Simple content");
workflow.advance(&doc, &mut store).await.unwrap();
assert!(workflow.is_complete(doc.id(), &store).await.unwrap());
}
#[tokio::test]
async fn test_multiple_documents() {
let export_counter = Arc::new(AtomicU32::new(0));
let workflow = Workflow::builder()
.stage("parse", ParseStage)
.stage("export", ExportStage::with_counter(export_counter.clone()))
.dependency("export", "parse")
.build()
.unwrap();
let mut store = MemoryStateStore::new();
let docs = vec![
Document::new("doc-1", "Content 1"),
Document::new("doc-2", "Content 2"),
Document::new("doc-3", "Content 3"),
];
for doc in &docs {
workflow.advance(doc, &mut store).await.unwrap();
}
for doc in &docs {
assert!(workflow.is_complete(doc.id(), &store).await.unwrap());
}
assert_eq!(export_counter.load(Ordering::SeqCst), 3);
}
#[tokio::test]
async fn test_status_display() {
let workflow = Workflow::builder()
.stage("parse", ParseStage)
.stage("enrich", EnrichStage::new())
.stage("review", ReviewStage)
.stage("export", ExportStage::new())
.dependency("enrich", "parse")
.dependency("review", "enrich")
.dependency("export", "review")
.build()
.unwrap();
let mut store = MemoryStateStore::new();
let doc = Document::new("status-test", "Test content");
workflow.advance(&doc, &mut store).await.unwrap();
let status = workflow.status(doc.id(), &store).await.unwrap();
let display = format!("{}", status);
assert!(display.contains("status-test"));
assert!(display.contains("parse"));
assert!(display.contains("enrich"));
assert!(display.contains("review"));
assert!(display.contains("export"));
assert!(display.contains("Awaiting review"));
println!("{}", display);
}
#[tokio::test]
async fn test_event_streaming() {
let workflow = Workflow::builder()
.stage("stage1", ParseStage)
.stage("stage2", ExportStage::new())
.dependency("stage2", "stage1")
.build()
.unwrap();
let mut events = workflow.subscribe();
let mut store = MemoryStateStore::new();
let doc = Document::new("event-test", "Content");
workflow.advance(&doc, &mut store).await.unwrap();
let mut event_types = Vec::new();
while let Ok(event) = events.try_recv() {
match event {
WorkflowEvent::StageStarted { .. } => event_types.push("started"),
WorkflowEvent::StageCompleted { .. } => event_types.push("completed"),
WorkflowEvent::WorkflowCompleted { .. } => event_types.push("workflow_complete"),
_ => {}
}
}
assert!(event_types.contains(&"started"));
assert!(event_types.contains(&"completed"));
assert!(event_types.contains(&"workflow_complete"));
}
#[tokio::test]
async fn test_fanout_with_subtask_tracking() {
let enrich_counter = Arc::new(AtomicU32::new(0));
let workflow = Workflow::builder()
.stage("enrich", EnrichStage::with_counter(enrich_counter.clone()))
.build()
.unwrap();
let mut store = MemoryStateStore::new();
let doc = Document::new("fanout-test", "Content");
workflow.advance(&doc, &mut store).await.unwrap();
let state = store.get_stage_state(doc.id(), "enrich").await.unwrap().unwrap();
assert!(matches!(state.status, StageStatus::Complete));
assert_eq!(state.subtasks.len(), 3);
assert_eq!(enrich_counter.load(Ordering::SeqCst), 3);
for subtask in &state.subtasks {
assert!(matches!(subtask.status, StageStatus::Complete));
}
}
#[tokio::test]
async fn test_pipeline_progress_tracking() {
let workflow = Workflow::builder()
.stage("s1", ParseStage)
.stage("s2", ParseStage)
.stage("s3", ParseStage)
.stage("s4", ParseStage)
.dependency("s2", "s1")
.dependency("s3", "s2")
.dependency("s4", "s3")
.build()
.unwrap();
let mut store = MemoryStateStore::new();
let doc = Document::new("progress-test", "Content");
let status = workflow.status(doc.id(), &store).await.unwrap();
assert_eq!(status.progress_percent(), 0.0);
workflow.advance(&doc, &mut store).await.unwrap();
let final_status = workflow.status(doc.id(), &store).await.unwrap();
assert_eq!(final_status.progress_percent(), 100.0);
assert!(final_status.is_complete());
}