#![allow(dead_code)]
use proptest::collection::vec as prop_vec;
use proptest::prelude::*;
use std::sync::Arc;
use vyre_foundation::ir::model::program::ProgramStats;
use vyre_foundation::ir::{
AtomicOp, BinOp, BufferDecl, DataType, Expr, ExprNode, Node, NodeExtension, Program, UnOp,
};
use vyre_foundation::MemoryOrdering;
use vyre_spec::data_type::TypeId;
use vyre_spec::extension::{
ExtensionAtomicOpId, ExtensionBinOpId, ExtensionDataTypeId, ExtensionUnOpId,
};
const CAP_SUBGROUP_OPS: u32 = 1 << 0;
const CAP_F16: u32 = 1 << 1;
const CAP_BF16: u32 = 1 << 2;
const CAP_F64: u32 = 1 << 3;
const CAP_ASYNC_DISPATCH: u32 = 1 << 4;
const CAP_INDIRECT_DISPATCH: u32 = 1 << 5;
const CAP_TENSOR_OPS: u32 = 1 << 6;
const CAP_TRAP: u32 = 1 << 7;
#[derive(Debug)]
struct TestOpaqueExpr;
impl ExprNode for TestOpaqueExpr {
fn extension_kind(&self) -> &'static str {
"test.stats.expr"
}
fn debug_identity(&self) -> &str {
"test-expr"
}
fn result_type(&self) -> Option<DataType> {
Some(DataType::U32)
}
fn cse_safe(&self) -> bool {
true
}
fn stable_fingerprint(&self) -> [u8; 32] {
[0; 32]
}
fn validate_extension(&self) -> Result<(), String> {
Ok(())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
#[derive(Debug)]
struct TestOpaqueNode;
impl NodeExtension for TestOpaqueNode {
fn extension_kind(&self) -> &'static str {
"test.stats.node"
}
fn debug_identity(&self) -> &str {
"test-node"
}
fn stable_fingerprint(&self) -> [u8; 32] {
[0; 32]
}
fn validate_extension(&self) -> Result<(), String> {
Ok(())
}
fn as_any(&self) -> &dyn std::any::Any {
self
}
}
const VAR_NAMES: &[&str] = &["", "x", "alpha", "snow_雪", "nul\0name"];
const CALL_IDS: &[&str] = &[
"",
"call",
"筛选",
"op::雪",
"subgroup_reduce",
"my::wave::add",
"warp_shuffle",
];
const TAG_NAMES: &[&str] = &["", "tag", "stream-雪", "wait\0tag"];
const BUFFER_NAMES: &[&str] = &[
"out",
"input",
"rw",
"bytes_in",
"bytes_out",
"counts",
"scratch",
];
fn arb_ident() -> BoxedStrategy<String> {
prop::sample::select(VAR_NAMES.to_vec())
.prop_map(str::to_string)
.boxed()
}
fn arb_call_id() -> BoxedStrategy<String> {
prop::sample::select(CALL_IDS.to_vec())
.prop_map(str::to_string)
.boxed()
}
fn arb_tag() -> BoxedStrategy<String> {
prop::sample::select(TAG_NAMES.to_vec())
.prop_map(str::to_string)
.boxed()
}
fn arb_axis() -> BoxedStrategy<u8> {
prop_oneof![Just(0), Just(1), Just(2), Just(255)].boxed()
}
fn arb_datatype() -> BoxedStrategy<DataType> {
let leaf = prop_oneof![
Just(DataType::U8),
Just(DataType::U16),
Just(DataType::U32),
Just(DataType::I8),
Just(DataType::I16),
Just(DataType::I32),
Just(DataType::I64),
Just(DataType::U64),
Just(DataType::Vec2U32),
Just(DataType::Vec4U32),
Just(DataType::Bool),
Just(DataType::Bytes),
(0usize..=64).prop_map(|element_size| DataType::Array { element_size }),
Just(DataType::F16),
Just(DataType::BF16),
Just(DataType::F32),
Just(DataType::F64),
Just(DataType::Tensor),
any::<u32>().prop_map(|id| DataType::Handle(TypeId(id))),
any::<u32>().prop_map(|id| DataType::Opaque(ExtensionDataTypeId(id | 0x8000_0000))),
];
leaf.prop_recursive(3, 24, 3, |inner| {
prop_oneof![
(inner.clone(), 0u8..=4).prop_map(|(element, count)| DataType::Vec {
element: Box::new(element),
count,
}),
(inner.clone(), prop_vec(any::<u32>(), 0..=4)).prop_map(|(element, shape)| {
DataType::TensorShaped {
element: Box::new(element),
shape: shape.into_iter().collect(),
}
}),
]
})
.boxed()
}
fn arb_buffer_datatype() -> BoxedStrategy<DataType> {
prop_oneof![
Just(DataType::U8),
Just(DataType::U16),
Just(DataType::U32),
Just(DataType::I8),
Just(DataType::I16),
Just(DataType::I32),
Just(DataType::I64),
Just(DataType::U64),
Just(DataType::Vec2U32),
Just(DataType::Vec4U32),
Just(DataType::Bool),
Just(DataType::Bytes),
(0usize..=64).prop_map(|element_size| DataType::Array { element_size }),
Just(DataType::F16),
Just(DataType::BF16),
Just(DataType::F32),
Just(DataType::F64),
Just(DataType::Tensor),
]
.boxed()
}
fn arb_literal() -> BoxedStrategy<Expr> {
let adversarial_f32_bits = prop_oneof![
Just(0x0000_0000u32),
Just(0x0000_0001u32),
Just(0x007f_ffffu32),
Just(f32::MIN_POSITIVE.to_bits()),
Just(f32::MIN.to_bits()),
Just(f32::MAX.to_bits()),
any::<u32>().prop_filter("exclude NaN and -0.0", |bits| !f32::from_bits(*bits)
.is_nan()
&& *bits != (-0.0f32).to_bits()),
];
prop_oneof![
any::<u32>().prop_map(Expr::LitU32),
any::<i32>().prop_map(Expr::LitI32),
any::<bool>().prop_map(Expr::LitBool),
adversarial_f32_bits.prop_map(|bits| Expr::LitF32(f32::from_bits(bits))),
]
.boxed()
}
fn arb_expr() -> BoxedStrategy<Expr> {
let leaf = prop_oneof![
arb_literal(),
arb_ident().prop_map(Expr::var),
prop::sample::select(BUFFER_NAMES.to_vec()).prop_map(Expr::buf_len),
arb_axis().prop_map(|axis| Expr::InvocationId { axis }),
arb_axis().prop_map(|axis| Expr::WorkgroupId { axis }),
arb_axis().prop_map(|axis| Expr::LocalId { axis }),
Just(Expr::Opaque(Arc::new(TestOpaqueExpr))),
];
leaf.prop_recursive(4, 128, 4, |inner| {
prop_oneof![
(prop::sample::select(BUFFER_NAMES.to_vec()), inner.clone()).prop_map(
|(buffer, index)| Expr::Load {
buffer: buffer.into(),
index: Box::new(index),
}
),
(
prop_oneof![
Just(BinOp::Add),
Just(BinOp::Sub),
Just(BinOp::Mul),
Just(BinOp::Div),
Just(BinOp::Mod),
Just(BinOp::BitAnd),
Just(BinOp::BitOr),
Just(BinOp::BitXor),
Just(BinOp::Shl),
Just(BinOp::Shr),
Just(BinOp::Eq),
Just(BinOp::Ne),
Just(BinOp::Lt),
Just(BinOp::Gt),
Just(BinOp::Le),
Just(BinOp::Ge),
Just(BinOp::And),
Just(BinOp::Or),
Just(BinOp::AbsDiff),
Just(BinOp::Min),
Just(BinOp::Max),
Just(BinOp::SaturatingAdd),
Just(BinOp::SaturatingSub),
Just(BinOp::SaturatingMul),
Just(BinOp::Shuffle),
Just(BinOp::Ballot),
Just(BinOp::WaveReduce),
Just(BinOp::WaveBroadcast),
Just(BinOp::RotateLeft),
Just(BinOp::RotateRight),
any::<u32>().prop_map(|id| BinOp::Opaque(ExtensionBinOpId(id | 0x8000_0000))),
],
inner.clone(),
inner.clone(),
)
.prop_map(|(op, left, right)| Expr::BinOp {
op,
left: Box::new(left),
right: Box::new(right),
}),
(
prop_oneof![
Just(UnOp::Negate),
Just(UnOp::BitNot),
Just(UnOp::LogicalNot),
Just(UnOp::Popcount),
Just(UnOp::Clz),
Just(UnOp::Ctz),
Just(UnOp::ReverseBits),
Just(UnOp::Cos),
Just(UnOp::Sin),
Just(UnOp::Abs),
Just(UnOp::Sqrt),
Just(UnOp::Floor),
Just(UnOp::Ceil),
Just(UnOp::Round),
Just(UnOp::Trunc),
Just(UnOp::Sign),
Just(UnOp::IsNan),
Just(UnOp::IsInf),
Just(UnOp::IsFinite),
Just(UnOp::Exp),
Just(UnOp::Log),
Just(UnOp::Log2),
Just(UnOp::Exp2),
Just(UnOp::Tan),
Just(UnOp::Acos),
Just(UnOp::Asin),
Just(UnOp::Atan),
Just(UnOp::Tanh),
Just(UnOp::Sinh),
Just(UnOp::Cosh),
Just(UnOp::InverseSqrt),
Just(UnOp::Reciprocal),
Just(UnOp::Unpack4Low),
Just(UnOp::Unpack4High),
Just(UnOp::Unpack8Low),
Just(UnOp::Unpack8High),
any::<u32>().prop_map(|id| UnOp::Opaque(ExtensionUnOpId(id | 0x8000_0000))),
],
inner.clone(),
)
.prop_map(|(op, operand)| Expr::UnOp {
op,
operand: Box::new(operand),
}),
(arb_call_id(), prop_vec(inner.clone(), 0..=4)).prop_map(|(op_id, args)| Expr::Call {
op_id: op_id.into(),
args,
}),
(inner.clone(), inner.clone(), inner.clone()).prop_map(
|(cond, true_val, false_val)| Expr::Select {
cond: Box::new(cond),
true_val: Box::new(true_val),
false_val: Box::new(false_val),
}
),
(arb_buffer_datatype(), inner.clone()).prop_map(|(target, value)| Expr::Cast {
target,
value: Box::new(value),
}),
(inner.clone(), inner.clone(), inner.clone()).prop_map(|(a, b, c)| Expr::Fma {
a: Box::new(a),
b: Box::new(b),
c: Box::new(c),
}),
(
prop_oneof![
Just(AtomicOp::Add),
Just(AtomicOp::Or),
Just(AtomicOp::And),
Just(AtomicOp::Xor),
Just(AtomicOp::Min),
Just(AtomicOp::Max),
Just(AtomicOp::Exchange),
Just(AtomicOp::CompareExchange),
Just(AtomicOp::CompareExchangeWeak),
Just(AtomicOp::FetchNand),
Just(AtomicOp::LruUpdate),
any::<u32>()
.prop_map(|id| AtomicOp::Opaque(ExtensionAtomicOpId(id | 0x8000_0000))),
],
prop::sample::select(vec!["rw", "out", "counts", "bytes_out"]),
inner.clone(),
proptest::option::of(inner.clone()),
inner.clone(),
)
.prop_map(|(op, buffer, index, expected, value)| Expr::Atomic {
op,
buffer: buffer.into(),
index: Box::new(index),
expected: expected.map(Box::new),
value: Box::new(value),
ordering: MemoryOrdering::SeqCst,
}),
inner.clone().prop_map(|value| Expr::SubgroupAdd {
value: Box::new(value)
}),
(inner.clone(), inner.clone()).prop_map(|(value, lane)| Expr::SubgroupShuffle {
value: Box::new(value),
lane: Box::new(lane),
}),
inner.prop_map(|cond| Expr::SubgroupBallot {
cond: Box::new(cond)
}),
]
})
.boxed()
}
fn arb_node_with_depth(depth: u32) -> BoxedStrategy<Node> {
let leaf = prop_oneof![
(arb_ident(), arb_expr()).prop_map(|(name, value)| Node::Let {
name: name.into(),
value
}),
(arb_ident(), arb_expr()).prop_map(|(name, value)| Node::Assign {
name: name.into(),
value
}),
(
prop::sample::select(vec!["out", "rw", "bytes_out"]),
arb_expr(),
arb_expr(),
)
.prop_map(|(buffer, index, value)| Node::Store {
buffer: buffer.into(),
index,
value,
}),
Just(Node::Return),
Just(Node::barrier()),
];
if depth == 0 {
return leaf.boxed();
}
let deeper = arb_node_with_depth(depth - 1);
leaf.prop_recursive(3, 64, 3, move |inner| {
prop_oneof![
(
arb_expr(),
prop_vec(inner.clone(), 0..=3),
prop_vec(inner.clone(), 0..=3),
)
.prop_map(|(cond, then, otherwise)| Node::If {
cond,
then,
otherwise
}),
(
arb_ident(),
arb_expr(),
arb_expr(),
prop_vec(inner.clone(), 0..=3),
)
.prop_map(|(var, from, to, body)| Node::Loop {
var: var.into(),
from,
to,
body
}),
prop_vec(inner.clone(), 0..=3).prop_map(Node::Block),
(arb_ident(), prop_vec(deeper.clone(), 0..=3),).prop_map(|(generator, body)| {
Node::Region {
generator: generator.into(),
source_region: None,
body: Arc::new(body),
}
}),
(arb_ident(), arb_ident(), arb_expr(), arb_expr(), arb_tag(),).prop_map(
|(source, destination, offset, size, tag)| Node::AsyncLoad {
source: source.into(),
destination: destination.into(),
offset: Box::new(offset),
size: Box::new(size),
tag: tag.into(),
}
),
(arb_ident(), arb_ident(), arb_expr(), arb_expr(), arb_tag(),).prop_map(
|(source, destination, offset, size, tag)| Node::AsyncStore {
source: source.into(),
destination: destination.into(),
offset: Box::new(offset),
size: Box::new(size),
tag: tag.into(),
}
),
arb_tag().prop_map(|tag| Node::AsyncWait { tag: tag.into() }),
(arb_ident(), any::<u64>()).prop_map(|(count_buffer, count_offset)| {
Node::IndirectDispatch {
count_buffer: count_buffer.into(),
count_offset,
}
}),
(arb_expr(), arb_tag()).prop_map(|(address, tag)| Node::Trap {
address: Box::new(address),
tag: tag.into(),
}),
arb_tag().prop_map(|tag| Node::Resume { tag: tag.into() }),
Just(Node::Opaque(Arc::new(TestOpaqueNode))),
]
})
.boxed()
}
fn arb_node() -> BoxedStrategy<Node> {
arb_node_with_depth(3)
}
fn arb_program() -> BoxedStrategy<Program> {
(
arb_buffer_datatype(),
arb_buffer_datatype(),
prop_vec(arb_node(), 0..=6),
prop_oneof![9 => Just(false), 1 => Just(true)],
)
.prop_map(|(extra_a, extra_b, entry, non_composable)| {
Program::wrapped(
vec![
BufferDecl::output("out", 0, DataType::U32)
.with_count(8)
.with_output_byte_range(0..16),
BufferDecl::read("input", 1, DataType::U32).with_count(8),
BufferDecl::read_write("rw", 2, DataType::U32).with_count(8),
BufferDecl::read("bytes_in", 3, DataType::Bytes).with_count(16),
BufferDecl::read_write("bytes_out", 4, DataType::Bytes).with_count(16),
BufferDecl::read("counts", 5, DataType::U32).with_count(8),
BufferDecl::workgroup("scratch", 4, DataType::U32),
BufferDecl::read("extra_a", 6, extra_a).with_count(1),
BufferDecl::read("extra_b", 7, extra_b).with_count(1),
],
[1, 1, 1],
entry,
)
.with_non_composable_with_self(non_composable)
})
.boxed()
}
#[inline]
fn mark_datatype_bits(ty: &DataType, bits: &mut u32) {
match ty {
DataType::F16 => *bits |= CAP_F16,
DataType::BF16 => *bits |= CAP_BF16,
DataType::F64 => *bits |= CAP_F64,
DataType::Tensor | DataType::TensorShaped { .. } => *bits |= CAP_TENSOR_OPS,
_ => {}
}
}
fn is_subgroup_intrinsic_id(op_id: &str) -> bool {
const MARKERS: &[&str] = &[
"subgroup_",
"::subgroup::",
"::subgroup",
"wave_",
"::wave::",
"warp_",
"::warp::",
];
MARKERS.iter().any(|marker| op_id.contains(marker))
}
#[allow(clippy::only_used_in_recursion)]
fn manual_walk_expr(
expr: &Expr,
nodes: &mut usize,
regions: &mut u32,
calls: &mut u32,
opaque: &mut u32,
bits: &mut u32,
) {
match expr {
Expr::SubgroupAdd { value } => {
*bits |= CAP_SUBGROUP_OPS;
manual_walk_expr(value, nodes, regions, calls, opaque, bits);
}
Expr::SubgroupBallot { cond } => {
*bits |= CAP_SUBGROUP_OPS;
manual_walk_expr(cond, nodes, regions, calls, opaque, bits);
}
Expr::SubgroupShuffle { value, lane } => {
*bits |= CAP_SUBGROUP_OPS;
manual_walk_expr(value, nodes, regions, calls, opaque, bits);
manual_walk_expr(lane, nodes, regions, calls, opaque, bits);
}
Expr::BinOp { left, right, .. } => {
manual_walk_expr(left, nodes, regions, calls, opaque, bits);
manual_walk_expr(right, nodes, regions, calls, opaque, bits);
}
Expr::UnOp { operand, .. } => {
manual_walk_expr(operand, nodes, regions, calls, opaque, bits)
}
Expr::Fma { a, b, c } => {
manual_walk_expr(a, nodes, regions, calls, opaque, bits);
manual_walk_expr(b, nodes, regions, calls, opaque, bits);
manual_walk_expr(c, nodes, regions, calls, opaque, bits);
}
Expr::Select {
cond,
true_val,
false_val,
} => {
manual_walk_expr(cond, nodes, regions, calls, opaque, bits);
manual_walk_expr(true_val, nodes, regions, calls, opaque, bits);
manual_walk_expr(false_val, nodes, regions, calls, opaque, bits);
}
Expr::Cast { target, value } => {
mark_datatype_bits(target, bits);
manual_walk_expr(value, nodes, regions, calls, opaque, bits);
}
Expr::Load { index, .. } => manual_walk_expr(index, nodes, regions, calls, opaque, bits),
Expr::Call { op_id, args } => {
if is_subgroup_intrinsic_id(op_id.as_str()) {
*bits |= CAP_SUBGROUP_OPS;
}
*calls = calls.saturating_add(1);
for arg in args.iter() {
manual_walk_expr(arg, nodes, regions, calls, opaque, bits);
}
}
Expr::Atomic {
index,
expected,
value,
..
} => {
manual_walk_expr(index, nodes, regions, calls, opaque, bits);
if let Some(expected) = expected.as_deref() {
manual_walk_expr(expected, nodes, regions, calls, opaque, bits);
}
manual_walk_expr(value, nodes, regions, calls, opaque, bits);
}
Expr::Opaque(_) => {
*opaque = opaque.saturating_add(1);
}
Expr::LitU32(_)
| Expr::LitI32(_)
| Expr::LitF32(_)
| Expr::LitBool(_)
| Expr::Var(_)
| Expr::BufLen { .. }
| Expr::InvocationId { .. }
| Expr::WorkgroupId { .. }
| Expr::LocalId { .. }
| Expr::SubgroupLocalId
| Expr::SubgroupSize => {}
_ => {}
}
}
fn manual_walk_node(
node: &Node,
nodes: &mut usize,
regions: &mut u32,
calls: &mut u32,
opaque: &mut u32,
bits: &mut u32,
) {
*nodes = nodes.saturating_add(1);
match node {
Node::Let { value, .. } | Node::Assign { value, .. } => {
manual_walk_expr(value, nodes, regions, calls, opaque, bits);
}
Node::Store { index, value, .. } => {
manual_walk_expr(index, nodes, regions, calls, opaque, bits);
manual_walk_expr(value, nodes, regions, calls, opaque, bits);
}
Node::If {
cond,
then,
otherwise,
} => {
manual_walk_expr(cond, nodes, regions, calls, opaque, bits);
for child in then.iter().chain(otherwise.iter()) {
manual_walk_node(child, nodes, regions, calls, opaque, bits);
}
}
Node::Loop { from, to, body, .. } => {
manual_walk_expr(from, nodes, regions, calls, opaque, bits);
manual_walk_expr(to, nodes, regions, calls, opaque, bits);
for child in body.iter() {
manual_walk_node(child, nodes, regions, calls, opaque, bits);
}
}
Node::Block(children) => {
for child in children.iter() {
manual_walk_node(child, nodes, regions, calls, opaque, bits);
}
}
Node::Region { body, .. } => {
*regions = regions.saturating_add(1);
for child in body.iter() {
manual_walk_node(child, nodes, regions, calls, opaque, bits);
}
}
Node::AsyncLoad { offset, size, .. } | Node::AsyncStore { offset, size, .. } => {
*bits |= CAP_ASYNC_DISPATCH;
manual_walk_expr(offset, nodes, regions, calls, opaque, bits);
manual_walk_expr(size, nodes, regions, calls, opaque, bits);
}
Node::AsyncWait { .. } => {
*bits |= CAP_ASYNC_DISPATCH;
}
Node::IndirectDispatch { .. } => {
*bits |= CAP_INDIRECT_DISPATCH;
}
Node::Trap { address, .. } => {
*bits |= CAP_TRAP;
manual_walk_expr(address, nodes, regions, calls, opaque, bits);
}
Node::Opaque(_) => {
*opaque = opaque.saturating_add(1);
}
Node::Return | Node::Barrier { .. } | Node::Resume { .. } => {}
_ => {}
}
}
fn manual_compute_stats(program: &Program) -> ProgramStats {
let mut node_count = 0usize;
let mut region_count = 0u32;
let mut call_count = 0u32;
let mut opaque_count = 0u32;
let mut capability_bits = 0u32;
let mut static_storage_bytes = 0u64;
for decl in program.buffers().iter() {
let count = decl.count();
if count != 0 {
if let Some(elem) = decl.element().size_bytes() {
static_storage_bytes =
static_storage_bytes.saturating_add(u64::from(count) * elem as u64);
}
}
mark_datatype_bits(&decl.element(), &mut capability_bits);
}
for node in program.entry().iter() {
manual_walk_node(
node,
&mut node_count,
&mut region_count,
&mut call_count,
&mut opaque_count,
&mut capability_bits,
);
}
let top_level_regions = program
.entry()
.iter()
.filter(|n| matches!(n, Node::Region { .. }))
.count() as u32;
ProgramStats {
node_count,
region_count,
call_count,
opaque_count,
top_level_regions,
static_storage_bytes,
capability_bits,
..ProgramStats::default()
}
}
proptest! {
#![proptest_config(ProptestConfig {
cases: 50,
.. ProptestConfig::default()
})]
#[test]
fn program_stats_cache_invariants(program in arb_program()) {
let stats_a = program.stats();
let stats_b = program.stats();
prop_assert!(
std::ptr::eq(stats_a, stats_b),
"program.stats() must return the same cached reference on repeated calls"
);
let manual = manual_compute_stats(&program);
let cached = program.stats();
prop_assert_eq!(
cached.node_count, manual.node_count,
"node_count mismatch: cached={}, manual={}", cached.node_count, manual.node_count
);
prop_assert_eq!(
cached.region_count, manual.region_count,
"region_count mismatch: cached={}, manual={}", cached.region_count, manual.region_count
);
prop_assert_eq!(
cached.call_count, manual.call_count,
"call_count mismatch: cached={}, manual={}", cached.call_count, manual.call_count
);
prop_assert_eq!(
cached.opaque_count, manual.opaque_count,
"opaque_count mismatch: cached={}, manual={}", cached.opaque_count, manual.opaque_count
);
prop_assert_eq!(
cached.top_level_regions, manual.top_level_regions,
"top_level_regions mismatch: cached={}, manual={}", cached.top_level_regions, manual.top_level_regions
);
prop_assert_eq!(
cached.static_storage_bytes, manual.static_storage_bytes,
"static_storage_bytes mismatch: cached={}, manual={}", cached.static_storage_bytes, manual.static_storage_bytes
);
prop_assert_eq!(
cached.capability_bits, manual.capability_bits,
"capability_bits mismatch: cached={:08b}, manual={:08b}", cached.capability_bits, manual.capability_bits
);
}
}