use cutile;
use cutile_compiler::compiler::utils::CompileOptions;
use cutile_compiler::compiler::{CUDATileFunctionCompiler, CUDATileModules};
use cutile_compiler::cuda_tile_runtime_utils::get_gpu_name;
mod common;
#[cutile::module]
mod reduce_scan_ops_module {
use cutile::core::*;
#[cutile::entry()]
fn scan_sum_test_kernel<const S: [i32; 1]>(output: &mut Tensor<f32, S>) {
let tile: Tile<f32, S> = load_tile_mut(output);
let prefix_sums: Tile<f32, S> = scan_sum(tile, 0i32, false, 0.0f32);
output.store(prefix_sums);
}
#[cutile::entry()]
fn reduce_closure_test_kernel<const S: [i32; 1]>(
input: &mut Tensor<f32, S>,
result: &mut Tensor<f32, { [1] }>,
) {
let tile: Tile<f32, S> = load_tile_mut(input);
let sum_scalar = reduce(tile, 0i32, 0.0f32, |acc, x| acc + x);
let sum_as_array: Tile<f32, { [1] }> = sum_scalar.reshape(const_shape![1]);
result.store(sum_as_array);
}
#[cutile::entry()]
fn reduce_product_closure_test_kernel<const S: [i32; 1]>(
input: &mut Tensor<f32, S>,
result: &mut Tensor<f32, { [1] }>,
) {
let tile: Tile<f32, S> = load_tile_mut(input);
let product_scalar = reduce(tile, 0i32, 1.0f32, |acc, x| acc * x);
let product_as_array: Tile<f32, { [1] }> = product_scalar.reshape(const_shape![1]);
result.store(product_as_array);
}
#[cutile::entry()]
fn reduce_max_closure_test_kernel<const S: [i32; 1]>(
input: &mut Tensor<f32, S>,
result: &mut Tensor<f32, { [1] }>,
) {
let tile: Tile<f32, S> = load_tile_mut(input);
let max_scalar = reduce(tile, 0i32, f32::NEG_INFINITY, |acc, x| max(acc, x));
let max_as_array: Tile<f32, { [1] }> = max_scalar.reshape(const_shape![1]);
result.store(max_as_array);
}
#[cutile::entry()]
fn scan_closure_test_kernel<const S: [i32; 1]>(output: &mut Tensor<f32, S>) {
let tile: Tile<f32, S> = load_tile_mut(output);
let prefix_products: Tile<f32, S> = scan(tile, 0i32, false, 1.0f32, |acc, x| acc * x);
output.store(prefix_products);
}
}
use reduce_scan_ops_module::_module_asts;
#[test]
fn compile_scan_sum_test() -> () {
common::with_test_stack(|| {
let modules =
CUDATileModules::new(_module_asts()).expect("Failed to create CUDATileModules");
let gpu_name = get_gpu_name(0);
let compiler = CUDATileFunctionCompiler::new(
&modules,
"reduce_scan_ops_module",
"scan_sum_test_kernel",
&[128.to_string()],
&[("output", &[1])],
&[],
&[],
None,
gpu_name,
&CompileOptions::default(),
)
.expect("Failed.");
let module_op_str = compiler
.compile()
.expect("Failed.")
.to_string();
println!("\n=== SCAN_SUM MLIR ===\n{}", module_op_str);
assert!(
module_op_str.contains("scan"),
"Expected scan operation in MLIR output"
);
assert!(
module_op_str.contains("yield"),
"Expected yield in scan region"
);
assert!(
module_op_str.contains("reverse=false") || module_op_str.contains("reverse = false"),
"Expected reverse=false in scan operation"
);
println!("\n✓ scan_sum operation verified (with prefix sum scan region)");
});
}
#[test]
fn compile_reduce_closure_test() -> () {
common::with_test_stack(|| {
let modules =
CUDATileModules::new(_module_asts()).expect("Failed to create CUDATileModules");
let gpu_name = get_gpu_name(0);
let compiler = CUDATileFunctionCompiler::new(
&modules,
"reduce_scan_ops_module",
"reduce_closure_test_kernel",
&[128.to_string()],
&[("input", &[1]), ("result", &[1])],
&[],
&[],
None,
gpu_name,
&CompileOptions::default(),
)
.expect("Failed.");
let module_op_str = compiler
.compile()
.expect("Failed.")
.to_string();
println!(
"\n=== REDUCE WITH CLOSURE (SUM) MLIR ===\n{}",
module_op_str
);
assert!(
module_op_str.contains("reduce"),
"Expected reduce operation in MLIR output"
);
assert!(
module_op_str.contains("addf") || module_op_str.contains("addi"),
"Expected add operation in reduce region"
);
assert!(
module_op_str.contains("yield"),
"Expected yield in reduce region"
);
println!("\n✓ reduce with closure (sum) operation verified");
});
}
#[test]
fn compile_reduce_product_closure_test() -> () {
common::with_test_stack(|| {
let modules =
CUDATileModules::new(_module_asts()).expect("Failed to create CUDATileModules");
let gpu_name = get_gpu_name(0);
let compiler = CUDATileFunctionCompiler::new(
&modules,
"reduce_scan_ops_module",
"reduce_product_closure_test_kernel",
&[128.to_string()],
&[("input", &[1]), ("result", &[1])],
&[],
&[],
None,
gpu_name,
&CompileOptions::default(),
)
.expect("Failed.");
let module_op_str = compiler
.compile()
.expect("Failed.")
.to_string();
println!(
"\n=== REDUCE WITH CLOSURE (PRODUCT) MLIR ===\n{}",
module_op_str
);
assert!(
module_op_str.contains("reduce"),
"Expected reduce operation in MLIR output"
);
assert!(
module_op_str.contains("mulf") || module_op_str.contains("muli"),
"Expected multiply operation in reduce region"
);
assert!(
module_op_str.contains("yield"),
"Expected yield in reduce region"
);
println!("\n✓ reduce with closure (product) operation verified");
});
}
#[test]
fn compile_reduce_max_closure_test() -> () {
common::with_test_stack(|| {
let modules =
CUDATileModules::new(_module_asts()).expect("Failed to create CUDATileModules");
let gpu_name = get_gpu_name(0);
let compiler = CUDATileFunctionCompiler::new(
&modules,
"reduce_scan_ops_module",
"reduce_max_closure_test_kernel",
&[128.to_string()],
&[("input", &[1]), ("result", &[1])],
&[],
&[],
None,
gpu_name,
&CompileOptions::default(),
)
.expect("Failed.");
let module_op_str = compiler
.compile()
.expect("Failed.")
.to_string();
println!(
"\n=== REDUCE WITH CLOSURE (MAX) MLIR ===\n{}",
module_op_str
);
assert!(
module_op_str.contains("reduce"),
"Expected reduce operation in MLIR output"
);
assert!(
module_op_str.contains("maxf"),
"Expected maxf operation in reduce region"
);
assert!(
module_op_str.contains("yield"),
"Expected yield in reduce region"
);
println!("\n✓ reduce with closure (max) operation verified");
});
}
#[test]
fn compile_scan_closure_test() -> () {
common::with_test_stack(|| {
let modules =
CUDATileModules::new(_module_asts()).expect("Failed to create CUDATileModules");
let gpu_name = get_gpu_name(0);
let compiler = CUDATileFunctionCompiler::new(
&modules,
"reduce_scan_ops_module",
"scan_closure_test_kernel",
&[128.to_string()],
&[("output", &[1])],
&[],
&[],
None,
gpu_name,
&CompileOptions::default(),
)
.expect("Failed.");
let module_op_str = compiler
.compile()
.expect("Failed.")
.to_string();
println!(
"\n=== SCAN WITH CLOSURE (PREFIX PRODUCT) MLIR ===\n{}",
module_op_str
);
assert!(
module_op_str.contains("scan"),
"Expected scan operation in MLIR output"
);
assert!(
module_op_str.contains("mulf") || module_op_str.contains("muli"),
"Expected multiply operation in scan region"
);
assert!(
module_op_str.contains("yield"),
"Expected yield in scan region"
);
assert!(
module_op_str.contains("reverse=false") || module_op_str.contains("reverse = false"),
"Expected reverse=false in scan operation"
);
println!("\n✓ scan with closure (prefix product) operation verified");
});
}