hpt/backends/cpu/tensor_external/
arg_reduce.rs

1use hpt_allocator::Cpu;
2use hpt_common::axis::axis::Axis;
3use hpt_common::error::base::TensorError;
4use hpt_traits::{ops::reduce::IndexReduce, tensor::CommonBounds};
5use hpt_types::type_promote::{Cmp, NormalOut};
6
7use crate::tensor::{DiffTensor, Tensor};
8
9impl<T: CommonBounds + NormalOut<Output = T> + Cmp<T, Output = bool>, const DEVICE: usize>
10    IndexReduce for Tensor<T, Cpu, DEVICE>
11{
12    type Output = Tensor<i64, Cpu, DEVICE>;
13
14    fn argmax<S: Into<Axis>>(
15        &self,
16        axis: S,
17        keep_dims: bool,
18    ) -> std::result::Result<Self::Output, TensorError> {
19        Ok(self.inner.argmax(axis, keep_dims)?.into())
20    }
21
22    fn argmin<S: Into<Axis>>(
23        &self,
24        axis: S,
25        keep_dims: bool,
26    ) -> std::result::Result<Self::Output, TensorError> {
27        Ok(self.inner.argmin(axis, keep_dims)?.into())
28    }
29}
30
31impl<T: CommonBounds + NormalOut<Output = T> + Cmp<T, Output = bool>, const DEVICE: usize>
32    IndexReduce for DiffTensor<T, Cpu, DEVICE>
33{
34    type Output = Tensor<i64, Cpu, DEVICE>;
35
36    fn argmax<S: Into<Axis>>(
37        &self,
38        axis: S,
39        keep_dims: bool,
40    ) -> std::result::Result<Self::Output, TensorError> {
41        Ok(self.inner.argmax(axis, keep_dims)?.into())
42    }
43
44    fn argmin<S: Into<Axis>>(
45        &self,
46        axis: S,
47        keep_dims: bool,
48    ) -> std::result::Result<Self::Output, TensorError> {
49        Ok(self.inner.argmin(axis, keep_dims)?.into())
50    }
51}