burn_jit/kernel/
binary.rs

1use std::marker::PhantomData;
2
3use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime};
4use burn_tensor::Shape;
5use cubecl::{
6    calculate_cube_count_elemwise, linalg::tensor::index_offset_with_layout, prelude::*,
7    tensor_line_size_parallel,
8};
9
10use super::into_contiguous;
11
12pub(crate) trait BinaryOpFamily: Send + Sync + 'static {
13    type BinaryOp<C: Numeric>: BinaryOp<C>;
14}
15
16#[cube]
17pub(crate) trait BinaryOp<C: Numeric>: 'static + Send + Sync {
18    /// Execute a binary operation.
19    fn execute(lhs: Line<C>, rhs: Line<C>) -> Line<C>;
20}
21
22pub(crate) struct AddOp;
23pub(crate) struct SubOp;
24pub(crate) struct MulOp;
25pub(crate) struct DivOp;
26pub(crate) struct RemainderOp;
27
28/// Since Powf only works on float, but we still want to implement the numeric binary op family, we
29/// set another precision in the family type to cast, when necessary, the input value to a valid
30/// float.
31///
32/// Because of this we won't benefit from the cubecl rust compilation speed improvement from using
33/// the family pattern for [PowOp], but at least we don't duplicate code.
34pub(crate) struct PowOp<F: Float> {
35    _f: PhantomData<F>,
36}
37
38impl BinaryOpFamily for AddOp {
39    type BinaryOp<C: Numeric> = Self;
40}
41
42impl BinaryOpFamily for SubOp {
43    type BinaryOp<C: Numeric> = Self;
44}
45
46impl BinaryOpFamily for MulOp {
47    type BinaryOp<C: Numeric> = Self;
48}
49
50impl BinaryOpFamily for DivOp {
51    type BinaryOp<C: Numeric> = Self;
52}
53
54impl BinaryOpFamily for RemainderOp {
55    type BinaryOp<C: Numeric> = Self;
56}
57
58impl<F: Float> BinaryOpFamily for PowOp<F> {
59    type BinaryOp<C: Numeric> = Self;
60}
61
62#[cube]
63impl<N: Numeric> BinaryOp<N> for AddOp {
64    fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
65        lhs + rhs
66    }
67}
68
69#[cube]
70impl<N: Numeric> BinaryOp<N> for SubOp {
71    fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
72        lhs - rhs
73    }
74}
75
76#[cube]
77impl<N: Numeric> BinaryOp<N> for MulOp {
78    fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
79        lhs * rhs
80    }
81}
82
83#[cube]
84impl<N: Numeric> BinaryOp<N> for DivOp {
85    fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
86        lhs / rhs
87    }
88}
89
90#[cube]
91impl<N: Numeric> BinaryOp<N> for RemainderOp {
92    fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
93        Line::rem(lhs, rhs)
94    }
95}
96
97#[cube]
98impl<N: Numeric, F: Float> BinaryOp<N> for PowOp<F> {
99    fn execute(lhs: Line<N>, rhs: Line<N>) -> Line<N> {
100        let lhs = Line::<F>::cast_from(lhs);
101        let rhs = Line::<F>::cast_from(rhs);
102        let out = Line::powf(lhs, rhs);
103
104        Line::cast_from(out)
105    }
106}
107
108#[cube(launch_unchecked)]
109pub(crate) fn kernel_scalar_binop<C: Numeric, O: BinaryOpFamily>(
110    input: &Tensor<Line<C>>,
111    scalar: C,
112    output: &mut Tensor<Line<C>>,
113) {
114    if ABSOLUTE_POS >= output.len() {
115        return;
116    }
117
118    output[ABSOLUTE_POS] = O::BinaryOp::<C>::execute(input[ABSOLUTE_POS], Line::new(scalar));
119}
120
121#[cube(launch_unchecked)]
122pub(crate) fn kernel_binop<C: Numeric, O: BinaryOpFamily>(
123    lhs: &Tensor<Line<C>>,
124    rhs: &Tensor<Line<C>>,
125    out: &mut Tensor<Line<C>>,
126    #[comptime] rank: Option<u32>,
127    #[comptime] to_contiguous_lhs: bool,
128    #[comptime] to_contiguous_rhs: bool,
129) {
130    let offset_out = ABSOLUTE_POS;
131    let mut offset_lhs = ABSOLUTE_POS;
132    let mut offset_rhs = ABSOLUTE_POS;
133
134    if offset_out >= out.len() {
135        return;
136    }
137
138    if to_contiguous_lhs {
139        offset_lhs = index_offset_with_layout::<C, C>(
140            lhs,
141            out,
142            offset_out,
143            0,
144            rank.unwrap_or_else(|| out.rank()),
145            rank.is_some(),
146        );
147    }
148
149    if to_contiguous_rhs {
150        offset_rhs = index_offset_with_layout::<C, C>(
151            rhs,
152            out,
153            offset_out,
154            0,
155            rank.unwrap_or_else(|| out.rank()),
156            rank.is_some(),
157        );
158    }
159
160    out[offset_out] = O::BinaryOp::<C>::execute(lhs[offset_lhs], rhs[offset_rhs]);
161}
162
163pub(crate) fn launch_binop<R: JitRuntime, E: JitElement, O: BinaryOpFamily>(
164    lhs: JitTensor<R>,
165    rhs: JitTensor<R>,
166) -> JitTensor<R> {
167    let ndims = lhs.shape.num_dims();
168    let line_size_lhs = tensor_line_size_parallel(
169        R::line_size_elem(&E::as_elem_native_unchecked()),
170        &lhs.shape.dims,
171        &lhs.strides,
172        ndims - 1,
173    );
174    let line_size_rhs = tensor_line_size_parallel(
175        R::line_size_elem(&E::as_elem_native_unchecked()),
176        &rhs.shape.dims,
177        &rhs.strides,
178        ndims - 1,
179    );
180    let line_size = Ord::min(line_size_lhs, line_size_rhs);
181
182    let mut shape_out = vec![0; ndims];
183    lhs.shape
184        .dims
185        .iter()
186        .zip(rhs.shape.dims.iter())
187        .enumerate()
188        .for_each(|(index, (dim_lhs, dim_rhs))| {
189            shape_out[index] = usize::max(*dim_lhs, *dim_rhs);
190        });
191
192    let shape_out = Shape::from(shape_out);
193    let client = lhs.client.clone();
194    let num_elems = shape_out.num_elements();
195
196    let cube_dim = CubeDim::default();
197    let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim);
198
199    unsafe {
200        if lhs.can_mut_broadcast(&rhs) {
201            kernel_binop::launch_unchecked::<E, O, R>(
202                &client,
203                cube_count,
204                cube_dim,
205                lhs.as_tensor_arg::<E>(line_size),
206                rhs.as_tensor_arg::<E>(line_size),
207                TensorArg::alias(0),
208                None,
209                false,
210                rhs.strides != lhs.strides || rhs.shape != lhs.shape,
211            );
212
213            lhs
214        } else if rhs.can_mut_broadcast(&lhs) {
215            kernel_binop::launch_unchecked::<E, O, R>(
216                &client,
217                cube_count,
218                cube_dim,
219                lhs.as_tensor_arg::<E>(line_size),
220                rhs.as_tensor_arg::<E>(line_size),
221                TensorArg::alias(1),
222                None,
223                rhs.strides != lhs.strides || rhs.shape != lhs.shape,
224                false,
225            );
226
227            rhs
228        } else {
229            let output = empty_device::<R, E>(lhs.client.clone(), lhs.device.clone(), shape_out);
230            let to_contiguous_lhs = lhs.strides != output.strides || lhs.shape != output.shape;
231            let to_contiguous_rhs = rhs.strides != output.strides || rhs.shape != output.shape;
232
233            kernel_binop::launch_unchecked::<E, O, R>(
234                &client,
235                cube_count,
236                cube_dim,
237                lhs.as_tensor_arg::<E>(line_size),
238                rhs.as_tensor_arg::<E>(line_size),
239                output.as_tensor_arg::<E>(line_size),
240                None,
241                to_contiguous_lhs,
242                to_contiguous_rhs,
243            );
244
245            output
246        }
247    }
248}
249
250pub(crate) fn launch_scalar_binop<R: JitRuntime, E: JitElement, O: BinaryOpFamily>(
251    mut tensor: JitTensor<R>,
252    scalar: E,
253) -> JitTensor<R> {
254    if !tensor.is_contiguous_buffer() {
255        tensor = into_contiguous(tensor);
256    }
257
258    // Vectorization is only enabled when the last dimension is contiguous.
259    let ndims = tensor.shape.num_dims();
260    let line_size = tensor_line_size_parallel(
261        R::line_size_elem(&E::as_elem_native_unchecked()),
262        &tensor.shape.dims,
263        &tensor.strides,
264        ndims - 1,
265    );
266    let client = tensor.client.clone();
267    let num_elems = tensor.shape.num_elements();
268
269    let cube_dim = CubeDim::default();
270    let cube_count = calculate_cube_count_elemwise(num_elems / line_size as usize, cube_dim);
271
272    unsafe {
273        if tensor.can_mut() {
274            kernel_scalar_binop::launch_unchecked::<E, O, R>(
275                &client,
276                cube_count,
277                cube_dim,
278                tensor.as_tensor_arg::<E>(line_size),
279                ScalarArg::new(scalar),
280                TensorArg::alias(0),
281            );
282
283            tensor
284        } else {
285            let output = empty_device::<R, E>(
286                tensor.client.clone(),
287                tensor.device.clone(),
288                tensor.shape.clone(),
289            );
290
291            kernel_scalar_binop::launch_unchecked::<E, O, R>(
292                &client,
293                cube_count,
294                CubeDim::default(),
295                tensor.as_tensor_arg::<E>(line_size),
296                ScalarArg::new(scalar),
297                output.as_tensor_arg::<E>(line_size),
298            );
299
300            output
301        }
302    }
303}