use crate::ir;
use crate::ir::ast::{
ComponentRefPart, ComponentReference, Connection, Equation, Expression, OpBinary, TerminalType,
Token,
};
use anyhow::Result;
use indexmap::{IndexMap, IndexSet};
use super::ClassDict;
pub(super) fn make_comp_ref(name: &str) -> ComponentReference {
if let Some(bracket_start) = name.find('[') {
let base = &name[..bracket_start];
let subs_str = &name[bracket_start..];
let subs = parse_subscript_string(subs_str);
ComponentReference {
local: false,
parts: vec![ComponentRefPart {
ident: Token {
text: base.to_string(),
..Default::default()
},
subs: if subs.is_empty() { None } else { Some(subs) },
}],
}
} else {
ComponentReference {
local: false,
parts: vec![ComponentRefPart {
ident: Token {
text: name.to_string(),
..Default::default()
},
subs: None,
}],
}
}
}
fn parse_subscript_string(s: &str) -> Vec<ir::ast::Subscript> {
let inner = s.trim_start_matches('[').trim_end_matches(']');
if inner.is_empty() {
return vec![];
}
inner
.split(',')
.map(|part| {
let trimmed = part.trim();
if trimmed.parse::<i64>().is_ok() {
ir::ast::Subscript::Expression(Expression::Terminal {
token: Token {
text: trimmed.to_string(),
..Default::default()
},
terminal_type: ir::ast::TerminalType::UnsignedInteger,
})
} else {
ir::ast::Subscript::Expression(Expression::ComponentReference(ComponentReference {
local: false,
parts: vec![ComponentRefPart {
ident: Token {
text: trimmed.to_string(),
..Default::default()
},
subs: None,
}],
}))
}
})
.collect()
}
fn make_simple_eq(lhs: &str, rhs: &str) -> Equation {
Equation::Simple {
lhs: Expression::ComponentReference(make_comp_ref(lhs)),
rhs: Expression::ComponentReference(make_comp_ref(rhs)),
}
}
fn make_sum_eq(vars: &[String]) -> Equation {
if vars.is_empty() {
return Equation::Empty;
}
if vars.len() == 1 {
return Equation::Simple {
lhs: Expression::ComponentReference(make_comp_ref(&vars[0])),
rhs: Expression::Terminal {
token: Token {
text: "0".to_string(),
..Default::default()
},
terminal_type: TerminalType::UnsignedReal,
},
};
}
let mut sum = Expression::ComponentReference(make_comp_ref(&vars[0]));
for var in vars.iter().skip(1) {
sum = Expression::Binary {
op: OpBinary::Add(Token::default()),
lhs: Box::new(sum),
rhs: Box::new(Expression::ComponentReference(make_comp_ref(var))),
};
}
Equation::Simple {
lhs: sum,
rhs: Expression::Terminal {
token: Token {
text: "0".to_string(),
..Default::default()
},
terminal_type: TerminalType::UnsignedReal,
},
}
}
pub(super) fn contains_connects(eq: &Equation) -> bool {
match eq {
Equation::Connect { .. } => true,
Equation::For { equations, .. } => equations.iter().any(contains_connects),
Equation::If {
cond_blocks,
else_block,
} => {
cond_blocks
.iter()
.any(|b| b.eqs.iter().any(contains_connects))
|| else_block
.as_ref()
.is_some_and(|eqs| eqs.iter().any(contains_connects))
}
Equation::When(blocks) => blocks.iter().any(|b| b.eqs.iter().any(contains_connects)),
_ => false,
}
}
fn extract_connect_equations_recursive(
eq: &Equation,
connect_eqs: &mut Vec<(ComponentReference, ComponentReference)>,
) -> Option<Equation> {
match eq {
Equation::Connect { lhs, rhs } => {
connect_eqs.push((lhs.clone(), rhs.clone()));
None }
Equation::For { indices, equations } => {
if equations.iter().any(contains_connects) {
Some(eq.clone())
} else {
let mut filtered_eqs = Vec::new();
for inner_eq in equations {
if let Some(filtered) =
extract_connect_equations_recursive(inner_eq, connect_eqs)
{
filtered_eqs.push(filtered);
}
}
if filtered_eqs.is_empty() {
None
} else {
Some(Equation::For {
indices: indices.clone(),
equations: filtered_eqs,
})
}
}
}
Equation::If {
cond_blocks,
else_block,
} => {
let mut new_cond_blocks = Vec::new();
for block in cond_blocks {
let mut filtered_eqs = Vec::new();
for inner_eq in &block.eqs {
if let Some(filtered) =
extract_connect_equations_recursive(inner_eq, connect_eqs)
{
filtered_eqs.push(filtered);
}
}
new_cond_blocks.push(ir::ast::EquationBlock {
cond: block.cond.clone(),
eqs: filtered_eqs,
});
}
let new_else = else_block.as_ref().map(|eqs| {
let mut filtered = Vec::new();
for inner_eq in eqs {
if let Some(f) = extract_connect_equations_recursive(inner_eq, connect_eqs) {
filtered.push(f);
}
}
filtered
});
Some(Equation::If {
cond_blocks: new_cond_blocks,
else_block: new_else,
})
}
Equation::When(blocks) => {
let mut new_blocks = Vec::new();
for block in blocks {
let mut filtered_eqs = Vec::new();
for inner_eq in &block.eqs {
if let Some(filtered) =
extract_connect_equations_recursive(inner_eq, connect_eqs)
{
filtered_eqs.push(filtered);
}
}
new_blocks.push(ir::ast::EquationBlock {
cond: block.cond.clone(),
eqs: filtered_eqs,
});
}
Some(Equation::When(new_blocks))
}
_ => Some(eq.clone()),
}
}
fn strip_subscripts_from_name(name: &str) -> String {
let mut result = String::new();
let mut in_bracket = false;
for c in name.chars() {
if c == '[' {
in_bracket = true;
} else if c == ']' {
in_bracket = false;
} else if !in_bracket {
result.push(c);
}
}
result
}
fn transform_pin_to_field(pin: &str, field: &str) -> String {
if let Some(bracket_start) = pin.find('[') {
let mut bracket_depth = 0;
let mut bracket_end = bracket_start;
for (i, c) in pin[bracket_start..].char_indices() {
match c {
'[' => bracket_depth += 1,
']' => {
bracket_depth -= 1;
if bracket_depth == 0 {
bracket_end = bracket_start + i;
break;
}
}
_ => {}
}
}
let base = &pin[..bracket_start];
let subscripts = &pin[bracket_start..=bracket_end];
let suffix = &pin[bracket_end + 1..];
format!("{}{}.{}{}", base, suffix, field, subscripts)
} else {
format!("{}.{}", pin, field)
}
}
fn generate_connection_equations(
pins: &[&String],
connector_class: &ir::ast::ClassDefinition,
equations: &mut Vec<Equation>,
) {
for (var_name, var_comp) in &connector_class.components {
let is_flow = matches!(var_comp.connection, Connection::Flow(_));
if is_flow {
let flow_vars: Vec<String> = pins
.iter()
.map(|pin| transform_pin_to_field(pin, var_name))
.collect();
equations.push(make_sum_eq(&flow_vars));
} else {
let first_var = transform_pin_to_field(pins[0], var_name);
for pin in pins.iter().skip(1) {
let other_var = transform_pin_to_field(pin, var_name);
equations.push(make_simple_eq(&first_var, &other_var));
}
}
}
}
fn expand_single_connect(
lhs: &ComponentReference,
rhs: &ComponentReference,
class_dict: &ClassDict,
pin_types: &IndexMap<String, String>,
) -> Vec<Equation> {
let mut equations = Vec::new();
let lhs_name = lhs.to_string();
let rhs_name = rhs.to_string();
let lhs_base = strip_subscripts_from_name(&lhs_name);
let connector_type = pin_types.get(&lhs_base);
if let Some(connector_type) = connector_type
&& let Some(connector_class) = class_dict.get(connector_type)
&& !connector_class.components.is_empty()
{
for (var_name, var_comp) in &connector_class.components {
let is_flow = matches!(var_comp.connection, Connection::Flow(_));
let lhs_field = transform_pin_to_field(&lhs_name, var_name);
let rhs_field = transform_pin_to_field(&rhs_name, var_name);
if is_flow {
equations.push(make_sum_eq(&[lhs_field, rhs_field]));
} else {
equations.push(make_simple_eq(&lhs_field, &rhs_field));
}
}
} else {
equations.push(make_simple_eq(&lhs_name, &rhs_name));
}
equations
}
fn expand_connects_in_equation(
eq: &Equation,
class_dict: &ClassDict,
pin_types: &IndexMap<String, String>,
) -> Vec<Equation> {
match eq {
Equation::Connect { lhs, rhs } => {
expand_single_connect(lhs, rhs, class_dict, pin_types)
}
Equation::For { indices, equations } => {
let mut expanded_inner = Vec::new();
for inner_eq in equations {
expanded_inner.extend(expand_connects_in_equation(inner_eq, class_dict, pin_types));
}
if expanded_inner.is_empty() {
vec![]
} else {
vec![Equation::For {
indices: indices.clone(),
equations: expanded_inner,
}]
}
}
Equation::If {
cond_blocks,
else_block,
} => {
let new_cond_blocks: Vec<_> = cond_blocks
.iter()
.map(|block| {
let mut expanded_eqs = Vec::new();
for inner_eq in &block.eqs {
expanded_eqs
.extend(expand_connects_in_equation(inner_eq, class_dict, pin_types));
}
ir::ast::EquationBlock {
cond: block.cond.clone(),
eqs: expanded_eqs,
}
})
.collect();
let new_else = else_block.as_ref().map(|eqs| {
let mut expanded = Vec::new();
for inner_eq in eqs {
expanded.extend(expand_connects_in_equation(inner_eq, class_dict, pin_types));
}
expanded
});
vec![Equation::If {
cond_blocks: new_cond_blocks,
else_block: new_else,
}]
}
Equation::When(blocks) => {
let new_blocks: Vec<_> = blocks
.iter()
.map(|block| {
let mut expanded_eqs = Vec::new();
for inner_eq in &block.eqs {
expanded_eqs
.extend(expand_connects_in_equation(inner_eq, class_dict, pin_types));
}
ir::ast::EquationBlock {
cond: block.cond.clone(),
eqs: expanded_eqs,
}
})
.collect();
vec![Equation::When(new_blocks)]
}
_ => vec![eq.clone()],
}
}
pub(super) fn expand_connect_equations(
fclass: &mut ir::ast::ClassDefinition,
class_dict: &ClassDict,
pin_types: &IndexMap<String, String>,
) -> Result<()> {
let mut connection_sets: IndexMap<String, IndexSet<String>> = IndexMap::new();
let mut connect_eqs: Vec<(ComponentReference, ComponentReference)> = Vec::new();
let mut other_eqs: Vec<Equation> = Vec::new();
for eq in &fclass.equations {
if let Some(filtered_eq) = extract_connect_equations_recursive(eq, &mut connect_eqs) {
other_eqs.push(filtered_eq);
}
}
let has_for_connects = other_eqs.iter().any(contains_connects);
if connect_eqs.is_empty() && !has_for_connects {
return Ok(());
}
let mut parent: IndexMap<String, String> = IndexMap::new();
fn find(parent: &mut IndexMap<String, String>, x: &str) -> String {
if !parent.contains_key(x) {
parent.insert(x.to_string(), x.to_string());
return x.to_string();
}
let p = parent.get(x).expect("parent should exist").clone();
if p != x {
let root = find(parent, &p);
parent.insert(x.to_string(), root.clone());
return root;
}
p
}
fn union(parent: &mut IndexMap<String, String>, x: &str, y: &str) {
let px = find(parent, x);
let py = find(parent, y);
if px != py {
parent.insert(py, px);
}
}
for (lhs, rhs) in &connect_eqs {
let lhs_name = lhs.to_string();
let rhs_name = rhs.to_string();
union(&mut parent, &lhs_name, &rhs_name);
}
for (lhs, rhs) in &connect_eqs {
let lhs_name = lhs.to_string();
let rhs_name = rhs.to_string();
let root = find(&mut parent, &lhs_name);
connection_sets
.entry(root.clone())
.or_default()
.insert(lhs_name);
let root = find(&mut parent, &rhs_name);
connection_sets.entry(root).or_default().insert(rhs_name);
}
let mut new_equations: Vec<Equation> = Vec::new();
for (_root, pins) in &connection_sets {
if pins.len() < 2 {
continue;
}
let pins_vec: Vec<&String> = pins.iter().collect();
let first_pin = pins_vec[0];
let first_pin_base = strip_subscripts_from_name(first_pin);
let mut generated = false;
if let Some(connector_type) = pin_types.get(&first_pin_base)
&& let Some(connector_class) = class_dict.get(connector_type)
&& !connector_class.components.is_empty()
{
generate_connection_equations(&pins_vec, connector_class, &mut new_equations);
generated = true;
}
if !generated {
let mut output_pins: Vec<&String> = Vec::new();
let mut input_pins: Vec<&String> = Vec::new();
for pin in &pins_vec {
if let Some(comp) = fclass.components.get(*pin) {
if matches!(comp.causality, ir::ast::Causality::Input(..)) {
input_pins.push(pin);
} else {
output_pins.push(pin);
}
} else {
output_pins.push(pin);
}
}
let source_pin = output_pins.first().or(pins_vec.first());
if let Some(source) = source_pin {
for input_pin in &input_pins {
if *input_pin != *source {
new_equations.push(make_simple_eq(input_pin, source));
}
}
for output_pin in output_pins.iter().skip(1) {
if *output_pin != *source {
new_equations.push(make_simple_eq(output_pin, source));
}
}
}
}
}
let mut final_eqs: Vec<Equation> = Vec::new();
for eq in other_eqs {
if contains_connects(&eq) {
let expanded = expand_connects_in_equation(&eq, class_dict, pin_types);
final_eqs.extend(expanded);
} else {
final_eqs.push(eq);
}
}
fclass.equations = final_eqs;
fclass.equations.extend(new_equations);
Ok(())
}