cubek_std/tile/data/
rowwise.rs1use cubecl;
2use cubecl::prelude::*;
3
4#[derive(CubeType)]
5pub struct RowWise<E: Numeric> {
32 pub vals: Array<E>,
33 #[cube(comptime)]
34 pub num_rows: usize,
35}
36
37#[cube]
38impl<E: Numeric> RowWise<E> {
39 pub fn new_filled(#[comptime] num_rows: usize, val: E) -> RowWise<E> {
41 let mut vals = Array::new(num_rows);
42 for i in 0..num_rows {
43 vals[i] = val;
44 }
45 RowWise::<E> { vals, num_rows }
46 }
47
48 pub fn fill(&mut self, val: E) {
50 for i in 0..self.num_rows {
51 self.vals[i] = val;
52 }
53 }
54
55 pub fn new_min_value(#[comptime] num_rows: usize) -> RowWise<E> {
57 Self::new_filled(num_rows, E::min_value())
58 }
59
60 pub fn new_zero(#[comptime] num_rows: usize) -> RowWise<E> {
62 Self::new_filled(num_rows, E::from_int(0))
63 }
64
65 pub fn copy_from(&mut self, other: &RowWise<E>) {
67 for i in 0..self.num_rows {
68 self.vals[i] = other.vals[i]
69 }
70 }
71
72 pub fn add(&self, other: &RowWise<E>) -> RowWise<E> {
74 let mut result = Array::new(self.num_rows);
75 for i in 0..self.num_rows {
76 result[i] = self.vals[i] + other.vals[i];
77 }
78 RowWise::<E> {
79 vals: result,
80 num_rows: self.num_rows,
81 }
82 }
83
84 pub fn add_inplace(&mut self, other: &RowWise<E>) {
86 for i in 0..self.num_rows {
87 self.vals[i] += other.vals[i];
88 }
89 }
90
91 pub fn mul(&self, other: &RowWise<E>) -> RowWise<E> {
93 let mut result = Array::new(self.num_rows);
94 for i in 0..self.num_rows {
95 result[i] = self.vals[i] * other.vals[i];
96 }
97 RowWise::<E> {
98 vals: result,
99 num_rows: self.num_rows,
100 }
101 }
102
103 pub fn mul_inplace(&mut self, other: &RowWise<E>) {
105 for i in 0..self.num_rows {
106 self.vals[i] *= other.vals[i];
107 }
108 }
109
110 pub fn max_inplace(&mut self, other: &RowWise<E>) {
112 for i in 0..self.num_rows {
113 self.vals[i] = max(self.vals[i], other.vals[i]);
114 }
115 }
116
117 pub fn replace_at(&mut self, i: usize, new_val: E) {
119 self.vals[i] = new_val;
120 }
121
122 pub fn cast_from<E2: Float>(row_wise: &RowWise<E>) -> RowWise<E2> {
124 let num_rows = row_wise.num_rows;
125 let mut vals = Array::new(num_rows);
126
127 for i in 0..num_rows {
128 vals[i] = E2::cast_from(row_wise.vals[i]);
129 }
130
131 RowWise::<E2> { vals, num_rows }
132 }
133}
134
135#[cube]
136impl<E: Float> RowWise<E> {
137 pub fn exp_diff(&self, other: &RowWise<E>) -> RowWise<E> {
139 let mut vals = Array::new(self.num_rows);
140
141 for i in 0..self.num_rows {
142 vals[i] = (self.vals[i] - other.vals[i]).exp();
143 }
144
145 RowWise::<E> {
146 vals,
147 num_rows: self.num_rows,
148 }
149 }
150}