asl 0.2.0

Rust implementation for Amazon States Language
Documentation
use serde_json::value::Value;
use std::cell::Cell;
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;

use asl::asl::execution::{
    ExecutionStatus, ExecutionStatusDiscriminants, StateMachineExecutionError,
};
use asl::asl::execution::{StateExecutionHandler, StateExecutionOutput};
use asl::asl::state_machine::StateMachine;
use similar_asserts::assert_eq;

struct TestStateExecutionHandler {
    resource_name_to_output: HashMap<String, TaskResults>,
    times_called: HashMap<String, Cell<usize>>,
}

impl TestStateExecutionHandler {
    fn new() -> TestStateExecutionHandler {
        TestStateExecutionHandler {
            resource_name_to_output: hash_map![],
            times_called: hash_map![],
        }
    }
    fn with_map(
        resource_name_to_output: HashMap<String, TaskResults>,
    ) -> TestStateExecutionHandler {
        let map = resource_name_to_output
            .keys()
            .map(|k| (k.to_string(), Cell::new(0)))
            .collect();
        TestStateExecutionHandler {
            resource_name_to_output,
            times_called: map,
        }
    }
}

#[derive(Error, Debug)]
enum MyTaskExecutionError {
    #[error("{0}")]
    ForwardedError(String),
}

impl StateExecutionHandler for TestStateExecutionHandler {
    type TaskExecutionError = MyTaskExecutionError;

    fn execute_task(
        &self,
        resource: &str,
        input: &Value,
        _credentials: Option<&Value>,
    ) -> Result<Option<Value>, Self::TaskExecutionError> {
        let option = self.resource_name_to_output.get(resource);
        match option {
            None => Ok(Some(input.clone())), // resource is not mapped, so just forwards the input
            Some(desired_outputs) => {
                let times_called = self.times_called.get(resource).unwrap();
                let index = times_called.get();
                let desired_output = match desired_outputs {
                    TaskResults::Repeat(o) => o,
                    TaskResults::List(vec) => vec.get(index).unwrap_or_else(|| {
                        panic!(
                            "Task called {} times, but there are only {} expected results.",
                            index + 1,
                            vec.len()
                        )
                    }),
                };
                times_called.set(index + 1);

                match desired_output {
                    TaskBehavior::Output(val) => Ok(Some(val.to_owned())), // resource is mapped, so returns the desired output
                    TaskBehavior::Error(err) => {
                        Err(MyTaskExecutionError::ForwardedError(err.clone()))
                    }
                }
            }
        }
    }

    fn wait(&self, _seconds: f64) {
        //nop on purpose (sleeping a thread is bad for tests).
    }
}

use asl::asl::execution::ExecutionStatus::FinishedWithFailure;
use asl::asl::states::error_handling::StateMachineExecutionPredefinedErrors;
use asl::asl::types::execution::{EmptyContext, StateMachineContext};
use itertools::Itertools;
use map_macro::hash_map;
use rstest::*;
use serde_with::serde_derive::Deserialize;
use testresult::TestResult;
use thiserror::Error;
use wildmatch::WildMatch;

#[rstest]
fn execute_hello_world_succeed_state() -> TestResult {
    let definition = include_str!("test-data/hello-world-succeed-state.json");
    let state_machine = StateMachine::parse(definition)?;

    let input = serde_json::from_str(
        r#"
            "Hello world"
        "#,
    )?;
    let val = Value::from("Hello world");

    let mut execution =
        state_machine.start(&input, TestStateExecutionHandler::new(), EmptyContext {});
    assert_eq!(ExecutionStatus::Executing, execution.status);

    // Advance state
    let state_output = execution.next();
    assert_eq!(
        state_output,
        Some(StateExecutionOutput {
            status: ExecutionStatus::Executing,
            state_name: Some("Hello World".to_string()),
            result: Some(val.clone())
        })
    );
    assert_eq!(ExecutionStatus::Executing, execution.status);

    // Advance state
    let state_output = execution.next();
    assert_eq!(
        state_output,
        Some(StateExecutionOutput {
            status: ExecutionStatus::FinishedWithSuccess(Some(val.clone())),
            state_name: Some("Succeed State".to_string()),
            result: Some(val.clone())
        })
    );
    assert_eq!(
        ExecutionStatus::FinishedWithSuccess(Some(val.clone())),
        execution.status
    );

    // Iterator is exhausted
    assert_eq!(None, execution.next());
    assert_eq!(
        ExecutionStatus::FinishedWithSuccess(Some(val.clone())),
        execution.status
    );
    Ok(())
}

