hpt_traits/
tensor.rs

1use hpt_common::error::base::TensorError;
2use hpt_common::{
3    axis::axis::Axis, layout::layout::Layout, shape::shape::Shape, strides::strides::Strides,
4    utils::pointer::Pointer,
5};
6use hpt_types::arch_simd as simd;
7use hpt_types::{
8    dtype::TypeCommon,
9    into_scalar::Cast,
10    type_promote::{FloatOutBinary, FloatOutUnary, NormalOut, NormalOutUnary},
11};
12use std::fmt::Debug;
13use std::{borrow::BorrowMut, fmt::Display};
14
15#[cfg(target_feature = "avx2")]
16type BoolVector = simd::_256bit::boolx32::boolx32;
17#[cfg(any(
18    all(not(target_feature = "avx2"), target_feature = "sse"),
19    target_arch = "arm",
20    target_arch = "aarch64",
21    target_feature = "neon"
22))]
23type BoolVector = simd::_128bit::boolx16::boolx16;
24
25/// A trait for getting information of a Tensor
26pub trait TensorInfo<T> {
27    /// Returns a pointer to the tensor's first data.
28    #[track_caller]
29    fn ptr(&self) -> Pointer<T>;
30
31    /// Returns the size of the tensor based on the shape
32    #[track_caller]
33    fn size(&self) -> usize;
34
35    /// Returns the shape of the tensor.
36    #[track_caller]
37    fn shape(&self) -> &Shape;
38
39    /// Returns the strides of the tensor.
40    #[track_caller]
41    fn strides(&self) -> &Strides;
42
43    /// Returns the layout of the tensor. Layout contains shape and strides.
44    #[track_caller]
45    fn layout(&self) -> &Layout;
46    /// Returns the root tensor, if any.
47    ///
48    /// if the tensor is a view, it will return the root tensor. Otherwise, it will return None.
49    #[track_caller]
50    fn parent(&self) -> Option<Pointer<T>>;
51
52    /// Returns the number of dimensions of the tensor.
53    #[track_caller]
54    fn ndim(&self) -> usize;
55
56    /// Returns whether the tensor is contiguous in memory. View or transpose tensors are not contiguous.
57    #[track_caller]
58    fn is_contiguous(&self) -> bool;
59
60    /// Returns the data type memory size in bytes.
61    #[track_caller]
62    fn elsize() -> usize {
63        size_of::<T>()
64    }
65}
66
67/// A trait for let the object like a tensor
68pub trait TensorLike<T>: Sized {
69    /// directly convert the tensor to raw slice
70    ///
71    /// # Note
72    ///
73    /// This function will return a raw slice of the tensor regardless of the shape and strides.
74    ///
75    /// if you do iteration on the view tensor, you may see unexpected results.
76    fn as_raw(&self) -> &[T];
77
78    /// directly convert the tensor to mutable raw slice
79    ///
80    /// # Note
81    ///
82    /// This function will return a mutable raw slice of the tensor regardless of the shape and strides.
83    ///
84    /// if you do iteration on the view tensor, you may see unexpected results.
85    fn as_raw_mut(&mut self) -> &mut [T];
86
87    /// Returns the tensor as a contiguous tensor.
88    ///
89    /// # Note
90    ///
91    /// This function will return a contiguous tensor. If the tensor is already contiguous, it will return a clone of the tensor.
92    ///
93    /// If the tensor is a view tensor, it will return a new tensor with the same data but with a contiguous layout.
94    fn contiguous(&self) -> Result<Self, TensorError>;
95
96    /// Returns the data type memory size in bytes.
97    fn elsize() -> usize {
98        size_of::<T>()
99    }
100}
101
102/// A trait defines a set of functions to create tensors.
103pub trait TensorCreator<T>
104where
105    Self: Sized,
106{
107    /// the output type of the creator
108    type Output;
109
110    /// Creates a tensor with uninitialized elements of the specified shape.
111    ///
112    /// This function allocates memory for a tensor of the given shape, but the values are uninitialized, meaning they may contain random data.
113    ///
114    /// # Arguments
115    ///
116    /// * `shape` - The desired shape of the tensor. The type `S` must implement `Into<Shape>`.
117    ///
118    /// # Returns
119    ///
120    /// * A tensor with the specified shape, but with uninitialized data.
121    ///
122    /// # Panics
123    ///
124    /// * This function may panic if the requested shape is invalid or too large for available memory.
125    #[track_caller]
126    fn empty<S: Into<Shape>>(shape: S) -> Result<Self::Output, TensorError>;
127
128    /// Creates a tensor filled with zeros of the specified shape.
129    ///
130    /// This function returns a tensor where every element is initialized to `0`, with the shape defined by the input.
131    ///
132    /// # Arguments
133    ///
134    /// * `shape` - The desired shape of the tensor. The type `S` must implement `Into<Shape>`.
135    ///
136    /// # Returns
137    ///
138    /// * A tensor filled with zeros, with the specified shape.
139    ///
140    /// # Panics
141    ///
142    /// * This function may panic if the requested shape is invalid or too large for available memory.
143    #[track_caller]
144    fn zeros<S: Into<Shape>>(shape: S) -> Result<Self::Output, TensorError>;
145
146    /// Creates a tensor filled with ones of the specified shape.
147    ///
148    /// This function returns a tensor where every element is initialized to `1`, with the shape defined by the input.
149    ///
150    /// # Arguments
151    ///
152    /// * `shape` - The desired shape of the tensor. The type `S` must implement `Into<Shape>`.
153    ///
154    /// # Returns
155    ///
156    /// * A tensor filled with ones, with the specified shape.
157    ///
158    /// # Panics
159    ///
160    /// * This function may panic if the requested shape is invalid or too large for available memory.
161    #[track_caller]
162    fn ones<S: Into<Shape>>(shape: S) -> Result<Self::Output, TensorError>
163    where
164        u8: Cast<T>;
165
166    /// Creates a tensor with uninitialized elements, having the same shape as the input tensor.
167    ///
168    /// This function returns a tensor with the same shape as the calling tensor, but with uninitialized data.
169    ///
170    /// # Arguments
171    ///
172    /// This function takes no arguments.
173    ///
174    /// # Returns
175    ///
176    /// * A tensor with the same shape as the input, but with uninitialized data.
177    ///
178    /// # Panics
179    ///
180    /// * This function may panic if the shape is too large for available memory.
181    #[track_caller]
182    fn empty_like(&self) -> Result<Self::Output, TensorError>;
183
184    /// Creates a tensor filled with zeros, having the same shape as the input tensor.
185    ///
186    /// This function returns a tensor with the same shape as the calling tensor, with all elements initialized to `0`.
187    ///
188    /// # Arguments
189    ///
190    /// This function takes no arguments.
191    ///
192    /// # Returns
193    ///
194    /// * A tensor with the same shape as the input, filled with zeros.
195    ///
196    /// # Panics
197    ///
198    /// * This function may panic if the shape is too large for available memory.
199    #[track_caller]
200    fn zeros_like(&self) -> Result<Self::Output, TensorError>;
201
202    /// Creates a tensor filled with ones, having the same shape as the input tensor.
203    ///
204    /// This function returns a tensor with the same shape as the calling tensor, with all elements initialized to `1`.
205    ///
206    /// # Arguments
207    ///
208    /// This function takes no arguments.
209    ///
210    /// # Returns
211    ///
212    /// * A tensor with the same shape as the input, filled with ones.
213    ///
214    /// # Panics
215    ///
216    /// * This function may panic if the shape is too large for available memory.
217    #[track_caller]
218    fn ones_like(&self) -> Result<Self::Output, TensorError>
219    where
220        u8: Cast<T>;
221
222    /// Creates a tensor filled with a specified value, with the specified shape.
223    ///
224    /// This function returns a tensor where every element is set to `val`, with the shape defined by the input.
225    ///
226    /// # Arguments
227    ///
228    /// * `val` - The value to fill the tensor with.
229    /// * `shape` - The desired shape of the tensor. The type `S` must implement `Into<Shape>`.
230    ///
231    /// # Returns
232    ///
233    /// * A tensor filled with `val`, with the specified shape.
234    ///
235    /// # Panics
236    ///
237    /// * This function may panic if the requested shape is invalid or too large for available memory.
238    #[track_caller]
239    fn full<S: Into<Shape>>(val: T, shape: S) -> Result<Self::Output, TensorError>;
240
241    /// Creates a tensor filled with a specified value, having the same shape as the input tensor.
242    ///
243    /// This function returns a tensor where every element is set to `val`, with the same shape as the calling tensor.
244    ///
245    /// # Arguments
246    ///
247    /// * `val` - The value to fill the tensor with.
248    ///
249    /// # Returns
250    ///
251    /// * A tensor with the same shape as the input, filled with `val`.
252    ///
253    /// # Panics
254    ///
255    /// * This function may panic if the shape is too large for available memory.
256    #[track_caller]
257    fn full_like(&self, val: T) -> Result<Self::Output, TensorError>;
258
259    /// Creates a tensor with values within a specified range.
260    ///
261    /// This function generates a 1D tensor with values ranging from `start` (inclusive) to `end` (exclusive).
262    ///
263    /// # Arguments
264    ///
265    /// * `start` - The start of the range.
266    /// * `end` - The end of the range.
267    ///
268    /// # Returns
269    ///
270    /// * A 1D tensor with values ranging from `start` to `end`.
271    ///
272    /// # Panics
273    ///
274    /// * This function will panic if `start` is greater than or equal to `end`, or if the range is too large for available memory.
275    #[track_caller]
276    fn arange<U>(start: U, end: U) -> Result<Self::Output, TensorError>
277    where
278        usize: Cast<T>,
279        U: Cast<i64> + Cast<T> + Copy;
280
281    /// Creates a tensor with values within a specified range with a given step size.
282    ///
283    /// This function generates a 1D tensor with values ranging from `start` (inclusive) to `end` (exclusive),
284    /// incremented by `step`.
285    ///
286    /// # Arguments
287    ///
288    /// * `start` - The start of the range.
289    /// * `end` - The end of the range (exclusive).
290    /// * `step` - The step size between consecutive values.
291    ///
292    /// # Returns
293    ///
294    /// * A 1D tensor with values from `start` to `end`, incremented by `step`.
295    ///
296    /// # Panics
297    ///
298    /// * This function will panic if `step` is zero or if the range and step values are incompatible.
299    #[track_caller]
300    fn arange_step(start: T, end: T, step: T) -> Result<Self::Output, TensorError>
301    where
302        T: Cast<f64> + Cast<usize>,
303        usize: Cast<T>;
304
305    /// Creates a 2D identity matrix with ones on a diagonal and zeros elsewhere.
306    ///
307    /// This function generates a matrix of size `n` by `m`, with ones on the `k`th diagonal (can be offset) and zeros elsewhere.
308    ///
309    /// # Arguments
310    ///
311    /// * `n` - The number of rows in the matrix.
312    /// * `m` - The number of columns in the matrix.
313    /// * `k` - The diagonal offset (0 for main diagonal, positive for upper diagonals, negative for lower diagonals).
314    ///
315    /// # Returns
316    ///
317    /// * A 2D identity matrix with ones on the specified diagonal.
318    ///
319    /// # Panics
320    ///
321    /// * This function will panic if `n` or `m` is zero, or if memory constraints are exceeded.
322    #[track_caller]
323    fn eye(n: usize, m: usize, k: usize) -> Result<Self::Output, TensorError>;
324
325    /// Creates a tensor with evenly spaced values between `start` and `end`.
326    ///
327    /// This function generates a 1D tensor of `num` values, linearly spaced between `start` and `end`.
328    /// If `include_end` is `true`, the `end` value will be included as the last element.
329    ///
330    /// # Arguments
331    ///
332    /// * `start` - The start of the range.
333    /// * `end` - The end of the range.
334    /// * `num` - The number of evenly spaced values to generate.
335    /// * `include_end` - Whether to include the `end` value in the generated tensor.
336    ///
337    /// # Returns
338    ///
339    /// * A 1D tensor with `num` linearly spaced values between `start` and `end`.
340    ///
341    /// # Panics
342    ///
343    /// * This function will panic if `num` is zero or if `num` is too large for available memory.
344    #[track_caller]
345    fn linspace<U>(
346        start: U,
347        end: U,
348        num: usize,
349        include_end: bool,
350    ) -> Result<Self::Output, TensorError>
351    where
352        U: Cast<f64> + Cast<T> + Copy,
353        usize: Cast<T>,
354        f64: Cast<T>;
355
356    /// Creates a tensor with logarithmically spaced values between `start` and `end`.
357    ///
358    /// This function generates a 1D tensor of `num` values spaced evenly on a log scale between `start` and `end`.
359    /// The spacing is based on the logarithm to the given `base`. If `include_end` is `true`, the `end` value will be included.
360    ///
361    /// # Arguments
362    ///
363    /// * `start` - The starting exponent (base `base`).
364    /// * `end` - The ending exponent (base `base`).
365    /// * `num` - The number of logarithmically spaced values to generate.
366    /// * `include_end` - Whether to include the `end` value in the generated tensor.
367    /// * `base` - The base of the logarithm.
368    ///
369    /// # Returns
370    ///
371    /// * A 1D tensor with `num` logarithmically spaced values between `start` and `end`.
372    ///
373    /// # Panics
374    ///
375    /// * This function will panic if `num` is zero or if `base` is less than or equal to zero.
376    #[track_caller]
377    fn logspace(
378        start: T,
379        end: T,
380        num: usize,
381        include_end: bool,
382        base: T,
383    ) -> Result<Self::Output, TensorError>
384    where
385        T: Cast<f64> + num::Float + NormalOut<T, Output = T>,
386        usize: Cast<T>,
387        f64: Cast<T>;
388
389    /// Creates a tensor with geometrically spaced values between `start` and `end`.
390    ///
391    /// This function generates a 1D tensor of `n` values spaced evenly on a geometric scale between `start` and `end`.
392    /// If `include_end` is `true`, the `end` value will be included.
393    ///
394    /// # Arguments
395    ///
396    /// * `start` - The starting value (must be positive).
397    /// * `end` - The ending value (must be positive).
398    /// * `n` - The number of geometrically spaced values to generate.
399    /// * `include_end` - Whether to include the `end` value in the generated tensor.
400    ///
401    /// # Returns
402    ///
403    /// * A 1D tensor with `n` geometrically spaced values between `start` and `end`.
404    ///
405    /// # Panics
406    ///
407    /// * This function will panic if `n` is zero, if `start` or `end` is negative, or if the values result in undefined behavior.
408    #[track_caller]
409    fn geomspace(
410        start: T,
411        end: T,
412        n: usize,
413        include_end: bool,
414    ) -> Result<Self::Output, TensorError>
415    where
416        f64: Cast<T>,
417        usize: Cast<T>,
418        T: Cast<f64>;
419
420    /// Creates a 2D triangular matrix of size `n` by `m`, with ones below or on the `k`th diagonal and zeros elsewhere.
421    ///
422    /// This function generates a matrix with a triangular structure, filled with ones and zeros, based on the diagonal offset and the `low_triangle` flag.
423    ///
424    /// # Arguments
425    ///
426    /// * `n` - The number of rows in the matrix.
427    /// * `m` - The number of columns in the matrix.
428    /// * `k` - The diagonal offset (0 for main diagonal, positive for upper diagonals, negative for lower diagonals).
429    /// * `low_triangle` - If `true`, the matrix will be lower triangular; otherwise, upper triangular.
430    ///
431    /// # Returns
432    ///
433    /// * A 2D triangular matrix of ones and zeros.
434    ///
435    /// # Panics
436    ///
437    /// * This function will panic if `n` or `m` is zero.
438    #[track_caller]
439    fn tri(n: usize, m: usize, k: i64, low_triangle: bool) -> Result<Self::Output, TensorError>
440    where
441        u8: Cast<T>;
442
443    /// Returns the lower triangular part of the matrix, with all elements above the `k`th diagonal set to zero.
444    ///
445    /// This function generates a tensor where the elements above the specified diagonal are set to zero.
446    ///
447    /// # Arguments
448    ///
449    /// * `k` - The diagonal offset (0 for main diagonal, positive for upper diagonals, negative for lower diagonals).
450    ///
451    /// # Returns
452    ///
453    /// * A tensor with its upper triangular part set to zero.
454    ///
455    /// # Panics
456    ///
457    /// * This function should not panic under normal conditions.
458    #[track_caller]
459    fn tril(&self, k: i64) -> Result<Self::Output, TensorError>
460    where
461        T: NormalOut<bool, Output = T> + Cast<T> + TypeCommon,
462        <T as TypeCommon>::Vec: NormalOut<BoolVector, Output = <T as TypeCommon>::Vec>;
463
464    /// Returns the upper triangular part of the matrix, with all elements below the `k`th diagonal set to zero.
465    ///
466    /// This function generates a tensor where the elements below the specified diagonal are set to zero.
467    ///
468    /// # Arguments
469    ///
470    /// * `k` - The diagonal offset (0 for main diagonal, positive for upper diagonals, negative for lower diagonals).
471    ///
472    /// # Returns
473    ///
474    /// * A tensor with its lower triangular part set to zero.
475    ///
476    /// # Panics
477    ///
478    /// * This function should not panic under normal conditions.
479    #[track_caller]
480    fn triu(&self, k: i64) -> Result<Self::Output, TensorError>
481    where
482        T: NormalOut<bool, Output = T> + Cast<T> + TypeCommon,
483        <T as TypeCommon>::Vec: NormalOut<BoolVector, Output = <T as TypeCommon>::Vec>;
484
485    /// Creates a 2D identity matrix of size `n` by `n`.
486    ///
487    /// This function generates a square matrix with ones on the main diagonal and zeros elsewhere.
488    ///
489    /// # Arguments
490    ///
491    /// * `n` - The size of the matrix (both rows and columns).
492    ///
493    /// # Returns
494    ///
495    /// * A 2D identity matrix of size `n`.
496    ///
497    /// # Panics
498    ///
499    /// * This function will panic if `n` is zero.
500    #[track_caller]
501    fn identity(n: usize) -> Result<Self::Output, TensorError>
502    where
503        u8: Cast<T>;
504}
505
506/// A trait for tensor memory allocation, this trait only used when we work with generic type
507pub trait TensorAlloc<Output = Self> {
508    /// The tensor data type.
509    type Meta;
510    /// Creates a tensor with the specified shape,
511    ///
512    /// # Note
513    ///
514    /// This function doesn't initialize the tensor's elements.
515    #[track_caller]
516    fn _empty<S: Into<Shape>>(shape: S) -> Result<Output, TensorError>
517    where
518        Self: Sized;
519}
520
521/// A trait typically for argmax and argmin functions.
522pub trait IndexReduce
523where
524    Self: Sized,
525{
526    /// The output tensor type.
527    type Output;
528
529    /// Returns the indices of the maximum values along the specified axis.
530    ///
531    /// The `argmax` function computes the index of the maximum value along the given axis for each slice of the tensor.
532    ///
533    /// # Parameters
534    ///
535    /// - `axis`: The axis along which to compute the index of the maximum value.
536    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
537    ///
538    /// # Returns
539    ///
540    /// - `anyhow::Result<Self::Output>`: A tensor containing the indices of the maximum values.
541    ///
542    /// # See Also
543    ///
544    /// - [`argmin`]: Returns the indices of the minimum values along the specified axis.
545    #[track_caller]
546    fn argmax<S: Into<Axis>>(&self, axis: S, keep_dims: bool) -> Result<Self::Output, TensorError>;
547
548    /// Returns the indices of the minimum values along the specified axis.
549    ///
550    /// The `argmin` function computes the index of the minimum value along the given axis for each slice of the tensor.
551    ///
552    /// # Parameters
553    ///
554    /// - `axis`: The axis along which to compute the index of the minimum value.
555    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
556    ///
557    /// # Returns
558    ///
559    /// - `anyhow::Result<Self::Output>`: A tensor containing the indices of the minimum values.
560    ///
561    /// # See Also
562    ///
563    /// - [`argmax`]: Returns the indices of the maximum values along the specified axis.
564    #[track_caller]
565    fn argmin<S: Into<Axis>>(&self, axis: S, keep_dims: bool) -> Result<Self::Output, TensorError>;
566}
567
568/// A trait for normal tensor reduction operations.
569pub trait NormalReduce<T>
570where
571    Self: Sized,
572{
573    /// The output tensor type.
574    type Output;
575
576    /// Computes the sum of the elements along the specified axis.
577    ///
578    /// The `sum` function computes the sum of elements along the specified axis of the tensor.
579    ///
580    /// # Parameters
581    ///
582    /// - `axis`: The axis along which to sum the elements.
583    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
584    ///
585    /// # Returns
586    ///
587    /// - `anyhow::Result<Self::Output>`: A tensor containing the sum of elements along the specified axis.
588    ///
589    /// # See Also
590    ///
591    /// - [`nansum`]: Computes the sum while ignoring NaN values.
592    #[track_caller]
593    fn sum<S: Into<Axis>>(&self, axis: S, keep_dims: bool) -> Result<Self::Output, TensorError>;
594
595    /// Computes the sum of the elements along the specified axis, storing the result in a pre-allocated tensor.
596    ///
597    /// The `sum_` function computes the sum of elements along the specified axis, and optionally initializes an output tensor to store the result.
598    ///
599    /// # Parameters
600    ///
601    /// - `axis`: The axis along which to sum the elements.
602    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
603    /// - `init_out`: Whether to initialize the output tensor.
604    /// - `out`: The tensor in which to store the result.
605    ///
606    /// # Returns
607    ///
608    /// - `anyhow::Result<Self::Output>`: A tensor containing the sum of elements, with the result stored in the specified output tensor.
609    #[track_caller]
610    fn sum_<S: Into<Axis>, O>(
611        &self,
612        axis: S,
613        keep_dims: bool,
614        init_out: bool,
615        out: O,
616    ) -> Result<Self::Output, TensorError>
617    where
618        O: BorrowMut<Self::Output>;
619
620    // /// Computes the sum of the elements along the specified axis, with an initial value.
621    // ///
622    // /// The `sum_with_init` function computes the sum of elements along the specified axes, starting from a given initial value.
623    // ///
624    // /// # Parameters
625    // ///
626    // /// - `init_val`: The initial value to start the summation.
627    // /// - `axes`: The axes along which to sum the elements.
628    // /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
629    // ///
630    // /// # Returns
631    // ///
632    // /// - `anyhow::Result<Self::Output>`: A tensor containing the sum of elements along the specified axes.
633    // #[track_caller]
634    // fn sum_with_init<S: Into<Axis>>(
635    //     &self,
636    //     init_val: T,
637    //     axes: S,
638    //     keep_dims: bool,
639    // ) -> anyhow::Result<Self::Output>;
640
641    /// Computes the product of the elements along the specified axis.
642    ///
643    /// The `prod` function computes the product of elements along the specified axis of the tensor.
644    ///
645    /// # Parameters
646    ///
647    /// - `axis`: The axis along which to compute the product.
648    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
649    ///
650    /// # Returns
651    ///
652    /// - `anyhow::Result<Self::Output>`: A tensor containing the product of elements along the specified axis.
653    ///
654    /// # See Also
655    ///
656    /// - [`nanprod`]: Computes the product while ignoring NaN values.
657    #[track_caller]
658    fn prod<S: Into<Axis>>(&self, axis: S, keep_dims: bool) -> Result<Self::Output, TensorError>;
659
660    // /// Computes the product of the elements along the specified axis, with an initial value.
661    // ///
662    // /// The `prod_with_init` function computes the product of elements along the specified axes, starting from a given initial value.
663    // ///
664    // /// # Parameters
665    // ///
666    // /// - `init_val`: The initial value to start the product computation.
667    // /// - `axes`: The axes along which to compute the product.
668    // /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
669    // ///
670    // /// # Returns
671    // ///
672    // /// - `anyhow::Result<Self::Output>`: A tensor containing the product of elements along the specified axes.
673    // #[track_caller]
674    // fn prod_with_init<S: Into<Axis>>(
675    //     &self,
676    //     init_val: T,
677    //     axes: S,
678    //     keep_dims: bool,
679    // ) -> anyhow::Result<Self::Output>;
680
681    /// Computes the minimum value along the specified axis.
682    ///
683    /// The `min` function returns the minimum value of the elements along the specified axis of the tensor.
684    ///
685    /// # Parameters
686    ///
687    /// - `axis`: The axis along which to compute the minimum value.
688    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
689    ///
690    /// # Returns
691    ///
692    /// - `anyhow::Result<Self>`: A tensor containing the minimum values along the specified axis.
693    ///
694    /// # See Also
695    ///
696    /// - [`max`]: Computes the maximum value along the specified axis.
697    #[track_caller]
698    fn min<S: Into<Axis>>(&self, axis: S, keep_dims: bool) -> Result<Self::Output, TensorError>;
699
700    // /// Computes the minimum value along the specified axis, with an initial value.
701    // ///
702    // /// The `min_with_init` function computes the minimum value along the specified axes, starting from a given initial value.
703    // ///
704    // /// # Parameters
705    // ///
706    // /// - `init_val`: The initial value to compare against.
707    // /// - `axes`: The axes along which to compute the minimum value.
708    // /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
709    // ///
710    // /// # Returns
711    // ///
712    // /// - `anyhow::Result<Self>`: A tensor containing the minimum values along the specified axes.
713    // #[track_caller]
714    // fn min_with_init<S: Into<Axis>>(
715    //     &self,
716    //     init_val: T,
717    //     axes: S,
718    //     keep_dims: bool,
719    // ) -> anyhow::Result<Self>;
720
721    /// Computes the maximum value along the specified axis.
722    ///
723    /// The `max` function returns the maximum value of the elements along the specified axis of the tensor.
724    ///
725    /// # Parameters
726    ///
727    /// - `axis`: The axis along which to compute the maximum value.
728    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
729    ///
730    /// # Returns
731    ///
732    /// - `anyhow::Result<Self>`: A tensor containing the maximum values along the specified axis.
733    ///
734    /// # See Also
735    ///
736    /// - [`min`]: Computes the minimum value along the specified axis.
737    #[track_caller]
738    fn max<S: Into<Axis>>(&self, axis: S, keep_dims: bool) -> Result<Self::Output, TensorError>;
739
740    // /// Computes the maximum value along the specified axis, with an initial value.
741    // ///
742    // /// The `max_with_init` function computes the maximum value along the specified axes, starting from a given initial value.
743    // ///
744    // /// # Parameters
745    // ///
746    // /// - `init_val`: The initial value to compare against.
747    // /// - `axes`: The axes along which to compute the maximum value.
748    // /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
749    // ///
750    // /// # Returns
751    // ///
752    // /// - `anyhow::Result<Self>`: A tensor containing the maximum values along the specified axes.
753    // #[track_caller]
754    // fn max_with_init<S: Into<Axis>>(
755    //     &self,
756    //     init_val: T,
757    //     axes: S,
758    //     keep_dims: bool,
759    // ) -> anyhow::Result<Self>;
760
761    /// Reduces the tensor along the specified axis using the L1 norm (sum of absolute values).
762    ///
763    /// The `reducel1` function computes the L1 norm (sum of absolute values) along the specified axis of the tensor.
764    ///
765    /// # Parameters
766    ///
767    /// - `axis`: The axis along which to reduce the tensor.
768    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
769    ///
770    /// # Returns
771    ///
772    /// - `anyhow::Result<Self::Output>`: A tensor with the L1 norm computed along the specified axis.
773    #[track_caller]
774    fn reducel1<S: Into<Axis>>(
775        &self,
776        axis: S,
777        keep_dims: bool,
778    ) -> Result<Self::Output, TensorError>;
779
780    /// Computes the sum of the squares of the elements along the specified axis.
781    ///
782    /// The `sum_square` function computes the sum of the squares of the elements along the specified axis of the tensor.
783    ///
784    /// # Parameters
785    ///
786    /// - `axis`: The axis along which to sum the squares.
787    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
788    ///
789    /// # Returns
790    ///
791    /// - `anyhow::Result<Self::Output>`: A tensor containing the sum of squares of elements along the specified axis.
792    #[track_caller]
793    fn sum_square<S: Into<Axis>>(
794        &self,
795        axis: S,
796        keep_dims: bool,
797    ) -> Result<Self::Output, TensorError>;
798}
799
800/// A trait for tensor reduction operations, the output must be a boolean tensor.
801pub trait EvalReduce {
802    /// The boolean tensor type.
803    type BoolOutput;
804    /// Returns `true` if all elements along the specified axis evaluate to `true`.
805    ///
806    /// The `all` function checks whether all elements along the specified axis evaluate to `true`.
807    ///
808    /// # Parameters
809    ///
810    /// - `axis`: The axis along which to check.
811    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
812    ///
813    /// # Returns
814    ///
815    /// - `anyhow::Result<Self::BoolOutput>`: A boolean tensor indicating whether all elements evaluate to `true`.
816    ///
817    /// # See Also
818    ///
819    /// - [`any`]: Returns `true` if any element along the specified axis evaluates to `true`.
820    #[track_caller]
821    fn all<S: Into<Axis>>(&self, axis: S, keep_dims: bool)
822        -> Result<Self::BoolOutput, TensorError>;
823
824    /// Returns `true` if any element along the specified axis evaluates to `true`.
825    ///
826    /// The `any` function checks whether any element along the specified axis evaluates to `true`.
827    ///
828    /// # Parameters
829    ///
830    /// - `axis`: The axis along which to check.
831    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
832    ///
833    /// # Returns
834    ///
835    /// - `anyhow::Result<Self::BoolOutput>`: A boolean tensor indicating whether any element evaluates to `true`.
836    ///
837    /// # See Also
838    ///
839    /// - [`all`]: Returns `true` if all elements along the specified axis evaluate to `true`.
840    #[track_caller]
841    fn any<S: Into<Axis>>(&self, axis: S, keep_dims: bool)
842        -> Result<Self::BoolOutput, TensorError>;
843}
844
845/// A trait for tensor reduction operations, the output must remain the same tensor type.
846pub trait NormalEvalReduce<T> {
847    /// the output tensor type.
848    type Output;
849    /// Computes the sum of the elements along the specified axis, ignoring NaN values.
850    ///
851    /// The `nansum` function computes the sum of elements along the specified axis, while ignoring NaN values in the tensor.
852    ///
853    /// # Parameters
854    ///
855    /// - `axis`: The axis along which to sum the elements.
856    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
857    ///
858    /// # Returns
859    ///
860    /// - `anyhow::Result<Self::Output>`: A tensor containing the sum of elements, ignoring NaN values.
861    #[track_caller]
862    fn nansum<S: Into<Axis>>(&self, axis: S, keep_dims: bool) -> Result<Self::Output, TensorError>;
863
864    /// Computes the sum of the elements along the specified axis, with an initial value, ignoring NaN values.
865    ///
866    /// The `nansum_with_init` function computes the sum of elements along the specified axes, starting from a given initial value and ignoring NaN values.
867    ///
868    /// # Parameters
869    ///
870    /// - `init_val`: The initial value to start the summation.
871    /// - `axes`: The axes along which to sum the elements.
872    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
873    ///
874    /// # Returns
875    ///
876    /// - `anyhow::Result<Self::Output>`: A tensor containing the sum of elements, ignoring NaN values.
877    #[track_caller]
878    fn nansum_<S: Into<Axis>, O>(
879        &self,
880        axis: S,
881        keep_dims: bool,
882        init_out: bool,
883        out: O,
884    ) -> Result<Self::Output, TensorError>
885    where
886        O: BorrowMut<Self::Output>;
887    // /// Computes the sum of the elements along the specified axis, with an initial value, ignoring NaN values.
888    // ///
889    // /// The `nansum_with_init` function computes the sum of elements along the specified axes, starting from a given initial value and ignoring NaN values.
890    // ///
891    // /// # Parameters
892    // ///
893    // /// - `init_val`: The initial value to start the summation.
894    // /// - `axes`: The axes along which to sum the elements.
895    // /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
896    // ///
897    // /// # Returns
898    // ///
899    // /// - `anyhow::Result<Self::Output>`: A tensor containing the sum of elements, ignoring NaN values.
900    // #[track_caller]
901    // fn nansum_with_init<S: Into<Axis>>(
902    //     &self,
903    //     init_val: T,
904    //     axes: S,
905    //     keep_dims: bool,
906    // ) -> anyhow::Result<Self::Output>;
907
908    /// Computes the product of the elements along the specified axis, ignoring NaN values.
909    ///
910    /// The `nanprod` function computes the product of elements along the specified axis, while ignoring NaN values in the tensor.
911    ///
912    /// # Parameters
913    ///
914    /// - `axis`: The axis along which to compute the product.
915    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
916    ///
917    /// # Returns
918    ///
919    /// - `anyhow::Result<Self::Output>`: A tensor containing the product of elements, ignoring NaN values.
920    #[track_caller]
921    fn nanprod<S: Into<Axis>>(&self, axis: S, keep_dims: bool)
922        -> Result<Self::Output, TensorError>;
923
924    // /// Computes the product of the elements along the specified axis, with an initial value, ignoring NaN values.
925    // ///
926    // /// The `nanprod_with_init` function computes the product of elements along the specified axes, starting from a given initial value and ignoring NaN values.
927    // ///
928    // /// # Parameters
929    // ///
930    // /// - `init_val`: The initial value to start the product computation.
931    // /// - `axes`: The axes along which to compute the product.
932    // /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
933    // ///
934    // /// # Returns
935    // ///
936    // /// - `anyhow::Result<Self::Output>`: A tensor containing the product of elements, ignoring NaN values.
937    // #[track_caller]
938    // fn nanprod_with_init<S: Into<Axis>>(
939    //     &self,
940    //     init_val: T,
941    //     axes: S,
942    //     keep_dims: bool,
943    // ) -> anyhow::Result<Self::Output>;
944}
945
946/// A trait for tensor reduction operations, the output must be a floating-point tensor.
947pub trait FloatReduce<T>
948where
949    Self: Sized,
950{
951    /// The output tensor type.
952    type Output;
953
954    /// Computes the mean of the elements along the specified axis.
955    ///
956    /// The `mean` function calculates the mean of the elements along the specified axis of the tensor.
957    ///
958    /// # Parameters
959    ///
960    /// - `axis`: The axis along which to compute the mean.
961    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
962    ///
963    /// # Returns
964    ///
965    /// - `anyhow::Result<Self::Output>`: A tensor containing the mean values along the specified axis.
966    #[track_caller]
967    fn mean<S: Into<Axis>>(&self, axis: S, keep_dims: bool) -> Result<Self::Output, TensorError>;
968
969    /// Reduces the tensor along the specified axis using the L2 norm (Euclidean norm).
970    ///
971    /// The `reducel2` function computes the L2 norm (Euclidean norm) along the specified axis of the tensor.
972    ///
973    /// # Parameters
974    ///
975    /// - `axis`: The axis along which to reduce the tensor.
976    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
977    ///
978    /// # Returns
979    ///
980    /// - `anyhow::Result<Self::Output>`: A tensor with the L2 norm computed along the specified axis.
981    #[track_caller]
982    fn reducel2<S: Into<Axis>>(
983        &self,
984        axis: S,
985        keep_dims: bool,
986    ) -> Result<Self::Output, TensorError>;
987
988    /// Reduces the tensor along the specified axis using the L3 norm.
989    ///
990    /// The `reducel3` function computes the L3 norm along the specified axis of the tensor.
991    ///
992    /// # Parameters
993    ///
994    /// - `axis`: The axis along which to reduce the tensor.
995    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
996    ///
997    /// # Returns
998    ///
999    /// - `anyhow::Result<Self::Output>`: A tensor with the L3 norm computed along the specified axis.
1000    #[track_caller]
1001    fn reducel3<S: Into<Axis>>(
1002        &self,
1003        axis: S,
1004        keep_dims: bool,
1005    ) -> Result<Self::Output, TensorError>;
1006
1007    /// Computes the logarithm of the sum of exponentials of the elements along the specified axis.
1008    ///
1009    /// The `logsumexp` function calculates the logarithm of the sum of exponentials of the elements along the specified axis,
1010    /// which is useful for numerical stability in certain operations.
1011    ///
1012    /// # Parameters
1013    ///
1014    /// - `axis`: The axis along which to compute the logarithm of the sum of exponentials.
1015    /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
1016    ///
1017    /// # Returns
1018    ///
1019    /// - `anyhow::Result<Self::Output>`: A tensor containing the log-sum-exp values along the specified axis.
1020    #[track_caller]
1021    fn logsumexp<S: Into<Axis>>(
1022        &self,
1023        axis: S,
1024        keep_dims: bool,
1025    ) -> Result<Self::Output, TensorError>;
1026}
1027
1028/// Common bounds for primitive types
1029pub trait CommonBounds
1030where
1031    <Self as TypeCommon>::Vec: Send + Sync + Copy,
1032    Self: Sync
1033        + Send
1034        + Clone
1035        + Copy
1036        + TypeCommon
1037        + 'static
1038        + Display
1039        + Debug
1040        + Cast<Self>
1041        + NormalOut<Self, Output = Self>
1042        + FloatOutUnary
1043        + NormalOut<<Self as FloatOutUnary>::Output, Output = <Self as FloatOutUnary>::Output>
1044        + FloatOutBinary<<Self as FloatOutUnary>::Output, Output = <Self as FloatOutUnary>::Output>
1045        + FloatOutBinary<Self>
1046        + NormalOut<
1047            <Self as FloatOutBinary<Self>>::Output,
1048            Output = <Self as FloatOutBinary<Self>>::Output,
1049        >
1050        + NormalOutUnary
1051        + Cast<f64>,
1052{
1053}
1054impl<T> CommonBounds for T
1055where
1056    <Self as TypeCommon>::Vec: Send + Sync + Copy,
1057    Self: Sync
1058        + Send
1059        + Clone
1060        + Copy
1061        + TypeCommon
1062        + 'static
1063        + Display
1064        + Debug
1065        + Cast<Self>
1066        + NormalOut<Self, Output = Self>
1067        + FloatOutUnary
1068        + NormalOut<<Self as FloatOutUnary>::Output, Output = <Self as FloatOutUnary>::Output>
1069        + FloatOutBinary<<Self as FloatOutUnary>::Output, Output = <Self as FloatOutUnary>::Output>
1070        + FloatOutBinary<Self>
1071        + FloatOutBinary<
1072            <Self as FloatOutBinary<Self>>::Output,
1073            Output = <Self as FloatOutBinary<Self>>::Output,
1074        >
1075        + NormalOut<
1076            <Self as FloatOutBinary<Self>>::Output,
1077            Output = <Self as FloatOutBinary<Self>>::Output,
1078        >
1079        + NormalOutUnary
1080        + Cast<f64>,
1081{
1082}