burn_jit/kernel/reduce/
base.rs1use crate::{element::JitElement, ops::numeric::empty_device, tensor::JitTensor, JitRuntime};
2
3use super::autotune_reduce;
4
5pub use cubecl::reduce::instructions::{ArgMax, ArgMin, Mean, Prod, Sum};
6
7pub fn reduce<Run: JitRuntime, In: JitElement, Out: JitElement, Rd: cubecl::reduce::Reduce>(
16 mut input: JitTensor<Run>,
17 strategy: ReduceStrategy,
18) -> Result<JitTensor<Run>, cubecl::reduce::ReduceError> {
19 input.shape = input.shape.flatten();
20 input.strides = vec![1];
21 reduce_dim::<Run, In, Out, Rd>(input, 0, strategy)
22}
23
24pub fn reduce_dim<Run: JitRuntime, In: JitElement, Out: JitElement, Rd: cubecl::reduce::Reduce>(
33 input: JitTensor<Run>,
34 dim: usize,
35 strategy: ReduceStrategy,
36) -> Result<JitTensor<Run>, cubecl::reduce::ReduceError> {
37 let client = input.client.clone();
38 let output = init_reduce_output::<Run, In, Out>(&input, dim).ok_or(
39 cubecl::reduce::ReduceError::InvalidAxis {
40 axis: dim,
41 rank: input.shape.num_dims(),
42 },
43 )?;
44 let result = match strategy {
45 ReduceStrategy::Unspecified => cubecl::reduce::reduce::<Run, In, Out, Rd>(
46 &client,
47 input.as_handle_ref(),
48 output.as_handle_ref(),
49 dim,
50 None,
51 ),
52 ReduceStrategy::Specific(strategy) => cubecl::reduce::reduce::<Run, In, Out, Rd>(
53 &client,
54 input.as_handle_ref(),
55 output.as_handle_ref(),
56 dim,
57 Some(strategy),
58 ),
59 #[cfg(feature = "autotune")]
60 ReduceStrategy::Autotune => {
61 autotune_reduce::<Run, In, Out, Rd>(&client, input, output.clone(), dim)
62 }
63 };
64 result.map(|_| output)
65}
66
67pub fn init_reduce_output<Run: JitRuntime, In: JitElement, Out: JitElement>(
70 input: &JitTensor<Run>,
71 dim: usize,
72) -> Option<JitTensor<Run>> {
73 (dim < input.shape.num_dims()).then(|| {
74 let mut shape_out = input.shape.clone();
75 shape_out.dims[dim] = 1;
76 empty_device::<Run, Out>(input.client.clone(), input.device.clone(), shape_out)
77 })
78}
79
80#[derive(Copy, Clone, Debug)]
82pub enum ReduceStrategy {
83 Unspecified,
86 Specific(cubecl::reduce::ReduceStrategy),
88 #[cfg(feature = "autotune")]
90 Autotune,
91}
92
93impl Default for ReduceStrategy {
94 fn default() -> Self {
95 #[cfg(feature = "autotune")]
96 return Self::Autotune;
97
98 #[cfg(not(feature = "autotune"))]
99 return Self::Unspecified;
100 }
101}