cubecl_cpp/hip/
processors.rs1use cubecl_core::{
2 self as cubecl,
3 ir::{ExpandElement, Instruction, Scope},
4};
5use cubecl_core::{
6 cube,
7 ir::{Allocator, CoopMma, MatrixIdent, Operation, Processor, ScopeProcessing},
8};
9
10#[derive(new, Debug)]
11pub struct HipMmaProcessor;
12
13impl Processor for HipMmaProcessor {
14 fn transform(&self, mut processing: ScopeProcessing, allocator: Allocator) -> ScopeProcessing {
15 let mut instructions = Vec::new();
16 core::mem::swap(&mut processing.instructions, &mut instructions);
17
18 for instruction in instructions {
19 match instruction.operation {
20 Operation::CoopMma(CoopMma::RowIndex { lane_id, i, matrix }) => {
21 let lane_id = ExpandElement::Plain(lane_id);
22 let i = ExpandElement::Plain(i);
23 let mut scope = Scope::root(false).with_allocator(allocator.clone());
24 let row_idx: ExpandElement =
25 row_index::expand(&mut scope, lane_id.into(), i.into(), matrix.ident)
26 .into();
27 let tmp_processing = scope.process([]);
28 for inst in tmp_processing.instructions {
29 processing.instructions.push(inst);
30 }
31 for var in tmp_processing.variables {
32 processing.variables.push(var);
33 }
34
35 processing.instructions.push(Instruction::new(
36 Operation::Copy(*row_idx),
37 instruction.out(),
38 ));
39 }
40 Operation::CoopMma(CoopMma::ColIndex { lane_id, i, matrix }) => {
41 let lane_id = ExpandElement::Plain(lane_id);
42 let i = ExpandElement::Plain(i);
43 let mut scope = Scope::root(false).with_allocator(allocator.clone());
44 let row_idx: ExpandElement =
45 col_index::expand(&mut scope, lane_id.into(), i.into(), matrix.ident)
46 .into();
47 let tmp_processing = scope.process([]);
48 for inst in tmp_processing.instructions {
49 processing.instructions.push(inst);
50 }
51 for var in tmp_processing.variables {
52 processing.variables.push(var);
53 }
54
55 processing.instructions.push(Instruction::new(
56 Operation::Copy(*row_idx),
57 instruction.out(),
58 ));
59 }
60 _ => {
61 processing.instructions.push(instruction);
62 }
63 }
64 }
65
66 processing
67 }
68}
69
70#[cube]
71fn row_index(lane_id: u32, i: u32, #[comptime] ident: MatrixIdent) -> u32 {
72 match ident {
73 MatrixIdent::A => lane_id % 16,
74 MatrixIdent::B => i,
75 MatrixIdent::Accumulator => i * 2 + (lane_id / 16),
77 }
78}
79
80#[cube]
81fn col_index(lane_id: u32, i: u32, #[comptime] ident: MatrixIdent) -> u32 {
82 match ident {
83 MatrixIdent::A => i,
84 MatrixIdent::B => lane_id % 16,
85 MatrixIdent::Accumulator => lane_id % 16,
86 }
87}