use cutile;
use cutile_compiler::compiler::utils::CompileOptions;
use cutile_compiler::compiler::{CUDATileFunctionCompiler, CUDATileModules};
mod common;
#[cutile::module]
mod load_tile_like_examples_module {
use cutile::core::*;
#[cutile::entry()]
fn add_refs_like<const S: [i32; 1]>(
z: &mut Tensor<f32, S>,
x: &Tensor<f32, { [-1] }>,
y: &Tensor<f32, { [-1] }>,
) {
let tile_x = load_tile_like(x, z);
let tile_y = load_tile_like(y, z);
z.store(tile_x + tile_y);
}
#[cutile::entry()]
fn saxpy_like<const S: [i32; 2]>(
y: &mut Tensor<f32, S>,
a: f32,
x: &Tensor<f32, { [-1, -1] }>,
) {
let tile_a = a.broadcast(y.shape());
let tile_x = load_tile_like(x, y);
let tile_y = y.load();
y.store(tile_a * tile_x + tile_y);
}
#[cutile::entry()]
fn generic_saxpy_like<T: ElementType, const S: [i32; 2]>(
y: &mut Tensor<T, S>,
a: T,
x: &Tensor<T, { [-1, -1] }>,
) {
let tile_a = a.broadcast(y.shape());
let tile_x = load_tile_like(x, y);
let tile_y = y.load();
y.store(tile_a * tile_x + tile_y);
}
}
use load_tile_like_examples_module::__module_ast_self;
fn compile(kernel: &str, generics: &[String], strides: &[(&str, &[i32])]) -> String {
let modules = CUDATileModules::from_kernel(__module_ast_self())
.expect("Failed to create CUDATileModules");
let compiler = CUDATileFunctionCompiler::new(
&modules,
"load_tile_like_examples_module",
kernel,
generics,
strides,
&[],
&[],
None,
"sm_120".to_string(),
&CompileOptions::default(),
)
.expect("Failed to create compiler");
let mlir = compiler.compile().expect("Failed to compile").to_string();
println!("=== MLIR for {kernel} ===\n{mlir}");
mlir
}
#[test]
fn compiles_add_refs_style_1d_inference() {
common::with_test_stack(|| {
let mlir = compile(
"add_refs_like",
&[4.to_string()],
&[("z", &[1]), ("x", &[1]), ("y", &[1])],
);
assert!(mlir.contains("load_view_tko"));
assert_eq!(mlir.matches("load_view_tko").count(), 2);
});
}
#[test]
fn compiles_saxpy_style_2d_inference() {
common::with_test_stack(|| {
let mlir = compile(
"saxpy_like",
&[2.to_string(), 4.to_string()],
&[("y", &[4, 1]), ("x", &[4, 1])],
);
assert!(mlir.contains("load_view_tko"));
assert!(mlir.contains("tile<2x4xf32>"));
});
}
#[test]
fn compiles_generic_saxpy_style_2d_inference() {
common::with_test_stack(|| {
let mlir = compile(
"generic_saxpy_like",
&["f32".to_string(), 2.to_string(), 4.to_string()],
&[("y", &[4, 1]), ("x", &[4, 1])],
);
assert!(mlir.contains("load_view_tko"));
assert!(mlir.contains("tile<2x4xf32>"));
});
}