use crate::error::{WorkflowError, WorkflowResult};
use crate::status::StatusPhase;
use crate::task_runner::{TaskRunner, TaskSupport};
use crate::tasks::task_name_impl;
use serde_json::Value;
use swf_core::models::duration::OneOfDurationOrIso8601Expression;
use swf_core::models::task::WaitTaskDefinition;
pub struct WaitTaskRunner {
name: String,
duration_expr: OneOfDurationOrIso8601Expression,
}
impl WaitTaskRunner {
pub fn new(name: &str, task: &WaitTaskDefinition) -> WorkflowResult<Self> {
Ok(Self {
name: name.to_string(),
duration_expr: task.wait.clone(),
})
}
}
#[async_trait::async_trait]
impl TaskRunner for WaitTaskRunner {
async fn run(&self, input: Value, support: &mut TaskSupport<'_>) -> WorkflowResult<Value> {
let wait_duration = support.eval_duration(&self.duration_expr, &input, &self.name)?;
if wait_duration.as_millis() == 0 {
return Ok(input);
}
support.set_task_status(&self.name, StatusPhase::Waiting);
let cancel_token = support.context.cancellation_token();
tokio::select! {
_ = tokio::time::sleep(wait_duration) => {
Ok(input)
}
_ = cancel_token.cancelled() => {
Err(WorkflowError::timeout(
format!("wait task '{}' cancelled", self.name),
&self.name,
))
}
}
}
task_name_impl!();
}
#[cfg(test)]
mod tests {
use super::*;
use crate::context::WorkflowContext;
use crate::default_support;
use crate::utils::parse_iso8601_duration;
use std::time::Duration;
use swf_core::models::duration::Duration as SwfDuration;
#[test]
fn test_parse_iso8601_duration_seconds() {
let dur = parse_iso8601_duration("PT5S").unwrap();
assert_eq!(dur, Duration::from_millis(5000));
}
#[test]
fn test_parse_iso8601_duration_minutes() {
let dur = parse_iso8601_duration("PT10M").unwrap();
assert_eq!(dur, Duration::from_millis(10 * 60 * 1000));
}
#[test]
fn test_parse_iso8601_duration_hours() {
let dur = parse_iso8601_duration("PT1H").unwrap();
assert_eq!(dur, Duration::from_millis(60 * 60 * 1000));
}
#[test]
fn test_parse_iso8601_duration_days() {
let dur = parse_iso8601_duration("P1D").unwrap();
assert_eq!(dur, Duration::from_millis(24 * 60 * 60 * 1000));
}
#[test]
fn test_parse_iso8601_duration_combined() {
let dur = parse_iso8601_duration("P1DT12H30M5S").unwrap();
let expected = (24 + 12) * 60 * 60 * 1000 + 30 * 60 * 1000 + 5000;
assert_eq!(dur, Duration::from_millis(expected as u64));
}
#[test]
fn test_parse_iso8601_duration_invalid() {
let result = parse_iso8601_duration("1Y");
assert!(result.is_none(), "expected None for invalid duration '1Y'");
let result = parse_iso8601_duration("");
assert!(result.is_none(), "expected None for empty duration");
let result = parse_iso8601_duration("P");
assert!(result.is_some(), "'P' alone should parse successfully");
let _result = parse_iso8601_duration("P1Y");
}
#[test]
fn test_parse_iso8601_duration_fractional_seconds() {
let dur = parse_iso8601_duration("PT0.25S").unwrap();
assert_eq!(dur, Duration::from_millis(250));
}
#[test]
fn test_parse_iso8601_duration_milliseconds_suffix() {
let dur = parse_iso8601_duration("PT250MS").unwrap();
assert_eq!(dur, Duration::from_millis(250));
}
#[test]
fn test_parse_iso8601_duration_combined_with_ms() {
let dur = parse_iso8601_duration("P3DT4H5M6S250MS").unwrap();
let expected = 3 * 24 * 3600 * 1000 + 4 * 3600 * 1000 + 5 * 60 * 1000 + 6 * 1000 + 250;
assert_eq!(dur, Duration::from_millis(expected as u64));
}
#[test]
fn test_parse_iso8601_duration_zero() {
let dur = parse_iso8601_duration("PT0S").unwrap();
assert_eq!(dur, Duration::from_millis(0));
}
#[tokio::test]
async fn test_wait_returns_input_unchanged() {
use serde_json::json;
use swf_core::models::task::TaskDefinitionFields;
use swf_core::models::workflow::WorkflowDefinition;
let task = WaitTaskDefinition {
wait: OneOfDurationOrIso8601Expression::Duration(SwfDuration::from_milliseconds(10)),
common: TaskDefinitionFields::new(),
};
let runner = WaitTaskRunner::new("waitTest", &task).unwrap();
let workflow = WorkflowDefinition::default();
default_support!(workflow, context, support);
let input = json!({"data": "preserved"});
let output = runner.run(input.clone(), &mut support).await.unwrap();
assert_eq!(output, input);
}
#[tokio::test]
async fn test_wait_zero_duration() {
use serde_json::json;
use swf_core::models::task::TaskDefinitionFields;
use swf_core::models::workflow::WorkflowDefinition;
let task = WaitTaskDefinition {
wait: OneOfDurationOrIso8601Expression::Duration(SwfDuration::from_milliseconds(0)),
common: TaskDefinitionFields::new(),
};
let runner = WaitTaskRunner::new("zeroWait", &task).unwrap();
let workflow = WorkflowDefinition::default();
default_support!(workflow, context, support);
let input = json!({"fast": true});
let output = runner.run(input.clone(), &mut support).await.unwrap();
assert_eq!(output, input);
}
#[tokio::test]
async fn test_wait_with_iso8601_string() {
use serde_json::json;
use swf_core::models::task::TaskDefinitionFields;
use swf_core::models::workflow::WorkflowDefinition;
let task = WaitTaskDefinition {
wait: OneOfDurationOrIso8601Expression::Iso8601Expression("PT0.01S".to_string()),
common: TaskDefinitionFields::new(),
};
let runner = WaitTaskRunner::new("isoWait", &task).unwrap();
let workflow = WorkflowDefinition::default();
default_support!(workflow, context, support);
let input = json!({"iso": "duration"});
let output = runner.run(input.clone(), &mut support).await.unwrap();
assert_eq!(output, input);
}
#[tokio::test]
async fn test_wait_then_set() {
use crate::tasks::DoTaskRunner;
use serde_json::json;
use std::collections::HashMap;
use swf_core::models::map::Map;
use swf_core::models::task::{
DoTaskDefinition, SetTaskDefinition, SetValue, TaskDefinition, TaskDefinitionFields,
};
use swf_core::models::workflow::WorkflowDefinition;
let wait_task = TaskDefinition::Wait(WaitTaskDefinition {
wait: OneOfDurationOrIso8601Expression::Duration(SwfDuration::from_milliseconds(50)),
common: TaskDefinitionFields::new(),
});
let mut set_map = HashMap::new();
set_map.insert("name".to_string(), json!("Javierito"));
let set_task = TaskDefinition::Set(SetTaskDefinition {
set: SetValue::Map(set_map),
common: TaskDefinitionFields::new(),
});
let entries = vec![
("waitABit".to_string(), wait_task),
("useExpression".to_string(), set_task),
];
let do_def = DoTaskDefinition::new(Map { entries });
let workflow = WorkflowDefinition::default();
let runner = DoTaskRunner::new("waitSet", &do_def).unwrap();
default_support!(workflow, context, support);
let output = runner.run(json!({}), &mut support).await.unwrap();
assert_eq!(output["name"], json!("Javierito"));
}
#[tokio::test]
async fn test_wait_preserves_and_references_prior_values() {
use crate::tasks::DoTaskRunner;
use serde_json::json;
use std::collections::HashMap;
use swf_core::models::map::Map;
use swf_core::models::task::{
DoTaskDefinition, SetTaskDefinition, SetValue, TaskDefinition, TaskDefinitionFields,
};
use swf_core::models::workflow::WorkflowDefinition;
let set_prepare = TaskDefinition::Set(SetTaskDefinition {
set: SetValue::Map({
let mut m = HashMap::new();
m.insert("phase".to_string(), json!("started"));
m.insert("waitExpression".to_string(), json!("PT1S"));
m
}),
common: TaskDefinitionFields::new(),
});
let wait_task = TaskDefinition::Wait(WaitTaskDefinition {
wait: OneOfDurationOrIso8601Expression::Iso8601Expression("PT0.01S".to_string()),
common: TaskDefinitionFields::new(),
});
let set_complete = TaskDefinition::Set(SetTaskDefinition {
set: SetValue::Map({
let mut m = HashMap::new();
m.insert("phase".to_string(), json!("completed"));
m.insert("previousPhase".to_string(), json!("${ .phase }"));
m.insert("waitExpression".to_string(), json!("${ .waitExpression }"));
m
}),
common: TaskDefinitionFields::new(),
});
let entries = vec![
("prepareWaitExample".to_string(), set_prepare),
("waitOneSecond".to_string(), wait_task),
("completeWaitExample".to_string(), set_complete),
];
let do_def = DoTaskDefinition::new(Map { entries });
let workflow = WorkflowDefinition::default();
let runner = DoTaskRunner::new("waitPreserve", &do_def).unwrap();
default_support!(workflow, context, support);
let output = runner.run(json!({}), &mut support).await.unwrap();
assert_eq!(output["phase"], json!("completed"));
assert_eq!(output["previousPhase"], json!("started"));
assert_eq!(output["waitExpression"], json!("PT1S"));
}
#[tokio::test]
async fn test_wait_cancellation() {
use serde_json::json;
use swf_core::models::task::TaskDefinitionFields;
use swf_core::models::workflow::WorkflowDefinition;
let task = WaitTaskDefinition {
wait: OneOfDurationOrIso8601Expression::Duration(SwfDuration::from_seconds(10)),
common: TaskDefinitionFields::new(),
};
let runner = WaitTaskRunner::new("cancelTest", &task).unwrap();
let workflow = WorkflowDefinition::default();
let mut context = WorkflowContext::new(&workflow).unwrap();
context.cancel();
let mut support = TaskSupport::new(&workflow, &mut context);
let result = runner.run(json!({"data": "test"}), &mut support).await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("cancelled"),
"Expected cancellation error, got: {}",
err
);
}
#[tokio::test]
async fn test_wait_cancellation_during_wait() {
use serde_json::json;
use swf_core::models::task::TaskDefinitionFields;
use swf_core::models::workflow::WorkflowDefinition;
let task = WaitTaskDefinition {
wait: OneOfDurationOrIso8601Expression::Duration(SwfDuration::from_milliseconds(10)),
common: TaskDefinitionFields::new(),
};
let runner = WaitTaskRunner::new("midCancel", &task).unwrap();
let workflow = WorkflowDefinition::default();
let mut context = WorkflowContext::new(&workflow).unwrap();
let token = context.cancellation_token();
let mut support = TaskSupport::new(&workflow, &mut context);
tokio::spawn(async move {
tokio::time::sleep(Duration::from_millis(1)).await;
token.cancel();
});
let start = std::time::Instant::now();
let result = runner.run(json!({"data": "test"}), &mut support).await;
let elapsed = start.elapsed();
assert!(result.is_err());
assert!(
elapsed < Duration::from_millis(500),
"Should cancel quickly, took {:?}",
elapsed
);
}
}