state-machines 0.9.0

Type-safe state machines with hierarchical states, guards, callbacks, and async support - Rust port of Ruby's state_machines gem
Documentation
#![cfg(feature = "async")]
#![allow(non_camel_case_types)]
#![allow(non_snake_case)]

use std::sync::atomic::{AtomicBool, Ordering};

use pollster::block_on;
use state_machines::{DynamicError, EventError, state_machine};

static BEFORE_FAILS: AtomicBool = AtomicBool::new(false);
static AFTER_FAILS: AtomicBool = AtomicBool::new(false);
static BEFORE_CALLED: AtomicBool = AtomicBool::new(false);
static AFTER_CALLED: AtomicBool = AtomicBool::new(false);

#[derive(Debug, Clone, PartialEq, Eq)]
enum AuthError {
    Network,
    Audit,
}

state_machine! {
    name: AuthRecovery,
    dynamic: true,
    async: true,
    error: AuthError,
    initial: RefreshToken,
    states: [RefreshToken, Done],
    events {
        refresh {
            before: [refresh_token],
            after: [record_completion],
            transition: { from: RefreshToken, to: Done }
        }
    }
}

impl<C, S> AuthRecovery<C, S> {
    async fn refresh_token(&self) -> Result<(), AuthError> {
        BEFORE_CALLED.store(true, Ordering::SeqCst);
        if BEFORE_FAILS.load(Ordering::SeqCst) {
            Err(AuthError::Network)
        } else {
            Ok(())
        }
    }

    async fn record_completion(&self) -> Result<(), AuthError> {
        AFTER_CALLED.store(true, Ordering::SeqCst);
        if AFTER_FAILS.load(Ordering::SeqCst) {
            Err(AuthError::Audit)
        } else {
            Ok(())
        }
    }
}

fn reset_flags() {
    BEFORE_FAILS.store(false, Ordering::SeqCst);
    AFTER_FAILS.store(false, Ordering::SeqCst);
    BEFORE_CALLED.store(false, Ordering::SeqCst);
    AFTER_CALLED.store(false, Ordering::SeqCst);
}

#[test]
fn async_before_callback_failure_returns_source_machine() {
    reset_flags();
    BEFORE_FAILS.store(true, Ordering::SeqCst);

    block_on(async {
        let machine = AuthRecovery::new(());
        let err = machine
            .refresh()
            .await
            .expect_err("before callback should fail");

        let (machine, err) = err;
        match err {
            EventError::Callback(err) => {
                assert_eq!(err.action, "refresh_token");
                assert_eq!(err.event, "refresh");
                assert_eq!(err.source, AuthError::Network);
            }
            other => panic!("expected callback error, got {other:?}"),
        }

        assert!(BEFORE_CALLED.load(Ordering::SeqCst));
        assert!(!AFTER_CALLED.load(Ordering::SeqCst));

        BEFORE_FAILS.store(false, Ordering::SeqCst);
        let _done = machine
            .refresh()
            .await
            .expect("retry should succeed from the original state");
    });
}

#[test]
fn async_after_callback_failure_rolls_back_transition() {
    reset_flags();
    AFTER_FAILS.store(true, Ordering::SeqCst);

    block_on(async {
        let machine = AuthRecovery::new(());
        let err = machine
            .refresh()
            .await
            .expect_err("after callback should fail");

        let (machine, err) = err;
        match err {
            EventError::Callback(err) => {
                assert_eq!(err.action, "record_completion");
                assert_eq!(err.event, "refresh");
                assert_eq!(err.source, AuthError::Audit);
            }
            other => panic!("expected callback error, got {other:?}"),
        }

        assert!(BEFORE_CALLED.load(Ordering::SeqCst));
        assert!(AFTER_CALLED.load(Ordering::SeqCst));

        AFTER_FAILS.store(false, Ordering::SeqCst);
        let _done = machine
            .refresh()
            .await
            .expect("machine should still be in RefreshToken after rollback");
    });
}

#[test]
fn async_dynamic_callback_failure_keeps_runtime_state() {
    reset_flags();
    BEFORE_FAILS.store(true, Ordering::SeqCst);

    block_on(async {
        let mut machine = DynamicAuthRecovery::new(());

        let err = machine
            .handle(AuthRecoveryEvent::Refresh)
            .await
            .expect_err("dynamic handle should surface callback errors");

        match err {
            DynamicError::CallbackFailed {
                action,
                event,
                source,
            } => {
                assert_eq!(action, "refresh_token");
                assert_eq!(event, "refresh");
                assert_eq!(source, AuthError::Network);
            }
            other => panic!("expected callback failure, got {other:?}"),
        }

        assert_eq!(machine.current_state(), "RefreshToken");

        BEFORE_FAILS.store(false, Ordering::SeqCst);
        machine
            .handle(AuthRecoveryEvent::Refresh)
            .await
            .expect("retry should succeed without advancing on the failed attempt");
        assert_eq!(machine.current_state(), "Done");
    });
}