entromatica 0.23.0

A simulation library for markov chains with a uniform steady state distribution centered around their entropy
Documentation
use std::collections::hash_map::DefaultHasher;
use std::fmt::Debug;
use std::fmt::Display;
use std::hash::{Hash, Hasher};
use std::sync::mpsc::SendError;
use std::sync::Mutex;

use backtrace::Backtrace as trc;
use derive_more::*;
use hashbrown::HashMap;
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use thiserror::Error;

use crate::prelude::*;

#[derive(
    Clone,
    PartialEq,
    Eq,
    Hash,
    Debug,
    Display,
    Default,
    From,
    AsRef,
    AsMut,
    Into,
    Serialize,
    Deserialize,
)]
pub struct ParameterName(String);

impl ParameterName {
    pub fn new(name: &str) -> Self {
        Self(name.to_string())
    }
}

#[derive(Hash, Clone, PartialEq, Debug, Default, Serialize, Deserialize)]
pub struct Parameter<T> {
    value: T,
}

impl<
        T: Hash
            + Clone
            + PartialEq
            + Debug
            + Default
            + Serialize
            + Send
            + Sync
            + for<'a> Deserialize<'a>,
    > Parameter<T>
{
    pub fn new(value: T) -> Self {
        Self { value }
    }

    pub fn value(&self) -> &T {
        &self.value
    }

    pub fn value_mut(&mut self) -> &mut T {
        &mut self.value
    }

    pub fn set_value(&mut self, value: T) {
        self.value = value;
    }
}

#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct Entity<T> {
    parameters: HashMap<ParameterName, Parameter<T>>,
}

impl<T: Hash> Hash for Entity<T> {
    fn hash<H: Hasher>(&self, state: &mut H) {
        let mut hasher = DefaultHasher::new();
        for (parameter_name, parameter) in &self.parameters {
            parameter_name.hash(&mut hasher);
            parameter.hash(&mut hasher);
        }
        hasher.finish().hash(state);
    }
}

impl<T: PartialEq> PartialEq for Entity<T> {
    fn eq(&self, other: &Self) -> bool {
        self.parameters == other.parameters
    }
}

impl<T: Debug> Display for Entity<T> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        writeln!(f, "Entity:")?;
        for (parameter_name, parameter) in &self.parameters {
            writeln!(f, "  {parameter_name}: {parameter:#?}")?;
        }
        Ok(())
    }
}

impl<T> Entity<T> {
    pub fn new(parameters: Vec<(ParameterName, Parameter<T>)>) -> Self {
        Self {
            parameters: parameters.into_iter().collect(),
        }
    }

    pub fn parameter(&self, parameter_name: &ParameterName) -> Result<&Parameter<T>, EntityError> {
        self.parameters
            .get(parameter_name)
            .ok_or_else(|| EntityError::ParameterNotFound {
                parameter_name: parameter_name.clone(),
                context: get_backtrace(),
            })
    }

    pub fn parameter_mut(
        &mut self,
        parameter_name: &ParameterName,
    ) -> Result<&mut Parameter<T>, EntityError> {
        self.parameters
            .get_mut(parameter_name)
            .ok_or_else(|| EntityError::ParameterNotFound {
                parameter_name: parameter_name.clone(),
                context: get_backtrace(),
            })
    }

    pub fn iter_parameters(&self) -> impl Iterator<Item = (&ParameterName, &Parameter<T>)> {
        self.parameters.iter()
    }

    pub fn iter_parameters_mut(
        &mut self,
    ) -> impl Iterator<Item = (&ParameterName, &mut Parameter<T>)> {
        self.parameters.iter_mut()
    }
}

#[non_exhaustive]
#[derive(Debug, Clone, Error)]
pub enum EntityError {
    #[error("Parameter not found: {parameter_name:#?}")]
    ParameterNotFound {
        parameter_name: ParameterName,
        context: trc,
    },
}

#[derive(
    Clone,
    PartialEq,
    Eq,
    Hash,
    Debug,
    Display,
    Default,
    From,
    Into,
    AsRef,
    AsMut,
    Serialize,
    Deserialize,
)]
pub struct EntityName(pub String);

impl EntityName {
    pub fn new(name: &str) -> Self {
        Self(name.to_string())
    }
}

