use crate::{Result, Shape, Tensor, TensorError};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ReductionOp {
Sum,
Prod,
Max,
Min,
Mean,
Variance,
StdDev,
L1Norm,
L2Norm,
Any,
All,
}
impl ReductionOp {
pub fn name(&self) -> &'static str {
match self {
Self::Sum => "sum",
Self::Prod => "prod",
Self::Max => "max",
Self::Min => "min",
Self::Mean => "mean",
Self::Variance => "var",
Self::StdDev => "std",
Self::L1Norm => "l1_norm",
Self::L2Norm => "l2_norm",
Self::Any => "any",
Self::All => "all",
}
}
pub fn identity_f32(&self) -> f32 {
match self {
Self::Sum | Self::Mean | Self::L1Norm | Self::L2Norm => 0.0,
Self::Prod => 1.0,
Self::Max => f32::NEG_INFINITY,
Self::Min => f32::INFINITY,
Self::Any => 0.0,
Self::All => 1.0,
Self::Variance | Self::StdDev => 0.0,
}
}
pub fn requires_two_passes(&self) -> bool {
matches!(self, Self::Variance | Self::StdDev | Self::Mean)
}
}
#[derive(Debug, Clone)]
pub struct GpuReductionConfig {
pub block_size: usize,
pub use_shared_memory: bool,
pub use_warp_primitives: bool,
pub max_workgroup_size: usize,
}
impl Default for GpuReductionConfig {
fn default() -> Self {
Self {
block_size: 256,
use_shared_memory: true,
use_warp_primitives: true,
max_workgroup_size: 1024,
}
}
}
#[cfg(feature = "gpu")]
pub fn gpu_reduce_axis<T>(
tensor: &Tensor<T>,
axis: usize,
op: ReductionOp,
keep_dims: bool,
) -> Result<Tensor<T>>
where
T: scirs2_core::num_traits::Float
+ Default
+ bytemuck::Pod
+ Send
+ Sync
+ 'static
+ scirs2_core::num_traits::FromPrimitive
+ scirs2_core::num_traits::ops::mul_add::MulAdd
+ scirs2_core::ndarray::ScalarOperand,
{
let shape = tensor.shape();
if axis >= shape.rank() {
return Err(TensorError::InvalidAxis {
operation: "reduce".to_string(),
axis: axis as i32,
ndim: shape.rank(),
context: None,
});
}
if !tensor.device().is_gpu() {
return super::statistical::reduce_axis_cpu(tensor, axis, op, keep_dims);
}
let gpu_op = super::gpu_execution::GpuReductionOp::from(op);
let result = super::gpu_execution::execute_gpu_reduction(tensor, axis, gpu_op, keep_dims)?;
match op {
ReductionOp::StdDev => {
result.sqrt()
}
ReductionOp::L2Norm => {
super::statistical::reduce_axis_cpu(tensor, axis, op, keep_dims)
}
ReductionOp::L1Norm => {
super::statistical::reduce_axis_cpu(tensor, axis, op, keep_dims)
}
_ => Ok(result),
}
}
#[cfg(feature = "gpu")]
pub fn gpu_reduce_all<T>(tensor: &Tensor<T>, op: ReductionOp) -> Result<T>
where
T: scirs2_core::num_traits::Float
+ Default
+ bytemuck::Pod
+ Send
+ Sync
+ 'static
+ scirs2_core::num_traits::FromPrimitive
+ scirs2_core::num_traits::ops::mul_add::MulAdd
+ scirs2_core::ndarray::ScalarOperand,
{
if !tensor.device().is_gpu() {
return super::statistical::reduce_all_cpu(tensor, op);
}
let mut result = tensor.clone();
let rank = tensor.shape().rank();
for _ in 0..rank {
result = gpu_reduce_axis(&result, 0, op, false)?;
}
let data = result.data();
if data.is_empty() {
return Err(TensorError::invalid_operation_simple(
"Reduction produced empty result".to_string(),
));
}
Ok(data[0])
}
pub fn gpu_sum_axis<T>(tensor: &Tensor<T>, axis: usize, keep_dims: bool) -> Result<Tensor<T>>
where
T: scirs2_core::num_traits::Float
+ Default
+ bytemuck::Pod
+ Send
+ Sync
+ 'static
+ scirs2_core::num_traits::FromPrimitive
+ scirs2_core::num_traits::ops::mul_add::MulAdd
+ scirs2_core::ndarray::ScalarOperand,
{
#[cfg(feature = "gpu")]
{
if tensor.device().is_gpu() {
return gpu_reduce_axis(tensor, axis, ReductionOp::Sum, keep_dims);
}
}
super::statistical::reduce_axis_cpu(tensor, axis, ReductionOp::Sum, keep_dims)
}
pub fn gpu_mean_axis<T>(tensor: &Tensor<T>, axis: usize, keep_dims: bool) -> Result<Tensor<T>>
where
T: scirs2_core::num_traits::Float
+ Default
+ bytemuck::Pod
+ Send
+ Sync
+ 'static
+ scirs2_core::num_traits::FromPrimitive
+ scirs2_core::num_traits::ops::mul_add::MulAdd
+ scirs2_core::ndarray::ScalarOperand,
{
#[cfg(feature = "gpu")]
{
if tensor.device().is_gpu() {
return gpu_reduce_axis(tensor, axis, ReductionOp::Mean, keep_dims);
}
}
super::statistical::reduce_axis_cpu(tensor, axis, ReductionOp::Mean, keep_dims)
}
pub fn gpu_max_axis<T>(tensor: &Tensor<T>, axis: usize, keep_dims: bool) -> Result<Tensor<T>>
where
T: scirs2_core::num_traits::Float
+ Default
+ bytemuck::Pod
+ Send
+ Sync
+ 'static
+ scirs2_core::num_traits::FromPrimitive
+ scirs2_core::num_traits::ops::mul_add::MulAdd
+ scirs2_core::ndarray::ScalarOperand,
{
#[cfg(feature = "gpu")]
{
if tensor.device().is_gpu() {
return gpu_reduce_axis(tensor, axis, ReductionOp::Max, keep_dims);
}
}
super::statistical::reduce_axis_cpu(tensor, axis, ReductionOp::Max, keep_dims)
}
pub fn gpu_min_axis<T>(tensor: &Tensor<T>, axis: usize, keep_dims: bool) -> Result<Tensor<T>>
where
T: scirs2_core::num_traits::Float
+ Default
+ bytemuck::Pod
+ Send
+ Sync
+ 'static
+ scirs2_core::num_traits::FromPrimitive
+ scirs2_core::num_traits::ops::mul_add::MulAdd
+ scirs2_core::ndarray::ScalarOperand,
{
#[cfg(feature = "gpu")]
{
if tensor.device().is_gpu() {
return gpu_reduce_axis(tensor, axis, ReductionOp::Min, keep_dims);
}
}
super::statistical::reduce_axis_cpu(tensor, axis, ReductionOp::Min, keep_dims)
}
pub fn gpu_var_axis<T>(tensor: &Tensor<T>, axis: usize, keep_dims: bool) -> Result<Tensor<T>>
where
T: scirs2_core::num_traits::Float
+ Default
+ bytemuck::Pod
+ Send
+ Sync
+ 'static
+ scirs2_core::num_traits::FromPrimitive
+ scirs2_core::num_traits::ops::mul_add::MulAdd
+ scirs2_core::ndarray::ScalarOperand,
{
#[cfg(feature = "gpu")]
{
if tensor.device().is_gpu() {
return gpu_reduce_axis(tensor, axis, ReductionOp::Variance, keep_dims);
}
}
super::statistical::reduce_axis_cpu(tensor, axis, ReductionOp::Variance, keep_dims)
}
pub fn gpu_std_axis<T>(tensor: &Tensor<T>, axis: usize, keep_dims: bool) -> Result<Tensor<T>>
where
T: scirs2_core::num_traits::Float
+ Default
+ bytemuck::Pod
+ Send
+ Sync
+ 'static
+ scirs2_core::num_traits::FromPrimitive
+ scirs2_core::num_traits::ops::mul_add::MulAdd
+ scirs2_core::ndarray::ScalarOperand,
{
#[cfg(feature = "gpu")]
{
if tensor.device().is_gpu() {
return gpu_reduce_axis(tensor, axis, ReductionOp::StdDev, keep_dims);
}
}
super::statistical::reduce_axis_cpu(tensor, axis, ReductionOp::StdDev, keep_dims)
}
pub fn gpu_l1_norm_axis<T>(tensor: &Tensor<T>, axis: usize, keep_dims: bool) -> Result<Tensor<T>>
where
T: scirs2_core::num_traits::Float
+ Default
+ bytemuck::Pod
+ Send
+ Sync
+ 'static
+ scirs2_core::num_traits::FromPrimitive
+ scirs2_core::num_traits::ops::mul_add::MulAdd
+ scirs2_core::ndarray::ScalarOperand,
{
#[cfg(feature = "gpu")]
{
if tensor.device().is_gpu() {
return gpu_reduce_axis(tensor, axis, ReductionOp::L1Norm, keep_dims);
}
}
super::statistical::reduce_axis_cpu(tensor, axis, ReductionOp::L1Norm, keep_dims)
}
pub fn gpu_l2_norm_axis<T>(tensor: &Tensor<T>, axis: usize, keep_dims: bool) -> Result<Tensor<T>>
where
T: scirs2_core::num_traits::Float
+ Default
+ bytemuck::Pod
+ Send
+ Sync
+ 'static
+ scirs2_core::num_traits::FromPrimitive
+ scirs2_core::num_traits::ops::mul_add::MulAdd
+ scirs2_core::ndarray::ScalarOperand,
{
#[cfg(feature = "gpu")]
{
if tensor.device().is_gpu() {
return gpu_reduce_axis(tensor, axis, ReductionOp::L2Norm, keep_dims);
}
}
super::statistical::reduce_axis_cpu(tensor, axis, ReductionOp::L2Norm, keep_dims)
}
#[cfg(feature = "gpu")]
pub struct TreeReduction {
config: GpuReductionConfig,
}
#[cfg(feature = "gpu")]
impl TreeReduction {
pub fn new(config: GpuReductionConfig) -> Self {
Self { config }
}
pub fn num_steps(&self, input_size: usize) -> usize {
(input_size as f64).log2().ceil() as usize
}
pub fn workgroup_size(&self, input_size: usize) -> usize {
self.config
.block_size
.min(input_size)
.min(self.config.max_workgroup_size)
}
}
#[cfg(feature = "gpu")]
pub struct WarpReduction {
warp_size: usize,
}
#[cfg(feature = "gpu")]
impl WarpReduction {
pub fn new() -> Self {
Self { warp_size: 32 }
}
pub fn warp_size(&self) -> usize {
self.warp_size
}
pub fn can_use_warp_primitives(&self, size: usize) -> bool {
size >= self.warp_size
}
}
#[cfg(feature = "gpu")]
impl Default for WarpReduction {
fn default() -> Self {
Self::new()
}
}
pub struct MultiStageReduction {
pub stage_size: usize,
pub num_stages: usize,
}
impl MultiStageReduction {
pub fn new(total_elements: usize, max_stage_size: usize) -> Self {
let num_stages = (total_elements + max_stage_size - 1) / max_stage_size;
Self {
stage_size: max_stage_size,
num_stages,
}
}
pub fn stage_size_for(&self, stage: usize) -> usize {
if stage < self.num_stages - 1 {
self.stage_size
} else {
self.stage_size
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_reduction_op_identity() {
assert_eq!(ReductionOp::Sum.identity_f32(), 0.0);
assert_eq!(ReductionOp::Prod.identity_f32(), 1.0);
assert_eq!(ReductionOp::Max.identity_f32(), f32::NEG_INFINITY);
assert_eq!(ReductionOp::Min.identity_f32(), f32::INFINITY);
}
#[test]
fn test_reduction_op_two_passes() {
assert!(ReductionOp::Variance.requires_two_passes());
assert!(ReductionOp::StdDev.requires_two_passes());
assert!(!ReductionOp::Sum.requires_two_passes());
}
#[test]
fn test_multi_stage_reduction() {
let reduction = MultiStageReduction::new(10_000_000, 1_000_000);
assert_eq!(reduction.num_stages, 10);
assert_eq!(reduction.stage_size, 1_000_000);
}
#[cfg(feature = "gpu")]
#[test]
fn test_tree_reduction() {
let config = GpuReductionConfig::default();
let tree = TreeReduction::new(config);
assert_eq!(tree.num_steps(1024), 10); assert_eq!(tree.num_steps(256), 8); }
#[cfg(feature = "gpu")]
#[test]
fn test_warp_reduction() {
let warp = WarpReduction::new();
assert_eq!(warp.warp_size(), 32);
assert!(warp.can_use_warp_primitives(64));
assert!(warp.can_use_warp_primitives(32));
assert!(!warp.can_use_warp_primitives(16));
}
}