use marque_scheme::{
CategoryAction, CategoryId, CategoryPredicate, MarkingScheme, PageRewrite, RewriteId,
};
use std::collections::{BTreeMap, BTreeSet};
use crate::errors::EngineConstructionError;
pub fn schedule_rewrites<S>(
rewrites: &[PageRewrite<S>],
) -> Result<Box<[RewriteId]>, EngineConstructionError>
where
S: MarkingScheme + ?Sized,
{
for rw in rewrites {
let has_custom = rewrite_is_custom(rw);
if has_custom && (rw.reads.is_empty() || rw.writes.is_empty()) {
return Err(EngineConstructionError::UnannotatedCustomAxes { rewrite: rw.id });
}
}
let n = rewrites.len();
let mut in_degree: Vec<usize> = vec![0; n];
let mut successors: Vec<BTreeSet<usize>> = vec![BTreeSet::new(); n];
let mut writers: BTreeMap<CategoryId, Vec<usize>> = BTreeMap::new();
for (idx, rw) in rewrites.iter().enumerate() {
for w in rw.writes {
writers.entry(*w).or_default().push(idx);
}
}
for (idx, rw) in rewrites.iter().enumerate() {
for read_cat in rw.reads {
let Some(producers) = writers.get(read_cat) else {
continue;
};
for &producer_idx in producers {
if producer_idx == idx {
continue;
}
if successors[producer_idx].insert(idx) {
in_degree[idx] += 1;
}
}
}
}
let mut frontier: std::collections::VecDeque<usize> =
(0..n).filter(|i| in_degree[*i] == 0).collect();
let mut scheduled: Vec<RewriteId> = Vec::with_capacity(n);
while let Some(idx) = frontier.pop_front() {
scheduled.push(rewrites[idx].id);
for &succ in &successors[idx] {
in_degree[succ] -= 1;
if in_degree[succ] == 0 {
frontier.push_back(succ);
}
}
}
if scheduled.len() != n {
let sccs = tarjan_sccs(n, &successors);
let mut cycle_sccs: Vec<Vec<usize>> = sccs
.into_iter()
.filter(|scc| {
scc.len() > 1
})
.collect();
debug_assert!(
!cycle_sccs.is_empty(),
"scheduled.len() != n but Tarjan found no non-trivial SCC; \
this indicates a logic error in `schedule_rewrites` or in \
`tarjan_sccs`, because Kahn's algorithm only leaves nodes \
unscheduled when the residual graph contains a cycle."
);
let picked = cycle_sccs
.iter_mut()
.min_by_key(|scc| scc.iter().min().copied().unwrap_or(usize::MAX))
.expect("debug_assert above guards the empty-Vec case");
picked.sort_unstable();
let axis = cycle_axis(rewrites, picked);
let members: Box<[RewriteId]> = picked
.iter()
.map(|&i| rewrites[i].id)
.collect::<Vec<_>>()
.into_boxed_slice();
return Err(EngineConstructionError::RewriteCycle { axis, members });
}
Ok(scheduled.into_boxed_slice())
}
fn tarjan_sccs(n: usize, successors: &[BTreeSet<usize>]) -> Vec<Vec<usize>> {
let mut index: Vec<Option<usize>> = vec![None; n];
let mut lowlink: Vec<usize> = vec![0; n];
let mut on_stack: Vec<bool> = vec![false; n];
let mut scc_stack: Vec<usize> = Vec::new();
let mut next_index: usize = 0;
let mut sccs: Vec<Vec<usize>> = Vec::new();
struct Frame {
node: usize,
successors: Vec<usize>,
pos: usize,
}
let mut dfs: Vec<Frame> = Vec::new();
for start in 0..n {
if index[start].is_some() {
continue;
}
index[start] = Some(next_index);
lowlink[start] = next_index;
next_index += 1;
scc_stack.push(start);
on_stack[start] = true;
dfs.push(Frame {
node: start,
successors: successors[start].iter().copied().collect(),
pos: 0,
});
while let Some(frame) = dfs.last_mut() {
if frame.pos < frame.successors.len() {
let w = frame.successors[frame.pos];
frame.pos += 1;
if index[w].is_none() {
index[w] = Some(next_index);
lowlink[w] = next_index;
next_index += 1;
scc_stack.push(w);
on_stack[w] = true;
dfs.push(Frame {
node: w,
successors: successors[w].iter().copied().collect(),
pos: 0,
});
} else if on_stack[w] {
let v = frame.node;
let w_idx = index[w].expect("index[w] was set when w was pushed");
lowlink[v] = lowlink[v].min(w_idx);
}
} else {
let v = frame.node;
dfs.pop();
if let Some(parent) = dfs.last_mut() {
lowlink[parent.node] = lowlink[parent.node].min(lowlink[v]);
}
let v_index = index[v].expect("index[v] was set at seed");
if lowlink[v] == v_index {
let mut component = Vec::new();
while let Some(w) = scc_stack.pop() {
on_stack[w] = false;
component.push(w);
if w == v {
break;
}
}
sccs.push(component);
}
}
}
}
sccs
}
fn rewrite_is_custom<S: MarkingScheme + ?Sized>(rw: &PageRewrite<S>) -> bool {
matches!(&rw.trigger, CategoryPredicate::Custom(_))
|| matches!(&rw.action, CategoryAction::Custom(_))
}
fn cycle_axis<S: MarkingScheme + ?Sized>(
rewrites: &[PageRewrite<S>],
indexes: &[usize],
) -> CategoryId {
let mut reads: BTreeSet<CategoryId> = BTreeSet::new();
let mut writes: BTreeSet<CategoryId> = BTreeSet::new();
for &i in indexes {
for r in rewrites[i].reads {
reads.insert(*r);
}
for w in rewrites[i].writes {
writes.insert(*w);
}
}
let picked = reads.intersection(&writes).next().copied();
debug_assert!(
picked.is_some(),
"cycle_axis called with no shared read/write axis; this should \
be unreachable when `indexes` names a real cycle in the scheduler \
graph. The release-mode fallback is CategoryId(0), but reaching \
it means `schedule_rewrites` classified a non-cycle as a cycle.",
);
picked.unwrap_or(CategoryId(0))
}
#[cfg(test)]
#[cfg_attr(coverage_nightly, coverage(off))]
mod tests {
use super::*;
use marque_scheme::{
Category, Constraint, ConstraintViolation, Lattice, Parsed, Scope, Template, TokenId,
TokenRef,
};
#[derive(Clone, Debug, PartialEq, Eq, Default)]
struct StubMarking;
impl Lattice for StubMarking {
fn join(&self, _other: &Self) -> Self {
Self
}
fn meet(&self, _other: &Self) -> Self {
Self
}
}
struct StubScheme;
impl MarkingScheme for StubScheme {
type Token = TokenId;
type Marking = StubMarking;
type ParseError = ();
fn name(&self) -> &str {
"stub"
}
fn schema_version(&self) -> &str {
"v0"
}
fn categories(&self) -> &[Category] {
&[]
}
fn constraints(&self) -> &[Constraint] {
&[]
}
fn templates(&self) -> &[Template] {
&[]
}
fn parse(&self, _: &str) -> Result<Parsed<Self::Marking>, Self::ParseError> {
Err(())
}
fn satisfies(&self, _: &Self::Marking, _: &TokenRef) -> bool {
false
}
fn validate(&self, _: &Self::Marking) -> Vec<ConstraintViolation> {
vec![]
}
fn project(&self, _: Scope, _: &[Self::Marking]) -> Self::Marking {
StubMarking
}
fn render_portion(&self, _: &Self::Marking) -> String {
String::new()
}
fn render_banner(&self, _: &Self::Marking) -> String {
String::new()
}
}
const CAT_X: CategoryId = CategoryId(1);
const CAT_Y: CategoryId = CategoryId(2);
const CAT_Z: CategoryId = CategoryId(3);
fn declarative(
id: RewriteId,
reads: &'static [CategoryId],
writes: &'static [CategoryId],
) -> PageRewrite<StubScheme> {
PageRewrite::declarative(
id,
"test",
CategoryPredicate::Empty { category: CAT_X },
CategoryAction::Clear { category: CAT_X },
reads,
writes,
)
}
#[test]
fn empty_input_is_empty_output() {
let scheduled = schedule_rewrites::<StubScheme>(&[]).unwrap();
assert!(scheduled.is_empty());
}
#[test]
fn no_dependencies_preserves_declaration_order() {
let rewrites = vec![
declarative("a", &[], &[CAT_X]),
declarative("b", &[], &[CAT_Y]),
declarative("c", &[], &[CAT_Z]),
];
let scheduled = schedule_rewrites(&rewrites).unwrap();
assert_eq!(scheduled.as_ref(), ["a", "b", "c"]);
}
#[test]
fn writer_before_reader() {
let rewrites = vec![
declarative("b", &[CAT_X], &[CAT_Y]),
declarative("a", &[], &[CAT_X]),
];
let scheduled = schedule_rewrites(&rewrites).unwrap();
assert_eq!(scheduled.as_ref(), ["a", "b"]);
}
#[test]
fn self_edge_is_permitted() {
let rewrites = vec![declarative("a", &[CAT_X], &[CAT_X])];
let scheduled = schedule_rewrites(&rewrites).unwrap();
assert_eq!(scheduled.as_ref(), ["a"]);
}
}