use super::super::kernels;
use super::super::{CpuClient, CpuRuntime};
use crate::dispatch_dtype;
use crate::error::{Error, Result};
use crate::ops::reduce_output_shape;
use crate::runtime::ensure_contiguous;
use crate::tensor::Tensor;
#[inline]
fn normalize_dim(ndim: usize, dim: isize) -> Option<usize> {
if dim >= 0 {
let d = dim as usize;
if d < ndim { Some(d) } else { None }
} else {
let d = dim + ndim as isize;
if d >= 0 { Some(d as usize) } else { None }
}
}
pub fn cumsum_impl(
client: &CpuClient,
a: &Tensor<CpuRuntime>,
dim: isize,
) -> Result<Tensor<CpuRuntime>> {
let dtype = a.dtype();
let shape = a.shape();
let ndim = shape.len();
let dim_idx = normalize_dim(ndim, dim).ok_or(Error::InvalidDimension { dim, ndim })?;
if a.numel() == 0 {
return Ok(Tensor::<CpuRuntime>::empty(shape, dtype, &client.device));
}
let a_contig = ensure_contiguous(a);
let out = Tensor::<CpuRuntime>::empty(shape, dtype, &client.device);
let scan_size = shape[dim_idx];
let outer_size: usize = shape[..dim_idx].iter().product();
let outer_size = outer_size.max(1);
let inner_size: usize = shape[dim_idx + 1..].iter().product();
let inner_size = inner_size.max(1);
let a_ptr = a_contig.ptr();
let out_ptr = out.ptr();
dispatch_dtype!(dtype, T => {
unsafe {
if inner_size == 1 {
kernels::cumsum_kernel(
a_ptr as *const T,
out_ptr as *mut T,
scan_size,
outer_size,
);
} else {
kernels::cumsum_strided_kernel(
a_ptr as *const T,
out_ptr as *mut T,
scan_size,
outer_size,
inner_size,
);
}
}
}, "cumsum");
Ok(out)
}
pub fn cumprod_impl(
client: &CpuClient,
a: &Tensor<CpuRuntime>,
dim: isize,
) -> Result<Tensor<CpuRuntime>> {
let dtype = a.dtype();
let shape = a.shape();
let ndim = shape.len();
let dim_idx = normalize_dim(ndim, dim).ok_or(Error::InvalidDimension { dim, ndim })?;
if a.numel() == 0 {
return Ok(Tensor::<CpuRuntime>::empty(shape, dtype, &client.device));
}
let a_contig = ensure_contiguous(a);
let out = Tensor::<CpuRuntime>::empty(shape, dtype, &client.device);
let scan_size = shape[dim_idx];
let outer_size: usize = shape[..dim_idx].iter().product();
let outer_size = outer_size.max(1);
let inner_size: usize = shape[dim_idx + 1..].iter().product();
let inner_size = inner_size.max(1);
let a_ptr = a_contig.ptr();
let out_ptr = out.ptr();
dispatch_dtype!(dtype, T => {
unsafe {
if inner_size == 1 {
kernels::cumprod_kernel(
a_ptr as *const T,
out_ptr as *mut T,
scan_size,
outer_size,
);
} else {
kernels::cumprod_strided_kernel(
a_ptr as *const T,
out_ptr as *mut T,
scan_size,
outer_size,
inner_size,
);
}
}
}, "cumprod");
Ok(out)
}
pub fn logsumexp_impl(
client: &CpuClient,
a: &Tensor<CpuRuntime>,
dims: &[usize],
keepdim: bool,
) -> Result<Tensor<CpuRuntime>> {
let dtype = a.dtype();
let shape = a.shape();
let ndim = shape.len();
if !dtype.is_float() {
return Err(Error::UnsupportedDType {
dtype,
op: "logsumexp",
});
}
for &d in dims {
if d >= ndim {
return Err(Error::InvalidDimension {
dim: d as isize,
ndim,
});
}
}
if dims.len() == 1 && dims[0] == ndim - 1 && a.is_contiguous() {
let reduce_size = shape[ndim - 1];
let outer_size: usize = shape[..ndim - 1].iter().product();
let outer_size = outer_size.max(1);
let out_shape = reduce_output_shape(shape, dims, keepdim);
let out = Tensor::<CpuRuntime>::empty(&out_shape, dtype, &client.device);
let a_ptr = a.ptr();
let out_ptr = out.ptr();
dispatch_dtype!(dtype, T => {
unsafe {
kernels::logsumexp_kernel(
a_ptr as *const T,
out_ptr as *mut T,
reduce_size,
outer_size,
);
}
}, "logsumexp");
return Ok(out);
}
if dims.is_empty() {
return Ok(a.clone());
}
let a_contig = ensure_contiguous(a);
let mut sorted_dims: Vec<usize> = dims.to_vec();
sorted_dims.sort_unstable();
sorted_dims.reverse();
let mut current = a_contig;
for &dim in &sorted_dims {
current = logsumexp_single_dim(client, ¤t, dim, keepdim)?;
}
Ok(current)
}
fn logsumexp_single_dim(
client: &CpuClient,
a: &Tensor<CpuRuntime>,
dim: usize,
keepdim: bool,
) -> Result<Tensor<CpuRuntime>> {
let dtype = a.dtype();
let shape = a.shape();
let ndim = shape.len();
if dim >= ndim {
return Err(Error::InvalidDimension {
dim: dim as isize,
ndim,
});
}
let reduce_size = shape[dim];
let outer_size: usize = shape[..dim].iter().product();
let outer_size = outer_size.max(1);
let inner_size: usize = shape[dim + 1..].iter().product();
let inner_size = inner_size.max(1);
let out_shape: Vec<usize> = if keepdim {
shape
.iter()
.enumerate()
.map(|(i, &s)| if i == dim { 1 } else { s })
.collect()
} else {
shape
.iter()
.enumerate()
.filter_map(|(i, &s)| if i != dim { Some(s) } else { None })
.collect()
};
let out = Tensor::<CpuRuntime>::empty(&out_shape, dtype, &client.device);
let a_ptr = a.ptr();
let out_ptr = out.ptr();
dispatch_dtype!(dtype, T => {
unsafe {
if inner_size == 1 {
kernels::logsumexp_kernel(
a_ptr as *const T,
out_ptr as *mut T,
reduce_size,
outer_size,
);
} else {
kernels::logsumexp_strided_kernel(
a_ptr as *const T,
out_ptr as *mut T,
reduce_size,
outer_size,
inner_size,
inner_size, inner_size, );
}
}
}, "logsumexp");
Ok(out)
}