use rustc_hir::def_id::DefId;
use rustc_middle::mir::{Body, TerminatorKind};
use rustc_mir_dataflow::ResultsCursor;
use std::collections::HashSet;
use super::super::{AliasPair, FnAliasPairs};
use super::intraproc::{FnAliasAnalyzer, PlaceId};
fn extract_fields(place: &PlaceId) -> (usize, Vec<usize>) {
let mut fields = Vec::new();
let mut current = place;
loop {
match current {
PlaceId::Local(idx) => return (*idx, fields),
PlaceId::Field { base, field_idx } => {
fields.push(*field_idx);
current = base;
}
}
}
}
fn extract_field_path(place: &PlaceId) -> Vec<usize> {
let mut fields = Vec::new();
let mut current = place;
loop {
match current {
PlaceId::Local(_) => {
fields.reverse(); return fields;
}
PlaceId::Field { base, field_idx } => {
fields.push(*field_idx);
current = base;
}
}
}
}
fn is_field_prefix(prefix: &[usize], full: &[usize]) -> bool {
if prefix.len() > full.len() {
return false;
}
prefix == &full[..prefix.len()]
}
pub fn extract_summary<'tcx>(
results: &mut ResultsCursor<'_, 'tcx, FnAliasAnalyzer<'tcx>>,
body: &Body<'tcx>,
_def_id: DefId,
) -> FnAliasPairs {
let arg_count = body.arg_count;
let mut summary = FnAliasPairs::new(arg_count);
for (block_id, block_data) in body.basic_blocks.iter_enumerated() {
if let Some(terminator) = &block_data.terminator {
if matches!(terminator.kind, TerminatorKind::Return) {
results.seek_to_block_end(block_id);
let state = results.get();
let analyzer = results.analysis();
let place_info = analyzer.place_info();
let mut all_pairs = Vec::new();
for (idx_i, idx_j) in state.get_all_alias_pairs() {
if let (Some(place_i), Some(place_j)) =
(place_info.get_place(idx_i), place_info.get_place(idx_j))
{
all_pairs.push((idx_i, idx_j, place_i, place_j));
}
}
let mut relevant_places = HashSet::new();
for idx in 0..place_info.num_places() {
if let Some(place) = place_info.get_place(idx) {
if place.root_local() <= arg_count {
relevant_places.insert(idx);
}
}
}
const MAX_ITERATIONS: usize = 10;
for _iteration in 0..MAX_ITERATIONS {
let mut changed = false;
for &(idx_i, idx_j, _, _) in &all_pairs {
if relevant_places.contains(&idx_i) && !relevant_places.contains(&idx_j) {
relevant_places.insert(idx_j);
changed = true;
}
if relevant_places.contains(&idx_j) && !relevant_places.contains(&idx_i) {
relevant_places.insert(idx_i);
changed = true;
}
}
if !changed {
break;
}
}
let mut candidate_aliases = std::collections::HashSet::new();
for &(idx_i, idx_j, place_i, place_j) in &all_pairs {
if !relevant_places.contains(&idx_i) || !relevant_places.contains(&idx_j) {
continue;
}
for &(idx_k, idx_m, place_k, place_m) in &all_pairs {
if !relevant_places.contains(&idx_k) || !relevant_places.contains(&idx_m) {
continue;
}
if place_j.root_local() != place_m.root_local() {
continue;
}
let j_fields = extract_field_path(place_j);
let m_fields = extract_field_path(place_m);
if !is_field_prefix(&j_fields, &m_fields)
&& !is_field_prefix(&m_fields, &j_fields)
{
continue;
}
let (root_i, mut fields_i) = extract_fields(place_i);
let (root_k, fields_k) = extract_fields(place_k);
if root_i > arg_count || root_k > arg_count {
continue;
}
fields_i.reverse(); if fields_i.len() > 1 {
fields_i = vec![fields_i[0]];
}
let mut fields_k_reversed = fields_k.clone();
fields_k_reversed.reverse();
let mut alias = AliasPair::new(root_i, root_k);
alias.lhs_fields = fields_i;
alias.rhs_fields = fields_k_reversed;
candidate_aliases.insert(alias);
}
}
for (idx_i, idx_j, place_i, place_j) in all_pairs {
if relevant_places.contains(&idx_i) && relevant_places.contains(&idx_j) {
let (root_i, mut fields_i) = extract_fields(place_i);
let (root_j, mut fields_j) = extract_fields(place_j);
if root_i <= arg_count && root_j <= arg_count {
fields_i.reverse();
fields_j.reverse();
let mut alias = AliasPair::new(root_i, root_j);
alias.lhs_fields = fields_i;
alias.rhs_fields = fields_j;
candidate_aliases.insert(alias);
}
}
}
let normalized_aliases: std::collections::HashSet<_> = candidate_aliases
.iter()
.map(|alias| {
let mut normalized = alias.clone();
if normalized.left_local > normalized.right_local {
normalized.swap(); }
normalized
})
.collect();
let filtered_aliases = filter_redundant_aliases(normalized_aliases);
for alias in filtered_aliases.clone() {
summary.add_alias(alias);
}
}
}
}
summary
}
fn filter_redundant_aliases(
aliases: std::collections::HashSet<AliasPair>,
) -> std::collections::HashSet<AliasPair> {
use std::collections::HashSet;
let aliases_vec: Vec<_> = aliases.iter().cloned().collect();
let mut to_remove = HashSet::new();
for i in 0..aliases_vec.len() {
let alias_a = &aliases_vec[i];
if to_remove.contains(alias_a) {
continue;
}
if alias_a.left_local == alias_a.right_local {
to_remove.insert(alias_a.clone());
continue;
}
for j in 0..aliases_vec.len() {
if i == j {
continue;
}
let alias_b = &aliases_vec[j];
if to_remove.contains(alias_b) {
continue;
}
if alias_a.left_local != alias_b.left_local
|| alias_a.right_local != alias_b.right_local
{
continue;
}
let lhs_a_subsumes_b = is_strict_prefix(&alias_a.lhs_fields, &alias_b.lhs_fields)
|| alias_a.lhs_fields == alias_b.lhs_fields;
let rhs_a_subsumes_b = is_strict_prefix(&alias_a.rhs_fields, &alias_b.rhs_fields)
|| alias_a.rhs_fields == alias_b.rhs_fields;
let lhs_b_subsumes_a = is_strict_prefix(&alias_b.lhs_fields, &alias_a.lhs_fields)
|| alias_b.lhs_fields == alias_a.lhs_fields;
let rhs_b_subsumes_a = is_strict_prefix(&alias_b.rhs_fields, &alias_a.rhs_fields)
|| alias_b.rhs_fields == alias_a.rhs_fields;
let lhs_a_strict = is_strict_prefix(&alias_a.lhs_fields, &alias_b.lhs_fields);
let rhs_a_strict = is_strict_prefix(&alias_a.rhs_fields, &alias_b.rhs_fields);
let lhs_b_strict = is_strict_prefix(&alias_b.lhs_fields, &alias_a.lhs_fields);
let rhs_b_strict = is_strict_prefix(&alias_b.rhs_fields, &alias_a.rhs_fields);
let a_subsumes_b =
lhs_a_subsumes_b && rhs_a_subsumes_b && (lhs_a_strict || rhs_a_strict);
let b_subsumes_a =
lhs_b_subsumes_a && rhs_b_subsumes_a && (lhs_b_strict || rhs_b_strict);
if a_subsumes_b || b_subsumes_a {
let spec_a = alias_specificity(alias_a);
let spec_b = alias_specificity(alias_b);
if spec_a < spec_b {
to_remove.insert(alias_b.clone());
} else if spec_b < spec_a {
to_remove.insert(alias_a.clone());
break; }
}
}
}
aliases.difference(&to_remove).cloned().collect()
}
fn alias_specificity(alias: &AliasPair) -> usize {
alias.lhs_fields.len() + alias.rhs_fields.len()
}
fn is_strict_prefix(prefix: &[usize], full: &[usize]) -> bool {
prefix.len() < full.len() && prefix == &full[..prefix.len()]
}