const OP_TYPE_COOPERATIVE_MATRIX_KHR: u32 = 4456;
const OP_COOPERATIVE_MATRIX_LOAD_KHR: u32 = 4457;
const OP_COOPERATIVE_MATRIX_STORE_KHR: u32 = 4458;
const OP_COOPERATIVE_MATRIX_MUL_ADD_KHR: u32 = 4459;
const CAPABILITY_COOPERATIVE_MATRIX_KHR: u32 = 6022;
const CAPABILITY_SHADER: u32 = 1;
const CAPABILITY_FLOAT16: u32 = 9;
const ADDRESSING_MODEL_LOGICAL: u32 = 0;
const MEMORY_MODEL_GLSL450: u32 = 1;
const EXECUTION_MODEL_GLCOMPUTE: u32 = 5;
const EXECUTION_MODE_LOCAL_SIZE: u32 = 17;
const STORAGE_CLASS_STORAGE_BUFFER: u32 = 12;
const STORAGE_CLASS_INPUT: u32 = 1;
const DECORATION_DESCRIPTOR_SET: u32 = 34;
const DECORATION_BINDING: u32 = 33;
const DECORATION_BLOCK: u32 = 2;
const DECORATION_BUILTIN: u32 = 11;
const DECORATION_NON_WRITABLE: u32 = 24;
const BUILTIN_WORKGROUP_ID: u32 = 26;
const OP_EXTENSION: u32 = 10;
const OP_CAPABILITY: u32 = 17;
const OP_MEMORY_MODEL: u32 = 14;
const OP_ENTRY_POINT: u32 = 15;
const OP_EXECUTION_MODE: u32 = 16;
const OP_DECORATE: u32 = 71;
const OP_MEMBER_DECORATE: u32 = 72;
const OP_TYPE_VOID: u32 = 19;
const OP_TYPE_INT: u32 = 21;
const OP_TYPE_FLOAT: u32 = 22;
const OP_TYPE_POINTER: u32 = 32;
const OP_TYPE_FUNCTION: u32 = 33;
const OP_TYPE_STRUCT: u32 = 30;
const OP_TYPE_RUNTIME_ARRAY: u32 = 29;
const OP_CONSTANT: u32 = 43;
const OP_FUNCTION: u32 = 54;
const OP_FUNCTION_END: u32 = 56;
const OP_VARIABLE: u32 = 59;
const OP_LOAD: u32 = 61;
const OP_ACCESS_CHAIN: u32 = 65;
const OP_IN_BOUNDS_ACCESS_CHAIN: u32 = 66;
const OP_LABEL: u32 = 248;
const OP_RETURN: u32 = 253;
const OP_COMPOSITE_EXTRACT: u32 = 81;
const OP_I_MUL: u32 = 132;
const OP_I_ADD: u32 = 128;
const SCOPE_SUBGROUP: u32 = 3;
const MATRIX_USE_A: u32 = 0;
const MATRIX_USE_B: u32 = 1;
const MATRIX_USE_ACCUMULATOR: u32 = 2;
const MATRIX_LAYOUT_ROW_MAJOR: u32 = 0;
const COOPERATIVE_MATRIX_OPERANDS_NONE: u32 = 0;
const SPIRV_MAGIC: u32 = 0x07230203;
const SPIRV_VERSION_1_6: u32 = 0x0001_0600;
const SPIRV_GENERATOR: u32 = 0x000D_0003;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct XmxTileConfig {
pub m: u32,
pub n: u32,
pub k: u32,
}
impl XmxTileConfig {
pub const XE_HPC_FP16: Self = Self { m: 8, n: 16, k: 16 };
pub const XE_DEFAULT: Self = Self { m: 8, n: 16, k: 16 };
pub fn accum_elements(&self) -> u32 {
self.m * self.n
}
}
impl Default for XmxTileConfig {
fn default() -> Self {
Self::XE_HPC_FP16
}
}
struct XmxSpvModule {
words: Vec<u32>,
id_bound: u32,
}
impl XmxSpvModule {
fn new() -> Self {
let words = vec![SPIRV_MAGIC, SPIRV_VERSION_1_6, SPIRV_GENERATOR, 0, 0];
Self { words, id_bound: 1 }
}
fn alloc_id(&mut self) -> u32 {
let id = self.id_bound;
self.id_bound += 1;
id
}
fn emit(&mut self, opcode: u32, operands: &[u32]) {
let word_count = (1 + operands.len()) as u32;
self.words.push((word_count << 16) | opcode);
self.words.extend_from_slice(operands);
}
fn string_words(s: &str) -> Vec<u32> {
let bytes = s.as_bytes();
let padded_len = (bytes.len() + 4) & !3;
let mut out = vec![0u32; padded_len / 4];
for (i, &b) in bytes.iter().enumerate() {
out[i / 4] |= (b as u32) << ((i % 4) * 8);
}
out
}
fn finalize(mut self) -> Vec<u32> {
self.words[3] = self.id_bound;
self.words
}
fn emit_capability(&mut self, cap: u32) {
self.emit(OP_CAPABILITY, &[cap]);
}
fn emit_extension(&mut self, name: &str) {
let mut ops = Self::string_words(name);
let word_count = (1 + ops.len()) as u32;
self.words.push((word_count << 16) | OP_EXTENSION);
self.words.append(&mut ops);
}
fn emit_memory_model(&mut self, addr: u32, model: u32) {
self.emit(OP_MEMORY_MODEL, &[addr, model]);
}
fn emit_entry_point(&mut self, model: u32, func_id: u32, name: &str, interfaces: &[u32]) {
let mut ops = vec![model, func_id];
ops.extend(Self::string_words(name));
ops.extend_from_slice(interfaces);
self.emit(OP_ENTRY_POINT, &ops);
}
fn emit_execution_mode_local_size(&mut self, func_id: u32, x: u32, y: u32, z: u32) {
self.emit(
OP_EXECUTION_MODE,
&[func_id, EXECUTION_MODE_LOCAL_SIZE, x, y, z],
);
}
fn emit_decorate(&mut self, target: u32, decoration: u32, extra: &[u32]) {
let mut ops = vec![target, decoration];
ops.extend_from_slice(extra);
self.emit(OP_DECORATE, &ops);
}
fn emit_member_decorate(
&mut self,
struct_id: u32,
member: u32,
decoration: u32,
extra: &[u32],
) {
let mut ops = vec![struct_id, member, decoration];
ops.extend_from_slice(extra);
self.emit(OP_MEMBER_DECORATE, &ops);
}
fn emit_type_void(&mut self, id: u32) {
self.emit(OP_TYPE_VOID, &[id]);
}
fn emit_type_int(&mut self, id: u32, width: u32, sign: u32) {
self.emit(OP_TYPE_INT, &[id, width, sign]);
}
fn emit_type_float(&mut self, id: u32, width: u32) {
self.emit(OP_TYPE_FLOAT, &[id, width]);
}
fn emit_type_ptr(&mut self, id: u32, sc: u32, pointee: u32) {
self.emit(OP_TYPE_POINTER, &[id, sc, pointee]);
}
fn emit_type_fn(&mut self, id: u32, ret: u32, params: &[u32]) {
let mut ops = vec![id, ret];
ops.extend_from_slice(params);
self.emit(OP_TYPE_FUNCTION, &ops);
}
fn emit_type_struct(&mut self, id: u32, members: &[u32]) {
let mut ops = vec![id];
ops.extend_from_slice(members);
self.emit(OP_TYPE_STRUCT, &ops);
}
fn emit_type_runtime_array(&mut self, id: u32, elem: u32) {
self.emit(OP_TYPE_RUNTIME_ARRAY, &[id, elem]);
}
fn emit_const_u32(&mut self, ty: u32, id: u32, val: u32) {
self.emit(OP_CONSTANT, &[ty, id, val]);
}
fn emit_variable(&mut self, ty: u32, id: u32, sc: u32) {
self.emit(OP_VARIABLE, &[ty, id, sc]);
}
fn emit_load(&mut self, ty: u32, id: u32, ptr: u32) {
self.emit(OP_LOAD, &[ty, id, ptr]);
}
fn emit_label(&mut self, id: u32) {
self.emit(OP_LABEL, &[id]);
}
fn emit_return(&mut self) {
self.emit(OP_RETURN, &[]);
}
fn emit_function_end(&mut self) {
self.emit(OP_FUNCTION_END, &[]);
}
fn emit_function(&mut self, ret_ty: u32, id: u32, ctrl: u32, fn_ty: u32) {
self.emit(OP_FUNCTION, &[ret_ty, id, ctrl, fn_ty]);
}
fn emit_i_add(&mut self, ty: u32, id: u32, a: u32, b: u32) {
self.emit(OP_I_ADD, &[ty, id, a, b]);
}
fn emit_i_mul(&mut self, ty: u32, id: u32, a: u32, b: u32) {
self.emit(OP_I_MUL, &[ty, id, a, b]);
}
fn emit_composite_extract(&mut self, ty: u32, id: u32, composite: u32, idx: u32) {
self.emit(OP_COMPOSITE_EXTRACT, &[ty, id, composite, idx]);
}
fn emit_access_chain(&mut self, ty: u32, id: u32, base: u32, indices: &[u32]) {
let mut ops = vec![ty, id, base];
ops.extend_from_slice(indices);
self.emit(OP_ACCESS_CHAIN, &ops);
}
fn emit_in_bounds_access_chain(&mut self, ty: u32, id: u32, base: u32, indices: &[u32]) {
let mut ops = vec![ty, id, base];
ops.extend_from_slice(indices);
self.emit(OP_IN_BOUNDS_ACCESS_CHAIN, &ops);
}
fn emit_type_cooperative_matrix(
&mut self,
id: u32,
component_type: u32,
scope: u32,
rows: u32,
cols: u32,
matrix_use: u32,
) {
self.emit(
OP_TYPE_COOPERATIVE_MATRIX_KHR,
&[id, component_type, scope, rows, cols, matrix_use],
);
}
fn emit_coop_matrix_load(
&mut self,
result_ty: u32,
result: u32,
pointer: u32,
layout: u32,
stride: u32,
) {
self.emit(
OP_COOPERATIVE_MATRIX_LOAD_KHR,
&[
result_ty,
result,
pointer,
layout,
stride,
COOPERATIVE_MATRIX_OPERANDS_NONE,
],
);
}
fn emit_coop_matrix_store(&mut self, pointer: u32, object: u32, layout: u32, stride: u32) {
self.emit(
OP_COOPERATIVE_MATRIX_STORE_KHR,
&[
pointer,
object,
layout,
stride,
COOPERATIVE_MATRIX_OPERANDS_NONE,
],
);
}
fn emit_coop_matrix_muladd(
&mut self,
result_ty: u32,
result: u32,
a: u32,
b: u32,
c: u32,
operands: u32,
) {
self.emit(
OP_COOPERATIVE_MATRIX_MUL_ADD_KHR,
&[result_ty, result, a, b, c, operands],
);
}
}
pub fn gemm_xmx_spirv(tile: XmxTileConfig, wg_x: u32, wg_y: u32) -> Vec<u32> {
let mut m = XmxSpvModule::new();
m.emit_capability(CAPABILITY_SHADER);
m.emit_capability(CAPABILITY_COOPERATIVE_MATRIX_KHR);
m.emit_extension("SPV_KHR_cooperative_matrix");
m.emit_memory_model(ADDRESSING_MODEL_LOGICAL, MEMORY_MODEL_GLSL450);
let ty_void = m.alloc_id();
let ty_u32 = m.alloc_id();
let ty_f32 = m.alloc_id();
let ty_rt_f32 = m.alloc_id(); let ty_rt_u32 = m.alloc_id(); let ty_sb_f32 = m.alloc_id(); let ty_sb_u32 = m.alloc_id(); let ty_ptr_sb_f32 = m.alloc_id();
let ty_ptr_sb_u32 = m.alloc_id();
let ty_ptr_f32_sb = m.alloc_id();
let ty_ptr_u32_sb = m.alloc_id();
let ty_cmat_a = m.alloc_id(); let ty_cmat_b = m.alloc_id(); let ty_cmat_c = m.alloc_id();
let ty_fn_void = m.alloc_id();
let ty_v3u32 = m.alloc_id();
let ty_ptr_in_v3u32 = m.alloc_id();
let c0 = m.alloc_id();
let c1 = m.alloc_id();
let c_tile_m = m.alloc_id();
let c_tile_n = m.alloc_id();
let c_tile_k = m.alloc_id();
let var_a = m.alloc_id();
let var_b = m.alloc_id();
let var_c = m.alloc_id();
let var_dim = m.alloc_id();
let var_wg_id = m.alloc_id();
let fn_main = m.alloc_id();
let lbl_entry = m.alloc_id();
m.emit_entry_point(
EXECUTION_MODEL_GLCOMPUTE,
fn_main,
"gemm_xmx_f32",
&[var_a, var_b, var_c, var_dim, var_wg_id],
);
m.emit_execution_mode_local_size(fn_main, wg_x, wg_y, 1);
m.emit_decorate(ty_rt_f32, 6 , &[4]);
m.emit_decorate(ty_rt_u32, 6 , &[4]);
m.emit_decorate(ty_sb_f32, DECORATION_BLOCK, &[]);
m.emit_decorate(ty_sb_u32, DECORATION_BLOCK, &[]);
m.emit_member_decorate(ty_sb_f32, 0, 35 , &[0]);
m.emit_member_decorate(ty_sb_u32, 0, 35 , &[0]);
m.emit_decorate(var_a, DECORATION_DESCRIPTOR_SET, &[0]);
m.emit_decorate(var_a, DECORATION_BINDING, &[0]);
m.emit_decorate(var_a, DECORATION_NON_WRITABLE, &[]);
m.emit_decorate(var_b, DECORATION_DESCRIPTOR_SET, &[0]);
m.emit_decorate(var_b, DECORATION_BINDING, &[1]);
m.emit_decorate(var_b, DECORATION_NON_WRITABLE, &[]);
m.emit_decorate(var_c, DECORATION_DESCRIPTOR_SET, &[0]);
m.emit_decorate(var_c, DECORATION_BINDING, &[2]);
m.emit_decorate(var_dim, DECORATION_DESCRIPTOR_SET, &[0]);
m.emit_decorate(var_dim, DECORATION_BINDING, &[3]);
m.emit_decorate(var_dim, DECORATION_NON_WRITABLE, &[]);
m.emit_decorate(var_wg_id, DECORATION_BUILTIN, &[BUILTIN_WORKGROUP_ID]);
m.emit_type_void(ty_void);
m.emit_type_int(ty_u32, 32, 0);
m.emit_type_float(ty_f32, 32);
m.emit_type_runtime_array(ty_rt_f32, ty_f32);
m.emit_type_runtime_array(ty_rt_u32, ty_u32);
m.emit_type_struct(ty_sb_f32, &[ty_rt_f32]);
m.emit_type_struct(ty_sb_u32, &[ty_rt_u32]);
m.emit_type_ptr(ty_ptr_sb_f32, STORAGE_CLASS_STORAGE_BUFFER, ty_sb_f32);
m.emit_type_ptr(ty_ptr_sb_u32, STORAGE_CLASS_STORAGE_BUFFER, ty_sb_u32);
m.emit_type_ptr(ty_ptr_f32_sb, STORAGE_CLASS_STORAGE_BUFFER, ty_f32);
m.emit_type_ptr(ty_ptr_u32_sb, STORAGE_CLASS_STORAGE_BUFFER, ty_u32);
m.emit_type_cooperative_matrix(
ty_cmat_a,
ty_f32,
SCOPE_SUBGROUP,
tile.m,
tile.k,
MATRIX_USE_A,
);
m.emit_type_cooperative_matrix(
ty_cmat_b,
ty_f32,
SCOPE_SUBGROUP,
tile.k,
tile.n,
MATRIX_USE_B,
);
m.emit_type_cooperative_matrix(
ty_cmat_c,
ty_f32,
SCOPE_SUBGROUP,
tile.m,
tile.n,
MATRIX_USE_ACCUMULATOR,
);
let ty_v3u32_actual = ty_v3u32;
m.emit(30 , &[ty_v3u32_actual, ty_u32, 3]);
m.emit_type_ptr(ty_ptr_in_v3u32, STORAGE_CLASS_INPUT, ty_v3u32_actual);
m.emit_type_fn(ty_fn_void, ty_void, &[]);
m.emit_const_u32(ty_u32, c0, 0);
m.emit_const_u32(ty_u32, c1, 1);
m.emit_const_u32(ty_u32, c_tile_m, tile.m);
m.emit_const_u32(ty_u32, c_tile_n, tile.n);
m.emit_const_u32(ty_u32, c_tile_k, tile.k);
m.emit_variable(ty_ptr_sb_f32, var_a, STORAGE_CLASS_STORAGE_BUFFER);
m.emit_variable(ty_ptr_sb_f32, var_b, STORAGE_CLASS_STORAGE_BUFFER);
m.emit_variable(ty_ptr_sb_f32, var_c, STORAGE_CLASS_STORAGE_BUFFER);
m.emit_variable(ty_ptr_sb_u32, var_dim, STORAGE_CLASS_STORAGE_BUFFER);
m.emit_variable(ty_ptr_in_v3u32, var_wg_id, STORAGE_CLASS_INPUT);
m.emit_function(ty_void, fn_main, 0, ty_fn_void);
m.emit_label(lbl_entry);
let wg_id = m.alloc_id();
m.emit_load(ty_v3u32_actual, wg_id, var_wg_id);
let wg_col = m.alloc_id();
let wg_row = m.alloc_id();
m.emit_composite_extract(ty_u32, wg_col, wg_id, 0);
m.emit_composite_extract(ty_u32, wg_row, wg_id, 1);
let ptr_m = m.alloc_id();
let ptr_n = m.alloc_id();
let ptr_k = m.alloc_id();
let dim_m = m.alloc_id();
let dim_n = m.alloc_id();
let dim_k = m.alloc_id();
m.emit_access_chain(ty_ptr_u32_sb, ptr_m, var_dim, &[c0, c0]);
m.emit_access_chain(ty_ptr_u32_sb, ptr_n, var_dim, &[c0, c1]);
let c2 = m.alloc_id();
m.emit_const_u32(ty_u32, c2, 2);
m.emit_access_chain(ty_ptr_u32_sb, ptr_k, var_dim, &[c0, c2]);
m.emit_load(ty_u32, dim_m, ptr_m);
m.emit_load(ty_u32, dim_n, ptr_n);
m.emit_load(ty_u32, dim_k, ptr_k);
let row_base = m.alloc_id();
let col_base = m.alloc_id();
m.emit_i_mul(ty_u32, row_base, wg_row, c_tile_m);
m.emit_i_mul(ty_u32, col_base, wg_col, c_tile_n);
let c_row_stride = dim_n; let c_base_flat = m.alloc_id();
let c_base_tmp = m.alloc_id();
m.emit_i_mul(ty_u32, c_base_tmp, row_base, c_row_stride);
m.emit_i_add(ty_u32, c_base_flat, c_base_tmp, col_base);
let ptr_c_tile = m.alloc_id();
m.emit_in_bounds_access_chain(ty_ptr_f32_sb, ptr_c_tile, var_c, &[c0, c_base_flat]);
let mat_c_init = m.alloc_id();
m.emit_coop_matrix_load(
ty_cmat_c,
mat_c_init,
ptr_c_tile,
MATRIX_LAYOUT_ROW_MAJOR,
c_row_stride,
);
let mat_acc_after = {
let a_base_flat = m.alloc_id();
m.emit_i_mul(ty_u32, a_base_flat, row_base, dim_k);
let ptr_a_tile = m.alloc_id();
m.emit_in_bounds_access_chain(ty_ptr_f32_sb, ptr_a_tile, var_a, &[c0, a_base_flat]);
let mat_a = m.alloc_id();
m.emit_coop_matrix_load(ty_cmat_a, mat_a, ptr_a_tile, MATRIX_LAYOUT_ROW_MAJOR, dim_k);
let ptr_b_tile = m.alloc_id();
m.emit_in_bounds_access_chain(ty_ptr_f32_sb, ptr_b_tile, var_b, &[c0, col_base]);
let mat_b = m.alloc_id();
m.emit_coop_matrix_load(ty_cmat_b, mat_b, ptr_b_tile, MATRIX_LAYOUT_ROW_MAJOR, dim_n);
let mat_tmp = m.alloc_id();
m.emit_coop_matrix_muladd(
ty_cmat_c,
mat_tmp,
mat_a,
mat_b,
mat_c_init,
COOPERATIVE_MATRIX_OPERANDS_NONE,
);
mat_tmp
};
m.emit_coop_matrix_store(
ptr_c_tile,
mat_acc_after,
MATRIX_LAYOUT_ROW_MAJOR,
c_row_stride,
);
m.emit_return();
m.emit_function_end();
m.finalize()
}
pub fn gemm_xmx_f16_spirv(tile: XmxTileConfig, wg_x: u32, wg_y: u32) -> Vec<u32> {
let mut m = XmxSpvModule::new();
m.emit_capability(CAPABILITY_SHADER);
m.emit_capability(CAPABILITY_FLOAT16);
m.emit_capability(CAPABILITY_COOPERATIVE_MATRIX_KHR);
m.emit_extension("SPV_KHR_cooperative_matrix");
m.emit_memory_model(ADDRESSING_MODEL_LOGICAL, MEMORY_MODEL_GLSL450);
let ty_void = m.alloc_id();
let ty_u32 = m.alloc_id();
let ty_f16 = m.alloc_id();
let ty_f32 = m.alloc_id();
let ty_rt_f16 = m.alloc_id();
let ty_rt_f32 = m.alloc_id();
let ty_rt_u32 = m.alloc_id();
let ty_sb_f16 = m.alloc_id();
let ty_sb_f32 = m.alloc_id();
let ty_sb_u32 = m.alloc_id();
let ty_ptr_sb_f16 = m.alloc_id();
let ty_ptr_sb_f32 = m.alloc_id();
let ty_ptr_sb_u32 = m.alloc_id();
let ty_ptr_f16_sb = m.alloc_id();
let ty_ptr_f32_sb = m.alloc_id();
let ty_ptr_u32_sb = m.alloc_id();
let ty_cmat_a = m.alloc_id();
let ty_cmat_b = m.alloc_id();
let ty_cmat_c = m.alloc_id();
let ty_v3u32 = m.alloc_id();
let ty_ptr_in_v3u32 = m.alloc_id();
let ty_fn_void = m.alloc_id();
let var_a = m.alloc_id();
let var_b = m.alloc_id();
let var_c = m.alloc_id();
let var_dim = m.alloc_id();
let var_wg = m.alloc_id();
let fn_main = m.alloc_id();
let lbl = m.alloc_id();
m.emit_entry_point(
EXECUTION_MODEL_GLCOMPUTE,
fn_main,
"gemm_xmx_f16",
&[var_a, var_b, var_c, var_dim, var_wg],
);
m.emit_execution_mode_local_size(fn_main, wg_x, wg_y, 1);
m.emit_decorate(ty_rt_f16, 6, &[2]); m.emit_decorate(ty_rt_f32, 6, &[4]);
m.emit_decorate(ty_rt_u32, 6, &[4]);
m.emit_decorate(ty_sb_f16, DECORATION_BLOCK, &[]);
m.emit_decorate(ty_sb_f32, DECORATION_BLOCK, &[]);
m.emit_decorate(ty_sb_u32, DECORATION_BLOCK, &[]);
m.emit_member_decorate(ty_sb_f16, 0, 35, &[0]);
m.emit_member_decorate(ty_sb_f32, 0, 35, &[0]);
m.emit_member_decorate(ty_sb_u32, 0, 35, &[0]);
for (var, set, binding, writable) in [
(var_a, 0u32, 0u32, false),
(var_b, 0, 1, false),
(var_c, 0, 2, true),
(var_dim, 0, 3, false),
] {
m.emit_decorate(var, DECORATION_DESCRIPTOR_SET, &[set]);
m.emit_decorate(var, DECORATION_BINDING, &[binding]);
if !writable {
m.emit_decorate(var, DECORATION_NON_WRITABLE, &[]);
}
}
m.emit_decorate(var_wg, DECORATION_BUILTIN, &[BUILTIN_WORKGROUP_ID]);
m.emit_type_void(ty_void);
m.emit_type_int(ty_u32, 32, 0);
m.emit_type_float(ty_f16, 16);
m.emit_type_float(ty_f32, 32);
m.emit_type_runtime_array(ty_rt_f16, ty_f16);
m.emit_type_runtime_array(ty_rt_f32, ty_f32);
m.emit_type_runtime_array(ty_rt_u32, ty_u32);
m.emit_type_struct(ty_sb_f16, &[ty_rt_f16]);
m.emit_type_struct(ty_sb_f32, &[ty_rt_f32]);
m.emit_type_struct(ty_sb_u32, &[ty_rt_u32]);
m.emit_type_ptr(ty_ptr_sb_f16, STORAGE_CLASS_STORAGE_BUFFER, ty_sb_f16);
m.emit_type_ptr(ty_ptr_sb_f32, STORAGE_CLASS_STORAGE_BUFFER, ty_sb_f32);
m.emit_type_ptr(ty_ptr_sb_u32, STORAGE_CLASS_STORAGE_BUFFER, ty_sb_u32);
m.emit_type_ptr(ty_ptr_f16_sb, STORAGE_CLASS_STORAGE_BUFFER, ty_f16);
m.emit_type_ptr(ty_ptr_f32_sb, STORAGE_CLASS_STORAGE_BUFFER, ty_f32);
m.emit_type_ptr(ty_ptr_u32_sb, STORAGE_CLASS_STORAGE_BUFFER, ty_u32);
m.emit_type_cooperative_matrix(
ty_cmat_a,
ty_f16,
SCOPE_SUBGROUP,
tile.m,
tile.k,
MATRIX_USE_A,
);
m.emit_type_cooperative_matrix(
ty_cmat_b,
ty_f16,
SCOPE_SUBGROUP,
tile.k,
tile.n,
MATRIX_USE_B,
);
m.emit_type_cooperative_matrix(
ty_cmat_c,
ty_f32,
SCOPE_SUBGROUP,
tile.m,
tile.n,
MATRIX_USE_ACCUMULATOR,
);
m.emit(30, &[ty_v3u32, ty_u32, 3]); m.emit_type_ptr(ty_ptr_in_v3u32, STORAGE_CLASS_INPUT, ty_v3u32);
m.emit_type_fn(ty_fn_void, ty_void, &[]);
let c0 = m.alloc_id();
m.emit_const_u32(ty_u32, c0, 0);
let c1 = m.alloc_id();
m.emit_const_u32(ty_u32, c1, 1);
let c2 = m.alloc_id();
m.emit_const_u32(ty_u32, c2, 2);
let c_tm = m.alloc_id();
m.emit_const_u32(ty_u32, c_tm, tile.m);
let c_tn = m.alloc_id();
m.emit_const_u32(ty_u32, c_tn, tile.n);
let c_tk = m.alloc_id();
m.emit_const_u32(ty_u32, c_tk, tile.k);
m.emit_variable(ty_ptr_sb_f16, var_a, STORAGE_CLASS_STORAGE_BUFFER);
m.emit_variable(ty_ptr_sb_f16, var_b, STORAGE_CLASS_STORAGE_BUFFER);
m.emit_variable(ty_ptr_sb_f32, var_c, STORAGE_CLASS_STORAGE_BUFFER);
m.emit_variable(ty_ptr_sb_u32, var_dim, STORAGE_CLASS_STORAGE_BUFFER);
m.emit_variable(ty_ptr_in_v3u32, var_wg, STORAGE_CLASS_INPUT);
m.emit_function(ty_void, fn_main, 0, ty_fn_void);
m.emit_label(lbl);
let wg_id = m.alloc_id();
m.emit_load(ty_v3u32, wg_id, var_wg);
let wg_col = m.alloc_id();
m.emit_composite_extract(ty_u32, wg_col, wg_id, 0);
let wg_row = m.alloc_id();
m.emit_composite_extract(ty_u32, wg_row, wg_id, 1);
let ptr_m = m.alloc_id();
m.emit_access_chain(ty_ptr_u32_sb, ptr_m, var_dim, &[c0, c0]);
let ptr_n = m.alloc_id();
m.emit_access_chain(ty_ptr_u32_sb, ptr_n, var_dim, &[c0, c1]);
let ptr_k = m.alloc_id();
m.emit_access_chain(ty_ptr_u32_sb, ptr_k, var_dim, &[c0, c2]);
let dim_m = m.alloc_id();
m.emit_load(ty_u32, dim_m, ptr_m);
let dim_n = m.alloc_id();
m.emit_load(ty_u32, dim_n, ptr_n);
let dim_k = m.alloc_id();
m.emit_load(ty_u32, dim_k, ptr_k);
let row_base = m.alloc_id();
m.emit_i_mul(ty_u32, row_base, wg_row, c_tm);
let col_base = m.alloc_id();
m.emit_i_mul(ty_u32, col_base, wg_col, c_tn);
let c_base_tmp = m.alloc_id();
m.emit_i_mul(ty_u32, c_base_tmp, row_base, dim_n);
let c_base_flat = m.alloc_id();
m.emit_i_add(ty_u32, c_base_flat, c_base_tmp, col_base);
let ptr_c_tile = m.alloc_id();
m.emit_in_bounds_access_chain(ty_ptr_f32_sb, ptr_c_tile, var_c, &[c0, c_base_flat]);
let mat_c_init = m.alloc_id();
m.emit_coop_matrix_load(
ty_cmat_c,
mat_c_init,
ptr_c_tile,
MATRIX_LAYOUT_ROW_MAJOR,
dim_n,
);
let a_base = m.alloc_id();
m.emit_i_mul(ty_u32, a_base, row_base, dim_k);
let ptr_a = m.alloc_id();
m.emit_in_bounds_access_chain(ty_ptr_f16_sb, ptr_a, var_a, &[c0, a_base]);
let mat_a = m.alloc_id();
m.emit_coop_matrix_load(ty_cmat_a, mat_a, ptr_a, MATRIX_LAYOUT_ROW_MAJOR, dim_k);
let ptr_b = m.alloc_id();
m.emit_in_bounds_access_chain(ty_ptr_f16_sb, ptr_b, var_b, &[c0, col_base]);
let mat_b = m.alloc_id();
m.emit_coop_matrix_load(ty_cmat_b, mat_b, ptr_b, MATRIX_LAYOUT_ROW_MAJOR, dim_n);
let mat_out = m.alloc_id();
m.emit_coop_matrix_muladd(
ty_cmat_c,
mat_out,
mat_a,
mat_b,
mat_c_init,
COOPERATIVE_MATRIX_OPERANDS_NONE,
);
m.emit_coop_matrix_store(ptr_c_tile, mat_out, MATRIX_LAYOUT_ROW_MAJOR, dim_n);
m.emit_return();
m.emit_function_end();
m.finalize()
}
pub fn matmul_xmx_bf16_spirv(tile: XmxTileConfig, wg_x: u32, wg_y: u32) -> Vec<u32> {
let mut words = gemm_xmx_f16_spirv(tile, wg_x, wg_y);
let old = b"gemm_xmx_f16\0\0\0\0";
let new = b"matmul_xmx_bf\0\0\0"; patch_entry_point_name(&mut words, old, new);
words
}
fn patch_entry_point_name(words: &mut [u32], old: &[u8; 16], new: &[u8; 16]) {
let old_words = [
u32::from_le_bytes([old[0], old[1], old[2], old[3]]),
u32::from_le_bytes([old[4], old[5], old[6], old[7]]),
u32::from_le_bytes([old[8], old[9], old[10], old[11]]),
u32::from_le_bytes([old[12], old[13], old[14], old[15]]),
];
let new_words = [
u32::from_le_bytes([new[0], new[1], new[2], new[3]]),
u32::from_le_bytes([new[4], new[5], new[6], new[7]]),
u32::from_le_bytes([new[8], new[9], new[10], new[11]]),
u32::from_le_bytes([new[12], new[13], new[14], new[15]]),
];
'outer: for i in 0..words.len().saturating_sub(3) {
for (j, &ow) in old_words.iter().enumerate() {
if words[i + j] != ow {
continue 'outer;
}
}
for (j, &nw) in new_words.iter().enumerate() {
words[i + j] = nw;
}
break;
}
}
pub fn device_supports_xmx(device_name: &str) -> bool {
let name = device_name.to_ascii_lowercase();
name.contains("arc")
|| name.contains("data center gpu max")
|| name.contains("ponte vecchio")
|| name.contains("max 1")
|| name.contains("max 12")
|| name.contains("iris xe")
|| name.contains("uhd graphics")
}
pub fn best_xmx_tile(device_name: &str) -> XmxTileConfig {
let name = device_name.to_ascii_lowercase();
if name.contains("max") || name.contains("ponte vecchio") {
XmxTileConfig { m: 8, n: 32, k: 16 }
} else if name.contains("arc") || name.contains("iris xe") {
XmxTileConfig::XE_HPC_FP16
} else {
XmxTileConfig::XE_DEFAULT
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn gemm_xmx_spirv_starts_with_magic() {
let words = gemm_xmx_spirv(XmxTileConfig::default(), 16, 16);
assert!(!words.is_empty(), "output must not be empty");
assert_eq!(words[0], 0x07230203, "first word must be SPIR-V magic");
}
#[test]
fn gemm_xmx_spirv_version_1_6() {
let words = gemm_xmx_spirv(XmxTileConfig::default(), 16, 16);
assert_eq!(words[1], 0x0001_0600, "version must be SPIR-V 1.6");
}
#[test]
fn gemm_xmx_spirv_id_bound_nonzero() {
let words = gemm_xmx_spirv(XmxTileConfig::default(), 16, 16);
assert!(words[3] > 0, "ID bound must be > 0");
}
#[test]
fn gemm_xmx_f16_produces_valid_header() {
let words = gemm_xmx_f16_spirv(XmxTileConfig::XE_HPC_FP16, 16, 16);
assert_eq!(words[0], SPIRV_MAGIC);
assert_eq!(words[1], SPIRV_VERSION_1_6);
assert!(words.len() > 20, "module must have non-trivial content");
}
#[test]
fn matmul_xmx_bf16_produces_valid_header() {
let words = matmul_xmx_bf16_spirv(XmxTileConfig::default(), 16, 16);
assert_eq!(words[0], SPIRV_MAGIC);
}
#[test]
fn xmx_tile_accum_elements() {
let tile = XmxTileConfig { m: 8, n: 16, k: 16 };
assert_eq!(tile.accum_elements(), 128);
}
#[test]
fn device_supports_xmx_arc() {
assert!(device_supports_xmx("Intel Arc A770 Graphics"));
assert!(device_supports_xmx("Intel Data Center GPU Max 1550"));
assert!(!device_supports_xmx("AMD Radeon RX 7900 XTX"));
}
#[test]
fn best_xmx_tile_xe_hpc() {
let tile = best_xmx_tile("Intel Data Center GPU Max 1550");
assert_eq!(tile.m, 8);
assert_eq!(tile.n, 32);
}
#[test]
fn different_tile_sizes_produce_different_binaries() {
let a = gemm_xmx_spirv(XmxTileConfig { m: 8, n: 16, k: 16 }, 16, 16);
let b = gemm_xmx_spirv(XmxTileConfig { m: 8, n: 32, k: 16 }, 16, 16);
assert_ne!(
a, b,
"different tile configurations must yield distinct SPIR-V"
);
}
#[test]
fn gemm_xmx_spirv_contains_cooperative_matrix_opcode() {
let words = gemm_xmx_spirv(XmxTileConfig::default(), 16, 16);
let has_cmat = words
.iter()
.any(|&w| (w & 0xFFFF) == OP_TYPE_COOPERATIVE_MATRIX_KHR);
assert!(has_cmat, "module must declare OpTypeCooperativeMatrixKHR");
}
#[test]
fn gemm_xmx_f16_contains_float16_type() {
let words = gemm_xmx_f16_spirv(XmxTileConfig::XE_HPC_FP16, 16, 16);
let has_f16 = words.windows(3).any(|w| {
(w[0] & 0xFFFF) == 22 && w[2] == 16
});
assert!(has_f16, "FP16 module must declare 16-bit float type");
}
}