use super::{Program, ProgramStats};
use crate::ir::{BufferAccess, BufferDecl, DataType, Expr, Node};
#[test]
fn stats_matches_old_multi_walk_empty() {
let program = Program::empty();
let stats = program.stats();
assert_eq!(
*stats,
ProgramStats {
node_count: 1, region_count: 1,
call_count: 0,
opaque_count: 0,
top_level_regions: 1,
static_storage_bytes: 0,
capability_bits: 0,
..ProgramStats::default()
}
);
}
#[test]
fn stats_matches_old_multi_walk_single_store() {
let program = Program::wrapped(
vec![BufferDecl::storage("out", 0, BufferAccess::ReadWrite, DataType::U32).with_count(1)],
[1, 1, 1],
vec![Node::store("out", Expr::u32(0), Expr::u32(7)), Node::Return],
);
let stats = program.stats();
assert_eq!(
*stats,
ProgramStats {
node_count: 3, region_count: 1,
call_count: 0,
opaque_count: 0,
top_level_regions: 1,
static_storage_bytes: 4,
instruction_count: 2,
memory_op_count: 1,
control_flow_count: 1,
capability_bits: 0,
..ProgramStats::default()
}
);
}
#[test]
fn stats_matches_old_multi_walk_batch() {
let program = Program::wrapped(
vec![BufferDecl::storage("out", 0, BufferAccess::ReadWrite, DataType::U32).with_count(4)],
[1, 1, 1],
vec![
Node::store("out", Expr::u32(0), Expr::u32(1)),
Node::store("out", Expr::u32(1), Expr::u32(2)),
Node::store("out", Expr::u32(2), Expr::u32(3)),
Node::Return,
],
);
let stats = program.stats();
assert_eq!(
*stats,
ProgramStats {
node_count: 5, region_count: 1,
call_count: 0,
opaque_count: 0,
top_level_regions: 1,
static_storage_bytes: 16,
instruction_count: 4,
memory_op_count: 3,
control_flow_count: 1,
capability_bits: 0,
..ProgramStats::default()
}
);
}
#[test]
fn stats_matches_old_multi_walk_region_chain() {
#[allow(deprecated)]
let program = Program::new(
vec![],
[1, 1, 1],
vec![
Node::Region {
generator: "a".into(),
source_region: None,
body: std::sync::Arc::new(vec![]),
},
Node::Region {
generator: "b".into(),
source_region: None,
body: std::sync::Arc::new(vec![]),
},
],
);
let stats = program.stats();
assert_eq!(
*stats,
ProgramStats {
node_count: 2, region_count: 2,
call_count: 0,
opaque_count: 0,
top_level_regions: 2,
static_storage_bytes: 0,
capability_bits: 0,
..ProgramStats::default()
}
);
}
#[test]
fn stats_matches_old_multi_walk_recursive() {
let program = Program::wrapped(
vec![BufferDecl::storage("out", 0, BufferAccess::ReadWrite, DataType::U32).with_count(1)],
[1, 1, 1],
vec![Node::loop_for(
"i",
Expr::u32(0),
Expr::u32(10),
vec![Node::let_bind("x", Expr::call("foo", vec![Expr::u32(1)]))],
)],
);
let stats = program.stats();
assert_eq!(
*stats,
ProgramStats {
node_count: 3, region_count: 1,
call_count: 1,
opaque_count: 0,
top_level_regions: 1,
static_storage_bytes: 4,
instruction_count: 3,
control_flow_count: 1,
register_pressure_estimate: 1,
capability_bits: 0,
..ProgramStats::default()
}
);
}
#[test]
fn stats_cache_hit_returns_same_reference() {
let program = Program::empty();
let s1 = program.stats();
let s2 = program.stats();
assert!(
std::ptr::eq(s1, s2),
"Fix: repeated stats() calls must return cached reference"
);
}