use super::ast::Dae;
use crate::ir::ast::{Component, Connection, Equation, Expression, Statement, TerminalType};
use crate::ir::transform::eval::{eval_boolean, eval_integer};
use indexmap::IndexMap;
use serde::{Deserialize, Serialize};
use std::collections::HashSet;
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum BalanceStatus {
Balanced,
Partial,
Unbalanced,
CompileError(String),
}
#[derive(Debug, Clone, PartialEq, Deserialize)]
pub struct BalanceResult {
pub num_equations: usize,
pub num_unknowns: usize,
pub num_states: usize,
pub num_algebraic: usize,
pub num_parameters: usize,
pub num_inputs: usize,
pub num_external_connectors: usize,
pub status: BalanceStatus,
#[serde(default)]
pub compile_time_ms: u64,
}
impl Serialize for BalanceResult {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
use serde::ser::SerializeStruct;
let mut state = serializer.serialize_struct("BalanceResult", 10)?;
state.serialize_field("num_equations", &self.num_equations)?;
state.serialize_field("num_unknowns", &self.num_unknowns)?;
state.serialize_field("num_states", &self.num_states)?;
state.serialize_field("num_algebraic", &self.num_algebraic)?;
state.serialize_field("num_parameters", &self.num_parameters)?;
state.serialize_field("num_inputs", &self.num_inputs)?;
state.serialize_field("num_external_connectors", &self.num_external_connectors)?;
state.serialize_field("status", &self.status)?;
state.serialize_field("is_balanced", &self.is_balanced())?;
state.serialize_field("compile_time_ms", &self.compile_time_ms)?;
state.end()
}
}
impl BalanceResult {
pub fn compile_error(message: String) -> Self {
Self {
num_equations: 0,
num_unknowns: 0,
num_states: 0,
num_algebraic: 0,
num_parameters: 0,
num_inputs: 0,
num_external_connectors: 0,
status: BalanceStatus::CompileError(message),
compile_time_ms: 0,
}
}
pub fn is_balanced(&self) -> bool {
matches!(self.status, BalanceStatus::Balanced)
}
pub fn difference(&self) -> i64 {
self.num_equations as i64 - self.num_unknowns as i64
}
pub fn status_message(&self) -> String {
match &self.status {
BalanceStatus::Balanced => "balanced".to_string(),
BalanceStatus::Partial => {
let diff = -self.difference();
format!(
"partial (under by {}, {} external connectors)",
diff, self.num_external_connectors
)
}
BalanceStatus::Unbalanced => {
let diff = self.difference();
if diff > 0 {
format!("unbalanced: over-determined by {}", diff)
} else {
format!("unbalanced: under-determined by {}", -diff)
}
}
BalanceStatus::CompileError(msg) => format!("compile error: {}", msg),
}
}
}
impl Dae {
pub fn check_balance(&self) -> BalanceResult {
let num_states = count_scalars(&self.x);
let num_algebraic =
count_scalars(&self.y) + count_scalars(&self.z) + count_scalars(&self.m);
let num_unknowns = num_states + num_algebraic;
let all_params: IndexMap<String, Component> = self
.p
.iter()
.chain(self.cp.iter())
.chain(self.c.iter())
.map(|(k, v)| (k.clone(), v.clone()))
.collect();
let fx_count = count_equations(&self.fx, &all_params);
let fz_count = count_equations(&self.fz, &all_params);
let num_event_equations = count_unique_event_variables(&self.fr, &self.x);
let num_equations = fx_count + fz_count + num_event_equations;
let num_parameters = count_scalars(&self.p) + count_scalars(&self.cp);
let num_inputs = count_scalars(&self.u);
let num_external_connectors = count_external_connectors(&self.y)
+ count_external_connectors(&self.z)
+ count_external_connectors(&self.m);
let diff = num_equations as i64 - num_unknowns as i64;
let status = if diff == 0 {
BalanceStatus::Balanced
} else if diff > 0 {
BalanceStatus::Unbalanced
} else {
if num_external_connectors > 0 {
BalanceStatus::Partial
} else {
BalanceStatus::Unbalanced
}
};
BalanceResult {
num_equations,
num_unknowns,
num_states,
num_algebraic,
num_parameters,
num_inputs,
num_external_connectors,
status,
compile_time_ms: 0, }
}
}
fn count_equations(equations: &[Equation], params: &IndexMap<String, Component>) -> usize {
equations
.iter()
.map(|eq| count_single_equation(eq, params))
.sum()
}
fn count_single_equation(eq: &Equation, params: &IndexMap<String, Component>) -> usize {
match eq {
Equation::Simple { .. } | Equation::Connect { .. } | Equation::FunctionCall { .. } => 1,
Equation::Empty => 0,
Equation::For { indices, equations } => {
if let Some(range_size) = evaluate_for_range(indices, params) {
let inner_count = count_equations(equations, params);
inner_count * range_size
} else {
count_equations(equations, params)
}
}
Equation::If {
cond_blocks,
else_block,
} => {
for block in cond_blocks {
match eval_boolean(&block.cond, params) {
Some(true) => {
return count_equations(&block.eqs, params);
}
Some(false) => {
continue;
}
None => {
break;
}
}
}
let all_false = cond_blocks
.iter()
.all(|block| matches!(eval_boolean(&block.cond, params), Some(false)));
if all_false {
return else_block
.as_ref()
.map(|eqs| count_equations(eqs, params))
.unwrap_or(0);
}
let branch_counts: Vec<usize> = cond_blocks
.iter()
.map(|block| count_equations(&block.eqs, params))
.collect();
let else_count = else_block
.as_ref()
.map(|eqs| count_equations(eqs, params))
.unwrap_or(0);
branch_counts.into_iter().max().unwrap_or(0).max(else_count)
}
Equation::When(blocks) => {
blocks
.iter()
.map(|block| count_equations(&block.eqs, params))
.sum()
}
}
}
fn evaluate_for_range(
indices: &[crate::ir::ast::ForIndex],
params: &IndexMap<String, Component>,
) -> Option<usize> {
if indices.is_empty() {
return Some(1);
}
let first = &indices[0];
let range_size = evaluate_range_size(&first.range, params)?;
if indices.len() == 1 {
Some(range_size)
} else {
let rest_size = evaluate_for_range(&indices[1..], params)?;
Some(range_size * rest_size)
}
}
fn evaluate_range_size(expr: &Expression, params: &IndexMap<String, Component>) -> Option<usize> {
match expr {
Expression::Range { start, step, end } => {
let start_val = eval_integer(start, params)?;
let end_val = eval_integer(end, params)?;
let step_val = step
.as_ref()
.map(|s| eval_integer(s, params))
.unwrap_or(Some(1))?;
if step_val == 0 {
return None;
}
let count = if step_val > 0 {
if end_val >= start_val {
((end_val - start_val) / step_val + 1) as usize
} else {
0
}
} else if start_val >= end_val {
((start_val - end_val) / (-step_val) + 1) as usize
} else {
0
};
Some(count)
}
Expression::Terminal {
terminal_type: TerminalType::UnsignedInteger,
token,
} => {
let n: usize = token.text.parse().ok()?;
Some(n)
}
Expression::ComponentReference(_) => {
let val = eval_integer(expr, params)?;
if val > 0 { Some(val as usize) } else { Some(0) }
}
_ => None,
}
}
fn count_scalars(components: &IndexMap<String, Component>) -> usize {
components
.values()
.map(|comp| {
if comp.shape.is_empty() {
1
} else {
comp.shape.iter().product()
}
})
.sum()
}
fn count_external_connectors(components: &IndexMap<String, Component>) -> usize {
components
.values()
.filter(|comp| matches!(comp.connection, Connection::Flow(_)))
.map(|comp| {
if comp.shape.is_empty() {
1
} else {
comp.shape.iter().product()
}
})
.sum()
}
fn count_unique_event_variables(
fr: &IndexMap<String, Statement>,
states: &IndexMap<String, Component>,
) -> usize {
let mut unique_vars: HashSet<String> = HashSet::new();
for stmt in fr.values() {
if let Statement::Assignment { comp, .. } = stmt {
let var_name = comp.to_string();
if !states.contains_key(&var_name) {
unique_vars.insert(var_name);
}
}
}
unique_vars.len()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_balance_result_messages() {
let balanced = BalanceResult {
num_equations: 3,
num_unknowns: 3,
num_states: 2,
num_algebraic: 1,
num_parameters: 0,
num_inputs: 0,
num_external_connectors: 0,
status: BalanceStatus::Balanced,
compile_time_ms: 0,
};
assert!(balanced.status_message().contains("balanced"));
assert_eq!(balanced.difference(), 0);
assert_eq!(balanced.status, BalanceStatus::Balanced);
assert!(balanced.is_balanced());
let over = BalanceResult {
num_equations: 5,
num_unknowns: 3,
num_states: 2,
num_algebraic: 1,
num_parameters: 0,
num_inputs: 0,
num_external_connectors: 0,
status: BalanceStatus::Unbalanced,
compile_time_ms: 0,
};
assert!(over.status_message().contains("over-determined"));
assert_eq!(over.difference(), 2);
assert_eq!(over.status, BalanceStatus::Unbalanced);
assert!(!over.is_balanced());
let under_bug = BalanceResult {
num_equations: 3,
num_unknowns: 5,
num_states: 3,
num_algebraic: 2,
num_parameters: 0,
num_inputs: 0,
num_external_connectors: 0,
status: BalanceStatus::Unbalanced,
compile_time_ms: 0,
};
assert!(under_bug.status_message().contains("under-determined"));
assert_eq!(under_bug.difference(), -2);
assert_eq!(under_bug.status, BalanceStatus::Unbalanced);
assert!(!under_bug.is_balanced());
let partial = BalanceResult {
num_equations: 3,
num_unknowns: 5,
num_states: 2,
num_algebraic: 3,
num_parameters: 0,
num_inputs: 0,
num_external_connectors: 2,
status: BalanceStatus::Partial,
compile_time_ms: 0,
};
assert!(partial.status_message().contains("partial"));
assert_eq!(partial.difference(), -2);
assert_eq!(partial.status, BalanceStatus::Partial);
assert!(!partial.is_balanced());
}
}