#[cfg(feature = "autotune")]
use super::{autotune_reduce, autotune_sum};
use crate::{
CubeRuntime,
ops::numeric::{empty_device_contiguous_dtype, zeros_client},
tensor::CubeTensor,
};
use burn_backend::{DType, TensorMetadata};
use burn_std::Metadata;
use cubecl::{
AutotuneKey,
client::ComputeClient,
features::AtomicUsage,
ir::{StorageType, Type},
};
use cubek::reduce::{
ReduceDtypes, ReduceError, ReduceStrategy,
components::instructions::ReduceOperationConfig,
launch::{RoutineStrategy, VectorizationStrategy},
routines::{BlueprintStrategy, unit::UnitStrategy},
shared_sum,
};
use serde::{Deserialize, Serialize};
#[derive(Hash, Eq, PartialEq, Debug, Clone, Serialize, Deserialize, AutotuneKey)]
pub struct SumAutotuneKey {
dtype: burn_backend::DType,
#[autotune(anchor)]
length: usize,
}
fn supports_atomic_add<R: CubeRuntime>(client: &ComputeClient<R>, dtype: DType) -> bool {
client
.properties()
.atomic_type_usage(Type::new(StorageType::Atomic(dtype.into())))
.contains(AtomicUsage::Add)
}
pub fn sum_fallback<R: CubeRuntime>(
tensor: CubeTensor<R>,
mut strategy: SumStrategy,
) -> Result<CubeTensor<R>, ReduceError> {
if matches!(strategy, SumStrategy::OneShot(_))
&& !supports_atomic_add(&tensor.client, tensor.dtype)
{
strategy = SumStrategy::Chained(Default::default());
}
sum(tensor, strategy)
}
pub fn sum<Run: CubeRuntime>(
tensor: CubeTensor<Run>,
strategy: SumStrategy,
) -> Result<CubeTensor<Run>, ReduceError> {
let client = tensor.client.clone();
let device = tensor.device.clone();
match strategy {
SumStrategy::OneShot(cube_count) => {
let output = zeros_client(client.clone(), device, [1].into(), tensor.dtype);
let dtype = tensor.dtype;
shared_sum::<Run>(
&client,
tensor.binding(),
output.clone().binding(),
cube_count,
dtype.into(),
)?;
Ok(output)
}
SumStrategy::Chained(strategy) => {
reduce::<Run>(tensor, None, strategy, ReduceOperationConfig::Sum)
}
#[cfg(feature = "autotune")]
SumStrategy::Autotune => Ok(autotune_sum::<Run>(&client, tensor)),
}
}
pub enum SumStrategy {
OneShot(u32),
Chained(KernelReduceStrategy),
#[cfg(feature = "autotune")]
Autotune,
}
impl Default for SumStrategy {
fn default() -> Self {
#[cfg(feature = "autotune")]
return Self::Autotune;
#[cfg(not(feature = "autotune"))]
return Self::OneShot(4);
}
}
pub fn reduce<Run: CubeRuntime>(
mut tensor: CubeTensor<Run>,
output_dtype: Option<DType>,
strategy: KernelReduceStrategy,
config: ReduceOperationConfig,
) -> Result<CubeTensor<Run>, cubek::reduce::ReduceError> {
let sorted_axis = argsort(tensor.meta.shape());
for axis in sorted_axis {
tensor = reduce_dim::<Run>(tensor, output_dtype, axis, strategy.clone(), config)?;
}
*tensor.meta = Metadata::new([1], [1]);
Ok(tensor)
}
fn argsort(shape: &[usize]) -> Vec<usize> {
let mut indices = (0..shape.len()).collect::<Vec<_>>();
indices.sort_by_key(|&i| &shape[i]);
indices
}
pub fn reduce_dim<Run: CubeRuntime>(
input: CubeTensor<Run>,
output_dtype: Option<DType>,
dim: usize,
strategy: KernelReduceStrategy,
config: ReduceOperationConfig,
) -> Result<CubeTensor<Run>, cubek::reduce::ReduceError> {
debug_assert!(
!matches!(
config,
ReduceOperationConfig::ArgMax | ReduceOperationConfig::ArgMin
) || output_dtype.is_some(),
"The `output_dtype` has to be `Some` only when the `config` is `ArgMax` or `ArgMin`.
"
);
let dtypes = config.precision(input.dtype.into(), output_dtype.map(Into::into));
let client = input.client.clone();
let output = init_reduce_output::<Run>(&input, dim, &dtypes).ok_or(
cubek::reduce::ReduceError::InvalidAxis {
axis: dim,
rank: input.meta.num_dims(),
},
)?;
let result = match strategy {
KernelReduceStrategy::Unspecified => cubek::reduce::reduce::<Run>(
&client,
input.binding(),
output.clone().binding(),
dim,
ReduceStrategy {
routine: RoutineStrategy::Unit(BlueprintStrategy::Inferred(UnitStrategy)),
vectorization: VectorizationStrategy {
parallel_output_vectorization: false,
},
},
config,
dtypes,
),
KernelReduceStrategy::Specific(strategy) => cubek::reduce::reduce::<Run>(
&client,
input.binding(),
output.clone().binding(),
dim,
strategy,
config,
dtypes,
),
#[cfg(feature = "autotune")]
KernelReduceStrategy::Autotune => {
autotune_reduce::<Run>(&client, input, output.clone(), dim, config, dtypes);
Ok(())
}
};
result.map(|_| output)
}
pub fn init_reduce_output<Run: CubeRuntime>(
input: &CubeTensor<Run>,
dim: usize,
dtypes: &ReduceDtypes,
) -> Option<CubeTensor<Run>> {
(dim < input.meta.num_dims()).then(|| {
let mut shape_out = input.shape();
shape_out[dim] = 1;
empty_device_contiguous_dtype(
input.client.clone(),
input.device.clone(),
shape_out,
dtypes.output.elem_type().into(),
)
})
}
#[derive(Clone, Debug)]
pub enum KernelReduceStrategy {
Unspecified,
Specific(cubek::reduce::launch::ReduceStrategy),
#[cfg(feature = "autotune")]
Autotune,
}
impl Default for KernelReduceStrategy {
fn default() -> Self {
#[cfg(feature = "autotune")]
return Self::Autotune;
#[cfg(not(feature = "autotune"))]
return Self::Unspecified;
}
}