use cutile::{self, api::*, tensor::*, tile_kernel::*};
use cutile_compiler::compiler::utils::CompileOptions;
use cutile_compiler::compiler::{CUDATileFunctionCompiler, CUDATileModules};
use cutile_compiler::cuda_tile_runtime_utils::get_gpu_name;
mod common;
#[cutile::module]
mod control_flow_ops_module {
use cutile::core::*;
#[cutile::entry()]
fn control_flow_test_kernel<const S: [i32; 1]>(
output: &mut Tensor<f32, S>,
dynamic_value: i32,
) {
let mut sum: Tile<f32, S> = load_tile_mut(output);
for _i in 0i32..10i32 {
sum = sum + sum;
}
if dynamic_value < 5i32 {
sum = sum + sum;
} else {
sum = sum - sum;
}
output.store(sum);
}
#[cutile::entry()]
fn break_test_kernel<const S: [i32; 1]>(output: &mut Tensor<f32, S>) {
let mut sum: Tile<f32, S> = load_tile_mut(output);
let mut i: i32 = 0;
loop {
sum = sum + sum;
i = i + 1i32;
if i >= 2i32 {
break;
}
}
output.store(sum);
}
#[cutile::entry()]
fn if_return_test_kernel<const S: [i32; 1]>(output: &mut Tensor<i64, S>, conditional: bool) {
let mut val: Tile<i64, S> = output.load();
let result: Tile<i64, S> = if conditional {
val = val + val;
constant(2, val.shape())
} else {
val = val + val + val;
constant(3, val.shape())
};
val = val + result;
output.store(val);
}
#[cutile::entry()]
fn while_loop_test_kernel<const S: [i32; 1]>(output: &mut Tensor<f32, S>) {
let mut sum: Tile<f32, S> = load_tile_mut(output);
let mut counter: i32 = 0i32;
while counter < 10i32 {
sum = sum + sum;
counter = counter + 1i32;
}
output.store(sum);
}
#[cutile::entry()]
fn infinite_loop_test_kernel<const S: [i32; 1]>(output: &mut Tensor<f32, S>) {
let mut sum: Tile<f32, S> = load_tile_mut(output);
let mut counter: i32 = 0i32;
while counter < 10i32 {
sum = sum + sum;
counter = counter + 1i32;
}
output.store(sum);
}
#[cutile::entry()]
fn step_by_test_kernel<const S: [i32; 1]>(output: &mut Tensor<f32, S>) {
let mut sum: Tile<f32, S> = load_tile_mut(output);
for _i in (0i32..100i32).step_by(10) {
sum = sum + sum;
}
output.store(sum);
}
#[cutile::entry()]
fn assume_test_kernel<const S: [i32; 1]>(output: &mut Tensor<i64, S>) {
let tile: Tile<i64, S> = load_tile_mut(output);
let assumed_tile: Tile<i64, S> = unsafe { assume_bounds_lower::<_, 0>(tile) };
let result: Tile<i64, S> = assumed_tile + constant(1i64, output.shape());
output.store(result);
}
#[cutile::entry()]
fn assume_non_negative_test_kernel<const S: [i32; 1]>(output: &mut Tensor<i64, S>) {
let tile: Tile<i64, S> = load_tile_mut(output);
let non_neg_tile: Tile<i64, S> = unsafe { assume_bounds_lower::<_, 0>(tile) };
let result: Tile<i64, S> = non_neg_tile + constant(1i64, output.shape());
output.store(result);
}
#[cutile::entry()]
fn assume_div_by_test_kernel<const S: [i32; 1]>(output: &mut Tensor<i64, S>) {
let tile: Tile<i64, S> = load_tile_mut(output);
let aligned_tile: Tile<i64, S> = unsafe { assume_div_by::<_, 16>(tile) };
let result: Tile<i64, S> = aligned_tile + constant(1i64, output.shape());
output.store(result);
}
#[cutile::entry()]
fn assume_same_elements_test_kernel<const S: [i32; 2]>(output: &mut Tensor<i64, S>) {
let tile: Tile<i64, S> = load_tile_mut(output);
let same_tile: Tile<i64, S> = unsafe { assume_same_elements_2d::<_, 2, 4>(tile) };
let result: Tile<i64, S> = same_tile + constant(1i64, output.shape());
output.store(result);
}
#[cutile::entry()]
fn if_for_carry_kernel<const S: [i32; 1]>(output: &mut Tensor<f32, S>, flag: i32) {
let mut acc: Tile<f32, S> = constant(0.0f32, output.shape());
if flag > 0i32 {
for _i in 0i32..10i32 {
let ones: Tile<f32, S> = constant(1.0f32, output.shape());
acc = acc + ones;
}
} else {
acc = acc;
}
output.store(acc);
}
#[cutile::entry()]
fn if_for_carry_const_kernel<const S: [i32; 1], const FLAG: i32, const N: i32>(
output: &mut Tensor<f32, S>,
) {
let mut acc: Tile<f32, S> = constant(0.0f32, output.shape());
if FLAG > 0i32 {
for _i in 0i32..N {
let ones: Tile<f32, S> = constant(1.0f32, output.shape());
acc = acc + ones;
}
} else {
acc = acc;
}
output.store(acc);
}
#[cutile::entry()]
fn if_else_tile_expr_kernel<const S: [i32; 1]>(output: &mut Tensor<f32, S>, flag: i32) {
let ones: Tile<f32, S> = constant(1.0f32, output.shape());
let twos: Tile<f32, S> = constant(2.0f32, output.shape());
let result: Tile<f32, S> = if flag > 0i32 { ones } else { twos };
output.store(result);
}
#[cutile::entry()]
fn nested_for_in_if_kernel<const S: [i32; 1]>(output: &mut Tensor<f32, S>, flag: i32) {
let mut acc: Tile<f32, S> = constant(0.0f32, output.shape());
if flag > 0i32 {
for _i in 0i32..5i32 {
let twos: Tile<f32, S> = constant(2.0f32, output.shape());
acc = acc + twos;
}
} else {
let ones: Tile<f32, S> = constant(1.0f32, output.shape());
acc = acc + ones;
}
output.store(acc);
}
#[cutile::entry()]
fn nested_for_in_if_const_kernel<const S: [i32; 1], const FLAG: i32>(
output: &mut Tensor<f32, S>,
) {
let mut acc: Tile<f32, S> = constant(0.0f32, output.shape());
if FLAG > 0i32 {
for _i in 0i32..5i32 {
let twos: Tile<f32, S> = constant(2.0f32, output.shape());
acc = acc + twos;
}
} else {
let ones: Tile<f32, S> = constant(1.0f32, output.shape());
acc = acc + ones;
}
output.store(acc);
}
#[cutile::entry()]
fn nested_if_for_if_kernel<const S: [i32; 1], const OUTER: i32>(
output: &mut Tensor<f32, S>,
inner_flag: i32,
) {
let mut acc: Tile<f32, S> = constant(0.0f32, output.shape());
if OUTER > 0i32 {
for _i in 0i32..4i32 {
if inner_flag > 0i32 {
let threes: Tile<f32, S> = constant(3.0f32, output.shape());
acc = acc + threes;
} else {
acc = acc;
}
}
} else {
acc = acc;
}
output.store(acc);
}
}
use control_flow_ops_module::{
_module_asts, break_test_kernel, if_else_tile_expr_kernel, if_for_carry_const_kernel,
if_for_carry_kernel, if_return_test_kernel, nested_for_in_if_const_kernel,
nested_for_in_if_kernel, nested_if_for_if_kernel,
};
#[test]
fn compile_control_flow_test() -> () {
common::with_test_stack(|| {
let modules =
CUDATileModules::new(_module_asts()).expect("Failed to create CUDATileModules");
let gpu_name = get_gpu_name(0);
let compiler = CUDATileFunctionCompiler::new(
&modules,
"control_flow_ops_module",
"control_flow_test_kernel",
&[128.to_string()],
&[("output", &[1])],
&[],
&[],
None,
gpu_name,
&CompileOptions::default(),
)
.expect("Failed.");
let module_op_str = compiler
.compile()
.expect("Failed.")
.to_string();
println!("\n=== CONTROL FLOW TEST MLIR ===\n{}", module_op_str);
let has_for = module_op_str.contains(" for ");
let has_if = module_op_str.contains("if %");
let has_continue = module_op_str.contains("continue ");
let has_return = module_op_str.contains("return");
println!("\n=== Control Flow Operations Found ===");
println!("for: {}", if has_for { "✓" } else { "✗" });
println!("if: {}", if has_if { "✓" } else { "✗" });
println!("continue: {}", if has_continue { "✓" } else { "✗" });
println!("return: {}", if has_return { "✓" } else { "✗" });
assert!(has_for, "Expected 'for' loop in MLIR");
assert!(has_if, "Expected 'if' in MLIR");
assert!(has_continue, "Expected 'continue' as loop terminator");
assert!(has_return, "Expected 'return' at function end");
println!("\n✓ All control flow operations generated (for, if, continue, return)!");
});
}
#[test]
fn compile_if_result_test() -> () {
common::with_test_stack(|| {
let arg: Tensor<i64> = ones(&[16]).sync().expect("Failed.");
let (result, _) = if_return_test_kernel(arg.partition([4]), true)
.sync()
.expect("Failed.");
let result: Vec<i64> = result.unpartition().to_host_vec().sync().expect("Failed.");
assert!(result.iter().all(|x| *x == 4));
let arg: Tensor<i64> = ones(&[16]).sync().expect("Failed.");
let (result, _) = if_return_test_kernel(arg.partition([4]), false)
.sync()
.expect("Failed.");
let result: Vec<i64> = result.unpartition().to_host_vec().sync().expect("Failed.");
assert!(result.iter().all(|x| *x == 6));
});
}
#[test]
fn execute_break_test() -> () {
common::with_test_stack(|| {
let arg: Tensor<f32> = ones(&[16]).sync().expect("Failed.");
let (result,) = break_test_kernel(arg.partition([4]))
.sync()
.expect("Failed.");
let result: Vec<f32> = result.unpartition().to_host_vec().sync().expect("Failed.");
assert!(
result.iter().all(|x| *x == 4.0),
"Expected all elements to be 4.0, got: {:?}",
result
);
});
}
#[test]
fn compile_break_test() -> () {
common::with_test_stack(|| {
let modules =
CUDATileModules::new(_module_asts()).expect("Failed to create CUDATileModules");
let gpu_name = get_gpu_name(0);
let compiler = CUDATileFunctionCompiler::new(
&modules,
"control_flow_ops_module",
"break_test_kernel",
&[128.to_string()],
&[("output", &[1])],
&[],
&[],
None,
gpu_name,
&CompileOptions::default(),
)
.expect("Failed.");
let module_op_str = compiler
.compile()
.expect("Failed.")
.to_string();
println!("\n=== BREAK TEST MLIR ===\n{}", module_op_str);
let has_loop = module_op_str.contains("cuda_tile.loop");
let has_break = module_op_str.contains("break");
assert!(has_loop, "Expected loop operation in IR");
assert!(has_break, "Expected break operation in IR");
println!("\n✓ break statement compiled successfully");
});
}
#[test]
fn compile_while_loop_test() -> () {
common::with_test_stack(|| {
let modules =
CUDATileModules::new(_module_asts()).expect("Failed to create CUDATileModules");
let gpu_name = get_gpu_name(0);
let compiler = CUDATileFunctionCompiler::new(
&modules,
"control_flow_ops_module",
"while_loop_test_kernel",
&[128.to_string()],
&[("output", &[1])],
&[],
&[],
None,
gpu_name,
&CompileOptions::default(),
)
.expect("Failed.");
let module_op_str = compiler
.compile()
.expect("Failed.")
.to_string();
println!("\n=== WHILE LOOP TEST MLIR ===\n{}", module_op_str);
let has_loop = module_op_str.contains("cuda_tile.loop") || module_op_str.contains(" loop ");
let has_break = module_op_str.contains("break ");
assert!(has_loop, "Expected cuda_tile.loop operation in MLIR");
assert!(has_break, "Expected break operation for while loop exit");
println!("\n✓ while loop compiled to cuda_tile.loop with break!");
});
}
#[test]
fn compile_loop_test() -> () {
common::with_test_stack(|| {
let modules =
CUDATileModules::new(_module_asts()).expect("Failed to create CUDATileModules");
let gpu_name = get_gpu_name(0);
let compiler = CUDATileFunctionCompiler::new(
&modules,
"control_flow_ops_module",
"infinite_loop_test_kernel",
&[128.to_string()],
&[("output", &[1])],
&[],
&[],
None,
gpu_name,
&CompileOptions::default(),
)
.expect("Failed.");
let module_op_str = compiler
.compile()
.expect("Failed.")
.to_string();
println!("\n=== LOOP TEST MLIR ===\n{}", module_op_str);
let has_loop = module_op_str.contains("cuda_tile.loop") || module_op_str.contains(" loop ");
let has_break = module_op_str.contains("break ");
assert!(has_loop, "Expected cuda_tile.loop operation in MLIR");
assert!(has_break, "Expected break operation for loop exit");
println!("\n✓ loop expression compiled to cuda_tile.loop with break!");
});
}
#[test]
fn compile_step_by_test() -> () {
common::with_test_stack(|| {
let modules =
CUDATileModules::new(_module_asts()).expect("Failed to create CUDATileModules");
let gpu_name = get_gpu_name(0);
let compiler = CUDATileFunctionCompiler::new(
&modules,
"control_flow_ops_module",
"step_by_test_kernel",
&[128.to_string()],
&[("output", &[1])],
&[],
&[],
None,
gpu_name,
&CompileOptions::default(),
)
.expect("Failed.");
let module_op_str = compiler
.compile()
.expect("Failed.")
.to_string();
println!("\n=== STEP_BY TEST MLIR ===\n{}", module_op_str);
assert!(
module_op_str.contains(" for "),
"Expected for loop in MLIR output"
);
assert!(
module_op_str.contains(", step %"),
"Expected step_by(10) to compile to a for-loop with step"
);
});
}
#[test]
fn compile_assume_test() -> () {
common::with_test_stack(|| {
let modules =
CUDATileModules::new(_module_asts()).expect("Failed to create CUDATileModules");
let gpu_name = get_gpu_name(0);
let compiler = CUDATileFunctionCompiler::new(
&modules,
"control_flow_ops_module",
"assume_test_kernel",
&[128.to_string()],
&[("output", &[1])],
&[],
&[],
None,
gpu_name,
&CompileOptions::default(),
)
.expect("Failed.");
let module_op_str = compiler
.compile()
.expect("Failed.")
.to_string();
println!("\n=== ASSUME MLIR ===\n{}", module_op_str);
assert!(
module_op_str.contains("assume"),
"Expected assume operation in MLIR output"
);
assert!(
module_op_str.contains("bounded"),
"Expected bounded predicate on assume operation"
);
println!(
"\n✓ assume operation verified (compiler optimization hint with bounded predicate)"
);
});
}
#[test]
fn compile_assume_non_negative_test() -> () {
common::with_test_stack(|| {
let modules =
CUDATileModules::new(_module_asts()).expect("Failed to create CUDATileModules");
let gpu_name = get_gpu_name(0);
let compiler = CUDATileFunctionCompiler::new(
&modules,
"control_flow_ops_module",
"assume_non_negative_test_kernel",
&[128.to_string()],
&[("output", &[1])],
&[],
&[],
None,
gpu_name,
&CompileOptions::default(),
)
.expect("Failed.");
let module_op_str = compiler
.compile()
.expect("Failed.")
.to_string();
println!("\n=== ASSUME_NON_NEGATIVE MLIR ===\n{}", module_op_str);
assert!(
module_op_str.contains("assume"),
"Expected assume operation in MLIR output"
);
assert!(
module_op_str.contains("bounded<0, ?>"),
"Expected bounded<0, ?> predicate on assume operation"
);
println!("\n✓ assume_non_negative operation verified with bounded<0, ?>");
});
}
#[test]
fn compile_assume_div_by_test() -> () {
common::with_test_stack(|| {
let modules =
CUDATileModules::new(_module_asts()).expect("Failed to create CUDATileModules");
let gpu_name = get_gpu_name(0);
let compiler = CUDATileFunctionCompiler::new(
&modules,
"control_flow_ops_module",
"assume_div_by_test_kernel",
&[128.to_string()],
&[("output", &[1])],
&[],
&[],
None,
gpu_name,
&CompileOptions::default(),
)
.expect("Failed.");
let module_op_str = compiler
.compile()
.expect("Failed.")
.to_string();
println!("\n=== ASSUME_DIV_BY MLIR ===\n{}", module_op_str);
assert!(
module_op_str.contains("assume"),
"Expected assume operation in MLIR output"
);
assert!(
module_op_str.contains("div_by<16>"),
"Expected div_by<16> predicate on assume operation"
);
println!("\n✓ assume_div_by operation verified with div_by<16>");
});
}
#[test]
fn compile_assume_same_elements_test() -> () {
common::with_test_stack(|| {
let modules =
CUDATileModules::new(_module_asts()).expect("Failed to create CUDATileModules");
let gpu_name = get_gpu_name(0);
let compiler = CUDATileFunctionCompiler::new(
&modules,
"control_flow_ops_module",
"assume_same_elements_test_kernel",
&[4.to_string(), 8.to_string()],
&[("output", &[2, 2])],
&[],
&[],
None,
gpu_name,
&CompileOptions::default(),
)
.expect("Failed.");
let module_op_str = compiler
.compile()
.expect("Failed.")
.to_string();
println!("\n=== ASSUME_SAME_ELEMENTS MLIR ===\n{}", module_op_str);
assert!(
module_op_str.contains("assume"),
"Expected assume operation in MLIR output"
);
assert!(
module_op_str.contains("same_elements<[2, 4]>"),
"Expected same_elements<[2, 4]> predicate on assume operation"
);
println!("\n✓ assume_same_elements operation verified with same_elements<[2, 4]>");
});
}
#[test]
fn if_for_carry_propagates_mutation() {
common::with_test_stack(|| {
let mut output = cutile::api::zeros::<f32>(&[128]).sync().expect("alloc");
if_for_carry_kernel((&mut output).partition([128]), 1i32)
.sync()
.expect("kernel");
let host: Vec<f32> = output.dup().to_host_vec().sync().expect("to_host");
assert!(
(host[0] - 10.0).abs() < 1e-3,
"Expected 10.0 (for loop ran 10 times), got {}",
host[0]
);
});
}
#[test]
fn if_for_carry_const_propagates_mutation() {
common::with_test_stack(|| {
let mut output = cutile::api::zeros::<f32>(&[128]).sync().expect("alloc");
if_for_carry_const_kernel((&mut output).partition([128]))
.generics(vec![
"128".to_string(), "1".to_string(), "10".to_string(), ])
.sync()
.expect("kernel");
let host: Vec<f32> = output.dup().to_host_vec().sync().expect("to_host");
assert!(
(host[0] - 10.0).abs() < 1e-3,
"Expected 10.0 (const FLAG=1, N=10), got {}",
host[0]
);
});
}
#[test]
fn if_else_tile_expr_returns_value() {
common::with_test_stack(|| {
let mut output = cutile::api::zeros::<f32>(&[128]).sync().expect("alloc");
if_else_tile_expr_kernel((&mut output).partition([128]), 1i32)
.sync()
.expect("kernel");
let host: Vec<f32> = output.dup().to_host_vec().sync().expect("to_host");
assert!(
(host[0] - 1.0).abs() < 1e-3,
"flag=1: expected 1.0, got {}",
host[0]
);
});
common::with_test_stack(|| {
let mut output = cutile::api::zeros::<f32>(&[128]).sync().expect("alloc");
if_else_tile_expr_kernel((&mut output).partition([128]), 0i32)
.sync()
.expect("kernel");
let host: Vec<f32> = output.dup().to_host_vec().sync().expect("to_host");
assert!(
(host[0] - 2.0).abs() < 1e-3,
"flag=0: expected 2.0, got {}",
host[0]
);
});
}
#[test]
fn nested_for_in_if_dynamic() {
common::with_test_stack(|| {
let mut output = cutile::api::zeros::<f32>(&[128]).sync().expect("alloc");
nested_for_in_if_kernel((&mut output).partition([128]), 1i32)
.sync()
.expect("kernel");
let host: Vec<f32> = output.dup().to_host_vec().sync().expect("to_host");
assert!(
(host[0] - 10.0).abs() < 1e-3,
"flag=1: expected 10.0, got {}",
host[0]
);
});
common::with_test_stack(|| {
let mut output = cutile::api::zeros::<f32>(&[128]).sync().expect("alloc");
nested_for_in_if_kernel((&mut output).partition([128]), 0i32)
.sync()
.expect("kernel");
let host: Vec<f32> = output.dup().to_host_vec().sync().expect("to_host");
assert!(
(host[0] - 1.0).abs() < 1e-3,
"flag=0: expected 1.0, got {}",
host[0]
);
});
}
#[test]
fn nested_for_in_if_const_folded() {
common::with_test_stack(|| {
let mut output = cutile::api::zeros::<f32>(&[128]).sync().expect("alloc");
nested_for_in_if_const_kernel((&mut output).partition([128]))
.generics(vec!["128".into(), "1".into()])
.sync()
.expect("kernel");
let host: Vec<f32> = output.dup().to_host_vec().sync().expect("to_host");
assert!(
(host[0] - 10.0).abs() < 1e-3,
"FLAG=1: expected 10.0, got {}",
host[0]
);
});
common::with_test_stack(|| {
let mut output = cutile::api::zeros::<f32>(&[128]).sync().expect("alloc");
nested_for_in_if_const_kernel((&mut output).partition([128]))
.generics(vec!["128".into(), "0".into()])
.sync()
.expect("kernel");
let host: Vec<f32> = output.dup().to_host_vec().sync().expect("to_host");
assert!(
(host[0] - 1.0).abs() < 1e-3,
"FLAG=0: expected 1.0, got {}",
host[0]
);
});
}
#[test]
fn nested_if_for_if_deep() {
common::with_test_stack(|| {
let mut output = cutile::api::zeros::<f32>(&[128]).sync().expect("alloc");
nested_if_for_if_kernel((&mut output).partition([128]), 1i32)
.generics(vec!["128".into(), "1".into()])
.sync()
.expect("kernel");
let host: Vec<f32> = output.dup().to_host_vec().sync().expect("to_host");
assert!(
(host[0] - 12.0).abs() < 1e-3,
"OUTER=1, inner=1: expected 12.0, got {}",
host[0]
);
});
common::with_test_stack(|| {
let mut output = cutile::api::zeros::<f32>(&[128]).sync().expect("alloc");
nested_if_for_if_kernel((&mut output).partition([128]), 0i32)
.generics(vec!["128".into(), "1".into()])
.sync()
.expect("kernel");
let host: Vec<f32> = output.dup().to_host_vec().sync().expect("to_host");
assert!(
(host[0]).abs() < 1e-3,
"OUTER=1, inner=0: expected 0.0, got {}",
host[0]
);
});
common::with_test_stack(|| {
let mut output = cutile::api::zeros::<f32>(&[128]).sync().expect("alloc");
nested_if_for_if_kernel((&mut output).partition([128]), 1i32)
.generics(vec!["128".into(), "0".into()])
.sync()
.expect("kernel");
let host: Vec<f32> = output.dup().to_host_vec().sync().expect("to_host");
assert!(
(host[0]).abs() < 1e-3,
"OUTER=0: expected 0.0, got {}",
host[0]
);
});
}