Skip to main content

cubek_std/tile/data/
rowwise.rs

1use cubecl;
2use cubecl::prelude::*;
3
4#[derive(CubeType)]
5/// Contains one value per row of a fragment for which the unit contributes
6///
7/// Example: For a 8x8 tile shared by a plane of 32 units,
8/// every unit holds 8 values in the tile.
9///
10/// In the following layout, values are held contiguously, and num_rows=1 because
11/// every two occurrences of the same plane id are in the same row
12///  0,  0,  1,  1,  2,  2,  3,  3,
13///  4,  4,  5,  5,  6,  6,  7,  7,
14///  8,  8,  9,  9, 10, 10, 11, 11,
15/// 12, 12, 13, 13, 14, 14, 15, 15,
16/// 16, 16, 17, 17, 18, 18, 19, 19,
17/// 20, 20, 21, 21, 22, 22, 23, 23,
18/// 24, 24, 25, 25, 26, 26, 27, 27,
19/// 28, 28, 29, 29, 30, 30, 31, 31,
20///
21/// In the following layout, values are held disjointly, and num_rows=2 because
22/// the two occurrences of the same plane id are not in the same row
23///  0,  1,  2,  3,  4,  5,  6,  7,
24///  8,  9, 10, 11, 12, 13, 14, 15,
25/// 16, 17, 18, 19, 20, 21, 22, 23,
26/// 24, 25, 26, 27, 28, 29, 30, 31,
27///  0,  1,  2,  3,  4,  5,  6,  7,
28///  8,  9, 10, 11, 12, 13, 14, 15,
29/// 16, 17, 18, 19, 20, 21, 22, 23,
30/// 24, 25, 26, 27, 28, 29, 30, 31,
31pub struct RowWise<E: Numeric> {
32    pub vals: Array<E>,
33    #[cube(comptime)]
34    pub num_rows: usize,
35}
36
37#[cube]
38impl<E: Numeric> RowWise<E> {
39    /// Create a RowWise with the provided value at every row
40    pub fn new_filled(#[comptime] num_rows: usize, val: E) -> RowWise<E> {
41        let mut vals = Array::new(num_rows);
42        for i in 0..num_rows {
43            vals[i] = val;
44        }
45        RowWise::<E> { vals, num_rows }
46    }
47
48    /// Fill the existing RowWise with the provided value at every row
49    pub fn fill(&mut self, val: E) {
50        for i in 0..self.num_rows {
51            self.vals[i] = val;
52        }
53    }
54
55    /// Create a RowWise with -infinity at every row
56    pub fn new_min_value(#[comptime] num_rows: usize) -> RowWise<E> {
57        Self::new_filled(num_rows, E::min_value())
58    }
59
60    /// Create a RowWise with zero at every row
61    pub fn new_zero(#[comptime] num_rows: usize) -> RowWise<E> {
62        Self::new_filled(num_rows, E::from_int(0))
63    }
64
65    /// Fill the current RowWise with the value of other at each row
66    pub fn copy_from(&mut self, other: &RowWise<E>) {
67        for i in 0..self.num_rows {
68            self.vals[i] = other.vals[i]
69        }
70    }
71
72    /// For each row, add the the current and other, and outputs a new RowWise
73    pub fn add(&self, other: &RowWise<E>) -> RowWise<E> {
74        let mut result = Array::new(self.num_rows);
75        for i in 0..self.num_rows {
76            result[i] = self.vals[i] + other.vals[i];
77        }
78        RowWise::<E> {
79            vals: result,
80            num_rows: self.num_rows,
81        }
82    }
83
84    /// For each row, add the other value to the current RowWise
85    pub fn add_inplace(&mut self, other: &RowWise<E>) {
86        for i in 0..self.num_rows {
87            self.vals[i] += other.vals[i];
88        }
89    }
90
91    /// For each row, multiplies the the current and other, and outputs a new RowWise
92    pub fn mul(&self, other: &RowWise<E>) -> RowWise<E> {
93        let mut result = Array::new(self.num_rows);
94        for i in 0..self.num_rows {
95            result[i] = self.vals[i] * other.vals[i];
96        }
97        RowWise::<E> {
98            vals: result,
99            num_rows: self.num_rows,
100        }
101    }
102
103    /// For each row, multiplies the other value to the current RowWise
104    pub fn mul_inplace(&mut self, other: &RowWise<E>) {
105        for i in 0..self.num_rows {
106            self.vals[i] *= other.vals[i];
107        }
108    }
109
110    /// For each row, maxes the other value to the current RowWise
111    pub fn max_inplace(&mut self, other: &RowWise<E>) {
112        for i in 0..self.num_rows {
113            self.vals[i] = max(self.vals[i], other.vals[i]);
114        }
115    }
116
117    /// Changes the value at index i
118    pub fn replace_at(&mut self, i: usize, new_val: E) {
119        self.vals[i] = new_val;
120    }
121
122    /// Return a copy of self, cast into E2
123    pub fn cast_from<E2: Float>(row_wise: &RowWise<E>) -> RowWise<E2> {
124        let num_rows = row_wise.num_rows;
125        let mut vals = Array::new(num_rows);
126
127        for i in 0..num_rows {
128            vals[i] = E2::cast_from(row_wise.vals[i]);
129        }
130
131        RowWise::<E2> { vals, num_rows }
132    }
133}
134
135#[cube]
136impl<E: Float> RowWise<E> {
137    /// Computes e^(self.val - other.val) for every row, and outputs a new RowWise
138    pub fn exp_diff(&self, other: &RowWise<E>) -> RowWise<E> {
139        let mut vals = Array::new(self.num_rows);
140
141        for i in 0..self.num_rows {
142            vals[i] = (self.vals[i] - other.vals[i]).exp();
143        }
144
145        RowWise::<E> {
146            vals,
147            num_rows: self.num_rows,
148        }
149    }
150}