fsmy 0.1.0

A finite state machine library
use std::marker::PhantomData;

use derive_getters::Getters;

#[cfg(feature = "events")]
use crate::prelude::{EventProducer, StateMachineEvent};

use crate::{RunningFiniteStateMachine, prelude::*, stack::Stack};

#[derive(Debug, Getters, Default)]
pub struct JournaledStateMachine<'a, C>
where
    C: ConfigurableStateMachine<'a>,
{
    inner: C,
    marker: PhantomData<&'a ()>,
}

impl<'a, C> ConfigurableStateMachine<'a> for JournaledStateMachine<'a, C>
where
    C: ConfigurableStateMachine<
            'a,
            StateMachine = RunningFiniteStateMachine<
                'a,
                <C as ConfigurableStateMachine<'a>>::State,
                <C as ConfigurableStateMachine<'a>>::Input,
                <C as ConfigurableStateMachine<'a>>::Output,
                <C as ConfigurableStateMachine<'a>>::Context,
            >,
        >,
    C::Input: PartialEq,
    C::Context: 'a + Clone,
    C::State: StateLike,
{
    type Input = C::Input;

    type State = C::State;

    type Context = C::Context;

    type Output = C::Output;

    type StateMachine = RunningJournaledStateMachine<
        'a,
        RunningFiniteStateMachine<'a, Self::State, Self::Input, Self::Output, Self::Context>,
    >;

    fn start(self) -> Result<Self::StateMachine, Error<'a, Self::State, Self::Input>> {
        if let Ok(inner_machine) = self.inner.start() {
            Ok(RunningJournaledStateMachine {
                inner: inner_machine,
                history: Stack::default(),
            })
        } else {
            todo!()
        }
    }

    fn set_initial_state(&mut self, initial_state: Self::State) {
        self.inner.set_initial_state(initial_state);
    }

    fn define_state(
        &mut self,
        state: Self::State,
        options: state::Options<'a, Self>,
    ) -> Option<Self::State> {
        self.inner.define_state(
            state,
            state::Options {
                is_virtual: options.is_virtual,
                parent: options.parent,
                on_enter: options.on_enter,
                on_leave: options.on_leave,
            },
        )
    }

    fn add_transition(
        &mut self,
        source_state: Self::State,
        event: Self::Input,
        target_state: Self::State,
        output: Option<Self::Output>,
        mutation: Option<fn(&mut Self::Context)>,
    ) {
        self.inner
            .add_transition(source_state, event, target_state, output, mutation);
    }
}

/// A state machine that also records a history of its past states.
/// It can be rewinded to its previous state by the `rewind` function.
///
/// Rewinding a state machine also restores the previous state of
/// the context.
#[derive(Debug, Getters)]
pub struct RunningJournaledStateMachine<'a, M>
where
    M: RunningStateMachine<'a>,
{
    inner: M,
    history: Stack<(M::State, M::Context)>,
}

impl<'a, M> RunningStateMachine<'a> for RunningJournaledStateMachine<'a, M>
where
    M: RunningStateMachine<'a>,
    M::Context: Clone,
    M::State: Clone,
{
    type Input = M::Input;

    type State = M::State;

    type Context = M::Context;

    type Output = M::Output;

    fn states(&self) -> Vec<&Self::State> {
        self.inner.states()
    }

    fn state(&self) -> &Self::State {
        self.inner.state()
    }

    fn context(&self) -> &Self::Context {
        self.inner.context()
    }

    fn is_final(&self) -> bool {
        self.inner.is_final()
    }

    fn is_active(&self, state: Self::State) -> bool {
        self.inner.is_active(state)
    }

    fn consume(
        &mut self,
        event: Self::Input,
    ) -> Result<Option<&Self::Output>, crate::prelude::Error<'_, Self::State, Self::Input>> {
        let old_state = self.inner().state().clone();
        let old_context = self.inner().context().clone();

        match self.inner.consume(event) {
            Ok(result) => {
                self.history.push((old_state, old_context));
                Ok(result)
            }
            Err(e) => Err(e),
        }
    }
}

