use super::state_variable_config::StateVariableConfig;
use super::StateVariable;
use super::{
custom_variable_config::CustomVariableConfig, state_model_error::StateModelError,
update_operation::UpdateOperation,
};
use crate::model::state::InputFeature;
use crate::model::unit::{
DistanceUnit, EnergyUnit, RatioUnit, SpeedUnit, TemperatureUnit, TimeUnit,
};
use indexmap::IndexMap;
use itertools::Itertools;
use serde_json::json;
use std::collections::HashMap;
use std::iter::Enumerate;
use uom::si::f64::*;
#[derive(Debug)]
pub struct StateModel(IndexMap<String, StateVariableConfig>);
type FeatureIterator<'a> = Box<dyn Iterator<Item = (&'a String, &'a StateVariableConfig)> + 'a>;
type IndexedFeatureIterator<'a> =
Box<Enumerate<indexmap::map::Iter<'a, String, StateVariableConfig>>>;
impl StateModel {
pub fn new(features: Vec<(String, StateVariableConfig)>) -> StateModel {
let deduped = deduplicate(&features);
let map = IndexMap::from_iter(deduped);
StateModel(map)
}
pub fn empty() -> StateModel {
StateModel(IndexMap::new())
}
pub fn register(
&self,
input_features: Vec<InputFeature>,
output_features: Vec<(String, StateVariableConfig)>,
) -> Result<StateModel, StateModelError> {
let mut map = self
.0
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect::<IndexMap<_, _>>();
let overwrites = output_features
.iter()
.flat_map(|(name, new)| match map.insert(name.clone(), new.clone()) {
Some(old) if old != *new => Some((name.clone(), old, new)),
_ => None,
})
.collect_vec();
if !overwrites.is_empty() {
let msg = overwrites
.iter()
.map(|(k, old, new)| format!("{k} old: {old} | new: {new}"))
.join(", ");
return Err(StateModelError::BuildError(format!(
"new output features overwriting existing: {msg}"
)));
}
let disconnected = input_features
.into_iter()
.flat_map(|feature| match map.get(&feature.name()) {
Some(_) => None,
None => Some(feature.name()),
})
.collect_vec();
if !disconnected.is_empty() {
let msg = disconnected.iter().join(", ");
return Err(StateModelError::BuildError(format!(
"new input features required but no other model produces these features: {msg}"
)));
}
for (name, feature) in output_features.iter() {
map.insert(name.to_string(), feature.clone());
}
Ok(Self(map))
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn contains_key(&self, k: &str) -> bool {
self.0.contains_key(k)
}
pub fn keys<'a>(&'a self) -> Box<dyn Iterator<Item = &'a String> + 'a> {
Box::new(self.0.keys())
}
pub fn to_vec(&self) -> Vec<(String, StateVariableConfig)> {
self.0
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect_vec()
}
pub fn iter(&self) -> FeatureIterator<'_> {
Box::new(self.0.iter())
}
pub fn indexed_iter<'a>(&'a self) -> IndexedFeatureIterator<'a> {
Box::new(self.0.iter().enumerate())
}
pub fn is_accumlator(&self, name: &str) -> Result<bool, StateModelError> {
let feature = self.get_feature(name)?;
Ok(feature.is_accumulator())
}
pub fn initial_state(
&self,
prev_state: Option<&[StateVariable]>,
) -> Result<Vec<StateVariable>, StateModelError> {
let mut result: Vec<StateVariable> = Vec::with_capacity(self.0.len());
for (idx, (name, feature)) in self.0.iter().enumerate() {
let value = match prev_state {
Some(prev) if feature.is_accumulator() => {
prev.get(idx)
.ok_or_else(|| StateModelError::RuntimeError(format!("while initializing state variable '{name}' (not an accumulator), did not find expected previous value at index {idx} in previous state")))
.cloned()
},
_ => feature.initial_value(),
}?;
result.push(value);
}
Ok(result)
}
pub fn get_distance(
&self,
state: &[StateVariable],
name: &str,
) -> Result<Length, StateModelError> {
let value: &StateVariable = self.get_raw_state_variable(state, name)?;
let length = DistanceUnit::default().to_uom(value.0);
Ok(length)
}
pub fn get_time(&self, state: &[StateVariable], name: &str) -> Result<Time, StateModelError> {
let value: &StateVariable = self.get_raw_state_variable(state, name)?;
let time = TimeUnit::default().to_uom(value.0);
Ok(time)
}
pub fn get_energy(
&self,
state: &[StateVariable],
name: &str,
) -> Result<Energy, StateModelError> {
let value: &StateVariable = self.get_raw_state_variable(state, name)?;
let energy = EnergyUnit::default().to_uom(value.0);
Ok(energy)
}
pub fn get_speed(
&self,
state: &[StateVariable],
name: &str,
) -> Result<Velocity, StateModelError> {
let value: &StateVariable = self.get_raw_state_variable(state, name)?;
let speed = SpeedUnit::default().to_uom(value.0);
Ok(speed)
}
pub fn get_ratio(&self, state: &[StateVariable], name: &str) -> Result<Ratio, StateModelError> {
let value: &StateVariable = self.get_raw_state_variable(state, name)?;
let grade = RatioUnit::default().to_uom(value.0);
Ok(grade)
}
pub fn get_temperature(
&self,
state: &[StateVariable],
name: &str,
) -> Result<ThermodynamicTemperature, StateModelError> {
let value: &StateVariable = self.get_raw_state_variable(state, name)?;
let temperature = TemperatureUnit::default().to_uom(value.0);
Ok(temperature)
}
pub fn get_custom_f64(
&self,
state: &[StateVariable],
name: &str,
) -> Result<f64, StateModelError> {
let (value, format) = self.get_custom_state_variable(state, name)?;
let result = format.decode_f64(value)?;
Ok(result)
}
pub fn get_custom_i64(
&self,
state: &[StateVariable],
name: &str,
) -> Result<i64, StateModelError> {
let (value, format) = self.get_custom_state_variable(state, name)?;
let result = format.decode_i64(value)?;
Ok(result)
}
pub fn get_custom_u64(
&self,
state: &[StateVariable],
name: &str,
) -> Result<u64, StateModelError> {
let (value, format) = self.get_custom_state_variable(state, name)?;
let result = format.decode_u64(value)?;
Ok(result)
}
pub fn get_custom_bool(
&self,
state: &[StateVariable],
name: &str,
) -> Result<bool, StateModelError> {
let (value, format) = self.get_custom_state_variable(state, name)?;
let result = format.decode_bool(value)?;
Ok(result)
}
fn get_custom_state_variable<'a>(
&self,
state: &'a [StateVariable],
name: &str,
) -> Result<(&'a StateVariable, &CustomVariableConfig), StateModelError> {
let value = self.get_raw_state_variable(state, name)?;
let feature = self.get_feature(name)?;
let format = feature.get_custom_feature_format()?;
Ok((value, format))
}
pub fn get_delta<T: From<StateVariable>>(
&self,
prev: &[StateVariable],
next: &[StateVariable],
name: &str,
) -> Result<T, StateModelError> {
let prev_val = self.get_raw_state_variable(prev, name)?;
let next_val = self.get_raw_state_variable(next, name)?;
let delta = *next_val - *prev_val;
Ok(delta.into())
}
pub fn add_distance(
&self,
state: &mut [StateVariable],
name: &str,
distance: &Length,
) -> Result<(), StateModelError> {
let prev_distance: Length = self.get_distance(state, name)?;
let next_distance = prev_distance + *distance;
self.set_distance(state, name, &next_distance)
}
pub fn add_time(
&self,
state: &mut [StateVariable],
name: &str,
time: &Time,
) -> Result<(), StateModelError> {
let prev_time = self.get_time(state, name)?;
let next_time = prev_time + *time;
self.set_time(state, name, &next_time)
}
pub fn add_energy(
&self,
state: &mut [StateVariable],
name: &str,
energy: &Energy,
) -> Result<(), StateModelError> {
let prev_energy = self.get_energy(state, name)?;
let next_energy = prev_energy + *energy;
self.set_energy(state, name, &next_energy)
}
pub fn add_speed(
&self,
state: &mut [StateVariable],
name: &str,
speed: &Velocity,
) -> Result<(), StateModelError> {
let prev_speed = self.get_speed(state, name)?;
let next_speed = prev_speed + *speed;
self.set_speed(state, name, &next_speed)
}
pub fn add_grade(
&self,
state: &mut [StateVariable],
name: &str,
grade: &Ratio,
) -> Result<(), StateModelError> {
let prev_grade = self.get_ratio(state, name)?;
let next_grade = prev_grade + *grade;
self.set_ratio(state, name, &next_grade)
}
pub fn set_distance(
&self,
state: &mut [StateVariable],
name: &str,
distance: &Length,
) -> Result<(), StateModelError> {
let value = StateVariable(DistanceUnit::default().from_uom(*distance));
self.update_state(state, name, &value, UpdateOperation::Replace)
}
pub fn set_time(
&self,
state: &mut [StateVariable],
name: &str,
time: &Time,
) -> Result<(), StateModelError> {
let value = StateVariable(TimeUnit::default().from_uom(*time));
self.update_state(state, name, &value, UpdateOperation::Replace)
}
pub fn set_energy(
&self,
state: &mut [StateVariable],
name: &str,
energy: &Energy,
) -> Result<(), StateModelError> {
let value = StateVariable(EnergyUnit::default().from_uom(*energy));
self.update_state(state, name, &value, UpdateOperation::Replace)
}
pub fn set_ratio(
&self,
state: &mut [StateVariable],
name: &str,
grade: &Ratio,
) -> Result<(), StateModelError> {
let value = StateVariable(RatioUnit::default().from_uom(*grade));
self.update_state(state, name, &value, UpdateOperation::Replace)
}
pub fn set_speed(
&self,
state: &mut [StateVariable],
name: &str,
speed: &Velocity,
) -> Result<(), StateModelError> {
let value = StateVariable(SpeedUnit::default().from_uom(*speed));
self.update_state(state, name, &value, UpdateOperation::Replace)
}
pub fn set_temperature(
&self,
state: &mut [StateVariable],
name: &str,
temperature: &ThermodynamicTemperature,
) -> Result<(), StateModelError> {
let value = StateVariable(TemperatureUnit::default().from_uom(*temperature));
self.update_state(state, name, &value, UpdateOperation::Replace)
}
pub fn set_custom_f64(
&self,
state: &mut [StateVariable],
name: &str,
value: &f64,
) -> Result<(), StateModelError> {
let feature = self.get_feature(name)?;
let format = feature.get_custom_feature_format()?;
let encoded_value = format.encode_f64(value)?;
self.update_state(state, name, &encoded_value, UpdateOperation::Replace)
}
pub fn set_custom_i64(
&self,
state: &mut [StateVariable],
name: &str,
value: &i64,
) -> Result<(), StateModelError> {
let feature = self.get_feature(name)?;
let format = feature.get_custom_feature_format()?;
let encoded_value = format.encode_i64(value)?;
self.update_state(state, name, &encoded_value, UpdateOperation::Replace)
}
pub fn set_custom_u64(
&self,
state: &mut [StateVariable],
name: &str,
value: &u64,
) -> Result<(), StateModelError> {
let feature = self.get_feature(name)?;
let format = feature.get_custom_feature_format()?;
let encoded_value = format.encode_u64(value)?;
self.update_state(state, name, &encoded_value, UpdateOperation::Replace)
}
pub fn set_custom_bool(
&self,
state: &mut [StateVariable],
name: &str,
value: &bool,
) -> Result<(), StateModelError> {
let feature = self.get_feature(name)?;
let format = feature.get_custom_feature_format()?;
let encoded_value = format.encode_bool(value)?;
self.update_state(state, name, &encoded_value, UpdateOperation::Replace)
}
pub fn serialize_state(
&self,
state: &[StateVariable],
accumulators_only: bool,
) -> Result<serde_json::Value, StateModelError> {
let output = self
.iter()
.zip(state.iter())
.filter(|((_, feature), _)| !accumulators_only || feature.is_accumulator())
.map(|((name, feature), state_var)| {
let serialized = feature.serialize_variable(state_var)?;
Ok((name, serialized))
})
.collect::<Result<HashMap<_, _>, StateModelError>>()?;
Ok(json![output])
}
pub fn serialize_state_model(&self) -> serde_json::Value {
let mut out = serde_json::Map::new();
for (i, (name, feature)) in self.indexed_iter() {
let mut f_json = json![feature];
if let Some(map) = f_json.as_object_mut() {
map.insert(String::from("index"), json![i]);
map.insert(String::from("name"), json![name]);
}
out.insert(name.clone(), f_json);
}
json![out]
}
pub fn get_names(&self) -> String {
self.0.iter().map(|(k, _)| k.clone()).join(",")
}
fn get_feature(&self, feature_name: &str) -> Result<&StateVariableConfig, StateModelError> {
self.0.get(feature_name).ok_or_else(|| {
StateModelError::UnknownStateVariableName(feature_name.to_string(), self.get_names())
})
}
pub fn get_raw_state_variable<'a>(
&self,
state: &'a [StateVariable],
name: &str,
) -> Result<&'a StateVariable, StateModelError> {
let idx = self.0.get_index_of(name).ok_or_else(|| {
StateModelError::UnknownStateVariableName(name.to_string(), self.get_names())
})?;
let value = state.get(idx).ok_or_else(|| {
StateModelError::RuntimeError(format!(
"state index {} for {} is out of range for state vector with {} entries",
idx,
name,
state.len()
))
})?;
Ok(value)
}
fn update_state(
&self,
state: &mut [StateVariable],
name: &str,
value: &StateVariable,
op: UpdateOperation,
) -> Result<(), StateModelError> {
let index = self.0.get_index_of(name).ok_or_else(|| {
StateModelError::UnknownStateVariableName(name.to_string(), self.get_names())
})?;
let prev = state
.get(index)
.ok_or(StateModelError::InvalidStateVariableIndex(
name.to_string(),
index,
state.len(),
))?;
let updated = op.perform_operation(prev, value);
state[index] = updated;
Ok(())
}
}
fn deduplicate(features: &[(String, StateVariableConfig)]) -> HashMap<String, StateVariableConfig> {
let mut map: HashMap<String, StateVariableConfig> = HashMap::new();
for (name, next) in features.iter() {
map.entry(name.clone())
.and_modify(|prev| {
match (prev.get_unit_name(), next.get_unit_name()) {
(Some(prev_unit), Some(next_unit)) => {
log::warn!("{}", duplicate_message(name, &prev_unit, &next_unit));
}
(None, Some(_)) => {
*prev = next.clone();
}
_ => { }
}
})
.or_insert(next.clone());
}
map
}
fn duplicate_message(name: &str, prev: &str, next: &str) -> String {
let s1 = format!("Two traversal models are producing output feature '{name}'");
let s2 = format!("Previous uses '{prev}', next uses '{next}'. Keeping '{prev}'");
let s3 = "To avoid non-deterministic behavior, only set output_unit in one traversal model.";
format!("{s1} {s2} {s3}")
}
impl<'a> TryFrom<&'a serde_json::Value> for StateModel {
type Error = StateModelError;
fn try_from(json: &'a serde_json::Value) -> Result<StateModel, StateModelError> {
let value = json
.as_object()
.ok_or_else(|| {
StateModelError::BuildError(String::from(
"expected state model configuration to be a JSON object {}",
))
})?
.into_iter()
.map(|(feature_name, feature_json)| {
let feature = serde_json::from_value::<StateVariableConfig>(feature_json.clone())
.map_err(|e| {
StateModelError::BuildError(format!(
"unable to parse state feature row with name '{}' contents '{}' due to: {}",
feature_name.clone(),
feature_json.clone(),
e
))
})?;
Ok((feature_name.clone(), feature))
})
.collect::<Result<Vec<_>, StateModelError>>()?;
let state_model = StateModel::new(value);
Ok(state_model)
}
}