#[rstest]
fn execute_hello_world_fail_state() -> TestResult {
    let definition = include_str!("test-data/hello-world-fail-state.json");
    let state_machine = StateMachine::parse(definition)?;

    let val = serde_json::from_str(
        r#"
            "Hello world"
        "#,
    )?;
    let mut execution =
        state_machine.start(&val, TestStateExecutionHandler::new(), EmptyContext {});
    assert_eq!(ExecutionStatus::Executing, execution.status);

    // Advance state
    let state_output = execution.next();
    assert_eq!(
        state_output,
        Some(StateExecutionOutput {
            status: ExecutionStatus::Executing,
            state_name: Some("Hello World".to_string()),
            result: Some(val.clone())
        })
    );
    assert_eq!(ExecutionStatus::Executing, execution.status);

    // Advance state
    let state_output = execution.next();
    let expected_status = with_error_and_cause("ErrorA", "Kaiju attack");
    assert_eq!(
        state_output,
        Some(StateExecutionOutput {
            status: expected_status.clone(),
            state_name: Some("Fail State".to_string()),
            result: None,
        })
    );
    assert_eq!(expected_status, execution.status);

    // Iterator is exhausted
    assert_eq!(None, execution.next());
    assert_eq!(expected_status, execution.status);
    Ok(())
}

pub fn with_error_and_cause(error: &str, cause: &str) -> ExecutionStatus {
    FinishedWithFailure(StateMachineExecutionError {
        error: StateMachineExecutionPredefinedErrors::Custom(error.to_string()),
        cause: Some(String::from(cause)),
    })
}
pub fn with_success_and_output(output: &str) -> ExecutionStatus {
    let val = serde_json::from_str(output).expect("Invalid json specified");
    ExecutionStatus::FinishedWithSuccess(val)
}

#[derive(Deserialize, Clone, PartialEq, Debug)]
#[serde(rename_all = "snake_case")]
enum ExpectedFinalStatus {
    Output(Value),
    #[serde(rename_all = "PascalCase")]
    Error {
        error: StateMachineExecutionPredefinedErrors,
        cause: Option<String>,
    },
}

/// Controls what the Task will do in the test cases.
/// If it finds an `Output` key, then it will forward the output and *succeed* the task
/// If it finds an `Error` key, then it will error with the string provided and *fail* the task.
#[derive(Deserialize, Clone, PartialEq, Debug)]
#[serde(rename_all = "snake_case")]
enum TaskBehavior {
    Output(Value),
    Error(String),
}

#[derive(Deserialize, Clone, PartialEq, Debug)]
#[serde(untagged)]
enum TaskResults {
    Repeat(TaskBehavior),
    List(Vec<TaskBehavior>),
}

#[derive(Deserialize, Debug)]
struct ExpectedExecution {
    input: Value,
    #[serde(flatten)]
    final_status: ExpectedFinalStatus,
    states: Vec<String>,
    task_behavior: Option<HashMap<String, TaskResults>>,
    context: Option<HashMap<String, Value>>,
}

impl From<ExpectedFinalStatus> for ExecutionStatus {
    fn from(value: ExpectedFinalStatus) -> Self {
        match value {
            ExpectedFinalStatus::Output(val) => ExecutionStatus::FinishedWithSuccess(Some(val)),
            ExpectedFinalStatus::Error { error, cause } => {
                FinishedWithFailure(StateMachineExecutionError { error, cause })
            }
        }
    }
}

#[derive(Debug)]
pub struct MapContext {
    map: HashMap<String, Value>,
    current_state_name: String,
}

impl MapContext {
    fn new(map: HashMap<String, Value>) -> Self {
        MapContext {
            map,
            current_state_name: "".to_string(),
        }
    }
}

