hpt_traits/tensor.rs
1use hpt_common::error::base::TensorError;
2use hpt_common::{
3 axis::axis::Axis, layout::layout::Layout, shape::shape::Shape, strides::strides::Strides,
4 utils::pointer::Pointer,
5};
6use hpt_types::arch_simd as simd;
7use hpt_types::{
8 dtype::TypeCommon,
9 into_scalar::Cast,
10 type_promote::{FloatOutBinary, FloatOutUnary, NormalOut, NormalOutUnary},
11};
12use std::fmt::Debug;
13use std::{borrow::BorrowMut, fmt::Display};
14
15#[cfg(target_feature = "avx2")]
16type BoolVector = simd::_256bit::boolx32::boolx32;
17#[cfg(any(
18 all(not(target_feature = "avx2"), target_feature = "sse"),
19 target_arch = "arm",
20 target_arch = "aarch64",
21 target_feature = "neon"
22))]
23type BoolVector = simd::_128bit::boolx16::boolx16;
24
25/// A trait for getting information of a Tensor
26pub trait TensorInfo<T> {
27 /// Returns a pointer to the tensor's first data.
28 #[track_caller]
29 fn ptr(&self) -> Pointer<T>;
30
31 /// Returns the size of the tensor based on the shape
32 #[track_caller]
33 fn size(&self) -> usize;
34
35 /// Returns the shape of the tensor.
36 #[track_caller]
37 fn shape(&self) -> &Shape;
38
39 /// Returns the strides of the tensor.
40 #[track_caller]
41 fn strides(&self) -> &Strides;
42
43 /// Returns the layout of the tensor. Layout contains shape and strides.
44 #[track_caller]
45 fn layout(&self) -> &Layout;
46 /// Returns the root tensor, if any.
47 ///
48 /// if the tensor is a view, it will return the root tensor. Otherwise, it will return None.
49 #[track_caller]
50 fn parent(&self) -> Option<Pointer<T>>;
51
52 /// Returns the number of dimensions of the tensor.
53 #[track_caller]
54 fn ndim(&self) -> usize;
55
56 /// Returns whether the tensor is contiguous in memory. View or transpose tensors are not contiguous.
57 #[track_caller]
58 fn is_contiguous(&self) -> bool;
59
60 /// Returns the data type memory size in bytes.
61 #[track_caller]
62 fn elsize() -> usize {
63 size_of::<T>()
64 }
65}
66
67/// A trait for let the object like a tensor
68pub trait TensorLike<T>: Sized {
69 /// directly convert the tensor to raw slice
70 ///
71 /// # Note
72 ///
73 /// This function will return a raw slice of the tensor regardless of the shape and strides.
74 ///
75 /// if you do iteration on the view tensor, you may see unexpected results.
76 fn as_raw(&self) -> &[T];
77
78 /// directly convert the tensor to mutable raw slice
79 ///
80 /// # Note
81 ///
82 /// This function will return a mutable raw slice of the tensor regardless of the shape and strides.
83 ///
84 /// if you do iteration on the view tensor, you may see unexpected results.
85 fn as_raw_mut(&mut self) -> &mut [T];
86
87 /// Returns the tensor as a contiguous tensor.
88 ///
89 /// # Note
90 ///
91 /// This function will return a contiguous tensor. If the tensor is already contiguous, it will return a clone of the tensor.
92 ///
93 /// If the tensor is a view tensor, it will return a new tensor with the same data but with a contiguous layout.
94 fn contiguous(&self) -> Result<Self, TensorError>;
95
96 /// Returns the data type memory size in bytes.
97 fn elsize() -> usize {
98 size_of::<T>()
99 }
100}
101
102/// A trait defines a set of functions to create tensors.
103pub trait TensorCreator<T>
104where
105 Self: Sized,
106{
107 /// the output type of the creator
108 type Output;
109
110 /// Creates a tensor with uninitialized elements of the specified shape.
111 ///
112 /// This function allocates memory for a tensor of the given shape, but the values are uninitialized, meaning they may contain random data.
113 ///
114 /// # Arguments
115 ///
116 /// * `shape` - The desired shape of the tensor. The type `S` must implement `Into<Shape>`.
117 ///
118 /// # Returns
119 ///
120 /// * A tensor with the specified shape, but with uninitialized data.
121 ///
122 /// # Panics
123 ///
124 /// * This function may panic if the requested shape is invalid or too large for available memory.
125 #[track_caller]
126 fn empty<S: Into<Shape>>(shape: S) -> Result<Self::Output, TensorError>;
127
128 /// Creates a tensor filled with zeros of the specified shape.
129 ///
130 /// This function returns a tensor where every element is initialized to `0`, with the shape defined by the input.
131 ///
132 /// # Arguments
133 ///
134 /// * `shape` - The desired shape of the tensor. The type `S` must implement `Into<Shape>`.
135 ///
136 /// # Returns
137 ///
138 /// * A tensor filled with zeros, with the specified shape.
139 ///
140 /// # Panics
141 ///
142 /// * This function may panic if the requested shape is invalid or too large for available memory.
143 #[track_caller]
144 fn zeros<S: Into<Shape>>(shape: S) -> Result<Self::Output, TensorError>;
145
146 /// Creates a tensor filled with ones of the specified shape.
147 ///
148 /// This function returns a tensor where every element is initialized to `1`, with the shape defined by the input.
149 ///
150 /// # Arguments
151 ///
152 /// * `shape` - The desired shape of the tensor. The type `S` must implement `Into<Shape>`.
153 ///
154 /// # Returns
155 ///
156 /// * A tensor filled with ones, with the specified shape.
157 ///
158 /// # Panics
159 ///
160 /// * This function may panic if the requested shape is invalid or too large for available memory.
161 #[track_caller]
162 fn ones<S: Into<Shape>>(shape: S) -> Result<Self::Output, TensorError>
163 where
164 u8: Cast<T>;
165
166 /// Creates a tensor with uninitialized elements, having the same shape as the input tensor.
167 ///
168 /// This function returns a tensor with the same shape as the calling tensor, but with uninitialized data.
169 ///
170 /// # Arguments
171 ///
172 /// This function takes no arguments.
173 ///
174 /// # Returns
175 ///
176 /// * A tensor with the same shape as the input, but with uninitialized data.
177 ///
178 /// # Panics
179 ///
180 /// * This function may panic if the shape is too large for available memory.
181 #[track_caller]
182 fn empty_like(&self) -> Result<Self::Output, TensorError>;
183
184 /// Creates a tensor filled with zeros, having the same shape as the input tensor.
185 ///
186 /// This function returns a tensor with the same shape as the calling tensor, with all elements initialized to `0`.
187 ///
188 /// # Arguments
189 ///
190 /// This function takes no arguments.
191 ///
192 /// # Returns
193 ///
194 /// * A tensor with the same shape as the input, filled with zeros.
195 ///
196 /// # Panics
197 ///
198 /// * This function may panic if the shape is too large for available memory.
199 #[track_caller]
200 fn zeros_like(&self) -> Result<Self::Output, TensorError>;
201
202 /// Creates a tensor filled with ones, having the same shape as the input tensor.
203 ///
204 /// This function returns a tensor with the same shape as the calling tensor, with all elements initialized to `1`.
205 ///
206 /// # Arguments
207 ///
208 /// This function takes no arguments.
209 ///
210 /// # Returns
211 ///
212 /// * A tensor with the same shape as the input, filled with ones.
213 ///
214 /// # Panics
215 ///
216 /// * This function may panic if the shape is too large for available memory.
217 #[track_caller]
218 fn ones_like(&self) -> Result<Self::Output, TensorError>
219 where
220 u8: Cast<T>;
221
222 /// Creates a tensor filled with a specified value, with the specified shape.
223 ///
224 /// This function returns a tensor where every element is set to `val`, with the shape defined by the input.
225 ///
226 /// # Arguments
227 ///
228 /// * `val` - The value to fill the tensor with.
229 /// * `shape` - The desired shape of the tensor. The type `S` must implement `Into<Shape>`.
230 ///
231 /// # Returns
232 ///
233 /// * A tensor filled with `val`, with the specified shape.
234 ///
235 /// # Panics
236 ///
237 /// * This function may panic if the requested shape is invalid or too large for available memory.
238 #[track_caller]
239 fn full<S: Into<Shape>>(val: T, shape: S) -> Result<Self::Output, TensorError>;
240
241 /// Creates a tensor filled with a specified value, having the same shape as the input tensor.
242 ///
243 /// This function returns a tensor where every element is set to `val`, with the same shape as the calling tensor.
244 ///
245 /// # Arguments
246 ///
247 /// * `val` - The value to fill the tensor with.
248 ///
249 /// # Returns
250 ///
251 /// * A tensor with the same shape as the input, filled with `val`.
252 ///
253 /// # Panics
254 ///
255 /// * This function may panic if the shape is too large for available memory.
256 #[track_caller]
257 fn full_like(&self, val: T) -> Result<Self::Output, TensorError>;
258
259 /// Creates a tensor with values within a specified range.
260 ///
261 /// This function generates a 1D tensor with values ranging from `start` (inclusive) to `end` (exclusive).
262 ///
263 /// # Arguments
264 ///
265 /// * `start` - The start of the range.
266 /// * `end` - The end of the range.
267 ///
268 /// # Returns
269 ///
270 /// * A 1D tensor with values ranging from `start` to `end`.
271 ///
272 /// # Panics
273 ///
274 /// * This function will panic if `start` is greater than or equal to `end`, or if the range is too large for available memory.
275 #[track_caller]
276 fn arange<U>(start: U, end: U) -> Result<Self::Output, TensorError>
277 where
278 usize: Cast<T>,
279 U: Cast<i64> + Cast<T> + Copy;
280
281 /// Creates a tensor with values within a specified range with a given step size.
282 ///
283 /// This function generates a 1D tensor with values ranging from `start` (inclusive) to `end` (exclusive),
284 /// incremented by `step`.
285 ///
286 /// # Arguments
287 ///
288 /// * `start` - The start of the range.
289 /// * `end` - The end of the range (exclusive).
290 /// * `step` - The step size between consecutive values.
291 ///
292 /// # Returns
293 ///
294 /// * A 1D tensor with values from `start` to `end`, incremented by `step`.
295 ///
296 /// # Panics
297 ///
298 /// * This function will panic if `step` is zero or if the range and step values are incompatible.
299 #[track_caller]
300 fn arange_step(start: T, end: T, step: T) -> Result<Self::Output, TensorError>
301 where
302 T: Cast<f64> + Cast<usize>,
303 usize: Cast<T>;
304
305 /// Creates a 2D identity matrix with ones on a diagonal and zeros elsewhere.
306 ///
307 /// This function generates a matrix of size `n` by `m`, with ones on the `k`th diagonal (can be offset) and zeros elsewhere.
308 ///
309 /// # Arguments
310 ///
311 /// * `n` - The number of rows in the matrix.
312 /// * `m` - The number of columns in the matrix.
313 /// * `k` - The diagonal offset (0 for main diagonal, positive for upper diagonals, negative for lower diagonals).
314 ///
315 /// # Returns
316 ///
317 /// * A 2D identity matrix with ones on the specified diagonal.
318 ///
319 /// # Panics
320 ///
321 /// * This function will panic if `n` or `m` is zero, or if memory constraints are exceeded.
322 #[track_caller]
323 fn eye(n: usize, m: usize, k: usize) -> Result<Self::Output, TensorError>;
324
325 /// Creates a tensor with evenly spaced values between `start` and `end`.
326 ///
327 /// This function generates a 1D tensor of `num` values, linearly spaced between `start` and `end`.
328 /// If `include_end` is `true`, the `end` value will be included as the last element.
329 ///
330 /// # Arguments
331 ///
332 /// * `start` - The start of the range.
333 /// * `end` - The end of the range.
334 /// * `num` - The number of evenly spaced values to generate.
335 /// * `include_end` - Whether to include the `end` value in the generated tensor.
336 ///
337 /// # Returns
338 ///
339 /// * A 1D tensor with `num` linearly spaced values between `start` and `end`.
340 ///
341 /// # Panics
342 ///
343 /// * This function will panic if `num` is zero or if `num` is too large for available memory.
344 #[track_caller]
345 fn linspace<U>(
346 start: U,
347 end: U,
348 num: usize,
349 include_end: bool,
350 ) -> Result<Self::Output, TensorError>
351 where
352 U: Cast<f64> + Cast<T> + Copy,
353 usize: Cast<T>,
354 f64: Cast<T>;
355
356 /// Creates a tensor with logarithmically spaced values between `start` and `end`.
357 ///
358 /// This function generates a 1D tensor of `num` values spaced evenly on a log scale between `start` and `end`.
359 /// The spacing is based on the logarithm to the given `base`. If `include_end` is `true`, the `end` value will be included.
360 ///
361 /// # Arguments
362 ///
363 /// * `start` - The starting exponent (base `base`).
364 /// * `end` - The ending exponent (base `base`).
365 /// * `num` - The number of logarithmically spaced values to generate.
366 /// * `include_end` - Whether to include the `end` value in the generated tensor.
367 /// * `base` - The base of the logarithm.
368 ///
369 /// # Returns
370 ///
371 /// * A 1D tensor with `num` logarithmically spaced values between `start` and `end`.
372 ///
373 /// # Panics
374 ///
375 /// * This function will panic if `num` is zero or if `base` is less than or equal to zero.
376 #[track_caller]
377 fn logspace(
378 start: T,
379 end: T,
380 num: usize,
381 include_end: bool,
382 base: T,
383 ) -> Result<Self::Output, TensorError>
384 where
385 T: Cast<f64> + num::Float + NormalOut<T, Output = T>,
386 usize: Cast<T>,
387 f64: Cast<T>;
388
389 /// Creates a tensor with geometrically spaced values between `start` and `end`.
390 ///
391 /// This function generates a 1D tensor of `n` values spaced evenly on a geometric scale between `start` and `end`.
392 /// If `include_end` is `true`, the `end` value will be included.
393 ///
394 /// # Arguments
395 ///
396 /// * `start` - The starting value (must be positive).
397 /// * `end` - The ending value (must be positive).
398 /// * `n` - The number of geometrically spaced values to generate.
399 /// * `include_end` - Whether to include the `end` value in the generated tensor.
400 ///
401 /// # Returns
402 ///
403 /// * A 1D tensor with `n` geometrically spaced values between `start` and `end`.
404 ///
405 /// # Panics
406 ///
407 /// * This function will panic if `n` is zero, if `start` or `end` is negative, or if the values result in undefined behavior.
408 #[track_caller]
409 fn geomspace(
410 start: T,
411 end: T,
412 n: usize,
413 include_end: bool,
414 ) -> Result<Self::Output, TensorError>
415 where
416 f64: Cast<T>,
417 usize: Cast<T>,
418 T: Cast<f64>;
419
420 /// Creates a 2D triangular matrix of size `n` by `m`, with ones below or on the `k`th diagonal and zeros elsewhere.
421 ///
422 /// This function generates a matrix with a triangular structure, filled with ones and zeros, based on the diagonal offset and the `low_triangle` flag.
423 ///
424 /// # Arguments
425 ///
426 /// * `n` - The number of rows in the matrix.
427 /// * `m` - The number of columns in the matrix.
428 /// * `k` - The diagonal offset (0 for main diagonal, positive for upper diagonals, negative for lower diagonals).
429 /// * `low_triangle` - If `true`, the matrix will be lower triangular; otherwise, upper triangular.
430 ///
431 /// # Returns
432 ///
433 /// * A 2D triangular matrix of ones and zeros.
434 ///
435 /// # Panics
436 ///
437 /// * This function will panic if `n` or `m` is zero.
438 #[track_caller]
439 fn tri(n: usize, m: usize, k: i64, low_triangle: bool) -> Result<Self::Output, TensorError>
440 where
441 u8: Cast<T>;
442
443 /// Returns the lower triangular part of the matrix, with all elements above the `k`th diagonal set to zero.
444 ///
445 /// This function generates a tensor where the elements above the specified diagonal are set to zero.
446 ///
447 /// # Arguments
448 ///
449 /// * `k` - The diagonal offset (0 for main diagonal, positive for upper diagonals, negative for lower diagonals).
450 ///
451 /// # Returns
452 ///
453 /// * A tensor with its upper triangular part set to zero.
454 ///
455 /// # Panics
456 ///
457 /// * This function should not panic under normal conditions.
458 #[track_caller]
459 fn tril(&self, k: i64) -> Result<Self::Output, TensorError>
460 where
461 T: NormalOut<bool, Output = T> + Cast<T> + TypeCommon,
462 <T as TypeCommon>::Vec: NormalOut<BoolVector, Output = <T as TypeCommon>::Vec>;
463
464 /// Returns the upper triangular part of the matrix, with all elements below the `k`th diagonal set to zero.
465 ///
466 /// This function generates a tensor where the elements below the specified diagonal are set to zero.
467 ///
468 /// # Arguments
469 ///
470 /// * `k` - The diagonal offset (0 for main diagonal, positive for upper diagonals, negative for lower diagonals).
471 ///
472 /// # Returns
473 ///
474 /// * A tensor with its lower triangular part set to zero.
475 ///
476 /// # Panics
477 ///
478 /// * This function should not panic under normal conditions.
479 #[track_caller]
480 fn triu(&self, k: i64) -> Result<Self::Output, TensorError>
481 where
482 T: NormalOut<bool, Output = T> + Cast<T> + TypeCommon,
483 <T as TypeCommon>::Vec: NormalOut<BoolVector, Output = <T as TypeCommon>::Vec>;
484
485 /// Creates a 2D identity matrix of size `n` by `n`.
486 ///
487 /// This function generates a square matrix with ones on the main diagonal and zeros elsewhere.
488 ///
489 /// # Arguments
490 ///
491 /// * `n` - The size of the matrix (both rows and columns).
492 ///
493 /// # Returns
494 ///
495 /// * A 2D identity matrix of size `n`.
496 ///
497 /// # Panics
498 ///
499 /// * This function will panic if `n` is zero.
500 #[track_caller]
501 fn identity(n: usize) -> Result<Self::Output, TensorError>
502 where
503 u8: Cast<T>;
504}
505
506/// A trait for tensor memory allocation, this trait only used when we work with generic type
507pub trait TensorAlloc<Output = Self> {
508 /// The tensor data type.
509 type Meta;
510 /// Creates a tensor with the specified shape,
511 ///
512 /// # Note
513 ///
514 /// This function doesn't initialize the tensor's elements.
515 #[track_caller]
516 fn _empty<S: Into<Shape>>(shape: S) -> Result<Output, TensorError>
517 where
518 Self: Sized;
519}
520
521/// A trait typically for argmax and argmin functions.
522pub trait IndexReduce
523where
524 Self: Sized,
525{
526 /// The output tensor type.
527 type Output;
528
529 /// Returns the indices of the maximum values along the specified axis.
530 ///
531 /// The `argmax` function computes the index of the maximum value along the given axis for each slice of the tensor.
532 ///
533 /// # Parameters
534 ///
535 /// - `axis`: The axis along which to compute the index of the maximum value.
536 /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
537 ///
538 /// # Returns
539 ///
540 /// - `anyhow::Result<Self::Output>`: A tensor containing the indices of the maximum values.
541 ///
542 /// # See Also
543 ///
544 /// - [`argmin`]: Returns the indices of the minimum values along the specified axis.
545 #[track_caller]
546 fn argmax<S: Into<Axis>>(&self, axis: S, keep_dims: bool) -> Result<Self::Output, TensorError>;
547
548 /// Returns the indices of the minimum values along the specified axis.
549 ///
550 /// The `argmin` function computes the index of the minimum value along the given axis for each slice of the tensor.
551 ///
552 /// # Parameters
553 ///
554 /// - `axis`: The axis along which to compute the index of the minimum value.
555 /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
556 ///
557 /// # Returns
558 ///
559 /// - `anyhow::Result<Self::Output>`: A tensor containing the indices of the minimum values.
560 ///
561 /// # See Also
562 ///
563 /// - [`argmax`]: Returns the indices of the maximum values along the specified axis.
564 #[track_caller]
565 fn argmin<S: Into<Axis>>(&self, axis: S, keep_dims: bool) -> Result<Self::Output, TensorError>;
566}
567
568/// A trait for normal tensor reduction operations.
569pub trait NormalReduce<T>
570where
571 Self: Sized,
572{
573 /// The output tensor type.
574 type Output;
575
576 /// Computes the sum of the elements along the specified axis.
577 ///
578 /// The `sum` function computes the sum of elements along the specified axis of the tensor.
579 ///
580 /// # Parameters
581 ///
582 /// - `axis`: The axis along which to sum the elements.
583 /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
584 ///
585 /// # Returns
586 ///
587 /// - `anyhow::Result<Self::Output>`: A tensor containing the sum of elements along the specified axis.
588 ///
589 /// # See Also
590 ///
591 /// - [`nansum`]: Computes the sum while ignoring NaN values.
592 #[track_caller]
593 fn sum<S: Into<Axis>>(&self, axis: S, keep_dims: bool) -> Result<Self::Output, TensorError>;
594
595 /// Computes the sum of the elements along the specified axis, storing the result in a pre-allocated tensor.
596 ///
597 /// The `sum_` function computes the sum of elements along the specified axis, and optionally initializes an output tensor to store the result.
598 ///
599 /// # Parameters
600 ///
601 /// - `axis`: The axis along which to sum the elements.
602 /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
603 /// - `init_out`: Whether to initialize the output tensor.
604 /// - `out`: The tensor in which to store the result.
605 ///
606 /// # Returns
607 ///
608 /// - `anyhow::Result<Self::Output>`: A tensor containing the sum of elements, with the result stored in the specified output tensor.
609 #[track_caller]
610 fn sum_<S: Into<Axis>, O>(
611 &self,
612 axis: S,
613 keep_dims: bool,
614 init_out: bool,
615 out: O,
616 ) -> Result<Self::Output, TensorError>
617 where
618 O: BorrowMut<Self::Output>;
619
620 // /// Computes the sum of the elements along the specified axis, with an initial value.
621 // ///
622 // /// The `sum_with_init` function computes the sum of elements along the specified axes, starting from a given initial value.
623 // ///
624 // /// # Parameters
625 // ///
626 // /// - `init_val`: The initial value to start the summation.
627 // /// - `axes`: The axes along which to sum the elements.
628 // /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
629 // ///
630 // /// # Returns
631 // ///
632 // /// - `anyhow::Result<Self::Output>`: A tensor containing the sum of elements along the specified axes.
633 // #[track_caller]
634 // fn sum_with_init<S: Into<Axis>>(
635 // &self,
636 // init_val: T,
637 // axes: S,
638 // keep_dims: bool,
639 // ) -> anyhow::Result<Self::Output>;
640
641 /// Computes the product of the elements along the specified axis.
642 ///
643 /// The `prod` function computes the product of elements along the specified axis of the tensor.
644 ///
645 /// # Parameters
646 ///
647 /// - `axis`: The axis along which to compute the product.
648 /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
649 ///
650 /// # Returns
651 ///
652 /// - `anyhow::Result<Self::Output>`: A tensor containing the product of elements along the specified axis.
653 ///
654 /// # See Also
655 ///
656 /// - [`nanprod`]: Computes the product while ignoring NaN values.
657 #[track_caller]
658 fn prod<S: Into<Axis>>(&self, axis: S, keep_dims: bool) -> Result<Self::Output, TensorError>;
659
660 // /// Computes the product of the elements along the specified axis, with an initial value.
661 // ///
662 // /// The `prod_with_init` function computes the product of elements along the specified axes, starting from a given initial value.
663 // ///
664 // /// # Parameters
665 // ///
666 // /// - `init_val`: The initial value to start the product computation.
667 // /// - `axes`: The axes along which to compute the product.
668 // /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
669 // ///
670 // /// # Returns
671 // ///
672 // /// - `anyhow::Result<Self::Output>`: A tensor containing the product of elements along the specified axes.
673 // #[track_caller]
674 // fn prod_with_init<S: Into<Axis>>(
675 // &self,
676 // init_val: T,
677 // axes: S,
678 // keep_dims: bool,
679 // ) -> anyhow::Result<Self::Output>;
680
681 /// Computes the minimum value along the specified axis.
682 ///
683 /// The `min` function returns the minimum value of the elements along the specified axis of the tensor.
684 ///
685 /// # Parameters
686 ///
687 /// - `axis`: The axis along which to compute the minimum value.
688 /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
689 ///
690 /// # Returns
691 ///
692 /// - `anyhow::Result<Self>`: A tensor containing the minimum values along the specified axis.
693 ///
694 /// # See Also
695 ///
696 /// - [`max`]: Computes the maximum value along the specified axis.
697 #[track_caller]
698 fn min<S: Into<Axis>>(&self, axis: S, keep_dims: bool) -> Result<Self::Output, TensorError>;
699
700 // /// Computes the minimum value along the specified axis, with an initial value.
701 // ///
702 // /// The `min_with_init` function computes the minimum value along the specified axes, starting from a given initial value.
703 // ///
704 // /// # Parameters
705 // ///
706 // /// - `init_val`: The initial value to compare against.
707 // /// - `axes`: The axes along which to compute the minimum value.
708 // /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
709 // ///
710 // /// # Returns
711 // ///
712 // /// - `anyhow::Result<Self>`: A tensor containing the minimum values along the specified axes.
713 // #[track_caller]
714 // fn min_with_init<S: Into<Axis>>(
715 // &self,
716 // init_val: T,
717 // axes: S,
718 // keep_dims: bool,
719 // ) -> anyhow::Result<Self>;
720
721 /// Computes the maximum value along the specified axis.
722 ///
723 /// The `max` function returns the maximum value of the elements along the specified axis of the tensor.
724 ///
725 /// # Parameters
726 ///
727 /// - `axis`: The axis along which to compute the maximum value.
728 /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
729 ///
730 /// # Returns
731 ///
732 /// - `anyhow::Result<Self>`: A tensor containing the maximum values along the specified axis.
733 ///
734 /// # See Also
735 ///
736 /// - [`min`]: Computes the minimum value along the specified axis.
737 #[track_caller]
738 fn max<S: Into<Axis>>(&self, axis: S, keep_dims: bool) -> Result<Self::Output, TensorError>;
739
740 // /// Computes the maximum value along the specified axis, with an initial value.
741 // ///
742 // /// The `max_with_init` function computes the maximum value along the specified axes, starting from a given initial value.
743 // ///
744 // /// # Parameters
745 // ///
746 // /// - `init_val`: The initial value to compare against.
747 // /// - `axes`: The axes along which to compute the maximum value.
748 // /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
749 // ///
750 // /// # Returns
751 // ///
752 // /// - `anyhow::Result<Self>`: A tensor containing the maximum values along the specified axes.
753 // #[track_caller]
754 // fn max_with_init<S: Into<Axis>>(
755 // &self,
756 // init_val: T,
757 // axes: S,
758 // keep_dims: bool,
759 // ) -> anyhow::Result<Self>;
760
761 /// Reduces the tensor along the specified axis using the L1 norm (sum of absolute values).
762 ///
763 /// The `reducel1` function computes the L1 norm (sum of absolute values) along the specified axis of the tensor.
764 ///
765 /// # Parameters
766 ///
767 /// - `axis`: The axis along which to reduce the tensor.
768 /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
769 ///
770 /// # Returns
771 ///
772 /// - `anyhow::Result<Self::Output>`: A tensor with the L1 norm computed along the specified axis.
773 #[track_caller]
774 fn reducel1<S: Into<Axis>>(
775 &self,
776 axis: S,
777 keep_dims: bool,
778 ) -> Result<Self::Output, TensorError>;
779
780 /// Computes the sum of the squares of the elements along the specified axis.
781 ///
782 /// The `sum_square` function computes the sum of the squares of the elements along the specified axis of the tensor.
783 ///
784 /// # Parameters
785 ///
786 /// - `axis`: The axis along which to sum the squares.
787 /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
788 ///
789 /// # Returns
790 ///
791 /// - `anyhow::Result<Self::Output>`: A tensor containing the sum of squares of elements along the specified axis.
792 #[track_caller]
793 fn sum_square<S: Into<Axis>>(
794 &self,
795 axis: S,
796 keep_dims: bool,
797 ) -> Result<Self::Output, TensorError>;
798}
799
800/// A trait for tensor reduction operations, the output must be a boolean tensor.
801pub trait EvalReduce {
802 /// The boolean tensor type.
803 type BoolOutput;
804 /// Returns `true` if all elements along the specified axis evaluate to `true`.
805 ///
806 /// The `all` function checks whether all elements along the specified axis evaluate to `true`.
807 ///
808 /// # Parameters
809 ///
810 /// - `axis`: The axis along which to check.
811 /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
812 ///
813 /// # Returns
814 ///
815 /// - `anyhow::Result<Self::BoolOutput>`: A boolean tensor indicating whether all elements evaluate to `true`.
816 ///
817 /// # See Also
818 ///
819 /// - [`any`]: Returns `true` if any element along the specified axis evaluates to `true`.
820 #[track_caller]
821 fn all<S: Into<Axis>>(&self, axis: S, keep_dims: bool)
822 -> Result<Self::BoolOutput, TensorError>;
823
824 /// Returns `true` if any element along the specified axis evaluates to `true`.
825 ///
826 /// The `any` function checks whether any element along the specified axis evaluates to `true`.
827 ///
828 /// # Parameters
829 ///
830 /// - `axis`: The axis along which to check.
831 /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
832 ///
833 /// # Returns
834 ///
835 /// - `anyhow::Result<Self::BoolOutput>`: A boolean tensor indicating whether any element evaluates to `true`.
836 ///
837 /// # See Also
838 ///
839 /// - [`all`]: Returns `true` if all elements along the specified axis evaluate to `true`.
840 #[track_caller]
841 fn any<S: Into<Axis>>(&self, axis: S, keep_dims: bool)
842 -> Result<Self::BoolOutput, TensorError>;
843}
844
845/// A trait for tensor reduction operations, the output must remain the same tensor type.
846pub trait NormalEvalReduce<T> {
847 /// the output tensor type.
848 type Output;
849 /// Computes the sum of the elements along the specified axis, ignoring NaN values.
850 ///
851 /// The `nansum` function computes the sum of elements along the specified axis, while ignoring NaN values in the tensor.
852 ///
853 /// # Parameters
854 ///
855 /// - `axis`: The axis along which to sum the elements.
856 /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
857 ///
858 /// # Returns
859 ///
860 /// - `anyhow::Result<Self::Output>`: A tensor containing the sum of elements, ignoring NaN values.
861 #[track_caller]
862 fn nansum<S: Into<Axis>>(&self, axis: S, keep_dims: bool) -> Result<Self::Output, TensorError>;
863
864 /// Computes the sum of the elements along the specified axis, with an initial value, ignoring NaN values.
865 ///
866 /// The `nansum_with_init` function computes the sum of elements along the specified axes, starting from a given initial value and ignoring NaN values.
867 ///
868 /// # Parameters
869 ///
870 /// - `init_val`: The initial value to start the summation.
871 /// - `axes`: The axes along which to sum the elements.
872 /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
873 ///
874 /// # Returns
875 ///
876 /// - `anyhow::Result<Self::Output>`: A tensor containing the sum of elements, ignoring NaN values.
877 #[track_caller]
878 fn nansum_<S: Into<Axis>, O>(
879 &self,
880 axis: S,
881 keep_dims: bool,
882 init_out: bool,
883 out: O,
884 ) -> Result<Self::Output, TensorError>
885 where
886 O: BorrowMut<Self::Output>;
887 // /// Computes the sum of the elements along the specified axis, with an initial value, ignoring NaN values.
888 // ///
889 // /// The `nansum_with_init` function computes the sum of elements along the specified axes, starting from a given initial value and ignoring NaN values.
890 // ///
891 // /// # Parameters
892 // ///
893 // /// - `init_val`: The initial value to start the summation.
894 // /// - `axes`: The axes along which to sum the elements.
895 // /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
896 // ///
897 // /// # Returns
898 // ///
899 // /// - `anyhow::Result<Self::Output>`: A tensor containing the sum of elements, ignoring NaN values.
900 // #[track_caller]
901 // fn nansum_with_init<S: Into<Axis>>(
902 // &self,
903 // init_val: T,
904 // axes: S,
905 // keep_dims: bool,
906 // ) -> anyhow::Result<Self::Output>;
907
908 /// Computes the product of the elements along the specified axis, ignoring NaN values.
909 ///
910 /// The `nanprod` function computes the product of elements along the specified axis, while ignoring NaN values in the tensor.
911 ///
912 /// # Parameters
913 ///
914 /// - `axis`: The axis along which to compute the product.
915 /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
916 ///
917 /// # Returns
918 ///
919 /// - `anyhow::Result<Self::Output>`: A tensor containing the product of elements, ignoring NaN values.
920 #[track_caller]
921 fn nanprod<S: Into<Axis>>(&self, axis: S, keep_dims: bool)
922 -> Result<Self::Output, TensorError>;
923
924 // /// Computes the product of the elements along the specified axis, with an initial value, ignoring NaN values.
925 // ///
926 // /// The `nanprod_with_init` function computes the product of elements along the specified axes, starting from a given initial value and ignoring NaN values.
927 // ///
928 // /// # Parameters
929 // ///
930 // /// - `init_val`: The initial value to start the product computation.
931 // /// - `axes`: The axes along which to compute the product.
932 // /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
933 // ///
934 // /// # Returns
935 // ///
936 // /// - `anyhow::Result<Self::Output>`: A tensor containing the product of elements, ignoring NaN values.
937 // #[track_caller]
938 // fn nanprod_with_init<S: Into<Axis>>(
939 // &self,
940 // init_val: T,
941 // axes: S,
942 // keep_dims: bool,
943 // ) -> anyhow::Result<Self::Output>;
944}
945
946/// A trait for tensor reduction operations, the output must be a floating-point tensor.
947pub trait FloatReduce<T>
948where
949 Self: Sized,
950{
951 /// The output tensor type.
952 type Output;
953
954 /// Computes the mean of the elements along the specified axis.
955 ///
956 /// The `mean` function calculates the mean of the elements along the specified axis of the tensor.
957 ///
958 /// # Parameters
959 ///
960 /// - `axis`: The axis along which to compute the mean.
961 /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
962 ///
963 /// # Returns
964 ///
965 /// - `anyhow::Result<Self::Output>`: A tensor containing the mean values along the specified axis.
966 #[track_caller]
967 fn mean<S: Into<Axis>>(&self, axis: S, keep_dims: bool) -> Result<Self::Output, TensorError>;
968
969 /// Reduces the tensor along the specified axis using the L2 norm (Euclidean norm).
970 ///
971 /// The `reducel2` function computes the L2 norm (Euclidean norm) along the specified axis of the tensor.
972 ///
973 /// # Parameters
974 ///
975 /// - `axis`: The axis along which to reduce the tensor.
976 /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
977 ///
978 /// # Returns
979 ///
980 /// - `anyhow::Result<Self::Output>`: A tensor with the L2 norm computed along the specified axis.
981 #[track_caller]
982 fn reducel2<S: Into<Axis>>(
983 &self,
984 axis: S,
985 keep_dims: bool,
986 ) -> Result<Self::Output, TensorError>;
987
988 /// Reduces the tensor along the specified axis using the L3 norm.
989 ///
990 /// The `reducel3` function computes the L3 norm along the specified axis of the tensor.
991 ///
992 /// # Parameters
993 ///
994 /// - `axis`: The axis along which to reduce the tensor.
995 /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
996 ///
997 /// # Returns
998 ///
999 /// - `anyhow::Result<Self::Output>`: A tensor with the L3 norm computed along the specified axis.
1000 #[track_caller]
1001 fn reducel3<S: Into<Axis>>(
1002 &self,
1003 axis: S,
1004 keep_dims: bool,
1005 ) -> Result<Self::Output, TensorError>;
1006
1007 /// Computes the logarithm of the sum of exponentials of the elements along the specified axis.
1008 ///
1009 /// The `logsumexp` function calculates the logarithm of the sum of exponentials of the elements along the specified axis,
1010 /// which is useful for numerical stability in certain operations.
1011 ///
1012 /// # Parameters
1013 ///
1014 /// - `axis`: The axis along which to compute the logarithm of the sum of exponentials.
1015 /// - `keep_dims`: Whether to retain the reduced dimensions in the result.
1016 ///
1017 /// # Returns
1018 ///
1019 /// - `anyhow::Result<Self::Output>`: A tensor containing the log-sum-exp values along the specified axis.
1020 #[track_caller]
1021 fn logsumexp<S: Into<Axis>>(
1022 &self,
1023 axis: S,
1024 keep_dims: bool,
1025 ) -> Result<Self::Output, TensorError>;
1026}
1027
1028/// Common bounds for primitive types
1029pub trait CommonBounds
1030where
1031 <Self as TypeCommon>::Vec: Send + Sync + Copy,
1032 Self: Sync
1033 + Send
1034 + Clone
1035 + Copy
1036 + TypeCommon
1037 + 'static
1038 + Display
1039 + Debug
1040 + Cast<Self>
1041 + NormalOut<Self, Output = Self>
1042 + FloatOutUnary
1043 + NormalOut<<Self as FloatOutUnary>::Output, Output = <Self as FloatOutUnary>::Output>
1044 + FloatOutBinary<<Self as FloatOutUnary>::Output, Output = <Self as FloatOutUnary>::Output>
1045 + FloatOutBinary<Self>
1046 + NormalOut<
1047 <Self as FloatOutBinary<Self>>::Output,
1048 Output = <Self as FloatOutBinary<Self>>::Output,
1049 >
1050 + NormalOutUnary
1051 + Cast<f64>,
1052{
1053}
1054impl<T> CommonBounds for T
1055where
1056 <Self as TypeCommon>::Vec: Send + Sync + Copy,
1057 Self: Sync
1058 + Send
1059 + Clone
1060 + Copy
1061 + TypeCommon
1062 + 'static
1063 + Display
1064 + Debug
1065 + Cast<Self>
1066 + NormalOut<Self, Output = Self>
1067 + FloatOutUnary
1068 + NormalOut<<Self as FloatOutUnary>::Output, Output = <Self as FloatOutUnary>::Output>
1069 + FloatOutBinary<<Self as FloatOutUnary>::Output, Output = <Self as FloatOutUnary>::Output>
1070 + FloatOutBinary<Self>
1071 + FloatOutBinary<
1072 <Self as FloatOutBinary<Self>>::Output,
1073 Output = <Self as FloatOutBinary<Self>>::Output,
1074 >
1075 + NormalOut<
1076 <Self as FloatOutBinary<Self>>::Output,
1077 Output = <Self as FloatOutBinary<Self>>::Output,
1078 >
1079 + NormalOutUnary
1080 + Cast<f64>,
1081{
1082}