#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct State<T> {
    entities: HashMap<EntityName, Entity<T>>,
}

impl<
        T: Hash
            + Clone
            + PartialEq
            + Debug
            + Default
            + Serialize
            + Send
            + Sync
            + for<'a> Deserialize<'a>,
    > Display for State<T>
{
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        writeln!(f, "State:")?;
        for (entity_name, entity) in &self.entities {
            writeln!(f, "  {entity_name}:")?;
            for (parameter_name, parameter) in &entity.parameters {
                writeln!(f, "    {parameter_name}: {parameter:#?}")?;
            }
        }
        Ok(())
    }
}

impl<T: Hash> Hash for State<T> {
    fn hash<H: Hasher>(&self, state: &mut H) {
        let mut hasher = DefaultHasher::new();
        for (entity_name, entity) in &self.entities {
            entity_name.hash(&mut hasher);
            entity.hash(&mut hasher);
        }
        hasher.finish().hash(state);
    }
}

impl<T: PartialEq> PartialEq for State<T> {
    fn eq(&self, other: &Self) -> bool {
        self.entities == other.entities
    }
}

impl<
        T: for<'a> Deserialize<'a>
            + std::hash::Hash
            + std::clone::Clone
            + std::cmp::PartialEq
            + std::fmt::Debug
            + std::default::Default
            + std::marker::Send
            + std::marker::Sync
            + Serialize,
    > State<T>
{
    pub fn new(entities: Vec<(EntityName, Entity<T>)>) -> Self {
        Self {
            entities: entities.into_iter().collect(),
        }
    }

    pub fn entity(&self, entity_name: &EntityName) -> Result<&Entity<T>, StateError> {
        self.entities
            .get(entity_name)
            .ok_or_else(|| StateError::EntityNotFound {
                entity_name: entity_name.clone(),
                context: get_backtrace(),
            })
    }

    pub fn insert_entity(&mut self, entity_name: EntityName, entity: Entity<T>) {
        self.entities.insert(entity_name, entity);
    }

    pub fn entity_mut(&mut self, entity_name: &EntityName) -> Result<&mut Entity<T>, StateError> {
        self.entities
            .get_mut(entity_name)
            .ok_or_else(|| StateError::EntityNotFound {
                entity_name: entity_name.clone(),
                context: get_backtrace(),
            })
    }

    pub fn iter_entities(&self) -> impl Iterator<Item = (&EntityName, &Entity<T>)> {
        self.entities.iter()
    }

    pub fn iter_entities_mut(&mut self) -> impl Iterator<Item = (&EntityName, &mut Entity<T>)> {
        self.entities.iter_mut()
    }

    pub fn set_parameter(
        &mut self,
        target: &EntityName,
        parameter_name: ParameterName,
        parameter_val: Parameter<T>,
    ) -> Result<(), StateError> {
        let entity = self.entity_mut(target)?;
        let parameter = entity.parameter_mut(&parameter_name)?;
        *parameter = parameter_val;
        Ok(())
    }

    #[allow(clippy::type_complexity)]
    pub(crate) fn reachable_states(
        &self,
        base_state_probability: &Probability,
        rules: &HashMap<RuleName, Rule<T>>,
        possible_states: &PossibleStates<T>,
        cache: &Cache,
    ) -> Result<
        (
            ReachableStates,
            PossibleStates<T>,
            Vec<ConditionCacheUpdate>,
            Vec<ActionCacheUpdate>,
        ),
        ErrorKind<T>,
    > {
        let base_state_hash = StateHash::new(self);
        let mut new_base_state_probability: Probability = *base_state_probability;
        let mut applying_rules_probability_weight_sum = ProbabilityWeight::from(0.);
        let mut reachable_states_by_rule_probability_weight: HashMap<StateHash, ProbabilityWeight> =
            HashMap::new();

        let mut condition_cache_updates = Vec::new();
        let mut action_cache_updates = Vec::new();

        let mut new_possible_states = PossibleStates::new(HashMap::new());

        for (rule_name, rule) in rules {
            let base_state = possible_states.state(&base_state_hash)?;
            let (rule_applies, condition_cache_update) =
                rule.applies(cache, rule_name.clone(), base_state.clone())?;
            if let Some(cache) = condition_cache_update {
                condition_cache_updates.push(cache);
            }
            if rule_applies.is_true() {
                applying_rules_probability_weight_sum += rule.weight();
                let (new_state, action_cache_update) = rule.apply(
                    cache,
                    possible_states,
                    rule_name.clone(),
                    base_state_hash,
                    base_state.clone(),
                )?;
                if new_state != *self {
                    new_base_state_probability *= 1. - f64::from(rule.weight());
                }
                if let Some(cache) = action_cache_update {
                    action_cache_updates.push(cache);
                }
                let new_state_hash = StateHash::new(&new_state);
                new_possible_states.append_state(new_state_hash, new_state)?;
                reachable_states_by_rule_probability_weight.insert(new_state_hash, rule.weight());
            }
        }

        let mut new_reachable_states = ReachableStates::new();
        if new_base_state_probability > Probability::from(0.) {
            new_reachable_states.append_state(base_state_hash, new_base_state_probability)?;
        }

        let probabilities_for_reachable_states_from_base_state = self
            .probabilities_for_reachable_states(
                reachable_states_by_rule_probability_weight,
                *base_state_probability,
                new_base_state_probability,
                applying_rules_probability_weight_sum,
            );

        for (new_state_hash, new_state_probability) in
            probabilities_for_reachable_states_from_base_state.iter()
        {
            new_reachable_states.append_state(*new_state_hash, *new_state_probability)?;
        }
        Ok((
            new_reachable_states,
            new_possible_states,
            condition_cache_updates,
            action_cache_updates,
        ))
    }

    fn probabilities_for_reachable_states(
        &self,
        reachable_states_by_rule_probability_weight: HashMap<StateHash, ProbabilityWeight>,
        base_state_probability: Probability,
        new_base_state_probability: Probability,
        applying_rules_probability_weight_sum: ProbabilityWeight,
    ) -> ReachableStates {
        ReachableStates::from(HashMap::from_par_iter(
            reachable_states_by_rule_probability_weight
                .par_iter()
                .filter_map(|(new_reachable_state_hash, rule_probability_weight)| {
                    if *new_reachable_state_hash != StateHash::new(self) {
                        let new_reachable_state_probability =
                            Probability::from_probability_weight(*rule_probability_weight)
                                * f64::from(base_state_probability)
                                * f64::from(Probability::from(1.) - new_base_state_probability)
                                / f64::from(applying_rules_probability_weight_sum);
                        Option::Some((*new_reachable_state_hash, new_reachable_state_probability))
                    } else {
                        Option::None
                    }
                }),
        ))
    }
}

