use crate::physics::data::PhysicalData;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use serde_json::{Map, Value};
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub enum PhysicalQuantity {
Concentration,
Temperature,
Pressure,
Velocity,
Custom(String),
}
impl std::fmt::Display for PhysicalQuantity {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
PhysicalQuantity::Concentration => write!(f, "Concentration"),
PhysicalQuantity::Temperature => write!(f, "Temperature"),
PhysicalQuantity::Pressure => write!(f, "Pressure"),
PhysicalQuantity::Velocity => write!(f, "Velocity"),
PhysicalQuantity::Custom(name) => write!(f, "{}", name),
}
}
}
impl Serialize for PhysicalQuantity {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&self.to_string())
}
}
impl<'de> Deserialize<'de> for PhysicalQuantity {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct PhysicalQuantityVisitor;
impl<'de> serde::de::Visitor<'de> for PhysicalQuantityVisitor {
type Value = PhysicalQuantity;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a string representing a PhysicalQuantity")
}
fn visit_str<E: serde::de::Error>(self, value: &str) -> Result<Self::Value, E> {
match value {
"Concentration" => Ok(PhysicalQuantity::Concentration),
"Temperature" => Ok(PhysicalQuantity::Temperature),
"Pressure" => Ok(PhysicalQuantity::Pressure),
"Velocity" => Ok(PhysicalQuantity::Velocity),
other => Ok(PhysicalQuantity::Custom(other.to_string())),
}
}
}
deserializer.deserialize_str(PhysicalQuantityVisitor)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PhysicalState {
pub quantities: HashMap<PhysicalQuantity, PhysicalData>,
metadata: HashMap<String, f64>,
}
impl PhysicalState {
pub fn new(quantity: PhysicalQuantity, value: PhysicalData) -> Self {
let mut quantities = HashMap::new();
quantities.insert(quantity, value);
Self {
quantities,
metadata: HashMap::new(),
}
}
pub fn empty() -> Self {
Self {
quantities: HashMap::new(),
metadata: HashMap::new(),
}
}
pub fn get(&self, quantity: PhysicalQuantity) -> Option<&PhysicalData> {
self.quantities.get(&quantity)
}
pub fn get_mut(&mut self, quantity: PhysicalQuantity) -> Option<&mut PhysicalData> {
self.quantities.get_mut(&quantity)
}
pub fn set(&mut self, quantity: PhysicalQuantity, value: PhysicalData) {
self.quantities.insert(quantity, value);
}
pub fn remove(&mut self, quantity: PhysicalQuantity) -> Option<PhysicalData> {
self.quantities.remove(&quantity)
}
pub fn available_quantities(&self) -> Vec<PhysicalQuantity> {
self.quantities.keys().cloned().collect()
}
pub fn get_metadata(&self, key: &str) -> Option<f64> {
self.metadata.get(key).copied()
}
pub fn set_metadata(&mut self, key: String, value: f64) {
self.metadata.insert(key, value);
}
pub fn memory_bytes(&self) -> usize {
self.quantities
.values()
.map(|data| data.memory_bytes())
.sum()
}
pub fn len(&self) -> usize {
self.quantities.len()
}
pub fn is_empty(&self) -> bool {
self.quantities.is_empty()
}
}
impl std::ops::Add for PhysicalState {
type Output = Self;
fn add(mut self, rhs: Self) -> Self::Output {
for (quantity, rhs_value) in rhs.quantities {
if let Some(lhs_value) = self.quantities.remove(&quantity) {
let sum = lhs_value + rhs_value;
self.quantities.insert(quantity, sum);
} else {
self.quantities.insert(quantity, rhs_value);
}
}
self
}
}
impl std::ops::Mul<f64> for PhysicalState {
type Output = Self;
fn mul(mut self, scalar: f64) -> Self::Output {
for data in self.quantities.values_mut() {
*data *= scalar;
}
self
}
}
#[typetag::serde]
pub trait PhysicalModel: Send + Sync {
fn points(&self) -> usize;
fn compute_physics(
&self,
state: &PhysicalState,
ctx: &crate::physics::context::ComputeContext,
) -> PhysicalState;
fn setup_initial_state(&self) -> PhysicalState;
fn name(&self) -> &str;
fn description(&self) -> Option<&str> {
None
}
fn set_injections(
&mut self,
_injections: &HashMap<Option<String>, crate::models::TemporalInjection>,
) -> Result<(), String> {
Ok(())
}
}
#[derive(Debug)]
pub enum ExportError {
MissingKey(String),
InvalidValue { key: String, reason: String },
SpeciesCountMismatch { expected: usize, got: usize },
}
impl std::fmt::Display for ExportError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ExportError::MissingKey(k) => {
write!(f, "export map: missing key '{k}'")
}
ExportError::InvalidValue { key, reason } => {
write!(f, "export map: invalid value for '{key}': {reason}")
}
ExportError::SpeciesCountMismatch { expected, got } => {
write!(f, "export map: expected {expected} species, got {got}")
}
}
}
}
impl std::error::Error for ExportError {}
pub trait Exportable {
fn to_map(
&self,
time_points: &[f64],
trajectory: &[PhysicalState],
metadata: &HashMap<String, String>,
) -> Map<String, Value>;
fn from_map(map: Map<String, Value>) -> Result<Self, ExportError>
where
Self: Sized;
}
pub fn outlet_data(
quantity: PhysicalQuantity,
trajectory: &[PhysicalState],
idx: usize,
) -> Vec<f64> {
let state = match trajectory.get(idx) {
Some(s) => s,
None => return vec![],
};
match state.get(quantity) {
Some(PhysicalData::Scalar(c)) => vec![*c],
Some(PhysicalData::Vector(v)) => v.iter().next_back().copied().into_iter().collect(),
Some(PhysicalData::Matrix(m)) if m.nrows() > 0 => {
let last = m.nrows() - 1;
(0..m.ncols()).map(|s| m[(last, s)]).collect()
}
_ => vec![],
}
}
pub fn sample_indices(total: usize, n: Option<usize>) -> Vec<usize> {
match n {
None | Some(0) => (0..total).collect(),
Some(n) if n >= total => (0..total).collect(),
Some(1) => vec![0],
Some(n) => {
let mut indices: Vec<usize> = (0..n).map(|i| (i * (total - 1)) / (n - 1)).collect();
if let Some(last) = indices.last_mut() {
*last = total - 1;
}
indices
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use nalgebra::DMatrix;
#[test]
fn test_physical_quantity_creation() {
let c = PhysicalQuantity::Concentration;
let t = PhysicalQuantity::Temperature;
let p = PhysicalQuantity::Pressure;
let v = PhysicalQuantity::Velocity;
assert_eq!(format!("{}", c), "Concentration");
assert_eq!(format!("{}", t), "Temperature");
assert_eq!(format!("{}", p), "Pressure");
assert_eq!(format!("{}", v), "Velocity");
}
#[test]
fn test_custom_physical_quantity_create() {
let viscosity = PhysicalQuantity::Custom("Viscosity".to_string());
let k_langmuir = PhysicalQuantity::Custom("K langmuir".to_string());
assert_eq!(format!("{}", viscosity), "Viscosity");
assert_eq!(format!("{}", k_langmuir), "K langmuir");
}
#[test]
fn test_physical_quantity_equality() {
let c1 = PhysicalQuantity::Concentration;
let c2 = PhysicalQuantity::Concentration;
let p1 = PhysicalQuantity::Pressure;
assert_eq!(c1, c2);
assert_ne!(c1, p1);
}
#[test]
fn test_custom_physical_quantity_equality2() {
let v1 = PhysicalQuantity::Custom("Viscosity".to_string());
let v2 = PhysicalQuantity::Custom("Viscosity".to_string());
let k = PhysicalQuantity::Custom("K langmuir".to_string());
assert_ne!(v1, k);
assert_ne!(v2, k);
assert_eq!(v1, v2);
}
#[test]
fn test_physical_quantity_as_hashmap_key() {
use std::collections::HashMap;
let mut map = HashMap::new();
map.insert(PhysicalQuantity::Custom("Viscosity".to_string()), 0.15);
map.insert(PhysicalQuantity::Temperature, 350.0);
assert_eq!(
map.get(&PhysicalQuantity::Custom("Viscosity".to_string())),
Some(&0.15)
);
assert_eq!(map.get(&PhysicalQuantity::Temperature), Some(&350.0));
}
#[test]
fn test_empty_physical_state() {
let empty = PhysicalState::empty();
assert_eq!(empty.len(), 0);
assert!(empty.is_empty());
assert_eq!(empty.available_quantities().len(), 0);
assert_eq!(empty.memory_bytes(), 0);
}
#[test]
fn test_physical_state_from_scalar() {
let state = PhysicalState::new(
PhysicalQuantity::Custom("Viscosity".to_string()),
PhysicalData::Scalar(0.15),
);
assert!(!state.is_empty());
assert_eq!(state.len(), 1);
assert_eq!(state.available_quantities().len(), 1);
}
#[test]
fn test_retrieve_physical_data() {
let state = PhysicalState::new(PhysicalQuantity::Temperature, PhysicalData::Scalar(348.15));
assert!(
state
.available_quantities()
.contains(&PhysicalQuantity::Temperature)
);
assert_eq!(
state
.get(PhysicalQuantity::Temperature)
.unwrap()
.as_scalar(),
348.15
); }
#[test]
fn test_physical_state_from_vector() {
let state = PhysicalState::new(
PhysicalQuantity::Velocity,
PhysicalData::from_vec(vec![25.0, 10.0, 33.0]),
);
assert_eq!(state.len(), 1);
assert!(state.get(PhysicalQuantity::Velocity).unwrap().is_vector());
assert_eq!(
state
.get(PhysicalQuantity::Velocity)
.unwrap()
.as_vector()
.len(),
3
);
}
#[test]
fn test_physical_state_from_matrix() {
let state = PhysicalState::new(
PhysicalQuantity::Custom("Viscosity".to_string()),
PhysicalData::Matrix(DMatrix::from_row_slice(
3,
3,
&[0.12, 0.5, 0.01, 0.2, 0.23, 0.6, 0.0, 0.0, 1.0],
)),
);
assert_eq!(state.len(), 1);
assert!(
state
.get(PhysicalQuantity::Custom("Viscosity".to_string()))
.unwrap()
.is_matrix()
);
assert_eq!(
state
.get(PhysicalQuantity::Custom("Viscosity".to_string()))
.unwrap()
.as_matrix()
.shape(),
(3, 3)
);
}
#[test]
fn test_physical_set_n_get() {
let mut state = PhysicalState::empty();
state.set(PhysicalQuantity::Temperature, PhysicalData::Scalar(348.15));
assert_eq!(state.len(), 1);
assert!(
state
.get(PhysicalQuantity::Temperature)
.unwrap()
.is_scalar()
);
assert_eq!(
state
.get(PhysicalQuantity::Temperature)
.unwrap()
.as_scalar(),
348.15
);
}
#[test]
fn test_get_mut_physical_state() {
let mut state = PhysicalState::new(
PhysicalQuantity::Concentration,
PhysicalData::uniform_vector(100, 0.1),
);
if let Some(concentration) = state.get_mut(PhysicalQuantity::Concentration) {
concentration.apply(|c| c * 10.0)
}
assert_eq!(state.len(), 1);
assert_eq!(
state
.get(PhysicalQuantity::Concentration)
.unwrap()
.as_vector()[10],
1.0
);
}
#[test]
fn test_remove_quantity() {
let mut state = PhysicalState::new(
PhysicalQuantity::Temperature,
PhysicalData::from_scalar(300.0),
);
assert!(!state.is_empty());
let removed = state.remove(PhysicalQuantity::Temperature);
assert!(removed.is_some());
assert!(state.is_empty());
let removed_again = state.remove(PhysicalQuantity::Temperature);
assert!(removed_again.is_none());
}
#[test]
fn test_available_quantities() {
let mut state = PhysicalState::empty();
state.set(
PhysicalQuantity::Concentration,
PhysicalData::from_scalar(1.0),
);
state.set(
PhysicalQuantity::Temperature,
PhysicalData::from_scalar(298.15),
);
state.set(
PhysicalQuantity::Pressure,
PhysicalData::from_scalar(101325.0),
);
let quantities = state.available_quantities();
assert_eq!(quantities.len(), 3);
assert!(quantities.contains(&PhysicalQuantity::Concentration));
assert!(quantities.contains(&PhysicalQuantity::Temperature));
assert!(quantities.contains(&PhysicalQuantity::Pressure));
}
#[test]
fn test_metadata_set_get() {
let mut state = PhysicalState::empty();
state.set_metadata("total_mass".to_string(), 125.5);
state.set_metadata("flow_rate".to_string(), 1.0);
assert_eq!(state.get_metadata("total_mass"), Some(125.5));
assert_eq!(state.get_metadata("flow_rate"), Some(1.0));
assert_eq!(state.get_metadata("unknown"), None);
}
#[test]
fn test_metadata_overwrite() {
let mut state = PhysicalState::empty();
state.set_metadata("value".to_string(), 1.0);
assert_eq!(state.get_metadata("value"), Some(1.0));
state.set_metadata("value".to_string(), 2.0);
assert_eq!(state.get_metadata("value"), Some(2.0));
}
#[test]
fn test_memory_bytes() {
let mut state = PhysicalState::empty();
state.set(
PhysicalQuantity::Temperature,
PhysicalData::from_scalar(298.15),
);
assert_eq!(state.memory_bytes(), 8);
state.set(
PhysicalQuantity::Concentration,
PhysicalData::uniform_vector(100, 1.0),
);
assert_eq!(state.memory_bytes(), 808);
state.set(
PhysicalQuantity::Pressure,
PhysicalData::uniform_matrix(100, 3, 1.0),
);
assert_eq!(state.memory_bytes(), 3208); }
#[test]
fn test_add_states_same_quantities() {
let state1 = PhysicalState::new(
PhysicalQuantity::Concentration,
PhysicalData::uniform_vector(100, 1.0),
);
let state2 = PhysicalState::new(
PhysicalQuantity::Concentration,
PhysicalData::uniform_vector(100, 0.5),
);
let result = state1 + state2;
let conc = result.get(PhysicalQuantity::Concentration).unwrap();
assert_eq!(conc.as_vector()[0], 1.5);
assert_eq!(conc.as_vector()[99], 1.5);
}
#[test]
fn test_add_states_different_quantities() {
let mut state1 = PhysicalState::empty();
state1.set(
PhysicalQuantity::Concentration,
PhysicalData::from_scalar(1.0),
);
let mut state2 = PhysicalState::empty();
state2.set(
PhysicalQuantity::Temperature,
PhysicalData::from_scalar(298.15),
);
let result = state1 + state2;
assert_eq!(result.len(), 2);
assert!(result.get(PhysicalQuantity::Concentration).is_some());
assert!(result.get(PhysicalQuantity::Temperature).is_some());
}
#[test]
fn test_add_states_overlapping() {
let mut state1 = PhysicalState::empty();
state1.set(
PhysicalQuantity::Concentration,
PhysicalData::from_scalar(1.0),
);
state1.set(
PhysicalQuantity::Temperature,
PhysicalData::from_scalar(300.0),
);
let mut state2 = PhysicalState::empty();
state2.set(
PhysicalQuantity::Concentration,
PhysicalData::from_scalar(0.5),
);
state2.set(
PhysicalQuantity::Pressure,
PhysicalData::from_scalar(101325.0),
);
let result = state1 + state2;
assert_eq!(result.len(), 3);
assert_eq!(
result
.get(PhysicalQuantity::Concentration)
.unwrap()
.as_scalar(),
1.5
);
assert_eq!(
result
.get(PhysicalQuantity::Temperature)
.unwrap()
.as_scalar(),
300.0
);
assert_eq!(
result.get(PhysicalQuantity::Pressure).unwrap().as_scalar(),
101325.0
);
}
#[test]
fn test_mul_scalar() {
let mut state = PhysicalState::empty();
state.set(
PhysicalQuantity::Concentration,
PhysicalData::uniform_vector(100, 2.0),
);
state.set(
PhysicalQuantity::Temperature,
PhysicalData::from_scalar(300.0),
);
let scaled = state * 0.5;
let conc = scaled.get(PhysicalQuantity::Concentration).unwrap();
assert_eq!(conc.as_vector()[0], 1.0);
let temp = scaled.get(PhysicalQuantity::Temperature).unwrap();
assert_eq!(temp.as_scalar(), 150.0);
}
#[test]
fn test_outlet_data_scalar() {
let state = PhysicalState::new(PhysicalQuantity::Concentration, PhysicalData::Scalar(3.14));
assert_eq!(
outlet_data(PhysicalQuantity::Concentration, &[state], 0),
vec![3.14]
);
}
#[test]
fn test_outlet_data_vector_last_point() {
let state = PhysicalState::new(
PhysicalQuantity::Concentration,
PhysicalData::Vector(nalgebra::DVector::from_vec(vec![0.0, 0.5, 1.0])),
);
let outlet = outlet_data(PhysicalQuantity::Concentration, &[state], 0);
assert!((outlet[0] - 1.0).abs() < 1e-12);
}
#[test]
fn test_outlet_data_out_of_bounds() {
let state = PhysicalState::new(PhysicalQuantity::Concentration, PhysicalData::Scalar(1.0));
assert!(outlet_data(PhysicalQuantity::Concentration, &[state], 99).is_empty());
}
#[test]
fn test_outlet_data_unknown_quantity() {
let state = PhysicalState::new(PhysicalQuantity::Concentration, PhysicalData::Scalar(1.0));
assert!(outlet_data(PhysicalQuantity::Temperature, &[state], 0).is_empty());
}
#[test]
fn test_sample_indices_full() {
assert_eq!(sample_indices(4, None), vec![0, 1, 2, 3]);
}
#[test]
fn test_sample_indices_five_from_hundred() {
assert_eq!(sample_indices(100, Some(5)), vec![0, 24, 49, 74, 99]);
}
#[test]
fn test_sample_indices_single() {
assert_eq!(sample_indices(100, Some(1)), vec![0]);
}
#[test]
fn test_sample_indices_larger_than_total() {
assert_eq!(sample_indices(3, Some(10)), vec![0, 1, 2]);
}
#[test]
fn test_export_error_display_missing_key() {
let e = ExportError::MissingKey("metadata".into());
assert_eq!(e.to_string(), "export map: missing key 'metadata'");
}
#[test]
fn test_export_error_display_mismatch() {
let e = ExportError::SpeciesCountMismatch {
expected: 2,
got: 1,
};
assert_eq!(e.to_string(), "export map: expected 2 species, got 1");
}
}