use crate::event::IntoEvent;
use crate::{
action::ActionType, state::StateType, Action, Context, Error, Event, IntoAction, Result, State,
Transition,
};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct Machine<S = State, E = Event>
where
S: Clone + 'static + Default,
E: Clone + 'static,
{
pub name: String,
pub states: HashMap<String, State>,
pub transitions: Vec<Transition>,
pub initial: String,
pub current_states: HashSet<String>,
pub context: Context,
#[serde(skip)]
pub(crate) entry_actions: HashMap<String, Vec<Action>>,
#[serde(skip)]
pub(crate) exit_actions: HashMap<String, Vec<Action>>,
pub(crate) history: HashMap<String, String>,
#[serde(skip)]
#[serde(default)]
state_mapper: Option<fn(&str) -> S>,
#[serde(skip)]
current_state_cache: Option<S>,
#[serde(skip)]
_phantom_s: std::marker::PhantomData<S>,
#[serde(skip)]
_phantom_e: std::marker::PhantomData<E>,
}
impl<S, E> Machine<S, E>
where
S: Clone + 'static + Default,
E: Clone + 'static,
{
pub fn new<BuilderS, BuilderE>(builder: MachineBuilder<BuilderS, BuilderE>) -> Result<Self>
where
BuilderS: Clone + 'static + Default,
BuilderE: Clone + 'static,
{
let MachineBuilder {
name,
states,
transitions,
initial,
entry_actions,
exit_actions,
context,
_phantom_s: _,
_phantom_e: _,
} = builder;
if states.is_empty() {
return Err(Error::InvalidConfiguration("No states defined".into()));
}
if !states.contains_key(&initial) {
return Err(Error::StateNotFound(initial.clone()));
}
let mut machine = Self {
name,
states,
transitions,
initial,
current_states: HashSet::new(),
context: context.unwrap_or_else(Context::new),
entry_actions,
exit_actions,
history: HashMap::new(),
state_mapper: None,
current_state_cache: None,
_phantom_s: std::marker::PhantomData,
_phantom_e: std::marker::PhantomData,
};
machine.initialize()?;
Ok(machine)
}
fn initialize(&mut self) -> Result<()> {
let initial_state_id = self.initial.clone();
self.enter_state(&initial_state_id, &Event::new("init"))?;
Ok(())
}
pub fn send<EV: IntoEvent>(&mut self, event: EV) -> Result<bool> {
let event = event.into_event();
let mut processed = false;
let current_states: Vec<_> = self.current_states.iter().cloned().collect();
for state_id in current_states {
if self.process_state_event(&state_id, &event)? {
processed = true;
}
}
if processed {
self.current_state_cache = None;
}
Ok(processed)
}
fn process_state_event(&mut self, state_id: &str, event: &Event) -> Result<bool> {
let enabled_transition = self
.transitions
.iter()
.find(|t| t.source == state_id && t.is_enabled(&self.context, event))
.cloned();
if let Some(transition) = enabled_transition {
self.execute_transition(&transition, event)?;
Ok(true)
} else {
if let Some(parent_id) = self.get_parent_id(state_id) {
self.process_state_event(&parent_id, event)
} else {
Ok(false)
}
}
}
fn execute_transition(&mut self, transition: &Transition, event: &Event) -> Result<()> {
let source_id = transition.source.clone();
if transition.target.is_none() {
transition.execute_actions(&mut self.context, event);
return Ok(());
}
let target_id = transition
.target
.as_ref()
.ok_or_else(|| Error::InvalidTransition("No target state".into()))?
.clone();
self.exit_state(&source_id, event)?;
transition.execute_actions(&mut self.context, event);
self.enter_state(&target_id, event)?;
Ok(())
}
fn enter_state(&mut self, state_id: &str, event: &Event) -> Result<()> {
let state = self
.states
.get(state_id)
.ok_or_else(|| Error::StateNotFound(state_id.to_string()))?
.clone();
self.current_states.insert(state_id.to_string());
if let Some(actions) = self.entry_actions.get(state_id) {
for action in actions.clone() {
action.execute(&mut self.context, event);
}
}
match state.state_type {
StateType::Compound => {
if let Some(initial) = state.initial {
self.enter_state(&initial, event)?;
} else {
return Err(Error::InvalidConfiguration(format!(
"Compound state '{}' has no initial state",
state_id
)));
}
}
StateType::Parallel => {
for child_id in state.children {
self.enter_state(&child_id, event)?;
}
}
StateType::History => {
if let Some(last_active) = self.history.get(state_id).cloned() {
self.enter_state(&last_active, event)?;
} else if let Some(parent_id) = self.get_parent_id(state_id) {
let parent = self
.states
.get(&parent_id)
.ok_or_else(|| Error::StateNotFound(parent_id.to_string()))?
.clone();
if let Some(initial) = parent.initial {
self.enter_state(&initial, event)?;
}
}
}
StateType::DeepHistory => {
if let Some(last_active) = self.history.get(state_id).cloned() {
self.enter_state(&last_active, event)?;
} else if let Some(parent_id) = self.get_parent_id(state_id) {
let parent = self
.states
.get(&parent_id)
.ok_or_else(|| Error::StateNotFound(parent_id.to_string()))?
.clone();
if let Some(initial) = parent.initial {
self.enter_state(&initial, event)?;
}
}
}
_ => {} }
Ok(())
}
fn exit_state(&mut self, state_id: &str, event: &Event) -> Result<()> {
let state = self
.states
.get(state_id)
.ok_or_else(|| Error::StateNotFound(state_id.to_string()))?
.clone();
if let Some(parent_id) = self.get_parent_id(state_id) {
self.history.insert(parent_id, state_id.to_string());
}
match state.state_type {
StateType::Compound | StateType::Parallel => {
let active_children: Vec<_> = state
.children
.iter()
.filter(|child_id| self.current_states.contains(*child_id))
.cloned()
.collect();
for child_id in active_children {
self.exit_state(&child_id, event)?;
}
}
_ => {} }
if let Some(actions) = self.exit_actions.get(state_id) {
for action in actions.clone() {
action.execute(&mut self.context, event);
}
}
self.current_states.remove(state_id);
Ok(())
}
fn get_parent_id(&self, state_id: &str) -> Option<String> {
self.states
.get(state_id)
.and_then(|state| state.parent.clone())
}
pub fn is_in(&self, state_id: &str) -> bool {
self.current_states.contains(state_id)
}
pub fn to_json(&self) -> Result<String> {
let json = serde_json::to_string_pretty(self)?;
Ok(json)
}
pub fn from_json(json: &str) -> Result<Self> {
let machine: Self = serde_json::from_str(json)?;
Ok(machine)
}
pub fn with_state_mapper(mut self, mapper: fn(&str) -> S) -> Self
{
self.state_mapper = Some(mapper);
self
}
pub fn current_state(&self) -> S {
if self.current_states.is_empty() {
panic!("ステートマシンが初期化されていません。send() を呼び出す前に initialize() を呼び出してください。");
}
let state_id = self.current_states.iter().next().unwrap();
if let Some(mapper) = self.state_mapper {
mapper(state_id)
} else {
panic!("状態マッパーが設定されていません。Machine::with_state_mapper()を使用してマッパーを設定してください。");
}
}
pub fn transition<EV: IntoEvent>(&mut self, event: EV, context: Context) -> Result<S> {
self.context = context;
let event = event.into_event();
let result = self.send(event)?;
if result {
self.current_state_cache = None;
}
Ok(self.current_state())
}
}
pub struct MachineBuilder<S = State, E = Event>
where
S: Clone + 'static + Default,
E: Clone + 'static,
{
pub name: String,
pub states: HashMap<String, State>,
pub transitions: Vec<Transition>,
pub initial: String,
pub context: Option<Context>,
pub(crate) entry_actions: HashMap<String, Vec<Action>>,
pub(crate) exit_actions: HashMap<String, Vec<Action>>,
_phantom_s: std::marker::PhantomData<S>,
_phantom_e: std::marker::PhantomData<E>,
}
impl<S, E> MachineBuilder<S, E>
where
S: Clone + 'static + Default,
E: Clone + 'static,
{
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
states: HashMap::new(),
transitions: Vec::new(),
initial: String::new(),
context: None,
entry_actions: HashMap::new(),
exit_actions: HashMap::new(),
_phantom_s: std::marker::PhantomData,
_phantom_e: std::marker::PhantomData,
}
}
pub fn initial(mut self, state_id: impl Into<String>) -> Self {
self.initial = state_id.into();
self
}
pub fn state(mut self, state: State) -> Self {
if self.states.is_empty() && self.initial.is_empty() {
self.initial = state.id.clone();
}
self.states.insert(state.id.clone(), state);
self
}
pub fn transition(mut self, transition: Transition) -> Self {
self.transitions.push(transition);
self
}
pub fn on_entry<A: IntoAction>(mut self, state_id: impl Into<String>, action: A) -> Self {
let state_id = state_id.into();
let action = action.into_action(ActionType::Entry);
self.entry_actions
.entry(state_id)
.or_default()
.push(action);
self
}
pub fn on_exit<A: IntoAction>(mut self, state_id: impl Into<String>, action: A) -> Self {
let state_id = state_id.into();
let action = action.into_action(ActionType::Exit);
self.exit_actions
.entry(state_id)
.or_default()
.push(action);
self
}
pub fn context(mut self, context: Context) -> Self {
self.context = Some(context);
self
}
pub fn build(self) -> Result<Machine<S, E>> {
Machine::new(self)
}
}