use metaltile_core::{
dtype::DType,
ir::{
Block,
BlockId,
CoopTileAccMode,
CoopTileScope,
Kernel,
KernelMode,
Op,
Param,
ParamKind,
},
shape::{Dim, Shape},
};
use rustc_hash::FxHashMap;
pub fn kernel_ir() -> Kernel {
let mut k = Kernel::new("mt_mpp_matmul_smoke");
k.mode = KernelMode::Elementwise;
k.params.push(Param {
name: "A".into(),
dtype: DType::F16,
shape: Shape::new([Dim::Known(16), Dim::Known(32)]),
is_output: false,
kind: ParamKind::Tensor,
});
k.params.push(Param {
name: "B".into(),
dtype: DType::F16,
shape: Shape::new([Dim::Known(32), Dim::Known(16)]),
is_output: false,
kind: ParamKind::Tensor,
});
k.params.push(Param {
name: "C".into(),
dtype: DType::F32,
shape: Shape::new([Dim::Known(16), Dim::Known(16)]),
is_output: true,
kind: ParamKind::Tensor,
});
k.return_shapes.push(Shape::new([Dim::Known(16), Dim::Known(16)]));
let mut body = Block::new(BlockId::new(0));
body.push_op_no_result(Op::CoopTileSetup {
name: "gemm".into(),
m: 16,
n: 16,
k: 32,
ta: false,
tb: false,
tc: false,
acc_mode: CoopTileAccMode::Overwrite,
exec_scope: CoopTileScope::SimdGroup,
act_dtype: DType::F16,
acc_dtype: DType::F32,
direct_inputs: false,
a_is_tg: false,
a_ei: 0,
a_eo: 0,
b_is_tg: false,
b_ei: 0,
b_eo: 0,
});
body.push_op_no_result(Op::CoopTileZero { name: "gemm".into() });
body.push_op_no_result(Op::CoopTileLoadA {
name: "gemm".into(),
ptr_name: "A".into(),
ptr_offset: None,
is_tg: false,
dtype: DType::F16,
ei: 32,
eo: 16,
direct: false,
});
body.push_op_no_result(Op::CoopTileLoadB {
name: "gemm".into(),
ptr_name: "B".into(),
ptr_offset: None,
is_tg: false,
dtype: DType::F16,
ei: 16,
eo: 32,
direct: false,
});
body.push_op_no_result(Op::CoopTileRun { name: "gemm".into(), direct: false });
body.push_op_no_result(Op::CoopTileStoreC {
name: "gemm".into(),
ptr_name: "C".into(),
ptr_offset: None,
is_tg: false,
dtype: DType::F32,
ei: 16,
eo: 16,
});
k.body = body.clone();
let mut blocks = FxHashMap::default();
blocks.insert(BlockId::new(0), body);
k.blocks = blocks;
k
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn kernel_ir_constructs_and_has_three_params() {
let k = kernel_ir();
assert_eq!(k.name, "mt_mpp_matmul_smoke");
assert_eq!(k.params.len(), 3);
assert_eq!(k.params[0].name, "A");
assert_eq!(k.params[1].name, "B");
assert_eq!(k.params[2].name, "C");
assert!(k.params[2].is_output);
assert_eq!(k.body.ops.len(), 6);
assert!(matches!(&k.body.ops[0], Op::CoopTileSetup { .. }));
assert!(matches!(&k.body.ops[5], Op::CoopTileStoreC { .. }));
}
#[test]
fn codegen_emits_mpp_include() {
use metaltile_codegen::msl::MslGenerator;
let k = kernel_ir();
let msl = MslGenerator::default().generate(&k).expect("codegen");
assert!(
msl.contains("MetalPerformancePrimitives/MetalPerformancePrimitives.h"),
"MPP include missing from generated MSL:\n{msl}"
);
assert!(msl.contains("mpp::tensor_ops::matmul2d_descriptor"));
assert!(msl.contains("kernel void mt_mpp_matmul_smoke"));
}
#[test]
fn dump_generated_msl() {
use metaltile_codegen::msl::MslGenerator;
let k = kernel_ir();
let msl = MslGenerator::default().generate(&k).expect("codegen");
println!("===== BEGIN MSL =====\n{}\n===== END MSL =====", msl);
}
}