#![allow(
clippy::type_complexity,
reason = "Too sensitive, triggers on tuple of vector."
)]
pub mod components;
pub mod launch;
pub mod routines;
mod error;
#[cfg(feature = "cpu-reference")]
pub mod cpu_reference;
pub use crate::launch::ReduceStrategy;
use crate::{components::instructions::ReduceOperationConfig, launch::launch_reduce};
pub use components::{
args::init_tensors,
config::*,
instructions::{ReduceFamily, ReduceInstruction},
precision::ReducePrecision,
};
use cubecl::prelude::*;
pub use error::*;
pub use launch::{ReduceDtypes, reduce_kernel};
pub use routines::shared_sum::shared_sum;
pub fn reduce<R: Runtime>(
client: &ComputeClient<R>,
input: TensorBinding<R>,
output: TensorBinding<R>,
axis: usize,
strategy: ReduceStrategy,
operation: ReduceOperationConfig,
dtypes: ReduceDtypes,
) -> Result<(), ReduceError> {
validate_axis(input.shape.len(), axis)?;
validate_shapes(
&input.shape,
&output.shape,
axis,
match operation {
ReduceOperationConfig::ArgTopK(k) => Some(k),
ReduceOperationConfig::TopK(k) => Some(k),
_ => None,
},
)?;
launch_reduce::<R>(client, input, output, axis, strategy, dtypes, operation)
}
fn validate_axis(rank: usize, axis: usize) -> Result<(), ReduceError> {
if axis > rank {
return Err(ReduceError::InvalidAxis { axis, rank });
}
Ok(())
}
fn validate_shapes(
input_shape: &[usize],
output_shape: &[usize],
axis: usize,
k: Option<usize>,
) -> Result<(), ReduceError> {
let mut expected_shape = input_shape.to_vec();
let k = k.unwrap_or(1);
if expected_shape[axis] < k {
return Err(ReduceError::ReduceAxisTooSmall {
axis_length: expected_shape[axis],
k,
});
}
expected_shape[axis] = k;
if output_shape != expected_shape {
return Err(ReduceError::MismatchOutputShape {
expected_shape,
output_shape: output_shape.to_vec(),
});
}
Ok(())
}