burn_ndarray/ops/
base.rs

1use alloc::{vec, vec::Vec};
2use burn_tensor::ElementConversion;
3use burn_tensor::TensorData;
4use burn_tensor::TensorMetadata;
5#[cfg(feature = "simd")]
6use burn_tensor::{DType, quantization::QuantInputType};
7use core::fmt::Debug;
8use core::{marker::PhantomData, ops::Range};
9use ndarray::Array2;
10use ndarray::IntoDimension;
11use ndarray::SliceInfo;
12use ndarray::Zip;
13use ndarray::s;
14use num_traits::Signed;
15#[cfg(feature = "simd")]
16use paste::paste;
17
18#[cfg(not(feature = "std"))]
19#[allow(unused_imports)]
20use num_traits::Float;
21
22use burn_tensor::Shape;
23use ndarray::Axis;
24use ndarray::Dim;
25use ndarray::IxDyn;
26use ndarray::SliceInfoElem;
27
28use crate::element::NdArrayElement;
29#[cfg(feature = "simd")]
30use crate::ops::simd::{
31    binary::try_binary_simd,
32    binary_elemwise::{
33        VecAdd, VecBitAnd, VecBitOr, VecBitXor, VecClamp, VecDiv, VecMax, VecMin, VecMul, VecSub,
34        try_binary_scalar_simd,
35    },
36    cmp::{
37        VecEquals, VecGreater, VecGreaterEq, VecLower, VecLowerEq, try_cmp_scalar_simd,
38        try_cmp_simd,
39    },
40    unary::{RecipVec, VecAbs, VecBitNot, try_unary_simd},
41};
42use crate::{
43    IntNdArrayElement,
44    ops::macros::{keepdim, mean_dim, prod_dim, sum_dim},
45};
46use crate::{reshape, tensor::NdArrayTensor};
47
48pub struct NdArrayOps<E> {
49    e: PhantomData<E>,
50}
51
52pub(crate) struct NdArrayMathOps<E> {
53    e: PhantomData<E>,
54}
55
56impl<E> NdArrayOps<E>
57where
58    E: Copy + Debug + burn_tensor::Element,
59{
60    pub fn into_data(tensor: NdArrayTensor<E>) -> TensorData {
61        tensor.into_data()
62    }
63
64    pub fn slice(tensor: NdArrayTensor<E>, ranges: &[Range<usize>]) -> NdArrayTensor<E> {
65        let slices = Self::to_slice_args(ranges, tensor.shape().num_dims());
66        let array = tensor.array.slice_move(slices.as_slice()).into_shared();
67
68        NdArrayTensor { array }
69    }
70
71    pub fn slice_assign(
72        tensor: NdArrayTensor<E>,
73        ranges: &[Range<usize>],
74        value: NdArrayTensor<E>,
75    ) -> NdArrayTensor<E> {
76        let slices = Self::to_slice_args(ranges, tensor.shape().num_dims());
77        let mut array = tensor.array.into_owned();
78        array.slice_mut(slices.as_slice()).assign(&value.array);
79        let array = array.into_shared();
80
81        NdArrayTensor { array }
82    }
83
84    pub fn reshape(tensor: NdArrayTensor<E>, shape: Shape) -> NdArrayTensor<E> {
85        reshape!(
86            ty E,
87            shape shape,
88            array tensor.array,
89            d shape.num_dims()
90        )
91    }
92
93    pub(crate) fn concatenate(
94        arrays: &[ndarray::ArrayView<E, IxDyn>],
95        dim: usize,
96    ) -> NdArrayTensor<E> {
97        let array = ndarray::concatenate(Axis(dim), arrays)
98            .unwrap()
99            .into_shared();
100
101        // Transform column-major layout into row-major (standard) layout. (fix #1053)
102        let array = NdArrayTensor { array };
103        Self::reshape(array.clone(), array.shape())
104    }
105
106    pub fn cat(tensors: Vec<NdArrayTensor<E>>, dim: usize) -> NdArrayTensor<E> {
107        let arrays: Vec<_> = tensors.iter().map(|t| t.array.view()).collect();
108        Self::concatenate(&arrays, dim)
109    }
110
111    fn to_slice_args(ranges: &[Range<usize>], ndims: usize) -> Vec<SliceInfoElem> {
112        let mut slices = vec![SliceInfoElem::NewAxis; ndims];
113        for i in 0..ndims {
114            if i >= ranges.len() {
115                slices[i] = SliceInfoElem::Slice {
116                    start: 0,
117                    end: None,
118                    step: 1,
119                }
120            } else {
121                slices[i] = SliceInfoElem::Slice {
122                    start: ranges[i].start as isize,
123                    end: Some(ranges[i].end as isize),
124                    step: 1,
125                }
126            }
127        }
128        slices
129    }
130
131    pub fn swap_dims(tensor: NdArrayTensor<E>, dim1: usize, dim2: usize) -> NdArrayTensor<E> {
132        let mut array = tensor.array;
133        array.swap_axes(dim1, dim2);
134
135        NdArrayTensor::new(array)
136    }
137
138    pub fn permute(tensor: NdArrayTensor<E>, axes: &[usize]) -> NdArrayTensor<E> {
139        let array = tensor.array.permuted_axes(axes.into_dimension());
140
141        NdArrayTensor::new(array)
142    }
143
144    /// Broadcasts the tensor to the given shape
145    pub(crate) fn expand(tensor: NdArrayTensor<E>, shape: Shape) -> NdArrayTensor<E> {
146        let array = tensor
147            .array
148            .broadcast(shape.dims.into_dimension())
149            .expect("The shapes should be broadcastable")
150            // need to convert view to owned array because NdArrayTensor expects owned array
151            // and try_into_owned_nocopy() panics for broadcasted arrays (zero strides)
152            .into_owned()
153            .into_shared();
154        NdArrayTensor { array }
155    }
156
157    pub fn flip(tensor: NdArrayTensor<E>, axes: &[usize]) -> NdArrayTensor<E> {
158        let slice_items: Vec<_> = (0..tensor.shape().num_dims())
159            .map(|i| {
160                if axes.contains(&i) {
161                    SliceInfoElem::Slice {
162                        start: 0,
163                        end: None,
164                        step: -1,
165                    }
166                } else {
167                    SliceInfoElem::Slice {
168                        start: 0,
169                        end: None,
170                        step: 1,
171                    }
172                }
173            })
174            .collect();
175        let slice_info =
176            SliceInfo::<Vec<SliceInfoElem>, IxDyn, IxDyn>::try_from(slice_items).unwrap();
177        let array = tensor.array.slice(slice_info).into_owned().into_shared();
178
179        NdArrayTensor::new(array)
180    }
181}
182
183#[cfg(feature = "simd")]
184macro_rules! dispatch_binary_simd {
185    (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{
186        paste! {
187            let simd = match $elem::dtype() {
188                $(DType::[<$ty:upper>] => try_binary_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)*
189                _ => Err(($lhs, $rhs)),
190            };
191            match simd {
192                Ok(out) => return out,
193                Err(args) => args,
194            }
195        }
196    }};
197    ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{
198        paste! {
199            let simd = match $elem::dtype() {
200                $(DType::[<$ty:upper>] => try_binary_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)*
201                DType::QFloat(strategy) => match strategy.q_type {
202                    QuantInputType::QInt8 => try_binary_simd::<$elem, $elem, i8, i8, $op>($lhs, $rhs),
203                },
204                _ => Err(($lhs, $rhs)),
205            };
206            match simd {
207                Ok(out) => return out,
208                Err(args) => args,
209            }
210        }
211    }};
212}
213
214#[cfg(not(feature = "simd"))]
215macro_rules! dispatch_binary_simd {
216    (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ ($lhs, $rhs) }};
217    ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ ($lhs, $rhs) }};
218}
219
220#[cfg(feature = "simd")]
221macro_rules! dispatch_binary_scalar_simd {
222    (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{
223        paste! {
224            let simd = match $elem::dtype() {
225                $(DType::[<$ty:upper>] => try_binary_scalar_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)*
226                _ => Err($lhs),
227            };
228            match simd {
229                Ok(out) => return out,
230                Err(args) => args,
231            }
232        }
233    }};
234    ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{
235        paste! {
236            let simd = match $elem::dtype() {
237                $(DType::[<$ty:upper>] => try_binary_scalar_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)*
238                DType::QFloat(strategy) => match strategy.q_type {
239                    QuantInputType::QInt8 => try_binary_scalar_simd::<$elem, $elem, i8, i8, $op>($lhs, $rhs),
240                },
241                _ => Err($lhs),
242            };
243            match simd {
244                Ok(out) => return out,
245                Err(args) => args,
246            }
247        }
248    }};
249}
250
251#[cfg(not(feature = "simd"))]
252macro_rules! dispatch_binary_scalar_simd {
253    (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ $lhs }};
254    ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ $lhs }};
255}
256
257#[cfg(feature = "simd")]
258macro_rules! dispatch_cmp_simd {
259    ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{
260        paste! {
261            let simd = match $elem::dtype() {
262                $(DType::[<$ty:upper>] => try_cmp_simd::<$elem, $ty, $op>($lhs, $rhs),)*
263                DType::QFloat(strategy) => match strategy.q_type {
264                    QuantInputType::QInt8 => try_cmp_simd::<$elem, i8, $op>($lhs, $rhs),
265                },
266                _ => Err(($lhs, $rhs)),
267            };
268            match simd {
269                Ok(out) => return out,
270                Err(args) => args,
271            }
272        }
273    }};
274}
275
276#[cfg(not(feature = "simd"))]
277macro_rules! dispatch_cmp_simd {
278    ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ ($lhs, $rhs) }};
279}
280
281#[cfg(feature = "simd")]
282macro_rules! dispatch_cmp_scalar_simd {
283    ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{
284        paste! {
285            let simd = match $elem::dtype() {
286                $(DType::[<$ty:upper>] => try_cmp_scalar_simd::<$elem, $ty, $op>($lhs, $rhs),)*
287                DType::QFloat(strategy) => match strategy.q_type {
288                    QuantInputType::QInt8 => try_cmp_scalar_simd::<$elem, i8, $op>($lhs, $rhs),
289                },
290                _ => Err($lhs),
291            };
292            match simd {
293                Ok(out) => return out,
294                Err(args) => args,
295            }
296        }
297    }};
298}
299
300#[cfg(not(feature = "simd"))]
301macro_rules! dispatch_cmp_scalar_simd {
302    ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ $lhs }};
303}
304
305#[cfg(feature = "simd")]
306macro_rules! dispatch_unary_simd {
307    ($elem: ty, $op: ty, $lhs: expr, $($ty: ty),*) => {{
308        paste! {
309            let simd = match $elem::dtype() {
310                $(DType::[<$ty:upper>] => try_unary_simd::<$elem, $elem, $ty, $ty, $op>($lhs),)*
311                _ => Err($lhs),
312            };
313            match simd {
314                Ok(out) => return out,
315                Err(args) => args,
316            }
317        }
318    }};
319}
320
321#[cfg(not(feature = "simd"))]
322macro_rules! dispatch_unary_simd {
323    ($elem: ty, $op: ty, $lhs: expr, $($ty: ty),*) => {{ $lhs }};
324}
325
326impl<E> NdArrayMathOps<E>
327where
328    E: Copy + NdArrayElement,
329{
330    pub fn add(lhs: NdArrayTensor<E>, rhs: NdArrayTensor<E>) -> NdArrayTensor<E> {
331        let (lhs, rhs) = dispatch_binary_simd!(
332            E, VecAdd, lhs, rhs, u8, i8, u16, i16, u32, i32, f32, u64, i64, f64
333        );
334
335        let array = &lhs.array + &rhs.array;
336        let array = array.into_shared();
337
338        NdArrayTensor { array }
339    }
340
341    pub fn add_scalar(lhs: NdArrayTensor<E>, rhs: E) -> NdArrayTensor<E> {
342        let lhs = dispatch_binary_scalar_simd!(
343            E,
344            VecAdd,
345            lhs,
346            rhs.elem(),
347            u8,
348            i8,
349            u16,
350            i16,
351            u32,
352            i32,
353            f32,
354            u64,
355            i64,
356            f64
357        );
358
359        let array = lhs.array + rhs;
360        let array = array.into_shared();
361
362        NdArrayTensor { array }
363    }
364
365    pub fn sub(lhs: NdArrayTensor<E>, rhs: NdArrayTensor<E>) -> NdArrayTensor<E> {
366        let (lhs, rhs) = dispatch_binary_simd!(
367            E, VecSub, lhs, rhs, u8, i8, u16, i16, u32, i32, f32, u64, i64, f64
368        );
369
370        let array = lhs.array - rhs.array;
371        let array = array.into_shared();
372
373        NdArrayTensor { array }
374    }
375
376    pub fn sub_scalar(lhs: NdArrayTensor<E>, rhs: E) -> NdArrayTensor<E> {
377        let lhs = dispatch_binary_scalar_simd!(
378            E,
379            VecSub,
380            lhs,
381            rhs.elem(),
382            u8,
383            i8,
384            u16,
385            i16,
386            u32,
387            i32,
388            f32,
389            u64,
390            i64,
391            f64
392        );
393
394        let array = lhs.array - rhs;
395        let array = array.into_shared();
396
397        NdArrayTensor { array }
398    }
399
400    pub fn mul(lhs: NdArrayTensor<E>, rhs: NdArrayTensor<E>) -> NdArrayTensor<E> {
401        let (lhs, rhs) =
402            dispatch_binary_simd!(noq, E, VecMul, lhs, rhs, u16, i16, u32, i32, f32, f64);
403
404        let array = lhs.array * rhs.array;
405        let array = array.into_shared();
406
407        NdArrayTensor { array }
408    }
409
410    pub fn mul_scalar(lhs: NdArrayTensor<E>, rhs: E) -> NdArrayTensor<E> {
411        let lhs = dispatch_binary_scalar_simd!(
412            noq,
413            E,
414            VecMul,
415            lhs,
416            rhs.elem(),
417            u16,
418            i16,
419            u32,
420            i32,
421            f32,
422            f64
423        );
424
425        let array = lhs.array * rhs;
426        let array = array.into_shared();
427
428        NdArrayTensor { array }
429    }
430
431    pub fn div(lhs: NdArrayTensor<E>, rhs: NdArrayTensor<E>) -> NdArrayTensor<E> {
432        let (lhs, rhs) = dispatch_binary_simd!(noq, E, VecDiv, lhs, rhs, f32, f64);
433
434        let array = lhs.array / rhs.array;
435        let array = array.into_shared();
436
437        NdArrayTensor { array }
438    }
439
440    pub fn div_scalar(lhs: NdArrayTensor<E>, rhs: E) -> NdArrayTensor<E> {
441        let lhs = dispatch_binary_scalar_simd!(noq, E, VecDiv, lhs, rhs.elem(), f32, f64);
442
443        let array = lhs.array / rhs;
444        let array = array.into_shared();
445
446        NdArrayTensor { array }
447    }
448
449    pub fn remainder(lhs: NdArrayTensor<E>, rhs: NdArrayTensor<E>) -> NdArrayTensor<E> {
450        let array = lhs.array.clone()
451            - (lhs.array / rhs.array.clone()).mapv_into(|a| (a.to_f64()).floor().elem())
452                * rhs.array;
453        let array = array.into_shared();
454        NdArrayTensor { array }
455    }
456
457    pub fn remainder_scalar(lhs: NdArrayTensor<E>, rhs: E) -> NdArrayTensor<E>
458    where
459        E: core::ops::Rem<Output = E>,
460    {
461        let array = lhs.array.mapv(|x| ((x % rhs) + rhs) % rhs);
462        let array = array.into_shared();
463
464        NdArrayTensor { array }
465    }
466
467    pub fn recip(tensor: NdArrayTensor<E>) -> NdArrayTensor<E> {
468        let tensor = dispatch_unary_simd!(E, RecipVec, tensor, f32);
469
470        let array = tensor.array.map(|x| 1.elem::<E>() / *x);
471        let array = array.into_shared();
472
473        NdArrayTensor { array }
474    }
475
476    pub fn mean(tensor: NdArrayTensor<E>) -> NdArrayTensor<E> {
477        let data = TensorData::from([tensor.array.mean().unwrap()]);
478        NdArrayTensor::from_data(data)
479    }
480
481    pub fn sum(tensor: NdArrayTensor<E>) -> NdArrayTensor<E> {
482        let data = TensorData::from([tensor.array.sum()]);
483        NdArrayTensor::from_data(data)
484    }
485
486    pub fn prod(tensor: NdArrayTensor<E>) -> NdArrayTensor<E> {
487        let data = TensorData::from([tensor.array.product()]);
488        NdArrayTensor::from_data(data)
489    }
490
491    pub fn mean_dim(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<E> {
492        let ndims = tensor.shape().num_dims();
493        match ndims {
494            d if (1..=6).contains(&d) => keepdim!(dim, tensor, mean),
495            _ => panic!("Dim not supported {ndims}"),
496        }
497    }
498
499    pub fn sum_dim(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<E> {
500        let ndims = tensor.shape().num_dims();
501        match ndims {
502            d if (1..=6).contains(&d) => keepdim!(dim, tensor, sum),
503            _ => panic!("Dim not supported {ndims}"),
504        }
505    }
506
507    pub fn prod_dim(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<E> {
508        let ndims = tensor.shape().num_dims();
509        match ndims {
510            d if (1..=6).contains(&d) => keepdim!(dim, tensor, prod),
511            _ => panic!("Dim not supported {ndims}"),
512        }
513    }
514
515    pub fn gather<I: NdArrayElement>(
516        dim: usize,
517        mut tensor: NdArrayTensor<E>,
518        mut indices: NdArrayTensor<I>,
519    ) -> NdArrayTensor<E> {
520        let ndims = tensor.shape().num_dims();
521        if dim != ndims - 1 {
522            tensor.array.swap_axes(ndims - 1, dim);
523            indices.array.swap_axes(ndims - 1, dim);
524        }
525        let (shape_tensor, shape_indices) = (tensor.shape(), indices.shape());
526        let (size_tensor, size_index) =
527            (shape_tensor.dims[ndims - 1], shape_indices.dims[ndims - 1]);
528        let batch_size = Self::gather_batch_size(&shape_tensor, &shape_indices);
529
530        let indices = NdArrayOps::reshape(indices, Shape::new([batch_size, size_index])).array;
531        let tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])).array;
532        let mut output = Array2::zeros((batch_size, size_index));
533
534        for b in 0..batch_size {
535            let indices = indices.slice(s!(b, ..));
536
537            for (i, index) in indices.iter().enumerate() {
538                output[[b, i]] = tensor[[b, index.elem::<i64>() as usize]];
539            }
540        }
541
542        let mut output = NdArrayOps::reshape(
543            NdArrayTensor::<E>::new(output.into_shared().into_dyn()),
544            shape_indices,
545        );
546
547        if dim != ndims - 1 {
548            output.array.swap_axes(ndims - 1, dim);
549        }
550
551        output
552    }
553
554    pub fn scatter<I: NdArrayElement>(
555        dim: usize,
556        mut tensor: NdArrayTensor<E>,
557        mut indices: NdArrayTensor<I>,
558        mut value: NdArrayTensor<E>,
559    ) -> NdArrayTensor<E> {
560        let ndims = tensor.shape().num_dims();
561        if dim != ndims - 1 {
562            tensor.array.swap_axes(ndims - 1, dim);
563            indices.array.swap_axes(ndims - 1, dim);
564            value.array.swap_axes(ndims - 1, dim);
565        }
566
567        let (shape_tensor, shape_indices, shape_value) =
568            (tensor.shape(), indices.shape(), value.shape());
569        let (size_tensor, size_index, size_value) = (
570            shape_tensor.dims[ndims - 1],
571            shape_indices.dims[ndims - 1],
572            shape_value.dims[ndims - 1],
573        );
574        let batch_size = Self::gather_batch_size(&shape_tensor, &shape_indices);
575
576        if shape_value != shape_indices {
577            panic!(
578                "Invalid dimension: the shape of the index tensor should be the same as the value \
579                 tensor: Index {:?} value {:?}",
580                shape_indices.dims, shape_value.dims
581            );
582        }
583
584        let indices = NdArrayOps::reshape(indices, Shape::new([batch_size, size_index])).array;
585        let value = NdArrayOps::reshape(value, Shape::new([batch_size, size_value])).array;
586        let mut tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])).array;
587
588        for b in 0..batch_size {
589            let indices = indices.slice(s!(b, ..));
590
591            for (i, index) in indices.iter().enumerate() {
592                let index = index.elem::<i64>() as usize;
593                tensor[[b, index]] += value[[b, i]];
594            }
595        }
596
597        let mut output = NdArrayOps::reshape(
598            NdArrayTensor::<E>::new(tensor.into_shared().into_dyn()),
599            shape_tensor,
600        );
601        if dim != ndims - 1 {
602            output.array.swap_axes(ndims - 1, dim);
603        }
604        output
605    }
606
607    pub fn mask_where(
608        tensor: NdArrayTensor<E>,
609        mask: NdArrayTensor<bool>,
610        source: NdArrayTensor<E>,
611    ) -> NdArrayTensor<E> {
612        let tensor = tensor.array.broadcast(mask.array.dim()).unwrap();
613        let source = source.array.broadcast(mask.array.dim()).unwrap();
614        let output = Zip::from(&tensor)
615            .and(&mask.array)
616            .and(&source)
617            .map_collect(|&x, &mask_val, &y| if mask_val { y } else { x })
618            .into_shared();
619        NdArrayTensor::new(output)
620    }
621
622    pub fn mask_fill(
623        tensor: NdArrayTensor<E>,
624        mask: NdArrayTensor<bool>,
625        value: E,
626    ) -> NdArrayTensor<E> {
627        let mut output = tensor.array.clone();
628        let broadcast_mask = mask.array.broadcast(output.dim()).unwrap();
629        Zip::from(&mut output)
630            .and(&broadcast_mask)
631            .for_each(|out, &mask_val| {
632                if mask_val {
633                    *out = value;
634                }
635            });
636        NdArrayTensor::new(output.into_shared())
637    }
638
639    fn gather_batch_size(shape_tensor: &Shape, shape_indices: &Shape) -> usize {
640        let ndims = shape_tensor.num_dims();
641        let mut batch_size = 1;
642
643        for i in 0..ndims - 1 {
644            if shape_tensor.dims[i] != shape_indices.dims[i] {
645                panic!(
646                    "Unsupported dimension, only the last dimension can differ: Tensor {:?} Index \
647                     {:?}",
648                    shape_tensor.dims, shape_indices.dims
649                );
650            }
651            batch_size *= shape_indices.dims[i];
652        }
653
654        batch_size
655    }
656
657    pub fn select<I: NdArrayElement>(
658        tensor: NdArrayTensor<E>,
659        dim: usize,
660        indices: NdArrayTensor<I>,
661    ) -> NdArrayTensor<E> {
662        let array = tensor.array.select(
663            Axis(dim),
664            &indices
665                .array
666                .into_iter()
667                .map(|i| i.elem::<i64>() as usize)
668                .collect::<Vec<_>>(),
669        );
670
671        NdArrayTensor::new(array.into_shared())
672    }
673
674    pub fn select_assign<I: NdArrayElement>(
675        tensor: NdArrayTensor<E>,
676        dim: usize,
677        indices: NdArrayTensor<I>,
678        value: NdArrayTensor<E>,
679    ) -> NdArrayTensor<E> {
680        let mut output_array = tensor.array.into_owned();
681
682        for (index_value, index) in indices.array.into_iter().enumerate() {
683            let mut view = output_array.index_axis_mut(Axis(dim), index.elem::<i64>() as usize);
684            let value = value.array.index_axis(Axis(dim), index_value);
685
686            view.zip_mut_with(&value, |a, b| *a += *b);
687        }
688
689        NdArrayTensor::new(output_array.into_shared())
690    }
691    pub fn argmax<I: NdArrayElement>(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<I> {
692        arg(tensor, dim, CmpType::Max)
693    }
694
695    pub fn argmin<I: NdArrayElement>(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<I> {
696        arg(tensor, dim, CmpType::Min)
697    }
698
699    pub fn clamp_min(tensor: NdArrayTensor<E>, min: E) -> NdArrayTensor<E> {
700        let mut tensor = dispatch_binary_scalar_simd!(
701            E,
702            VecMax,
703            tensor,
704            min.elem(),
705            u8,
706            i8,
707            u16,
708            i16,
709            u32,
710            i32,
711            f32,
712            u64,
713            i64,
714            f64
715        );
716
717        tensor.array.mapv_inplace(|x| match x < min {
718            true => min,
719            false => x,
720        });
721
722        tensor
723    }
724
725    pub fn clamp_max(tensor: NdArrayTensor<E>, max: E) -> NdArrayTensor<E> {
726        let mut tensor = dispatch_binary_scalar_simd!(
727            E,
728            VecMin,
729            tensor,
730            max.elem(),
731            u8,
732            i8,
733            u16,
734            i16,
735            u32,
736            i32,
737            f32,
738            u64,
739            i64,
740            f64
741        );
742
743        tensor.array.mapv_inplace(|x| match x > max {
744            true => max,
745            false => x,
746        });
747
748        tensor
749    }
750
751    pub fn clamp(tensor: NdArrayTensor<E>, min: E, max: E) -> NdArrayTensor<E> {
752        let mut tensor = dispatch_binary_scalar_simd!(
753            E,
754            VecClamp,
755            tensor,
756            (min.elem(), max.elem()),
757            u8,
758            i8,
759            u16,
760            i16,
761            u32,
762            i32,
763            f32,
764            u64,
765            i64,
766            f64
767        );
768
769        tensor.array.mapv_inplace(|x| match x < min {
770            true => min,
771            false => match x > max {
772                true => max,
773                false => x,
774            },
775        });
776
777        tensor
778    }
779
780    pub(crate) fn elementwise_op<OtherE>(
781        lhs: NdArrayTensor<E>,
782        rhs: NdArrayTensor<OtherE>,
783        var_name: impl FnMut(&E, &OtherE) -> E,
784    ) -> NdArrayTensor<E> {
785        let lhs = lhs
786            .array
787            .broadcast(rhs.array.dim())
788            .unwrap_or(lhs.array.view());
789        let rhs = rhs.array.broadcast(lhs.dim()).unwrap_or(rhs.array.view());
790
791        NdArrayTensor::new(Zip::from(lhs).and(rhs).map_collect(var_name).into_shared())
792    }
793
794    pub(crate) fn elementwise_op_scalar(
795        lhs: NdArrayTensor<E>,
796        var_name: impl FnMut(E) -> E,
797    ) -> NdArrayTensor<E> {
798        NdArrayTensor::new(lhs.array.mapv(var_name).into_shared())
799    }
800
801    pub(crate) fn sign_op(tensor: NdArrayTensor<E>) -> NdArrayTensor<E>
802    where
803        E: Signed,
804    {
805        let zero = 0.elem();
806        let one = 1.elem::<E>();
807        NdArrayTensor::new(
808            tensor
809                .array
810                .mapv(|x| {
811                    if x > zero {
812                        one
813                    } else if x < zero {
814                        -one
815                    } else {
816                        zero
817                    }
818                })
819                .into_shared(),
820        )
821    }
822
823    pub(crate) fn abs(tensor: NdArrayTensor<E>) -> NdArrayTensor<E> {
824        let tensor = dispatch_unary_simd!(E, VecAbs, tensor, i8, i16, i32, f32, f64);
825
826        let array = tensor.array.mapv_into(|a| a.abs_elem()).into_shared();
827
828        NdArrayTensor::new(array)
829    }
830
831    pub(crate) fn equal(lhs: NdArrayTensor<E>, rhs: NdArrayTensor<E>) -> NdArrayTensor<bool> {
832        let (lhs, rhs) = dispatch_cmp_simd!(
833            E, VecEquals, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64
834        );
835
836        let output = Zip::from(&lhs.array)
837            .and(&rhs.array)
838            .map_collect(|&lhs_val, &rhs_val| (lhs_val == rhs_val))
839            .into_shared();
840        NdArrayTensor::new(output)
841    }
842
843    pub(crate) fn equal_elem(lhs: NdArrayTensor<E>, rhs: E) -> NdArrayTensor<bool> {
844        let lhs = dispatch_cmp_scalar_simd!(
845            E,
846            VecEquals,
847            lhs,
848            rhs.elem(),
849            u8,
850            i8,
851            u16,
852            i16,
853            u32,
854            f32,
855            i32,
856            u64,
857            i64,
858            f64
859        );
860
861        let array = lhs.array.mapv(|a| a == rhs).into_shared();
862        NdArrayTensor { array }
863    }
864
865    pub(crate) fn greater(lhs: NdArrayTensor<E>, rhs: NdArrayTensor<E>) -> NdArrayTensor<bool> {
866        let (lhs, rhs) = dispatch_cmp_simd!(
867            E, VecGreater, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64
868        );
869
870        let lhs = lhs
871            .array
872            .broadcast(rhs.array.dim())
873            .unwrap_or(lhs.array.view());
874        let rhs = rhs.array.broadcast(lhs.dim()).unwrap_or(rhs.array.view());
875
876        NdArrayTensor::new(
877            Zip::from(lhs)
878                .and(rhs)
879                .map_collect(|lhs, rhs| lhs > rhs)
880                .into_shared(),
881        )
882    }
883
884    pub(crate) fn greater_elem(lhs: NdArrayTensor<E>, rhs: E) -> NdArrayTensor<bool> {
885        let lhs = dispatch_cmp_scalar_simd!(
886            E,
887            VecGreater,
888            lhs,
889            rhs.elem(),
890            u8,
891            i8,
892            u16,
893            i16,
894            u32,
895            f32,
896            i32,
897            u64,
898            i64,
899            f64
900        );
901
902        let array = lhs.array.mapv(|a| a > rhs).into_shared();
903        NdArrayTensor { array }
904    }
905
906    pub(crate) fn greater_equal(
907        lhs: NdArrayTensor<E>,
908        rhs: NdArrayTensor<E>,
909    ) -> NdArrayTensor<bool> {
910        let (lhs, rhs) = dispatch_cmp_simd!(
911            E,
912            VecGreaterEq,
913            lhs,
914            rhs,
915            u8,
916            i8,
917            u16,
918            i16,
919            u32,
920            f32,
921            i32,
922            u64,
923            i64,
924            f64
925        );
926
927        let lhs = lhs
928            .array
929            .broadcast(rhs.array.dim())
930            .unwrap_or(lhs.array.view());
931        let rhs = rhs.array.broadcast(lhs.dim()).unwrap_or(rhs.array.view());
932
933        NdArrayTensor::new(
934            Zip::from(lhs)
935                .and(rhs)
936                .map_collect(|lhs, rhs| lhs >= rhs)
937                .into_shared(),
938        )
939    }
940
941    pub(crate) fn greater_equal_elem(lhs: NdArrayTensor<E>, rhs: E) -> NdArrayTensor<bool> {
942        let lhs = dispatch_cmp_scalar_simd!(
943            E,
944            VecGreaterEq,
945            lhs,
946            rhs.elem(),
947            u8,
948            i8,
949            u16,
950            i16,
951            u32,
952            f32,
953            i32,
954            u64,
955            i64,
956            f64
957        );
958
959        let array = lhs.array.mapv(|a| a >= rhs).into_shared();
960        NdArrayTensor { array }
961    }
962
963    pub(crate) fn lower_equal(lhs: NdArrayTensor<E>, rhs: NdArrayTensor<E>) -> NdArrayTensor<bool> {
964        let (lhs, rhs) = dispatch_cmp_simd!(
965            E, VecLowerEq, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64
966        );
967
968        let lhs = lhs
969            .array
970            .broadcast(rhs.array.dim())
971            .unwrap_or(lhs.array.view());
972        let rhs = rhs.array.broadcast(lhs.dim()).unwrap_or(rhs.array.view());
973
974        NdArrayTensor::new(
975            Zip::from(lhs)
976                .and(rhs)
977                .map_collect(|lhs, rhs| lhs <= rhs)
978                .into_shared(),
979        )
980    }
981
982    pub(crate) fn lower_equal_elem(lhs: NdArrayTensor<E>, rhs: E) -> NdArrayTensor<bool> {
983        let lhs = dispatch_cmp_scalar_simd!(
984            E,
985            VecLowerEq,
986            lhs,
987            rhs.elem(),
988            u8,
989            i8,
990            u16,
991            i16,
992            u32,
993            f32,
994            i32,
995            u64,
996            i64,
997            f64
998        );
999
1000        let array = lhs.array.mapv(|a| a <= rhs).into_shared();
1001        NdArrayTensor { array }
1002    }
1003
1004    pub(crate) fn lower(lhs: NdArrayTensor<E>, rhs: NdArrayTensor<E>) -> NdArrayTensor<bool> {
1005        let (lhs, rhs) = dispatch_cmp_simd!(
1006            E, VecLower, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64
1007        );
1008
1009        let lhs = lhs
1010            .array
1011            .broadcast(rhs.array.dim())
1012            .unwrap_or(lhs.array.view());
1013        let rhs = rhs.array.broadcast(lhs.dim()).unwrap_or(rhs.array.view());
1014
1015        NdArrayTensor::new(
1016            Zip::from(lhs)
1017                .and(rhs)
1018                .map_collect(|lhs, rhs| lhs < rhs)
1019                .into_shared(),
1020        )
1021    }
1022
1023    pub(crate) fn lower_elem(lhs: NdArrayTensor<E>, rhs: E) -> NdArrayTensor<bool> {
1024        let lhs = dispatch_cmp_scalar_simd!(
1025            E,
1026            VecLower,
1027            lhs,
1028            rhs.elem(),
1029            u8,
1030            i8,
1031            u16,
1032            i16,
1033            u32,
1034            f32,
1035            i32,
1036            u64,
1037            i64,
1038            f64
1039        );
1040
1041        let array = lhs.array.mapv(|a| a < rhs).into_shared();
1042        NdArrayTensor { array }
1043    }
1044}
1045
1046pub struct NdArrayBitOps<I: IntNdArrayElement>(PhantomData<I>);
1047
1048impl<I: IntNdArrayElement> NdArrayBitOps<I> {
1049    pub(crate) fn bitand(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
1050        let (lhs, rhs) =
1051            dispatch_binary_simd!(I, VecBitAnd, lhs, rhs, i8, u8, i16, u16, i32, u32, i64, u64);
1052
1053        NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
1054            (a.elem::<i64>() & (b.elem::<i64>())).elem()
1055        })
1056    }
1057
1058    pub(crate) fn bitand_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
1059        let lhs = dispatch_binary_scalar_simd!(
1060            I,
1061            VecBitAnd,
1062            lhs,
1063            rhs.elem(),
1064            i8,
1065            u8,
1066            i16,
1067            u16,
1068            i32,
1069            u32,
1070            i64,
1071            u64
1072        );
1073
1074        NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| {
1075            (a.elem::<i64>() & rhs.elem::<i64>()).elem()
1076        })
1077    }
1078
1079    pub(crate) fn bitor(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
1080        let (lhs, rhs) =
1081            dispatch_binary_simd!(I, VecBitOr, lhs, rhs, i8, u8, i16, u16, i32, u32, i64, u64);
1082
1083        NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
1084            (a.elem::<i64>() | (b.elem::<i64>())).elem()
1085        })
1086    }
1087
1088    pub(crate) fn bitor_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
1089        let lhs = dispatch_binary_scalar_simd!(
1090            I,
1091            VecBitOr,
1092            lhs,
1093            rhs.elem(),
1094            i8,
1095            u8,
1096            i16,
1097            u16,
1098            i32,
1099            u32,
1100            i64,
1101            u64
1102        );
1103
1104        NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| {
1105            (a.elem::<i64>() | rhs.elem::<i64>()).elem()
1106        })
1107    }
1108
1109    pub(crate) fn bitxor(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
1110        let (lhs, rhs) =
1111            dispatch_binary_simd!(I, VecBitXor, lhs, rhs, i8, u8, i16, u16, i32, u32, i64, u64);
1112
1113        NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
1114            (a.elem::<i64>() ^ (b.elem::<i64>())).elem()
1115        })
1116    }
1117
1118    pub(crate) fn bitxor_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
1119        let lhs = dispatch_binary_scalar_simd!(
1120            I,
1121            VecBitXor,
1122            lhs,
1123            rhs.elem(),
1124            i8,
1125            u8,
1126            i16,
1127            u16,
1128            i32,
1129            u32,
1130            i64,
1131            u64
1132        );
1133
1134        NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| {
1135            (a.elem::<i64>() ^ rhs.elem::<i64>()).elem()
1136        })
1137    }
1138
1139    pub(crate) fn bitnot(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
1140        let tensor =
1141            dispatch_unary_simd!(I, VecBitNot, tensor, i8, u8, i16, u16, i32, u32, i64, u64);
1142
1143        NdArrayMathOps::elementwise_op_scalar(tensor, |a: I| (!a.elem::<i64>()).elem())
1144    }
1145}
1146
1147pub struct NdArrayBoolOps;
1148
1149// Rust booleans are either `00000000` or `00000001`, so bitwise and/or is fine, but bitwise not would
1150// produce invalid values.
1151impl NdArrayBoolOps {
1152    pub(crate) fn equal(lhs: NdArrayTensor<bool>, rhs: NdArrayTensor<bool>) -> NdArrayTensor<bool> {
1153        #[cfg(feature = "simd")]
1154        let (lhs, rhs) = match try_cmp_simd::<bool, u8, VecEquals>(lhs, rhs) {
1155            Ok(out) => return out,
1156            Err(args) => args,
1157        };
1158
1159        let output = Zip::from(&lhs.array)
1160            .and(&rhs.array)
1161            .map_collect(|&lhs_val, &rhs_val| (lhs_val == rhs_val))
1162            .into_shared();
1163        NdArrayTensor::new(output)
1164    }
1165
1166    pub(crate) fn and(lhs: NdArrayTensor<bool>, rhs: NdArrayTensor<bool>) -> NdArrayTensor<bool> {
1167        #[cfg(feature = "simd")]
1168        let (lhs, rhs) = match try_binary_simd::<bool, bool, u8, u8, VecBitAnd>(lhs, rhs) {
1169            Ok(out) => return out,
1170            Err(args) => args,
1171        };
1172
1173        let output = Zip::from(&lhs.array)
1174            .and(&rhs.array)
1175            .map_collect(|&lhs_val, &rhs_val| (lhs_val && rhs_val))
1176            .into_shared();
1177        NdArrayTensor::new(output)
1178    }
1179
1180    pub(crate) fn or(lhs: NdArrayTensor<bool>, rhs: NdArrayTensor<bool>) -> NdArrayTensor<bool> {
1181        #[cfg(feature = "simd")]
1182        let (lhs, rhs) = match try_binary_simd::<bool, bool, u8, u8, VecBitOr>(lhs, rhs) {
1183            Ok(out) => return out,
1184            Err(args) => args,
1185        };
1186
1187        let output = Zip::from(&lhs.array)
1188            .and(&rhs.array)
1189            .map_collect(|&lhs_val, &rhs_val| (lhs_val || rhs_val))
1190            .into_shared();
1191        NdArrayTensor::new(output)
1192    }
1193}
1194
1195enum CmpType {
1196    Min,
1197    Max,
1198}
1199
1200fn arg<E: NdArrayElement, I: NdArrayElement>(
1201    tensor: NdArrayTensor<E>,
1202    dim: usize,
1203    cmp: CmpType,
1204) -> NdArrayTensor<I> {
1205    let mut reshape = tensor.array.shape().to_vec();
1206    reshape[dim] = 1;
1207
1208    let output = tensor.array.map_axis(Axis(dim), |arr| {
1209        // Find the min/max value in the array, and return its index.
1210        let (_e, idx) = arr.indexed_iter().fold((arr[0], 0usize), |acc, (idx, e)| {
1211            let cmp = match cmp {
1212                CmpType::Min => e < &acc.0,
1213                CmpType::Max => e > &acc.0,
1214            };
1215
1216            if cmp { (*e, idx) } else { acc }
1217        });
1218
1219        (idx as i64).elem()
1220    });
1221
1222    let output = output.to_shape(Dim(reshape.as_slice())).unwrap();
1223
1224    NdArrayTensor {
1225        array: output.into_shared(),
1226    }
1227}
1228
1229#[cfg(test)]
1230mod tests {
1231    use super::*;
1232
1233    #[test]
1234    fn should_generate_row_major_layout_for_cat() {
1235        let expected_shape: &[usize] = &[4, 6, 2];
1236        let expected_strides: &[isize] = &[12, 2, 1];
1237        let expected_array: NdArrayTensor<i32> = NdArrayTensor::from_data(TensorData::from([
1238            [[1, 0], [2, 0], [3, 0], [4, 0], [5, 0], [6, 0]],
1239            [[7, 0], [8, 0], [9, 0], [10, 0], [11, 0], [12, 0]],
1240            [[13, 0], [14, 0], [15, 0], [16, 0], [17, 0], [18, 0]],
1241            [[19, 0], [20, 0], [21, 0], [22, 0], [23, 0], [24, 0]],
1242        ]));
1243
1244        // unsqueeze dim on the outermost axis
1245        let array = NdArrayOps::reshape(
1246            NdArrayTensor::<i32>::from_data(TensorData::from([
1247                [1, 2, 3, 4, 5, 6],
1248                [7, 8, 9, 10, 11, 12],
1249                [13, 14, 15, 16, 17, 18],
1250                [19, 20, 21, 22, 23, 24],
1251            ])),
1252            Shape::from([4, 6, 1]),
1253        );
1254        let zeros = NdArrayTensor::<i32>::from_data(TensorData::zeros::<i32, _>([4, 6, 1]));
1255        // make `ndarray` concatenates array on the outermost axis
1256        let array = NdArrayOps::cat([array, zeros].to_vec(), 2);
1257
1258        assert!(array.array.is_standard_layout());
1259        assert_eq!(array.array.shape(), expected_shape);
1260        assert_eq!(array.array.strides(), expected_strides);
1261        assert_eq!(
1262            array.array.into_iter().collect::<Vec<_>>(),
1263            expected_array.array.into_iter().collect::<Vec<_>>(),
1264        );
1265    }
1266}