use cutile;
use cutile_compiler::compiler::utils::CompileOptions;
use cutile_compiler::compiler::{CUDATileFunctionCompiler, CUDATileModules};
use cutile_compiler::cuda_tile_runtime_utils::get_gpu_name;
mod common;
#[cutile::module]
mod binary_math_ops_module {
use cutile::core::*;
#[cutile::entry()]
fn minmax_kernel<const S: [i32; 1]>(output: &mut Tensor<f32, S>) {
let x: Tile<f32, S> = load_tile_mut(output);
let y: Tile<f32, S> = load_tile_mut(output);
let rem_result: Tile<f32, S> = remf(x, y);
let max_result: Tile<f32, S> = maxf(rem_result, y, nan::Disabled, ftz::Disabled);
let min_result: Tile<f32, S> = minf(max_result, y, nan::Disabled, ftz::Disabled);
output.store(min_result);
}
#[cutile::entry()]
fn select_kernel<const S: [i32; 1]>(output: &mut Tensor<f32, S>) {
let x: Tile<f32, S> = load_tile_mut(output);
let y: Tile<f32, S> = load_tile_mut(output);
let mask: Tile<bool, S> = cmpf(x, y, predicate::LessThan, cmp_ordering::Ordered);
let result: Tile<f32, S> = select(mask, x, y);
output.store(result);
}
#[cutile::entry()]
fn bf16_binary_arith_kernel<const S: [i32; 1]>(output: &mut Tensor<bf16, S>) {
let x: Tile<bf16, S> = load_tile_mut(output);
let y: Tile<bf16, S> = load_tile_mut(output);
let sum: Tile<bf16, S> = x + y;
let product: Tile<bf16, S> = sum * y;
let result: Tile<bf16, S> = product / x;
output.store(result);
}
#[cutile::entry()]
fn addf_shadow_dispatch_kernel<const S: [i32; 1]>(output: &mut Tensor<f32, S>) {
let x: Tile<f32, S> = load_tile_mut(output);
let y: Tile<f32, S> = load_tile_mut(output);
let result: Tile<f32, S> = addf(x, y, rounding::NearestEven, ftz::Disabled);
output.store(result);
}
#[cutile::entry()]
fn addf_shadow_dispatch_nested_kernel<const S: [i32; 1]>(output: &mut Tensor<f32, S>) {
let x: Tile<f32, S> = load_tile_mut(output);
let y: Tile<f32, S> = load_tile_mut(output);
let z: Tile<f32, S> = load_tile_mut(output);
let result: Tile<f32, S> = addf(
addf(x, y, rounding::NearestEven, ftz::Disabled),
z,
rounding::NearestEven,
ftz::Disabled,
);
output.store(result);
}
#[cutile::entry()]
fn reshape_shadow_dispatch_kernel<const S: [i32; 2]>(output: &mut Tensor<f32, S>) {
let x: Tile<f32, S> = load_tile_mut(output);
let target: Shape<{ [128] }> = Shape::<{ [128] }> { dims: &[128i32] };
let flat: Tile<f32, { [128] }> = reshape(x, target);
let back_shape: Shape<S> = output.shape();
let back: Tile<f32, S> = reshape(flat, back_shape);
output.store(back);
}
#[cutile::entry()]
fn reduce_sum_shadow_dispatch_kernel<const S: [i32; 2]>(
input: &mut Tensor<f32, S>,
output: &mut Tensor<f32, { [1, 1] }>,
) {
let tile: Tile<f32, S> = load_tile_mut(input);
let reduced: Tile<f32, { [1, 1] }> = reduce_sum(tile, 1i32);
output.store(reduced);
}
}
use binary_math_ops_module::__module_ast_self;
#[test]
fn compile_minmax() -> () {
common::with_test_stack(|| {
let modules = CUDATileModules::from_kernel(__module_ast_self())
.expect("Failed to create CUDATileModules");
let gpu_name = get_gpu_name(0);
let compiler = CUDATileFunctionCompiler::new(
&modules,
"binary_math_ops_module",
"minmax_kernel",
&[128.to_string()],
&[("output", &[1])],
&[],
&[],
None,
gpu_name,
&CompileOptions::default(),
)
.expect("Failed.");
let module_op_str = compiler.compile().expect("Failed.").to_string();
println!("\n=== MIN/MAX MLIR ===\n{}", module_op_str);
let expected_ops = ["remf", "maxf", "minf"];
for op in expected_ops {
assert!(
module_op_str.contains(format!("= {}", op).as_str()),
"Expected {} operation in MLIR output",
op
);
}
println!(
"\n✓ All {} min/max operations verified in MLIR output",
expected_ops.len()
);
});
}
#[test]
fn compile_select() -> () {
common::with_test_stack(|| {
let modules = CUDATileModules::from_kernel(__module_ast_self())
.expect("Failed to create CUDATileModules");
let gpu_name = get_gpu_name(0);
let compiler = CUDATileFunctionCompiler::new(
&modules,
"binary_math_ops_module",
"select_kernel",
&[128.to_string()],
&[("output", &[1])],
&[],
&[],
None,
gpu_name,
&CompileOptions::default(),
)
.expect("Failed.");
let module_op_str = compiler.compile().expect("Failed.").to_string();
println!("\n=== SELECT MLIR ===\n{}", module_op_str);
assert!(
module_op_str.contains("= cmpf"),
"Expected cmpf operation in MLIR output"
);
assert!(
module_op_str.contains("select"),
"Expected select operation in MLIR output"
);
println!("\n✓ select operation verified in MLIR output");
});
}
#[test]
fn compile_bf16_binary_arith() -> () {
common::with_test_stack(|| {
let modules = CUDATileModules::from_kernel(__module_ast_self())
.expect("Failed to create CUDATileModules");
let gpu_name = get_gpu_name(0);
let compiler = CUDATileFunctionCompiler::new(
&modules,
"binary_math_ops_module",
"bf16_binary_arith_kernel",
&[128.to_string()],
&[("output", &[1])],
&[],
&[],
None,
gpu_name,
&CompileOptions::default(),
)
.expect("Failed.");
let module_op_str = compiler.compile().expect("Failed.").to_string();
println!("\n=== BF16 BINARY ARITH MLIR ===\n{}", module_op_str);
for op in ["addf", "mulf", "divf"] {
assert!(
module_op_str.contains(format!("= {}", op).as_str()),
"Expected {} operation in MLIR output",
op
);
}
assert!(
module_op_str.contains("bf16"),
"Expected bf16 type in MLIR output"
);
});
}
#[test]
fn compile_addf_shadow_dispatch() -> () {
common::with_test_stack(|| {
let modules = CUDATileModules::from_kernel(__module_ast_self())
.expect("Failed to create CUDATileModules");
let gpu_name = get_gpu_name(0);
let compiler = CUDATileFunctionCompiler::new(
&modules,
"binary_math_ops_module",
"addf_shadow_dispatch_kernel",
&[128.to_string()],
&[("output", &[1])],
&[],
&[],
None,
gpu_name,
&CompileOptions::default(),
)
.expect("Failed.");
let module_op_str = compiler.compile().expect("Failed.").to_string();
println!("\n=== ADDF TRAIT DISPATCH MLIR ===\n{}", module_op_str);
assert!(
module_op_str.contains("= addf"),
"Expected addf operation in MLIR output"
);
});
}
#[test]
fn compile_reshape_shadow_dispatch() -> () {
common::with_test_stack(|| {
let modules = CUDATileModules::from_kernel(__module_ast_self())
.expect("Failed to create CUDATileModules");
let gpu_name = get_gpu_name(0);
let compiler = CUDATileFunctionCompiler::new(
&modules,
"binary_math_ops_module",
"reshape_shadow_dispatch_kernel",
&[8.to_string(), 16.to_string()],
&[("output", &[8, 16])],
&[],
&[],
None,
gpu_name,
&CompileOptions::default(),
)
.expect("Failed.");
let module_op_str = compiler.compile().expect("Failed.").to_string();
println!("\n=== RESHAPE TRAIT DISPATCH MLIR ===\n{}", module_op_str);
assert!(
module_op_str.contains("reshape"),
"Expected reshape operation in MLIR output"
);
});
}
#[test]
fn compile_reduce_sum_shadow_dispatch() -> () {
common::with_test_stack(|| {
let modules = CUDATileModules::from_kernel(__module_ast_self())
.expect("Failed to create CUDATileModules");
let gpu_name = get_gpu_name(0);
let compiler = CUDATileFunctionCompiler::new(
&modules,
"binary_math_ops_module",
"reduce_sum_shadow_dispatch_kernel",
&[8.to_string(), 16.to_string()],
&[("input", &[8, 16]), ("output", &[1, 1])],
&[],
&[],
None,
gpu_name,
&CompileOptions::default(),
)
.expect("Failed.");
let module_op_str = compiler.compile().expect("Failed.").to_string();
println!(
"\n=== REDUCE_SUM TRAIT DISPATCH MLIR ===\n{}",
module_op_str
);
assert!(
module_op_str.contains("reduce"),
"Expected reduce operation in MLIR output"
);
});
}
#[test]
fn compile_addf_shadow_dispatch_nested() -> () {
common::with_test_stack(|| {
let modules = CUDATileModules::from_kernel(__module_ast_self())
.expect("Failed to create CUDATileModules");
let gpu_name = get_gpu_name(0);
let compiler = CUDATileFunctionCompiler::new(
&modules,
"binary_math_ops_module",
"addf_shadow_dispatch_nested_kernel",
&[128.to_string()],
&[("output", &[1])],
&[],
&[],
None,
gpu_name,
&CompileOptions::default(),
)
.expect("Failed.");
let module_op_str = compiler.compile().expect("Failed.").to_string();
println!(
"\n=== ADDF TRAIT DISPATCH NESTED MLIR ===\n{}",
module_op_str
);
let addf_count = module_op_str.matches("= addf").count();
assert!(
addf_count >= 2,
"Expected at least 2 addf operations (nested), got {}",
addf_count
);
});
}