impl StateMachineContext for MapContext {
    fn as_value(&self) -> Value {
        self.map
            .get(&self.current_state_name)
            .cloned()
            .unwrap_or(Value::Null)
    }

    fn transition_to_state(&mut self, state: &str) {
        self.current_state_name = String::from(state);
    }
}

#[rstest]
fn execute_all(
    #[files("**/test-data/expected-executions-valid-cases/valid-*.json5")]
    // TODO: Support Map state
    #[exclude("valid-map.*")]
    // TODO: Support Parallel state
    #[exclude("valid-parallel.*")]
    #[exclude("valid-parameters-resultSelector.*")]
    // TODO: Support Intrinsic Functions
    #[exclude("valid-fail-paths\\.json.*")]
    #[exclude("valid-intrinsic-functions.*")]
    // TODO: Support negative index
    #[exclude("valid-pass-negativeIndex.*")]
    // TODO: Support retries
    #[exclude("valid-retry-failure.*")]
    path: PathBuf,
) -> TestResult {
    let all_expected_executions: Vec<ExpectedExecution> =
        json5::from_str(&fs::read_to_string(&path)?)?;

    let state_machine_definition_filename = path.with_extension("json");
    let state_machine_definition_filename = state_machine_definition_filename.file_name().unwrap();
    let definition = fs::read_to_string(format!(
        "tests/test-data/asl-validator/{}",
        state_machine_definition_filename.to_str().unwrap()
    ))?;
    let state_machine = StateMachine::parse(&definition)?;
    dbg!("Parsed state machine:", &state_machine);

    // The loop is for each test case contained within the JSON file
    for (i, execution_expected_input) in all_expected_executions.iter().enumerate() {
        let input = &execution_expected_input.input;
        let map = execution_expected_input
            .task_behavior
            .clone()
            .unwrap_or(hash_map![]);
        let handler = TestStateExecutionHandler::with_map(map);
        let context = MapContext::new(
            execution_expected_input
                .context
                .clone()
                .unwrap_or(HashMap::new()),
        );
        dbg!("State machine execution", i);
        dbg!("Input:", input);
        dbg!("Context:", &context);
        let execution = state_machine.start(input, handler, context);

        // Execute the state machine and collect all steps so they can be compared
        let execution_steps = execution.collect_vec();
        for step in &execution_steps {
            println!("{:?}", step);
        }

        // Compare the states seen by the execution
        let actual_states = &execution_steps
            .iter()
            .map(|e| e.state_name.as_ref().unwrap_or(&String::new()).clone())
            .collect_vec();
        let expected_states = &execution_expected_input.states;
        assert_eq!(
            expected_states, actual_states,
            "States are different for test case {i}."
        );

        // Compare final status
        let expected_status = &ExecutionStatus::from(execution_expected_input.final_status.clone());
        let actual_status = &execution_steps.last().unwrap().status;
        assert_eq!(
            ExecutionStatusDiscriminants::from(expected_status),
            ExecutionStatusDiscriminants::from(actual_status),
            "Final execution status are different for test case {i}."
        );
        // When it's a failure, then we treat the "Cause" as a wildcard so that error messages can
        // be easily matched.
        if let FinishedWithFailure(expected) = expected_status {
            // TODO: use if-let chains, but needs https://github.com/rust-lang/rust/issues/53667 first
            if let FinishedWithFailure(actual) = actual_status {
                assert_eq!(expected.error, actual.error);
                assert_eq!(expected.cause.is_some(), actual.cause.is_some());
                if expected.cause.is_some() {
                    let expected_cause = expected.cause.as_ref().unwrap();
                    let actual_cause = actual.cause.as_ref().unwrap();
                    let matches = WildMatch::new(&expected_cause).matches(&actual_cause);
                    assert!(
                        matches,
                        "Expected wildcard expression for Cause '{}' to match the actual Cause: {}",
                        expected_cause, actual_cause
                    );
                }
            } else {
                panic!("wtf");
            }
        } else {
            assert_eq!(
                expected_status, actual_status,
                "Status are different for test case {i}."
            );
        }
    }

    Ok(())
}