use crate::*;
use crate::{core::ResolvedCall, typechecking::FuncType};
use egglog_ast::generic_ast::{GenericAction, GenericExpr, GenericFact, GenericRule};
struct GlobalRemover<'a> {
fresh: &'a mut SymbolGen,
}
pub(crate) fn remove_globals(
prog: Vec<ResolvedNCommand>,
fresh: &mut SymbolGen,
) -> Vec<ResolvedNCommand> {
let mut remover = GlobalRemover { fresh };
prog.into_iter()
.flat_map(|cmd| remover.remove_globals_cmd(cmd))
.collect()
}
fn resolved_var_to_call(var: &ResolvedVar) -> ResolvedCall {
assert!(
var.is_global_ref,
"resolved_var_to_call called on non-global var"
);
ResolvedCall::Func(FuncType {
name: var.name.clone(),
subtype: FunctionSubtype::Custom,
input: vec![],
output: var.sort.clone(),
})
}
fn replace_global_vars(expr: ResolvedExpr) -> ResolvedExpr {
match expr.get_global_var() {
Some(resolved_var) => {
GenericExpr::Call(expr.span(), resolved_var_to_call(&resolved_var), vec![])
}
None => expr,
}
}
fn remove_globals_expr(expr: ResolvedExpr) -> ResolvedExpr {
expr.visit_exprs(&mut replace_global_vars)
}
fn remove_globals_action(action: ResolvedAction) -> ResolvedAction {
action.visit_exprs(&mut replace_global_vars)
}
impl GlobalRemover<'_> {
fn remove_globals_cmd(&mut self, cmd: ResolvedNCommand) -> Vec<ResolvedNCommand> {
match cmd {
GenericNCommand::CoreAction(action) => match action {
GenericAction::Let(span, name, expr) => {
let ty = expr.output_type();
let resolved_call = ResolvedCall::Func(FuncType {
name: name.name.clone(),
subtype: FunctionSubtype::Custom,
input: vec![],
output: ty.clone(),
});
let func_decl = ResolvedFunctionDecl {
name: name.name,
subtype: FunctionSubtype::Custom,
schema: Schema {
input: vec![],
output: ty.name().to_owned(),
},
resolved_schema: resolved_call.clone(),
merge: None,
cost: None,
unextractable: true,
let_binding: true,
span: span.clone(),
};
vec![
GenericNCommand::Function(func_decl),
GenericNCommand::CoreAction(GenericAction::Set(
span,
resolved_call,
vec![],
remove_globals_expr(expr),
)),
]
}
_ => vec![GenericNCommand::CoreAction(remove_globals_action(action))],
},
GenericNCommand::NormRule { rule } => {
let mut globals = HashMap::default();
rule.head.clone().visit_exprs(&mut |expr| {
if let Some(resolved_var) = expr.get_global_var() {
let new_name = self.fresh.fresh(&resolved_var.name);
globals.insert(
resolved_var.clone(),
GenericExpr::Var(
expr.span(),
ResolvedVar {
name: new_name,
sort: resolved_var.sort.clone(),
is_global_ref: false,
},
),
);
}
expr
});
let new_facts: Vec<ResolvedFact> = globals
.iter()
.map(|(old, new)| {
GenericFact::Eq(
new.span(),
GenericExpr::Call(new.span(), resolved_var_to_call(old), vec![]),
new.clone(),
)
})
.collect();
let new_rule = GenericRule {
span: rule.span,
body: rule
.body
.iter()
.map(|fact| fact.clone().visit_exprs(&mut replace_global_vars))
.chain(new_facts)
.collect(),
head: rule.head.clone().visit_exprs(&mut |expr| {
if let Some(resolved_var) = expr.get_global_var() {
globals.get(&resolved_var).unwrap().clone()
} else {
expr
}
}),
name: rule.name.clone(),
ruleset: rule.ruleset.clone(),
};
vec![GenericNCommand::NormRule { rule: new_rule }]
}
GenericNCommand::Fail(span, cmd) => {
let mut removed = self.remove_globals_cmd(*cmd);
let last = removed.pop().unwrap();
let boxed_last = Box::new(last);
let new_command = GenericNCommand::Fail(span, boxed_last);
removed.push(new_command);
removed
}
_ => vec![cmd.visit_exprs(&mut replace_global_vars)],
}
}
}