use nalgebra::{DMatrix, DVector};
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum ContextVariable {
Time,
TimeStep,
SpatialGradient {
dimension: usize,
component: Option<usize>,
},
External {
name: Cow<'static, str>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ContextValue {
Scalar(f64),
Boolean(bool),
ScalarField(DVector<f64>),
VectorField(DMatrix<f64>),
}
#[derive(Debug, Clone)]
pub struct ComputeContext {
time: f64,
time_step: f64,
variables: HashMap<ContextVariable, ContextValue>,
}
impl ComputeContext {
pub fn new(time: f64, time_step: f64) -> Self {
Self {
time,
time_step,
variables: HashMap::new(),
}
}
#[inline]
pub fn time(&self) -> f64 {
self.time
}
#[inline]
pub fn time_step(&self) -> f64 {
self.time_step
}
pub fn insert(&mut self, key: ContextVariable, value: ContextValue) {
self.variables.insert(key, value);
}
pub fn get(&self, key: &ContextVariable) -> Option<&ContextValue> {
self.variables.get(key)
}
pub fn external_scalar(&self, name: &str) -> Option<f64> {
let key = ContextVariable::External {
name: Cow::Owned(name.to_string()),
};
match self.variables.get(&key) {
Some(ContextValue::Scalar(v)) => Some(*v),
_ => None,
}
}
pub fn spatial_gradient(
&self,
dimension: usize,
component: Option<usize>,
) -> Option<&DVector<f64>> {
let key = ContextVariable::SpatialGradient {
dimension,
component,
};
match self.variables.get(&key) {
Some(ContextValue::ScalarField(v)) => Some(v),
_ => None,
}
}
pub fn is_empty(&self) -> bool {
self.variables.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
use nalgebra::DVector;
#[test]
fn test_context_new() {
let ctx = ComputeContext::new(10.0, 0.01);
assert!((ctx.time() - 10.0).abs() < 1e-12);
assert!((ctx.time_step() - 0.01).abs() < 1e-12);
assert!(ctx.is_empty());
}
#[test]
fn test_context_insert_and_get_scalar() {
let mut ctx = ComputeContext::new(0.0, 0.01);
ctx.insert(
ContextVariable::External {
name: "pressure".into(),
},
ContextValue::Scalar(101325.0),
);
assert!(!ctx.is_empty());
match ctx.get(&ContextVariable::External {
name: "pressure".into(),
}) {
Some(ContextValue::Scalar(v)) => assert!((v - 101325.0).abs() < 1e-6),
_ => panic!("Expected Scalar"),
}
}
#[test]
fn test_external_scalar_accessor() {
let mut ctx = ComputeContext::new(0.0, 0.01);
ctx.insert(
ContextVariable::External {
name: "flow".into(),
},
ContextValue::Scalar(1e-6),
);
assert!((ctx.external_scalar("flow").unwrap() - 1e-6).abs() < 1e-18);
assert!(ctx.external_scalar("unknown").is_none());
}
#[test]
fn test_spatial_gradient_accessor() {
let mut ctx = ComputeContext::new(0.0, 0.01);
let field = DVector::from_vec(vec![1.0, 2.0, 3.0]);
ctx.insert(
ContextVariable::SpatialGradient {
dimension: 0,
component: None,
},
ContextValue::ScalarField(field.clone()),
);
let retrieved = ctx.spatial_gradient(0, None).unwrap();
assert_eq!(retrieved, &field);
assert!(ctx.spatial_gradient(1, None).is_none());
}
#[test]
fn test_context_variable_hash_eq() {
let a = ContextVariable::External {
name: "pressure".into(),
};
let b = ContextVariable::External {
name: "pressure".into(),
};
let c = ContextVariable::External {
name: "flow".into(),
};
assert_eq!(a, b);
assert_ne!(a, c);
let mut map = HashMap::new();
map.insert(a.clone(), 1u32);
assert_eq!(map.get(&b), Some(&1u32));
assert_eq!(map.get(&c), None);
}
#[test]
fn test_context_variable_spatial_gradient_eq() {
let a = ContextVariable::SpatialGradient {
dimension: 0,
component: Some(1),
};
let b = ContextVariable::SpatialGradient {
dimension: 0,
component: Some(1),
};
let c = ContextVariable::SpatialGradient {
dimension: 0,
component: None,
};
assert_eq!(a, b);
assert_ne!(a, c);
}
#[test]
fn test_context_value_boolean() {
let mut ctx = ComputeContext::new(0.0, 0.01);
ctx.insert(ContextVariable::Time, ContextValue::Boolean(true));
match ctx.get(&ContextVariable::Time) {
Some(ContextValue::Boolean(v)) => assert!(*v),
_ => panic!("Expected Boolean"),
}
}
}