use alloc::vec;
use burn_backend::Element;
use burn_std::Bytes;
use bytemuck::Pod;
use num_traits::{Bounded, Num};
use crate::{FlexTensor, Layout};
pub fn cumsum<E: Element + Pod + Default + Copy + Num>(
tensor: FlexTensor,
dim: usize,
) -> FlexTensor {
cumulative_op(tensor, dim, E::zero(), |acc, val| acc + val)
}
pub fn cumprod<E: Element + Pod + Default + Copy + Num>(
tensor: FlexTensor,
dim: usize,
) -> FlexTensor {
cumulative_op(tensor, dim, E::one(), |acc, val| acc * val)
}
fn cumulative_op<E: Element + Pod + Default + Copy, F>(
tensor: FlexTensor,
dim: usize,
init: E,
op: F,
) -> FlexTensor
where
F: Fn(E, E) -> E,
{
let tensor = tensor.to_contiguous();
let shape = tensor.layout().shape().clone();
let ndims = shape.num_dims();
assert!(
dim < ndims,
"dim {} out of bounds for {} dimensions",
dim,
ndims
);
let data: &[E] = tensor.storage();
let total_size = shape.num_elements();
let mut result = vec![E::default(); total_size];
let dim_size = shape[dim];
let inner_size: usize = shape[dim + 1..].iter().product();
let outer_size: usize = shape[..dim].iter().product();
let block_size = dim_size * inner_size;
if inner_size == 1 {
for outer in 0..outer_size {
let base = outer * dim_size;
let mut acc = init;
for i in 0..dim_size {
acc = op(acc, data[base + i]);
result[base + i] = acc;
}
}
} else {
let mut acc = vec![init; inner_size];
for outer in 0..outer_size {
let base = outer * block_size;
acc.fill(init);
for i in 0..dim_size {
let offset = base + i * inner_size;
for j in 0..inner_size {
acc[j] = op(acc[j], data[offset + j]);
result[offset + j] = acc[j];
}
}
}
}
let bytes = Bytes::from_elems(result);
FlexTensor::new(bytes, Layout::contiguous(shape), E::dtype())
}
fn cumulative_op_half<E: Element + Pod + Default + Copy, F>(
tensor: FlexTensor,
dim: usize,
init: f32,
op: F,
to_f32: fn(E) -> f32,
from_f32: fn(f32) -> E,
) -> FlexTensor
where
F: Fn(f32, f32) -> f32,
{
let tensor = tensor.to_contiguous();
let shape = tensor.layout().shape().clone();
let ndims = shape.num_dims();
assert!(
dim < ndims,
"dim {} out of bounds for {} dimensions",
dim,
ndims
);
let data: &[E] = tensor.storage();
let total_size = shape.num_elements();
let mut result = vec![E::default(); total_size];
let dim_size = shape[dim];
let inner_size: usize = shape[dim + 1..].iter().product();
let outer_size: usize = shape[..dim].iter().product();
let block_size = dim_size * inner_size;
let mut acc = vec![init; inner_size];
for outer in 0..outer_size {
let base = outer * block_size;
acc.fill(init);
for i in 0..dim_size {
let offset = base + i * inner_size;
for j in 0..inner_size {
acc[j] = op(acc[j], to_f32(data[offset + j]));
result[offset + j] = from_f32(acc[j]);
}
}
}
let bytes = Bytes::from_elems(result);
FlexTensor::new(bytes, Layout::contiguous(shape), E::dtype())
}
pub fn cumsum_f32(tensor: FlexTensor, dim: usize) -> FlexTensor {
cumsum::<f32>(tensor, dim)
}
pub fn cumsum_f64(tensor: FlexTensor, dim: usize) -> FlexTensor {
cumsum::<f64>(tensor, dim)
}
pub fn cumprod_f32(tensor: FlexTensor, dim: usize) -> FlexTensor {
cumprod::<f32>(tensor, dim)
}
pub fn cumprod_f64(tensor: FlexTensor, dim: usize) -> FlexTensor {
cumprod::<f64>(tensor, dim)
}
pub fn cummin<E: Element + Pod + Default + Copy + Ord + Bounded>(
tensor: FlexTensor,
dim: usize,
) -> FlexTensor {
cumulative_op(tensor, dim, E::max_value(), Ord::min)
}
pub fn cummin_f32(tensor: FlexTensor, dim: usize) -> FlexTensor {
cumulative_op(tensor, dim, f32::INFINITY, |acc, val| {
if val.is_nan() || val < acc { val } else { acc }
})
}
pub fn cummin_f64(tensor: FlexTensor, dim: usize) -> FlexTensor {
cumulative_op(tensor, dim, f64::INFINITY, |acc, val| {
if val.is_nan() || val < acc { val } else { acc }
})
}
pub fn cummax<E: Element + Pod + Default + Copy + Ord + Bounded>(
tensor: FlexTensor,
dim: usize,
) -> FlexTensor {
cumulative_op(tensor, dim, E::min_value(), Ord::max)
}
pub fn cummax_f32(tensor: FlexTensor, dim: usize) -> FlexTensor {
cumulative_op(tensor, dim, f32::NEG_INFINITY, |acc, val| {
if val.is_nan() || val > acc { val } else { acc }
})
}
pub fn cummax_f64(tensor: FlexTensor, dim: usize) -> FlexTensor {
cumulative_op(tensor, dim, f64::NEG_INFINITY, |acc, val| {
if val.is_nan() || val > acc { val } else { acc }
})
}
pub fn cumsum_half<E: Element + Pod + Default + Copy>(
tensor: FlexTensor,
dim: usize,
to_f32: fn(E) -> f32,
from_f32: fn(f32) -> E,
) -> FlexTensor {
cumulative_op_half(tensor, dim, 0.0, |acc, val| acc + val, to_f32, from_f32)
}
pub fn cumprod_half<E: Element + Pod + Default + Copy>(
tensor: FlexTensor,
dim: usize,
to_f32: fn(E) -> f32,
from_f32: fn(f32) -> E,
) -> FlexTensor {
cumulative_op_half(tensor, dim, 1.0, |acc, val| acc * val, to_f32, from_f32)
}
pub fn cummin_half<E: Element + Pod + Default + Copy>(
tensor: FlexTensor,
dim: usize,
to_f32: fn(E) -> f32,
from_f32: fn(f32) -> E,
) -> FlexTensor {
cumulative_op_half(
tensor,
dim,
f32::INFINITY,
|acc, val| if val.is_nan() || val < acc { val } else { acc },
to_f32,
from_f32,
)
}
pub fn cummax_half<E: Element + Pod + Default + Copy>(
tensor: FlexTensor,
dim: usize,
to_f32: fn(E) -> f32,
from_f32: fn(f32) -> E,
) -> FlexTensor {
cumulative_op_half(
tensor,
dim,
f32::NEG_INFINITY,
|acc, val| if val.is_nan() || val > acc { val } else { acc },
to_f32,
from_f32,
)
}