pub mod inline {
use crate::codegen::{Codegen, CodegenStrategy};
use haloumi_ir_gen::{circuit::resolved::ResolvedIRCircuit, ctx::IRCtx};
#[derive(Default)]
pub struct InlineConstraintsStrat {}
impl CodegenStrategy for InlineConstraintsStrat {
fn codegen<'c: 'st, 's, 'st, C>(
&self,
codegen: &C,
ctx: &IRCtx,
ir: &ResolvedIRCircuit,
) -> Result<(), C::Error>
where
C: Codegen<'c, 'st>,
{
log::debug!(
"Performing codegen with {} strategy",
std::any::type_name_of_val(self)
);
log::debug!("Generating main body");
let main_id = ir.main().id();
codegen.define_main_function_with_body(
ctx.advice_io_of_group(main_id),
ctx.instance_io_of_group(main_id),
ir.groups().to_vec(),
)
}
}
}
pub mod groups {
use crate::codegen::{Codegen, CodegenStrategy};
use eqv::EqvRelation;
use haloumi_ir::{SymbolicEqv, expr::IRAexpr, groups::IRGroup};
use haloumi_ir_gen::{circuit::resolved::ResolvedIRCircuit, ctx::IRCtx};
use haloumi_synthesis::groups::GroupKey;
use std::collections::{HashMap, HashSet};
#[derive(Default)]
pub struct GroupConstraintsStrat {}
impl CodegenStrategy for GroupConstraintsStrat {
fn codegen<'c: 'st, 's, 'st, C>(
&self,
codegen: &C,
ctx: &IRCtx,
ir: &ResolvedIRCircuit,
) -> Result<(), C::Error>
where
C: Codegen<'c, 'st>,
{
ir.validate()?;
let mut groups_ir = ir.groups().to_vec();
let (leaders, updated_calldata) = select_leaders(&groups_ir);
log::debug!("Leaders for the non-main groups: {leaders:?}");
log::debug!("Updated calldata: {updated_calldata:?}");
groups_ir.retain_mut(|g| {
let keep = g.is_main() || leaders.contains(&g.id());
if keep {
update_names(g, &updated_calldata)
}
keep
});
for group in groups_ir {
log::debug!("Generating code for group \"{}\"", group.name());
let advice_io = ctx.advice_io_of_group(group.id());
let instance_io = ctx.instance_io_of_group(group.id());
if group.is_main() {
log::debug!("Generating main body");
codegen.define_main_function_with_body(advice_io, instance_io, [group])?;
} else {
log::debug!("Generating body of function {}", group.name());
let name = group.name().to_owned();
codegen.define_function_with_body(
&name,
advice_io.inputs_count() + instance_io.inputs_count(),
advice_io.outputs_count() + instance_io.outputs_count(),
|_, _, _| Ok([group]),
)?;
}
}
Ok(())
}
}
pub fn organize_groups_by_key(groups: &[IRGroup<IRAexpr>]) -> HashMap<GroupKey, Vec<usize>> {
let mut groups_by_key: HashMap<_, Vec<_>> = HashMap::new();
for group in groups {
if group.is_main() {
log::debug!("Group {} is main. Skipping...", group.id());
continue;
}
groups_by_key
.entry(group.key().expect("Non main group needs a key"))
.or_default()
.push(group.id());
log::debug!("Inserting group {} with key {:?}", group.id(), group.key());
}
groups_by_key
}
fn select_leaders(
groups_ir: &[IRGroup<IRAexpr>],
) -> (Vec<usize>, Vec<Option<(usize, String)>>) {
let groups_by_key = organize_groups_by_key(groups_ir);
log::debug!("Groups: {groups_by_key:?}");
let mut leaders = vec![];
let mut updated_calldata: Vec<Option<(usize, String)>> = vec![None; groups_ir.len()];
let mut used_names: HashSet<String> = HashSet::default();
let mut eqv_class = disjoint::DisjointSet::new();
let eqv_class_ids: Vec<_> = (0..groups_ir.len())
.map(|_| eqv_class.add_singleton())
.collect();
for groups in groups_by_key.values() {
for (i, j) in product(groups.as_slice(), groups.as_slice()) {
if *i == *j {
continue;
}
let lhs = &groups_ir[*i];
let rhs = &groups_ir[*j];
if SymbolicEqv::equivalent(lhs, rhs) {
eqv_class.join(eqv_class_ids[*i], eqv_class_ids[*j]);
}
}
}
let eqv_class_ids: HashMap<_, _> = eqv_class_ids
.into_iter()
.enumerate()
.map(|(n, id)| (id, n))
.collect();
for (n, set) in eqv_class.sets().into_iter().enumerate() {
debug_assert!(!set.is_empty());
let set: Vec<_> = set.into_iter().map(|id| eqv_class_ids[&id]).collect();
let leader_id = set[0];
leaders.push(leader_id);
let leader = groups_ir.get(leader_id).unwrap();
let name = fresh_group_name(leader.name(), &mut used_names, n);
for update in &set {
updated_calldata[*update] = Some((leader_id, name.clone()));
}
}
(leaders, updated_calldata)
}
fn update_names(group: &mut IRGroup<IRAexpr>, updated_calldata: &[Option<(usize, String)>]) {
if let Some((id, name)) = &updated_calldata[group.id()] {
*group.name_mut() = name.clone();
group.set_id(*id);
}
for callsite in group.callsites_mut() {
if let Some((id, name)) = &updated_calldata[callsite.callee_id()] {
callsite.set_callee_id(*id);
callsite.set_name(name.clone());
}
}
}
fn fresh_group_name(name: &str, used_names: &mut HashSet<String>, n: usize) -> String {
let name = [name.to_owned()]
.into_iter()
.chain((n..).map(|n| format!("{name}{n}")))
.find_map(|name| {
if used_names.contains(&name) {
return None;
}
Some(name)
})
.unwrap();
used_names.insert(name.clone());
name
}
#[inline]
fn product<'a, L: Clone + 'a, R: 'a>(
lhs: impl IntoIterator<Item = L> + 'a,
rhs: impl IntoIterator<Item = R> + Clone + 'a,
) -> impl Iterator<Item = (L, R)> + 'a {
lhs.into_iter()
.flat_map(move |lhs| rhs.clone().into_iter().map(move |rhs| (lhs.clone(), rhs)))
}
}