use super::helpers::*;
use crate::error::{Error, Result};
use crate::runtime::ensure_contiguous;
use crate::runtime::wgpu::shaders::cumulative;
use crate::runtime::wgpu::{WgpuClient, WgpuRuntime};
use crate::tensor::Tensor;
pub(crate) fn native_cumsum(
client: &WgpuClient,
a: &Tensor<WgpuRuntime>,
dim: isize,
) -> Result<Tensor<WgpuRuntime>> {
let dtype = a.dtype();
let shape = a.shape();
let ndim = shape.len();
if ndim == 0 {
return Err(Error::InvalidDimension { dim, ndim });
}
let dim = if dim < 0 {
(ndim as isize + dim) as usize
} else {
dim as usize
};
if dim >= ndim {
return Err(Error::InvalidDimension {
dim: dim as isize,
ndim,
});
}
let a_contig = ensure_contiguous(a);
let scan_size = shape[dim];
let outer_size: usize = shape[..dim].iter().product();
let inner_size: usize = shape[dim + 1..].iter().product();
let out = alloc_output(client, shape, dtype);
let a_buf = get_tensor_buffer(&a_contig)?;
let out_buf = get_tensor_buffer(&out)?;
if dim == ndim - 1 || inner_size == 1 {
let params = CumsumParams {
scan_size: scan_size as u32,
outer_size: outer_size.max(1) as u32,
};
let params_buf = create_params_buffer(client, ¶ms);
cumulative::launch_cumsum(
client.pipeline_cache(),
client.wgpu_queue(),
&a_buf,
&out_buf,
¶ms_buf,
outer_size.max(1),
dtype,
)?;
} else {
let params = CumsumStridedParams {
scan_size: scan_size as u32,
outer_size: outer_size.max(1) as u32,
inner_size: inner_size as u32,
};
let params_buf = create_params_buffer(client, ¶ms);
let total_inner = outer_size.max(1) * inner_size;
cumulative::launch_cumsum_strided(
client.pipeline_cache(),
client.wgpu_queue(),
&a_buf,
&out_buf,
¶ms_buf,
total_inner,
dtype,
)?;
}
Ok(out)
}
pub(crate) fn native_cumprod(
client: &WgpuClient,
a: &Tensor<WgpuRuntime>,
dim: isize,
) -> Result<Tensor<WgpuRuntime>> {
let dtype = a.dtype();
let shape = a.shape();
let ndim = shape.len();
if ndim == 0 {
return Err(Error::InvalidDimension { dim, ndim });
}
let dim = if dim < 0 {
(ndim as isize + dim) as usize
} else {
dim as usize
};
if dim >= ndim {
return Err(Error::InvalidDimension {
dim: dim as isize,
ndim,
});
}
let a_contig = ensure_contiguous(a);
let scan_size = shape[dim];
let outer_size: usize = shape[..dim].iter().product();
let inner_size: usize = shape[dim + 1..].iter().product();
let out = alloc_output(client, shape, dtype);
let a_buf = get_tensor_buffer(&a_contig)?;
let out_buf = get_tensor_buffer(&out)?;
if dim == ndim - 1 || inner_size == 1 {
let params = CumprodParams {
scan_size: scan_size as u32,
outer_size: outer_size.max(1) as u32,
};
let params_buf = create_params_buffer(client, ¶ms);
cumulative::launch_cumprod(
client.pipeline_cache(),
client.wgpu_queue(),
&a_buf,
&out_buf,
¶ms_buf,
outer_size.max(1),
dtype,
)?;
} else {
let params = CumprodStridedParams {
scan_size: scan_size as u32,
outer_size: outer_size.max(1) as u32,
inner_size: inner_size as u32,
};
let params_buf = create_params_buffer(client, ¶ms);
let total_inner = outer_size.max(1) * inner_size;
cumulative::launch_cumprod_strided(
client.pipeline_cache(),
client.wgpu_queue(),
&a_buf,
&out_buf,
¶ms_buf,
total_inner,
dtype,
)?;
}
Ok(out)
}
pub(crate) fn native_logsumexp(
client: &WgpuClient,
a: &Tensor<WgpuRuntime>,
dims: &[usize],
keepdim: bool,
) -> Result<Tensor<WgpuRuntime>> {
let dtype = a.dtype();
let shape = a.shape();
let ndim = shape.len();
if !dtype.is_float() {
return Err(Error::UnsupportedDType {
dtype,
op: "logsumexp",
});
}
if dims.is_empty() {
return Ok(a.clone());
}
if dims.len() > 1 {
let mut sorted_dims = dims.to_vec();
sorted_dims.sort_by(|a, b| b.cmp(a));
let mut result = a.clone();
for &dim in &sorted_dims {
result = native_logsumexp_single_dim(client, &result, dim, true)?;
}
if !keepdim {
let mut out_shape: Vec<usize> = shape.to_vec();
for &dim in &sorted_dims {
out_shape.remove(dim);
}
if out_shape.is_empty() {
out_shape.push(1);
}
result = result.reshape(&out_shape)?;
}
return Ok(result);
}
let dim = dims[0];
if dim >= ndim {
return Err(Error::InvalidDimension {
dim: dim as isize,
ndim,
});
}
native_logsumexp_single_dim(client, a, dim, keepdim)
}
fn native_logsumexp_single_dim(
client: &WgpuClient,
a: &Tensor<WgpuRuntime>,
dim: usize,
keepdim: bool,
) -> Result<Tensor<WgpuRuntime>> {
let dtype = a.dtype();
let shape = a.shape();
let ndim = shape.len();
if dim >= ndim {
return Err(Error::InvalidDimension {
dim: dim as isize,
ndim,
});
}
let a_contig = ensure_contiguous(a);
let reduce_size = shape[dim];
let outer_size: usize = shape[..dim].iter().product();
let inner_size: usize = shape[dim + 1..].iter().product();
let out_shape: Vec<usize> = if keepdim {
let mut s = shape.to_vec();
s[dim] = 1;
s
} else {
let mut s: Vec<usize> = shape[..dim].to_vec();
s.extend_from_slice(&shape[dim + 1..]);
if s.is_empty() {
s.push(1);
}
s
};
let out = alloc_output(client, &out_shape, dtype);
let a_buf = get_tensor_buffer(&a_contig)?;
let out_buf = get_tensor_buffer(&out)?;
if dim == ndim - 1 || inner_size == 1 {
let params = LogsumexpParams {
reduce_size: reduce_size as u32,
outer_size: outer_size.max(1) as u32,
};
let params_buf = create_params_buffer(client, ¶ms);
cumulative::launch_logsumexp(
client.pipeline_cache(),
client.wgpu_queue(),
&a_buf,
&out_buf,
¶ms_buf,
outer_size.max(1),
dtype,
)?;
} else {
let params = LogsumexpStridedParams {
reduce_size: reduce_size as u32,
outer_size: outer_size.max(1) as u32,
inner_size: inner_size as u32,
};
let params_buf = create_params_buffer(client, ¶ms);
let total_inner = outer_size.max(1) * inner_size;
cumulative::launch_logsumexp_strided(
client.pipeline_cache(),
client.wgpu_queue(),
&a_buf,
&out_buf,
¶ms_buf,
total_inner,
dtype,
)?;
}
Ok(out)
}