use cubecl::{
prelude::*,
std::tensor::{
layout::{Coords1d, Coords2d, Layout, LayoutExpand},
r#virtual::VirtualTensor,
},
};
use crate::components::args::NumericVector;
#[derive(CubeType, Clone)]
pub struct ReduceOutputLayout {
k_stride: usize,
write_stride: usize,
num_writes: usize,
accumulator_length: usize,
}
#[cube]
impl ReduceOutputLayout {
pub fn new(
k_stride: usize,
write_stride: usize,
num_writes: usize,
accumulator_length: usize,
) -> ReduceOutputLayout {
ReduceOutputLayout {
k_stride,
write_stride,
num_writes,
accumulator_length,
}
}
}
#[cube]
impl Layout for ReduceOutputLayout {
type Coordinates = Coords2d;
type SourceCoordinates = Coords1d;
fn to_source_pos(&self, coords: Self::Coordinates) -> Coords1d {
let write_index = coords.0 as usize;
let k_iter = coords.1 as usize;
k_iter * self.k_stride + write_index * self.write_stride
}
fn to_source_pos_checked(&self, coords: Self::Coordinates) -> (Coords1d, bool) {
(self.to_source_pos(coords), self.is_in_bounds(coords))
}
fn shape(&self) -> Self::Coordinates {
(self.num_writes as u32, self.accumulator_length as u32)
}
fn is_in_bounds(&self, pos: Self::Coordinates) -> bool {
pos.0 < self.num_writes as u32 && pos.1 < self.accumulator_length as u32
}
}
#[cube]
pub(crate) fn build_reduce_output_layout<Out: NumericVector>(
output: &VirtualTensor<Out::T, Out::N, ReadWrite>,
reduce_axis: usize,
out_vec_axis: usize,
#[comptime] accumulator_length: usize,
) -> ReduceOutputLayout {
let vec = output.vector_size();
let num_vectored_reductions = output.shape(out_vec_axis) / vec;
if comptime![accumulator_length == 1] {
ReduceOutputLayout::new(
num_vectored_reductions,
1,
num_vectored_reductions,
accumulator_length,
)
} else {
let k_stride = output.stride(reduce_axis) / vec;
let distinct = usize::cast_from(reduce_axis != out_vec_axis);
let write_stride = distinct;
let num_writes = distinct * num_vectored_reductions + (1 - distinct);
ReduceOutputLayout::new(k_stride, write_stride, num_writes, accumulator_length)
}
}