use std::{collections::HashSet, vec};
use rustc_hash::FxHashMap;
use crate::{
dominator::{self},
AnalysisResults, BinaryOpKind, Context, Function, InitAggrInitializer, InsertionPosition,
InstOp, Instruction, InstructionInserter, IrError, MetadataIndex, Pass, PassMutability,
Predicate, ScopedPass, Type, TypeContent, Value,
};
pub const INIT_AGGR_LOWERING_NAME: &str = "lower-init-aggr";
pub fn create_init_aggr_lowering_pass() -> Pass {
Pass {
name: INIT_AGGR_LOWERING_NAME,
descr: "Lowering of `init_aggr` instructions",
deps: vec![],
runner: ScopedPass::FunctionPass(PassMutability::Transform(init_aggr_lowering)),
}
}
pub fn init_aggr_lowering<'a, 'b>(
context: &'a mut Context<'b>,
_analyses: &AnalysisResults,
function: Function,
) -> Result<bool, IrError> {
let root_init_aggrs = find_root_init_aggrs(context, function);
if root_init_aggrs.is_empty() {
return Ok(false);
}
let mut replace_map = FxHashMap::<Value, Value>::default();
for root_init_aggr in root_init_aggrs.iter() {
let (root_aggr_ptr, initializers) = deconstruct_init_aggr(context, *root_init_aggr);
replace_map.insert(*root_init_aggr, root_aggr_ptr);
let aggr_type = root_aggr_ptr
.match_ptr_type(context)
.expect("`root_aggr_ptr` must be a pointer");
let _ = lower_mostly_zeroed_aggregate()
|| lower_to_stores(
context,
*root_init_aggr,
aggr_type,
root_aggr_ptr,
&mut Vec::new(),
&initializers,
);
}
function.replace_values(context, &replace_map, None);
function.remove_instructions(context, |inst| root_init_aggrs.contains(&inst));
Ok(true)
}
fn deconstruct_init_aggr(context: &Context, init_aggr: Value) -> (Value, Vec<InitAggrInitializer>) {
let Some(Instruction {
parent: _,
op: InstOp::InitAggr(init_aggr),
}) = init_aggr.get_instruction(context).cloned()
else {
panic!("`init_aggr` must be an `Instruction` with `op` of variant `InstOp::InitAggr`");
};
(
init_aggr.aggr_ptr,
init_aggr.initializers(context).collect(),
)
}
fn lower_mostly_zeroed_aggregate() -> bool {
false
}
fn lower_to_stores<'a, 'b>(
context: &'a mut Context<'b>,
init_aggr: Value,
aggr_type: Type,
root_aggr_ptr: Value,
gep_indices: &mut Vec<u64>,
initializers: &[InitAggrInitializer],
) -> bool {
let init_aggr_metadata = init_aggr.get_metadata(context);
match aggr_type.get_content(context).clone() {
TypeContent::Array(arr_elem_type, length) => {
assert_eq!(
length as usize,
initializers.len(),
"`init_aggr` initializers must match the length of the array type"
);
fn as_repeat_array(
initializers: &[InitAggrInitializer],
) -> Option<(InitAggrInitializer, u64)> {
initializers.split_first().and_then(|(first_init, rest)| {
if rest.iter().all(|init| init == first_init) {
Some((first_init.clone(), initializers.len() as u64))
} else {
None
}
})
}
match as_repeat_array(initializers) {
Some((initializer, length)) => {
let repeated_value = match initializer {
InitAggrInitializer::Value(value) => value,
InitAggrInitializer::NestedInitAggr {
load: nested_ia_load,
init_aggr: nested_init_aggr,
} => {
let (nested_aggr_ptr, nested_ia_initializers) =
deconstruct_init_aggr(context, nested_init_aggr);
let mut gep_indices: Vec<u64> = vec![];
let nested_aggr_type = nested_aggr_ptr
.match_ptr_type(context)
.expect("`nested_aggr_ptr` must be a pointer");
lower_to_stores(
context,
nested_init_aggr,
nested_aggr_type,
nested_aggr_ptr,
&mut gep_indices,
&nested_ia_initializers,
);
let nested_ia_block = nested_init_aggr
.get_parent_block(context)
.expect(
"`nested_init_aggr` is an instruction and must have a parent block",
);
nested_ia_block.remove_instruction(context, nested_init_aggr);
nested_ia_load.replace_instruction_value(
context,
nested_init_aggr,
nested_aggr_ptr,
);
nested_ia_load
}
};
if length > 5 {
let array_ptr = if gep_indices.is_empty() {
root_aggr_ptr
} else {
let inserter =
get_inst_inserter_for_before_init_aggr(context, init_aggr);
inserter
.get_elem_ptr_with_idcs(root_aggr_ptr, aggr_type, gep_indices)
.add_metadatum(context, init_aggr_metadata)
};
generate_array_init_loop(
context,
array_ptr,
arr_elem_type,
repeated_value,
length,
init_aggr,
init_aggr_metadata,
);
} else {
for insert_idx in 0..length {
gep_indices.push(insert_idx);
let inserter =
get_inst_inserter_for_before_init_aggr(context, init_aggr);
let gep_val = inserter
.get_elem_ptr_with_idcs(root_aggr_ptr, arr_elem_type, gep_indices)
.add_metadatum(context, init_aggr_metadata);
let inserter =
get_inst_inserter_for_before_init_aggr(context, init_aggr);
inserter
.store(gep_val, repeated_value)
.add_metadatum(context, init_aggr_metadata);
gep_indices.pop();
}
}
}
None => {
for (insert_idx, initializer) in initializers.iter().enumerate() {
gep_indices.push(insert_idx as u64);
lower_single_initializer_to_stores(
context,
init_aggr,
root_aggr_ptr,
gep_indices,
init_aggr_metadata,
initializer,
arr_elem_type,
);
gep_indices.pop();
}
}
}
}
TypeContent::Struct(field_types) => {
assert_eq!(
field_types.len(),
initializers.len(),
"`init_aggr` initializers must match the number of fields in the struct type"
);
for (insert_idx, (initializer, field_type)) in
initializers.iter().zip(field_types).enumerate()
{
gep_indices.push(insert_idx as u64);
lower_single_initializer_to_stores(
context,
init_aggr,
root_aggr_ptr,
gep_indices,
init_aggr_metadata,
initializer,
field_type,
);
gep_indices.pop();
}
}
_ => unreachable!("`aggr_ptr` must point to an array or struct IR type"),
}
true
}
fn get_inst_inserter_for_before_init_aggr<'a, 'b>(
context: &'a mut Context<'b>,
init_aggr: Value,
) -> InstructionInserter<'a, 'b> {
let block = init_aggr
.get_parent_block(context)
.expect("`init_aggr` is an instruction and must have a parent block");
InstructionInserter::new(context, block, InsertionPosition::Before(init_aggr))
}
fn lower_single_initializer_to_stores(
context: &mut Context<'_>,
init_aggr: Value,
root_aggr_ptr: Value,
gep_indices: &mut Vec<u64>,
init_aggr_metadata: Option<MetadataIndex>,
initializer: &InitAggrInitializer,
elem_ty: Type,
) {
match initializer {
InitAggrInitializer::Value(value) => {
let inserter = get_inst_inserter_for_before_init_aggr(context, init_aggr);
let gep_val = inserter
.get_elem_ptr_with_idcs(root_aggr_ptr, elem_ty, gep_indices)
.add_metadatum(context, init_aggr_metadata);
let inserter = get_inst_inserter_for_before_init_aggr(context, init_aggr);
inserter
.store(gep_val, *value)
.add_metadatum(context, init_aggr_metadata);
}
InitAggrInitializer::NestedInitAggr {
load: nested_ia_load,
init_aggr: nested_init_aggr,
} => {
let (nested_aggr_ptr, nested_ia_initializers) =
deconstruct_init_aggr(context, *nested_init_aggr);
let inserter = get_inst_inserter_for_before_init_aggr(context, *nested_init_aggr);
let gep_val = inserter
.get_elem_ptr_with_idcs(root_aggr_ptr, elem_ty, gep_indices)
.add_metadatum(context, init_aggr_metadata);
let nested_aggr_type = nested_aggr_ptr
.match_ptr_type(context)
.expect("`nested_aggr_ptr` must be a pointer");
lower_to_stores(
context,
*nested_init_aggr,
nested_aggr_type,
root_aggr_ptr,
gep_indices,
&nested_ia_initializers,
);
let nested_ia_block = nested_init_aggr
.get_parent_block(context)
.expect("`nested_init_aggr` is an instruction and must have a parent block");
nested_ia_block.remove_instruction(context, *nested_init_aggr);
nested_ia_load.replace_instruction_value(context, *nested_init_aggr, gep_val);
}
}
}
fn find_root_init_aggrs(context: &Context, function: Function) -> Vec<Value> {
fn visit_nested_init_aggrs(
context: &Context,
parent_initializers: impl Iterator<Item = InitAggrInitializer>,
nested_init_aggrs: &mut HashSet<Value>,
) {
for initializer in parent_initializers {
if let InitAggrInitializer::NestedInitAggr {
load: _,
init_aggr: init_aggr_val,
} = initializer
{
let Some(Instruction {
parent: _,
op: InstOp::InitAggr(init_aggr),
}) = init_aggr_val.get_instruction(context)
else {
unreachable!("`init_aggr` is an `InstOp::InitAggr`");
};
nested_init_aggrs.insert(init_aggr_val);
visit_nested_init_aggrs(
context,
init_aggr.initializers(context),
nested_init_aggrs,
);
}
}
}
let mut result = vec![];
let mut nested_init_aggrs = HashSet::new();
let po = dominator::compute_post_order(context, &function);
for block in po.po_to_block.iter() {
for inst in block.instruction_iter(context).rev() {
if let Some(Instruction {
parent: _,
op: InstOp::InitAggr(init_aggr),
}) = inst.get_instruction(context)
{
if !nested_init_aggrs.contains(&inst) {
result.push(inst);
visit_nested_init_aggrs(
context,
init_aggr.initializers(context),
&mut nested_init_aggrs,
);
}
}
}
}
result
}
fn generate_array_init_loop(
context: &mut Context,
array_ptr: Value,
elem_type: Type,
repeated_value: Value,
length: u64,
init_aggr: Value,
md_idx: Option<MetadataIndex>,
) {
let block = init_aggr
.get_parent_block(context)
.expect("`init_aggr` is an instruction and must have a parent block");
let init_aggr_idx = block
.instruction_iter(context)
.position(|v| v == init_aggr)
.expect("`init_aggr` must be in its parent block");
let (pre_block, exit_block) = block.split_at(context, init_aggr_idx + 1);
exit_block.set_label(context, Some("array_init_loop_exit".into()));
let loop_block = pre_block
.get_function(context)
.create_block_before(context, &exit_block, Some("array_init_loop".into()))
.expect("`exit_block` exists in the `pre_block`'s function");
let index_var_index = loop_block.new_arg(context, Type::get_uint64(context));
let index = loop_block.get_arg(context, index_var_index).unwrap();
let zero = Value::new_u64_constant(context, 0);
pre_block.append(context).branch(loop_block, vec![zero]);
let gep_val = loop_block
.append(context)
.get_elem_ptr(array_ptr, elem_type, vec![index]);
loop_block
.append(context)
.store(gep_val, repeated_value)
.add_metadatum(context, md_idx);
let one = Value::new_u64_constant(context, 1);
let index_inc = loop_block
.append(context)
.binary_op(BinaryOpKind::Add, index, one);
let len = Value::new_u64_constant(context, length);
let r#continue = loop_block
.append(context)
.cmp(Predicate::LessThan, index_inc, len);
loop_block.append(context).conditional_branch(
r#continue,
loop_block,
exit_block,
vec![index_inc],
vec![],
);
}