#[non_exhaustive]
#[derive(Debug, Clone, Error)]
pub enum StateError {
    #[error("Entity not found: {entity_name:#?}")]
    EntityNotFound {
        entity_name: EntityName,
        context: trc,
    },

    #[error("Parameter {parameter_name:#?} already affected for entity {entity_name:#?}")]
    ParameterAlreadyAffected {
        parameter_name: ParameterName,
        entity_name: EntityName,
        context: trc,
    },

    #[error("EntityError: {0:#?}")]
    EntityError(#[from] EntityError),
}

#[derive(
    Copy,
    Clone,
    PartialEq,
    Eq,
    Hash,
    Debug,
    Display,
    Default,
    From,
    Into,
    AsRef,
    AsMut,
    Serialize,
    Deserialize,
)]
pub struct StateHash(u64);

impl StateHash {
    pub fn new<T>(state: &State<T>) -> Self
    where
        T: Hash + Serialize + for<'a> Deserialize<'a>,
    {
        let mut hasher = &mut DefaultHasher::new();
        state.hash(&mut hasher);
        Self(hasher.finish())
    }
}

#[derive(Clone, Debug, Default, Serialize, Deserialize)]
pub struct PossibleStates<T>(HashMap<StateHash, State<T>>);

impl<T: Debug + for<'a> Deserialize<'a>> Display for PossibleStates<T> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        for (state_hash, state) in &self.0 {
            writeln!(f, "{state_hash}: {state:#?}")?;
        }
        Ok(())
    }
}

