cubecl_cpp/hip/
processors.rs

1use cubecl_core::{
2    self as cubecl,
3    ir::{ExpandElement, Instruction, Scope},
4};
5use cubecl_core::{
6    cube,
7    ir::{Allocator, CoopMma, MatrixIdent, Operation, Processor, ScopeProcessing},
8};
9
10#[derive(new, Debug)]
11pub struct HipMmaProcessor;
12
13impl Processor for HipMmaProcessor {
14    fn transform(&self, mut processing: ScopeProcessing, allocator: Allocator) -> ScopeProcessing {
15        let mut instructions = Vec::new();
16        core::mem::swap(&mut processing.instructions, &mut instructions);
17
18        for instruction in instructions {
19            match instruction.operation {
20                Operation::CoopMma(CoopMma::RowIndex { lane_id, i, matrix }) => {
21                    let lane_id = ExpandElement::Plain(lane_id);
22                    let i = ExpandElement::Plain(i);
23                    let mut scope = Scope::root(false).with_allocator(allocator.clone());
24                    let row_idx: ExpandElement =
25                        row_index::expand(&mut scope, lane_id.into(), i.into(), matrix.ident)
26                            .into();
27                    let tmp_processing = scope.process([]);
28                    for inst in tmp_processing.instructions {
29                        processing.instructions.push(inst);
30                    }
31                    for var in tmp_processing.variables {
32                        processing.variables.push(var);
33                    }
34
35                    processing.instructions.push(Instruction::new(
36                        Operation::Copy(*row_idx),
37                        instruction.out(),
38                    ));
39                }
40                Operation::CoopMma(CoopMma::ColIndex { lane_id, i, matrix }) => {
41                    let lane_id = ExpandElement::Plain(lane_id);
42                    let i = ExpandElement::Plain(i);
43                    let mut scope = Scope::root(false).with_allocator(allocator.clone());
44                    let row_idx: ExpandElement =
45                        col_index::expand(&mut scope, lane_id.into(), i.into(), matrix.ident)
46                            .into();
47                    let tmp_processing = scope.process([]);
48                    for inst in tmp_processing.instructions {
49                        processing.instructions.push(inst);
50                    }
51                    for var in tmp_processing.variables {
52                        processing.variables.push(var);
53                    }
54
55                    processing.instructions.push(Instruction::new(
56                        Operation::Copy(*row_idx),
57                        instruction.out(),
58                    ));
59                }
60                _ => {
61                    processing.instructions.push(instruction);
62                }
63            }
64        }
65
66        processing
67    }
68}
69
70#[cube]
71fn row_index(lane_id: u32, i: u32, #[comptime] ident: MatrixIdent) -> u32 {
72    match ident {
73        MatrixIdent::A => lane_id % 16,
74        MatrixIdent::B => i,
75        // 2 * i, offset by 1 if lane_id >= 16
76        MatrixIdent::Accumulator => i * 2 + (lane_id / 16),
77    }
78}
79
80#[cube]
81fn col_index(lane_id: u32, i: u32, #[comptime] ident: MatrixIdent) -> u32 {
82    match ident {
83        MatrixIdent::A => i,
84        MatrixIdent::B => lane_id % 16,
85        MatrixIdent::Accumulator => lane_id % 16,
86    }
87}