cubecl_cpp/hip/
processors.rs1use cubecl_core::{
2 self as cubecl,
3 ir::{ExpandElement, Instruction, Scope},
4};
5use cubecl_core::{
6 cube,
7 ir::{Allocator, CoopMma, MatrixIdent, Operation, Processor, ScopeProcessing},
8};
9
10#[derive(new, Debug)]
11pub struct HipMmaProcessor;
12
13impl Processor for HipMmaProcessor {
14 fn transform(&self, mut processing: ScopeProcessing, allocator: Allocator) -> ScopeProcessing {
15 let mut instructions = Vec::new();
16 core::mem::swap(&mut processing.instructions, &mut instructions);
17
18 for instruction in instructions {
19 match instruction.operation {
20 Operation::CoopMma(CoopMma::RowIndex { lane_id, i, matrix }) => {
21 let lane_id = ExpandElement::Plain(lane_id);
22 let i = ExpandElement::Plain(i);
23 let mut scope = Scope::root(false)
24 .with_allocator(allocator.clone())
25 .with_types(processing.typemap.clone());
26 let row_idx: ExpandElement =
27 row_index::expand(&mut scope, lane_id.into(), i.into(), matrix.ident)
28 .into();
29 let tmp_processing = scope.process([]);
30 for inst in tmp_processing.instructions {
31 processing.instructions.push(inst);
32 }
33 for var in tmp_processing.variables {
34 processing.variables.push(var);
35 }
36
37 processing.instructions.push(Instruction::new(
38 Operation::Copy(*row_idx),
39 instruction.out(),
40 ));
41 }
42 Operation::CoopMma(CoopMma::ColIndex { lane_id, i, matrix }) => {
43 let lane_id = ExpandElement::Plain(lane_id);
44 let i = ExpandElement::Plain(i);
45 let mut scope = Scope::root(false)
46 .with_allocator(allocator.clone())
47 .with_types(processing.typemap.clone());
48 let row_idx: ExpandElement =
49 col_index::expand(&mut scope, lane_id.into(), i.into(), matrix.ident)
50 .into();
51 let tmp_processing = scope.process([]);
52 for inst in tmp_processing.instructions {
53 processing.instructions.push(inst);
54 }
55 for var in tmp_processing.variables {
56 processing.variables.push(var);
57 }
58
59 processing.instructions.push(Instruction::new(
60 Operation::Copy(*row_idx),
61 instruction.out(),
62 ));
63 }
64 _ => {
65 processing.instructions.push(instruction);
66 }
67 }
68 }
69
70 processing
71 }
72}
73
74#[cube]
75fn row_index(lane_id: u32, i: u32, #[comptime] ident: MatrixIdent) -> u32 {
76 match ident {
77 MatrixIdent::A => lane_id % 16,
78 MatrixIdent::B => i,
79 MatrixIdent::Accumulator => i * 2 + (lane_id / 16),
81 }
82}
83
84#[cube]
85fn col_index(lane_id: u32, i: u32, #[comptime] ident: MatrixIdent) -> u32 {
86 match ident {
87 MatrixIdent::A => i,
88 MatrixIdent::B => lane_id % 16,
89 MatrixIdent::Accumulator => lane_id % 16,
90 }
91}