use crate::dae::ast::Dae;
use crate::ir::analysis::condition_finder::ConditionFinder;
use crate::ir::analysis::state_finder::StateFinder;
use crate::ir::ast::{
Causality, ClassDefinition, Component, Equation, Expression, Name, Statement, Token,
Variability,
};
use crate::ir::error::IrError;
use crate::ir::transform::constants::BUILTIN_REINIT;
use crate::ir::transform::eval::eval_boolean;
use crate::ir::visitor::{MutVisitable, Visitable, Visitor};
use git_version::git_version;
use indexmap::IndexMap;
use std::collections::HashSet;
use anyhow::Result;
struct DefinedVariableCollector {
defined: HashSet<String>,
}
impl DefinedVariableCollector {
fn new() -> Self {
Self {
defined: HashSet::new(),
}
}
fn into_defined(self) -> HashSet<String> {
self.defined
}
}
impl Visitor for DefinedVariableCollector {
fn enter_equation(&mut self, node: &Equation) {
if let Equation::Simple {
lhs: Expression::ComponentReference(cref),
..
} = node
{
self.defined.insert(cref.to_string());
}
}
}
const GIT_VERSION: &str = git_version!(
args = ["--tags", "--always", "--dirty=-dirty"],
fallback = "unknown"
);
fn equation_assigns_to_filtered(eq: &Equation, filtered: &HashSet<String>) -> bool {
match eq {
Equation::Simple { lhs, .. } => {
if let Expression::ComponentReference(cref) = lhs {
if let Some(first_part) = cref.parts.first() {
return filtered.contains(&first_part.ident.text);
}
}
false
}
_ => false,
}
}
fn filter_equations(equations: Vec<Equation>, filtered: &HashSet<String>) -> Vec<Equation> {
equations
.into_iter()
.filter(|eq| !equation_assigns_to_filtered(eq, filtered))
.map(|eq| match eq {
Equation::If {
cond_blocks,
else_block,
} => Equation::If {
cond_blocks: cond_blocks
.into_iter()
.map(|mut block| {
block.eqs = filter_equations(block.eqs, filtered);
block
})
.collect(),
else_block: else_block.map(|eqs| filter_equations(eqs, filtered)),
},
Equation::For {
indices,
equations: inner,
} => Equation::For {
indices,
equations: filter_equations(inner, filtered),
},
Equation::When(blocks) => Equation::When(
blocks
.into_iter()
.map(|mut block| {
block.eqs = filter_equations(block.eqs, filtered);
block
})
.collect(),
),
other => other,
})
.collect()
}
fn collect_defined_variables(equations: &[Equation]) -> HashSet<String> {
let mut collector = DefinedVariableCollector::new();
for eq in equations {
eq.accept(&mut collector);
}
collector.into_defined()
}
fn expand_array_component(comp: &Component) -> Vec<(String, Component)> {
if comp.shape.is_empty() {
return vec![(comp.name.clone(), comp.clone())];
}
let total_elements: usize = comp.shape.iter().product();
if total_elements == 0 {
return vec![];
}
let mut result = Vec::with_capacity(total_elements);
let indices = generate_indices(&comp.shape);
for idx in indices {
let subscript_str = idx
.iter()
.map(|i| i.to_string())
.collect::<Vec<_>>()
.join(",");
let scalar_name = format!("{}[{}]", comp.name, subscript_str);
let mut scalar_comp = comp.clone();
scalar_comp.name = scalar_name.clone();
scalar_comp.shape = vec![];
if !matches!(comp.start, Expression::Empty) {
scalar_comp.start = extract_array_element(&comp.start, &idx);
}
result.push((scalar_name, scalar_comp));
}
result
}
fn generate_indices(shape: &[usize]) -> Vec<Vec<usize>> {
if shape.is_empty() {
return vec![vec![]];
}
let mut result = Vec::new();
generate_indices_recursive(shape, 0, &mut vec![], &mut result);
result
}
fn generate_indices_recursive(
shape: &[usize],
dim: usize,
current: &mut Vec<usize>,
result: &mut Vec<Vec<usize>>,
) {
if dim >= shape.len() {
result.push(current.clone());
return;
}
for i in 1..=shape[dim] {
current.push(i);
generate_indices_recursive(shape, dim + 1, current, result);
current.pop();
}
}
fn extract_array_element(expr: &Expression, indices: &[usize]) -> Expression {
if indices.is_empty() {
return expr.clone();
}
match expr {
Expression::Array { elements, .. } => {
let idx = indices[0];
if idx > 0 && idx <= elements.len() {
if indices.len() == 1 {
elements[idx - 1].clone()
} else {
extract_array_element(&elements[idx - 1], &indices[1..])
}
} else {
expr.clone()
}
}
_ => {
expr.clone()
}
}
}
pub fn create_dae(fclass: &mut ClassDefinition) -> Result<Dae> {
let mut dae = Dae {
model_name: fclass.name.text.clone(),
rumoca_version: env!("CARGO_PKG_VERSION").to_string(),
git_version: GIT_VERSION.to_string(),
t: Component {
name: "t".to_string(),
type_name: Name {
name: vec![Token {
text: "Real".to_string(),
..Default::default()
}],
},
..Default::default()
},
..Default::default()
};
let mut state_finder = StateFinder::default();
fclass.accept_mut(&mut state_finder);
let mut condition_finder = ConditionFinder::default();
fclass.accept_mut(&mut condition_finder);
let defined_variables = collect_defined_variables(&fclass.equations);
let mut all_params: IndexMap<String, Component> = IndexMap::new();
for (_, comp) in &fclass.components {
if matches!(
comp.variability,
Variability::Parameter(..) | Variability::Constant(..)
) {
let expanded = expand_array_component(comp);
for (name, c) in expanded {
all_params.insert(name, c);
}
}
}
let mut filtered_components: HashSet<String> = HashSet::new();
for (_, comp) in &fclass.components {
if let Some(ref cond_expr) = comp.condition {
match eval_boolean(cond_expr, &all_params) {
Some(false) => {
filtered_components.insert(comp.name.clone());
continue;
}
Some(true) => {
}
None => {
}
}
}
let expanded = expand_array_component(comp);
for (scalar_name, scalar_comp) in expanded {
match scalar_comp.variability {
Variability::Parameter(..) => {
dae.p.insert(scalar_name, scalar_comp);
}
Variability::Constant(..) => {
dae.cp.insert(scalar_name, scalar_comp);
}
Variability::Discrete(..) => {
dae.m.insert(scalar_name, scalar_comp);
}
Variability::Empty => {
match scalar_comp.causality {
Causality::Input(..) => {
let base_name = &comp.name; let is_top_level = !base_name.contains('.');
if is_top_level {
dae.u.insert(scalar_name, scalar_comp);
} else if defined_variables.contains(&scalar_name) {
dae.y.insert(scalar_name, scalar_comp);
} else {
dae.u.insert(scalar_name, scalar_comp);
}
}
Causality::Output(..) | Causality::Empty => {
let base_name = comp.name.clone();
if state_finder.states.contains(&base_name)
|| state_finder.states.contains(&scalar_name)
{
dae.x.insert(scalar_name, scalar_comp);
} else {
dae.y.insert(scalar_name, scalar_comp);
}
}
}
}
}
}
}
dae.c = condition_finder.conditions.clone();
dae.fc = condition_finder.expressions.clone();
let mut exclude_from_matching: HashSet<String> = HashSet::new();
for name in dae.p.keys() {
exclude_from_matching.insert(name.clone());
}
for name in dae.cp.keys() {
exclude_from_matching.insert(name.clone());
}
for name in dae.u.keys() {
exclude_from_matching.insert(name.clone());
}
for name in dae.x.keys() {
exclude_from_matching.insert(name.clone());
}
exclude_from_matching.insert("time".to_string());
let transformed_equations =
crate::ir::structural::blt_transform(fclass.equations.clone(), &exclude_from_matching);
let filtered_equations = if filtered_components.is_empty() {
transformed_equations
} else {
filter_equations(transformed_equations, &filtered_components)
};
for eq in &filtered_equations {
match &eq {
Equation::Simple { .. } => {
dae.fx.push(eq.clone());
}
Equation::If { .. } => {
dae.fx.push(eq.clone());
}
Equation::For { .. } => {
dae.fx.push(eq.clone());
}
Equation::Connect { .. } => {
return Err(IrError::UnexpandedConnectionEquation.into());
}
Equation::When(blocks) => {
for block in blocks {
for eq in &block.eqs {
match eq {
Equation::FunctionCall { comp, args } => {
let name = comp.to_string();
if name == BUILTIN_REINIT {
let cond_name = match &block.cond {
Expression::ComponentReference(cref) => cref.to_string(),
other => {
let loc = other
.get_location()
.map(|l| {
format!(
" at {}:{}:{}",
l.file_name, l.start_line, l.start_column
)
})
.unwrap_or_default();
anyhow::bail!(
"Unsupported condition type in 'when' block{}. \
Expected a component reference.",
loc
)
}
};
if args.len() != 2 {
return Err(
IrError::InvalidReinitArgCount(args.len()).into()
);
}
match &args[0] {
Expression::ComponentReference(cref) => {
dae.fr.insert(
cond_name,
Statement::Assignment {
comp: cref.clone(),
value: args[1].clone(),
},
);
}
_ => {
return Err(IrError::InvalidReinitFirstArg(format!(
"{:?}",
args[0]
))
.into());
}
}
}
}
Equation::Simple { lhs, rhs } => {
let cond_name = match &block.cond {
Expression::ComponentReference(cref) => cref.to_string(),
other => {
let loc = other
.get_location()
.map(|l| {
format!(
" at {}:{}:{}",
l.file_name, l.start_line, l.start_column
)
})
.unwrap_or_default();
anyhow::bail!(
"Unsupported condition type in 'when' block{}. \
Expected a component reference.",
loc
)
}
};
match lhs {
Expression::ComponentReference(cref) => {
dae.fr.insert(
format!("{}_{}", cond_name, cref),
Statement::Assignment {
comp: cref.clone(),
value: rhs.clone(),
},
);
}
Expression::Tuple { elements } => {
for (i, elem) in elements.iter().enumerate() {
if let Expression::ComponentReference(cref) = elem {
dae.fr.insert(
format!("{}_tuple_{}", cond_name, i),
Statement::Assignment {
comp: cref.clone(),
value: rhs.clone(), },
);
}
}
}
_ => {
dae.fz.push(eq.clone());
}
}
}
Equation::If { .. } | Equation::For { .. } => {
dae.fz.push(eq.clone());
}
other => {
let loc = other
.get_location()
.map(|l| {
format!(
" at {}:{}:{}",
l.file_name, l.start_line, l.start_column
)
})
.unwrap_or_default();
anyhow::bail!(
"Unsupported equation type in 'when' block{}. \
Only assignments, 'reinit', 'if' and 'for' are currently supported.",
loc
)
}
}
}
}
}
_ => {}
}
}
for eq in &fclass.initial_equations {
match eq {
Equation::Simple { .. } | Equation::For { .. } | Equation::If { .. } => {
dae.fx_init.push(eq.clone());
}
_ => {
dae.fx_init.push(eq.clone());
}
}
}
Ok(dae)
}