pub mod args;
pub mod instructions;
pub mod primitives;
pub mod tune_key;
mod config;
mod error;
mod launch;
mod shared_sum;
mod strategy;
pub use config::*;
pub use error::*;
pub use instructions::ReduceFamily;
pub use instructions::ReduceInstruction;
pub use shared_sum::*;
pub use strategy::*;
use launch::*;
pub use launch::{ReduceParams, reduce_kernel};
#[cfg(feature = "export_tests")]
pub mod test;
use cubecl_core::prelude::*;
pub fn reduce<R: Runtime, In: Numeric, Out: Numeric, Inst: ReduceFamily>(
client: &ComputeClient<R::Server, R::Channel>,
input: TensorHandleRef<R>,
output: TensorHandleRef<R>,
axis: usize,
strategy: Option<ReduceStrategy>,
inst_config: Inst::Config,
) -> Result<(), ReduceError> {
validate_axis(input.shape.len(), axis)?;
valid_output_shape(input.shape, output.shape, axis)?;
let strategy = strategy
.map(|s| s.validate::<R>(client))
.unwrap_or(Ok(ReduceStrategy::new::<R>(client, true)))?;
let config = ReduceConfig::generate::<R, In>(client, &input, &output, axis, &strategy);
if let CubeCount::Static(x, y, z) = config.cube_count {
let (max_x, max_y, max_z) = R::max_cube_count();
if x > max_x || y > max_y || z > max_z {
return Err(ReduceError::CubeCountTooLarge);
}
}
launch_reduce::<R, In, Out, Inst>(
client,
input,
output,
axis as u32,
config,
strategy,
inst_config,
);
Ok(())
}
fn validate_axis(rank: usize, axis: usize) -> Result<(), ReduceError> {
if axis > rank {
return Err(ReduceError::InvalidAxis { axis, rank });
}
Ok(())
}
fn valid_output_shape(
input_shape: &[usize],
output_shape: &[usize],
axis: usize,
) -> Result<(), ReduceError> {
let mut expected_shape = input_shape.to_vec();
expected_shape[axis] = 1;
if output_shape != expected_shape {
return Err(ReduceError::MismatchShape {
expected_shape,
output_shape: output_shape.to_vec(),
});
}
Ok(())
}