Skip to main content

burn_ndarray/ops/
base.rs

1use alloc::{vec, vec::Vec};
2use burn_backend::element::{Element, ElementConversion};
3#[cfg(feature = "simd")]
4use burn_backend::{DType, quantization::QuantValue};
5use core::fmt::Debug;
6use core::marker::PhantomData;
7use ndarray::IntoDimension;
8use ndarray::SliceInfo;
9use ndarray::Zip;
10use ndarray::s;
11use ndarray::{Array2, ArrayD};
12use num_traits::Signed;
13#[cfg(feature = "simd")]
14use paste::paste;
15
16#[cfg(not(feature = "std"))]
17#[allow(unused_imports)]
18use num_traits::Float;
19
20#[cfg(feature = "simd")]
21use crate::ops::simd::{
22    binary::try_binary_simd,
23    binary_elemwise::{
24        VecAdd, VecBitAnd, VecBitOr, VecBitXor, VecClamp, VecDiv, VecMax, VecMin, VecMul, VecSub,
25        try_binary_scalar_simd,
26    },
27    cmp::{
28        VecEquals, VecGreater, VecGreaterEq, VecLower, VecLowerEq, try_cmp_scalar_simd,
29        try_cmp_simd,
30    },
31    unary::{RecipVec, VecAbs, VecBitNot, try_unary_simd},
32};
33use crate::reshape;
34use crate::{
35    IntNdArrayElement, ShapeOps,
36    ops::macros::{
37        cummax_dim, cummin_dim, cumprod_dim, cumsum_dim, keepdim, mean_dim, prod_dim, sum_dim,
38    },
39};
40use crate::{SharedArray, element::NdArrayElement};
41use burn_backend::ops::unfold::calculate_unfold_shape;
42use burn_backend::{Shape, Slice};
43use ndarray::ArrayView;
44use ndarray::Axis;
45use ndarray::Dim;
46use ndarray::IxDyn;
47use ndarray::SliceInfoElem;
48
49pub struct NdArrayOps<E> {
50    e: PhantomData<E>,
51}
52
53pub(crate) struct NdArrayMathOps<E> {
54    e: PhantomData<E>,
55}
56
57impl<E> NdArrayOps<E>
58where
59    E: Copy + Debug + Element + crate::AddAssignElement,
60{
61    pub fn slice(tensor: SharedArray<E>, slices: &[Slice]) -> SharedArray<E> {
62        let slices = Self::to_slice_args_with_steps(slices, tensor.shape().num_dims());
63        tensor.slice_move(slices.as_slice()).into_shared()
64    }
65
66    pub fn slice_assign(
67        tensor: SharedArray<E>,
68        slices: &[Slice],
69        value: SharedArray<E>,
70    ) -> SharedArray<E> {
71        let slices = Self::to_slice_args_with_steps(slices, tensor.shape().num_dims());
72        let mut array = tensor.into_owned();
73        array.slice_mut(slices.as_slice()).assign(&value);
74        array.into_shared()
75    }
76
77    pub fn mask_where(
78        tensor: SharedArray<E>,
79        mask: SharedArray<bool>,
80        source: SharedArray<E>,
81    ) -> SharedArray<E> {
82        let tensor = tensor.broadcast(mask.dim()).unwrap();
83        let source = source.broadcast(mask.dim()).unwrap();
84        Zip::from(&tensor)
85            .and(&mask)
86            .and(&source)
87            .map_collect(|&x, &mask_val, &y| if mask_val { y } else { x })
88            .into_shared()
89    }
90
91    pub fn mask_fill(tensor: SharedArray<E>, mask: SharedArray<bool>, value: E) -> SharedArray<E> {
92        // Use into_owned() instead of clone() - only copies if shared, avoids copy if unique
93        let mut output = tensor.into_owned();
94        let broadcast_mask = mask.broadcast(output.dim()).unwrap();
95        Zip::from(&mut output)
96            .and(&broadcast_mask)
97            .for_each(|out, &mask_val| {
98                if mask_val {
99                    *out = value;
100                }
101            });
102        output.into_shared()
103    }
104
105    pub fn gather<I: NdArrayElement>(
106        dim: usize,
107        mut tensor: SharedArray<E>,
108        mut indices: SharedArray<I>,
109    ) -> SharedArray<E> {
110        let ndims = tensor.shape().num_dims();
111        if dim != ndims - 1 {
112            tensor.swap_axes(ndims - 1, dim);
113            indices.swap_axes(ndims - 1, dim);
114        }
115        let (shape_tensor, shape_indices) = (tensor.shape(), indices.shape().into_shape());
116        let (size_tensor, size_index) = (shape_tensor[ndims - 1], shape_indices.dims[ndims - 1]);
117        let batch_size = Self::gather_batch_size(shape_tensor, &shape_indices.dims);
118
119        let indices = NdArrayOps::reshape(indices, Shape::new([batch_size, size_index]));
120        let tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor]));
121        let mut output = Array2::from_elem((batch_size, size_index), 0.elem::<E>());
122
123        for b in 0..batch_size {
124            let indices = indices.slice(s!(b, ..));
125            for (i, index) in indices.iter().enumerate() {
126                output[[b, i]] = tensor[[b, index.elem::<i64>() as usize]];
127            }
128        }
129
130        let mut output = NdArrayOps::reshape(output.into_shared().into_dyn(), shape_indices);
131
132        if dim != ndims - 1 {
133            output.swap_axes(ndims - 1, dim);
134        }
135
136        output
137    }
138
139    pub fn scatter<I: NdArrayElement>(
140        dim: usize,
141        mut tensor: SharedArray<E>,
142        mut indices: SharedArray<I>,
143        mut value: SharedArray<E>,
144    ) -> SharedArray<E> {
145        let ndims = tensor.shape().num_dims();
146        if dim != ndims - 1 {
147            tensor.swap_axes(ndims - 1, dim);
148            indices.swap_axes(ndims - 1, dim);
149            value.swap_axes(ndims - 1, dim);
150        }
151
152        let (shape_tensor, shape_indices, shape_value) =
153            (tensor.shape().into_shape(), indices.shape(), value.shape());
154        let (size_tensor, size_index, size_value) = (
155            shape_tensor.dims[ndims - 1],
156            shape_indices[ndims - 1],
157            shape_value[ndims - 1],
158        );
159        let batch_size = Self::gather_batch_size(&shape_tensor.dims, shape_indices);
160
161        if shape_value != shape_indices {
162            panic!(
163                "Invalid dimension: the shape of the index tensor should be the same as the value \
164                 tensor: Index {:?} value {:?}",
165                shape_indices, shape_value
166            );
167        }
168
169        let indices = NdArrayOps::reshape(indices, Shape::new([batch_size, size_index]));
170        let value = NdArrayOps::reshape(value, Shape::new([batch_size, size_value]));
171        let mut tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor]));
172
173        for b in 0..batch_size {
174            let indices = indices.slice(s!(b, ..));
175
176            for (i, index) in indices.iter().enumerate() {
177                let index = index.elem::<i64>() as usize;
178                tensor[[b, index]].add_assign(value[[b, i]]);
179            }
180        }
181
182        let mut output = NdArrayOps::reshape(tensor.into_shared().into_dyn(), shape_tensor);
183        if dim != ndims - 1 {
184            output.swap_axes(ndims - 1, dim);
185        }
186        output
187    }
188
189    fn gather_batch_size(shape_tensor: &[usize], shape_indices: &[usize]) -> usize {
190        let ndims = shape_tensor.num_dims();
191        let mut batch_size = 1;
192
193        for i in 0..ndims - 1 {
194            if shape_tensor[i] != shape_indices[i] {
195                panic!(
196                    "Unsupported dimension, only the last dimension can differ: Tensor {:?} Index \
197                     {:?}",
198                    shape_tensor, shape_indices
199                );
200            }
201            batch_size *= shape_indices[i];
202        }
203
204        batch_size
205    }
206
207    pub fn reshape(tensor: SharedArray<E>, shape: Shape) -> SharedArray<E> {
208        reshape!(
209            ty E,
210            shape shape,
211            array tensor,
212            d shape.num_dims()
213        )
214    }
215
216    pub(crate) fn concatenate(
217        arrays: &[ndarray::ArrayView<E, IxDyn>],
218        dim: usize,
219    ) -> SharedArray<E> {
220        let array = ndarray::concatenate(Axis(dim), arrays)
221            .unwrap()
222            .into_shared();
223
224        // Transform column-major layout into row-major (standard) layout. (fix #1053)
225        // Get shape first (via reference), then pass ownership to avoid clone
226        let shape = array.shape().into_shape();
227        Self::reshape(array, shape)
228    }
229
230    pub fn cat(tensors: Vec<SharedArray<E>>, dim: usize) -> SharedArray<E> {
231        let arrays: Vec<_> = tensors.iter().map(|t| t.view()).collect();
232        Self::concatenate(&arrays, dim)
233    }
234
235    #[allow(clippy::wrong_self_convention)]
236    fn to_slice_args_with_steps(
237        burn_slices: &[burn_backend::Slice],
238        ndims: usize,
239    ) -> Vec<SliceInfoElem> {
240        let mut slices = vec![SliceInfoElem::NewAxis; ndims];
241
242        for i in 0..ndims {
243            slices[i] = if i < burn_slices.len() {
244                let slice = &burn_slices[i];
245
246                // Check for empty range (would result in no elements)
247                if let Some(end) = slice.end
248                    && slice.start == end
249                {
250                    SliceInfoElem::Slice {
251                        start: 0,
252                        end: Some(0),
253                        step: 1,
254                    }
255                } else {
256                    // Pass slice parameters directly to ndarray
257                    // ndarray handles both positive and negative steps correctly:
258                    // - Positive step: iterates forward from start
259                    // - Negative step: iterates backward from the last element in range
260                    SliceInfoElem::Slice {
261                        start: slice.start,
262                        end: slice.end,
263                        step: slice.step,
264                    }
265                }
266            } else {
267                // Dimension not specified in slices - use full range
268                SliceInfoElem::Slice {
269                    start: 0,
270                    end: None,
271                    step: 1,
272                }
273            }
274        }
275
276        slices
277    }
278
279    pub fn swap_dims(mut tensor: SharedArray<E>, dim1: usize, dim2: usize) -> SharedArray<E> {
280        tensor.swap_axes(dim1, dim2);
281
282        tensor
283    }
284
285    pub fn permute(tensor: SharedArray<E>, axes: &[usize]) -> SharedArray<E> {
286        tensor.permuted_axes(axes.into_dimension())
287    }
288
289    /// Broadcasts the tensor to the given shape
290    pub(crate) fn expand(tensor: SharedArray<E>, shape: Shape) -> SharedArray<E> {
291        tensor
292            .broadcast(shape.dims.into_dimension())
293            .expect("The shapes should be broadcastable")
294            // need to convert view to owned array because NdArrayTensor expects owned array
295            // and try_into_owned_nocopy() panics for broadcasted arrays (zero strides)
296            .into_owned()
297            .into_shared()
298    }
299
300    pub fn flip(tensor: SharedArray<E>, axes: &[usize]) -> SharedArray<E> {
301        let slice_items: Vec<_> = (0..tensor.shape().num_dims())
302            .map(|i| {
303                if axes.contains(&i) {
304                    SliceInfoElem::Slice {
305                        start: 0,
306                        end: None,
307                        step: -1,
308                    }
309                } else {
310                    SliceInfoElem::Slice {
311                        start: 0,
312                        end: None,
313                        step: 1,
314                    }
315                }
316            })
317            .collect();
318        let slice_info =
319            SliceInfo::<Vec<SliceInfoElem>, IxDyn, IxDyn>::try_from(slice_items).unwrap();
320        tensor.slice(slice_info).into_owned().into_shared()
321    }
322
323    /// Unfold windows along a dimension.
324    ///
325    /// # Warning
326    ///
327    /// This is a copy impl; `ndarray` doesn't expose the layout machinery
328    /// necessary to build the stride view.
329    ///
330    /// Returns a copy of the tensor with all complete windows of size `size` in dimension `dim`;
331    /// where windows are advanced by `step` at each index.
332    ///
333    /// The number of windows is `max(0, (shape[dim] - size).ceil_div(step))`.
334    ///
335    /// # Arguments
336    ///
337    /// * `tensor` - The input tensor to unfold; of shape ``[pre=..., dim shape, post=...]``
338    /// * `dim` - the dimension to unfold.
339    /// * `size` - the size of each unfolded window.
340    /// * `step` - the step between each window.
341    ///
342    /// # Returns
343    ///
344    /// A tensor view with shape ``[pre=..., windows, post=..., size]``.
345    #[allow(unused)]
346    pub(crate) fn unfold(
347        tensor: SharedArray<E>,
348        dim: usize,
349        size: usize,
350        step: usize,
351    ) -> SharedArray<E> {
352        let result_shape = calculate_unfold_shape(tensor.shape(), dim, size, step);
353        let windows = result_shape[dim];
354
355        let mut slices = vec![Slice::new(0, None, 1); tensor.shape().len()];
356        let new_axis = slices.len();
357
358        let mut stack = Vec::with_capacity(windows);
359        for widx in 0..windows {
360            let start = widx * step;
361            let end = start + size;
362            slices[dim] = Slice::new(start as isize, Some(end as isize), 1);
363
364            let mut window_slice =
365                tensor.slice(Self::to_slice_args_with_steps(&slices, slices.len()).as_slice());
366            window_slice.insert_axis_inplace(Axis(new_axis));
367            window_slice.swap_axes(dim, new_axis);
368
369            stack.push(window_slice);
370        }
371        Self::concatenate(&stack, dim)
372    }
373}
374
375#[cfg(feature = "simd")]
376macro_rules! dispatch_binary_simd {
377    (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{
378        paste! {
379            let simd = match $elem::dtype() {
380                $(DType::[<$ty:upper>] => try_binary_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)*
381                _ => Err(($lhs, $rhs)),
382            };
383            match simd {
384                Ok(out) => return out,
385                Err(args) => args,
386            }
387        }
388    }};
389    ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{
390        paste! {
391            let simd = match $elem::dtype() {
392                $(DType::[<$ty:upper>] => try_binary_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)*
393                DType::QFloat(strategy) => match strategy.value {
394                    QuantValue::Q8F | QuantValue::Q8S => try_binary_simd::<$elem, $elem, i8, i8, $op>($lhs, $rhs),
395                    _ => Err(($lhs, $rhs)),
396                },
397                _ => Err(($lhs, $rhs)),
398            };
399            match simd {
400                Ok(out) => return out,
401                Err(args) => args,
402            }
403        }
404    }};
405}
406
407#[cfg(not(feature = "simd"))]
408macro_rules! dispatch_binary_simd {
409    (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ ($lhs, $rhs) }};
410    ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ ($lhs, $rhs) }};
411}
412
413#[cfg(feature = "simd")]
414macro_rules! dispatch_binary_scalar_simd {
415    (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{
416        paste! {
417            let simd = match $elem::dtype() {
418                $(DType::[<$ty:upper>] => try_binary_scalar_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)*
419                _ => Err($lhs),
420            };
421            match simd {
422                Ok(out) => return out,
423                Err(args) => args,
424            }
425        }
426    }};
427    ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{
428        paste! {
429            let simd = match $elem::dtype() {
430                $(DType::[<$ty:upper>] => try_binary_scalar_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)*
431                DType::QFloat(strategy) => match strategy.value {
432                    QuantValue::Q8F | QuantValue::Q8S => try_binary_scalar_simd::<$elem, $elem, i8, i8, $op>($lhs, $rhs),
433                    QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S | QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => Err($lhs)
434                },
435                _ => Err($lhs),
436            };
437            match simd {
438                Ok(out) => return out,
439                Err(args) => args,
440            }
441        }
442    }};
443}
444
445#[cfg(not(feature = "simd"))]
446macro_rules! dispatch_binary_scalar_simd {
447    (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ $lhs }};
448    ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ $lhs }};
449}
450
451#[cfg(feature = "simd")]
452macro_rules! dispatch_cmp_simd {
453    ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{
454        paste! {
455            let simd = match $elem::dtype() {
456                $(DType::[<$ty:upper>] => try_cmp_simd::<$elem, $ty, $op>($lhs, $rhs),)*
457                DType::QFloat(strategy) => match strategy.value {
458                    QuantValue::Q8F | QuantValue::Q8S => try_cmp_simd::<$elem, i8, $op>($lhs, $rhs),
459                    QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S | QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => Err(($lhs, $rhs))
460                },
461                _ => Err(($lhs, $rhs)),
462            };
463            match simd {
464                Ok(out) => return out,
465                Err(args) => args,
466            }
467        }
468    }};
469}
470
471#[cfg(not(feature = "simd"))]
472macro_rules! dispatch_cmp_simd {
473    ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ ($lhs, $rhs) }};
474}
475
476#[cfg(feature = "simd")]
477macro_rules! dispatch_cmp_scalar_simd {
478    ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{
479        paste! {
480            let simd = match $elem::dtype() {
481                $(DType::[<$ty:upper>] => try_cmp_scalar_simd::<$elem, $ty, $op>($lhs, $rhs),)*
482                DType::QFloat(strategy) => match strategy.value {
483                    QuantValue::Q8F | QuantValue::Q8S => try_cmp_scalar_simd::<$elem, i8, $op>($lhs, $rhs),
484                    QuantValue::Q4F | QuantValue::Q4S | QuantValue::Q2F | QuantValue::Q2S | QuantValue::E4M3 | QuantValue::E5M2 | QuantValue::E2M1 => Err($lhs)
485                },
486                _ => Err($lhs),
487            };
488            match simd {
489                Ok(out) => return out,
490                Err(args) => args,
491            }
492        }
493    }};
494}
495
496#[cfg(not(feature = "simd"))]
497macro_rules! dispatch_cmp_scalar_simd {
498    ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ $lhs }};
499}
500
501#[cfg(feature = "simd")]
502macro_rules! dispatch_unary_simd {
503    ($elem: ty, $op: ty, $lhs: expr, $($ty: ty),*) => {{
504        paste! {
505            let simd = match $elem::dtype() {
506                $(DType::[<$ty:upper>] => try_unary_simd::<$elem, $elem, $ty, $ty, $op>($lhs),)*
507                _ => Err($lhs),
508            };
509            match simd {
510                Ok(out) => return out,
511                Err(args) => args,
512            }
513        }
514    }};
515}
516
517#[cfg(not(feature = "simd"))]
518macro_rules! dispatch_unary_simd {
519    ($elem: ty, $op: ty, $lhs: expr, $($ty: ty),*) => {{ $lhs }};
520}
521
522// Helper function to broadcast two tensors to a common shape for comparison operations
523// Returns broadcasted views that can be safely zipped
524fn broadcast_for_comparison<'a, E: Copy, S1, S2>(
525    lhs: &'a ndarray::ArrayBase<S1, ndarray::IxDyn>,
526    rhs: &'a ndarray::ArrayBase<S2, ndarray::IxDyn>,
527) -> (
528    ndarray::ArrayView<'a, E, ndarray::IxDyn>,
529    ndarray::ArrayView<'a, E, ndarray::IxDyn>,
530)
531where
532    S1: ndarray::Data<Elem = E>,
533    S2: ndarray::Data<Elem = E>,
534{
535    // Get shapes
536    let lhs_shape = lhs.shape();
537    let rhs_shape = rhs.shape();
538
539    // Compute broadcast shape using ndarray's broadcast compatibility rules
540    let ndims = lhs_shape.len().max(rhs_shape.len());
541    let mut broadcast_shape = vec![1; ndims];
542
543    for i in 0..ndims {
544        let lhs_dim = if i < lhs_shape.len() {
545            lhs_shape[lhs_shape.len() - 1 - i]
546        } else {
547            1
548        };
549        let rhs_dim = if i < rhs_shape.len() {
550            rhs_shape[rhs_shape.len() - 1 - i]
551        } else {
552            1
553        };
554
555        if lhs_dim == rhs_dim {
556            broadcast_shape[ndims - 1 - i] = lhs_dim;
557        } else if lhs_dim == 1 {
558            broadcast_shape[ndims - 1 - i] = rhs_dim;
559        } else if rhs_dim == 1 {
560            broadcast_shape[ndims - 1 - i] = lhs_dim;
561        } else {
562            panic!(
563                "Incompatible shapes for broadcasting: {:?} and {:?}",
564                lhs_shape, rhs_shape
565            );
566        }
567    }
568
569    // Create IxDyn from broadcast shape
570    let broadcast_dim = ndarray::IxDyn(&broadcast_shape);
571
572    // Broadcast both arrays
573    let lhs_broadcast = lhs
574        .broadcast(broadcast_dim.clone())
575        .expect("Failed to broadcast lhs");
576    let rhs_broadcast = rhs
577        .broadcast(broadcast_dim)
578        .expect("Failed to broadcast rhs");
579
580    (lhs_broadcast, rhs_broadcast)
581}
582
583impl<E> NdArrayMathOps<E>
584where
585    E: Copy + NdArrayElement,
586{
587    pub fn add(lhs: SharedArray<E>, rhs: SharedArray<E>) -> SharedArray<E> {
588        let (lhs, rhs) = dispatch_binary_simd!(
589            E, VecAdd, lhs, rhs, u8, i8, u16, i16, u32, i32, f32, u64, i64, f64
590        );
591
592        let array = &lhs + &rhs;
593        array.into_shared()
594    }
595
596    pub fn add_scalar(lhs: SharedArray<E>, rhs: E) -> SharedArray<E> {
597        let lhs = dispatch_binary_scalar_simd!(
598            E,
599            VecAdd,
600            lhs,
601            rhs.elem(),
602            u8,
603            i8,
604            u16,
605            i16,
606            u32,
607            i32,
608            f32,
609            u64,
610            i64,
611            f64
612        );
613
614        let array = lhs + rhs;
615        array.into_shared()
616    }
617
618    pub fn sub(lhs: SharedArray<E>, rhs: SharedArray<E>) -> SharedArray<E> {
619        let (lhs, rhs) = dispatch_binary_simd!(
620            E, VecSub, lhs, rhs, u8, i8, u16, i16, u32, i32, f32, u64, i64, f64
621        );
622
623        let array = lhs - rhs;
624        array.into_shared()
625    }
626
627    pub fn sub_scalar(lhs: SharedArray<E>, rhs: E) -> SharedArray<E> {
628        let lhs = dispatch_binary_scalar_simd!(
629            E,
630            VecSub,
631            lhs,
632            rhs.elem(),
633            u8,
634            i8,
635            u16,
636            i16,
637            u32,
638            i32,
639            f32,
640            u64,
641            i64,
642            f64
643        );
644
645        let array = lhs - rhs;
646        array.into_shared()
647    }
648
649    pub fn mul(lhs: SharedArray<E>, rhs: SharedArray<E>) -> SharedArray<E> {
650        let (lhs, rhs) =
651            dispatch_binary_simd!(noq, E, VecMul, lhs, rhs, u16, i16, u32, i32, f32, f64);
652
653        let array = lhs * rhs;
654        array.into_shared()
655    }
656
657    pub fn mul_scalar(lhs: SharedArray<E>, rhs: E) -> SharedArray<E> {
658        let lhs = dispatch_binary_scalar_simd!(
659            noq,
660            E,
661            VecMul,
662            lhs,
663            rhs.elem(),
664            u16,
665            i16,
666            u32,
667            i32,
668            f32,
669            f64
670        );
671
672        let array = lhs * rhs;
673        array.into_shared()
674    }
675
676    pub fn div(lhs: SharedArray<E>, rhs: SharedArray<E>) -> SharedArray<E> {
677        let (lhs, rhs) = dispatch_binary_simd!(noq, E, VecDiv, lhs, rhs, f32, f64);
678
679        let array = lhs / rhs;
680        array.into_shared()
681    }
682
683    pub fn div_scalar(lhs: SharedArray<E>, rhs: E) -> SharedArray<E> {
684        let lhs = dispatch_binary_scalar_simd!(noq, E, VecDiv, lhs, rhs.elem(), f32, f64);
685
686        let array = lhs / rhs;
687        array.into_shared()
688    }
689
690    pub fn remainder(lhs: SharedArray<E>, rhs: SharedArray<E>) -> SharedArray<E> {
691        // Use into_owned() instead of clone() - only copies if shared, avoids copy if unique
692        let mut out = lhs.into_owned();
693        Zip::from(&mut out).and(&rhs).for_each(|out_elem, &b| {
694            // out_elem holds lhs value; read it before overwriting with remainder
695            let a_f = (*out_elem).to_f64();
696            let b_f = b.to_f64();
697            let r = a_f - b_f * (a_f / b_f).floor();
698            *out_elem = r.elem();
699        });
700        out.into_shared()
701    }
702
703    pub fn remainder_scalar(lhs: SharedArray<E>, rhs: E) -> SharedArray<E>
704    where
705        E: core::ops::Rem<Output = E>,
706    {
707        let array = lhs.mapv(|x| ((x % rhs) + rhs) % rhs);
708        array.into_shared()
709    }
710
711    pub fn recip(tensor: SharedArray<E>) -> SharedArray<E> {
712        let tensor = dispatch_unary_simd!(E, RecipVec, tensor, f32);
713
714        let array = tensor.map(|x| 1.elem::<E>() / *x);
715        array.into_shared()
716    }
717
718    /// Sum all elements - zero-copy for borrowed storage.
719    pub fn sum_view(view: ArrayView<'_, E, IxDyn>) -> SharedArray<E> {
720        let sum = view.sum();
721        ArrayD::from_elem(IxDyn(&[1]), sum).into_shared()
722    }
723
724    /// Mean of all elements - zero-copy for borrowed storage.
725    pub fn mean_view(view: ArrayView<'_, E, IxDyn>) -> SharedArray<E> {
726        let mean = view.mean().unwrap();
727        ArrayD::from_elem(IxDyn(&[1]), mean).into_shared()
728    }
729
730    /// Product of all elements - zero-copy for borrowed storage.
731    pub fn prod_view(view: ArrayView<'_, E, IxDyn>) -> SharedArray<E> {
732        let prod = view.iter().fold(E::one(), |acc, &x| acc * x);
733        ArrayD::from_elem(IxDyn(&[1]), prod).into_shared()
734    }
735
736    /// Max of all elements - zero-copy for borrowed storage.
737    pub fn max_view(view: ArrayView<'_, E, IxDyn>) -> SharedArray<E> {
738        let max = view
739            .iter()
740            .copied()
741            .reduce(|a, b| if a > b { a } else { b })
742            .expect("Cannot compute max of empty tensor");
743        ArrayD::from_elem(IxDyn(&[1]), max).into_shared()
744    }
745
746    /// Min of all elements - zero-copy for borrowed storage.
747    pub fn min_view(view: ArrayView<'_, E, IxDyn>) -> SharedArray<E> {
748        let min = view
749            .iter()
750            .copied()
751            .reduce(|a, b| if a < b { a } else { b })
752            .expect("Cannot compute min of empty tensor");
753        ArrayD::from_elem(IxDyn(&[1]), min).into_shared()
754    }
755
756    /// Argmax along dimension - zero-copy for borrowed storage.
757    pub fn argmax_view<I: NdArrayElement>(
758        view: ArrayView<'_, E, IxDyn>,
759        dim: usize,
760    ) -> SharedArray<I> {
761        arg_view(view, dim, CmpType::Max)
762    }
763
764    /// Argmin along dimension - zero-copy for borrowed storage.
765    pub fn argmin_view<I: NdArrayElement>(
766        view: ArrayView<'_, E, IxDyn>,
767        dim: usize,
768    ) -> SharedArray<I> {
769        arg_view(view, dim, CmpType::Min)
770    }
771
772    pub fn mean_dim(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {
773        let ndims = tensor.shape().num_dims();
774        match ndims {
775            d if (1..=6).contains(&d) => keepdim!(dim, tensor, mean),
776            _ => panic!("Dim not supported {ndims}"),
777        }
778    }
779
780    pub fn sum_dim(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {
781        let ndims = tensor.shape().num_dims();
782        match ndims {
783            d if (1..=6).contains(&d) => keepdim!(dim, tensor, sum),
784            _ => panic!("Dim not supported {ndims}"),
785        }
786    }
787
788    pub fn prod_dim(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {
789        let ndims = tensor.shape().num_dims();
790        match ndims {
791            d if (1..=6).contains(&d) => keepdim!(dim, tensor, prod),
792            _ => panic!("Dim not supported {ndims}"),
793        }
794    }
795
796    pub fn cumsum(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {
797        cumsum_dim(tensor, dim)
798    }
799
800    pub fn cumprod(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {
801        cumprod_dim(tensor, dim)
802    }
803
804    pub fn cummin(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {
805        cummin_dim(tensor, dim)
806    }
807
808    pub fn cummax(tensor: SharedArray<E>, dim: usize) -> SharedArray<E> {
809        cummax_dim(tensor, dim)
810    }
811
812    pub fn select<I: NdArrayElement>(
813        tensor: SharedArray<E>,
814        dim: usize,
815        indices: SharedArray<I>,
816    ) -> SharedArray<E> {
817        let array = tensor.select(
818            Axis(dim),
819            &indices
820                .into_iter()
821                .map(|i| i.elem::<i64>() as usize)
822                .collect::<Vec<_>>(),
823        );
824
825        array.into_shared()
826    }
827
828    pub fn select_assign<I: NdArrayElement>(
829        tensor: SharedArray<E>,
830        dim: usize,
831        indices: SharedArray<I>,
832        value: SharedArray<E>,
833    ) -> SharedArray<E> {
834        let mut output_array = tensor.into_owned();
835
836        for (index_value, index) in indices.into_iter().enumerate() {
837            let mut view = output_array.index_axis_mut(Axis(dim), index.elem::<i64>() as usize);
838            let value = value.index_axis(Axis(dim), index_value);
839
840            view.zip_mut_with(&value, |a, b| *a += *b);
841        }
842
843        output_array.into_shared()
844    }
845    pub fn argmax<I: NdArrayElement>(tensor: SharedArray<E>, dim: usize) -> SharedArray<I> {
846        arg(tensor, dim, CmpType::Max)
847    }
848
849    pub fn argmin<I: NdArrayElement>(tensor: SharedArray<E>, dim: usize) -> SharedArray<I> {
850        arg(tensor, dim, CmpType::Min)
851    }
852
853    pub fn clamp_min(tensor: SharedArray<E>, min: E) -> SharedArray<E> {
854        let mut tensor = dispatch_binary_scalar_simd!(
855            E,
856            VecMax,
857            tensor,
858            min.elem(),
859            u8,
860            i8,
861            u16,
862            i16,
863            u32,
864            i32,
865            f32,
866            u64,
867            i64,
868            f64
869        );
870
871        tensor.mapv_inplace(|x| match x < min {
872            true => min,
873            false => x,
874        });
875
876        tensor
877    }
878
879    pub fn clamp_max(tensor: SharedArray<E>, max: E) -> SharedArray<E> {
880        let mut tensor = dispatch_binary_scalar_simd!(
881            E,
882            VecMin,
883            tensor,
884            max.elem(),
885            u8,
886            i8,
887            u16,
888            i16,
889            u32,
890            i32,
891            f32,
892            u64,
893            i64,
894            f64
895        );
896
897        tensor.mapv_inplace(|x| match x > max {
898            true => max,
899            false => x,
900        });
901
902        tensor
903    }
904
905    pub fn clamp(tensor: SharedArray<E>, min: E, max: E) -> SharedArray<E> {
906        let mut tensor = dispatch_binary_scalar_simd!(
907            E,
908            VecClamp,
909            tensor,
910            (min.elem(), max.elem()),
911            u8,
912            i8,
913            u16,
914            i16,
915            u32,
916            i32,
917            f32,
918            u64,
919            i64,
920            f64
921        );
922
923        tensor.mapv_inplace(|x| match x < min {
924            true => min,
925            false => match x > max {
926                true => max,
927                false => x,
928            },
929        });
930
931        tensor
932    }
933
934    pub(crate) fn elementwise_op<OtherE>(
935        lhs: SharedArray<E>,
936        rhs: SharedArray<OtherE>,
937        var_name: impl FnMut(&E, &OtherE) -> E,
938    ) -> SharedArray<E> {
939        let lhs = lhs.broadcast(rhs.dim()).unwrap_or(lhs.view());
940        let rhs = rhs.broadcast(lhs.dim()).unwrap_or(rhs.view());
941
942        Zip::from(lhs).and(rhs).map_collect(var_name).into_shared()
943    }
944
945    pub(crate) fn elementwise_op_scalar(
946        lhs: SharedArray<E>,
947        var_name: impl FnMut(E) -> E,
948    ) -> SharedArray<E> {
949        lhs.mapv(var_name).into_shared()
950    }
951
952    pub(crate) fn sign_op(tensor: SharedArray<E>) -> SharedArray<E>
953    where
954        E: Signed,
955    {
956        let zero = 0.elem();
957        let one = 1.elem::<E>();
958
959        tensor
960            .mapv(|x| {
961                if x > zero {
962                    one
963                } else if x < zero {
964                    -one
965                } else {
966                    zero
967                }
968            })
969            .into_shared()
970    }
971
972    pub(crate) fn abs(tensor: SharedArray<E>) -> SharedArray<E> {
973        let tensor = dispatch_unary_simd!(E, VecAbs, tensor, i8, i16, i32, f32, f64);
974
975        tensor.mapv_into(|a| a.abs_elem()).into_shared()
976    }
977
978    pub(crate) fn equal(lhs: SharedArray<E>, rhs: SharedArray<E>) -> SharedArray<bool> {
979        let (lhs, rhs) = dispatch_cmp_simd!(
980            E, VecEquals, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64
981        );
982
983        // Use the helper to broadcast both arrays to a common shape
984        let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs);
985        // Now we can safely zip and compare
986        Zip::from(&lhs_broadcast)
987            .and(&rhs_broadcast)
988            .map_collect(|&lhs, &rhs| lhs == rhs)
989            .into_shared()
990    }
991
992    pub(crate) fn equal_elem(lhs: SharedArray<E>, rhs: E) -> SharedArray<bool> {
993        let lhs = dispatch_cmp_scalar_simd!(
994            E,
995            VecEquals,
996            lhs,
997            rhs.elem(),
998            u8,
999            i8,
1000            u16,
1001            i16,
1002            u32,
1003            f32,
1004            i32,
1005            u64,
1006            i64,
1007            f64
1008        );
1009
1010        lhs.mapv(|a| a == rhs).into_shared()
1011    }
1012
1013    pub(crate) fn greater(lhs: SharedArray<E>, rhs: SharedArray<E>) -> SharedArray<bool> {
1014        let (lhs, rhs) = dispatch_cmp_simd!(
1015            E, VecGreater, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64
1016        );
1017
1018        // Use the helper to broadcast both arrays to a common shape
1019        let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs);
1020        // Now we can safely zip and compare
1021        Zip::from(&lhs_broadcast)
1022            .and(&rhs_broadcast)
1023            .map_collect(|&lhs, &rhs| lhs > rhs)
1024            .into_shared()
1025    }
1026
1027    pub(crate) fn greater_elem(lhs: SharedArray<E>, rhs: E) -> SharedArray<bool> {
1028        let lhs = dispatch_cmp_scalar_simd!(
1029            E,
1030            VecGreater,
1031            lhs,
1032            rhs.elem(),
1033            u8,
1034            i8,
1035            u16,
1036            i16,
1037            u32,
1038            f32,
1039            i32,
1040            u64,
1041            i64,
1042            f64
1043        );
1044
1045        lhs.mapv(|a| a > rhs).into_shared()
1046    }
1047
1048    pub(crate) fn greater_equal(lhs: SharedArray<E>, rhs: SharedArray<E>) -> SharedArray<bool> {
1049        let (lhs, rhs) = dispatch_cmp_simd!(
1050            E,
1051            VecGreaterEq,
1052            lhs,
1053            rhs,
1054            u8,
1055            i8,
1056            u16,
1057            i16,
1058            u32,
1059            f32,
1060            i32,
1061            u64,
1062            i64,
1063            f64
1064        );
1065
1066        // Use the helper to broadcast both arrays to a common shape
1067        let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs);
1068        // Now we can safely zip and compare
1069        Zip::from(&lhs_broadcast)
1070            .and(&rhs_broadcast)
1071            .map_collect(|&lhs, &rhs| lhs >= rhs)
1072            .into_shared()
1073    }
1074
1075    pub(crate) fn greater_equal_elem(lhs: SharedArray<E>, rhs: E) -> SharedArray<bool> {
1076        let lhs = dispatch_cmp_scalar_simd!(
1077            E,
1078            VecGreaterEq,
1079            lhs,
1080            rhs.elem(),
1081            u8,
1082            i8,
1083            u16,
1084            i16,
1085            u32,
1086            f32,
1087            i32,
1088            u64,
1089            i64,
1090            f64
1091        );
1092
1093        lhs.mapv(|a| a >= rhs).into_shared()
1094    }
1095
1096    pub(crate) fn lower_equal(lhs: SharedArray<E>, rhs: SharedArray<E>) -> SharedArray<bool> {
1097        let (lhs, rhs) = dispatch_cmp_simd!(
1098            E, VecLowerEq, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64
1099        );
1100
1101        // Use the helper to broadcast both arrays to a common shape
1102        let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs);
1103        // Now we can safely zip and compare
1104        Zip::from(&lhs_broadcast)
1105            .and(&rhs_broadcast)
1106            .map_collect(|&lhs, &rhs| lhs <= rhs)
1107            .into_shared()
1108    }
1109
1110    pub(crate) fn lower_equal_elem(lhs: SharedArray<E>, rhs: E) -> SharedArray<bool> {
1111        let lhs = dispatch_cmp_scalar_simd!(
1112            E,
1113            VecLowerEq,
1114            lhs,
1115            rhs.elem(),
1116            u8,
1117            i8,
1118            u16,
1119            i16,
1120            u32,
1121            f32,
1122            i32,
1123            u64,
1124            i64,
1125            f64
1126        );
1127
1128        lhs.mapv(|a| a <= rhs).into_shared()
1129    }
1130
1131    pub(crate) fn lower(lhs: SharedArray<E>, rhs: SharedArray<E>) -> SharedArray<bool> {
1132        let (lhs, rhs) = dispatch_cmp_simd!(
1133            E, VecLower, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64
1134        );
1135
1136        // Use the helper to broadcast both arrays to a common shape
1137        let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs);
1138
1139        // Now we can safely zip and compare
1140        Zip::from(&lhs_broadcast)
1141            .and(&rhs_broadcast)
1142            .map_collect(|&lhs, &rhs| lhs < rhs)
1143            .into_shared()
1144    }
1145
1146    pub(crate) fn lower_elem(lhs: SharedArray<E>, rhs: E) -> SharedArray<bool> {
1147        let lhs = dispatch_cmp_scalar_simd!(
1148            E,
1149            VecLower,
1150            lhs,
1151            rhs.elem(),
1152            u8,
1153            i8,
1154            u16,
1155            i16,
1156            u32,
1157            f32,
1158            i32,
1159            u64,
1160            i64,
1161            f64
1162        );
1163
1164        lhs.mapv(|a| a < rhs).into_shared()
1165    }
1166}
1167
1168pub struct NdArrayBitOps<I: IntNdArrayElement>(PhantomData<I>);
1169
1170impl<I: IntNdArrayElement> NdArrayBitOps<I> {
1171    pub(crate) fn bitand(lhs: SharedArray<I>, rhs: SharedArray<I>) -> SharedArray<I> {
1172        let (lhs, rhs) =
1173            dispatch_binary_simd!(I, VecBitAnd, lhs, rhs, i8, u8, i16, u16, i32, u32, i64, u64);
1174
1175        NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
1176            (a.elem::<i64>() & (b.elem::<i64>())).elem()
1177        })
1178    }
1179
1180    pub(crate) fn bitand_scalar(lhs: SharedArray<I>, rhs: I) -> SharedArray<I> {
1181        let lhs = dispatch_binary_scalar_simd!(
1182            I,
1183            VecBitAnd,
1184            lhs,
1185            rhs.elem(),
1186            i8,
1187            u8,
1188            i16,
1189            u16,
1190            i32,
1191            u32,
1192            i64,
1193            u64
1194        );
1195
1196        NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| {
1197            (a.elem::<i64>() & rhs.elem::<i64>()).elem()
1198        })
1199    }
1200
1201    pub(crate) fn bitor(lhs: SharedArray<I>, rhs: SharedArray<I>) -> SharedArray<I> {
1202        let (lhs, rhs) =
1203            dispatch_binary_simd!(I, VecBitOr, lhs, rhs, i8, u8, i16, u16, i32, u32, i64, u64);
1204
1205        NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
1206            (a.elem::<i64>() | (b.elem::<i64>())).elem()
1207        })
1208    }
1209
1210    pub(crate) fn bitor_scalar(lhs: SharedArray<I>, rhs: I) -> SharedArray<I> {
1211        let lhs = dispatch_binary_scalar_simd!(
1212            I,
1213            VecBitOr,
1214            lhs,
1215            rhs.elem(),
1216            i8,
1217            u8,
1218            i16,
1219            u16,
1220            i32,
1221            u32,
1222            i64,
1223            u64
1224        );
1225
1226        NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| {
1227            (a.elem::<i64>() | rhs.elem::<i64>()).elem()
1228        })
1229    }
1230
1231    pub(crate) fn bitxor(lhs: SharedArray<I>, rhs: SharedArray<I>) -> SharedArray<I> {
1232        let (lhs, rhs) =
1233            dispatch_binary_simd!(I, VecBitXor, lhs, rhs, i8, u8, i16, u16, i32, u32, i64, u64);
1234
1235        NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
1236            (a.elem::<i64>() ^ (b.elem::<i64>())).elem()
1237        })
1238    }
1239
1240    pub(crate) fn bitxor_scalar(lhs: SharedArray<I>, rhs: I) -> SharedArray<I> {
1241        let lhs = dispatch_binary_scalar_simd!(
1242            I,
1243            VecBitXor,
1244            lhs,
1245            rhs.elem(),
1246            i8,
1247            u8,
1248            i16,
1249            u16,
1250            i32,
1251            u32,
1252            i64,
1253            u64
1254        );
1255
1256        NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| {
1257            (a.elem::<i64>() ^ rhs.elem::<i64>()).elem()
1258        })
1259    }
1260
1261    pub(crate) fn bitnot(tensor: SharedArray<I>) -> SharedArray<I> {
1262        let tensor =
1263            dispatch_unary_simd!(I, VecBitNot, tensor, i8, u8, i16, u16, i32, u32, i64, u64);
1264
1265        NdArrayMathOps::elementwise_op_scalar(tensor, |a: I| (!a.elem::<i64>()).elem())
1266    }
1267}
1268
1269pub struct NdArrayBoolOps;
1270
1271// Rust booleans are either `00000000` or `00000001`, so bitwise and/or is fine, but bitwise not would
1272// produce invalid values.
1273impl NdArrayBoolOps {
1274    pub(crate) fn equal(lhs: SharedArray<bool>, rhs: SharedArray<bool>) -> SharedArray<bool> {
1275        #[cfg(feature = "simd")]
1276        let (lhs, rhs) = match try_cmp_simd::<bool, u8, VecEquals>(lhs, rhs) {
1277            Ok(out) => return out,
1278            Err(args) => args,
1279        };
1280
1281        // Use the helper to broadcast both arrays to a common shape
1282        let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs);
1283        // Now we can safely zip and compare
1284        Zip::from(&lhs_broadcast)
1285            .and(&rhs_broadcast)
1286            .map_collect(|&lhs, &rhs| lhs == rhs)
1287            .into_shared()
1288    }
1289
1290    pub(crate) fn equal_elem(lhs: SharedArray<bool>, rhs: bool) -> SharedArray<bool> {
1291        #[cfg(feature = "simd")]
1292        let lhs = match try_cmp_scalar_simd::<bool, u8, VecEquals>(lhs, rhs.elem()) {
1293            Ok(out) => return out,
1294            Err(args) => args,
1295        };
1296
1297        lhs.mapv(|a| a == rhs).into_shared()
1298    }
1299
1300    pub(crate) fn and(lhs: SharedArray<bool>, rhs: SharedArray<bool>) -> SharedArray<bool> {
1301        #[cfg(feature = "simd")]
1302        let (lhs, rhs) = match try_binary_simd::<bool, bool, u8, u8, VecBitAnd>(lhs, rhs) {
1303            Ok(out) => return out,
1304            Err(args) => args,
1305        };
1306
1307        // Use the helper to broadcast both arrays to a common shape
1308        let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs);
1309        // Now we can safely zip and compare
1310        Zip::from(&lhs_broadcast)
1311            .and(&rhs_broadcast)
1312            .map_collect(|&lhs, &rhs| lhs && rhs)
1313            .into_shared()
1314    }
1315
1316    pub(crate) fn or(lhs: SharedArray<bool>, rhs: SharedArray<bool>) -> SharedArray<bool> {
1317        #[cfg(feature = "simd")]
1318        let (lhs, rhs) = match try_binary_simd::<bool, bool, u8, u8, VecBitOr>(lhs, rhs) {
1319            Ok(out) => return out,
1320            Err(args) => args,
1321        };
1322
1323        // Use the helper to broadcast both arrays to a common shape
1324        let (lhs_broadcast, rhs_broadcast) = broadcast_for_comparison(&lhs, &rhs);
1325        // Now we can safely zip and compare
1326        Zip::from(&lhs_broadcast)
1327            .and(&rhs_broadcast)
1328            .map_collect(|&lhs, &rhs| lhs || rhs)
1329            .into_shared()
1330    }
1331
1332    /// Any element is true - zero-copy for borrowed storage.
1333    pub fn any_view(view: ArrayView<'_, bool, IxDyn>) -> bool {
1334        view.iter().any(|&x| x)
1335    }
1336
1337    /// All elements are true - zero-copy for borrowed storage.
1338    pub fn all_view(view: ArrayView<'_, bool, IxDyn>) -> bool {
1339        view.iter().all(|&x| x)
1340    }
1341}
1342
1343enum CmpType {
1344    Min,
1345    Max,
1346}
1347
1348fn arg<E: NdArrayElement, I: NdArrayElement>(
1349    tensor: SharedArray<E>,
1350    dim: usize,
1351    cmp: CmpType,
1352) -> SharedArray<I> {
1353    arg_view(tensor.view(), dim, cmp)
1354}
1355
1356/// View-based argmax/argmin - zero-copy for borrowed storage.
1357fn arg_view<E: NdArrayElement, I: NdArrayElement>(
1358    view: ArrayView<'_, E, IxDyn>,
1359    dim: usize,
1360    cmp: CmpType,
1361) -> SharedArray<I> {
1362    let mut reshape = view.shape().to_vec();
1363    reshape[dim] = 1;
1364
1365    let output = view.map_axis(Axis(dim), |arr| {
1366        // Find the min/max value in the array, and return its index.
1367        let (_e, idx) = arr.indexed_iter().fold((arr[0], 0usize), |acc, (idx, e)| {
1368            let cmp = match cmp {
1369                CmpType::Min => e < &acc.0,
1370                CmpType::Max => e > &acc.0,
1371            };
1372
1373            if cmp { (*e, idx) } else { acc }
1374        });
1375
1376        (idx as i64).elem()
1377    });
1378
1379    let output = output.to_shape(Dim(reshape.as_slice())).unwrap();
1380
1381    output.into_shared()
1382}
1383
1384#[cfg(test)]
1385mod tests {
1386    use burn_backend::TensorData;
1387
1388    use crate::NdArrayTensor;
1389
1390    use super::*;
1391
1392    #[test]
1393    fn should_generate_row_major_layout_for_cat() {
1394        let expected_shape: &[usize] = &[4, 6, 2];
1395        let expected_strides: &[isize] = &[12, 2, 1];
1396        let NdArrayTensor::I32(expected_storage) = NdArrayTensor::from_data(TensorData::from([
1397            [[1, 0], [2, 0], [3, 0], [4, 0], [5, 0], [6, 0]],
1398            [[7, 0], [8, 0], [9, 0], [10, 0], [11, 0], [12, 0]],
1399            [[13, 0], [14, 0], [15, 0], [16, 0], [17, 0], [18, 0]],
1400            [[19, 0], [20, 0], [21, 0], [22, 0], [23, 0], [24, 0]],
1401        ])) else {
1402            panic!()
1403        };
1404        let expected_array = expected_storage.into_shared();
1405
1406        let NdArrayTensor::I32(tensor_storage) = NdArrayTensor::from_data(TensorData::from([
1407            [1, 2, 3, 4, 5, 6],
1408            [7, 8, 9, 10, 11, 12],
1409            [13, 14, 15, 16, 17, 18],
1410            [19, 20, 21, 22, 23, 24],
1411        ])) else {
1412            panic!()
1413        };
1414        let tensor = tensor_storage.into_shared();
1415
1416        // unsqueeze dim on the outermost axis
1417        let array = NdArrayOps::reshape(tensor, Shape::from([4, 6, 1]));
1418        let NdArrayTensor::I32(zeros_storage) =
1419            NdArrayTensor::from_data(TensorData::zeros::<i32, _>([4, 6, 1]))
1420        else {
1421            panic!()
1422        };
1423        let zeros = zeros_storage.into_shared();
1424        // make `ndarray` concatenates array on the outermost axis
1425        let array = NdArrayOps::cat([array, zeros].to_vec(), 2);
1426
1427        assert!(array.is_standard_layout());
1428        assert_eq!(array.shape(), expected_shape);
1429        assert_eq!(array.strides(), expected_strides);
1430        assert_eq!(
1431            array.into_iter().collect::<Vec<_>>(),
1432            expected_array.into_iter().collect::<Vec<_>>(),
1433        );
1434    }
1435}