use std::collections::{HashMap, HashSet, VecDeque};
use super::constraints::{Constraint, ConstraintExtractor};
use super::types::{AbstractLocation, AliasError, AliasInfo, MAX_FIELD_DEPTH};
pub const MAX_ITERATIONS: usize = 100;
#[derive(Debug)]
pub struct AliasSolver {
points_to: HashMap<String, HashSet<String>>,
worklist: VecDeque<String>,
in_worklist: HashSet<String>,
copy_constraints: HashMap<String, Vec<String>>,
reverse_copy: HashMap<String, Vec<String>>,
field_loads: HashMap<String, Vec<(String, String)>>,
field_stores: HashMap<String, Vec<(String, String)>>,
alloc_sites: HashMap<String, AbstractLocation>,
phi_targets: HashSet<String>,
parameters: HashSet<String>,
last_changed: Vec<String>,
iterations: usize,
}
impl AliasSolver {
pub fn new(extractor: &ConstraintExtractor) -> Self {
let mut solver = AliasSolver {
points_to: HashMap::new(),
worklist: VecDeque::new(),
in_worklist: HashSet::new(),
copy_constraints: HashMap::new(),
reverse_copy: HashMap::new(),
field_loads: HashMap::new(),
field_stores: HashMap::new(),
alloc_sites: HashMap::new(),
phi_targets: extractor.phi_targets().clone(),
parameters: extractor.parameters().clone(),
last_changed: Vec::new(),
iterations: 0,
};
solver.index_constraints(extractor.constraints());
solver.initialize_allocs(extractor);
solver.initialize_parameters();
solver
}
fn index_constraints(&mut self, constraints: &[Constraint]) {
for constraint in constraints {
match constraint {
Constraint::Copy { target, source } => {
self.copy_constraints
.entry(target.clone())
.or_default()
.push(source.clone());
self.reverse_copy
.entry(source.clone())
.or_default()
.push(target.clone());
}
Constraint::Alloc { target, site } => {
self.alloc_sites.insert(target.clone(), site.clone());
}
Constraint::FieldLoad {
target,
base,
field,
} => {
self.field_loads
.entry(target.clone())
.or_default()
.push((base.clone(), field.clone()));
self.reverse_copy
.entry(base.clone())
.or_default()
.push(target.clone());
}
Constraint::FieldStore {
base,
field,
source,
} => {
self.field_stores
.entry(base.clone())
.or_default()
.push((field.clone(), source.clone()));
self.reverse_copy
.entry(source.clone())
.or_default()
.push(base.clone());
}
}
}
}
fn initialize_allocs(&mut self, extractor: &ConstraintExtractor) {
for constraint in extractor.constraints() {
if let Constraint::Alloc { target, site } = constraint {
let location_str = site.format();
self.points_to
.entry(target.clone())
.or_default()
.insert(location_str);
self.add_to_worklist(target);
}
}
}
fn initialize_parameters(&mut self) {
for param in &self.parameters.clone() {
let location = AbstractLocation::param(param);
let location_str = location.format();
self.points_to
.entry(param.clone())
.or_default()
.insert(location_str);
self.add_to_worklist(param);
}
}
fn add_to_worklist(&mut self, var: &str) {
if !self.in_worklist.contains(var) {
self.worklist.push_back(var.to_string());
self.in_worklist.insert(var.to_string());
}
}
pub fn solve(&mut self) -> Result<(), AliasError> {
self.iterations = 0;
while !self.worklist.is_empty() {
self.iterations += 1;
if self.iterations > MAX_ITERATIONS {
return Err(AliasError::IterationLimit(self.iterations));
}
self.last_changed.clear();
let current_worklist: Vec<String> = self.worklist.drain(..).collect();
self.in_worklist.clear();
for var in current_worklist {
self.propagate_variable(&var);
}
}
Ok(())
}
fn propagate_variable(&mut self, var: &str) {
let current_pts = self.points_to.get(var).cloned().unwrap_or_default();
if let Some(targets) = self.reverse_copy.get(var).cloned() {
for target in targets {
if let Some(field_loads) = self.field_loads.get(&target).cloned() {
for (base, field) in field_loads {
if base == var {
self.propagate_field_load(&target, ¤t_pts, &field);
}
}
} else if self.copy_constraints.contains_key(&target) {
self.propagate_copy(&target, ¤t_pts);
}
}
}
if let Some(stores) = self.field_stores.get(var).cloned() {
for (field, source) in stores {
self.propagate_field_store(¤t_pts, &field, &source);
}
}
}
fn propagate_copy(&mut self, target: &str, source_pts: &HashSet<String>) {
let mut target_pts = self.points_to.get(target).cloned().unwrap_or_default();
let old_size = target_pts.len();
for loc in source_pts {
target_pts.insert(loc.clone());
}
let changed = target_pts.len() > old_size;
self.points_to.insert(target.to_string(), target_pts);
if changed {
self.last_changed.push(target.to_string());
self.add_to_worklist(target);
}
}
fn propagate_field_load(&mut self, target: &str, base_pts: &HashSet<String>, field: &str) {
let field_locs: Vec<String> = base_pts
.iter()
.map(|loc| self.create_field_location(loc, field))
.collect();
let mut target_pts = self.points_to.get(target).cloned().unwrap_or_default();
let old_size = target_pts.len();
for field_loc in field_locs {
target_pts.insert(field_loc);
}
let changed = target_pts.len() > old_size;
self.points_to.insert(target.to_string(), target_pts);
if changed {
self.last_changed.push(target.to_string());
self.add_to_worklist(target);
}
}
fn propagate_field_store(&mut self, base_pts: &HashSet<String>, field: &str, source: &str) {
let source_pts = self.points_to.get(source).cloned().unwrap_or_default();
let field_locs: Vec<String> = base_pts
.iter()
.map(|loc| self.create_field_location(loc, field))
.collect();
for field_loc in field_locs {
let mut field_pts = self.points_to.get(&field_loc).cloned().unwrap_or_default();
let old_size = field_pts.len();
for source_loc in &source_pts {
field_pts.insert(source_loc.clone());
}
let changed = field_pts.len() > old_size;
self.points_to.insert(field_loc.clone(), field_pts);
if changed {
self.last_changed.push(field_loc.clone());
self.add_to_worklist(&field_loc);
}
}
}
fn create_field_location(&self, base: &str, field: &str) -> String {
let depth = base.matches('.').count();
if depth >= MAX_FIELD_DEPTH {
format!("{}.truncated", base)
} else {
format!("{}.{}", base, field)
}
}
pub fn build_alias_info(&self, function_name: &str) -> AliasInfo {
let mut info = AliasInfo::new(function_name);
info.points_to = self.points_to.clone();
for (target, site) in &self.alloc_sites {
if let AbstractLocation::Alloc { site: line } = site {
info.add_allocation_site(*line, &site.format());
}
info.add_points_to(target, &site.format());
}
self.compute_may_alias(&mut info);
self.compute_must_alias(&mut info);
self.add_parameter_aliasing(&mut info);
info
}
fn compute_may_alias(&self, info: &mut AliasInfo) {
let vars: Vec<_> = self.points_to.keys().cloned().collect();
for i in 0..vars.len() {
for j in (i + 1)..vars.len() {
let v1 = &vars[i];
let v2 = &vars[j];
let pts1 = self.points_to.get(v1);
let pts2 = self.points_to.get(v2);
if let (Some(set1), Some(set2)) = (pts1, pts2) {
if !set1.is_disjoint(set2) {
info.add_may_alias(v1, v2);
}
}
}
}
for (target, sources) in &self.copy_constraints {
for source in sources {
info.add_may_alias(target, source);
}
}
}
fn compute_must_alias(&self, info: &mut AliasInfo) {
let mut direct_aliases: HashMap<String, HashSet<String>> = HashMap::new();
for (target, sources) in &self.copy_constraints {
if self.phi_targets.contains(target) {
continue;
}
for source in sources {
direct_aliases
.entry(target.clone())
.or_default()
.insert(source.clone());
direct_aliases
.entry(source.clone())
.or_default()
.insert(target.clone());
}
}
let transitive = self.transitive_closure(&direct_aliases);
for (var, aliases) in transitive {
for alias in aliases {
info.add_must_alias(&var, &alias);
}
}
}
fn transitive_closure(
&self,
relation: &HashMap<String, HashSet<String>>,
) -> HashMap<String, HashSet<String>> {
let mut result = relation.clone();
let vars: HashSet<_> = relation
.keys()
.chain(relation.values().flatten())
.cloned()
.collect();
let mut changed = true;
let mut iterations = 0;
while changed && iterations < MAX_ITERATIONS {
changed = false;
iterations += 1;
let mut updates: Vec<(String, String)> = Vec::new();
for v in &vars {
let current_aliases: Vec<String> = result
.get(v)
.cloned()
.unwrap_or_default()
.into_iter()
.collect();
for alias in current_aliases {
let transitive_aliases: Vec<String> = result
.get(&alias)
.cloned()
.unwrap_or_default()
.into_iter()
.filter(|x| x != v)
.collect();
let v_set = result.get(v).cloned().unwrap_or_default();
for x in transitive_aliases {
if !v_set.contains(&x) {
updates.push((v.clone(), x.clone()));
updates.push((x, v.clone())); }
}
}
}
for (from, to) in updates {
if result.entry(from).or_default().insert(to) {
changed = true;
}
}
}
result
}
fn add_parameter_aliasing(&self, info: &mut AliasInfo) {
let params: Vec<_> = self.parameters.iter().cloned().collect();
for i in 0..params.len() {
for j in (i + 1)..params.len() {
info.add_may_alias(¶ms[i], ¶ms[j]);
}
}
}
pub fn iterations(&self) -> usize {
self.iterations
}
pub fn last_changed(&self) -> &[String] {
&self.last_changed
}
pub fn get_points_to(&self, var: &str) -> HashSet<String> {
self.points_to.get(var).cloned().unwrap_or_default()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_extractor() -> ConstraintExtractor {
ConstraintExtractor::new()
}
#[test]
fn test_solver_new() {
let extractor = create_test_extractor();
let solver = AliasSolver::new(&extractor);
assert!(solver.points_to.is_empty());
assert!(solver.worklist.is_empty());
assert_eq!(solver.iterations, 0);
}
#[test]
fn test_solver_empty_constraints() {
let extractor = create_test_extractor();
let mut solver = AliasSolver::new(&extractor);
let result = solver.solve();
assert!(result.is_ok());
assert_eq!(solver.iterations, 0);
}
#[test]
fn test_solver_build_alias_info_empty() {
let extractor = create_test_extractor();
let solver = AliasSolver::new(&extractor);
let info = solver.build_alias_info("test_func");
assert_eq!(info.function_name, "test_func");
assert!(info.may_alias.is_empty());
assert!(info.must_alias.is_empty());
assert!(info.points_to.is_empty());
}
#[test]
fn test_create_field_location_simple() {
let extractor = create_test_extractor();
let solver = AliasSolver::new(&extractor);
let loc = solver.create_field_location("alloc_5", "data");
assert_eq!(loc, "alloc_5.data");
}
#[test]
fn test_create_field_location_truncates_deep() {
let extractor = create_test_extractor();
let solver = AliasSolver::new(&extractor);
let mut base = "alloc_1".to_string();
for i in 0..MAX_FIELD_DEPTH {
base = format!("{}.field{}", base, i);
}
let loc = solver.create_field_location(&base, "too_deep");
assert!(loc.ends_with(".truncated"));
}
#[test]
fn test_transitive_closure_simple() {
let extractor = create_test_extractor();
let solver = AliasSolver::new(&extractor);
let mut relation: HashMap<String, HashSet<String>> = HashMap::new();
relation.insert("x".to_string(), HashSet::from(["y".to_string()]));
relation.insert(
"y".to_string(),
HashSet::from(["x".to_string(), "z".to_string()]),
);
relation.insert("z".to_string(), HashSet::from(["y".to_string()]));
let result = solver.transitive_closure(&relation);
assert!(result.get("x").is_some_and(|s| s.contains("z")));
assert!(result.get("z").is_some_and(|s| s.contains("x")));
}
#[test]
fn test_add_to_worklist_deduplication() {
let extractor = create_test_extractor();
let mut solver = AliasSolver::new(&extractor);
solver.add_to_worklist("x");
solver.add_to_worklist("x");
solver.add_to_worklist("x");
assert_eq!(solver.worklist.len(), 1);
}
#[test]
fn test_max_iterations_constant() {
assert_eq!(MAX_ITERATIONS, 100);
}
#[test]
fn test_field_location_propagation() {
let extractor = create_test_extractor();
let solver = AliasSolver::new(&extractor);
assert_eq!(
solver.create_field_location("alloc_1", "data"),
"alloc_1.data"
);
assert_eq!(
solver.create_field_location("alloc_1.inner", "value"),
"alloc_1.inner.value"
);
assert_eq!(
solver.create_field_location("param_x", "attr"),
"param_x.attr"
);
}
#[test]
fn test_field_depth_truncation() {
let extractor = create_test_extractor();
let solver = AliasSolver::new(&extractor);
let mut base = "alloc_1".to_string();
for i in 0..(MAX_FIELD_DEPTH - 1) {
base = format!("{}.f{}", base, i);
}
let within_limit = solver.create_field_location(&base, "ok");
assert!(!within_limit.contains("truncated"));
let at_limit = solver.create_field_location(&within_limit, "toomuch");
assert!(at_limit.ends_with(".truncated"));
}
#[test]
fn test_solver_convergence_simple() {
let extractor = create_test_extractor();
let mut solver = AliasSolver::new(&extractor);
let result = solver.solve();
assert!(result.is_ok());
assert!(solver.iterations() <= 1);
}
#[test]
fn test_parameter_aliasing_in_solver() {
use crate::ssa::types::{
SsaBlock, SsaFunction, SsaInstruction, SsaInstructionKind, SsaName, SsaNameId,
SsaStats, SsaType,
};
use std::path::PathBuf;
let ssa = SsaFunction {
function: "test".to_string(),
file: PathBuf::from("test.py"),
ssa_type: SsaType::Minimal,
blocks: vec![SsaBlock {
id: 0,
label: Some("entry".to_string()),
lines: (1, 1),
phi_functions: vec![],
instructions: vec![
SsaInstruction {
kind: SsaInstructionKind::Param,
target: Some(SsaNameId(0)),
uses: vec![],
line: 1,
source_text: Some("def f(a, b):".to_string()),
},
SsaInstruction {
kind: SsaInstructionKind::Param,
target: Some(SsaNameId(1)),
uses: vec![],
line: 1,
source_text: None,
},
],
successors: vec![],
predecessors: vec![],
}],
ssa_names: vec![
SsaName {
id: SsaNameId(0),
variable: "a".to_string(),
version: 0,
def_block: Some(0),
def_line: 1,
},
SsaName {
id: SsaNameId(1),
variable: "b".to_string(),
version: 0,
def_block: Some(0),
def_line: 1,
},
],
def_use: std::collections::HashMap::new(),
stats: SsaStats::default(),
};
let extractor = ConstraintExtractor::extract_from_ssa(&ssa).unwrap();
let mut solver = AliasSolver::new(&extractor);
solver.solve().unwrap();
let info = solver.build_alias_info("test");
assert!(info.may_alias_check("a_0", "b_0"));
}
}