1use alloc::{vec, vec::Vec};
2use burn_tensor::ElementConversion;
3use burn_tensor::TensorData;
4use burn_tensor::TensorMetadata;
5#[cfg(feature = "simd")]
6use burn_tensor::{DType, quantization::QuantInputType};
7use core::fmt::Debug;
8use core::{marker::PhantomData, ops::Range};
9use ndarray::Array2;
10use ndarray::IntoDimension;
11use ndarray::SliceInfo;
12use ndarray::Zip;
13use ndarray::s;
14use num_traits::Signed;
15#[cfg(feature = "simd")]
16use paste::paste;
17
18#[cfg(not(feature = "std"))]
19#[allow(unused_imports)]
20use num_traits::Float;
21
22use burn_tensor::Shape;
23use ndarray::Axis;
24use ndarray::Dim;
25use ndarray::IxDyn;
26use ndarray::SliceInfoElem;
27
28use crate::element::NdArrayElement;
29#[cfg(feature = "simd")]
30use crate::ops::simd::{
31 binary::try_binary_simd,
32 binary_elemwise::{
33 VecAdd, VecBitAnd, VecBitOr, VecBitXor, VecClamp, VecDiv, VecMax, VecMin, VecMul, VecSub,
34 try_binary_scalar_simd,
35 },
36 cmp::{
37 VecEquals, VecGreater, VecGreaterEq, VecLower, VecLowerEq, try_cmp_scalar_simd,
38 try_cmp_simd,
39 },
40 unary::{RecipVec, VecAbs, VecBitNot, try_unary_simd},
41};
42use crate::{
43 IntNdArrayElement,
44 ops::macros::{keepdim, mean_dim, prod_dim, sum_dim},
45};
46use crate::{reshape, tensor::NdArrayTensor};
47
48pub struct NdArrayOps<E> {
49 e: PhantomData<E>,
50}
51
52pub(crate) struct NdArrayMathOps<E> {
53 e: PhantomData<E>,
54}
55
56impl<E> NdArrayOps<E>
57where
58 E: Copy + Debug + burn_tensor::Element,
59{
60 pub fn into_data(tensor: NdArrayTensor<E>) -> TensorData {
61 tensor.into_data()
62 }
63
64 pub fn slice(tensor: NdArrayTensor<E>, ranges: &[Range<usize>]) -> NdArrayTensor<E> {
65 let slices = Self::to_slice_args(ranges, tensor.shape().num_dims());
66 let array = tensor.array.slice_move(slices.as_slice()).into_shared();
67
68 NdArrayTensor { array }
69 }
70
71 pub fn slice_assign(
72 tensor: NdArrayTensor<E>,
73 ranges: &[Range<usize>],
74 value: NdArrayTensor<E>,
75 ) -> NdArrayTensor<E> {
76 let slices = Self::to_slice_args(ranges, tensor.shape().num_dims());
77 let mut array = tensor.array.into_owned();
78 array.slice_mut(slices.as_slice()).assign(&value.array);
79 let array = array.into_shared();
80
81 NdArrayTensor { array }
82 }
83
84 pub fn reshape(tensor: NdArrayTensor<E>, shape: Shape) -> NdArrayTensor<E> {
85 reshape!(
86 ty E,
87 shape shape,
88 array tensor.array,
89 d shape.num_dims()
90 )
91 }
92
93 pub(crate) fn concatenate(
94 arrays: &[ndarray::ArrayView<E, IxDyn>],
95 dim: usize,
96 ) -> NdArrayTensor<E> {
97 let array = ndarray::concatenate(Axis(dim), arrays)
98 .unwrap()
99 .into_shared();
100
101 let array = NdArrayTensor { array };
103 Self::reshape(array.clone(), array.shape())
104 }
105
106 pub fn cat(tensors: Vec<NdArrayTensor<E>>, dim: usize) -> NdArrayTensor<E> {
107 let arrays: Vec<_> = tensors.iter().map(|t| t.array.view()).collect();
108 Self::concatenate(&arrays, dim)
109 }
110
111 fn to_slice_args(ranges: &[Range<usize>], ndims: usize) -> Vec<SliceInfoElem> {
112 let mut slices = vec![SliceInfoElem::NewAxis; ndims];
113 for i in 0..ndims {
114 if i >= ranges.len() {
115 slices[i] = SliceInfoElem::Slice {
116 start: 0,
117 end: None,
118 step: 1,
119 }
120 } else {
121 slices[i] = SliceInfoElem::Slice {
122 start: ranges[i].start as isize,
123 end: Some(ranges[i].end as isize),
124 step: 1,
125 }
126 }
127 }
128 slices
129 }
130
131 pub fn swap_dims(tensor: NdArrayTensor<E>, dim1: usize, dim2: usize) -> NdArrayTensor<E> {
132 let mut array = tensor.array;
133 array.swap_axes(dim1, dim2);
134
135 NdArrayTensor::new(array)
136 }
137
138 pub fn permute(tensor: NdArrayTensor<E>, axes: &[usize]) -> NdArrayTensor<E> {
139 let array = tensor.array.permuted_axes(axes.into_dimension());
140
141 NdArrayTensor::new(array)
142 }
143
144 pub(crate) fn expand(tensor: NdArrayTensor<E>, shape: Shape) -> NdArrayTensor<E> {
146 let array = tensor
147 .array
148 .broadcast(shape.dims.into_dimension())
149 .expect("The shapes should be broadcastable")
150 .into_owned()
153 .into_shared();
154 NdArrayTensor { array }
155 }
156
157 pub fn flip(tensor: NdArrayTensor<E>, axes: &[usize]) -> NdArrayTensor<E> {
158 let slice_items: Vec<_> = (0..tensor.shape().num_dims())
159 .map(|i| {
160 if axes.contains(&i) {
161 SliceInfoElem::Slice {
162 start: 0,
163 end: None,
164 step: -1,
165 }
166 } else {
167 SliceInfoElem::Slice {
168 start: 0,
169 end: None,
170 step: 1,
171 }
172 }
173 })
174 .collect();
175 let slice_info =
176 SliceInfo::<Vec<SliceInfoElem>, IxDyn, IxDyn>::try_from(slice_items).unwrap();
177 let array = tensor.array.slice(slice_info).into_owned().into_shared();
178
179 NdArrayTensor::new(array)
180 }
181}
182
183#[cfg(feature = "simd")]
184macro_rules! dispatch_binary_simd {
185 (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{
186 paste! {
187 let simd = match $elem::dtype() {
188 $(DType::[<$ty:upper>] => try_binary_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)*
189 _ => Err(($lhs, $rhs)),
190 };
191 match simd {
192 Ok(out) => return out,
193 Err(args) => args,
194 }
195 }
196 }};
197 ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{
198 paste! {
199 let simd = match $elem::dtype() {
200 $(DType::[<$ty:upper>] => try_binary_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)*
201 DType::QFloat(strategy) => match strategy.q_type {
202 QuantInputType::QInt8 => try_binary_simd::<$elem, $elem, i8, i8, $op>($lhs, $rhs),
203 },
204 _ => Err(($lhs, $rhs)),
205 };
206 match simd {
207 Ok(out) => return out,
208 Err(args) => args,
209 }
210 }
211 }};
212}
213
214#[cfg(not(feature = "simd"))]
215macro_rules! dispatch_binary_simd {
216 (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ ($lhs, $rhs) }};
217 ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ ($lhs, $rhs) }};
218}
219
220#[cfg(feature = "simd")]
221macro_rules! dispatch_binary_scalar_simd {
222 (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{
223 paste! {
224 let simd = match $elem::dtype() {
225 $(DType::[<$ty:upper>] => try_binary_scalar_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)*
226 _ => Err($lhs),
227 };
228 match simd {
229 Ok(out) => return out,
230 Err(args) => args,
231 }
232 }
233 }};
234 ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{
235 paste! {
236 let simd = match $elem::dtype() {
237 $(DType::[<$ty:upper>] => try_binary_scalar_simd::<$elem, $elem, $ty, $ty, $op>($lhs, $rhs),)*
238 DType::QFloat(strategy) => match strategy.q_type {
239 QuantInputType::QInt8 => try_binary_scalar_simd::<$elem, $elem, i8, i8, $op>($lhs, $rhs),
240 },
241 _ => Err($lhs),
242 };
243 match simd {
244 Ok(out) => return out,
245 Err(args) => args,
246 }
247 }
248 }};
249}
250
251#[cfg(not(feature = "simd"))]
252macro_rules! dispatch_binary_scalar_simd {
253 (noq, $elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ $lhs }};
254 ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ $lhs }};
255}
256
257#[cfg(feature = "simd")]
258macro_rules! dispatch_cmp_simd {
259 ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{
260 paste! {
261 let simd = match $elem::dtype() {
262 $(DType::[<$ty:upper>] => try_cmp_simd::<$elem, $ty, $op>($lhs, $rhs),)*
263 DType::QFloat(strategy) => match strategy.q_type {
264 QuantInputType::QInt8 => try_cmp_simd::<$elem, i8, $op>($lhs, $rhs),
265 },
266 _ => Err(($lhs, $rhs)),
267 };
268 match simd {
269 Ok(out) => return out,
270 Err(args) => args,
271 }
272 }
273 }};
274}
275
276#[cfg(not(feature = "simd"))]
277macro_rules! dispatch_cmp_simd {
278 ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ ($lhs, $rhs) }};
279}
280
281#[cfg(feature = "simd")]
282macro_rules! dispatch_cmp_scalar_simd {
283 ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{
284 paste! {
285 let simd = match $elem::dtype() {
286 $(DType::[<$ty:upper>] => try_cmp_scalar_simd::<$elem, $ty, $op>($lhs, $rhs),)*
287 DType::QFloat(strategy) => match strategy.q_type {
288 QuantInputType::QInt8 => try_cmp_scalar_simd::<$elem, i8, $op>($lhs, $rhs),
289 },
290 _ => Err($lhs),
291 };
292 match simd {
293 Ok(out) => return out,
294 Err(args) => args,
295 }
296 }
297 }};
298}
299
300#[cfg(not(feature = "simd"))]
301macro_rules! dispatch_cmp_scalar_simd {
302 ($elem: ty, $op: ty, $lhs: expr, $rhs: expr, $($ty: ty),*) => {{ $lhs }};
303}
304
305#[cfg(feature = "simd")]
306macro_rules! dispatch_unary_simd {
307 ($elem: ty, $op: ty, $lhs: expr, $($ty: ty),*) => {{
308 paste! {
309 let simd = match $elem::dtype() {
310 $(DType::[<$ty:upper>] => try_unary_simd::<$elem, $elem, $ty, $ty, $op>($lhs),)*
311 _ => Err($lhs),
312 };
313 match simd {
314 Ok(out) => return out,
315 Err(args) => args,
316 }
317 }
318 }};
319}
320
321#[cfg(not(feature = "simd"))]
322macro_rules! dispatch_unary_simd {
323 ($elem: ty, $op: ty, $lhs: expr, $($ty: ty),*) => {{ $lhs }};
324}
325
326impl<E> NdArrayMathOps<E>
327where
328 E: Copy + NdArrayElement,
329{
330 pub fn add(lhs: NdArrayTensor<E>, rhs: NdArrayTensor<E>) -> NdArrayTensor<E> {
331 let (lhs, rhs) = dispatch_binary_simd!(
332 E, VecAdd, lhs, rhs, u8, i8, u16, i16, u32, i32, f32, u64, i64, f64
333 );
334
335 let array = &lhs.array + &rhs.array;
336 let array = array.into_shared();
337
338 NdArrayTensor { array }
339 }
340
341 pub fn add_scalar(lhs: NdArrayTensor<E>, rhs: E) -> NdArrayTensor<E> {
342 let lhs = dispatch_binary_scalar_simd!(
343 E,
344 VecAdd,
345 lhs,
346 rhs.elem(),
347 u8,
348 i8,
349 u16,
350 i16,
351 u32,
352 i32,
353 f32,
354 u64,
355 i64,
356 f64
357 );
358
359 let array = lhs.array + rhs;
360 let array = array.into_shared();
361
362 NdArrayTensor { array }
363 }
364
365 pub fn sub(lhs: NdArrayTensor<E>, rhs: NdArrayTensor<E>) -> NdArrayTensor<E> {
366 let (lhs, rhs) = dispatch_binary_simd!(
367 E, VecSub, lhs, rhs, u8, i8, u16, i16, u32, i32, f32, u64, i64, f64
368 );
369
370 let array = lhs.array - rhs.array;
371 let array = array.into_shared();
372
373 NdArrayTensor { array }
374 }
375
376 pub fn sub_scalar(lhs: NdArrayTensor<E>, rhs: E) -> NdArrayTensor<E> {
377 let lhs = dispatch_binary_scalar_simd!(
378 E,
379 VecSub,
380 lhs,
381 rhs.elem(),
382 u8,
383 i8,
384 u16,
385 i16,
386 u32,
387 i32,
388 f32,
389 u64,
390 i64,
391 f64
392 );
393
394 let array = lhs.array - rhs;
395 let array = array.into_shared();
396
397 NdArrayTensor { array }
398 }
399
400 pub fn mul(lhs: NdArrayTensor<E>, rhs: NdArrayTensor<E>) -> NdArrayTensor<E> {
401 let (lhs, rhs) =
402 dispatch_binary_simd!(noq, E, VecMul, lhs, rhs, u16, i16, u32, i32, f32, f64);
403
404 let array = lhs.array * rhs.array;
405 let array = array.into_shared();
406
407 NdArrayTensor { array }
408 }
409
410 pub fn mul_scalar(lhs: NdArrayTensor<E>, rhs: E) -> NdArrayTensor<E> {
411 let lhs = dispatch_binary_scalar_simd!(
412 noq,
413 E,
414 VecMul,
415 lhs,
416 rhs.elem(),
417 u16,
418 i16,
419 u32,
420 i32,
421 f32,
422 f64
423 );
424
425 let array = lhs.array * rhs;
426 let array = array.into_shared();
427
428 NdArrayTensor { array }
429 }
430
431 pub fn div(lhs: NdArrayTensor<E>, rhs: NdArrayTensor<E>) -> NdArrayTensor<E> {
432 let (lhs, rhs) = dispatch_binary_simd!(noq, E, VecDiv, lhs, rhs, f32, f64);
433
434 let array = lhs.array / rhs.array;
435 let array = array.into_shared();
436
437 NdArrayTensor { array }
438 }
439
440 pub fn div_scalar(lhs: NdArrayTensor<E>, rhs: E) -> NdArrayTensor<E> {
441 let lhs = dispatch_binary_scalar_simd!(noq, E, VecDiv, lhs, rhs.elem(), f32, f64);
442
443 let array = lhs.array / rhs;
444 let array = array.into_shared();
445
446 NdArrayTensor { array }
447 }
448
449 pub fn remainder(lhs: NdArrayTensor<E>, rhs: NdArrayTensor<E>) -> NdArrayTensor<E> {
450 let array = lhs.array.clone()
451 - (lhs.array / rhs.array.clone()).mapv_into(|a| (a.to_f64()).floor().elem())
452 * rhs.array;
453 let array = array.into_shared();
454 NdArrayTensor { array }
455 }
456
457 pub fn remainder_scalar(lhs: NdArrayTensor<E>, rhs: E) -> NdArrayTensor<E>
458 where
459 E: core::ops::Rem<Output = E>,
460 {
461 let array = lhs.array.mapv(|x| ((x % rhs) + rhs) % rhs);
462 let array = array.into_shared();
463
464 NdArrayTensor { array }
465 }
466
467 pub fn recip(tensor: NdArrayTensor<E>) -> NdArrayTensor<E> {
468 let tensor = dispatch_unary_simd!(E, RecipVec, tensor, f32);
469
470 let array = tensor.array.map(|x| 1.elem::<E>() / *x);
471 let array = array.into_shared();
472
473 NdArrayTensor { array }
474 }
475
476 pub fn mean(tensor: NdArrayTensor<E>) -> NdArrayTensor<E> {
477 let data = TensorData::from([tensor.array.mean().unwrap()]);
478 NdArrayTensor::from_data(data)
479 }
480
481 pub fn sum(tensor: NdArrayTensor<E>) -> NdArrayTensor<E> {
482 let data = TensorData::from([tensor.array.sum()]);
483 NdArrayTensor::from_data(data)
484 }
485
486 pub fn prod(tensor: NdArrayTensor<E>) -> NdArrayTensor<E> {
487 let data = TensorData::from([tensor.array.product()]);
488 NdArrayTensor::from_data(data)
489 }
490
491 pub fn mean_dim(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<E> {
492 let ndims = tensor.shape().num_dims();
493 match ndims {
494 d if (1..=6).contains(&d) => keepdim!(dim, tensor, mean),
495 _ => panic!("Dim not supported {ndims}"),
496 }
497 }
498
499 pub fn sum_dim(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<E> {
500 let ndims = tensor.shape().num_dims();
501 match ndims {
502 d if (1..=6).contains(&d) => keepdim!(dim, tensor, sum),
503 _ => panic!("Dim not supported {ndims}"),
504 }
505 }
506
507 pub fn prod_dim(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<E> {
508 let ndims = tensor.shape().num_dims();
509 match ndims {
510 d if (1..=6).contains(&d) => keepdim!(dim, tensor, prod),
511 _ => panic!("Dim not supported {ndims}"),
512 }
513 }
514
515 pub fn gather<I: NdArrayElement>(
516 dim: usize,
517 mut tensor: NdArrayTensor<E>,
518 mut indices: NdArrayTensor<I>,
519 ) -> NdArrayTensor<E> {
520 let ndims = tensor.shape().num_dims();
521 if dim != ndims - 1 {
522 tensor.array.swap_axes(ndims - 1, dim);
523 indices.array.swap_axes(ndims - 1, dim);
524 }
525 let (shape_tensor, shape_indices) = (tensor.shape(), indices.shape());
526 let (size_tensor, size_index) =
527 (shape_tensor.dims[ndims - 1], shape_indices.dims[ndims - 1]);
528 let batch_size = Self::gather_batch_size(&shape_tensor, &shape_indices);
529
530 let indices = NdArrayOps::reshape(indices, Shape::new([batch_size, size_index])).array;
531 let tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])).array;
532 let mut output = Array2::zeros((batch_size, size_index));
533
534 for b in 0..batch_size {
535 let indices = indices.slice(s!(b, ..));
536
537 for (i, index) in indices.iter().enumerate() {
538 output[[b, i]] = tensor[[b, index.elem::<i64>() as usize]];
539 }
540 }
541
542 let mut output = NdArrayOps::reshape(
543 NdArrayTensor::<E>::new(output.into_shared().into_dyn()),
544 shape_indices,
545 );
546
547 if dim != ndims - 1 {
548 output.array.swap_axes(ndims - 1, dim);
549 }
550
551 output
552 }
553
554 pub fn scatter<I: NdArrayElement>(
555 dim: usize,
556 mut tensor: NdArrayTensor<E>,
557 mut indices: NdArrayTensor<I>,
558 mut value: NdArrayTensor<E>,
559 ) -> NdArrayTensor<E> {
560 let ndims = tensor.shape().num_dims();
561 if dim != ndims - 1 {
562 tensor.array.swap_axes(ndims - 1, dim);
563 indices.array.swap_axes(ndims - 1, dim);
564 value.array.swap_axes(ndims - 1, dim);
565 }
566
567 let (shape_tensor, shape_indices, shape_value) =
568 (tensor.shape(), indices.shape(), value.shape());
569 let (size_tensor, size_index, size_value) = (
570 shape_tensor.dims[ndims - 1],
571 shape_indices.dims[ndims - 1],
572 shape_value.dims[ndims - 1],
573 );
574 let batch_size = Self::gather_batch_size(&shape_tensor, &shape_indices);
575
576 if shape_value != shape_indices {
577 panic!(
578 "Invalid dimension: the shape of the index tensor should be the same as the value \
579 tensor: Index {:?} value {:?}",
580 shape_indices.dims, shape_value.dims
581 );
582 }
583
584 let indices = NdArrayOps::reshape(indices, Shape::new([batch_size, size_index])).array;
585 let value = NdArrayOps::reshape(value, Shape::new([batch_size, size_value])).array;
586 let mut tensor = NdArrayOps::reshape(tensor, Shape::new([batch_size, size_tensor])).array;
587
588 for b in 0..batch_size {
589 let indices = indices.slice(s!(b, ..));
590
591 for (i, index) in indices.iter().enumerate() {
592 let index = index.elem::<i64>() as usize;
593 tensor[[b, index]] += value[[b, i]];
594 }
595 }
596
597 let mut output = NdArrayOps::reshape(
598 NdArrayTensor::<E>::new(tensor.into_shared().into_dyn()),
599 shape_tensor,
600 );
601 if dim != ndims - 1 {
602 output.array.swap_axes(ndims - 1, dim);
603 }
604 output
605 }
606
607 pub fn mask_where(
608 tensor: NdArrayTensor<E>,
609 mask: NdArrayTensor<bool>,
610 source: NdArrayTensor<E>,
611 ) -> NdArrayTensor<E> {
612 let tensor = tensor.array.broadcast(mask.array.dim()).unwrap();
613 let source = source.array.broadcast(mask.array.dim()).unwrap();
614 let output = Zip::from(&tensor)
615 .and(&mask.array)
616 .and(&source)
617 .map_collect(|&x, &mask_val, &y| if mask_val { y } else { x })
618 .into_shared();
619 NdArrayTensor::new(output)
620 }
621
622 pub fn mask_fill(
623 tensor: NdArrayTensor<E>,
624 mask: NdArrayTensor<bool>,
625 value: E,
626 ) -> NdArrayTensor<E> {
627 let mut output = tensor.array.clone();
628 let broadcast_mask = mask.array.broadcast(output.dim()).unwrap();
629 Zip::from(&mut output)
630 .and(&broadcast_mask)
631 .for_each(|out, &mask_val| {
632 if mask_val {
633 *out = value;
634 }
635 });
636 NdArrayTensor::new(output.into_shared())
637 }
638
639 fn gather_batch_size(shape_tensor: &Shape, shape_indices: &Shape) -> usize {
640 let ndims = shape_tensor.num_dims();
641 let mut batch_size = 1;
642
643 for i in 0..ndims - 1 {
644 if shape_tensor.dims[i] != shape_indices.dims[i] {
645 panic!(
646 "Unsupported dimension, only the last dimension can differ: Tensor {:?} Index \
647 {:?}",
648 shape_tensor.dims, shape_indices.dims
649 );
650 }
651 batch_size *= shape_indices.dims[i];
652 }
653
654 batch_size
655 }
656
657 pub fn select<I: NdArrayElement>(
658 tensor: NdArrayTensor<E>,
659 dim: usize,
660 indices: NdArrayTensor<I>,
661 ) -> NdArrayTensor<E> {
662 let array = tensor.array.select(
663 Axis(dim),
664 &indices
665 .array
666 .into_iter()
667 .map(|i| i.elem::<i64>() as usize)
668 .collect::<Vec<_>>(),
669 );
670
671 NdArrayTensor::new(array.into_shared())
672 }
673
674 pub fn select_assign<I: NdArrayElement>(
675 tensor: NdArrayTensor<E>,
676 dim: usize,
677 indices: NdArrayTensor<I>,
678 value: NdArrayTensor<E>,
679 ) -> NdArrayTensor<E> {
680 let mut output_array = tensor.array.into_owned();
681
682 for (index_value, index) in indices.array.into_iter().enumerate() {
683 let mut view = output_array.index_axis_mut(Axis(dim), index.elem::<i64>() as usize);
684 let value = value.array.index_axis(Axis(dim), index_value);
685
686 view.zip_mut_with(&value, |a, b| *a += *b);
687 }
688
689 NdArrayTensor::new(output_array.into_shared())
690 }
691 pub fn argmax<I: NdArrayElement>(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<I> {
692 arg(tensor, dim, CmpType::Max)
693 }
694
695 pub fn argmin<I: NdArrayElement>(tensor: NdArrayTensor<E>, dim: usize) -> NdArrayTensor<I> {
696 arg(tensor, dim, CmpType::Min)
697 }
698
699 pub fn clamp_min(tensor: NdArrayTensor<E>, min: E) -> NdArrayTensor<E> {
700 let mut tensor = dispatch_binary_scalar_simd!(
701 E,
702 VecMax,
703 tensor,
704 min.elem(),
705 u8,
706 i8,
707 u16,
708 i16,
709 u32,
710 i32,
711 f32,
712 u64,
713 i64,
714 f64
715 );
716
717 tensor.array.mapv_inplace(|x| match x < min {
718 true => min,
719 false => x,
720 });
721
722 tensor
723 }
724
725 pub fn clamp_max(tensor: NdArrayTensor<E>, max: E) -> NdArrayTensor<E> {
726 let mut tensor = dispatch_binary_scalar_simd!(
727 E,
728 VecMin,
729 tensor,
730 max.elem(),
731 u8,
732 i8,
733 u16,
734 i16,
735 u32,
736 i32,
737 f32,
738 u64,
739 i64,
740 f64
741 );
742
743 tensor.array.mapv_inplace(|x| match x > max {
744 true => max,
745 false => x,
746 });
747
748 tensor
749 }
750
751 pub fn clamp(tensor: NdArrayTensor<E>, min: E, max: E) -> NdArrayTensor<E> {
752 let mut tensor = dispatch_binary_scalar_simd!(
753 E,
754 VecClamp,
755 tensor,
756 (min.elem(), max.elem()),
757 u8,
758 i8,
759 u16,
760 i16,
761 u32,
762 i32,
763 f32,
764 u64,
765 i64,
766 f64
767 );
768
769 tensor.array.mapv_inplace(|x| match x < min {
770 true => min,
771 false => match x > max {
772 true => max,
773 false => x,
774 },
775 });
776
777 tensor
778 }
779
780 pub(crate) fn elementwise_op<OtherE>(
781 lhs: NdArrayTensor<E>,
782 rhs: NdArrayTensor<OtherE>,
783 var_name: impl FnMut(&E, &OtherE) -> E,
784 ) -> NdArrayTensor<E> {
785 let lhs = lhs
786 .array
787 .broadcast(rhs.array.dim())
788 .unwrap_or(lhs.array.view());
789 let rhs = rhs.array.broadcast(lhs.dim()).unwrap_or(rhs.array.view());
790
791 NdArrayTensor::new(Zip::from(lhs).and(rhs).map_collect(var_name).into_shared())
792 }
793
794 pub(crate) fn elementwise_op_scalar(
795 lhs: NdArrayTensor<E>,
796 var_name: impl FnMut(E) -> E,
797 ) -> NdArrayTensor<E> {
798 NdArrayTensor::new(lhs.array.mapv(var_name).into_shared())
799 }
800
801 pub(crate) fn sign_op(tensor: NdArrayTensor<E>) -> NdArrayTensor<E>
802 where
803 E: Signed,
804 {
805 let zero = 0.elem();
806 let one = 1.elem::<E>();
807 NdArrayTensor::new(
808 tensor
809 .array
810 .mapv(|x| {
811 if x > zero {
812 one
813 } else if x < zero {
814 -one
815 } else {
816 zero
817 }
818 })
819 .into_shared(),
820 )
821 }
822
823 pub(crate) fn abs(tensor: NdArrayTensor<E>) -> NdArrayTensor<E> {
824 let tensor = dispatch_unary_simd!(E, VecAbs, tensor, i8, i16, i32, f32, f64);
825
826 let array = tensor.array.mapv_into(|a| a.abs_elem()).into_shared();
827
828 NdArrayTensor::new(array)
829 }
830
831 pub(crate) fn equal(lhs: NdArrayTensor<E>, rhs: NdArrayTensor<E>) -> NdArrayTensor<bool> {
832 let (lhs, rhs) = dispatch_cmp_simd!(
833 E, VecEquals, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64
834 );
835
836 let output = Zip::from(&lhs.array)
837 .and(&rhs.array)
838 .map_collect(|&lhs_val, &rhs_val| (lhs_val == rhs_val))
839 .into_shared();
840 NdArrayTensor::new(output)
841 }
842
843 pub(crate) fn equal_elem(lhs: NdArrayTensor<E>, rhs: E) -> NdArrayTensor<bool> {
844 let lhs = dispatch_cmp_scalar_simd!(
845 E,
846 VecEquals,
847 lhs,
848 rhs.elem(),
849 u8,
850 i8,
851 u16,
852 i16,
853 u32,
854 f32,
855 i32,
856 u64,
857 i64,
858 f64
859 );
860
861 let array = lhs.array.mapv(|a| a == rhs).into_shared();
862 NdArrayTensor { array }
863 }
864
865 pub(crate) fn greater(lhs: NdArrayTensor<E>, rhs: NdArrayTensor<E>) -> NdArrayTensor<bool> {
866 let (lhs, rhs) = dispatch_cmp_simd!(
867 E, VecGreater, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64
868 );
869
870 let lhs = lhs
871 .array
872 .broadcast(rhs.array.dim())
873 .unwrap_or(lhs.array.view());
874 let rhs = rhs.array.broadcast(lhs.dim()).unwrap_or(rhs.array.view());
875
876 NdArrayTensor::new(
877 Zip::from(lhs)
878 .and(rhs)
879 .map_collect(|lhs, rhs| lhs > rhs)
880 .into_shared(),
881 )
882 }
883
884 pub(crate) fn greater_elem(lhs: NdArrayTensor<E>, rhs: E) -> NdArrayTensor<bool> {
885 let lhs = dispatch_cmp_scalar_simd!(
886 E,
887 VecGreater,
888 lhs,
889 rhs.elem(),
890 u8,
891 i8,
892 u16,
893 i16,
894 u32,
895 f32,
896 i32,
897 u64,
898 i64,
899 f64
900 );
901
902 let array = lhs.array.mapv(|a| a > rhs).into_shared();
903 NdArrayTensor { array }
904 }
905
906 pub(crate) fn greater_equal(
907 lhs: NdArrayTensor<E>,
908 rhs: NdArrayTensor<E>,
909 ) -> NdArrayTensor<bool> {
910 let (lhs, rhs) = dispatch_cmp_simd!(
911 E,
912 VecGreaterEq,
913 lhs,
914 rhs,
915 u8,
916 i8,
917 u16,
918 i16,
919 u32,
920 f32,
921 i32,
922 u64,
923 i64,
924 f64
925 );
926
927 let lhs = lhs
928 .array
929 .broadcast(rhs.array.dim())
930 .unwrap_or(lhs.array.view());
931 let rhs = rhs.array.broadcast(lhs.dim()).unwrap_or(rhs.array.view());
932
933 NdArrayTensor::new(
934 Zip::from(lhs)
935 .and(rhs)
936 .map_collect(|lhs, rhs| lhs >= rhs)
937 .into_shared(),
938 )
939 }
940
941 pub(crate) fn greater_equal_elem(lhs: NdArrayTensor<E>, rhs: E) -> NdArrayTensor<bool> {
942 let lhs = dispatch_cmp_scalar_simd!(
943 E,
944 VecGreaterEq,
945 lhs,
946 rhs.elem(),
947 u8,
948 i8,
949 u16,
950 i16,
951 u32,
952 f32,
953 i32,
954 u64,
955 i64,
956 f64
957 );
958
959 let array = lhs.array.mapv(|a| a >= rhs).into_shared();
960 NdArrayTensor { array }
961 }
962
963 pub(crate) fn lower_equal(lhs: NdArrayTensor<E>, rhs: NdArrayTensor<E>) -> NdArrayTensor<bool> {
964 let (lhs, rhs) = dispatch_cmp_simd!(
965 E, VecLowerEq, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64
966 );
967
968 let lhs = lhs
969 .array
970 .broadcast(rhs.array.dim())
971 .unwrap_or(lhs.array.view());
972 let rhs = rhs.array.broadcast(lhs.dim()).unwrap_or(rhs.array.view());
973
974 NdArrayTensor::new(
975 Zip::from(lhs)
976 .and(rhs)
977 .map_collect(|lhs, rhs| lhs <= rhs)
978 .into_shared(),
979 )
980 }
981
982 pub(crate) fn lower_equal_elem(lhs: NdArrayTensor<E>, rhs: E) -> NdArrayTensor<bool> {
983 let lhs = dispatch_cmp_scalar_simd!(
984 E,
985 VecLowerEq,
986 lhs,
987 rhs.elem(),
988 u8,
989 i8,
990 u16,
991 i16,
992 u32,
993 f32,
994 i32,
995 u64,
996 i64,
997 f64
998 );
999
1000 let array = lhs.array.mapv(|a| a <= rhs).into_shared();
1001 NdArrayTensor { array }
1002 }
1003
1004 pub(crate) fn lower(lhs: NdArrayTensor<E>, rhs: NdArrayTensor<E>) -> NdArrayTensor<bool> {
1005 let (lhs, rhs) = dispatch_cmp_simd!(
1006 E, VecLower, lhs, rhs, u8, i8, u16, i16, u32, f32, i32, u64, i64, f64
1007 );
1008
1009 let lhs = lhs
1010 .array
1011 .broadcast(rhs.array.dim())
1012 .unwrap_or(lhs.array.view());
1013 let rhs = rhs.array.broadcast(lhs.dim()).unwrap_or(rhs.array.view());
1014
1015 NdArrayTensor::new(
1016 Zip::from(lhs)
1017 .and(rhs)
1018 .map_collect(|lhs, rhs| lhs < rhs)
1019 .into_shared(),
1020 )
1021 }
1022
1023 pub(crate) fn lower_elem(lhs: NdArrayTensor<E>, rhs: E) -> NdArrayTensor<bool> {
1024 let lhs = dispatch_cmp_scalar_simd!(
1025 E,
1026 VecLower,
1027 lhs,
1028 rhs.elem(),
1029 u8,
1030 i8,
1031 u16,
1032 i16,
1033 u32,
1034 f32,
1035 i32,
1036 u64,
1037 i64,
1038 f64
1039 );
1040
1041 let array = lhs.array.mapv(|a| a < rhs).into_shared();
1042 NdArrayTensor { array }
1043 }
1044}
1045
1046pub struct NdArrayBitOps<I: IntNdArrayElement>(PhantomData<I>);
1047
1048impl<I: IntNdArrayElement> NdArrayBitOps<I> {
1049 pub(crate) fn bitand(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
1050 let (lhs, rhs) =
1051 dispatch_binary_simd!(I, VecBitAnd, lhs, rhs, i8, u8, i16, u16, i32, u32, i64, u64);
1052
1053 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
1054 (a.elem::<i64>() & (b.elem::<i64>())).elem()
1055 })
1056 }
1057
1058 pub(crate) fn bitand_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
1059 let lhs = dispatch_binary_scalar_simd!(
1060 I,
1061 VecBitAnd,
1062 lhs,
1063 rhs.elem(),
1064 i8,
1065 u8,
1066 i16,
1067 u16,
1068 i32,
1069 u32,
1070 i64,
1071 u64
1072 );
1073
1074 NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| {
1075 (a.elem::<i64>() & rhs.elem::<i64>()).elem()
1076 })
1077 }
1078
1079 pub(crate) fn bitor(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
1080 let (lhs, rhs) =
1081 dispatch_binary_simd!(I, VecBitOr, lhs, rhs, i8, u8, i16, u16, i32, u32, i64, u64);
1082
1083 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
1084 (a.elem::<i64>() | (b.elem::<i64>())).elem()
1085 })
1086 }
1087
1088 pub(crate) fn bitor_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
1089 let lhs = dispatch_binary_scalar_simd!(
1090 I,
1091 VecBitOr,
1092 lhs,
1093 rhs.elem(),
1094 i8,
1095 u8,
1096 i16,
1097 u16,
1098 i32,
1099 u32,
1100 i64,
1101 u64
1102 );
1103
1104 NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| {
1105 (a.elem::<i64>() | rhs.elem::<i64>()).elem()
1106 })
1107 }
1108
1109 pub(crate) fn bitxor(lhs: NdArrayTensor<I>, rhs: NdArrayTensor<I>) -> NdArrayTensor<I> {
1110 let (lhs, rhs) =
1111 dispatch_binary_simd!(I, VecBitXor, lhs, rhs, i8, u8, i16, u16, i32, u32, i64, u64);
1112
1113 NdArrayMathOps::elementwise_op(lhs, rhs, |a: &I, b: &I| {
1114 (a.elem::<i64>() ^ (b.elem::<i64>())).elem()
1115 })
1116 }
1117
1118 pub(crate) fn bitxor_scalar(lhs: NdArrayTensor<I>, rhs: I) -> NdArrayTensor<I> {
1119 let lhs = dispatch_binary_scalar_simd!(
1120 I,
1121 VecBitXor,
1122 lhs,
1123 rhs.elem(),
1124 i8,
1125 u8,
1126 i16,
1127 u16,
1128 i32,
1129 u32,
1130 i64,
1131 u64
1132 );
1133
1134 NdArrayMathOps::elementwise_op_scalar(lhs, |a: I| {
1135 (a.elem::<i64>() ^ rhs.elem::<i64>()).elem()
1136 })
1137 }
1138
1139 pub(crate) fn bitnot(tensor: NdArrayTensor<I>) -> NdArrayTensor<I> {
1140 let tensor =
1141 dispatch_unary_simd!(I, VecBitNot, tensor, i8, u8, i16, u16, i32, u32, i64, u64);
1142
1143 NdArrayMathOps::elementwise_op_scalar(tensor, |a: I| (!a.elem::<i64>()).elem())
1144 }
1145}
1146
1147pub struct NdArrayBoolOps;
1148
1149impl NdArrayBoolOps {
1152 pub(crate) fn equal(lhs: NdArrayTensor<bool>, rhs: NdArrayTensor<bool>) -> NdArrayTensor<bool> {
1153 #[cfg(feature = "simd")]
1154 let (lhs, rhs) = match try_cmp_simd::<bool, u8, VecEquals>(lhs, rhs) {
1155 Ok(out) => return out,
1156 Err(args) => args,
1157 };
1158
1159 let output = Zip::from(&lhs.array)
1160 .and(&rhs.array)
1161 .map_collect(|&lhs_val, &rhs_val| (lhs_val == rhs_val))
1162 .into_shared();
1163 NdArrayTensor::new(output)
1164 }
1165
1166 pub(crate) fn and(lhs: NdArrayTensor<bool>, rhs: NdArrayTensor<bool>) -> NdArrayTensor<bool> {
1167 #[cfg(feature = "simd")]
1168 let (lhs, rhs) = match try_binary_simd::<bool, bool, u8, u8, VecBitAnd>(lhs, rhs) {
1169 Ok(out) => return out,
1170 Err(args) => args,
1171 };
1172
1173 let output = Zip::from(&lhs.array)
1174 .and(&rhs.array)
1175 .map_collect(|&lhs_val, &rhs_val| (lhs_val && rhs_val))
1176 .into_shared();
1177 NdArrayTensor::new(output)
1178 }
1179
1180 pub(crate) fn or(lhs: NdArrayTensor<bool>, rhs: NdArrayTensor<bool>) -> NdArrayTensor<bool> {
1181 #[cfg(feature = "simd")]
1182 let (lhs, rhs) = match try_binary_simd::<bool, bool, u8, u8, VecBitOr>(lhs, rhs) {
1183 Ok(out) => return out,
1184 Err(args) => args,
1185 };
1186
1187 let output = Zip::from(&lhs.array)
1188 .and(&rhs.array)
1189 .map_collect(|&lhs_val, &rhs_val| (lhs_val || rhs_val))
1190 .into_shared();
1191 NdArrayTensor::new(output)
1192 }
1193}
1194
1195enum CmpType {
1196 Min,
1197 Max,
1198}
1199
1200fn arg<E: NdArrayElement, I: NdArrayElement>(
1201 tensor: NdArrayTensor<E>,
1202 dim: usize,
1203 cmp: CmpType,
1204) -> NdArrayTensor<I> {
1205 let mut reshape = tensor.array.shape().to_vec();
1206 reshape[dim] = 1;
1207
1208 let output = tensor.array.map_axis(Axis(dim), |arr| {
1209 let (_e, idx) = arr.indexed_iter().fold((arr[0], 0usize), |acc, (idx, e)| {
1211 let cmp = match cmp {
1212 CmpType::Min => e < &acc.0,
1213 CmpType::Max => e > &acc.0,
1214 };
1215
1216 if cmp { (*e, idx) } else { acc }
1217 });
1218
1219 (idx as i64).elem()
1220 });
1221
1222 let output = output.to_shape(Dim(reshape.as_slice())).unwrap();
1223
1224 NdArrayTensor {
1225 array: output.into_shared(),
1226 }
1227}
1228
1229#[cfg(test)]
1230mod tests {
1231 use super::*;
1232
1233 #[test]
1234 fn should_generate_row_major_layout_for_cat() {
1235 let expected_shape: &[usize] = &[4, 6, 2];
1236 let expected_strides: &[isize] = &[12, 2, 1];
1237 let expected_array: NdArrayTensor<i32> = NdArrayTensor::from_data(TensorData::from([
1238 [[1, 0], [2, 0], [3, 0], [4, 0], [5, 0], [6, 0]],
1239 [[7, 0], [8, 0], [9, 0], [10, 0], [11, 0], [12, 0]],
1240 [[13, 0], [14, 0], [15, 0], [16, 0], [17, 0], [18, 0]],
1241 [[19, 0], [20, 0], [21, 0], [22, 0], [23, 0], [24, 0]],
1242 ]));
1243
1244 let array = NdArrayOps::reshape(
1246 NdArrayTensor::<i32>::from_data(TensorData::from([
1247 [1, 2, 3, 4, 5, 6],
1248 [7, 8, 9, 10, 11, 12],
1249 [13, 14, 15, 16, 17, 18],
1250 [19, 20, 21, 22, 23, 24],
1251 ])),
1252 Shape::from([4, 6, 1]),
1253 );
1254 let zeros = NdArrayTensor::<i32>::from_data(TensorData::zeros::<i32, _>([4, 6, 1]));
1255 let array = NdArrayOps::cat([array, zeros].to_vec(), 2);
1257
1258 assert!(array.array.is_standard_layout());
1259 assert_eq!(array.array.shape(), expected_shape);
1260 assert_eq!(array.array.strides(), expected_strides);
1261 assert_eq!(
1262 array.array.into_iter().collect::<Vec<_>>(),
1263 expected_array.array.into_iter().collect::<Vec<_>>(),
1264 );
1265 }
1266}