use crate::ir::{Expr, Node, Program};
use crate::optimizer::AdapterCaps;
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
#[non_exhaustive]
pub struct CostCertificate {
pub node_count: usize,
pub instruction_count: u64,
pub memory_op_count: u64,
pub atomic_op_count: u64,
pub control_flow_count: u64,
pub register_pressure_estimate: u32,
pub static_storage_bytes: u64,
pub divergence_score: u64,
}
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq)]
pub struct DeviceCostEstimate {
pub base: CostCertificate,
pub vector_pack_bits: u32,
pub unroll_depth: u32,
pub workgroup_tile: [u32; 3],
pub score: u64,
}
impl CostCertificate {
#[must_use]
pub fn for_program(program: &Program) -> Self {
let stats = program.stats();
let mut divergence_score = 0u64;
for node in program.entry().iter() {
count_divergent_patterns(node, &mut divergence_score);
}
Self {
node_count: stats.node_count,
instruction_count: stats.instruction_count,
memory_op_count: stats.memory_op_count,
atomic_op_count: stats.atomic_op_count,
control_flow_count: stats.control_flow_count,
register_pressure_estimate: stats.register_pressure_estimate,
static_storage_bytes: stats.static_storage_bytes,
divergence_score,
}
}
#[must_use]
pub fn estimate_for_adapter(&self, caps: &AdapterCaps) -> DeviceCostEstimate {
let policy = crate::execution_plan::SchedulingPolicy::standard();
let vector_pack_bits = policy.select_vector_pack_bits(32, caps);
let unroll_depth = policy.select_unroll_depth(None, caps);
let workgroup_tile = policy.select_workgroup_tile([1, 1, 1], None, caps);
let vector_divisor = u64::from((vector_pack_bits / 32).max(1));
let unroll_divisor = u64::from(unroll_depth.max(1));
let tile_lanes = u64::from(
workgroup_tile[0]
.saturating_mul(workgroup_tile[1])
.saturating_mul(workgroup_tile[2])
.max(1),
);
let memory_component = self.memory_op_count.saturating_mul(1024) / vector_divisor;
let instruction_component = self.instruction_count.saturating_mul(1024) / unroll_divisor;
let occupancy_component =
u64::from(self.register_pressure_estimate).saturating_mul(1024) / tile_lanes.min(1024);
DeviceCostEstimate {
base: *self,
vector_pack_bits,
unroll_depth,
workgroup_tile,
score: memory_component
.saturating_add(instruction_component)
.saturating_add(occupancy_component)
.saturating_add(self.atomic_op_count.saturating_mul(2048))
.saturating_add(self.divergence_score.saturating_mul(4096)),
}
}
#[must_use]
pub fn for_program_on_adapter(program: &Program, caps: &AdapterCaps) -> DeviceCostEstimate {
Self::for_program(program).estimate_for_adapter(caps)
}
#[must_use]
pub fn dominates_or_equal(&self, other: &Self) -> bool {
self.node_count <= other.node_count
&& self.instruction_count <= other.instruction_count
&& self.memory_op_count <= other.memory_op_count
&& self.atomic_op_count <= other.atomic_op_count
&& self.control_flow_count <= other.control_flow_count
&& self.register_pressure_estimate <= other.register_pressure_estimate
&& self.static_storage_bytes <= other.static_storage_bytes
&& self.divergence_score <= other.divergence_score
}
#[must_use]
pub fn dimensions_increased_over(&self, other: &Self) -> Vec<&'static str> {
let mut out = Vec::with_capacity(8);
if self.node_count > other.node_count {
out.push("node_count");
}
if self.instruction_count > other.instruction_count {
out.push("instruction_count");
}
if self.memory_op_count > other.memory_op_count {
out.push("memory_op_count");
}
if self.atomic_op_count > other.atomic_op_count {
out.push("atomic_op_count");
}
if self.control_flow_count > other.control_flow_count {
out.push("control_flow_count");
}
if self.register_pressure_estimate > other.register_pressure_estimate {
out.push("register_pressure_estimate");
}
if self.static_storage_bytes > other.static_storage_bytes {
out.push("static_storage_bytes");
}
if self.divergence_score > other.divergence_score {
out.push("divergence_score");
}
out
}
}
fn count_divergent_patterns(node: &Node, score: &mut u64) {
let _ = crate::visit::node_map::any_descendant(node, &mut |n| {
if let Node::If { cond, .. } = n {
if is_invocation_id_eq_constant(cond) {
*score = score.saturating_add(1);
}
}
false
});
}
fn is_invocation_id_eq_constant(cond: &Expr) -> bool {
use crate::ir::BinOp;
match cond {
Expr::BinOp {
op: BinOp::Eq | BinOp::Ne,
left,
right,
} => {
is_invocation_id_expr(left) && matches!(**right, Expr::LitU32(_))
|| is_invocation_id_expr(right) && matches!(**left, Expr::LitU32(_))
}
Expr::BinOp { .. } => false,
_ => false,
}
}
fn is_invocation_id_expr(expr: &Expr) -> bool {
matches!(
expr,
Expr::InvocationId { .. } | Expr::LocalId { .. } | Expr::SubgroupLocalId
)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ir::{BinOp, BufferAccess, BufferDecl, DataType, Expr, Node, Program};
fn trivial_program() -> Program {
Program::wrapped(
vec![
BufferDecl::storage("buf", 0, BufferAccess::ReadWrite, DataType::U32).with_count(4),
],
[1, 1, 1],
vec![Node::store("buf", Expr::u32(0), Expr::u32(7))],
)
}
#[test]
fn for_program_returns_zero_divergence_on_pure_program() {
let cost = CostCertificate::for_program(&trivial_program());
assert_eq!(cost.divergence_score, 0);
assert!(cost.memory_op_count >= 1, "trivial program has one store");
}
#[test]
fn dominates_or_equal_is_reflexive() {
let cost = CostCertificate::for_program(&trivial_program());
assert!(cost.dominates_or_equal(&cost));
assert!(cost.dimensions_increased_over(&cost).is_empty());
}
#[test]
fn dominates_or_equal_detects_per_dimension_increase() {
let mut a = CostCertificate::default();
let mut b = CostCertificate {
atomic_op_count: 1,
..Default::default()
};
assert!(a.dominates_or_equal(&b));
assert!(!b.dominates_or_equal(&a));
let increased = b.dimensions_increased_over(&a);
assert_eq!(increased, vec!["atomic_op_count"]);
a.node_count = 0;
b.node_count = 5;
b.divergence_score = 2;
let increased = b.dimensions_increased_over(&a);
assert!(increased.contains(&"node_count"));
assert!(increased.contains(&"atomic_op_count"));
assert!(increased.contains(&"divergence_score"));
}
#[test]
fn divergence_score_counts_invocation_id_eq_constant() {
let program = Program::wrapped(
vec![
BufferDecl::storage("buf", 0, BufferAccess::ReadWrite, DataType::U32).with_count(4),
],
[256, 1, 1],
vec![Node::if_then(
Expr::BinOp {
op: BinOp::Eq,
left: Box::new(Expr::gid_x()),
right: Box::new(Expr::u32(0)),
},
vec![Node::store("buf", Expr::u32(0), Expr::u32(1))],
)],
);
let cost = CostCertificate::for_program(&program);
assert_eq!(
cost.divergence_score, 1,
"divergence walker must count an `if invocation_id == K {{ ... }}` pattern exactly once"
);
}
#[test]
fn divergence_score_ignores_non_thread_id_comparisons() {
let program = Program::wrapped(
vec![
BufferDecl::storage("buf", 0, BufferAccess::ReadWrite, DataType::U32).with_count(4),
],
[256, 1, 1],
vec![Node::if_then(
Expr::BinOp {
op: BinOp::Lt,
left: Box::new(Expr::load("buf", Expr::u32(0))),
right: Box::new(Expr::u32(5)),
},
vec![Node::store("buf", Expr::u32(0), Expr::u32(1))],
)],
);
let cost = CostCertificate::for_program(&program);
assert_eq!(
cost.divergence_score, 0,
"divergence walker must NOT count branches whose condition isn't a thread-id-vs-constant"
);
assert!(
cost.control_flow_count >= 1,
"branches still count toward control_flow_count regardless of divergence shape"
);
}
#[test]
fn divergence_score_recurses_into_nested_regions() {
let inner = Node::if_then(
Expr::BinOp {
op: BinOp::Eq,
left: Box::new(Expr::gid_x()),
right: Box::new(Expr::u32(1)),
},
vec![Node::store("buf", Expr::u32(1), Expr::u32(7))],
);
let outer = Node::if_then(
Expr::BinOp {
op: BinOp::Eq,
left: Box::new(Expr::gid_x()),
right: Box::new(Expr::u32(0)),
},
vec![inner],
);
let program = Program::wrapped(
vec![
BufferDecl::storage("buf", 0, BufferAccess::ReadWrite, DataType::U32).with_count(4),
],
[256, 1, 1],
vec![outer],
);
let cost = CostCertificate::for_program(&program);
assert_eq!(
cost.divergence_score, 2,
"nested divergence patterns must be counted at every depth"
);
}
#[test]
fn device_profile_fields_change_cost_projection() {
let program = Program::wrapped(
vec![
BufferDecl::storage("buf", 0, BufferAccess::ReadWrite, DataType::U32)
.with_count(4096),
],
[1, 1, 1],
vec![
Node::let_bind("x", Expr::load("buf", Expr::gid_x())),
Node::store("buf", Expr::gid_x(), Expr::var("x")),
],
);
let compact = AdapterCaps {
max_workgroup_size: [256, 256, 64],
max_invocations_per_workgroup: 256,
ideal_unroll_depth: 4,
ideal_vector_pack_bits: 64,
ideal_workgroup_tile: [8, 8, 1],
..AdapterCaps::conservative()
};
let wide = AdapterCaps {
ideal_unroll_depth: 8,
ideal_vector_pack_bits: 128,
ideal_workgroup_tile: [16, 16, 1],
..compact
};
let compact_cost = CostCertificate::for_program_on_adapter(&program, &compact);
let wide_cost = CostCertificate::for_program_on_adapter(&program, &wide);
assert_eq!(compact_cost.vector_pack_bits, 64);
assert_eq!(wide_cost.vector_pack_bits, 128);
assert_eq!(compact_cost.unroll_depth, 4);
assert_eq!(wide_cost.unroll_depth, 8);
assert_eq!(compact_cost.workgroup_tile, [8, 8, 1]);
assert_eq!(wide_cost.workgroup_tile, [16, 16, 1]);
assert!(
wide_cost.score < compact_cost.score,
"Fix: wider profile vector/unroll/tile facts must lower the projected device cost"
);
}
#[test]
fn walker_matches_canonical_on_corpus() {
fn count_divergent_patterns_old(node: &Node, score: &mut u64, visited: &mut Vec<Node>) {
let mut stack: smallvec::SmallVec<[&Node; 64]> = smallvec::SmallVec::new();
stack.push(node);
while let Some(node) = stack.pop() {
visited.push(node.clone());
match node {
Node::If {
cond,
then,
otherwise,
} => {
if super::is_invocation_id_eq_constant(cond) {
*score = score.saturating_add(1);
}
stack.extend(otherwise.iter());
stack.extend(then.iter());
}
Node::Loop { body, .. } | Node::Block(body) => {
stack.extend(body.iter());
}
Node::Region { body, .. } => stack.extend(body.iter()),
_ => {}
}
}
}
let inner = Node::if_then(
Expr::BinOp {
op: BinOp::Eq,
left: Box::new(Expr::gid_x()),
right: Box::new(Expr::u32(1)),
},
vec![Node::store("buf", Expr::u32(1), Expr::u32(7))],
);
let outer = Node::if_then(
Expr::BinOp {
op: BinOp::Eq,
left: Box::new(Expr::gid_x()),
right: Box::new(Expr::u32(0)),
},
vec![inner, Node::Block(vec![Node::Return])],
);
let mut score_old = 0;
let mut visited_old = Vec::new();
count_divergent_patterns_old(&outer, &mut score_old, &mut visited_old);
let mut score_new = 0;
let mut visited_new = Vec::new();
let _ = crate::visit::node_map::any_descendant(&outer, &mut |n| {
visited_new.push(n.clone());
if let Node::If { cond, .. } = n {
if super::is_invocation_id_eq_constant(cond) {
score_new += 1;
}
}
false
});
assert_eq!(score_old, score_new, "Divergence score mismatch");
assert_eq!(
visited_old.len(),
visited_new.len(),
"Node set length mismatch"
);
for node in &visited_old {
assert!(
visited_new.contains(node),
"Old walker visited a node that the new canonical walker missed"
);
}
for node in &visited_new {
assert!(
visited_old.contains(node),
"New canonical walker visited a node that the old walker missed"
);
}
}
}