use std::collections::HashMap;
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use dashmap::DashSet;
use log::debug;
use rayon::prelude::*;
use crate::{
error::{Error, Result},
events::EventKind,
graph::IndexedGraph,
ir::function::SsaFunction,
scheduling::pass::{ModificationScope, SsaPass, SsaPassHost},
target::Target,
};
type LayeredPass<T, H> = (Box<dyn SsaPass<T, H>>, usize);
pub struct PassScheduler<T, H>
where
T: Target,
T::MethodRef: Send + Sync,
H: SsaPassHost<T>,
{
max_iterations: usize,
stable_iterations: usize,
max_phase_iterations: usize,
passes: Vec<LayeredPass<T, H>>,
normalize: Vec<Box<dyn SsaPass<T, H>>>,
_host: std::marker::PhantomData<fn(&H)>,
}
impl<T, H> PassScheduler<T, H>
where
T: Target,
T::MethodRef: Send + Sync,
H: SsaPassHost<T>,
{
#[must_use]
pub fn new(
max_iterations: usize,
stable_iterations: usize,
max_phase_iterations: usize,
) -> Self {
Self {
max_iterations,
stable_iterations,
max_phase_iterations,
passes: Vec::new(),
normalize: Vec::new(),
_host: std::marker::PhantomData,
}
}
#[must_use]
pub fn pass_count(&self) -> usize {
self.passes.len()
}
#[must_use]
pub fn normalize_count(&self) -> usize {
self.normalize.len()
}
pub fn add(&mut self, pass: Box<dyn SsaPass<T, H>>) {
let layer = pass.fallback_layer();
self.passes.push((pass, layer));
}
pub fn add_at_layer(&mut self, pass: Box<dyn SsaPass<T, H>>, layer: usize) {
self.passes.push((pass, layer));
}
pub fn add_normalize(&mut self, pass: Box<dyn SsaPass<T, H>>) {
self.normalize.push(pass);
}
fn compute_layer_assignment(&self) -> Result<Vec<usize>> {
let n = self.passes.len();
if n == 0 {
return Ok(vec![]);
}
let mut providers: HashMap<T::Capability, Vec<usize>> = HashMap::new();
for (i, (pass, _)) in self.passes.iter().enumerate() {
for cap in pass.provides() {
providers.entry(*cap).or_default().push(i);
}
}
let mut graph: IndexedGraph<usize, ()> = IndexedGraph::with_capacity(n, n);
for i in 0..n {
graph.add_node(i);
}
let mut deps: Vec<Vec<usize>> = vec![vec![]; n];
for (i, (pass, _)) in self.passes.iter().enumerate() {
for cap in pass.requires() {
if let Some(provider_indices) = providers.get(cap) {
for &j in provider_indices {
if j != i {
if let Some(slot) = deps.get_mut(i) {
slot.push(j);
}
let _ = graph.add_edge(j, i, ());
}
}
}
}
}
if graph.topological_sort().is_none() {
if let Some(cycle) = graph.find_any_cycle() {
let names: Vec<&str> = cycle
.iter()
.filter_map(|&i| self.passes.get(i).map(|p| p.0.name()))
.collect();
return Err(Error::new(format!(
"Cycle detected in pass capability dependencies: {}",
names.join(" → ")
)));
}
return Err(Error::new("Cycle detected in pass capability dependencies"));
}
let mut layer: Vec<usize> = self.passes.iter().map(|(_, fallback)| *fallback).collect();
let mut changed = true;
while changed {
changed = false;
for i in 0..n {
let dep_list = match deps.get(i) {
Some(d) => d.clone(),
None => continue,
};
for dep in dep_list {
let layer_i = layer.get(i).copied().unwrap_or(0);
let layer_dep = layer.get(dep).copied().unwrap_or(0);
if layer_i <= layer_dep {
if let Some(slot) = layer.get_mut(i) {
*slot = layer_dep.saturating_add(1);
}
changed = true;
}
}
}
}
if !deps.iter().all(Vec::is_empty) {
let max_layer = layer.iter().copied().max().unwrap_or(0);
debug!(
"Capability scheduling: {} passes across {} layers",
n,
max_layer.saturating_add(1)
);
for (i, (pass, fallback)) in self.passes.iter().enumerate() {
let layer_i = layer.get(i).copied().unwrap_or(*fallback);
if layer_i != *fallback {
debug!(
" pass '{}': layer {} (moved from fallback {})",
pass.name(),
layer_i,
fallback
);
}
}
}
Ok(layer)
}
pub fn run_pipeline(&mut self, host: &H) -> Result<usize> {
for method in host.iter_methods() {
host.mark_dirty(&method);
}
let layer_assignment = self.compute_layer_assignment()?;
let num_layers = layer_assignment
.iter()
.copied()
.max()
.map_or(0, |m| m.saturating_add(1));
let mut layer_indices: Vec<Vec<usize>> = vec![vec![]; num_layers];
for (i, &layer) in layer_assignment.iter().enumerate() {
if let Some(slot) = layer_indices.get_mut(layer) {
slot.push(i);
}
}
layer_indices.retain(|layer| !layer.is_empty());
let mut stable_count: usize = 0;
let mut iterations: usize = 0;
let max_phase = self.max_phase_iterations;
let max_iterations = self.max_iterations;
let stable_iterations = self.stable_iterations;
for iteration in 0..max_iterations {
iterations = iteration.saturating_add(1);
debug!("Pipeline iteration {}/{}", iterations, max_iterations);
let iteration_modified: DashSet<T::MethodRef> = DashSet::new();
let mut iteration_changed = false;
for layer in &layer_indices {
if Self::layer_to_fixpoint(
host,
&mut self.passes,
layer,
&mut self.normalize,
max_phase,
&iteration_modified,
)? {
iteration_changed = true;
}
}
if iteration == 0 && !iteration_changed && !self.normalize.is_empty() {
iteration_changed = Self::normalize_to_fixpoint(
host,
&mut self.normalize,
max_phase,
&iteration_modified,
)?;
}
if iteration_changed {
let dirty = host.dirty_snapshot();
for m in dirty {
if !iteration_modified.contains(&m) {
host.clear_dirty_for(&m);
}
}
for entry in iteration_modified.iter() {
host.mark_dirty(&entry);
}
} else {
let dirty = host.dirty_snapshot();
for m in dirty {
host.clear_dirty_for(&m);
}
}
if iteration_changed {
stable_count = 0;
} else {
stable_count = stable_count.saturating_add(1);
if stable_count >= stable_iterations {
debug!("Pipeline stable after {} iterations", iterations);
break;
}
}
}
Ok(iterations)
}
fn normalize_to_fixpoint(
host: &H,
passes: &mut [Box<dyn SsaPass<T, H>>],
max_phase_iterations: usize,
iteration_modified: &DashSet<T::MethodRef>,
) -> Result<bool> {
let mut any_changed = false;
for _ in 0..max_phase_iterations {
let changed = Self::run_passes_once(host, passes, iteration_modified)?;
if !changed {
break;
}
any_changed = true;
}
Ok(any_changed)
}
fn layer_to_fixpoint(
host: &H,
all_passes: &mut [LayeredPass<T, H>],
layer_indices: &[usize],
normalize_passes: &mut [Box<dyn SsaPass<T, H>>],
max_phase_iterations: usize,
iteration_modified: &DashSet<T::MethodRef>,
) -> Result<bool> {
if layer_indices.is_empty() {
return Ok(false);
}
let mut phase_changed = false;
for _ in 0..max_phase_iterations {
let pass_changed =
Self::run_layer_passes_once(host, all_passes, layer_indices, iteration_modified)?;
if !pass_changed {
if phase_changed && !normalize_passes.is_empty() {
Self::normalize_to_fixpoint(
host,
normalize_passes,
max_phase_iterations,
iteration_modified,
)?;
}
break;
}
phase_changed = true;
if !normalize_passes.is_empty() {
Self::normalize_to_fixpoint(
host,
normalize_passes,
max_phase_iterations,
iteration_modified,
)?;
}
}
Ok(phase_changed)
}
fn run_passes_once(
host: &H,
passes: &mut [Box<dyn SsaPass<T, H>>],
iteration_modified: &DashSet<T::MethodRef>,
) -> Result<bool> {
for pass in passes.iter_mut() {
pass.initialize(host)?;
}
let all_methods = Self::method_order(host, false);
let dirty_methods = Self::method_order(host, true);
let any_changed = AtomicBool::new(false);
for pass in passes.iter() {
if pass.is_global() && pass.run_global(host)? {
any_changed.store(true, Ordering::Relaxed);
}
}
for pass in passes.iter() {
if pass.is_global() {
continue;
}
let methods = if pass.requires_full_scan() {
&all_methods
} else {
&dirty_methods
};
Self::run_single_pass(
pass.as_ref(),
host,
methods,
&any_changed,
iteration_modified,
);
}
for pass in passes.iter_mut() {
pass.finalize(host)?;
}
Ok(any_changed.load(Ordering::Relaxed))
}
fn run_layer_passes_once(
host: &H,
all_passes: &mut [LayeredPass<T, H>],
indices: &[usize],
iteration_modified: &DashSet<T::MethodRef>,
) -> Result<bool> {
for &idx in indices {
let pass_entry = all_passes
.get_mut(idx)
.ok_or_else(|| Error::new(format!("scheduler: pass index {idx} out of bounds")))?;
pass_entry.0.initialize(host)?;
}
let all_methods = Self::method_order(host, false);
let dirty_methods = Self::method_order(host, true);
let any_changed = AtomicBool::new(false);
for &idx in indices {
let pass_entry = all_passes
.get(idx)
.ok_or_else(|| Error::new(format!("scheduler: pass index {idx} out of bounds")))?;
let pass = &pass_entry.0;
if pass.is_global() && pass.run_global(host)? {
any_changed.store(true, Ordering::Relaxed);
}
}
for &idx in indices {
let pass_entry = all_passes
.get(idx)
.ok_or_else(|| Error::new(format!("scheduler: pass index {idx} out of bounds")))?;
let pass = &pass_entry.0;
if pass.is_global() {
continue;
}
let methods = if pass.requires_full_scan() {
&all_methods
} else {
&dirty_methods
};
Self::run_single_pass(
pass.as_ref(),
host,
methods,
&any_changed,
iteration_modified,
);
}
for &idx in indices {
let pass_entry = all_passes
.get_mut(idx)
.ok_or_else(|| Error::new(format!("scheduler: pass index {idx} out of bounds")))?;
pass_entry.0.finalize(host)?;
}
Ok(any_changed.load(Ordering::Relaxed))
}
fn method_order(host: &H, dirty_only: bool) -> Vec<T::MethodRef> {
let topo = host.methods_reverse_topological();
let order: Vec<_> = if topo.is_empty() {
host.iter_methods()
} else {
topo
};
let dirty_set: Option<Vec<T::MethodRef>> = if dirty_only {
Some(host.dirty_snapshot())
} else {
None
};
order
.into_iter()
.filter(|m| host.contains(m))
.filter(|m| dirty_set.as_ref().is_none_or(|d| d.contains(m)))
.collect()
}
fn run_single_pass(
pass: &dyn SsaPass<T, H>,
host: &H,
methods: &[T::MethodRef],
any_changed: &AtomicBool,
iteration_modified: &DashSet<T::MethodRef>,
) {
let event_snapshot = host.events().len();
let pass_change_count = AtomicUsize::new(0);
let clone_for_visibility = pass.reads_peer_ssa();
methods.par_iter().for_each(|method| {
if !pass.should_run(method, host) {
return;
}
let mut ssa: SsaFunction<T> = if clone_for_visibility {
let Some(cloned) = host.clone_ssa(method) else {
return;
};
cloned
} else {
let Some(ssa) = host.take_ssa(method) else {
return;
};
ssa
};
let result = pass.run_on_method(&mut ssa, method, host);
if let Ok(true) = result {
match pass.modification_scope() {
ModificationScope::UsesOnly | ModificationScope::InstructionsOnly => {
ssa.repair_ssa();
}
ModificationScope::CfgModifying => {
if let Err(e) = ssa.rebuild_ssa() {
log::warn!("SSA rebuild failed for method: {}", e);
}
}
}
}
host.insert_ssa(method.clone(), ssa);
if let Ok(true) = result {
any_changed.store(true, Ordering::Relaxed);
pass_change_count.fetch_add(1, Ordering::Relaxed);
host.mark_processed(method);
iteration_modified.insert(method.clone());
}
});
let count = pass_change_count.load(Ordering::Relaxed);
if count > 0 {
let event_delta = host.events().count_by_kind_since(event_snapshot);
if event_delta.is_empty() {
debug!(" pass '{}' changed {} methods", pass.name(), count);
} else {
let summary = format_event_delta(&event_delta);
if summary.is_empty() {
debug!(" pass '{}' changed {} methods", pass.name(), count);
} else {
debug!(
" pass '{}' changed {} methods ({})",
pass.name(),
count,
summary
);
}
}
}
}
}
fn format_event_delta(delta: &HashMap<EventKind, usize>) -> String {
let mut parts: Vec<String> = delta
.iter()
.filter(|(kind, _)| kind.is_transformation())
.map(|(kind, count)| format!("{} {}", count, kind.description()))
.collect();
parts.sort();
parts.join(", ")
}