1use std::marker::PhantomData;
2
3use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime};
4use burn_tensor::Shape;
5use cubecl::{
6 calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*,
7 tensor_line_size_parallel,
8};
9
10use super::into_contiguous;
11
12pub(crate) trait BinaryOpFamily: Send + Sync + 'static {
13 type BinaryOp<C: Numeric>: BinaryOp<C>;
14}
15
16#[cube]
17pub(crate) trait BinaryOp<C: Numeric>: 'static + Send + Sync {
18 fn execute(lhs: Line<C>, rhs: Line<C>) -> Line<C>;
20}
21
22pub(crate) struct AddOp;
23pub(crate) struct SubOp;
24pub(crate) struct MulOp;
25pub(crate) struct DivOp;
26pub(crate) struct RemainderOp;
27
28pub(crate) struct PowOp<F: Float> {
35 _f: PhantomData<F>,
36}
37
38impl BinaryOpFamily for AddOp {
39 type BinaryOp<C: Numeric> = Self;
40}
41
42impl BinaryOpFamily for SubOp {
43 type BinaryOp<C: Numeric> = Self;
44}
45
46impl BinaryOpFamily for MulOp {
47 type BinaryOp<C: Numeric> = Self;
48}
49
50impl BinaryOpFamily for DivOp {
51 type BinaryOp<C: Numeric> = Self;
52}
53
54impl BinaryOpFamily for RemainderOp {
55 type BinaryOp<C: Numeric> = Self;
56}
57
58impl<F: Float> BinaryOpFamily for PowOp<F> {
59 type BinaryOp<C: Numeric> = Self;
60}
61
62#[cube]
63impl<N: Numeric> BinaryOp<N> for AddOp {
64 fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
65 lhs + rhs
66 }
67}
68
69#[cube]
70impl<N: Numeric> BinaryOp<N> for SubOp {
71 fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
72 lhs - rhs
73 }
74}
75
76#[cube]
77impl<N: Numeric> BinaryOp<N> for MulOp {
78 fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
79 lhs * rhs
80 }
81}
82
83#[cube]
84impl<N: Numeric> BinaryOp<N> for DivOp {
85 fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
86 lhs / rhs
87 }
88}
89
90#[cube]
91impl<N: Numeric> BinaryOp<N> for RemainderOp {
92 fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
93 Line::rem(lhs, rhs)
94 }
95}
96
97#[cube]
98impl<N: Numeric, F: Float> BinaryOp<N> for PowOp<F> {
99 fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
100 let lhs = Line::<F>::cast_from(lhs);
101 let rhs = Line::<F>::cast_from(rhs);
102 let out = Line::powf(lhs, rhs);
103
104 Line::cast_from(out)
105 }
106}
107
108#[cube(launch_unchecked)]
109pub(crate) fn kernel_scalar_binop<C: Numeric, O: BinaryOpFamily>(
110 input: &Tensor<Line<C>>,
111 scalar: C,
112 output: &mut Tensor<Line<C>>,
113) {
114 if ABSOLUTE_POS >= output.len() {
115 return;
116 }
117
118 output[ABSOLUTE_POS] = O::BinaryOp::<C>::execute(input[ABSOLUTE_POS], Line::new(scalar));
119}
120
121#[cube(launch_unchecked)]
122pub(crate) fn kernel_binop<C: Numeric, O: BinaryOpFamily>(
123 lhs: &Tensor<Line<C>>,
124 rhs: &Tensor<Line<C>>,
125 out: &mut Tensor<Line<C>>,
126 #[comptime] rank: Option<u32>,
127 #[comptime] to_contiguous_lhs: bool,
128 #[comptime] to_contiguous_rhs: bool,
129) {
130 let offset_out = ABSOLUTE_POS;
131 let mut offset_lhs = ABSOLUTE_POS;
132 let mut offset_rhs = ABSOLUTE_POS;
133
134 if offset_out >= out.len() {
135 return;
136 }
137
138 if to_contiguous_lhs {
139 offset_lhs = index_offset_with_layout::<C, C>(
140 lhs,
141 out,
142 offset_out,
143 0,
144 rank.unwrap_or_else(|| out.rank()),
145 rank.is_some(),
146 );
147 }
148
149 if to_contiguous_rhs {
150 offset_rhs = index_offset_with_layout::<C, C>(
151 rhs,
152 out,
153 offset_out,
154 0,
155 rank.unwrap_or_else(|| out.rank()),
156 rank.is_some(),
157 );
158 }
159
160 out[offset_out] = O::BinaryOp::<C>::execute(lhs[offset_lhs], rhs[offset_rhs]);
161}
162
163pub(crate) fn launch_binop<R: JitRuntime, E: JitElement, O: BinaryOpFamily>(
164 lhs: JitTensor<R>,
165 rhs: JitTensor<R>,
166) -> JitTensor<R> {
167 let ndims = lhs.shape.num_dims();
168 let line_size_lhs = tensor_line_size_parallel(
169 R::line_size_elem(&E::as_elem_native_unchecked()),
170 &lhs.shape.dims,
171 &lhs.strides,
172 ndims - 1,
173 );
174 let line_size_rhs = tensor_line_size_parallel(
175 R::line_size_elem(&E::as_elem_native_unchecked()),
176 &rhs.shape.dims,
177 &rhs.strides,
178 ndims - 1,
179 );
180 let line_size = Ord::min(line_size_lhs, line_size_rhs);
181
182 let mut shape_out = vec![0; ndims];
183 lhs.shape
184 .dims
185 .iter()
186 .zip(rhs.shape.dims.iter())
187 .enumerate()
188 .for_each(|(index, (dim_lhs, dim_rhs))| {
189 shape_out[index] = usize::max(*dim_lhs, *dim_rhs);
190 });
191
192 let shape_out = Shape::from(shape_out);
193 let client = lhs.client.clone();
194 let num_elems = shape_out.num_elements();
195
196 let cube_dim = CubeDim::default();
197 let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim);
198
199 unsafe {
200 if lhs.can_mut_broadcast(&rhs) {
201 kernel_binop::launch_unchecked::<E, O, R>(
202 &client,
203 cube_count,
204 cube_dim,
205 lhs.as_tensor_arg::<E>(line_size),
206 rhs.as_tensor_arg::<E>(line_size),
207 TensorArg::alias(0),
208 None,
209 false,
210 rhs.strides != lhs.strides || rhs.shape != lhs.shape,
211 );
212
213 lhs
214 } else if rhs.can_mut_broadcast(&lhs) {
215 kernel_binop::launch_unchecked::<E, O, R>(
216 &client,
217 cube_count,
218 cube_dim,
219 lhs.as_tensor_arg::<E>(line_size),
220 rhs.as_tensor_arg::<E>(line_size),
221 TensorArg::alias(1),
222 None,
223 rhs.strides != lhs.strides || rhs.shape != lhs.shape,
224 false,
225 );
226
227 rhs
228 } else {
229 let output = empty_device::<R, E>(lhs.client.clone(), lhs.device.clone(), shape_out);
230 let to_contiguous_lhs = lhs.strides != output.strides || lhs.shape != output.shape;
231 let to_contiguous_rhs = rhs.strides != output.strides || rhs.shape != output.shape;
232
233 kernel_binop::launch_unchecked::<E, O, R>(
234 &client,
235 cube_count,
236 cube_dim,
237 lhs.as_tensor_arg::<E>(line_size),
238 rhs.as_tensor_arg::<E>(line_size),
239 output.as_tensor_arg::<E>(line_size),
240 None,
241 to_contiguous_lhs,
242 to_contiguous_rhs,
243 );
244
245 output
246 }
247 }
248}
249
250pub(crate) fn launch_scalar_binop<R: JitRuntime, E: JitElement, O: BinaryOpFamily>(
251 mut tensor: JitTensor<R>,
252 scalar: E,
253) -> JitTensor<R> {
254 if !tensor.is_contiguous_buffer() {
255 tensor = into_contiguous(tensor);
256 }
257
258 let ndims = tensor.shape.num_dims();
260 let line_size = tensor_line_size_parallel(
261 R::line_size_elem(&E::as_elem_native_unchecked()),
262 &tensor.shape.dims,
263 &tensor.strides,
264 ndims - 1,
265 );
266 let client = tensor.client.clone();
267 let num_elems = tensor.shape.num_elements();
268
269 let cube_dim = CubeDim::default();
270 let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim);
271
272 unsafe {
273 if tensor.can_mut() {
274 kernel_scalar_binop::launch_unchecked::<E, O, R>(
275 &client,
276 cube_count,
277 cube_dim,
278 tensor.as_tensor_arg::<E>(line_size),
279 ScalarArg::new(scalar),
280 TensorArg::alias(0),
281 );
282
283 tensor
284 } else {
285 let output = empty_device::<R, E>(
286 tensor.client.clone(),
287 tensor.device.clone(),
288 tensor.shape.clone(),
289 );
290
291 kernel_scalar_binop::launch_unchecked::<E, O, R>(
292 &client,
293 cube_count,
294 CubeDim::default(),
295 tensor.as_tensor_arg::<E>(line_size),
296 ScalarArg::new(scalar),
297 output.as_tensor_arg::<E>(line_size),
298 );
299
300 output
301 }
302 }
303}