use cairo_lang_defs::ids::ExternFunctionId;
use cairo_lang_semantic::helper::ModuleHelper;
use cairo_lang_utils::unordered_hash_set::UnorderedHashSet;
use itertools::chain;
use salsa::Database;
use crate::db::LoweringGroup;
use crate::utils::InliningStrategy;
#[derive(Debug, Eq, PartialEq, Clone)]
pub enum Optimizations {
Disabled,
Enabled(OptimizationConfig),
}
#[derive(Default, Debug, Eq, PartialEq, Clone)]
pub struct OptimizationConfig {
pub(crate) moveable_functions: Vec<String>,
pub(crate) inlining_strategy: InliningStrategy,
pub(crate) skip_const_folding: bool,
}
impl OptimizationConfig {
pub fn with_skip_const_folding(mut self, skip_const_folding: bool) -> Self {
self.skip_const_folding = skip_const_folding;
self
}
}
impl Optimizations {
pub fn enabled_with_default_movable_functions(inlining_strategy: InliningStrategy) -> Self {
Self::Enabled(OptimizationConfig {
moveable_functions: default_moveable_functions(),
inlining_strategy,
skip_const_folding: false,
})
}
pub fn enabled_with_minimal_movable_functions() -> Self {
Self::Enabled(OptimizationConfig {
moveable_functions: vec!["felt252_sub".to_string()],
inlining_strategy: Default::default(),
skip_const_folding: false,
})
}
pub fn moveable_functions(&self) -> &[String] {
if let Self::Enabled(config) = self { &config.moveable_functions } else { &[] }
}
pub fn inlining_strategy(&self) -> InliningStrategy {
if let Self::Enabled(config) = self {
config.inlining_strategy
} else {
InliningStrategy::Avoid
}
}
pub fn skip_const_folding(&self) -> bool {
if let Self::Enabled(config) = self { config.skip_const_folding } else { true }
}
}
#[salsa::tracked(returns(ref))]
pub fn priv_movable_function_ids<'db>(
db: &'db dyn Database,
) -> UnorderedHashSet<ExternFunctionId<'db>> {
db.optimizations()
.moveable_functions()
.iter()
.map(|name: &String| {
let mut path_iter = name.split("::");
let mut module = ModuleHelper::core(db);
let mut next = path_iter.next();
while let Some(path_item) = next {
next = path_iter.next();
if next.is_some() {
module = module.submodule(path_item);
continue;
}
return module.extern_function_id(path_item);
}
panic!("Got empty string as movable_function");
})
.collect()
}
fn default_moveable_functions() -> Vec<String> {
let mut moveable_functions: Vec<String> = chain!(
["bool_not_impl"],
["felt252_add", "felt252_sub", "felt252_mul", "felt252_div"],
["array::array_new", "array::array_append"],
["box::box_forward_snapshot"],
)
.map(|s| s.to_string())
.collect();
for ty in ["i8", "i16", "i32", "i64", "u8", "u16", "u32", "u64"] {
moveable_functions.push(format!("integer::{ty}_wide_mul"));
}
moveable_functions
}