#[cfg(feature = "events")]
impl<'a, M> History<'a> for RunningJournaledStateMachine<'a, M>
where
    M: EventProducer<M::State> + StateMachineMut<'a>,
    M::Context: Clone,
    M::State: Clone,
{
    type Len = usize;

    fn history_len(&self) -> Self::Len {
        self.history.len()
    }

    fn rewind(&mut self) -> Result<(), crate::prelude::Error<'_, Self::State, Self::Input>> {
        if let Some(prev) = self.history.pop() {
            #[cfg(feature = "events")]
            let current = self.inner.state().clone();

            self.inner.set_state(prev.0.clone());
            *self.inner.context_mut() = prev.1;

            #[cfg(feature = "events")]
            let _ = self
                .inner
                .emit(StateMachineEvent::StateChanged((current, prev.0)));
            Ok(())
        } else {
            Err(Error::NoHistory)
        }
    }
}

#[cfg(not(feature = "events"))]
impl<'a, M> History<'a> for RunningJournaledStateMachine<'a, M>
where
    M: StateMachineMut<'a>,
    M::Context: Clone,
    M::State: Clone,
{
    type Len = usize;

    fn history_len(&self) -> Self::Len {
        self.history.len()
    }

    fn rewind(&mut self) -> Result<(), crate::prelude::Error<'_, Self::State, Self::Input>> {
        if let Some(prev) = self.history.pop() {
            #[cfg(feature = "events")]
            let current = self.inner.state().clone();

            self.inner.set_state(prev.0.clone());
            *self.inner.context_mut() = prev.1;
            Ok(())
        } else {
            Err(Error::NoHistory)
        }
    }
}

impl<'a, M> RunningJournaledStateMachine<'a, M>
where
    M: RunningStateMachine<'a>,
    M::State: Clone,
    M::Context: Clone,
{
    pub fn with_machine(inner: M) -> Self {
        let inner = inner;
        let history = Stack::default();

        Self { inner, history }
    }

    pub fn with_machine_and_max_history_size(inner: M, max_history_size: usize) -> Self {
        let inner = inner;
        let history = Stack::with_max_size(max_history_size);

        Self { inner, history }
    }
}

#[cfg(test)]
mod test {
    use fsmy_dsl::Primitive;

    use crate::{FiniteStateMachine, JournaledStateMachine, prelude::*};

    #[derive(Debug, Primitive)]
    enum State {
        Open,
        Closed,
    }

    #[derive(Debug, PartialEq)]
    enum Message {
        Push,
        Pull,
    }

    #[derive(Default, Debug, Clone)]
    pub struct Context {
        counter: usize,
    }

    fn machine<'a>()
    -> JournaledStateMachine<'a, FiniteStateMachine<'a, State, Message, (), Context>> {
        let mut machine = JournaledStateMachine::default();
        machine.set_initial_state(State::Closed);

        machine.add_transition_with_mutation(
            State::Closed,
            Message::Pull,
            State::Open,
            None,
            |c: &mut Context| c.counter += 1,
        );
        machine.add_transition(State::Open, Message::Push, State::Closed, None, None);

        machine
    }

    #[test]
    fn rewind() {
        let mut machine = machine().start().unwrap();
        assert!(machine.consume(Message::Pull).is_ok());
        assert_eq!(*machine.state(), State::Open);
        assert_eq!(machine.context().counter, 1);

        assert!(machine.consume(Message::Push).is_ok());
        assert_eq!(*machine.state(), State::Closed);
        assert_eq!(machine.context().counter, 1);

        assert!(machine.rewind().is_ok());
        assert_eq!(*machine.state(), State::Open);
        assert_eq!(machine.context().counter, 1);

        assert!(machine.rewind().is_ok());
        assert_eq!(*machine.state(), State::Closed);
        assert_eq!(machine.context().counter, 0);

        assert!(machine.rewind().is_err());
    }
}