rstsr_core/storage/
reduction.rs

1use crate::prelude_dev::*;
2
3#[allow(clippy::type_complexity)]
4#[duplicate_item(
5    OpReduceAPI   func           func_all    ;
6   [OpSumAPI   ] [sum_axes    ] [sum_all    ];
7   [OpMinAPI   ] [min_axes    ] [min_all    ];
8   [OpMaxAPI   ] [max_axes    ] [max_all    ];
9   [OpProdAPI  ] [prod_axes   ] [prod_all   ];
10   [OpMeanAPI  ] [mean_axes   ] [mean_all   ];
11   [OpVarAPI   ] [var_axes    ] [var_all    ];
12   [OpStdAPI   ] [std_axes    ] [std_all    ];
13   [OpL2NormAPI] [l2_norm_axes] [l2_norm_all];
14   [OpArgMinAPI] [argmin_axes ] [argmin_all ];
15   [OpArgMaxAPI] [argmax_axes ] [argmax_all ];
16   [OpAllAPI   ] [all_axes    ] [all_all    ];
17   [OpAnyAPI   ] [any_axes    ] [any_all    ];
18   [OpCountNonZeroAPI] [count_nonzero_axes] [count_nonzero_all];
19)]
20pub trait OpReduceAPI<T, D>
21where
22    D: DimAPI,
23    Self: DeviceAPI<T> + DeviceAPI<Self::TOut>,
24{
25    type TOut;
26    fn func_all(&self, a: &<Self as DeviceRawAPI<T>>::Raw, la: &Layout<D>) -> Result<Self::TOut>;
27    fn func(
28        &self,
29        a: &<Self as DeviceRawAPI<T>>::Raw,
30        la: &Layout<D>,
31        axes: &[isize],
32    ) -> Result<(Storage<DataOwned<<Self as DeviceRawAPI<Self::TOut>>::Raw>, Self::TOut, Self>, Layout<IxD>)>;
33}
34
35#[allow(clippy::type_complexity)]
36#[duplicate_item(
37    OpReduceAPI            func                    func_all             ;
38   [OpUnraveledArgMinAPI] [unraveled_argmin_axes] [unraveled_argmin_all];
39   [OpUnraveledArgMaxAPI] [unraveled_argmax_axes] [unraveled_argmax_all];
40)]
41pub trait OpReduceAPI<T, D>
42where
43    D: DimAPI,
44    Self: DeviceAPI<T>,
45{
46    fn func_all(&self, a: &<Self as DeviceRawAPI<T>>::Raw, la: &Layout<D>) -> Result<D>;
47    fn func(
48        &self,
49        a: &<Self as DeviceRawAPI<T>>::Raw,
50        la: &Layout<D>,
51        axes: &[isize],
52    ) -> Result<(Storage<DataOwned<<Self as DeviceRawAPI<IxD>>::Raw>, IxD, Self>, Layout<IxD>)>
53    where
54        Self: DeviceAPI<IxD>;
55}
56
57#[allow(clippy::type_complexity)]
58pub trait OpSumBoolAPI<D>
59where
60    D: DimAPI,
61    Self: DeviceAPI<bool> + DeviceAPI<usize>,
62{
63    fn sum_all(&self, a: &<Self as DeviceRawAPI<bool>>::Raw, la: &Layout<D>) -> Result<usize>;
64    fn sum_axes(
65        &self,
66        a: &<Self as DeviceRawAPI<bool>>::Raw,
67        la: &Layout<D>,
68        axes: &[isize],
69    ) -> Result<(Storage<DataOwned<<Self as DeviceRawAPI<usize>>::Raw>, usize, Self>, Layout<IxD>)>;
70}
71
72#[allow(clippy::type_complexity)]
73pub trait OpAllCloseAPI<TA, TB, TE, D>
74where
75    D: DimAPI,
76    Self: DeviceAPI<TA> + DeviceAPI<TB> + DeviceAPI<bool>,
77{
78    fn allclose_all(
79        &self,
80        a: &<Self as DeviceRawAPI<TA>>::Raw,
81        la: &Layout<D>,
82        b: &<Self as DeviceRawAPI<TB>>::Raw,
83        lb: &Layout<D>,
84        isclose_args: &IsCloseArgs<TE>,
85    ) -> Result<bool>;
86    fn allclose_axes(
87        &self,
88        a: &<Self as DeviceRawAPI<TA>>::Raw,
89        la: &Layout<D>,
90        b: &<Self as DeviceRawAPI<TB>>::Raw,
91        lb: &Layout<D>,
92        axes: &[isize],
93        isclose_args: &IsCloseArgs<TE>,
94    ) -> Result<(Storage<DataOwned<<Self as DeviceRawAPI<bool>>::Raw>, bool, Self>, Layout<IxD>)>;
95}