cubecl_cpp/hip/
processors.rs

1use cubecl_core::{
2    self as cubecl,
3    ir::{ExpandElement, Instruction, Scope},
4};
5use cubecl_core::{
6    cube,
7    ir::{Allocator, CoopMma, MatrixIdent, Operation, Processor, ScopeProcessing},
8};
9
10#[derive(new, Debug)]
11pub struct HipMmaProcessor;
12
13impl Processor for HipMmaProcessor {
14    fn transform(&self, mut processing: ScopeProcessing, allocator: Allocator) -> ScopeProcessing {
15        let mut instructions = Vec::new();
16        core::mem::swap(&mut processing.instructions, &mut instructions);
17
18        for instruction in instructions {
19            match instruction.operation {
20                Operation::CoopMma(CoopMma::RowIndex { lane_id, i, matrix }) => {
21                    let lane_id = ExpandElement::Plain(lane_id);
22                    let i = ExpandElement::Plain(i);
23                    let mut scope = Scope::root(false)
24                        .with_allocator(allocator.clone())
25                        .with_types(processing.typemap.clone());
26                    let row_idx: ExpandElement =
27                        row_index::expand(&mut scope, lane_id.into(), i.into(), matrix.ident)
28                            .into();
29                    let tmp_processing = scope.process([]);
30                    for inst in tmp_processing.instructions {
31                        processing.instructions.push(inst);
32                    }
33                    for var in tmp_processing.variables {
34                        processing.variables.push(var);
35                    }
36
37                    processing.instructions.push(Instruction::new(
38                        Operation::Copy(*row_idx),
39                        instruction.out(),
40                    ));
41                }
42                Operation::CoopMma(CoopMma::ColIndex { lane_id, i, matrix }) => {
43                    let lane_id = ExpandElement::Plain(lane_id);
44                    let i = ExpandElement::Plain(i);
45                    let mut scope = Scope::root(false)
46                        .with_allocator(allocator.clone())
47                        .with_types(processing.typemap.clone());
48                    let row_idx: ExpandElement =
49                        col_index::expand(&mut scope, lane_id.into(), i.into(), matrix.ident)
50                            .into();
51                    let tmp_processing = scope.process([]);
52                    for inst in tmp_processing.instructions {
53                        processing.instructions.push(inst);
54                    }
55                    for var in tmp_processing.variables {
56                        processing.variables.push(var);
57                    }
58
59                    processing.instructions.push(Instruction::new(
60                        Operation::Copy(*row_idx),
61                        instruction.out(),
62                    ));
63                }
64                _ => {
65                    processing.instructions.push(instruction);
66                }
67            }
68        }
69
70        processing
71    }
72}
73
74#[cube]
75fn row_index(lane_id: u32, i: u32, #[comptime] ident: MatrixIdent) -> u32 {
76    match ident {
77        MatrixIdent::A => lane_id % 16,
78        MatrixIdent::B => i,
79        // 2 * i, offset by 1 if lane_id >= 16
80        MatrixIdent::Accumulator => i * 2 + (lane_id / 16),
81    }
82}
83
84#[cube]
85fn col_index(lane_id: u32, i: u32, #[comptime] ident: MatrixIdent) -> u32 {
86    match ident {
87        MatrixIdent::A => i,
88        MatrixIdent::B => lane_id % 16,
89        MatrixIdent::Accumulator => lane_id % 16,
90    }
91}