impl<T: Hash> Hash for PossibleStates<T> {
    fn hash<H: Hasher>(&self, hasher: &mut H) {
        for (state_hash, state) in &self.0 {
            state_hash.hash(hasher);
            state.hash(hasher);
        }
    }
}

impl<T: PartialEq> PartialEq for PossibleStates<T> {
    fn eq(&self, other: &Self) -> bool {
        self.0.len() == other.0.len()
            && self
                .0
                .iter()
                .all(|(state_hash, state)| other.0.get(state_hash) == Some(state))
    }
}

impl<
        T: Hash
            + Clone
            + PartialEq
            + Debug
            + Default
            + Serialize
            + Send
            + Sync
            + for<'a> Deserialize<'a>,
    > PossibleStates<T>
{
    pub fn new(possible_states: HashMap<StateHash, State<T>>) -> Self {
        Self(possible_states)
    }

    pub(crate) fn append_state(
        &mut self,
        state_hash: StateHash,
        state: State<T>,
    ) -> Result<(), PossibleStatesError<T>> {
        match self.0.get(&state_hash) {
            Some(present_state) => {
                if state != *present_state {
                    Err(PossibleStatesError::StateAlreadyExists {
                        state_hash,
                        context: get_backtrace(),
                    })
                } else {
                    Ok(())
                }
            }
            None => {
                self.0.insert(state_hash, state);
                Ok(())
            }
        }
    }

    pub(crate) fn merge(&mut self, states: &PossibleStates<T>) -> Result<(), ErrorKind<T>> {
        for (state_hash, state) in states.iter() {
            self.append_state(*state_hash, state.clone())?;
        }
        Ok(())
    }

    pub fn state(&self, state_hash: &StateHash) -> Result<&State<T>, PossibleStatesError<T>> {
        self.0
            .get(state_hash)
            .ok_or_else(|| PossibleStatesError::StateNotFound {
                state_hash: *state_hash,
                context: get_backtrace(),
            })
    }

    pub fn iter(&self) -> hashbrown::hash_map::Iter<StateHash, State<T>> {
        self.0.iter()
    }

    pub fn values(&self) -> hashbrown::hash_map::Values<StateHash, State<T>> {
        self.0.values()
    }

    pub fn len(&self) -> usize {
        self.0.len()
    }

    pub fn is_empty(&self) -> bool {
        self.0.is_empty()
    }

    pub fn contains(&self, state_hash: &StateHash) -> bool {
        self.0.contains_key(state_hash)
    }
}

#[non_exhaustive]
#[derive(Debug, Clone, Error)]
pub enum PossibleStatesError<T: Clone + Debug> {
    #[error("State not found: {state_hash:#?}")]
    StateNotFound { state_hash: StateHash, context: trc },

    #[error("State already exists: {state_hash:#?}")]
    StateAlreadyExists { state_hash: StateHash, context: trc },

    #[error("Possible states send error: {source:#?}")]
    PossibleStatesSendError {
        #[source]
        source: SendError<PossibleStates<T>>,
        context: trc,
    },
}

#[derive(
    Clone, PartialEq, Debug, Default, From, Into, AsRef, AsMut, Index, Serialize, Deserialize,
)]
pub struct ReachableStates(HashMap<StateHash, Probability>);

impl Display for ReachableStates {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        for (state_hash, probability) in &self.0 {
            writeln!(f, "{state_hash}: {probability}")?;
        }
        Ok(())
    }
}

impl ReachableStates {
    pub fn new() -> Self {
        Self(HashMap::new())
    }

    pub fn probability(&self, state_hash: &StateHash) -> Probability {
        if let Some(probability) = self.0.get(state_hash) {
            *probability
        } else {
            Probability::from(0.)
        }
    }

    pub fn append_state(
        &mut self,
        state_hash: StateHash,
        state_probability: Probability,
    ) -> Result<(), UnitsError> {
        match self.0.get_mut(&state_hash) {
            Some(probability) => {
                if *probability + state_probability > Probability::from(1.) {
                    return Err(UnitsError::ProbabilityOutOfRange {
                        probability: *probability + state_probability,
                        context: get_backtrace(),
                    });
                }
                *probability += state_probability;
            }
            None => {
                if state_probability > Probability::from(0.) {
                    self.0.insert(state_hash, state_probability);
                }
            }
        }
        Ok(())
    }

