use crate::context::calculator::ContextCalculator;
use crate::context::error::OxiflowError;
use crate::context::variable::ContextVariable;
const BUILTIN_VARIABLES: &[ContextVariable] = &[ContextVariable::Time, ContextVariable::TimeStep];
pub fn build_calculator_chain<'a>(
requirements: &[ContextVariable],
calculators: &'a [Box<dyn ContextCalculator>],
) -> Result<Vec<&'a dyn ContextCalculator>, OxiflowError> {
for req in requirements {
if is_builtin(req) {
continue;
}
let covered = calculators.iter().any(|c| &c.provides() == req);
if !covered {
return Err(OxiflowError::MissingCalculator(req.clone()));
}
}
let has_deps = calculators.iter().any(|c| !c.depends_on().is_empty());
if has_deps {
build_kahn_chain(calculators)
} else {
build_priority_chain(calculators)
}
}
fn build_priority_chain(
calculators: &[Box<dyn ContextCalculator>],
) -> Result<Vec<&dyn ContextCalculator>, OxiflowError> {
let mut chain: Vec<&dyn ContextCalculator> = calculators.iter().map(|c| c.as_ref()).collect();
chain.sort_by_key(|c| c.priority());
Ok(chain)
}
fn build_kahn_chain(
calculators: &[Box<dyn ContextCalculator>],
) -> Result<Vec<&dyn ContextCalculator>, OxiflowError> {
let n = calculators.len();
let mut successors: Vec<Vec<usize>> = vec![vec![]; n];
let mut in_degree: Vec<usize> = vec![0; n];
for (i, calc) in calculators.iter().enumerate() {
for dep_var in calc.depends_on() {
if is_builtin(&dep_var) {
continue;
}
for (j, provider) in calculators.iter().enumerate() {
if provider.provides() == dep_var {
successors[j].push(i);
in_degree[i] += 1;
}
}
}
}
let mut queue: Vec<usize> = (0..n).filter(|&i| in_degree[i] == 0).collect();
queue.sort_by_key(|&i| calculators[i].priority());
let mut result: Vec<&dyn ContextCalculator> = Vec::with_capacity(n);
while !queue.is_empty() {
let i = queue.remove(0);
result.push(calculators[i].as_ref());
let mut newly_free: Vec<usize> = Vec::new();
for &j in &successors[i] {
in_degree[j] -= 1;
if in_degree[j] == 0 {
newly_free.push(j);
}
}
queue.extend(newly_free);
queue.sort_by_key(|&i| calculators[i].priority());
}
if result.len() < n {
let blocked = (0..n).find(|&i| in_degree[i] > 0).unwrap();
let var = calculators[blocked]
.depends_on()
.into_iter()
.find(|v| !is_builtin(v))
.unwrap_or_else(|| calculators[blocked].provides());
return Err(OxiflowError::CircularDependency(var));
}
Ok(result)
}
fn is_builtin(var: &ContextVariable) -> bool {
BUILTIN_VARIABLES.contains(var)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::context::compute::ComputeContext;
use crate::context::value::ContextValue;
use crate::model::traits::RequiresContext;
#[derive(Debug)]
struct NamedCalc {
provides: ContextVariable,
priority: u32,
depends_on: Vec<ContextVariable>,
}
impl RequiresContext for NamedCalc {
fn required_variables(&self) -> Vec<ContextVariable> {
vec![]
}
fn priority(&self) -> u32 {
self.priority
}
fn depends_on(&self) -> Vec<ContextVariable> {
self.depends_on.clone()
}
}
impl ContextCalculator for NamedCalc {
fn provides(&self) -> ContextVariable {
self.provides.clone()
}
fn compute(
&self,
_state: &ContextValue,
ctx: &ComputeContext,
) -> Result<ContextValue, OxiflowError> {
Ok(ContextValue::Scalar(ctx.time()))
}
}
fn make_calc(provides: ContextVariable, priority: u32) -> Box<dyn ContextCalculator> {
Box::new(NamedCalc {
provides,
priority,
depends_on: vec![],
})
}
fn make_deps_calc(
provides: ContextVariable,
priority: u32,
depends_on: Vec<ContextVariable>,
) -> Box<dyn ContextCalculator> {
Box::new(NamedCalc {
provides,
priority,
depends_on,
})
}
fn var(name: &'static str) -> ContextVariable {
ContextVariable::External { name: name.into() }
}
#[test]
fn empty_requirements_with_no_calculators_succeeds() {
let chain = build_calculator_chain(&[], &[]).unwrap();
assert!(chain.is_empty());
}
#[test]
fn builtin_time_requires_no_calculator() {
let requirements = vec![ContextVariable::Time, ContextVariable::TimeStep];
let chain = build_calculator_chain(&requirements, &[]).unwrap();
assert!(chain.is_empty());
}
#[test]
fn satisfied_requirement_succeeds() {
let requirements = vec![var("D_ax")];
let calcs = vec![make_calc(var("D_ax"), 100)];
let chain = build_calculator_chain(&requirements, &calcs).unwrap();
assert_eq!(chain.len(), 1);
}
#[test]
fn missing_calculator_returns_error() {
let requirements = vec![var("missing")];
let err = build_calculator_chain(&requirements, &[]).unwrap_err();
assert!(matches!(err, OxiflowError::MissingCalculator(_)));
}
#[test]
fn missing_calculator_error_names_the_variable() {
let v = ContextVariable::SpatialGradient {
dimension: 0,
component: None,
};
let requirements = vec![v.clone()];
let err = build_calculator_chain(&requirements, &[]).unwrap_err();
assert!(matches!(err, OxiflowError::MissingCalculator(x) if x == v));
}
#[test]
fn duplicate_calculator_for_same_variable_is_accepted() {
let requirements = vec![var("v")];
let calcs = vec![make_calc(var("v"), 100), make_calc(var("v"), 50)];
assert!(build_calculator_chain(&requirements, &calcs).is_ok());
}
#[test]
fn extra_calculators_beyond_requirements_are_included() {
let requirements = vec![var("a")];
let calcs = vec![make_calc(var("a"), 100), make_calc(var("b"), 100)];
let chain = build_calculator_chain(&requirements, &calcs).unwrap();
assert_eq!(chain.len(), 2);
}
#[test]
fn chain_sorted_by_ascending_priority() {
let calcs = vec![
make_calc(var("c"), 200),
make_calc(var("a"), 50),
make_calc(var("b"), 100),
];
let chain = build_calculator_chain(&[], &calcs).unwrap();
assert_eq!(chain[0].priority(), 50);
assert_eq!(chain[1].priority(), 100);
assert_eq!(chain[2].priority(), 200);
}
#[test]
fn stable_sort_preserves_registration_order_within_same_priority() {
let calcs = vec![make_calc(var("first"), 100), make_calc(var("second"), 100)];
let chain = build_calculator_chain(&[], &calcs).unwrap();
assert_eq!(chain[0].provides(), var("first"));
assert_eq!(chain[1].provides(), var("second"));
}
#[test]
fn mixed_builtin_and_user_requirements() {
let requirements = vec![ContextVariable::Time, var("D_ax")];
let calcs = vec![make_calc(var("D_ax"), 100)];
let chain = build_calculator_chain(&requirements, &calcs).unwrap();
assert_eq!(chain.len(), 1);
}
#[test]
fn kahn_simple_chain() {
let calcs = vec![
make_deps_calc(var("Y"), 100, vec![var("X")]), make_deps_calc(var("Z"), 100, vec![var("Y")]), make_calc(var("X"), 100), ];
let chain = build_calculator_chain(&[], &calcs).unwrap();
assert_eq!(chain[0].provides(), var("X"));
assert_eq!(chain[1].provides(), var("Y"));
assert_eq!(chain[2].provides(), var("Z"));
}
#[test]
fn kahn_diamond() {
let calcs = vec![
make_deps_calc(var("W"), 100, vec![var("Y"), var("Z")]), make_deps_calc(var("Y"), 100, vec![var("X")]), make_deps_calc(var("Z"), 100, vec![var("X")]), make_calc(var("X"), 100), ];
let chain = build_calculator_chain(&[], &calcs).unwrap();
assert_eq!(chain[0].provides(), var("X"));
assert_eq!(chain[3].provides(), var("W"));
let middle: Vec<ContextVariable> = chain[1..3].iter().map(|c| c.provides()).collect();
assert!(middle.contains(&var("Y")));
assert!(middle.contains(&var("Z")));
}
#[test]
fn kahn_priority_tiebreaker_within_tier() {
let calcs = vec![
make_deps_calc(var("B_out"), 200, vec![var("X")]),
make_deps_calc(var("C_out"), 50, vec![var("X")]),
make_calc(var("X"), 10),
];
let chain = build_calculator_chain(&[], &calcs).unwrap();
assert_eq!(chain[0].provides(), var("X"));
assert_eq!(chain[1].provides(), var("C_out"));
assert_eq!(chain[2].provides(), var("B_out"));
}
#[test]
fn kahn_multiple_providers_all_precede_dependent() {
let calcs = vec![
make_deps_calc(var("Z"), 100, vec![var("X")]), make_calc(var("X"), 60), make_calc(var("X"), 50), ];
let chain = build_calculator_chain(&[], &calcs).unwrap();
let c_pos = chain.iter().position(|c| c.provides() == var("Z")).unwrap();
let a_pos = chain
.iter()
.position(|c| c.priority() == 50 && c.provides() == var("X"))
.unwrap();
let b_pos = chain
.iter()
.position(|c| c.priority() == 60 && c.provides() == var("X"))
.unwrap();
assert!(a_pos < c_pos);
assert!(b_pos < c_pos);
}
#[test]
fn kahn_builtin_in_depends_on_is_ignored() {
let calcs = vec![make_deps_calc(
var("flux"),
100,
vec![ContextVariable::Time],
)];
let chain = build_calculator_chain(&[], &calcs).unwrap();
assert_eq!(chain.len(), 1);
assert_eq!(chain[0].provides(), var("flux"));
}
#[test]
fn kahn_cycle_two_nodes_returns_error() {
let calcs = vec![
make_deps_calc(var("A_out"), 100, vec![var("B_out")]),
make_deps_calc(var("B_out"), 100, vec![var("A_out")]),
];
let err = build_calculator_chain(&[], &calcs).unwrap_err();
assert!(matches!(err, OxiflowError::CircularDependency(_)));
}
#[test]
fn kahn_cycle_three_nodes_returns_error() {
let calcs = vec![
make_deps_calc(var("A_out"), 100, vec![var("C_out")]),
make_deps_calc(var("B_out"), 100, vec![var("A_out")]),
make_deps_calc(var("C_out"), 100, vec![var("B_out")]),
];
let err = build_calculator_chain(&[], &calcs).unwrap_err();
assert!(matches!(err, OxiflowError::CircularDependency(_)));
}
#[test]
fn kahn_mixed_with_and_without_deps() {
let calcs = vec![
make_calc(var("alpha"), 200),
make_deps_calc(var("beta"), 100, vec![var("alpha")]),
make_calc(var("gamma"), 50),
];
let chain = build_calculator_chain(&[], &calcs).unwrap();
let alpha_pos = chain
.iter()
.position(|c| c.provides() == var("alpha"))
.unwrap();
let beta_pos = chain
.iter()
.position(|c| c.provides() == var("beta"))
.unwrap();
assert!(alpha_pos < beta_pos);
assert_eq!(chain[0].provides(), var("gamma"));
}
}