use super::function_collector::collect_all_functions;
use super::result::CompilationResult;
use crate::dae::balance::BalanceResult;
use crate::ir::analysis::type_checker::{
check_algorithm_assignments_with_types, check_array_bounds, check_assert_arguments,
check_break_return_context, check_builtin_attribute_modifiers, check_cardinality_arguments,
check_cardinality_context, check_class_member_access, check_component_bindings,
check_scalar_subscripts, check_start_modification_dimensions,
};
use crate::ir::analysis::var_validator::VarValidator;
use crate::ir::ast::{ClassType, StoredDefinition};
use crate::ir::structural::create_dae::create_dae;
use crate::ir::transform::array_comprehension::expand_array_comprehensions;
use crate::ir::transform::constant_substitutor::ConstantSubstitutor;
use crate::ir::transform::enum_substitutor::EnumSubstitutor;
use crate::ir::transform::equation_expander::expand_equations;
use crate::ir::transform::flatten::{
ClassDict, FileDependencies, flatten, flatten_with_deps, flatten_with_library_dicts,
is_cache_enabled,
};
use crate::ir::transform::function_inliner::FunctionInliner;
use crate::ir::transform::import_resolver::ImportResolver;
use crate::ir::transform::operator_expand::{
build_operator_record_map, build_type_map, expand_complex_equations,
};
use crate::ir::transform::tuple_expander::expand_tuple_equations;
use crate::ir::visitor::MutVisitable;
use anyhow::Result;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::{Arc, LazyLock, RwLock};
#[cfg(not(target_arch = "wasm32"))]
use std::time::Instant;
#[cfg(target_arch = "wasm32")]
use web_time::Instant;
#[cfg(all(feature = "cache", not(target_arch = "wasm32")))]
#[derive(serde::Serialize, serde::Deserialize)]
struct DaeCacheEntry {
result: BalanceResult,
dependencies: FileDependencies,
}
static DAE_CACHE: LazyLock<RwLock<HashMap<String, (BalanceResult, FileDependencies)>>> =
LazyLock::new(|| RwLock::new(HashMap::new()));
const DAE_CACHE_VERSION: u32 = 2;
const RUMOCA_VERSION: &str = env!("CARGO_PKG_VERSION");
const GIT_VERSION: &str = env!("RUMOCA_GIT_VERSION");
fn compute_dae_cache_key(model_name: &str, def: &StoredDefinition) -> String {
use std::collections::hash_map::DefaultHasher;
let mut hasher = DefaultHasher::new();
DAE_CACHE_VERSION.hash(&mut hasher);
RUMOCA_VERSION.hash(&mut hasher);
GIT_VERSION.hash(&mut hasher);
model_name.hash(&mut hasher);
for (name, class) in &def.class_list {
hash_class_for_cache(name, class, &mut hasher);
}
format!("{:016x}", hasher.finish())
}
fn hash_class_for_cache(
name: &str,
class: &crate::ir::ast::ClassDefinition,
hasher: &mut impl std::hash::Hasher,
) {
name.hash(hasher);
std::mem::discriminant(&class.class_type).hash(hasher);
for (comp_name, comp) in &class.components {
comp_name.hash(hasher);
comp.type_name.to_string().hash(hasher);
std::mem::discriminant(&comp.variability).hash(hasher);
std::mem::discriminant(&comp.causality).hash(hasher);
comp.shape.hash(hasher);
format!("{:?}", comp.start).hash(hasher);
}
for ext in &class.extends {
ext.comp.to_string().hash(hasher);
}
class.equations.len().hash(hasher);
for eq in &class.equations {
format!("{:?}", eq).hash(hasher);
}
for (nested_name, nested_class) in &class.classes {
hash_class_for_cache(nested_name, nested_class, hasher);
}
}
#[cfg(all(feature = "cache", not(target_arch = "wasm32")))]
mod disk_cache {
use super::*;
pub fn dae_cache_dir() -> Option<std::path::PathBuf> {
dirs::cache_dir().map(|d| d.join("rumoca").join("dae"))
}
fn check_and_update_version_marker(cache_dir: &std::path::Path) -> bool {
let version_file = cache_dir.join(".version");
let current_version = format!(
"{}:{}:{}",
super::DAE_CACHE_VERSION,
super::RUMOCA_VERSION,
super::GIT_VERSION
);
if version_file.exists()
&& let Ok(stored_version) = std::fs::read_to_string(&version_file)
&& stored_version.trim() == current_version
{
return true;
}
if cache_dir.exists() {
let _ = std::fs::remove_dir_all(cache_dir);
}
if std::fs::create_dir_all(cache_dir).is_ok() {
let _ = std::fs::write(&version_file, ¤t_version);
}
false
}
pub fn load_dae_from_disk(cache_key: &str) -> Option<BalanceResult> {
let cache_dir = dae_cache_dir()?;
if !check_and_update_version_marker(&cache_dir) {
return None;
}
let cache_file = cache_dir.join(format!("{}.bin", cache_key));
if !cache_file.exists() {
return None;
}
let data = std::fs::read(&cache_file).ok()?;
let entry: DaeCacheEntry = bincode::deserialize(&data).ok()?;
if !entry.dependencies.is_valid() {
let _ = std::fs::remove_file(&cache_file);
return None;
}
Some(entry.result)
}
pub fn save_dae_to_disk(
cache_key: &str,
result: &BalanceResult,
dependencies: &FileDependencies,
) {
let Some(cache_dir) = dae_cache_dir() else {
return;
};
check_and_update_version_marker(&cache_dir);
if std::fs::create_dir_all(&cache_dir).is_err() {
return;
}
let cache_file = cache_dir.join(format!("{}.bin", cache_key));
let entry = DaeCacheEntry {
result: result.clone(),
dependencies: dependencies.clone(),
};
if let Ok(data) = bincode::serialize(&entry) {
let _ = std::fs::write(&cache_file, data);
}
}
pub fn clear_disk_cache() {
if let Some(cache_dir) = dae_cache_dir() {
let _ = std::fs::remove_dir_all(&cache_dir);
}
}
}
pub fn clear_dae_cache() {
DAE_CACHE.write().expect("DAE cache lock poisoned").clear();
#[cfg(all(feature = "cache", not(target_arch = "wasm32")))]
disk_cache::clear_disk_cache();
}
pub fn compile_from_ast(
def: StoredDefinition,
model_name: Option<&str>,
model_hash: String,
parse_time: std::time::Duration,
verbose: bool,
) -> Result<CompilationResult> {
compile_from_ast_ref(&def, model_name, model_hash, parse_time, verbose)
}
pub fn compile_from_ast_ref(
def: &StoredDefinition,
model_name: Option<&str>,
model_hash: String,
parse_time: std::time::Duration,
verbose: bool,
) -> Result<CompilationResult> {
let flatten_start = Instant::now();
let fclass_result = flatten(def, model_name);
let mut fclass = match fclass_result {
Ok(fc) => fc,
Err(e) => {
return Err(e);
}
};
let flatten_time = flatten_start.elapsed();
if verbose {
println!("Flattening took {} ms", flatten_time.as_millis());
println!("Flattened class:\n{:#?}\n", fclass);
}
let mut import_resolver = ImportResolver::new(&fclass, def);
fclass.accept_mut(&mut import_resolver);
let expanded_class = fclass.clone();
let mut const_substitutor = ConstantSubstitutor::new();
fclass.accept_mut(&mut const_substitutor);
let mut enum_substitutor = EnumSubstitutor::new();
fclass.accept_mut(&mut enum_substitutor);
let function_names = collect_all_functions(def);
let has_nested_functions = fclass
.classes
.values()
.any(|c| c.class_type == ClassType::Function);
let should_validate = !matches!(fclass.class_type, ClassType::Package) && !has_nested_functions;
if should_validate {
let peer_class_names: Vec<String> = def.class_list.keys().cloned().collect();
let mut validator = VarValidator::with_context(&fclass, &function_names, &peer_class_names);
fclass.accept_mut(&mut validator);
if !validator.undefined_vars.is_empty() {
let (var_name, context) = &validator.undefined_vars[0];
return Err(anyhow::anyhow!(
"Undefined variable '{}' in {}",
var_name,
context
));
}
let type_check_result = check_component_bindings(&expanded_class);
if type_check_result.has_errors() {
let first_error = &type_check_result.errors[0];
return Err(anyhow::anyhow!("{}", first_error.message));
}
let peer_class_types: std::collections::HashMap<String, ClassType> = def
.class_list
.iter()
.map(|(name, class)| (name.clone(), class.class_type.clone()))
.collect();
let original_comp_types: std::collections::HashMap<String, String> = if let Some(name) =
model_name
{
def.class_list
.get(name)
.map(|class| {
class
.components
.iter()
.map(|(comp_name, comp)| (comp_name.clone(), comp.type_name.to_string()))
.collect()
})
.unwrap_or_default()
} else {
std::collections::HashMap::new()
};
let assign_check_result = check_algorithm_assignments_with_types(
&expanded_class,
&peer_class_types,
&original_comp_types,
);
if assign_check_result.has_errors() {
let first_error = &assign_check_result.errors[0];
return Err(anyhow::anyhow!("{}", first_error.message));
}
for (_, file_class) in &def.class_list {
if file_class.class_type == ClassType::Function {
let file_assign_result = check_algorithm_assignments_with_types(
file_class,
&peer_class_types,
&std::collections::HashMap::new(),
);
if file_assign_result.has_errors() {
let first_error = &file_assign_result.errors[0];
return Err(anyhow::anyhow!("{}", first_error.message));
}
}
}
let assert_check_result = check_assert_arguments(&expanded_class);
if assert_check_result.has_errors() {
let first_error = &assert_check_result.errors[0];
return Err(anyhow::anyhow!("{}", first_error.message));
}
let modifier_check_result = check_builtin_attribute_modifiers(&expanded_class);
if modifier_check_result.has_errors() {
let first_error = &modifier_check_result.errors[0];
return Err(anyhow::anyhow!("{}", first_error.message));
}
let start_dim_result = check_start_modification_dimensions(&expanded_class);
if start_dim_result.has_errors() {
let first_error = &start_dim_result.errors[0];
return Err(anyhow::anyhow!("{}", first_error.message));
}
let break_check_result = check_break_return_context(&expanded_class);
if break_check_result.has_errors() {
let first_error = &break_check_result.errors[0];
return Err(anyhow::anyhow!("{}", first_error.message));
}
for (_, class) in &def.class_list {
let file_break_result = check_break_return_context(class);
if file_break_result.has_errors() {
let first_error = &file_break_result.errors[0];
return Err(anyhow::anyhow!("{}", first_error.message));
}
}
let scalar_subscript_result = check_scalar_subscripts(&expanded_class);
if scalar_subscript_result.has_errors() {
let first_error = &scalar_subscript_result.errors[0];
return Err(anyhow::anyhow!("{}", first_error.message));
}
let cardinality_result = check_cardinality_context(&expanded_class);
if cardinality_result.has_errors() {
let first_error = &cardinality_result.errors[0];
return Err(anyhow::anyhow!("{}", first_error.message));
}
let cardinality_args_result = check_cardinality_arguments(&expanded_class);
if cardinality_args_result.has_errors() {
let first_error = &cardinality_args_result.errors[0];
return Err(anyhow::anyhow!("{}", first_error.message));
}
let class_member_result = check_class_member_access(&expanded_class);
if class_member_result.has_errors() {
let first_error = &class_member_result.errors[0];
return Err(anyhow::anyhow!("{}", first_error.message));
}
}
let mut inliner = FunctionInliner::from_class_list(&def.class_list);
fclass.accept_mut(&mut inliner);
drop(inliner);
expand_tuple_equations(&mut fclass);
expand_array_comprehensions(&mut fclass);
expand_equations(&mut fclass);
if should_validate {
let bounds_check_result = check_array_bounds(&fclass);
if bounds_check_result.has_errors() {
let first_error = &bounds_check_result.errors[0];
return Err(anyhow::anyhow!("{}", first_error.message));
}
}
let operator_records = build_operator_record_map(&def.class_list);
let type_map = build_type_map(&fclass, &def.class_list);
if verbose {
println!("=== Complex expansion debug ===");
println!(
"Operator records: {:?}",
operator_records.keys().collect::<Vec<_>>()
);
println!(
"Type map entries: {:?}",
type_map.keys().collect::<Vec<_>>()
);
}
let eq_count_before = fclass.equations.len();
fclass.equations = expand_complex_equations(&fclass.equations, &type_map, &operator_records);
if verbose {
println!(
"Equations: {} -> {}",
eq_count_before,
fclass.equations.len()
);
println!(
"After function inlining, tuple expansion, array comprehension, and equation expansion:\n{:#?}\n",
fclass
);
}
let dae_start = Instant::now();
let mut dae = create_dae(&mut fclass)?;
dae.model_hash = model_hash.clone();
let dae_time = dae_start.elapsed();
if verbose {
println!("DAE creation took {} ms", dae_time.as_millis());
println!("DAE:\n{:#?}\n", dae);
}
let balance = dae.check_balance();
if verbose {
println!("{}", balance.status_message());
}
Ok(CompilationResult {
dae,
def: def.clone(), expanded_class,
parse_time,
flatten_time,
dae_time,
model_hash,
balance,
})
}
pub fn check_balance_only(
def: &StoredDefinition,
model_name: Option<&str>,
) -> Result<crate::dae::balance::BalanceResult> {
let model_name_str = model_name.unwrap_or("");
let cache_key = compute_dae_cache_key(model_name_str, def);
if is_cache_enabled() {
let cache = DAE_CACHE.read().expect("DAE cache lock poisoned");
if let Some((result, _deps)) = cache.get(&cache_key) {
return Ok(result.clone());
}
}
#[cfg(all(feature = "cache", not(target_arch = "wasm32")))]
if is_cache_enabled()
&& let Some(result) = disk_cache::load_dae_from_disk(&cache_key)
{
let deps = FileDependencies::default(); DAE_CACHE
.write()
.expect("DAE cache lock poisoned")
.insert(cache_key, (result.clone(), deps));
return Ok(result);
}
let flatten_result = flatten_with_deps(def, model_name);
let mut fclass = match flatten_result {
Ok(fr) => fr,
Err(e) => {
return Err(anyhow::anyhow!("Flatten error: {}", e));
}
};
let mut const_substitutor = ConstantSubstitutor::new();
fclass.class.accept_mut(&mut const_substitutor);
let mut enum_substitutor = EnumSubstitutor::new();
fclass.class.accept_mut(&mut enum_substitutor);
let mut inliner = FunctionInliner::from_class_list(&def.class_list);
fclass.class.accept_mut(&mut inliner);
drop(inliner);
expand_tuple_equations(&mut fclass.class);
expand_array_comprehensions(&mut fclass.class);
expand_equations(&mut fclass.class);
let operator_records = build_operator_record_map(&def.class_list);
let type_map = build_type_map(&fclass.class, &def.class_list);
fclass.class.equations =
expand_complex_equations(&fclass.class.equations, &type_map, &operator_records);
let dae = create_dae(&mut fclass.class)?;
let result = dae.check_balance();
if is_cache_enabled() {
#[cfg(all(feature = "cache", not(target_arch = "wasm32")))]
disk_cache::save_dae_to_disk(&cache_key, &result, &fclass.dependencies);
DAE_CACHE
.write()
.expect("DAE cache lock poisoned")
.insert(cache_key, (result.clone(), fclass.dependencies));
}
Ok(result)
}
pub fn check_balance_with_library_dicts(
user_def: &StoredDefinition,
library_dicts: &[Arc<ClassDict>],
model_name: Option<&str>,
) -> Result<crate::dae::balance::BalanceResult> {
let flatten_result = flatten_with_library_dicts(user_def, library_dicts, model_name);
let mut fclass = match flatten_result {
Ok(fr) => fr,
Err(e) => {
return Err(anyhow::anyhow!("Flatten error: {}", e));
}
};
let mut const_substitutor = ConstantSubstitutor::new();
fclass.class.accept_mut(&mut const_substitutor);
let mut enum_substitutor = EnumSubstitutor::new();
fclass.class.accept_mut(&mut enum_substitutor);
let mut inliner = FunctionInliner::from_class_list(&user_def.class_list);
fclass.class.accept_mut(&mut inliner);
drop(inliner);
expand_tuple_equations(&mut fclass.class);
expand_array_comprehensions(&mut fclass.class);
expand_equations(&mut fclass.class);
let operator_records = build_operator_record_map(&user_def.class_list);
let type_map = build_type_map(&fclass.class, &user_def.class_list);
fclass.class.equations =
expand_complex_equations(&fclass.class.equations, &type_map, &operator_records);
let dae = create_dae(&mut fclass.class)?;
Ok(dae.check_balance())
}
pub fn compile_with_library_dicts(
user_def: &StoredDefinition,
library_dicts: &[Arc<ClassDict>],
model_name: &str,
) -> Result<CompilationResult> {
let start = Instant::now();
let flatten_result = flatten_with_library_dicts(user_def, library_dicts, Some(model_name));
let mut fclass_result = match flatten_result {
Ok(fr) => fr,
Err(e) => {
return Err(e);
}
};
let flatten_time = start.elapsed();
let mut import_resolver = ImportResolver::new(&fclass_result.class, user_def);
fclass_result.class.accept_mut(&mut import_resolver);
let expanded_class = fclass_result.class.clone();
let mut const_substitutor = ConstantSubstitutor::new();
fclass_result.class.accept_mut(&mut const_substitutor);
let mut enum_substitutor = EnumSubstitutor::new();
fclass_result.class.accept_mut(&mut enum_substitutor);
let function_names = collect_all_functions(user_def);
let has_nested_functions = fclass_result
.class
.classes
.values()
.any(|c| c.class_type == ClassType::Function);
let should_validate =
!matches!(fclass_result.class.class_type, ClassType::Package) && !has_nested_functions;
if should_validate {
let peer_class_names: Vec<String> = user_def.class_list.keys().cloned().collect();
let mut validator =
VarValidator::with_context(&fclass_result.class, &function_names, &peer_class_names);
fclass_result.class.accept_mut(&mut validator);
if !validator.undefined_vars.is_empty() {
let (var_name, context) = &validator.undefined_vars[0];
return Err(anyhow::anyhow!(
"Undefined variable '{}' in {}",
var_name,
context
));
}
let type_check_result = check_component_bindings(&expanded_class);
if type_check_result.has_errors() {
let first_error = &type_check_result.errors[0];
return Err(anyhow::anyhow!("{}", first_error.message));
}
}
let mut inliner = FunctionInliner::from_class_list(&user_def.class_list);
fclass_result.class.accept_mut(&mut inliner);
drop(inliner);
expand_tuple_equations(&mut fclass_result.class);
expand_array_comprehensions(&mut fclass_result.class);
expand_equations(&mut fclass_result.class);
let operator_records = build_operator_record_map(&user_def.class_list);
let type_map = build_type_map(&fclass_result.class, &user_def.class_list);
fclass_result.class.equations =
expand_complex_equations(&fclass_result.class.equations, &type_map, &operator_records);
let dae_start = Instant::now();
let dae = create_dae(&mut fclass_result.class)?;
let dae_time = dae_start.elapsed();
let balance = dae.check_balance();
let model_hash = format!("{:x}", chksum_md5::hash(format!("{:?}", user_def)));
Ok(CompilationResult {
dae,
def: user_def.clone(),
expanded_class,
parse_time: std::time::Duration::ZERO,
flatten_time,
dae_time,
model_hash,
balance,
})
}