use std::{
collections::HashMap,
sync::{
atomic::{AtomicBool, AtomicUsize, Ordering},
Arc,
},
};
use dashmap::DashSet;
use log::debug;
use rayon::prelude::*;
use crate::{
compiler::{
context::CompilerContext,
events::EventKind,
pass::{ModificationScope, PassCapability, PassPhase, SsaPass},
state::ProcessingState,
},
metadata::token::Token,
utils::graph::IndexedGraph,
CilObject, Error, Result,
};
pub struct PassScheduler {
max_iterations: usize,
stable_iterations: usize,
max_phase_iterations: usize,
passes: Vec<(Box<dyn SsaPass>, usize)>,
normalize: Vec<Box<dyn SsaPass>>,
}
impl Default for PassScheduler {
fn default() -> Self {
Self::new(5, 2, 15)
}
}
impl PassScheduler {
#[must_use]
pub fn new(
max_iterations: usize,
stable_iterations: usize,
max_phase_iterations: usize,
) -> Self {
Self {
max_iterations,
stable_iterations,
max_phase_iterations,
passes: Vec::new(),
normalize: Vec::new(),
}
}
#[must_use]
pub fn pass_count(&self) -> usize {
self.passes.len()
}
#[must_use]
pub fn normalize_count(&self) -> usize {
self.normalize.len()
}
pub fn add(&mut self, pass: Box<dyn SsaPass>, phase: PassPhase) {
match phase {
PassPhase::Normalize => self.normalize.push(pass),
_ => self.passes.push((pass, phase.as_layer())),
}
}
fn compute_layer_assignment(&self) -> Result<Vec<usize>> {
let n = self.passes.len();
if n == 0 {
return Ok(vec![]);
}
let mut providers: HashMap<PassCapability, Vec<usize>> = HashMap::new();
for (i, (pass, _)) in self.passes.iter().enumerate() {
for &cap in pass.provides() {
providers.entry(cap).or_default().push(i);
}
}
let mut graph: IndexedGraph<usize, ()> = IndexedGraph::with_capacity(n, n);
for i in 0..n {
graph.add_node(i);
}
let mut deps: Vec<Vec<usize>> = vec![vec![]; n];
for (i, (pass, _)) in self.passes.iter().enumerate() {
for &cap in pass.requires() {
if let Some(provider_indices) = providers.get(&cap) {
for &j in provider_indices {
if j != i {
deps[i].push(j);
let _ = graph.add_edge(j, i, ());
}
}
}
}
}
if graph.topological_sort().is_none() {
if let Some(cycle) = graph.find_any_cycle() {
let names: Vec<&str> = cycle.iter().map(|&i| self.passes[i].0.name()).collect();
return Err(Error::SsaError(format!(
"Cycle detected in pass capability dependencies: {}",
names.join(" → ")
)));
}
return Err(Error::SsaError(
"Cycle detected in pass capability dependencies".to_string(),
));
}
let mut layer: Vec<usize> = self.passes.iter().map(|(_, fallback)| *fallback).collect();
let mut changed = true;
while changed {
changed = false;
for i in 0..n {
for &dep in &deps[i] {
if layer[i] <= layer[dep] {
layer[i] = layer[dep] + 1;
changed = true;
}
}
}
}
if !deps.iter().all(Vec::is_empty) {
let max_layer = layer.iter().copied().max().unwrap_or(0);
debug!(
"Capability scheduling: {} passes across {} layers",
n,
max_layer + 1
);
for (i, (pass, fallback)) in self.passes.iter().enumerate() {
if layer[i] != *fallback {
debug!(
" pass '{}': layer {} (moved from fallback {})",
pass.name(),
layer[i],
fallback
);
}
}
}
Ok(layer)
}
fn normalize_to_fixpoint(
ctx: &CompilerContext,
passes: &mut [Box<dyn SsaPass>],
max_phase_iterations: usize,
assembly: &Arc<CilObject>,
state: Option<&ProcessingState>,
iteration_modified: Option<&DashSet<Token>>,
) -> Result<bool> {
let mut any_changed = false;
for _ in 0..max_phase_iterations {
let changed = Self::run_passes_once(ctx, passes, assembly, state, iteration_modified)?;
if !changed {
break;
}
any_changed = true;
}
Ok(any_changed)
}
#[allow(clippy::too_many_arguments)]
fn layer_to_fixpoint(
ctx: &CompilerContext,
all_passes: &mut [(Box<dyn SsaPass>, usize)],
layer_indices: &[usize],
normalize_passes: &mut [Box<dyn SsaPass>],
max_phase_iterations: usize,
assembly: &Arc<CilObject>,
state: Option<&ProcessingState>,
iteration_modified: Option<&DashSet<Token>>,
) -> Result<bool> {
if layer_indices.is_empty() {
return Ok(false);
}
let mut phase_changed = false;
for _ in 0..max_phase_iterations {
let pass_changed = Self::run_layer_passes_once(
ctx,
all_passes,
layer_indices,
assembly,
state,
iteration_modified,
)?;
if !pass_changed {
if phase_changed && !normalize_passes.is_empty() {
Self::normalize_to_fixpoint(
ctx,
normalize_passes,
max_phase_iterations,
assembly,
state,
iteration_modified,
)?;
}
break;
}
phase_changed = true;
if !normalize_passes.is_empty() {
Self::normalize_to_fixpoint(
ctx,
normalize_passes,
max_phase_iterations,
assembly,
state,
iteration_modified,
)?;
}
}
Ok(phase_changed)
}
fn run_passes_once(
ctx: &CompilerContext,
passes: &mut [Box<dyn SsaPass>],
assembly: &Arc<CilObject>,
state: Option<&ProcessingState>,
iteration_modified: Option<&DashSet<Token>>,
) -> Result<bool> {
for pass in passes.iter_mut() {
pass.initialize(ctx)?;
}
let dirty_set = state.map(|s| &s.method_dirty);
let all_methods = Self::method_order(ctx, None);
let dirty_methods = Self::method_order(ctx, dirty_set);
let any_changed = AtomicBool::new(false);
for pass in passes.iter() {
if pass.is_global() && pass.run_global(ctx, assembly)? {
any_changed.store(true, Ordering::Relaxed);
}
}
for pass in passes.iter() {
if pass.is_global() {
continue;
}
let methods = if pass.requires_full_scan() {
&all_methods
} else {
&dirty_methods
};
Self::run_single_pass(
pass.as_ref(),
ctx,
methods,
assembly,
&any_changed,
iteration_modified,
);
}
for pass in passes.iter_mut() {
pass.finalize(ctx)?;
}
Ok(any_changed.load(Ordering::Relaxed))
}
fn run_layer_passes_once(
ctx: &CompilerContext,
all_passes: &mut [(Box<dyn SsaPass>, usize)],
indices: &[usize],
assembly: &Arc<CilObject>,
state: Option<&ProcessingState>,
iteration_modified: Option<&DashSet<Token>>,
) -> Result<bool> {
for &idx in indices {
all_passes[idx].0.initialize(ctx)?;
}
let dirty_set = state.map(|s| &s.method_dirty);
let all_methods = Self::method_order(ctx, None);
let dirty_methods = Self::method_order(ctx, dirty_set);
let any_changed = AtomicBool::new(false);
for &idx in indices {
let pass = &all_passes[idx].0;
if pass.is_global() && pass.run_global(ctx, assembly)? {
any_changed.store(true, Ordering::Relaxed);
}
}
for &idx in indices {
let pass = &all_passes[idx].0;
if pass.is_global() {
continue;
}
let methods = if pass.requires_full_scan() {
&all_methods
} else {
&dirty_methods
};
Self::run_single_pass(
pass.as_ref(),
ctx,
methods,
assembly,
&any_changed,
iteration_modified,
);
}
for &idx in indices {
all_passes[idx].0.finalize(ctx)?;
}
Ok(any_changed.load(Ordering::Relaxed))
}
fn method_order(ctx: &CompilerContext, dirty_only: Option<&DashSet<Token>>) -> Vec<Token> {
let topo = ctx.methods_reverse_topological();
let order: Vec<_> = if topo.is_empty() {
ctx.all_methods().collect()
} else {
topo
};
order
.into_iter()
.filter(|token| ctx.ssa_functions.contains_key(token))
.filter(|token| dirty_only.is_none_or(|dirty| dirty.contains(token)))
.collect()
}
fn run_single_pass(
pass: &dyn SsaPass,
ctx: &CompilerContext,
methods: &[Token],
assembly: &Arc<CilObject>,
any_changed: &AtomicBool,
iteration_modified: Option<&DashSet<Token>>,
) {
let event_snapshot = ctx.events.len();
let pass_change_count = AtomicUsize::new(0);
let clone_for_visibility = pass.reads_peer_ssa();
methods.par_iter().for_each(|&method_token| {
if !pass.should_run(method_token, ctx) {
return;
}
let mut ssa = if clone_for_visibility {
let Some(ssa_ref) = ctx.ssa_functions.get(&method_token) else {
return;
};
ssa_ref.clone()
} else {
let Some((_, ssa)) = ctx.ssa_functions.remove(&method_token) else {
return;
};
ssa
};
let result = pass.run_on_method(&mut ssa, method_token, ctx, assembly);
if let Ok(true) = result {
match pass.modification_scope() {
ModificationScope::UsesOnly | ModificationScope::InstructionsOnly => {
ssa.repair_ssa();
}
ModificationScope::CfgModifying => {
if let Err(e) = ssa.rebuild_ssa() {
log::warn!("SSA rebuild failed for {}: {}", method_token, e);
}
}
}
}
ctx.ssa_functions.insert(method_token, ssa);
if let Ok(true) = result {
any_changed.store(true, Ordering::Relaxed);
pass_change_count.fetch_add(1, Ordering::Relaxed);
ctx.processed_methods.insert(method_token);
if let Some(modified) = iteration_modified {
modified.insert(method_token);
}
}
});
let count = pass_change_count.load(Ordering::Relaxed);
if count > 0 {
let event_delta = ctx.events.count_by_kind_since(event_snapshot);
if event_delta.is_empty() {
debug!(" pass '{}' changed {} methods", pass.name(), count);
} else {
let summary = format_event_delta(&event_delta);
if summary.is_empty() {
debug!(" pass '{}' changed {} methods", pass.name(), count);
} else {
debug!(
" pass '{}' changed {} methods ({})",
pass.name(),
count,
summary
);
}
}
}
}
pub fn run_pipeline(
&mut self,
ctx: &CompilerContext,
assembly: &Arc<CilObject>,
state: Option<&ProcessingState>,
) -> Result<usize> {
let layer_assignment = self.compute_layer_assignment()?;
let num_layers = layer_assignment.iter().copied().max().map_or(0, |m| m + 1);
let mut layer_indices: Vec<Vec<usize>> = vec![vec![]; num_layers];
for (i, &layer) in layer_assignment.iter().enumerate() {
layer_indices[layer].push(i);
}
layer_indices.retain(|layer| !layer.is_empty());
let mut stable_count = 0;
let mut iterations = 0;
let max_phase = self.max_phase_iterations;
let max_iterations = self.max_iterations;
let stable_iterations = self.stable_iterations;
for iteration in 0..max_iterations {
iterations = iteration + 1;
debug!("Pipeline iteration {}/{}", iterations, max_iterations);
let iteration_modified = DashSet::new();
let modified_ref = state.map(|_| &iteration_modified);
let mut iteration_changed = false;
for layer in &layer_indices {
if Self::layer_to_fixpoint(
ctx,
&mut self.passes,
layer,
&mut self.normalize,
max_phase,
assembly,
state,
modified_ref,
)? {
iteration_changed = true;
}
}
if iteration == 0 && !iteration_changed && !self.normalize.is_empty() {
iteration_changed = Self::normalize_to_fixpoint(
ctx,
&mut self.normalize,
max_phase,
assembly,
state,
modified_ref,
)?;
}
if let Some(state) = state {
if iteration_changed {
let dirty: Vec<Token> = state.method_dirty.iter().map(|t| *t).collect();
for token in dirty {
if !iteration_modified.contains(&token) {
state.mark_method_stable(token);
}
}
for token in iteration_modified.iter() {
state.mark_method_dirty(*token);
}
} else {
let dirty: Vec<Token> = state.method_dirty.iter().map(|t| *t).collect();
for token in dirty {
state.mark_method_stable(token);
}
}
}
if iteration_changed {
stable_count = 0;
} else {
stable_count += 1;
if stable_count >= stable_iterations {
debug!("Pipeline stable after {} iterations", iterations);
break;
}
}
}
Ok(iterations)
}
}
fn format_event_delta(delta: &HashMap<EventKind, usize>) -> String {
let mut parts: Vec<String> = delta
.iter()
.filter(|(kind, _)| kind.is_transformation())
.map(|(kind, count)| format!("{} {}", count, kind.description()))
.collect();
parts.sort();
parts.join(", ")
}
#[cfg(test)]
mod tests {
use crate::{
analysis::SsaFunction,
compiler::{
context::CompilerContext,
pass::{PassCapability, PassPhase, SsaPass},
EventKind, PassScheduler,
},
metadata::token::Token,
CilObject, Result,
};
struct TestPass {
name: &'static str,
changes_to_make: usize,
}
impl TestPass {
fn new(name: &'static str, changes: usize) -> Self {
Self {
name,
changes_to_make: changes,
}
}
}
impl SsaPass for TestPass {
fn name(&self) -> &'static str {
self.name
}
fn run_on_method(
&self,
_ssa: &mut SsaFunction,
method_token: Token,
ctx: &CompilerContext,
_assembly: &CilObject,
) -> Result<bool> {
for i in 0..self.changes_to_make {
ctx.events
.record(EventKind::ConstantFolded)
.at(method_token, i)
.message("test");
}
Ok(self.changes_to_make > 0)
}
}
struct CapabilityPass {
name: &'static str,
provides: Vec<PassCapability>,
requires: Vec<PassCapability>,
}
impl SsaPass for CapabilityPass {
fn name(&self) -> &'static str {
self.name
}
fn run_on_method(
&self,
_ssa: &mut SsaFunction,
_method_token: Token,
_ctx: &CompilerContext,
_assembly: &CilObject,
) -> Result<bool> {
Ok(false)
}
fn provides(&self) -> &[PassCapability] {
&self.provides
}
fn requires(&self) -> &[PassCapability] {
&self.requires
}
}
#[test]
fn test_scheduler_iteration_limits() {
let scheduler = PassScheduler::new(10, 3, 5);
assert_eq!(scheduler.max_iterations, 10);
assert_eq!(scheduler.stable_iterations, 3);
assert_eq!(scheduler.max_phase_iterations, 5);
}
#[test]
fn test_default_scheduler() {
let scheduler = PassScheduler::default();
assert_eq!(scheduler.max_iterations, 5);
assert_eq!(scheduler.stable_iterations, 2);
assert_eq!(scheduler.max_phase_iterations, 15);
}
#[test]
fn test_pass_names() {
let passes: Vec<Box<dyn SsaPass>> = vec![
Box::new(TestPass::new("pass1", 0)),
Box::new(TestPass::new("pass2", 0)),
];
assert_eq!(passes.len(), 2);
assert_eq!(passes[0].name(), "pass1");
assert_eq!(passes[1].name(), "pass2");
}
#[test]
fn test_add_pass() {
let mut scheduler = PassScheduler::new(5, 2, 15);
scheduler.add(
Box::new(TestPass::new("structure_pass", 0)),
PassPhase::Structure,
);
scheduler.add(Box::new(TestPass::new("value_pass", 0)), PassPhase::Value);
scheduler.add(
Box::new(TestPass::new("simplify_pass", 0)),
PassPhase::Simplify,
);
assert_eq!(scheduler.pass_count(), 3);
}
#[test]
fn test_capability_layer_computation() {
let mut scheduler = PassScheduler::new(5, 2, 15);
scheduler.add(
Box::new(CapabilityPass {
name: "value-resolver",
provides: vec![PassCapability::ResolvedStaticFields],
requires: vec![],
}),
PassPhase::Value,
);
scheduler.add(
Box::new(CapabilityPass {
name: "cff-reconstruction",
provides: vec![PassCapability::RestoredControlFlow],
requires: vec![PassCapability::ResolvedStaticFields],
}),
PassPhase::Structure,
);
scheduler.add(
Box::new(CapabilityPass {
name: "opaque-predicates",
provides: vec![PassCapability::SimplifiedPredicates],
requires: vec![PassCapability::RestoredControlFlow],
}),
PassPhase::Simplify,
);
let layers = scheduler.compute_layer_assignment().unwrap();
assert_eq!(layers[0], 1); assert_eq!(layers[1], 2); assert_eq!(layers[2], 3); }
#[test]
fn test_no_capabilities_uses_fallback() {
let mut scheduler = PassScheduler::new(5, 2, 15);
scheduler.add(
Box::new(TestPass::new("structure", 0)),
PassPhase::Structure,
);
scheduler.add(Box::new(TestPass::new("value", 0)), PassPhase::Value);
scheduler.add(Box::new(TestPass::new("simplify", 0)), PassPhase::Simplify);
let layers = scheduler.compute_layer_assignment().unwrap();
assert_eq!(layers[0], 0);
assert_eq!(layers[1], 1);
assert_eq!(layers[2], 2);
}
#[test]
fn test_missing_provider_uses_fallback() {
let mut scheduler = PassScheduler::new(5, 2, 15);
scheduler.add(
Box::new(CapabilityPass {
name: "cff",
provides: vec![PassCapability::RestoredControlFlow],
requires: vec![PassCapability::ResolvedStaticFields],
}),
PassPhase::Structure,
);
let layers = scheduler.compute_layer_assignment().unwrap();
assert_eq!(layers[0], 0);
}
#[test]
fn test_cycle_detection() {
let mut scheduler = PassScheduler::new(5, 2, 15);
scheduler.add(
Box::new(CapabilityPass {
name: "pass-a",
provides: vec![PassCapability::ResolvedStaticFields],
requires: vec![PassCapability::RestoredControlFlow],
}),
PassPhase::Structure,
);
scheduler.add(
Box::new(CapabilityPass {
name: "pass-b",
provides: vec![PassCapability::RestoredControlFlow],
requires: vec![PassCapability::ResolvedStaticFields],
}),
PassPhase::Structure,
);
let result = scheduler.compute_layer_assignment();
assert!(result.is_err());
}
}