use std::collections::HashMap;
use nalgebra::{DMatrix, DVector};
use crate::context::error::OxiflowError;
use crate::context::value::ContextValue;
use crate::context::variable::ContextVariable;
pub struct ComputeContext {
time: f64,
time_step: f64,
variables: HashMap<ContextVariable, ContextValue>,
}
impl ComputeContext {
pub fn new(time: f64, time_step: f64) -> Self {
Self {
time,
time_step,
variables: HashMap::new(),
}
}
pub fn insert(&mut self, var: ContextVariable, value: ContextValue) {
self.variables.insert(var, value);
}
pub fn time(&self) -> f64 {
self.time
}
pub fn time_step(&self) -> f64 {
self.time_step
}
pub fn scalar(&self, var: ContextVariable) -> Result<f64, OxiflowError> {
self.get_value(&var)?.as_scalar()
}
pub fn vector(&self, var: ContextVariable) -> Result<&DVector<f64>, OxiflowError> {
self.get_value(&var)?.as_vector()
}
pub fn matrix(&self, var: ContextVariable) -> Result<&DMatrix<f64>, OxiflowError> {
self.get_value(&var)?.as_matrix()
}
pub fn gradient(&self, dim: usize) -> Result<&DVector<f64>, OxiflowError> {
let var = ContextVariable::SpatialGradient {
dimension: dim,
component: None,
};
self.get_value(&var)?.as_scalar_field()
}
pub fn external(&self, var: ContextVariable) -> Result<&ContextValue, OxiflowError> {
self.get_value(&var)
}
pub fn try_get(&self, var: ContextVariable) -> Option<&ContextValue> {
self.variables.get(&var)
}
fn get_value(&self, var: &ContextVariable) -> Result<&ContextValue, OxiflowError> {
self.variables
.get(var)
.ok_or_else(|| OxiflowError::MissingCalculator(var.clone()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use nalgebra::{DMatrix, DVector};
fn ctx_with_all_variants() -> ComputeContext {
let mut ctx = ComputeContext::new(2.0, 0.01);
ctx.insert(
ContextVariable::External {
name: "coeff".into(),
},
ContextValue::Scalar(42.0),
);
ctx.insert(
ContextVariable::External {
name: "flag".into(),
},
ContextValue::Boolean(true),
);
ctx.insert(
ContextVariable::External { name: "vel".into() },
ContextValue::Vector(DVector::from_vec(vec![1.0, 2.0, 3.0])),
);
ctx.insert(
ContextVariable::External {
name: "tensor".into(),
},
ContextValue::Matrix(DMatrix::from_element(2, 2, 5.0)),
);
ctx.insert(
ContextVariable::SpatialGradient {
dimension: 0,
component: None,
},
ContextValue::ScalarField(DVector::from_vec(vec![0.1, 0.2, 0.3])),
);
ctx.insert(
ContextVariable::External {
name: "vfield".into(),
},
ContextValue::VectorField(DMatrix::from_element(3, 2, 1.5)),
);
ctx
}
#[test]
fn time_returns_correct_value() {
let ctx = ComputeContext::new(3.14, 0.001);
assert_eq!(ctx.time(), 3.14);
}
#[test]
fn time_step_returns_correct_value() {
let ctx = ComputeContext::new(0.0, 0.05);
assert_eq!(ctx.time_step(), 0.05);
}
#[test]
fn scalar_returns_value_for_scalar_variable() {
let ctx = ctx_with_all_variants();
let v = ctx
.scalar(ContextVariable::External {
name: "coeff".into(),
})
.unwrap();
assert_eq!(v, 42.0);
}
#[test]
fn scalar_returns_missing_calculator_when_absent() {
let ctx = ComputeContext::new(0.0, 0.01);
let err = ctx.scalar(ContextVariable::Time).unwrap_err();
assert!(matches!(err, OxiflowError::MissingCalculator(_)));
}
#[test]
fn scalar_returns_type_mismatch_for_non_scalar() {
let ctx = ctx_with_all_variants();
let err = ctx
.scalar(ContextVariable::External { name: "vel".into() })
.unwrap_err();
assert!(matches!(
err,
OxiflowError::TypeMismatch {
expected: "Scalar",
..
}
));
}
#[test]
fn vector_returns_reference_for_vector_variable() {
let ctx = ctx_with_all_variants();
let v = ctx
.vector(ContextVariable::External { name: "vel".into() })
.unwrap();
assert_eq!(v.len(), 3);
assert_eq!(v[1], 2.0);
}
#[test]
fn vector_returns_missing_calculator_when_absent() {
let ctx = ComputeContext::new(0.0, 0.01);
let err = ctx
.vector(ContextVariable::External {
name: "missing".into(),
})
.unwrap_err();
assert!(matches!(err, OxiflowError::MissingCalculator(_)));
}
#[test]
fn vector_returns_type_mismatch_for_non_vector() {
let ctx = ctx_with_all_variants();
let err = ctx
.vector(ContextVariable::External {
name: "coeff".into(),
})
.unwrap_err();
assert!(matches!(
err,
OxiflowError::TypeMismatch {
expected: "Vector",
..
}
));
}
#[test]
fn matrix_returns_reference_for_matrix_variable() {
let ctx = ctx_with_all_variants();
let m = ctx
.matrix(ContextVariable::External {
name: "tensor".into(),
})
.unwrap();
assert_eq!(m.shape(), (2, 2));
assert_eq!(m[(0, 0)], 5.0);
}
#[test]
fn matrix_returns_type_mismatch_for_non_matrix() {
let ctx = ctx_with_all_variants();
let err = ctx
.matrix(ContextVariable::External {
name: "coeff".into(),
})
.unwrap_err();
assert!(matches!(
err,
OxiflowError::TypeMismatch {
expected: "Matrix",
..
}
));
}
#[test]
fn gradient_returns_full_nodal_field() {
let ctx = ctx_with_all_variants();
let g = ctx.gradient(0).unwrap();
assert_eq!(g.len(), 3);
assert!((g[0] - 0.1).abs() < 1e-12);
assert!((g[2] - 0.3).abs() < 1e-12);
}
#[test]
fn gradient_missing_dimension_returns_missing_calculator() {
let ctx = ctx_with_all_variants();
let err = ctx.gradient(1).unwrap_err();
assert!(matches!(err, OxiflowError::MissingCalculator(_)));
}
#[test]
fn gradient_with_wrong_type_returns_type_mismatch() {
let mut ctx = ComputeContext::new(0.0, 0.01);
ctx.insert(
ContextVariable::SpatialGradient {
dimension: 0,
component: None,
},
ContextValue::Scalar(1.0),
);
let err = ctx.gradient(0).unwrap_err();
assert!(matches!(
err,
OxiflowError::TypeMismatch {
expected: "ScalarField",
..
}
));
}
#[test]
fn external_returns_raw_context_value() {
let ctx = ctx_with_all_variants();
let val = ctx
.external(ContextVariable::External {
name: "flag".into(),
})
.unwrap();
assert!(matches!(val, ContextValue::Boolean(true)));
}
#[test]
fn external_returns_missing_calculator_when_absent() {
let ctx = ComputeContext::new(0.0, 0.01);
let err = ctx
.external(ContextVariable::External {
name: "absent".into(),
})
.unwrap_err();
assert!(matches!(err, OxiflowError::MissingCalculator(_)));
}
#[test]
fn external_works_for_any_variant_type() {
let ctx = ctx_with_all_variants();
let val = ctx
.external(ContextVariable::External {
name: "vfield".into(),
})
.unwrap();
assert!(val.is_vector_field());
}
#[test]
fn try_get_returns_some_when_present() {
let ctx = ctx_with_all_variants();
let val = ctx.try_get(ContextVariable::SpatialGradient {
dimension: 0,
component: None,
});
assert!(val.is_some());
assert!(val.unwrap().is_scalar_field());
}
#[test]
fn try_get_returns_none_when_absent() {
let ctx = ComputeContext::new(0.0, 0.01);
assert!(ctx.try_get(ContextVariable::Time).is_none());
}
#[test]
fn insert_overwrites_previous_value() {
let mut ctx = ComputeContext::new(0.0, 0.01);
let var = ContextVariable::External { name: "x".into() };
ctx.insert(var.clone(), ContextValue::Scalar(1.0));
ctx.insert(var.clone(), ContextValue::Scalar(2.0));
assert_eq!(ctx.scalar(var).unwrap(), 2.0);
}
}