    pub fn merge<
        T: Hash
            + Clone
            + PartialEq
            + Debug
            + Default
            + Serialize
            + Send
            + Sync
            + for<'a> Deserialize<'a>,
    >(
        &mut self,
        states: &ReachableStates,
    ) -> Result<(), ErrorKind<T>> {
        for (state_hash, state_probability) in states.iter() {
            self.append_state(*state_hash, *state_probability)?;
        }
        Ok(())
    }

    pub fn values(&self) -> hashbrown::hash_map::Values<StateHash, Probability> {
        self.0.values()
    }

    pub fn iter(&self) -> hashbrown::hash_map::Iter<StateHash, Probability> {
        self.0.iter()
    }

    pub fn iter_mut(&mut self) -> hashbrown::hash_map::IterMut<StateHash, Probability> {
        self.0.iter_mut()
    }

    pub fn par_iter(&self) -> hashbrown::hash_map::rayon::ParIter<'_, StateHash, Probability> {
        self.0.par_iter()
    }

    pub fn len(&self) -> usize {
        self.0.len()
    }

    pub fn is_empty(&self) -> bool {
        self.0.is_empty()
    }

    pub fn contains(&self, state_hash: &StateHash) -> bool {
        self.0.contains_key(state_hash)
    }

    pub fn probability_sum(&self) -> Probability {
        Probability::from(
            self.iter()
                .par_bridge()
                .map(|(_, probability)| probability.to_f64())
                .sum::<f64>(),
        )
    }

    pub fn entropy(&self) -> Entropy {
        Entropy::from(
            self.par_iter()
                .map(|(_, probability)| {
                    if *probability > Probability::from(0.) {
                        f64::from(*probability) * -f64::from(*probability).log2()
                    } else {
                        0.
                    }
                })
                .sum::<f64>(),
        )
    }

    pub fn euclidean_norm(&self, base: &ReachableStates) -> Entropy {
        Entropy::from(
            self.par_iter()
                .map(|(state_hash, probability)| {
                    let base_state_probability = base.probability(state_hash);
                    (probability.to_f64() - base_state_probability.to_f64()).powi(2)
                })
                .sum::<f64>()
                .sqrt(),
        )
    }

    pub(crate) fn apply_rules<T>(
        &self,
        possible_states: &mut PossibleStates<T>,
        cache: &mut Cache,
        rules: &HashMap<RuleName, Rule<T>>,
    ) -> Result<ReachableStates, ErrorKind<T>>
    where
        T: Hash
            + Clone
            + PartialEq
            + Debug
            + Default
            + Serialize
            + Send
            + Sync
            + for<'a> Deserialize<'a>,
    {
        let new_reachable_states_mutex = Mutex::new(ReachableStates::new());
        let possible_states_update_mutex = Mutex::new(PossibleStates::default());
        let cache_update_mutex = Mutex::new(cache.clone());

        self.par_iter()
            .map(|(base_state_hash, base_state_probability)| {
                possible_states.state(base_state_hash)?.reachable_states(
                    base_state_probability,
                    rules,
                    possible_states,
                    cache,
                )
            })
            .try_for_each(|result| {
                if let Ok((
                    new_reachable_states,
                    new_possible_states,
                    condition_cache_updates,
                    action_cache_updates,
                )) = result
                {
                    new_reachable_states_mutex
                        .lock()?
                        .merge(&new_reachable_states)?;
                    possible_states_update_mutex
                        .lock()?
                        .merge(&new_possible_states)?;
                    for condition_cache_update in condition_cache_updates {
                        cache_update_mutex
                            .lock()?
                            .apply_condition_update(condition_cache_update)?;
                    }
                    for action_cache_update in action_cache_updates {
                        cache_update_mutex
                            .lock()?
                            .apply_action_update(action_cache_update)?;
                    }
                    Ok(())
                } else {
                    Err(result.err().unwrap())
                }
            })?;

        if cfg!(debug_assertions) {
            let probability_sum = self.probability_sum();
            if probability_sum != Probability::from(1.) {
                return Err(ErrorKind::UnitsError(UnitsError::ProbabilitySumNot1 {
                    probability_sum,
                    context: get_backtrace(),
                }));
            }
        }

        possible_states.merge(&possible_states_update_mutex.lock()?.clone())?;
        cache.merge(&cache_update_mutex.lock()?.clone())?;
        let new_reachable_states = new_reachable_states_mutex.lock()?.clone();
        Ok(new_reachable_states)
    }
}

