hpt/backends/cpu/tensor_external/
arg_reduce.rs1use hpt_allocator::Cpu;
2use hpt_common::axis::axis::Axis;
3use hpt_common::error::base::TensorError;
4use hpt_traits::{ops::reduce::IndexReduce, tensor::CommonBounds};
5use hpt_types::type_promote::{Cmp, NormalOut};
6
7use crate::tensor::{DiffTensor, Tensor};
8
9impl<T: CommonBounds + NormalOut<Output = T> + Cmp<T, Output = bool>, const DEVICE: usize>
10 IndexReduce for Tensor<T, Cpu, DEVICE>
11{
12 type Output = Tensor<i64, Cpu, DEVICE>;
13
14 fn argmax<S: Into<Axis>>(
15 &self,
16 axis: S,
17 keep_dims: bool,
18 ) -> std::result::Result<Self::Output, TensorError> {
19 Ok(self.inner.argmax(axis, keep_dims)?.into())
20 }
21
22 fn argmin<S: Into<Axis>>(
23 &self,
24 axis: S,
25 keep_dims: bool,
26 ) -> std::result::Result<Self::Output, TensorError> {
27 Ok(self.inner.argmin(axis, keep_dims)?.into())
28 }
29}
30
31impl<T: CommonBounds + NormalOut<Output = T> + Cmp<T, Output = bool>, const DEVICE: usize>
32 IndexReduce for DiffTensor<T, Cpu, DEVICE>
33{
34 type Output = Tensor<i64, Cpu, DEVICE>;
35
36 fn argmax<S: Into<Axis>>(
37 &self,
38 axis: S,
39 keep_dims: bool,
40 ) -> std::result::Result<Self::Output, TensorError> {
41 Ok(self.inner.argmax(axis, keep_dims)?.into())
42 }
43
44 fn argmin<S: Into<Axis>>(
45 &self,
46 axis: S,
47 keep_dims: bool,
48 ) -> std::result::Result<Self::Output, TensorError> {
49 Ok(self.inner.argmin(axis, keep_dims)?.into())
50 }
51}