use crate::ir::ast::{
ClassDefinition, Component, ComponentRefPart, ComponentReference, Equation, Expression,
ForIndex, Statement, Subscript, TerminalType, Token,
};
use crate::ir::transform::eval::{eval_boolean, eval_integer, eval_real};
use crate::ir::visitor::{Visitable, Visitor};
use indexmap::IndexMap;
use std::collections::HashSet;
struct DerVarFinder {
states: HashSet<String>,
}
impl DerVarFinder {
fn new() -> Self {
Self {
states: HashSet::new(),
}
}
fn into_states(self) -> HashSet<String> {
self.states
}
}
impl Visitor for DerVarFinder {
fn enter_expression(&mut self, node: &Expression) {
let Expression::FunctionCall { comp, args } = node else {
return;
};
if comp.parts.first().is_none_or(|p| p.ident.text != "der") {
return;
}
if let Some(name) = args.first().and_then(|arg| {
if let Expression::ComponentReference(comp_ref) = arg {
comp_ref.parts.first().map(|p| p.ident.text.clone())
} else {
None
}
}) {
self.states.insert(name);
}
}
}
struct AssignedVarFinder {
assigned: HashSet<String>,
}
impl AssignedVarFinder {
fn new() -> Self {
Self {
assigned: HashSet::new(),
}
}
fn into_assigned(self) -> HashSet<String> {
self.assigned
}
}
impl Visitor for AssignedVarFinder {
fn enter_statement(&mut self, node: &Statement) {
match node {
Statement::Assignment { comp, .. } => {
if let Some(first_part) = comp.parts.first() {
self.assigned.insert(first_part.ident.text.clone());
}
}
Statement::FunctionCall { outputs, .. } => {
for name in outputs.iter().filter_map(|o| {
if let Expression::ComponentReference(comp_ref) = o {
comp_ref.parts.first().map(|p| p.ident.text.clone())
} else {
None
}
}) {
self.assigned.insert(name);
}
}
_ => {}
}
}
}
pub fn expand_equations(class: &mut ClassDefinition) {
evaluate_computed_parameters(&mut class.components, &class.initial_equations);
evaluate_array_shapes(&mut class.components);
let mut expanded = Vec::new();
for eq in &class.equations {
expand_equation(eq, &class.components, &mut expanded);
}
let algorithm_equations = convert_algorithms_to_equations(&class.algorithms, &class.components);
expanded.extend(algorithm_equations);
let binding_equations = extract_binding_equations(&class.components, &expanded);
expanded.extend(binding_equations);
class.equations = expanded;
let mut expanded_init = Vec::new();
for eq in &class.initial_equations {
expand_equation(eq, &class.components, &mut expanded_init);
}
class.initial_equations = expanded_init;
}
fn evaluate_computed_parameters(
components: &mut IndexMap<String, Component>,
initial_equations: &[Equation],
) {
const MAX_ITERATIONS: usize = 10;
for _iteration in 0..MAX_ITERATIONS {
let mut updates: Vec<(String, Expression)> = Vec::new();
for eq in initial_equations {
if let Equation::Simple {
lhs: Expression::ComponentReference(comp_ref),
rhs,
} = eq
&& comp_ref.parts.iter().all(|p| p.subs.is_none())
{
let name = comp_ref
.parts
.iter()
.map(|p| p.ident.text.as_str())
.collect::<Vec<_>>()
.join(".");
if let Some(comp) = components.get(&name)
&& matches!(comp.start, Expression::Empty)
{
if let Some(val) = eval_integer(rhs, components) {
updates.push((
name,
Expression::Terminal {
terminal_type: TerminalType::UnsignedInteger,
token: Token {
text: val.to_string(),
..Default::default()
},
},
));
} else if let Some(val) = eval_real(rhs, components) {
updates.push((
name,
Expression::Terminal {
terminal_type: TerminalType::UnsignedReal,
token: Token {
text: val.to_string(),
..Default::default()
},
},
));
}
}
}
}
if updates.is_empty() {
break;
}
for (name, value) in updates {
if let Some(comp) = components.get_mut(&name) {
comp.start = value;
}
}
}
}
fn evaluate_array_shapes(components: &mut IndexMap<String, Component>) {
const MAX_ITERATIONS: usize = 10;
for _iteration in 0..MAX_ITERATIONS {
let to_evaluate: Vec<(String, Vec<Subscript>)> = components
.iter()
.filter(|(_, comp)| comp.shape.is_empty() && !comp.shape_expr.is_empty())
.map(|(name, comp)| (name.clone(), comp.shape_expr.clone()))
.collect();
let mut updates: Vec<(String, Vec<usize>)> = Vec::new();
for (name, shape_expr) in to_evaluate {
let mut evaluated_shape = Vec::new();
let mut all_evaluated = true;
for sub in &shape_expr {
match sub {
Subscript::Expression(expr) => {
if let Some(val) = eval_integer(expr, components) {
if val >= 0 {
evaluated_shape.push(val as usize);
} else {
all_evaluated = false;
break;
}
} else {
all_evaluated = false;
break;
}
}
Subscript::Range { .. } | Subscript::Empty => {
all_evaluated = false;
break;
}
}
}
if all_evaluated {
updates.push((name, evaluated_shape));
}
}
if updates.is_empty() {
break;
}
for (name, shape) in updates {
if let Some(comp) = components.get_mut(&name) {
comp.shape = shape;
}
}
}
}
fn extract_binding_equations(
components: &IndexMap<String, Component>,
equations: &[Equation],
) -> Vec<Equation> {
let states = find_differentiated_variables(equations);
let mut binding_equations = Vec::new();
for (name, comp) in components {
if comp.start_is_modification {
continue;
}
if matches!(comp.start, Expression::Empty) {
continue;
}
if is_default_value(&comp.start) {
continue;
}
if matches!(
comp.variability,
crate::ir::ast::Variability::Parameter(_) | crate::ir::ast::Variability::Constant(_)
) {
continue;
}
if matches!(comp.causality, crate::ir::ast::Causality::Input(..)) {
continue;
}
if states.contains(name) {
continue;
}
let lhs = Expression::ComponentReference(ComponentReference {
local: false,
parts: vec![ComponentRefPart {
ident: Token {
text: name.clone(),
..Default::default()
},
subs: None,
}],
});
if comp.shape.is_empty() {
binding_equations.push(Equation::Simple {
lhs,
rhs: comp.start.clone(),
});
} else {
expand_array_binding(name, &comp.shape, &comp.start, &mut binding_equations);
}
}
binding_equations
}
fn convert_algorithms_to_equations(
algorithms: &[Vec<Statement>],
components: &IndexMap<String, Component>,
) -> Vec<Equation> {
let mut equations = Vec::new();
for algorithm_section in algorithms {
let assigned_vars = find_assigned_variables(algorithm_section);
for var_name in assigned_vars {
if let Some(comp) = components.get(&var_name)
&& matches!(comp.causality, crate::ir::ast::Causality::Input(..))
{
continue;
}
let comp_ref = ComponentReference {
local: false,
parts: vec![ComponentRefPart {
ident: Token {
text: var_name.clone(),
..Default::default()
},
subs: None,
}],
};
let lhs = Expression::ComponentReference(comp_ref.clone());
let rhs = Expression::ComponentReference(comp_ref);
equations.push(Equation::Simple { lhs, rhs });
}
}
equations
}
fn find_assigned_variables(statements: &[Statement]) -> HashSet<String> {
let mut finder = AssignedVarFinder::new();
for stmt in statements {
stmt.accept(&mut finder);
}
finder.into_assigned()
}
fn find_differentiated_variables(equations: &[Equation]) -> HashSet<String> {
let mut finder = DerVarFinder::new();
for eq in equations {
eq.accept(&mut finder);
}
finder.into_states()
}
fn is_default_value(expr: &Expression) -> bool {
match expr {
Expression::Terminal {
terminal_type,
token,
} => {
if !token.location.file_name.is_empty() {
return false;
}
match terminal_type {
TerminalType::UnsignedInteger => token.text == "0",
TerminalType::UnsignedReal => {
if let Ok(val) = token.text.parse::<f64>() {
val == 0.0
} else {
false
}
}
TerminalType::Bool => token.text == "false",
_ => false,
}
}
_ => false,
}
}
fn expand_array_binding(
name: &str,
shape: &[usize],
rhs: &Expression,
equations: &mut Vec<Equation>,
) {
if shape.len() == 1 {
let size = shape[0];
for i in 1..=size {
let lhs = make_subscripted_ref(name, &[i]);
let rhs_elem = match rhs {
Expression::Array { elements, .. } => {
if i <= elements.len() {
elements[i - 1].clone()
} else {
subscript_expr(rhs.clone(), &[i])
}
}
_ => subscript_expr(rhs.clone(), &[i]),
};
equations.push(Equation::Simple { lhs, rhs: rhs_elem });
}
} else {
expand_array_binding_nd(name, shape, 0, &[], rhs, equations);
}
}
fn expand_array_binding_nd(
name: &str,
shape: &[usize],
dim: usize,
indices: &[usize],
rhs: &Expression,
equations: &mut Vec<Equation>,
) {
if dim >= shape.len() {
let lhs = make_subscripted_ref(name, indices);
let rhs_elem = subscript_expr_nd(rhs.clone(), indices);
equations.push(Equation::Simple { lhs, rhs: rhs_elem });
return;
}
for i in 1..=shape[dim] {
let mut new_indices = indices.to_vec();
new_indices.push(i);
expand_array_binding_nd(name, shape, dim + 1, &new_indices, rhs, equations);
}
}
fn expand_equation(
eq: &Equation,
components: &IndexMap<String, Component>,
out: &mut Vec<Equation>,
) {
match eq {
Equation::Empty => {}
Equation::Simple { lhs, rhs } => {
if let Some(size) = get_equation_array_size(lhs, components) {
if size == 0 {
return;
}
if size > 1 {
expand_array_equation(lhs, rhs, size, components, out);
return;
}
}
out.push(eq.clone());
}
Equation::For { indices, equations } => {
expand_for_equation(indices, equations, components, out);
}
Equation::If {
cond_blocks,
else_block,
} => {
let mut selected_branch: Option<&Vec<Equation>> = None;
for block in cond_blocks {
if let Some(val) = eval_boolean(&block.cond, components) {
if val {
selected_branch = Some(&block.eqs);
break;
}
} else {
let mut expanded_cond_blocks = Vec::new();
for block in cond_blocks {
let mut expanded_eqs = Vec::new();
for inner_eq in &block.eqs {
expand_equation(inner_eq, components, &mut expanded_eqs);
}
expanded_cond_blocks.push(crate::ir::ast::EquationBlock {
cond: block.cond.clone(),
eqs: expanded_eqs,
});
}
let expanded_else = else_block.as_ref().map(|eqs| {
let mut expanded = Vec::new();
for inner_eq in eqs {
expand_equation(inner_eq, components, &mut expanded);
}
expanded
});
out.push(Equation::If {
cond_blocks: expanded_cond_blocks,
else_block: expanded_else,
});
return;
}
}
let eqs_to_expand = selected_branch.or(else_block.as_ref());
if let Some(eqs) = eqs_to_expand {
for inner_eq in eqs {
expand_equation(inner_eq, components, out);
}
}
}
Equation::When(blocks) => {
let mut expanded_blocks = Vec::new();
for block in blocks {
let mut expanded_eqs = Vec::new();
for inner_eq in &block.eqs {
expand_equation(inner_eq, components, &mut expanded_eqs);
}
expanded_blocks.push(crate::ir::ast::EquationBlock {
cond: block.cond.clone(),
eqs: expanded_eqs,
});
}
out.push(Equation::When(expanded_blocks));
}
Equation::Connect { .. } | Equation::FunctionCall { .. } => {
out.push(eq.clone());
}
}
}
fn expand_for_equation(
indices: &[ForIndex],
equations: &[Equation],
components: &IndexMap<String, Component>,
out: &mut Vec<Equation>,
) {
if indices.is_empty() {
for eq in equations {
expand_equation(eq, components, out);
}
return;
}
let index = &indices[0];
let range = get_iteration_range(&index.range, components);
if let Some((start, end, step)) = range {
let index_name = &index.ident.text;
let mut i = start;
while (step > 0 && i <= end) || (step < 0 && i >= end) {
for eq in equations {
let substituted = substitute_index(eq, index_name, i);
expand_for_equation(&indices[1..], &[substituted], components, out);
}
i += step;
}
} else {
let mut expanded_inner = Vec::new();
for eq in equations {
expand_equation(eq, components, &mut expanded_inner);
}
out.push(Equation::For {
indices: indices.to_vec(),
equations: expanded_inner,
});
}
}
fn get_iteration_range(
expr: &Expression,
components: &IndexMap<String, Component>,
) -> Option<(i64, i64, i64)> {
match expr {
Expression::Range { start, step, end } => {
let start_val = eval_integer(start, components)?;
let end_val = eval_integer(end, components)?;
let step_val = step
.as_ref()
.map(|s| eval_integer(s, components))
.unwrap_or(Some(1))?;
Some((start_val, end_val, step_val))
}
Expression::Terminal {
terminal_type: TerminalType::UnsignedInteger,
token,
} => {
let n: i64 = token.text.parse().ok()?;
Some((1, n, 1))
}
Expression::ComponentReference(_) => {
let val = eval_integer(expr, components)?;
Some((1, val, 1))
}
_ => None,
}
}
fn substitute_index(eq: &Equation, index_name: &str, value: i64) -> Equation {
match eq {
Equation::Simple { lhs, rhs } => Equation::Simple {
lhs: substitute_in_expr(lhs, index_name, value),
rhs: substitute_in_expr(rhs, index_name, value),
},
Equation::For { indices, equations } => {
let is_shadowed = indices.iter().any(|idx| idx.ident.text == index_name);
if is_shadowed {
eq.clone()
} else {
Equation::For {
indices: indices.clone(),
equations: equations
.iter()
.map(|e| substitute_index(e, index_name, value))
.collect(),
}
}
}
Equation::If {
cond_blocks,
else_block,
} => Equation::If {
cond_blocks: cond_blocks
.iter()
.map(|b| crate::ir::ast::EquationBlock {
cond: substitute_in_expr(&b.cond, index_name, value),
eqs: b
.eqs
.iter()
.map(|e| substitute_index(e, index_name, value))
.collect(),
})
.collect(),
else_block: else_block.as_ref().map(|eqs| {
eqs.iter()
.map(|e| substitute_index(e, index_name, value))
.collect()
}),
},
Equation::When(blocks) => Equation::When(
blocks
.iter()
.map(|b| crate::ir::ast::EquationBlock {
cond: substitute_in_expr(&b.cond, index_name, value),
eqs: b
.eqs
.iter()
.map(|e| substitute_index(e, index_name, value))
.collect(),
})
.collect(),
),
_ => eq.clone(),
}
}
fn substitute_in_expr(expr: &Expression, index_name: &str, value: i64) -> Expression {
match expr {
Expression::ComponentReference(comp_ref) => {
if comp_ref.parts.len() == 1
&& comp_ref.parts[0].subs.is_none()
&& comp_ref.parts[0].ident.text == index_name
{
return Expression::Terminal {
terminal_type: TerminalType::UnsignedInteger,
token: Token {
text: value.to_string(),
..Default::default()
},
};
}
let new_parts: Vec<ComponentRefPart> = comp_ref
.parts
.iter()
.map(|part| ComponentRefPart {
ident: part.ident.clone(),
subs: part.subs.as_ref().map(|subs| {
subs.iter()
.map(|s| match s {
Subscript::Expression(e) => {
Subscript::Expression(substitute_in_expr(e, index_name, value))
}
_ => s.clone(),
})
.collect()
}),
})
.collect();
Expression::ComponentReference(ComponentReference {
local: comp_ref.local,
parts: new_parts,
})
}
Expression::Binary { op, lhs, rhs } => Expression::Binary {
op: op.clone(),
lhs: Box::new(substitute_in_expr(lhs, index_name, value)),
rhs: Box::new(substitute_in_expr(rhs, index_name, value)),
},
Expression::Unary { op, rhs } => Expression::Unary {
op: op.clone(),
rhs: Box::new(substitute_in_expr(rhs, index_name, value)),
},
Expression::FunctionCall { comp, args } => Expression::FunctionCall {
comp: comp.clone(),
args: args
.iter()
.map(|a| substitute_in_expr(a, index_name, value))
.collect(),
},
Expression::Array {
elements,
is_matrix,
} => Expression::Array {
elements: elements
.iter()
.map(|e| substitute_in_expr(e, index_name, value))
.collect(),
is_matrix: *is_matrix,
},
Expression::If {
branches,
else_branch,
} => Expression::If {
branches: branches
.iter()
.map(|(cond, expr)| {
(
substitute_in_expr(cond, index_name, value),
substitute_in_expr(expr, index_name, value),
)
})
.collect(),
else_branch: Box::new(substitute_in_expr(else_branch, index_name, value)),
},
Expression::Range { start, step, end } => Expression::Range {
start: Box::new(substitute_in_expr(start, index_name, value)),
step: step
.as_ref()
.map(|s| Box::new(substitute_in_expr(s, index_name, value))),
end: Box::new(substitute_in_expr(end, index_name, value)),
},
Expression::Parenthesized { inner } => Expression::Parenthesized {
inner: Box::new(substitute_in_expr(inner, index_name, value)),
},
_ => expr.clone(),
}
}
fn get_equation_array_size(
lhs: &Expression,
components: &IndexMap<String, Component>,
) -> Option<usize> {
match lhs {
Expression::ComponentReference(comp_ref) => {
if let Some(first_part) = comp_ref.parts.first() {
if first_part
.subs
.as_ref()
.map(|s| !s.is_empty())
.unwrap_or(false)
{
return Some(1);
}
let name = &first_part.ident.text;
if let Some(comp) = components.get(name) {
if comp.shape.is_empty() {
Some(1)
} else {
Some(comp.shape.iter().product())
}
} else {
Some(1)
}
} else {
Some(1)
}
}
Expression::FunctionCall { comp, args } => {
if let Some(first_part) = comp.parts.first()
&& first_part.ident.text == "der"
&& let Some(arg) = args.first()
{
return get_equation_array_size(arg, components);
}
Some(1)
}
Expression::Array { elements, .. } => {
let mut total = 0;
for elem in elements {
if let Some(size) = get_equation_array_size(elem, components) {
total += size;
} else {
return None;
}
}
Some(total)
}
_ => Some(1),
}
}
fn expand_array_equation(
lhs: &Expression,
rhs: &Expression,
size: usize,
components: &IndexMap<String, Component>,
out: &mut Vec<Equation>,
) {
for i in 1..=size {
let lhs_elem = flatten_and_subscript(lhs, i, components);
let rhs_elem = flatten_and_subscript(rhs, i, components);
out.push(Equation::Simple {
lhs: lhs_elem,
rhs: rhs_elem,
});
}
}
fn make_subscripted_ref(name: &str, indices: &[usize]) -> Expression {
Expression::ComponentReference(ComponentReference {
local: false,
parts: vec![ComponentRefPart {
ident: Token {
text: name.to_string(),
..Default::default()
},
subs: Some(
indices
.iter()
.map(|&i| {
Subscript::Expression(Expression::Terminal {
terminal_type: TerminalType::UnsignedInteger,
token: Token {
text: i.to_string(),
..Default::default()
},
})
})
.collect(),
),
}],
})
}
fn subscript_expr(expr: Expression, indices: &[usize]) -> Expression {
match expr {
Expression::ComponentReference(mut comp_ref) => {
if let Some(first_part) = comp_ref.parts.first_mut() {
let new_subs: Vec<Subscript> = indices
.iter()
.map(|&i| {
Subscript::Expression(Expression::Terminal {
terminal_type: TerminalType::UnsignedInteger,
token: Token {
text: i.to_string(),
..Default::default()
},
})
})
.collect();
first_part.subs = Some(new_subs);
}
Expression::ComponentReference(comp_ref)
}
Expression::FunctionCall { comp, args } => {
if let Some(first_part) = comp.parts.first()
&& first_part.ident.text == "der"
&& args.len() == 1
{
return Expression::FunctionCall {
comp,
args: vec![subscript_expr(args[0].clone(), indices)],
};
}
Expression::FunctionCall { comp, args }
}
Expression::Unary { op, rhs } => Expression::Unary {
op,
rhs: Box::new(subscript_expr(*rhs, indices)),
},
Expression::Array { ref elements, .. } => {
if elements.is_empty() {
expr
} else if indices.len() == 1 && indices[0] > 0 && indices[0] <= elements.len() {
elements[indices[0] - 1].clone()
} else {
subscript_expr(elements[0].clone(), indices)
}
}
_ => expr,
}
}
fn flatten_and_subscript(
expr: &Expression,
flat_index: usize,
components: &IndexMap<String, Component>,
) -> Expression {
match expr {
Expression::Array { elements, .. } => {
let mut cumulative = 0;
for elem in elements {
let elem_size = get_equation_array_size(elem, components).unwrap_or(1);
if flat_index <= cumulative + elem_size {
let local_index = flat_index - cumulative;
if elem_size == 1 {
if let Expression::ComponentReference(comp_ref) = elem
&& let Some(first_part) = comp_ref.parts.first()
{
let name = &first_part.ident.text;
if let Some(comp) = components.get(name)
&& !comp.shape.is_empty()
{
return subscript_expr(elem.clone(), &[1]);
}
}
return elem.clone();
} else {
return flatten_and_subscript(elem, local_index, components);
}
}
cumulative += elem_size;
}
expr.clone()
}
Expression::ComponentReference(comp_ref) => {
if let Some(first_part) = comp_ref.parts.first() {
let name = &first_part.ident.text;
if let Some(comp) = components.get(name)
&& !comp.shape.is_empty()
{
return subscript_expr(expr.clone(), &[flat_index]);
}
}
expr.clone()
}
_ => expr.clone(),
}
}
fn subscript_expr_nd(expr: Expression, indices: &[usize]) -> Expression {
subscript_expr(expr, indices)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_eval_integer() {
let components = IndexMap::new();
let expr = Expression::Terminal {
terminal_type: TerminalType::UnsignedInteger,
token: Token {
text: "3".to_string(),
..Default::default()
},
};
assert_eq!(eval_integer(&expr, &components), Some(3));
let expr = Expression::Unary {
op: crate::ir::ast::OpUnary::Minus(Token::default()),
rhs: Box::new(Expression::Terminal {
terminal_type: TerminalType::UnsignedInteger,
token: Token {
text: "5".to_string(),
..Default::default()
},
}),
};
assert_eq!(eval_integer(&expr, &components), Some(-5));
}
#[test]
fn test_get_iteration_range() {
let range = Expression::Range {
start: Box::new(Expression::Terminal {
terminal_type: TerminalType::UnsignedInteger,
token: Token {
text: "1".to_string(),
..Default::default()
},
}),
step: None,
end: Box::new(Expression::Terminal {
terminal_type: TerminalType::UnsignedInteger,
token: Token {
text: "3".to_string(),
..Default::default()
},
}),
};
let components = IndexMap::new();
assert_eq!(get_iteration_range(&range, &components), Some((1, 3, 1)));
}
}