use crate::ir::ast::{Expression, TerminalType, Token};
use crate::ir::transform::constants::get_modelica_constant;
use crate::ir::visitor::MutVisitor;
const SAFE_SHORT_NAMES: &[&str] = &[
"pi", "mu_0", "epsilon_0", "sigma", "g_n", "N_A", "D2R", "R2D", "gamma", "eps", "small", "inf", "T_zero", ];
fn is_safe_short_name(name: &str) -> bool {
SAFE_SHORT_NAMES.contains(&name)
}
pub struct ConstantSubstitutor {
pub substitution_count: usize,
}
impl ConstantSubstitutor {
pub fn new() -> Self {
Self {
substitution_count: 0,
}
}
}
impl Default for ConstantSubstitutor {
fn default() -> Self {
Self::new()
}
}
impl MutVisitor for ConstantSubstitutor {
fn exit_expression(&mut self, expr: &mut Expression) {
if let Expression::ComponentReference(comp_ref) = expr {
let name = comp_ref.to_string();
let is_qualified = name.starts_with("Modelica.Constants.");
let should_substitute = if is_qualified {
true } else {
is_safe_short_name(&name)
};
if should_substitute && let Some(value) = get_modelica_constant(&name) {
*expr = Expression::Terminal {
terminal_type: TerminalType::UnsignedReal,
token: Token {
text: format_float(value),
..Default::default()
},
};
self.substitution_count += 1;
}
}
}
}
fn format_float(value: f64) -> String {
if value == f64::MAX {
"1.7976931348623157e308".to_string()
} else if value == f64::MIN_POSITIVE {
"2.2250738585072014e-308".to_string()
} else if value.abs() >= 1e10 || (value != 0.0 && value.abs() < 1e-4) {
format!("{:e}", value)
} else {
let s = format!("{}", value);
if s.contains('.') || s.contains('e') {
s
} else {
format!("{}.0", s)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::ast::{ComponentRefPart, ComponentReference};
use crate::ir::visitor::MutVisitable;
fn make_comp_ref(name: &str) -> Expression {
let parts: Vec<ComponentRefPart> = name
.split('.')
.map(|p| ComponentRefPart {
ident: Token {
text: p.to_string(),
..Default::default()
},
subs: None,
})
.collect();
Expression::ComponentReference(ComponentReference {
local: false,
parts,
})
}
#[test]
fn test_substitute_pi() {
let mut expr = make_comp_ref("pi");
let mut sub = ConstantSubstitutor::new();
expr.accept_mut(&mut sub);
if let Expression::Terminal { token, .. } = &expr {
let value: f64 = token.text.parse().unwrap();
assert!((value - std::f64::consts::PI).abs() < 1e-10);
} else {
panic!("Expected Terminal expression");
}
assert_eq!(sub.substitution_count, 1);
}
#[test]
fn test_substitute_qualified_pi() {
let mut expr = make_comp_ref("Modelica.Constants.pi");
let mut sub = ConstantSubstitutor::new();
expr.accept_mut(&mut sub);
if let Expression::Terminal { token, .. } = &expr {
let value: f64 = token.text.parse().unwrap();
assert!((value - std::f64::consts::PI).abs() < 1e-10);
} else {
panic!("Expected Terminal expression");
}
}
#[test]
fn test_substitute_mu_0() {
let mut expr = make_comp_ref("mu_0");
let mut sub = ConstantSubstitutor::new();
expr.accept_mut(&mut sub);
if let Expression::Terminal { token, .. } = &expr {
let value: f64 = token.text.parse().unwrap();
assert!((value - 1.25663706212e-6).abs() < 1e-16);
} else {
panic!("Expected Terminal expression");
}
}
#[test]
fn test_no_substitute_unknown() {
let mut expr = make_comp_ref("unknown_var");
let mut sub = ConstantSubstitutor::new();
expr.accept_mut(&mut sub);
assert!(matches!(expr, Expression::ComponentReference(_)));
assert_eq!(sub.substitution_count, 0);
}
#[test]
fn test_no_substitute_ambiguous_short_name() {
let mut expr = make_comp_ref("h");
let mut sub = ConstantSubstitutor::new();
expr.accept_mut(&mut sub);
assert!(matches!(expr, Expression::ComponentReference(_)));
assert_eq!(sub.substitution_count, 0);
}
#[test]
fn test_substitute_qualified_h() {
let mut expr = make_comp_ref("Modelica.Constants.h");
let mut sub = ConstantSubstitutor::new();
expr.accept_mut(&mut sub);
if let Expression::Terminal { token, .. } = &expr {
let value: f64 = token.text.parse().unwrap();
assert!((value - 6.62607015e-34).abs() < 1e-44);
} else {
panic!("Expected Terminal expression");
}
assert_eq!(sub.substitution_count, 1);
}
}