use crate::context::compute::ComputeContext;
use crate::context::error::OxiflowError;
use crate::context::value::ContextValue;
use crate::context::variable::ContextVariable;
use crate::model::traits::RequiresContext;
pub trait ContextCalculator: RequiresContext + Send + Sync + std::fmt::Debug {
fn provides(&self) -> ContextVariable;
fn compute(
&self,
state: &ContextValue,
ctx: &ComputeContext,
) -> Result<ContextValue, OxiflowError>;
fn name(&self) -> &str {
"unnamed calculator"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug)]
struct TimeCalculator;
impl RequiresContext for TimeCalculator {
fn required_variables(&self) -> Vec<ContextVariable> {
vec![]
}
fn priority(&self) -> u32 {
0
}
}
impl ContextCalculator for TimeCalculator {
fn provides(&self) -> ContextVariable {
ContextVariable::Time
}
fn compute(
&self,
_state: &ContextValue,
ctx: &ComputeContext,
) -> Result<ContextValue, OxiflowError> {
Ok(ContextValue::Scalar(ctx.time()))
}
fn name(&self) -> &str {
"time"
}
}
#[derive(Debug)]
struct ConstantCalculator {
var: ContextVariable,
value: f64,
}
impl RequiresContext for ConstantCalculator {
fn required_variables(&self) -> Vec<ContextVariable> {
vec![]
}
}
impl ContextCalculator for ConstantCalculator {
fn provides(&self) -> ContextVariable {
self.var.clone()
}
fn compute(
&self,
_state: &ContextValue,
_ctx: &ComputeContext,
) -> Result<ContextValue, OxiflowError> {
Ok(ContextValue::Scalar(self.value))
}
}
#[derive(Debug)]
struct TimeDependentCalculator;
impl RequiresContext for TimeDependentCalculator {
fn required_variables(&self) -> Vec<ContextVariable> {
vec![ContextVariable::Time]
}
fn depends_on(&self) -> Vec<ContextVariable> {
vec![ContextVariable::Time]
}
fn priority(&self) -> u32 {
50
}
}
impl ContextCalculator for TimeDependentCalculator {
fn provides(&self) -> ContextVariable {
ContextVariable::External {
name: "double_time".into(),
}
}
fn compute(
&self,
_state: &ContextValue,
ctx: &ComputeContext,
) -> Result<ContextValue, OxiflowError> {
let t = ctx.time();
Ok(ContextValue::Scalar(t * 2.0))
}
}
#[test]
fn provides_returns_declared_variable() {
assert_eq!(TimeCalculator.provides(), ContextVariable::Time);
}
#[test]
fn provides_is_stable_across_calls() {
let calc = ConstantCalculator {
var: ContextVariable::TimeStep,
value: 0.01,
};
assert_eq!(calc.provides(), calc.provides());
}
#[test]
fn time_calculator_returns_current_time() {
let ctx = ComputeContext::new(3.14, 0.01);
let result = TimeCalculator
.compute(&ContextValue::Scalar(0.0), &ctx)
.unwrap();
assert_eq!(result.as_scalar().unwrap(), 3.14);
}
#[test]
fn constant_calculator_returns_fixed_value() {
let calc = ConstantCalculator {
var: ContextVariable::External {
name: "D_ax".into(),
},
value: 1.5e-4,
};
let ctx = ComputeContext::new(0.0, 0.01);
let result = calc.compute(&ContextValue::Scalar(0.0), &ctx).unwrap();
assert!((result.as_scalar().unwrap() - 1.5e-4).abs() < 1e-12);
}
#[test]
fn time_dependent_calculator_reads_from_ctx() {
let mut ctx = ComputeContext::new(5.0, 0.01);
ctx.insert(ContextVariable::Time, ContextValue::Scalar(5.0));
let result = TimeDependentCalculator
.compute(&ContextValue::Scalar(0.0), &ctx)
.unwrap();
assert_eq!(result.as_scalar().unwrap(), 10.0);
}
#[test]
fn name_returns_provided_string() {
assert_eq!(TimeCalculator.name(), "time");
}
#[test]
fn default_name_is_unnamed() {
let calc = ConstantCalculator {
var: ContextVariable::TimeStep,
value: 0.01,
};
assert_eq!(calc.name(), "unnamed calculator");
}
#[test]
fn time_calculator_has_no_requirements() {
assert!(TimeCalculator.required_variables().is_empty());
assert_eq!(TimeCalculator.priority(), 0);
}
#[test]
fn time_dependent_requires_time() {
let calc = TimeDependentCalculator;
assert!(calc.required_variables().contains(&ContextVariable::Time));
assert!(calc.depends_on().contains(&ContextVariable::Time));
assert_eq!(calc.priority(), 50);
}
#[test]
fn trait_is_object_safe() {
let calcs: Vec<Box<dyn ContextCalculator>> = vec![
Box::new(TimeCalculator),
Box::new(ConstantCalculator {
var: ContextVariable::TimeStep,
value: 0.01,
}),
];
assert_eq!(calcs[0].provides(), ContextVariable::Time);
assert_eq!(calcs[1].provides(), ContextVariable::TimeStep);
}
}