#[non_exhaustive]
#[derive(Debug, Clone, Error)]
pub enum ReachableStatesError {
    #[error("State not found: {state_hash:#?}")]
    StateNotFound { state_hash: StateHash, context: trc },

    #[error("Reachable states send error: {source:#?}")]
    ReachableStatesSendError {
        #[source]
        source: SendError<ReachableStates>,
        context: trc,
    },
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn entity_get_parameter_should_return_value_on_present_parameter() {
        let parameters = vec![(ParameterName::new("parameter"), Parameter::new(1))];
        let entity = Entity::new(parameters);
        assert_eq!(
            entity
                .parameter(&ParameterName::new("parameter"))
                .cloned()
                .unwrap(),
            Parameter::new(1)
        );
    }

    #[test]
    fn entity_get_parameter_should_return_error_on_missing_parameter() {
        let parameters = vec![(ParameterName::new("parameter"), Parameter::new(1))];
        let entity = Entity::new(parameters);
        if let Err(EntityError::ParameterNotFound { parameter_name, .. }) =
            entity.parameter(&ParameterName::new("missing_parameter"))
        {
            assert_eq!(parameter_name, ParameterName::new("missing_parameter"));
        } else {
            panic!("Unexpected error type");
        }
    }

    #[test]
    fn state_partial_equal_works_as_expected() {
        let state_a_0 = State::new(vec![(
            EntityName::new("A"),
            Entity::new(vec![(ParameterName::new("Parameter"), Parameter::new(0))]),
        )]);
        let state_a_1 = State::new(vec![(
            EntityName::new("A"),
            Entity::new(vec![(ParameterName::new("Parameter"), Parameter::new(0))]),
        )]);
        let state_b = State::new(vec![(
            EntityName::new("A"),
            Entity::new(vec![(ParameterName::new("Parameter"), Parameter::new(1))]),
        )]);
        let state_c = State::new(vec![(
            EntityName::new("B"),
            Entity::new(vec![(ParameterName::new("Parameter"), Parameter::new(1))]),
        )]);
        assert_eq!(state_a_0, state_a_1);
        assert_ne!(state_a_0, state_b);
        assert_ne!(state_a_1, state_b);
        assert_ne!(state_b, state_c);
    }

    #[test]
    fn state_get_entity_should_return_value_on_present_entity() {
        let state = State::new(vec![(
            EntityName::new("A"),
            Entity::new(vec![(ParameterName::new("Parameter"), Parameter::new(0))]),
        )]);

        assert_eq!(
            state.entity(&EntityName::new("A"),).cloned().unwrap(),
            Entity::new(vec![(ParameterName::new("Parameter"), Parameter::new(0))])
        );
    }

    #[test]
    fn state_get_entity_should_return_error_on_missing_entity() {
        let state = State::new(vec![(
            EntityName::new("A"),
            Entity::new(vec![(ParameterName::new("Parameter"), Parameter::new(0))]),
        )]);

        if let Err(StateError::EntityNotFound { entity_name, .. }) =
            state.entity(&EntityName::new("missing_entity")).cloned()
        {
            assert_eq!(entity_name, EntityName::new("missing_entity"));
        } else {
            panic!("Unexpected error type");
        }
    }

    #[test]
    fn state_get_mut_entity_should_return_value_on_present_entity() {
        let mut state = State::new(vec![(
            EntityName::new("A"),
            Entity::new(vec![(ParameterName::new("Parameter"), Parameter::new(0))]),
        )]);

        assert_eq!(
            state.entity_mut(&EntityName::new("A"),).unwrap(),
            &mut Entity::new(vec![(ParameterName::new("Parameter"), Parameter::new(0))])
        );
    }

