use super::{
custom_feature_format::CustomFeatureFormat, output_feature::OutputFeature,
state_model_error::StateModelError, update_operation::UpdateOperation,
};
use super::{InputFeature, StateVariable};
use crate::model::unit::{Convert, Grade, GradeUnit, Speed, SpeedUnit};
use crate::util::compact_ordered_hash_map::CompactOrderedHashMap;
use crate::{
model::unit::{Distance, DistanceUnit, Energy, EnergyUnit, Time, TimeUnit},
util::compact_ordered_hash_map::IndexedEntry,
};
use itertools::Itertools;
use serde_json::json;
use std::borrow::Cow;
use std::collections::HashMap;
use std::iter::Enumerate;
pub struct StateModel(CompactOrderedHashMap<String, OutputFeature>);
type FeatureIterator<'a> = Box<dyn Iterator<Item = (&'a String, &'a OutputFeature)> + 'a>;
type IndexedFeatureIterator<'a> =
Enumerate<Box<dyn Iterator<Item = (&'a String, &'a OutputFeature)> + 'a>>;
impl StateModel {
pub fn new(features: Vec<(String, OutputFeature)>) -> StateModel {
let map = CompactOrderedHashMap::new(features);
StateModel(map)
}
pub fn empty() -> StateModel {
StateModel(CompactOrderedHashMap::empty())
}
pub fn register(
&self,
input_features: Vec<(String, InputFeature)>,
output_features: Vec<(String, OutputFeature)>,
) -> Result<StateModel, StateModelError> {
let mut map = self
.0
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect::<CompactOrderedHashMap<_, _>>();
let overwrites = output_features
.iter()
.flat_map(|(name, new)| match map.insert(name.clone(), new.clone()) {
Some(old) if old != *new => Some((name.clone(), old, new.clone())),
_ => None,
})
.collect_vec();
if !overwrites.is_empty() {
let msg = overwrites
.iter()
.map(|(k, old, new)| format!("{} old: {} | new: {}", k, old, new))
.join(", ");
return Err(StateModelError::BuildError(format!(
"new output features overwriting existing: {}",
msg
)));
}
let disconnected = input_features
.into_iter()
.flat_map(|(name, feature)| match map.get(&name) {
Some(_) => None,
None => Some((name, feature)),
})
.collect_vec();
if !disconnected.is_empty() {
let msg = disconnected
.iter()
.map(|(k, v)| format!("({}: {})", k, v))
.join(", ");
return Err(StateModelError::BuildError(format!(
"new input features required but no other model produces these features: {}",
msg
)));
}
for (name, feature) in output_features.iter() {
map.insert(name.to_string(), feature.clone());
}
Ok(Self(map))
}
pub fn len(&self) -> usize {
self.0.len()
}
pub fn is_empty(&self) -> bool {
self.0.is_empty()
}
pub fn contains_key(&self, k: &String) -> bool {
self.0.contains_key(k)
}
pub fn keys<'a>(&'a self) -> Box<dyn Iterator<Item = &'a String> + 'a> {
self.0.keys()
}
pub fn to_vec(&self) -> Vec<(String, IndexedEntry<OutputFeature>)> {
self.0.to_vec()
}
pub fn iter(&self) -> FeatureIterator {
self.0.iter()
}
pub fn indexed_iter(&self) -> IndexedFeatureIterator {
self.0.indexed_iter()
}
pub fn is_accumlator(&self, name: &str) -> Result<bool, StateModelError> {
let feature = self.get_feature(name)?;
Ok(feature.is_accumlator())
}
pub fn initial_state(&self) -> Result<Vec<StateVariable>, StateModelError> {
self.0
.iter()
.map(|(_, feature)| {
let initial = feature.get_initial()?;
Ok(initial)
})
.collect::<Result<Vec<_>, _>>()
}
pub fn get_distance<'a>(
&'a self,
state: &[StateVariable],
name: &str,
unit: Option<&'a DistanceUnit>,
) -> Result<(Distance, &'a DistanceUnit), StateModelError> {
let value: Distance = self.get_state_variable(state, name)?.into();
let feature = self.get_feature(name)?;
let from_unit = feature.get_distance_unit()?;
match unit {
Some(to_unit) => {
let mut v_cow = Cow::Owned(value);
from_unit.convert(&mut v_cow, to_unit)?;
Ok((v_cow.into_owned(), to_unit))
}
None => Ok((value, from_unit)),
}
}
pub fn get_time<'a>(
&'a self,
state: &[StateVariable],
name: &str,
unit: Option<&'a TimeUnit>,
) -> Result<(Time, &'a TimeUnit), StateModelError> {
let value: Time = self.get_state_variable(state, name)?.into();
let feature = self.get_feature(name)?;
let from_unit = feature.get_time_unit()?;
match unit {
Some(to_unit) => {
let mut v_cow = Cow::Owned(value);
from_unit.convert(&mut v_cow, to_unit)?;
Ok((v_cow.into_owned(), to_unit))
}
None => Ok((value, from_unit)),
}
}
pub fn get_energy<'a>(
&'a self,
state: &[StateVariable],
name: &str,
unit: Option<&'a EnergyUnit>,
) -> Result<(Energy, &'a EnergyUnit), StateModelError> {
let value: Energy = self.get_state_variable(state, name)?.into();
let feature = self.get_feature(name)?;
let from_unit = feature.get_energy_unit()?;
match unit {
Some(to_unit) => {
let mut v_cow = Cow::Owned(value);
from_unit.convert(&mut v_cow, to_unit)?;
Ok((v_cow.into_owned(), to_unit))
}
None => Ok((value, from_unit)),
}
}
pub fn get_speed<'a>(
&'a self,
state: &[StateVariable],
name: &str,
unit: Option<&'a SpeedUnit>,
) -> Result<(Speed, &'a SpeedUnit), StateModelError> {
let value: Speed = self.get_state_variable(state, name)?.into();
let feature = self.get_feature(name)?;
let from_unit = feature.get_speed_unit()?;
match unit {
Some(to_unit) => {
let mut v_cow = Cow::Owned(value);
from_unit.convert(&mut v_cow, to_unit)?;
Ok((v_cow.into_owned(), to_unit))
}
None => Ok((value, from_unit)),
}
}
pub fn get_grade<'a>(
&'a self,
state: &[StateVariable],
name: &str,
unit: Option<&'a GradeUnit>,
) -> Result<(Grade, &'a GradeUnit), StateModelError> {
let value: Grade = self.get_state_variable(state, name)?.into();
let feature = self.get_feature(name)?;
let from_unit = feature.get_grade_unit()?;
match unit {
Some(to_unit) => {
let mut v_cow = Cow::Owned(value);
from_unit.convert(&mut v_cow, to_unit)?;
Ok((v_cow.into_owned(), to_unit))
}
None => Ok((value, from_unit)),
}
}
pub fn get_custom_f64(
&self,
state: &[StateVariable],
name: &str,
) -> Result<f64, StateModelError> {
let (value, format) = self.get_custom_state_variable(state, name)?;
let result = format.decode_f64(value)?;
Ok(result)
}
pub fn get_custom_i64(
&self,
state: &[StateVariable],
name: &str,
) -> Result<i64, StateModelError> {
let (value, format) = self.get_custom_state_variable(state, name)?;
let result = format.decode_i64(value)?;
Ok(result)
}
pub fn get_custom_u64(
&self,
state: &[StateVariable],
name: &str,
) -> Result<u64, StateModelError> {
let (value, format) = self.get_custom_state_variable(state, name)?;
let result = format.decode_u64(value)?;
Ok(result)
}
pub fn get_custom_bool(
&self,
state: &[StateVariable],
name: &str,
) -> Result<bool, StateModelError> {
let (value, format) = self.get_custom_state_variable(state, name)?;
let result = format.decode_bool(value)?;
Ok(result)
}
fn get_custom_state_variable<'a>(
&self,
state: &'a [StateVariable],
name: &str,
) -> Result<(&'a StateVariable, &CustomFeatureFormat), StateModelError> {
let value = self.get_state_variable(state, name)?;
let feature = self.get_feature(name)?;
let format = feature.get_custom_feature_format()?;
Ok((value, format))
}
pub fn get_delta<T: From<StateVariable>>(
&self,
prev: &[StateVariable],
next: &[StateVariable],
name: &str,
) -> Result<T, StateModelError> {
let prev_val = self.get_state_variable(prev, name)?;
let next_val = self.get_state_variable(next, name)?;
let delta = *next_val - *prev_val;
Ok(delta.into())
}
pub fn add_distance(
&self,
state: &mut [StateVariable],
name: &str,
distance: &Distance,
from_unit: &DistanceUnit,
) -> Result<(), StateModelError> {
let (prev_distance, _) = self.get_distance(state, name, Some(from_unit))?;
let next_distance = prev_distance + *distance;
self.set_distance(state, name, &next_distance, from_unit)
}
pub fn add_time(
&self,
state: &mut [StateVariable],
name: &str,
time: &Time,
from_unit: &TimeUnit,
) -> Result<(), StateModelError> {
let (prev_time, _) = self.get_time(state, name, Some(from_unit))?;
let next_time = prev_time + *time;
self.set_time(state, name, &next_time, from_unit)
}
pub fn add_energy(
&self,
state: &mut [StateVariable],
name: &str,
energy: &Energy,
from_unit: &EnergyUnit,
) -> Result<(), StateModelError> {
let (prev_energy, _) = self.get_energy(state, name, Some(from_unit))?;
let next_energy = prev_energy + *energy;
self.set_energy(state, name, &next_energy, from_unit)
}
pub fn add_speed(
&self,
state: &mut [StateVariable],
name: &str,
speed: &Speed,
from_unit: &SpeedUnit,
) -> Result<(), StateModelError> {
let (prev_speed, _) = self.get_speed(state, name, Some(from_unit))?;
let next_speed = prev_speed + *speed;
self.set_speed(state, name, &next_speed, from_unit)
}
pub fn add_grade(
&self,
state: &mut [StateVariable],
name: &str,
grade: &Grade,
from_unit: &GradeUnit,
) -> Result<(), StateModelError> {
let (prev_grade, _) = self.get_grade(state, name, Some(from_unit))?;
let next_grade = prev_grade + *grade;
self.set_grade(state, name, &next_grade, from_unit)
}
pub fn set_distance(
&self,
state: &mut [StateVariable],
name: &str,
distance: &Distance,
from_unit: &DistanceUnit,
) -> Result<(), StateModelError> {
let mut dist_cow = Cow::Borrowed(distance);
let to_unit = self.get_feature(name)?.get_distance_unit()?;
from_unit.convert(&mut dist_cow, to_unit)?;
self.update_state(
state,
name,
&dist_cow.into_owned().into(),
UpdateOperation::Replace,
)
}
pub fn set_time(
&self,
state: &mut [StateVariable],
name: &str,
time: &Time,
from_unit: &TimeUnit,
) -> Result<(), StateModelError> {
let mut time_mut = Cow::Borrowed(time);
let to_unit = self.get_feature(name)?.get_time_unit()?;
from_unit.convert(&mut time_mut, to_unit)?;
self.update_state(
state,
name,
&time_mut.into_owned().into(),
UpdateOperation::Replace,
)
}
pub fn set_energy(
&self,
state: &mut [StateVariable],
name: &str,
energy: &Energy,
from_unit: &EnergyUnit,
) -> Result<(), StateModelError> {
let mut energy_mut = Cow::Borrowed(energy);
let to_unit = self.get_feature(name)?.get_energy_unit()?;
from_unit.convert(&mut energy_mut, to_unit)?;
self.update_state(
state,
name,
&energy_mut.into_owned().into(),
UpdateOperation::Replace,
)
}
pub fn set_grade(
&self,
state: &mut [StateVariable],
name: &str,
grade: &Grade,
from_unit: &GradeUnit,
) -> Result<(), StateModelError> {
let mut grade_mut = Cow::Borrowed(grade);
let to_unit = self.get_feature(name)?.get_grade_unit()?;
from_unit.convert(&mut grade_mut, to_unit)?;
self.update_state(
state,
name,
&grade_mut.into_owned().into(),
UpdateOperation::Replace,
)
}
pub fn set_speed(
&self,
state: &mut [StateVariable],
name: &str,
speed: &Speed,
from_unit: &SpeedUnit,
) -> Result<(), StateModelError> {
let mut speed_mut = Cow::Borrowed(speed);
let to_unit = self.get_feature(name)?.get_speed_unit()?;
from_unit.convert(&mut speed_mut, to_unit)?;
self.update_state(
state,
name,
&speed_mut.into_owned().into(),
UpdateOperation::Replace,
)
}
pub fn set_custom_f64(
&self,
state: &mut [StateVariable],
name: &str,
value: &f64,
) -> Result<(), StateModelError> {
let feature = self.get_feature(name)?;
let format = feature.get_custom_feature_format()?;
let encoded_value = format.encode_f64(value)?;
self.update_state(state, name, &encoded_value, UpdateOperation::Replace)
}
pub fn set_custom_i64(
&self,
state: &mut [StateVariable],
name: &str,
value: &i64,
) -> Result<(), StateModelError> {
let feature = self.get_feature(name)?;
let format = feature.get_custom_feature_format()?;
let encoded_value = format.encode_i64(value)?;
self.update_state(state, name, &encoded_value, UpdateOperation::Replace)
}
pub fn set_custom_u64(
&self,
state: &mut [StateVariable],
name: &str,
value: &u64,
) -> Result<(), StateModelError> {
let feature = self.get_feature(name)?;
let format = feature.get_custom_feature_format()?;
let encoded_value = format.encode_u64(value)?;
self.update_state(state, name, &encoded_value, UpdateOperation::Replace)
}
pub fn set_custom_bool(
&self,
state: &mut [StateVariable],
name: &str,
value: &bool,
) -> Result<(), StateModelError> {
let feature = self.get_feature(name)?;
let format = feature.get_custom_feature_format()?;
let encoded_value = format.encode_bool(value)?;
self.update_state(state, name, &encoded_value, UpdateOperation::Replace)
}
pub fn serialize_state(&self, state: &[StateVariable]) -> serde_json::Value {
let output = self
.iter()
.zip(state.iter())
.filter_map(
|((name, feature), state_var)| match feature.is_accumlator() {
false => None,
true => Some((name, state_var)),
},
)
.collect::<HashMap<_, _>>();
json![output]
}
pub fn serialize_state_model(&self) -> serde_json::Value {
let mut out = serde_json::Map::new();
for (i, (name, feature)) in self.indexed_iter() {
let mut f_json = json![feature];
if let Some(map) = f_json.as_object_mut() {
map.insert(String::from("index"), json![i]);
map.insert(String::from("name"), json![name]);
}
out.insert(name.clone(), f_json);
}
json![out]
}
pub fn get_names(&self) -> String {
self.0.iter().map(|(k, _)| k.clone()).join(",")
}
fn get_feature(&self, feature_name: &str) -> Result<&OutputFeature, StateModelError> {
self.0.get(feature_name).ok_or_else(|| {
StateModelError::UnknownStateVariableName(feature_name.to_string(), self.get_names())
})
}
fn get_state_variable<'a>(
&self,
state: &'a [StateVariable],
name: &str,
) -> Result<&'a StateVariable, StateModelError> {
let idx = self.0.get_index(name).ok_or_else(|| {
StateModelError::UnknownStateVariableName(name.to_string(), self.get_names())
})?;
let value = state.get(idx).ok_or_else(|| {
StateModelError::RuntimeError(format!(
"state index {} for {} is out of range for state vector with {} entries",
idx,
name,
state.len()
))
})?;
Ok(value)
}
fn update_state(
&self,
state: &mut [StateVariable],
name: &str,
value: &StateVariable,
op: UpdateOperation,
) -> Result<(), StateModelError> {
let index = self.0.get_index(name).ok_or_else(|| {
StateModelError::UnknownStateVariableName(name.to_string(), self.get_names())
})?;
let prev = state
.get(index)
.ok_or(StateModelError::InvalidStateVariableIndex(
name.to_string(),
index,
state.len(),
))?;
let updated = op.perform_operation(prev, value);
state[index] = updated;
Ok(())
}
}
impl<'a> TryFrom<&'a serde_json::Value> for StateModel {
type Error = StateModelError;
fn try_from(json: &'a serde_json::Value) -> Result<StateModel, StateModelError> {
let value = json
.as_object()
.ok_or_else(|| {
StateModelError::BuildError(String::from(
"expected state model configuration to be a JSON object {}",
))
})?
.into_iter()
.map(|(feature_name, feature_json)| {
let feature = serde_json::from_value::<OutputFeature>(feature_json.clone())
.map_err(|e| {
StateModelError::BuildError(format!(
"unable to parse state feature row with name '{}' contents '{}' due to: {}",
feature_name.clone(),
feature_json.clone(),
e
))
})?;
Ok((feature_name.clone(), feature))
})
.collect::<Result<Vec<_>, StateModelError>>()?;
let state_model = StateModel::new(value);
Ok(state_model)
}
}