hpt_traits/
tensor.rs

1use hpt_common::error::base::TensorError;
2use hpt_common::{
3    axis::axis::Axis, layout::layout::Layout, shape::shape::Shape, strides::strides::Strides,
4    utils::pointer::Pointer,
5};
6#[cfg(feature = "archsimd")]
7use hpt_types::arch_simd as simd;
8#[cfg(feature = "stdsimd")]
9use hpt_types::std_simd as simd;
10use hpt_types::{
11    dtype::TypeCommon,
12    into_scalar::Cast,
13    type_promote::{FloatOutBinary, FloatOutUnary, NormalOut, NormalOutUnary},
14};
15use std::fmt::Debug;
16use std::{borrow::Borrow, fmt::Display};
17
18#[cfg(target_feature = "avx2")]
19type BoolVector = simd::_256bit::boolx32::boolx32;
20#[cfg(target_feature = "avx512f")]
21type BoolVector = simd::_512bit::boolx64::boolx64;
22#[cfg(any(
23    all(not(target_feature = "avx2"), target_feature = "sse"),
24    target_arch = "arm",
25    target_arch = "aarch64",
26    target_feature = "neon"
27))]
28type BoolVector = simd::_128bit::boolx16::boolx16;
29
30/// A trait for getting information of a Tensor
31pub trait TensorInfo<T> {
32    /// Returns a pointer to the tensor's first data.
33    #[track_caller]
34    fn ptr(&self) -> Pointer<T>;
35
36    /// Returns the size of the tensor based on the shape
37    #[track_caller]
38    fn size(&self) -> usize;
39
40    /// Returns the shape of the tensor.
41    #[track_caller]
42    fn shape(&self) -> &Shape;
43
44    /// Returns the strides of the tensor.
45    #[track_caller]
46    fn strides(&self) -> &Strides;
47
48    /// Returns the layout of the tensor. Layout contains shape and strides.
49    #[track_caller]
50    fn layout(&self) -> &Layout;
51    /// Returns the root tensor, if any.
52    ///
53    /// if the tensor is a view, it will return the root tensor. Otherwise, it will return None.
54    #[track_caller]
55    fn parent(&self) -> Option<Pointer<T>>;
56
57    /// Returns the number of dimensions of the tensor.
58    #[track_caller]
59    fn ndim(&self) -> usize;
60
61    /// Returns whether the tensor is contiguous in memory. View or transpose tensors are not contiguous.
62    #[track_caller]
63    fn is_contiguous(&self) -> bool;
64
65    /// Returns the data type memory size in bytes.
66    #[track_caller]
67    fn elsize() -> usize {
68        size_of::<T>()
69    }
70}
71
72/// A trait for let the object like a tensor
73pub trait TensorLike<T>: Sized {
74    /// directly convert the tensor to raw slice
75    ///
76    /// # Note
77    ///
78    /// This function will return a raw slice of the tensor regardless of the shape and strides.
79    ///
80    /// if you do iteration on the view tensor, you may see unexpected results.
81    fn as_raw(&self) -> &[T];
82
83    /// directly convert the tensor to mutable raw slice
84    ///
85    /// # Note
86    ///
87    /// This function will return a mutable raw slice of the tensor regardless of the shape and strides.
88    ///
89    /// if you do iteration on the view tensor, you may see unexpected results.
90    fn as_raw_mut(&mut self) -> &mut [T];
91
92    /// Returns the tensor as a contiguous tensor.
93    ///
94    /// # Note
95    ///
96    /// This function will return a contiguous tensor. If the tensor is already contiguous, it will return a clone of the tensor.
97    ///
98    /// If the tensor is a view tensor, it will return a new tensor with the same data but with a contiguous layout.
99    fn contiguous(&self) -> Result<Self, TensorError>;
100
101    /// Returns the data type memory size in bytes.
102    fn elsize() -> usize {
103        size_of::<T>()
104    }
105}
106
107/// A trait defines a set of functions to create tensors.
108pub trait TensorCreator<T>
109where
110    Self: Sized,
111{
112    /// the output type of the creator
113    type Output;
114
115    /// Creates a tensor with uninitialized elements of the specified shape.
116    ///
117    /// This function allocates memory for a tensor of the given shape, but the values are uninitialized, meaning they may contain random data.
118    ///
119    /// # Arguments
120    ///
121    /// * `shape` - The desired shape of the tensor. The type `S` must implement `Into<Shape>`.
122    ///
123    /// # Returns
124    ///
125    /// * A tensor with the specified shape, but with uninitialized data.
126    ///
127    /// # Panics
128    ///
129    /// * This function may panic if the requested shape is invalid or too large for available memory.
130    #[track_caller]
131    fn empty<S: Into<Shape>>(shape: S) -> Result<Self::Output, TensorError>;
132
133    /// Creates a tensor filled with zeros of the specified shape.
134    ///
135    /// This function returns a tensor where every element is initialized to `0`, with the shape defined by the input.
136    ///
137    /// # Arguments
138    ///
139    /// * `shape` - The desired shape of the tensor. The type `S` must implement `Into<Shape>`.
140    ///
141    /// # Returns
142    ///
143    /// * A tensor filled with zeros, with the specified shape.
144    ///
145    /// # Panics
146    ///
147    /// * This function may panic if the requested shape is invalid or too large for available memory.
148    #[track_caller]
149    fn zeros<S: Into<Shape>>(shape: S) -> Result<Self::Output, TensorError>;
150
151    /// Creates a tensor filled with ones of the specified shape.
152    ///
153    /// This function returns a tensor where every element is initialized to `1`, with the shape defined by the input.
154    ///
155    /// # Arguments
156    ///
157    /// * `shape` - The desired shape of the tensor. The type `S` must implement `Into<Shape>`.
158    ///
159    /// # Returns
160    ///
161    /// * A tensor filled with ones, with the specified shape.
162    ///
163    /// # Panics
164    ///
165    /// * This function may panic if the requested shape is invalid or too large for available memory.
166    #[track_caller]
167    fn ones<S: Into<Shape>>(shape: S) -> Result<Self::Output, TensorError>
168    where
169        u8: Cast<T>;
170
171    /// Creates a tensor with uninitialized elements, having the same shape as the input tensor.
172    ///
173    /// This function returns a tensor with the same shape as the calling tensor, but with uninitialized data.
174    ///
175    /// # Arguments
176    ///
177    /// This function takes no arguments.
178    ///
179    /// # Returns
180    ///
181    /// * A tensor with the same shape as the input, but with uninitialized data.
182    ///
183    /// # Panics
184    ///
185    /// * This function may panic if the shape is too large for available memory.
186    #[track_caller]
187    fn empty_like(&self) -> Result<Self::Output, TensorError>;
188
189    /// Creates a tensor filled with zeros, having the same shape as the input tensor.
190    ///
191    /// This function returns a tensor with the same shape as the calling tensor, with all elements initialized to `0`.
192    ///
193    /// # Arguments
194    ///
195    /// This function takes no arguments.
196    ///
197    /// # Returns
198    ///
199    /// * A tensor with the same shape as the input, filled with zeros.
200    ///
201    /// # Panics
202    ///
203    /// * This function may panic if the shape is too large for available memory.
204    #[track_caller]
205    fn zeros_like(&self) -> Result<Self::Output, TensorError>;
206
207    /// Creates a tensor filled with ones, having the same shape as the input tensor.
208    ///
209    /// This function returns a tensor with the same shape as the calling tensor, with all elements initialized to `1`.
210    ///
211    /// # Arguments
212    ///
213    /// This function takes no arguments.
214    ///
215    /// # Returns
216    ///
217    /// * A tensor with the same shape as the input, filled with ones.
218    ///
219    /// # Panics
220    ///
221    /// * This function may panic if the shape is too large for available memory.
222    #[track_caller]
223    fn ones_like(&self) -> Result<Self::Output, TensorError>
224    where
225        u8: Cast<T>;
226
227    /// Creates a tensor filled with a specified value, with the specified shape.
228    ///
229    /// This function returns a tensor where every element is set to `val`, with the shape defined by the input.
230    ///
231    /// # Arguments
232    ///
233    /// * `val` - The value to fill the tensor with.
234    /// * `shape` - The desired shape of the tensor. The type `S` must implement `Into<Shape>`.
235    ///
236    /// # Returns
237    ///
238    /// * A tensor filled with `val`, with the specified shape.
239    ///
240    /// # Panics
241    ///
242    /// * This function may panic if the requested shape is invalid or too large for available memory.
243    #[track_caller]
244    fn full<S: Into<Shape>>(val: T, shape: S) -> Result<Self::Output, TensorError>;
245
246    /// Creates a tensor filled with a specified value, having the same shape as the input tensor.
247    ///
248    /// This function returns a tensor where every element is set to `val`, with the same shape as the calling tensor.
249    ///
250    /// # Arguments
251    ///
252    /// * `val` - The value to fill the tensor with.
253    ///
254    /// # Returns
255    ///
256    /// * A tensor with the same shape as the input, filled with `val`.
257    ///
258    /// # Panics
259    ///
260    /// * This function may panic if the shape is too large for available memory.
261    #[track_caller]
262    fn full_like(&self, val: T) -> Result<Self::Output, TensorError>;
263
264    /// Creates a tensor with values within a specified range.
265    ///
266    /// This function generates a 1D tensor with values ranging from `start` (inclusive) to `end` (exclusive).
267    ///
268    /// # Arguments
269    ///
270    /// * `start` - The start of the range.
271    /// * `end` - The end of the range.
272    ///
273    /// # Returns
274    ///
275    /// * A 1D tensor with values ranging from `start` to `end`.
276    ///
277    /// # Panics
278    ///
279    /// * This function will panic if `start` is greater than or equal to `end`, or if the range is too large for available memory.
280    #[track_caller]
281    fn arange<U>(start: U, end: U) -> Result<Self::Output, TensorError>
282    where
283        usize: Cast<T>,
284        U: Cast<i64> + Cast<T> + Copy;
285
286    /// Creates a tensor with values within a specified range with a given step size.
287    ///
288    /// This function generates a 1D tensor with values ranging from `start` (inclusive) to `end` (exclusive),
289    /// incremented by `step`.
290    ///
291    /// # Arguments
292    ///
293    /// * `start` - The start of the range.
294    /// * `end` - The end of the range (exclusive).
295    /// * `step` - The step size between consecutive values.
296    ///
297    /// # Returns
298    ///
299    /// * A 1D tensor with values from `start` to `end`, incremented by `step`.
300    ///
301    /// # Panics
302    ///
303    /// * This function will panic if `step` is zero or if the range and step values are incompatible.
304    #[track_caller]
305    fn arange_step(start: T, end: T, step: T) -> Result<Self::Output, TensorError>
306    where
307        T: Cast<f64> + Cast<usize>,
308        usize: Cast<T>;
309
310    /// Creates a 2D identity matrix with ones on a diagonal and zeros elsewhere.
311    ///
312    /// This function generates a matrix of size `n` by `m`, with ones on the `k`th diagonal (can be offset) and zeros elsewhere.
313    ///
314    /// # Arguments
315    ///
316    /// * `n` - The number of rows in the matrix.
317    /// * `m` - The number of columns in the matrix.
318    /// * `k` - The diagonal offset (0 for main diagonal, positive for upper diagonals, negative for lower diagonals).
319    ///
320    /// # Returns
321    ///
322    /// * A 2D identity matrix with ones on the specified diagonal.
323    ///
324    /// # Panics
325    ///
326    /// * This function will panic if `n` or `m` is zero, or if memory constraints are exceeded.
327    #[track_caller]
328    fn eye(n: usize, m: usize, k: usize) -> Result<Self::Output, TensorError>;
329
330    /// Creates a tensor with evenly spaced values between `start` and `end`.
331    ///
332    /// This function generates a 1D tensor of `num` values, linearly spaced between `start` and `end`.
333    /// If `include_end` is `true`, the `end` value will be included as the last element.
334    ///
335    /// # Arguments
336    ///
337    /// * `start` - The start of the range.
338    /// * `end` - The end of the range.
339    /// * `num` - The number of evenly spaced values to generate.
340    /// * `include_end` - Whether to include the `end` value in the generated tensor.
341    ///
342    /// # Returns
343    ///
344    /// * A 1D tensor with `num` linearly spaced values between `start` and `end`.
345    ///
346    /// # Panics
347    ///
348    /// * This function will panic if `num` is zero or if `num` is too large for available memory.
349    #[track_caller]
350    fn linspace<U>(
351        start: U,
352        end: U,
353        num: usize,
354        include_end: bool,
355    ) -> Result<Self::Output, TensorError>
356    where
357        U: Cast<f64> + Cast<T> + Copy,
358        usize: Cast<T>,
359        f64: Cast<T>;
360
361    /// Creates a tensor with logarithmically spaced values between `start` and `end`.
362    ///
363    /// This function generates a 1D tensor of `num` values spaced evenly on a log scale between `start` and `end`.
364    /// The spacing is based on the logarithm to the given `base`. If `include_end` is `true`, the `end` value will be included.
365    ///
366    /// # Arguments
367    ///
368    /// * `start` - The starting exponent (base `base`).
369    /// * `end` - The ending exponent (base `base`).
370    /// * `num` - The number of logarithmically spaced values to generate.
371    /// * `include_end` - Whether to include the `end` value in the generated tensor.
372    /// * `base` - The base of the logarithm.
373    ///
374    /// # Returns
375    ///
376    /// * A 1D tensor with `num` logarithmically spaced values between `start` and `end`.
377    ///
378    /// # Panics
379    ///
380    /// * This function will panic if `num` is zero or if `base` is less than or equal to zero.
381    #[track_caller]
382    fn logspace(
383        start: T,
384        end: T,
385        num: usize,
386        include_end: bool,
387        base: T,
388    ) -> Result<Self::Output, TensorError>
389    where
390        T: Cast<f64> + num::Float + NormalOut<T, Output = T>,
391        usize: Cast<T>,
392        f64: Cast<T>;
393
394    /// Creates a tensor with geometrically spaced values between `start` and `end`.
395    ///
396    /// This function generates a 1D tensor of `n` values spaced evenly on a geometric scale between `start` and `end`.
397    /// If `include_end` is `true`, the `end` value will be included.
398    ///
399    /// # Arguments
400    ///
401    /// * `start` - The starting value (must be positive).
402    /// * `end` - The ending value (must be positive).
403    /// * `n` - The number of geometrically spaced values to generate.
404    /// * `include_end` - Whether to include the `end` value in the generated tensor.
405    ///
406    /// # Returns
407    ///
408    /// * A 1D tensor with `n` geometrically spaced values between `start` and `end`.
409    ///
410    /// # Panics
411    ///
412    /// * This function will panic if `n` is zero, if `start` or `end` is negative, or if the values result in undefined behavior.
413    #[track_caller]
414    fn geomspace(
415        start: T,
416        end: T,
417        n: usize,
418        include_end: bool,
419    ) -> Result<Self::Output, TensorError>
420    where
421        f64: Cast<T>,
422        usize: Cast<T>,
423        T: Cast<f64>;
424
425    /// Creates a 2D triangular matrix of size `n` by `m`, with ones below or on the `k`th diagonal and zeros elsewhere.
426    ///
427    /// This function generates a matrix with a triangular structure, filled with ones and zeros, based on the diagonal offset and the `low_triangle` flag.
428    ///
429    /// # Arguments
430    ///
431    /// * `n` - The number of rows in the matrix.
432    /// * `m` - The number of columns in the matrix.
433    /// * `k` - The diagonal offset (0 for main diagonal, positive for upper diagonals, negative for lower diagonals).
434    /// * `low_triangle` - If `true`, the matrix will be lower triangular; otherwise, upper triangular.
435    ///
436    /// # Returns
437    ///
438    /// * A 2D triangular matrix of ones and zeros.
439    ///
440    /// # Panics
441    ///
442    /// * This function will panic if `n` or `m` is zero.
443    #[track_caller]
444    fn tri(n: usize, m: usize, k: i64, low_triangle: bool) -> Result<Self::Output, TensorError>
445    where
446        u8: Cast<T>;
447
448    /// Returns the lower triangular part of the matrix, with all elements above the `k`th diagonal set to zero.
449    ///
450    /// This function generates a tensor where the elements above the specified diagonal are set to zero.
451    ///
452    /// # Arguments
453    ///
454    /// * `k` - The diagonal offset (0 for main diagonal, positive for upper diagonals, negative for lower diagonals).
455    ///
456    /// # Returns
457    ///
458    /// * A tensor with its upper triangular part set to zero.
459    ///
460    /// # Panics
461    ///
462    /// * This function should not panic under normal conditions.
463    #[track_caller]
464    fn tril(&self, k: i64) -> Result<Self::Output, TensorError>
465    where
466        T: NormalOut<bool, Output = T> + Cast<T> + TypeCommon,
467        <T as TypeCommon>::Vec: NormalOut<BoolVector, Output = <T as TypeCommon>::Vec>;
468
469    /// Returns the upper triangular part of the matrix, with all elements below the `k`th diagonal set to zero.
470    ///
471    /// This function generates a tensor where the elements below the specified diagonal are set to zero.
472    ///
473    /// # Arguments
474    ///
475    /// * `k` - The diagonal offset (0 for main diagonal, positive for upper diagonals, negative for lower diagonals).
476    ///
477    /// # Returns
478    ///
479    /// * A tensor with its lower triangular part set to zero.
480    ///
481    /// # Panics
482    ///
483    /// * This function should not panic under normal conditions.
484    #[track_caller]
485    fn triu(&self, k: i64) -> Result<Self::Output, TensorError>
486    where
487        T: NormalOut<bool, Output = T> + Cast<T> + TypeCommon,
488        <T as TypeCommon>::Vec: NormalOut<BoolVector, Output = <T as TypeCommon>::Vec>;
489
490    /// Creates a 2D identity matrix of size `n` by `n`.
491    ///
492    /// This function generates a square matrix with ones on the main diagonal and zeros elsewhere.
493    ///
494    /// # Arguments
495    ///
496    /// * `n` - The size of the matrix (both rows and columns).
497    ///
498    /// # Returns
499    ///
500    /// * A 2D identity matrix of size `n`.
501    ///
502    /// # Panics
503    ///
504    /// * This function will panic if `n` is zero.
505    #[track_caller]
506    fn identity(n: usize) -> Result<Self::Output, TensorError>
507    where
508        u8: Cast<T>;
509}
510
511/// A trait for tensor memory allocation, this trait only used when we work with generic type
512pub trait TensorAlloc<Output = Self> {
513    /// The tensor data type.
514    type Meta;
515    /// Creates a tensor with the specified shape,
516    ///
517    /// # Note
518    ///
519    /// This function doesn't initialize the tensor's elements.
520    #[track_caller]
521    fn _empty<S: Into<Shape>>(shape: S) -> Result<Output, TensorError>
522    where
523        Self: Sized;
524}
525
526/// A trait typically for argmax and argmin functions.
527pub trait IndexReduce
528where
529    Self: Sized,
530{
531    /// The output tensor type.
532    type Output;
533
534    /// Returns the indices of the maximum values along the specified axis.
535    ///
536    /// The `argmax` function computes the index of the maximum value along the given axis for each slice of the tensor.
537    ///
538    /// # Parameters
539    ///
540    /// - `axis`: The axis along which to compute the index of the maximum value.
541    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
542    ///
543    /// # Returns
544    ///
545    /// - `anyhow::Result<Self::Output>`: A tensor containing the indices of the maximum values.
546    ///
547    /// # See Also
548    ///
549    /// - [`argmin`]: Returns the indices of the minimum values along the specified axis.
550    #[track_caller]
551    fn argmax<S: Into<Axis>>(&self, axis: S, keep_dims: bool) -> Result<Self::Output, TensorError>;
552
553    /// Returns the indices of the minimum values along the specified axis.
554    ///
555    /// The `argmin` function computes the index of the minimum value along the given axis for each slice of the tensor.
556    ///
557    /// # Parameters
558    ///
559    /// - `axis`: The axis along which to compute the index of the minimum value.
560    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
561    ///
562    /// # Returns
563    ///
564    /// - `anyhow::Result<Self::Output>`: A tensor containing the indices of the minimum values.
565    ///
566    /// # See Also
567    ///
568    /// - [`argmax`]: Returns the indices of the maximum values along the specified axis.
569    #[track_caller]
570    fn argmin<S: Into<Axis>>(&self, axis: S, keep_dims: bool) -> Result<Self::Output, TensorError>;
571}
572
573/// A trait for normal tensor reduction operations.
574pub trait NormalReduce<T>
575where
576    Self: Sized,
577{
578    /// The output tensor type.
579    type Output;
580
581    /// Computes the sum of the elements along the specified axis.
582    ///
583    /// The `sum` function computes the sum of elements along the specified axis of the tensor.
584    ///
585    /// # Parameters
586    ///
587    /// - `axis`: The axis along which to sum the elements.
588    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
589    ///
590    /// # Returns
591    ///
592    /// - `anyhow::Result<Self::Output>`: A tensor containing the sum of elements along the specified axis.
593    ///
594    /// # See Also
595    ///
596    /// - [`nansum`]: Computes the sum while ignoring NaN values.
597    #[track_caller]
598    fn sum<S: Into<Axis>>(&self, axis: S, keep_dims: bool) -> Result<Self::Output, TensorError>;
599
600    /// Computes the sum of the elements along the specified axis, storing the result in a pre-allocated tensor.
601    ///
602    /// The `sum_` function computes the sum of elements along the specified axis, and optionally initializes an output tensor to store the result.
603    ///
604    /// # Parameters
605    ///
606    /// - `axis`: The axis along which to sum the elements.
607    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
608    /// - `init_out`: Whether to initialize the output tensor.
609    /// - `out`: The tensor in which to store the result.
610    ///
611    /// # Returns
612    ///
613    /// - `anyhow::Result<Self::Output>`: A tensor containing the sum of elements, with the result stored in the specified output tensor.
614    #[track_caller]
615    fn sum_<S: Into<Axis>, O>(
616        &self,
617        axis: S,
618        keep_dims: bool,
619        init_out: bool,
620        out: O,
621    ) -> Result<Self::Output, TensorError>
622    where
623        O: Borrow<Self::Output>;
624
625    // /// Computes the sum of the elements along the specified axis, with an initial value.
626    // ///
627    // /// The `sum_with_init` function computes the sum of elements along the specified axes, starting from a given initial value.
628    // ///
629    // /// # Parameters
630    // ///
631    // /// - `init_val`: The initial value to start the summation.
632    // /// - `axes`: The axes along which to sum the elements.
633    // /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
634    // ///
635    // /// # Returns
636    // ///
637    // /// - `anyhow::Result<Self::Output>`: A tensor containing the sum of elements along the specified axes.
638    // #[track_caller]
639    // fn sum_with_init<S: Into<Axis>>(
640    //     &self,
641    //     init_val: T,
642    //     axes: S,
643    //     keep_dims: bool,
644    // ) -> anyhow::Result<Self::Output>;
645
646    /// Computes the product of the elements along the specified axis.
647    ///
648    /// The `prod` function computes the product of elements along the specified axis of the tensor.
649    ///
650    /// # Parameters
651    ///
652    /// - `axis`: The axis along which to compute the product.
653    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
654    ///
655    /// # Returns
656    ///
657    /// - `anyhow::Result<Self::Output>`: A tensor containing the product of elements along the specified axis.
658    ///
659    /// # See Also
660    ///
661    /// - [`nanprod`]: Computes the product while ignoring NaN values.
662    #[track_caller]
663    fn prod<S: Into<Axis>>(&self, axis: S, keep_dims: bool) -> Result<Self::Output, TensorError>;
664
665    // /// Computes the product of the elements along the specified axis, with an initial value.
666    // ///
667    // /// The `prod_with_init` function computes the product of elements along the specified axes, starting from a given initial value.
668    // ///
669    // /// # Parameters
670    // ///
671    // /// - `init_val`: The initial value to start the product computation.
672    // /// - `axes`: The axes along which to compute the product.
673    // /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
674    // ///
675    // /// # Returns
676    // ///
677    // /// - `anyhow::Result<Self::Output>`: A tensor containing the product of elements along the specified axes.
678    // #[track_caller]
679    // fn prod_with_init<S: Into<Axis>>(
680    //     &self,
681    //     init_val: T,
682    //     axes: S,
683    //     keep_dims: bool,
684    // ) -> anyhow::Result<Self::Output>;
685
686    /// Computes the minimum value along the specified axis.
687    ///
688    /// The `min` function returns the minimum value of the elements along the specified axis of the tensor.
689    ///
690    /// # Parameters
691    ///
692    /// - `axis`: The axis along which to compute the minimum value.
693    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
694    ///
695    /// # Returns
696    ///
697    /// - `anyhow::Result<Self>`: A tensor containing the minimum values along the specified axis.
698    ///
699    /// # See Also
700    ///
701    /// - [`max`]: Computes the maximum value along the specified axis.
702    #[track_caller]
703    fn min<S: Into<Axis>>(&self, axis: S, keep_dims: bool) -> Result<Self::Output, TensorError>;
704
705    // /// Computes the minimum value along the specified axis, with an initial value.
706    // ///
707    // /// The `min_with_init` function computes the minimum value along the specified axes, starting from a given initial value.
708    // ///
709    // /// # Parameters
710    // ///
711    // /// - `init_val`: The initial value to compare against.
712    // /// - `axes`: The axes along which to compute the minimum value.
713    // /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
714    // ///
715    // /// # Returns
716    // ///
717    // /// - `anyhow::Result<Self>`: A tensor containing the minimum values along the specified axes.
718    // #[track_caller]
719    // fn min_with_init<S: Into<Axis>>(
720    //     &self,
721    //     init_val: T,
722    //     axes: S,
723    //     keep_dims: bool,
724    // ) -> anyhow::Result<Self>;
725
726    /// Computes the maximum value along the specified axis.
727    ///
728    /// The `max` function returns the maximum value of the elements along the specified axis of the tensor.
729    ///
730    /// # Parameters
731    ///
732    /// - `axis`: The axis along which to compute the maximum value.
733    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
734    ///
735    /// # Returns
736    ///
737    /// - `anyhow::Result<Self>`: A tensor containing the maximum values along the specified axis.
738    ///
739    /// # See Also
740    ///
741    /// - [`min`]: Computes the minimum value along the specified axis.
742    #[track_caller]
743    fn max<S: Into<Axis>>(&self, axis: S, keep_dims: bool) -> Result<Self::Output, TensorError>;
744
745    // /// Computes the maximum value along the specified axis, with an initial value.
746    // ///
747    // /// The `max_with_init` function computes the maximum value along the specified axes, starting from a given initial value.
748    // ///
749    // /// # Parameters
750    // ///
751    // /// - `init_val`: The initial value to compare against.
752    // /// - `axes`: The axes along which to compute the maximum value.
753    // /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
754    // ///
755    // /// # Returns
756    // ///
757    // /// - `anyhow::Result<Self>`: A tensor containing the maximum values along the specified axes.
758    // #[track_caller]
759    // fn max_with_init<S: Into<Axis>>(
760    //     &self,
761    //     init_val: T,
762    //     axes: S,
763    //     keep_dims: bool,
764    // ) -> anyhow::Result<Self>;
765
766    /// Reduces the tensor along the specified axis using the L1 norm (sum of absolute values).
767    ///
768    /// The `reducel1` function computes the L1 norm (sum of absolute values) along the specified axis of the tensor.
769    ///
770    /// # Parameters
771    ///
772    /// - `axis`: The axis along which to reduce the tensor.
773    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
774    ///
775    /// # Returns
776    ///
777    /// - `anyhow::Result<Self::Output>`: A tensor with the L1 norm computed along the specified axis.
778    #[track_caller]
779    fn reducel1<S: Into<Axis>>(
780        &self,
781        axis: S,
782        keep_dims: bool,
783    ) -> Result<Self::Output, TensorError>;
784
785    /// Computes the sum of the squares of the elements along the specified axis.
786    ///
787    /// The `sum_square` function computes the sum of the squares of the elements along the specified axis of the tensor.
788    ///
789    /// # Parameters
790    ///
791    /// - `axis`: The axis along which to sum the squares.
792    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
793    ///
794    /// # Returns
795    ///
796    /// - `anyhow::Result<Self::Output>`: A tensor containing the sum of squares of elements along the specified axis.
797    #[track_caller]
798    fn sum_square<S: Into<Axis>>(
799        &self,
800        axis: S,
801        keep_dims: bool,
802    ) -> Result<Self::Output, TensorError>;
803}
804
805/// A trait for tensor reduction operations, the output must be a boolean tensor.
806pub trait EvalReduce {
807    /// The boolean tensor type.
808    type BoolOutput;
809    /// Returns `true` if all elements along the specified axis evaluate to `true`.
810    ///
811    /// The `all` function checks whether all elements along the specified axis evaluate to `true`.
812    ///
813    /// # Parameters
814    ///
815    /// - `axis`: The axis along which to check.
816    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
817    ///
818    /// # Returns
819    ///
820    /// - `anyhow::Result<Self::BoolOutput>`: A boolean tensor indicating whether all elements evaluate to `true`.
821    ///
822    /// # See Also
823    ///
824    /// - [`any`]: Returns `true` if any element along the specified axis evaluates to `true`.
825    #[track_caller]
826    fn all<S: Into<Axis>>(&self, axis: S, keep_dims: bool)
827        -> Result<Self::BoolOutput, TensorError>;
828
829    /// Returns `true` if any element along the specified axis evaluates to `true`.
830    ///
831    /// The `any` function checks whether any element along the specified axis evaluates to `true`.
832    ///
833    /// # Parameters
834    ///
835    /// - `axis`: The axis along which to check.
836    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
837    ///
838    /// # Returns
839    ///
840    /// - `anyhow::Result<Self::BoolOutput>`: A boolean tensor indicating whether any element evaluates to `true`.
841    ///
842    /// # See Also
843    ///
844    /// - [`all`]: Returns `true` if all elements along the specified axis evaluate to `true`.
845    #[track_caller]
846    fn any<S: Into<Axis>>(&self, axis: S, keep_dims: bool)
847        -> Result<Self::BoolOutput, TensorError>;
848}
849
850/// A trait for tensor reduction operations, the output must remain the same tensor type.
851pub trait NormalEvalReduce<T> {
852    /// the output tensor type.
853    type Output;
854    /// Computes the sum of the elements along the specified axis, ignoring NaN values.
855    ///
856    /// The `nansum` function computes the sum of elements along the specified axis, while ignoring NaN values in the tensor.
857    ///
858    /// # Parameters
859    ///
860    /// - `axis`: The axis along which to sum the elements.
861    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
862    ///
863    /// # Returns
864    ///
865    /// - `anyhow::Result<Self::Output>`: A tensor containing the sum of elements, ignoring NaN values.
866    #[track_caller]
867    fn nansum<S: Into<Axis>>(&self, axis: S, keep_dims: bool) -> Result<Self::Output, TensorError>;
868
869    /// Computes the sum of the elements along the specified axis, with an initial value, ignoring NaN values.
870    ///
871    /// The `nansum_with_init` function computes the sum of elements along the specified axes, starting from a given initial value and ignoring NaN values.
872    ///
873    /// # Parameters
874    ///
875    /// - `init_val`: The initial value to start the summation.
876    /// - `axes`: The axes along which to sum the elements.
877    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
878    ///
879    /// # Returns
880    ///
881    /// - `anyhow::Result<Self::Output>`: A tensor containing the sum of elements, ignoring NaN values.
882    #[track_caller]
883    fn nansum_<S: Into<Axis>, O>(
884        &self,
885        axis: S,
886        keep_dims: bool,
887        init_out: bool,
888        out: O,
889    ) -> Result<Self::Output, TensorError>
890    where
891        O: Borrow<Self::Output>;
892    // /// Computes the sum of the elements along the specified axis, with an initial value, ignoring NaN values.
893    // ///
894    // /// The `nansum_with_init` function computes the sum of elements along the specified axes, starting from a given initial value and ignoring NaN values.
895    // ///
896    // /// # Parameters
897    // ///
898    // /// - `init_val`: The initial value to start the summation.
899    // /// - `axes`: The axes along which to sum the elements.
900    // /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
901    // ///
902    // /// # Returns
903    // ///
904    // /// - `anyhow::Result<Self::Output>`: A tensor containing the sum of elements, ignoring NaN values.
905    // #[track_caller]
906    // fn nansum_with_init<S: Into<Axis>>(
907    //     &self,
908    //     init_val: T,
909    //     axes: S,
910    //     keep_dims: bool,
911    // ) -> anyhow::Result<Self::Output>;
912
913    /// Computes the product of the elements along the specified axis, ignoring NaN values.
914    ///
915    /// The `nanprod` function computes the product of elements along the specified axis, while ignoring NaN values in the tensor.
916    ///
917    /// # Parameters
918    ///
919    /// - `axis`: The axis along which to compute the product.
920    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
921    ///
922    /// # Returns
923    ///
924    /// - `anyhow::Result<Self::Output>`: A tensor containing the product of elements, ignoring NaN values.
925    #[track_caller]
926    fn nanprod<S: Into<Axis>>(&self, axis: S, keep_dims: bool)
927        -> Result<Self::Output, TensorError>;
928
929    // /// Computes the product of the elements along the specified axis, with an initial value, ignoring NaN values.
930    // ///
931    // /// The `nanprod_with_init` function computes the product of elements along the specified axes, starting from a given initial value and ignoring NaN values.
932    // ///
933    // /// # Parameters
934    // ///
935    // /// - `init_val`: The initial value to start the product computation.
936    // /// - `axes`: The axes along which to compute the product.
937    // /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
938    // ///
939    // /// # Returns
940    // ///
941    // /// - `anyhow::Result<Self::Output>`: A tensor containing the product of elements, ignoring NaN values.
942    // #[track_caller]
943    // fn nanprod_with_init<S: Into<Axis>>(
944    //     &self,
945    //     init_val: T,
946    //     axes: S,
947    //     keep_dims: bool,
948    // ) -> anyhow::Result<Self::Output>;
949}
950
951/// A trait for tensor reduction operations, the output must be a floating-point tensor.
952pub trait FloatReduce<T>
953where
954    Self: Sized,
955{
956    /// The output tensor type.
957    type Output;
958
959    /// Computes the mean of the elements along the specified axis.
960    ///
961    /// The `mean` function calculates the mean of the elements along the specified axis of the tensor.
962    ///
963    /// # Parameters
964    ///
965    /// - `axis`: The axis along which to compute the mean.
966    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
967    ///
968    /// # Returns
969    ///
970    /// - `anyhow::Result<Self::Output>`: A tensor containing the mean values along the specified axis.
971    #[track_caller]
972    fn mean<S: Into<Axis>>(&self, axis: S, keep_dims: bool) -> Result<Self::Output, TensorError>;
973
974    /// Reduces the tensor along the specified axis using the L2 norm (Euclidean norm).
975    ///
976    /// The `reducel2` function computes the L2 norm (Euclidean norm) along the specified axis of the tensor.
977    ///
978    /// # Parameters
979    ///
980    /// - `axis`: The axis along which to reduce the tensor.
981    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
982    ///
983    /// # Returns
984    ///
985    /// - `anyhow::Result<Self::Output>`: A tensor with the L2 norm computed along the specified axis.
986    #[track_caller]
987    fn reducel2<S: Into<Axis>>(
988        &self,
989        axis: S,
990        keep_dims: bool,
991    ) -> Result<Self::Output, TensorError>;
992
993    /// Reduces the tensor along the specified axis using the L3 norm.
994    ///
995    /// The `reducel3` function computes the L3 norm along the specified axis of the tensor.
996    ///
997    /// # Parameters
998    ///
999    /// - `axis`: The axis along which to reduce the tensor.
1000    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
1001    ///
1002    /// # Returns
1003    ///
1004    /// - `anyhow::Result<Self::Output>`: A tensor with the L3 norm computed along the specified axis.
1005    #[track_caller]
1006    fn reducel3<S: Into<Axis>>(
1007        &self,
1008        axis: S,
1009        keep_dims: bool,
1010    ) -> Result<Self::Output, TensorError>;
1011
1012    /// Computes the logarithm of the sum of exponentials of the elements along the specified axis.
1013    ///
1014    /// The `logsumexp` function calculates the logarithm of the sum of exponentials of the elements along the specified axis,
1015    /// which is useful for numerical stability in certain operations.
1016    ///
1017    /// # Parameters
1018    ///
1019    /// - `axis`: The axis along which to compute the logarithm of the sum of exponentials.
1020    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
1021    ///
1022    /// # Returns
1023    ///
1024    /// - `anyhow::Result<Self::Output>`: A tensor containing the log-sum-exp values along the specified axis.
1025    #[track_caller]
1026    fn logsumexp<S: Into<Axis>>(
1027        &self,
1028        axis: S,
1029        keep_dims: bool,
1030    ) -> Result<Self::Output, TensorError>;
1031}
1032
1033/// Common bounds for primitive types
1034pub trait CommonBounds
1035where
1036    <Self as TypeCommon>::Vec: Send + Sync + Copy,
1037    Self: Sync
1038        + Send
1039        + Clone
1040        + Copy
1041        + TypeCommon
1042        + 'static
1043        + Display
1044        + Debug
1045        + Cast<Self>
1046        + NormalOut<Self, Output = Self>
1047        + FloatOutUnary
1048        + NormalOut<<Self as FloatOutUnary>::Output, Output = <Self as FloatOutUnary>::Output>
1049        + FloatOutBinary<<Self as FloatOutUnary>::Output, Output = <Self as FloatOutUnary>::Output>
1050        + FloatOutBinary<Self>
1051        + NormalOut<
1052            <Self as FloatOutBinary<Self>>::Output,
1053            Output = <Self as FloatOutBinary<Self>>::Output,
1054        >
1055        + NormalOutUnary,
1056{
1057}
1058impl<T> CommonBounds for T
1059where
1060    <Self as TypeCommon>::Vec: Send + Sync + Copy,
1061    Self: Sync
1062        + Send
1063        + Clone
1064        + Copy
1065        + TypeCommon
1066        + 'static
1067        + Display
1068        + Debug
1069        + Cast<Self>
1070        + NormalOut<Self, Output = Self>
1071        + FloatOutUnary
1072        + NormalOut<<Self as FloatOutUnary>::Output, Output = <Self as FloatOutUnary>::Output>
1073        + FloatOutBinary<<Self as FloatOutUnary>::Output, Output = <Self as FloatOutUnary>::Output>
1074        + FloatOutBinary<Self>
1075        + FloatOutBinary<
1076            <Self as FloatOutBinary<Self>>::Output,
1077            Output = <Self as FloatOutBinary<Self>>::Output,
1078        >
1079        + NormalOut<
1080            <Self as FloatOutBinary<Self>>::Output,
1081            Output = <Self as FloatOutBinary<Self>>::Output,
1082        >
1083        + NormalOutUnary,
1084{
1085}