    #[test]
    fn state_get_mut_entity_should_return_error_on_missing_entity() {
        let mut state = State::new(vec![(
            EntityName::new("A"),
            Entity::new(vec![(ParameterName::new("Parameter"), Parameter::new(0))]),
        )]);

        if let Err(StateError::EntityNotFound { entity_name, .. }) = state
            .entity_mut(&EntityName::new("missing_entity"))
            .cloned()
        {
            assert_eq!(entity_name, EntityName::new("missing_entity"));
        } else {
            panic!("Unexpected error type");
        }
    }

    #[test]
    fn possible_states_append_state() {
        let state = State::new(vec![(
            EntityName::new("A"),
            Entity::new(vec![
                (ParameterName::new("Parameter"), Parameter::new(0)),
                (ParameterName::new("Parameter2"), Parameter::new(0)),
            ]),
        )]);
        let state_hash = StateHash::new(&state);
        let mut possible_states = PossibleStates::default();
        possible_states
            .append_state(state_hash, state.clone())
            .unwrap();
        let expected = HashMap::from([(state_hash, state)]);
        assert_eq!(possible_states.0, expected);

        let new_state = State::new(vec![(
            EntityName::new("A"),
            Entity::new(vec![
                (ParameterName::new("Parameter"), Parameter::new(1)),
                (ParameterName::new("Parameter2"), Parameter::new(2)),
            ]),
        )]);

        possible_states
            .append_state(state_hash, new_state)
            .unwrap_err();
        assert_eq!(possible_states.0, expected);
    }

    #[test]
    fn reachable_states_append_state() {
        let mut reachable_states = ReachableStates::new();
        let state_hash = StateHash::new::<i32>(&State::default());
        let probability = Probability::from(1.);
        reachable_states
            .append_state(state_hash, probability)
            .unwrap();
        let expected = HashMap::from([(state_hash, probability)]);
        assert_eq!(reachable_states.0, expected);

        reachable_states
            .append_state(state_hash, probability)
            .unwrap_err();
        assert_eq!(reachable_states.0, expected);
    }

    #[test]
    fn reachable_states_probability_sum() {
        let mut reachable_states = ReachableStates::new();
        let state_hash = StateHash::new::<i32>(&State::default());
        let probability = Probability::from(0.2);
        reachable_states
            .append_state(state_hash, probability)
            .unwrap();
        let state_hash = StateHash::new(&State::new(vec![(
            EntityName::new("A"),
            Entity::new(vec![(ParameterName::new("Parameter"), Parameter::new(0))]),
        )]));
        let probability = Probability::from(0.5);
        reachable_states
            .append_state(state_hash, probability)
            .unwrap();
        assert_eq!(reachable_states.probability_sum(), Probability::from(0.7));
    }

    #[test]
    fn reachable_states_entropy() {
        let mut reachable_states = ReachableStates::new();
        assert_eq!(reachable_states.entropy(), Entropy::from(0.));
        let state_hash = StateHash::new::<i32>(&State::default());
        let probability = Probability::from(0.5);
        reachable_states
            .append_state(state_hash, probability)
            .unwrap();
        let state_hash = StateHash::new(&State::new(vec![(
            EntityName::new("A"),
            Entity::new(vec![(ParameterName::new("Parameter"), Parameter::new(0))]),
        )]));
        let probability = Probability::from(0.5);
        reachable_states
            .append_state(state_hash, probability)
            .unwrap();
        assert_eq!(reachable_states.entropy(), Entropy::from(1.));
    }

    #[test]
    fn euclidean_norm() {
        let mut reachable_states = ReachableStates::new();
        reachable_states
            .append_state(StateHash::from(1), Probability::new(0.5))
            .unwrap();
        reachable_states
            .append_state(StateHash::from(2), Probability::new(0.25))
            .unwrap();
        reachable_states
            .append_state(StateHash::from(3), Probability::new(0.25))
            .unwrap();
        let mut base_reachable_states = ReachableStates::new();
        base_reachable_states
            .append_state(StateHash::from(1), Probability::new(0.5))
            .unwrap();
        base_reachable_states
            .append_state(StateHash::from(2), Probability::new(0.5))
            .unwrap();
        assert_eq!(
            Entropy::new(0.3535533905932738),
            reachable_states.euclidean_norm(&base_reachable_states)
        );
    }
}