use crate::ir::Program;
use crate::optimizer::{fingerprint_program, vyre_pass, PassAnalysis, PassResult};
#[derive(Debug, Default)]
#[vyre_pass(
name = "buffer_decl_sort",
requires = [],
invalidates = []
)]
pub struct BufferDeclSortPass;
impl BufferDeclSortPass {
#[must_use]
pub fn analyze(program: &Program) -> PassAnalysis {
if buffers_in_canonical_order(program) {
PassAnalysis::SKIP
} else {
PassAnalysis::RUN
}
}
#[must_use]
pub fn transform(program: Program) -> PassResult {
if buffers_in_canonical_order(&program) {
return PassResult {
program,
changed: false,
};
}
let mut buffers = program.buffers().to_vec();
buffers.sort_by(|a, b| a.binding.cmp(&b.binding).then_with(|| a.name.cmp(&b.name)));
let new_program = program.with_rewritten_buffers(buffers);
PassResult {
program: new_program,
changed: true,
}
}
#[must_use]
pub fn fingerprint(program: &Program) -> u64 {
fingerprint_program(program)
}
}
fn buffers_in_canonical_order(program: &Program) -> bool {
let buffers = program.buffers();
if buffers.len() < 2 {
return true;
}
buffers
.windows(2)
.all(|pair| match pair[0].binding.cmp(&pair[1].binding) {
std::cmp::Ordering::Less => true,
std::cmp::Ordering::Equal => pair[0].name <= pair[1].name,
std::cmp::Ordering::Greater => false,
})
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BufferAccess, BufferDecl, DataType, Expr, Node};
fn buf(name: &str, binding: u32) -> BufferDecl {
BufferDecl::storage(name, binding, BufferAccess::ReadWrite, DataType::U32).with_count(4)
}
fn ro_buf(name: &str, binding: u32) -> BufferDecl {
BufferDecl::storage(name, binding, BufferAccess::ReadOnly, DataType::U32).with_count(4)
}
fn entry() -> Vec<Node> {
vec![Node::store("a", Expr::u32(0), Expr::u32(7))]
}
#[test]
fn skip_analysis_on_already_sorted() {
let program = Program::wrapped(vec![buf("a", 0), buf("b", 1)], [1, 1, 1], entry());
assert_eq!(BufferDeclSortPass::analyze(&program), PassAnalysis::SKIP);
}
#[test]
fn run_analysis_on_unsorted() {
let program = Program::wrapped(vec![buf("a", 1), buf("b", 0)], [1, 1, 1], entry());
assert_eq!(BufferDeclSortPass::analyze(&program), PassAnalysis::RUN);
}
#[test]
fn transform_sorts_simple_two_buffer_swap() {
let program = Program::wrapped(vec![buf("late", 5), buf("early", 0)], [1, 1, 1], entry());
let result = BufferDeclSortPass::transform(program);
assert!(result.changed);
let bindings: Vec<u32> = result.program.buffers().iter().map(|b| b.binding).collect();
assert_eq!(bindings, vec![0, 5]);
let names: Vec<&str> = result
.program
.buffers()
.iter()
.map(|b| b.name.as_ref())
.collect();
assert_eq!(names, vec!["early", "late"]);
}
#[test]
fn transform_preserves_already_sorted_program_unchanged() {
let program = Program::wrapped(
vec![buf("a", 0), buf("b", 3), buf("c", 7)],
[1, 1, 1],
entry(),
);
let result = BufferDeclSortPass::transform(program);
assert!(
!result.changed,
"already-sorted Program must not report changed"
);
}
#[test]
fn transform_uses_name_tiebreaker_when_bindings_collide() {
let program = Program::wrapped(vec![buf("beta", 3), buf("alpha", 3)], [1, 1, 1], entry());
let result = BufferDeclSortPass::transform(program);
assert!(result.changed);
let names: Vec<&str> = result
.program
.buffers()
.iter()
.map(|b| b.name.as_ref())
.collect();
assert_eq!(names, vec!["alpha", "beta"]);
}
#[test]
fn transform_preserves_per_buffer_metadata() {
let read_write = buf("rw", 5);
let read_only = ro_buf("ro", 0);
let program = Program::wrapped(
vec![read_write.clone(), read_only.clone()],
[1, 1, 1],
entry(),
);
let result = BufferDeclSortPass::transform(program);
let buffers = result.program.buffers();
assert_eq!(buffers[0].name.as_ref(), "ro");
assert_eq!(buffers[0].access, BufferAccess::ReadOnly);
assert_eq!(buffers[1].name.as_ref(), "rw");
assert_eq!(buffers[1].access, BufferAccess::ReadWrite);
}
#[test]
fn transform_preserves_entry_body_unchanged() {
let original_entry = entry();
let program = Program::wrapped(
vec![buf("late", 5), buf("early", 0)],
[1, 1, 1],
original_entry.clone(),
);
let result = BufferDeclSortPass::transform(program);
assert_eq!(
result.program.entry().len(),
original_entry.len(),
"entry body length must be preserved"
);
}
#[test]
fn transform_handles_empty_buffer_table() {
let program = Program::wrapped(vec![], [1, 1, 1], vec![]);
let result = BufferDeclSortPass::transform(program);
assert!(!result.changed);
assert_eq!(result.program.buffers().len(), 0);
}
#[test]
fn transform_handles_single_buffer_no_op() {
let program = Program::wrapped(vec![buf("only", 0)], [1, 1, 1], entry());
let result = BufferDeclSortPass::transform(program);
assert!(!result.changed);
}
#[test]
fn transform_sorts_many_scrambled_bindings() {
let bindings = [7, 2, 5, 0, 3, 9, 1, 8, 4, 6];
let buffers: Vec<BufferDecl> = bindings
.iter()
.map(|b| buf(&format!("buf_{b}"), *b))
.collect();
let program = Program::wrapped(buffers, [1, 1, 1], entry());
let result = BufferDeclSortPass::transform(program);
assert!(result.changed);
let sorted_bindings: Vec<u32> =
result.program.buffers().iter().map(|b| b.binding).collect();
assert_eq!(sorted_bindings, vec![0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
}
#[test]
fn transform_is_idempotent() {
let program = Program::wrapped(
vec![buf("c", 5), buf("a", 1), buf("b", 3)],
[1, 1, 1],
entry(),
);
let once = BufferDeclSortPass::transform(program);
let twice = BufferDeclSortPass::transform(once.program.clone());
assert!(once.changed);
assert!(!twice.changed, "second run must report no change");
let once_names: Vec<&str> = once
.program
.buffers()
.iter()
.map(|b| b.name.as_ref())
.collect();
let twice_names: Vec<&str> = twice
.program
.buffers()
.iter()
.map(|b| b.name.as_ref())
.collect();
assert_eq!(once_names, twice_names);
}
#[test]
fn fingerprint_returns_stable_value() {
let program = Program::wrapped(vec![buf("a", 0)], [1, 1, 1], entry());
let fp1 = BufferDeclSortPass::fingerprint(&program);
let fp2 = BufferDeclSortPass::fingerprint(&program);
assert_eq!(fp1, fp2);
}
#[test]
fn already_sorted_with_tied_bindings_is_skipped() {
let program = Program::wrapped(vec![buf("alpha", 3), buf("beta", 3)], [1, 1, 1], entry());
assert_eq!(BufferDeclSortPass::analyze(&program), PassAnalysis::SKIP);
}
#[test]
fn unsorted_with_tied_bindings_runs() {
let program = Program::wrapped(vec![buf("beta", 3), buf("alpha", 3)], [1, 1, 1], entry());
assert_eq!(BufferDeclSortPass::analyze(&program), PassAnalysis::RUN);
}
}