cubek_std/tile/compute/rowwise/
dispatch.rs1use cubecl::prelude::*;
2
3use crate::tile::{
4 LOGIT_MASKED, Plane, Tile, TileExpand,
5 compute::rowwise::reducer::{fragment_row_max, fragment_row_sum},
6 data::RowWise,
7};
8
9#[cube]
21impl<E: Float> Tile<E, Plane, ReadWrite> {
22 pub fn row_max(&self, acc: &mut RowWise<E>, base: &RowWise<E>) {
23 match self {
24 Tile::Unit(t) => {
25 acc.copy_from(base);
26 let m = comptime!(t.layout.num_rows);
27 let n = comptime!(t.layout.num_cols);
28 for r in 0..m as usize {
29 let row_offset = r as u32 * n;
30 let mut val = E::min_value();
31 for c in 0..n {
32 val = max(val, t.data[(row_offset + c) as usize]);
33 }
34 acc.vals[r] = max(acc.vals[r], val);
35 }
36 }
37 Tile::WhiteboxFragment(t) => {
38 fragment_row_max::<E>(acc, base, t);
39 }
40 Tile::Bounce(b) => {
41 fragment_row_max::<E>(acc, base, &b.fragment);
42 }
43 Tile::Register(t) => {
44 acc.copy_from(base);
45 let m = comptime!(t.config.tile_size.m());
46 let n = comptime!(t.config.tile_size.n());
47 for r in 0..m as usize {
48 let row_offset = r as u32 * n;
49 let mut val = E::min_value();
50 for c in 0..n {
51 val = max(val, t.data[(row_offset + c) as usize]);
52 }
53 acc.vals[r] = max(acc.vals[r], val);
54 }
55 }
56 _ => panic!("row_max: unsupported tile variant"),
57 }
58 }
59
60 pub fn row_sum(&self, acc: &mut RowWise<E>) {
61 match self {
62 Tile::Unit(t) => {
63 acc.fill(E::from_int(0));
64 let m = comptime!(t.layout.num_rows);
65 let n = comptime!(t.layout.num_cols);
66 for r in 0..m as usize {
67 let row_offset = r as u32 * n;
68 let mut val = E::from_int(0);
69 for c in 0..n {
70 val += t.data[(row_offset + c) as usize];
71 }
72 acc.vals[r] += val;
73 }
74 }
75 Tile::WhiteboxFragment(t) => {
76 fragment_row_sum::<E>(acc, t);
77 }
78 Tile::Bounce(b) => {
79 fragment_row_sum::<E>(acc, &b.fragment);
80 }
81 Tile::Register(t) => {
82 acc.fill(E::from_int(0));
83 let m = comptime!(t.config.tile_size.m());
84 let n = comptime!(t.config.tile_size.n());
85 for r in 0..m as usize {
86 let row_offset = r as u32 * n;
87 let mut val = E::from_int(0);
88 for c in 0..n {
89 val += t.data[(row_offset + c) as usize];
90 }
91 acc.vals[r] += val;
92 }
93 }
94 _ => panic!("row_sum: unsupported tile variant"),
95 }
96 }
97
98 pub fn exp_diff(&mut self, rowwise: &RowWise<E>) {
99 match self {
100 Tile::Unit(t) => t.exp_diff(rowwise),
101 Tile::WhiteboxFragment(t) => t.exp_diff(rowwise),
102 Tile::Bounce(b) => b.fragment.exp_diff(rowwise),
103 Tile::Register(t) => {
104 let m = comptime!(t.config.tile_size.m());
105 let n = comptime!(t.config.tile_size.n());
106 let threshold = E::new(LOGIT_MASKED);
107 for r in 0..m as usize {
108 let row_offset = r as u32 * n;
109 let val = rowwise.vals[r];
110 let safe_val = clamp_min(val, threshold);
111 let not_masked = E::cast_from(val >= threshold);
112 for c in 0..n {
113 let idx = (row_offset + c) as usize;
114 t.data[idx] = not_masked * (t.data[idx] - safe_val).exp();
115 }
116 }
117 }
118 _ => panic!("exp_diff: unsupported tile variant"),
119 }
120 }
121
122 pub fn rowwise_scale(&mut self, scale: &RowWise<E>) {
123 match self {
124 Tile::Unit(t) => t.rowwise_scale(scale),
125 Tile::WhiteboxFragment(t) => t.rowwise_scale(scale),
126 Tile::Bounce(b) => b.fragment.rowwise_scale(scale),
127 Tile::Register(t) => {
128 let m = comptime!(t.config.tile_size.m());
129 let n = comptime!(t.config.tile_size.n());
130 for r in 0..m as usize {
131 let row_offset = r as u32 * n;
132 for c in 0..n {
133 let idx = (row_offset + c) as usize;
134 t.data[idx] = t.data[idx] * scale.vals[r];
135 }
136 }
137 }
138 _ => panic!("rowwise_scale: unsupported tile variant"),
139 }
140 }
141}