use serde_json::value::Value;
use std::cell::Cell;
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use asl::asl::execution::{
ExecutionStatus, ExecutionStatusDiscriminants, StateMachineExecutionError,
};
use asl::asl::execution::{StateExecutionHandler, StateExecutionOutput};
use asl::asl::state_machine::StateMachine;
use similar_asserts::assert_eq;
struct TestStateExecutionHandler {
resource_name_to_output: HashMap<String, TaskResults>,
times_called: HashMap<String, Cell<usize>>,
}
impl TestStateExecutionHandler {
fn new() -> TestStateExecutionHandler {
TestStateExecutionHandler {
resource_name_to_output: hash_map![],
times_called: hash_map![],
}
}
fn with_map(
resource_name_to_output: HashMap<String, TaskResults>,
) -> TestStateExecutionHandler {
let map = resource_name_to_output
.keys()
.map(|k| (k.to_string(), Cell::new(0)))
.collect();
TestStateExecutionHandler {
resource_name_to_output,
times_called: map,
}
}
}
#[derive(Error, Debug)]
enum MyTaskExecutionError {
#[error("{0}")]
ForwardedError(String),
}
impl StateExecutionHandler for TestStateExecutionHandler {
type TaskExecutionError = MyTaskExecutionError;
fn execute_task(
&self,
resource: &str,
input: &Value,
_credentials: Option<&Value>,
) -> Result<Option<Value>, Self::TaskExecutionError> {
let option = self.resource_name_to_output.get(resource);
match option {
None => Ok(Some(input.clone())), Some(desired_outputs) => {
let times_called = self.times_called.get(resource).unwrap();
let index = times_called.get();
let desired_output = match desired_outputs {
TaskResults::Repeat(o) => o,
TaskResults::List(vec) => vec.get(index).unwrap_or_else(|| {
panic!(
"Task called {} times, but there are only {} expected results.",
index + 1,
vec.len()
)
}),
};
times_called.set(index + 1);
match desired_output {
TaskBehavior::Output(val) => Ok(Some(val.to_owned())), TaskBehavior::Error(err) => {
Err(MyTaskExecutionError::ForwardedError(err.clone()))
}
}
}
}
}
fn wait(&self, _seconds: f64) {
}
}
use asl::asl::execution::ExecutionStatus::FinishedWithFailure;
use asl::asl::states::error_handling::StateMachineExecutionPredefinedErrors;
use asl::asl::types::execution::{EmptyContext, StateMachineContext};
use itertools::Itertools;
use map_macro::hash_map;
use rstest::*;
use serde_with::serde_derive::Deserialize;
use testresult::TestResult;
use thiserror::Error;
use wildmatch::WildMatch;
#[rstest]
fn execute_hello_world_succeed_state() -> TestResult {
let definition = include_str!("test-data/hello-world-succeed-state.json");
let state_machine = StateMachine::parse(definition)?;
let input = serde_json::from_str(
r#"
"Hello world"
"#,
)?;
let val = Value::from("Hello world");
let mut execution =
state_machine.start(&input, TestStateExecutionHandler::new(), EmptyContext {});
assert_eq!(ExecutionStatus::Executing, execution.status);
let state_output = execution.next();
assert_eq!(
state_output,
Some(StateExecutionOutput {
status: ExecutionStatus::Executing,
state_name: Some("Hello World".to_string()),
result: Some(val.clone())
})
);
assert_eq!(ExecutionStatus::Executing, execution.status);
let state_output = execution.next();
assert_eq!(
state_output,
Some(StateExecutionOutput {
status: ExecutionStatus::FinishedWithSuccess(Some(val.clone())),
state_name: Some("Succeed State".to_string()),
result: Some(val.clone())
})
);
assert_eq!(
ExecutionStatus::FinishedWithSuccess(Some(val.clone())),
execution.status
);
assert_eq!(None, execution.next());
assert_eq!(
ExecutionStatus::FinishedWithSuccess(Some(val.clone())),
execution.status
);
Ok(())
}
#[rstest]
fn execute_hello_world_fail_state() -> TestResult {
let definition = include_str!("test-data/hello-world-fail-state.json");
let state_machine = StateMachine::parse(definition)?;
let val = serde_json::from_str(
r#"
"Hello world"
"#,
)?;
let mut execution =
state_machine.start(&val, TestStateExecutionHandler::new(), EmptyContext {});
assert_eq!(ExecutionStatus::Executing, execution.status);
let state_output = execution.next();
assert_eq!(
state_output,
Some(StateExecutionOutput {
status: ExecutionStatus::Executing,
state_name: Some("Hello World".to_string()),
result: Some(val.clone())
})
);
assert_eq!(ExecutionStatus::Executing, execution.status);
let state_output = execution.next();
let expected_status = with_error_and_cause("ErrorA", "Kaiju attack");
assert_eq!(
state_output,
Some(StateExecutionOutput {
status: expected_status.clone(),
state_name: Some("Fail State".to_string()),
result: None,
})
);
assert_eq!(expected_status, execution.status);
assert_eq!(None, execution.next());
assert_eq!(expected_status, execution.status);
Ok(())
}
pub fn with_error_and_cause(error: &str, cause: &str) -> ExecutionStatus {
FinishedWithFailure(StateMachineExecutionError {
error: StateMachineExecutionPredefinedErrors::Custom(error.to_string()),
cause: Some(String::from(cause)),
})
}
pub fn with_success_and_output(output: &str) -> ExecutionStatus {
let val = serde_json::from_str(output).expect("Invalid json specified");
ExecutionStatus::FinishedWithSuccess(val)
}
#[derive(Deserialize, Clone, PartialEq, Debug)]
#[serde(rename_all = "snake_case")]
enum ExpectedFinalStatus {
Output(Value),
#[serde(rename_all = "PascalCase")]
Error {
error: StateMachineExecutionPredefinedErrors,
cause: Option<String>,
},
}
#[derive(Deserialize, Clone, PartialEq, Debug)]
#[serde(rename_all = "snake_case")]
enum TaskBehavior {
Output(Value),
Error(String),
}
#[derive(Deserialize, Clone, PartialEq, Debug)]
#[serde(untagged)]
enum TaskResults {
Repeat(TaskBehavior),
List(Vec<TaskBehavior>),
}
#[derive(Deserialize, Debug)]
struct ExpectedExecution {
input: Value,
#[serde(flatten)]
final_status: ExpectedFinalStatus,
states: Vec<String>,
task_behavior: Option<HashMap<String, TaskResults>>,
context: Option<HashMap<String, Value>>,
}
impl From<ExpectedFinalStatus> for ExecutionStatus {
fn from(value: ExpectedFinalStatus) -> Self {
match value {
ExpectedFinalStatus::Output(val) => ExecutionStatus::FinishedWithSuccess(Some(val)),
ExpectedFinalStatus::Error { error, cause } => {
FinishedWithFailure(StateMachineExecutionError { error, cause })
}
}
}
}
#[derive(Debug)]
pub struct MapContext {
map: HashMap<String, Value>,
current_state_name: String,
}
impl MapContext {
fn new(map: HashMap<String, Value>) -> Self {
MapContext {
map,
current_state_name: "".to_string(),
}
}
}
impl StateMachineContext for MapContext {
fn as_value(&self) -> Value {
self.map
.get(&self.current_state_name)
.cloned()
.unwrap_or(Value::Null)
}
fn transition_to_state(&mut self, state: &str) {
self.current_state_name = String::from(state);
}
}
#[rstest]
fn execute_all(
#[files("**/test-data/expected-executions-valid-cases/valid-*.json5")]
#[exclude("valid-map.*")]
#[exclude("valid-parallel.*")]
#[exclude("valid-parameters-resultSelector.*")]
#[exclude("valid-fail-paths\\.json.*")]
#[exclude("valid-intrinsic-functions.*")]
#[exclude("valid-pass-negativeIndex.*")]
#[exclude("valid-retry-failure.*")]
path: PathBuf,
) -> TestResult {
let all_expected_executions: Vec<ExpectedExecution> =
json5::from_str(&fs::read_to_string(&path)?)?;
let state_machine_definition_filename = path.with_extension("json");
let state_machine_definition_filename = state_machine_definition_filename.file_name().unwrap();
let definition = fs::read_to_string(format!(
"tests/test-data/asl-validator/{}",
state_machine_definition_filename.to_str().unwrap()
))?;
let state_machine = StateMachine::parse(&definition)?;
dbg!("Parsed state machine:", &state_machine);
for (i, execution_expected_input) in all_expected_executions.iter().enumerate() {
let input = &execution_expected_input.input;
let map = execution_expected_input
.task_behavior
.clone()
.unwrap_or(hash_map![]);
let handler = TestStateExecutionHandler::with_map(map);
let context = MapContext::new(
execution_expected_input
.context
.clone()
.unwrap_or(HashMap::new()),
);
dbg!("State machine execution", i);
dbg!("Input:", input);
dbg!("Context:", &context);
let execution = state_machine.start(input, handler, context);
let execution_steps = execution.collect_vec();
for step in &execution_steps {
println!("{:?}", step);
}
let actual_states = &execution_steps
.iter()
.map(|e| e.state_name.as_ref().unwrap_or(&String::new()).clone())
.collect_vec();
let expected_states = &execution_expected_input.states;
assert_eq!(
expected_states, actual_states,
"States are different for test case {i}."
);
let expected_status = &ExecutionStatus::from(execution_expected_input.final_status.clone());
let actual_status = &execution_steps.last().unwrap().status;
assert_eq!(
ExecutionStatusDiscriminants::from(expected_status),
ExecutionStatusDiscriminants::from(actual_status),
"Final execution status are different for test case {i}."
);
if let FinishedWithFailure(expected) = expected_status {
if let FinishedWithFailure(actual) = actual_status {
assert_eq!(expected.error, actual.error);
assert_eq!(expected.cause.is_some(), actual.cause.is_some());
if expected.cause.is_some() {
let expected_cause = expected.cause.as_ref().unwrap();
let actual_cause = actual.cause.as_ref().unwrap();
let matches = WildMatch::new(&expected_cause).matches(&actual_cause);
assert!(
matches,
"Expected wildcard expression for Cause '{}' to match the actual Cause: {}",
expected_cause, actual_cause
);
}
} else {
panic!("wtf");
}
} else {
assert_eq!(
expected_status, actual_status,
"Status are different for test case {i}."
);
}
}
Ok(())
}