use super::helpers::*;
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::ScalarOps;
use crate::runtime::ensure_contiguous;
use crate::runtime::wgpu::shaders::reduce;
use crate::runtime::wgpu::{WgpuClient, WgpuRuntime};
use crate::tensor::Tensor;
pub(crate) fn native_reduce_op(
client: &WgpuClient,
op: &'static str,
a: &Tensor<WgpuRuntime>,
dims: &[usize],
keepdim: bool,
) -> Result<Tensor<WgpuRuntime>> {
let _dtype = a.dtype();
let shape = a.shape();
if dims.is_empty() {
return native_full_reduce(client, op, a);
}
if dims.len() > 1 {
let mut sorted_dims = dims.to_vec();
sorted_dims.sort_by(|a, b| b.cmp(a));
let mut result = a.clone();
for &dim in &sorted_dims {
result = native_single_dim_reduce(client, op, &result, dim, true)?;
}
if !keepdim {
let mut out_shape: Vec<usize> = shape.to_vec();
for &dim in &sorted_dims {
out_shape.remove(dim);
}
if out_shape.is_empty() {
out_shape.push(1);
}
result = result.reshape(&out_shape)?;
}
return Ok(result);
}
let dim = dims[0];
native_single_dim_reduce(client, op, a, dim, keepdim)
}
fn native_single_dim_reduce(
client: &WgpuClient,
op: &'static str,
a: &Tensor<WgpuRuntime>,
dim: usize,
keepdim: bool,
) -> Result<Tensor<WgpuRuntime>> {
let dtype = a.dtype();
let shape = a.shape();
let ndim = shape.len();
if dim >= ndim {
return Err(Error::InvalidDimension {
dim: dim as isize,
ndim,
});
}
let a_contig = ensure_contiguous(a);
let reduce_size = shape[dim];
let outer_size: usize = shape[..dim].iter().product();
let inner_size: usize = shape[dim + 1..].iter().product();
let numel_out = outer_size * inner_size;
let out_shape: Vec<usize> = if keepdim {
let mut s = shape.to_vec();
s[dim] = 1;
s
} else {
let mut s: Vec<usize> = shape[..dim].to_vec();
s.extend_from_slice(&shape[dim + 1..]);
if s.is_empty() {
s.push(1);
}
s
};
let out = alloc_output(client, &out_shape, dtype);
let a_buf = get_tensor_buffer(&a_contig)?;
let out_buf = get_tensor_buffer(&out)?;
let params = ReduceParams {
reduce_size: reduce_size as u32,
outer_size: outer_size.max(1) as u32,
inner_size: inner_size.max(1) as u32,
numel_out: numel_out.max(1) as u32,
};
let params_buf = create_params_buffer(client, ¶ms);
reduce::launch_reduce_op(
client.pipeline_cache(),
client.wgpu_queue(),
op,
&a_buf,
&out_buf,
¶ms_buf,
numel_out.max(1),
dtype,
)?;
Ok(out)
}
fn native_full_reduce(
client: &WgpuClient,
op: &'static str,
a: &Tensor<WgpuRuntime>,
) -> Result<Tensor<WgpuRuntime>> {
let dtype = a.dtype();
let a_contig = ensure_contiguous(a);
let numel = a.numel();
let is_mean = op == "mean";
let reduce_op = if is_mean { "sum" } else { op };
let workgroup_size = 256;
let num_workgroups = (numel + workgroup_size - 1) / workgroup_size;
if num_workgroups <= 1 {
let out = alloc_output(client, &[1], dtype);
let a_buf = get_tensor_buffer(&a_contig)?;
let out_buf = get_tensor_buffer(&out)?;
let params = FullReduceParams {
numel: numel as u32,
};
let params_buf = create_params_buffer(client, ¶ms);
reduce::launch_full_reduce_op(
client.pipeline_cache(),
client.wgpu_queue(),
reduce_op,
&a_buf,
&out_buf,
¶ms_buf,
numel,
dtype,
)?;
if is_mean {
return client.div_scalar(&out, numel as f64);
}
return Ok(out);
}
let partial = alloc_output(client, &[num_workgroups], dtype);
let a_buf = get_tensor_buffer(&a_contig)?;
let partial_buf = get_tensor_buffer(&partial)?;
let params = FullReduceParams {
numel: numel as u32,
};
let params_buf = create_params_buffer(client, ¶ms);
reduce::launch_full_reduce_op(
client.pipeline_cache(),
client.wgpu_queue(),
reduce_op,
&a_buf,
&partial_buf,
¶ms_buf,
numel,
dtype,
)?;
let out = alloc_output(client, &[1], dtype);
let out_buf = get_tensor_buffer(&out)?;
let params2 = FullReduceParams {
numel: num_workgroups as u32,
};
let params_buf2 = create_params_buffer(client, ¶ms2);
reduce::launch_full_reduce_op(
client.pipeline_cache(),
client.wgpu_queue(),
reduce_op,
&partial_buf,
&out_buf,
¶ms_buf2,
num_workgroups,
dtype,
)?;
if is_mean {
return client.div_scalar(&out, numel as f64);
}
Ok(out)
}
pub(crate) fn native_softmax(
client: &WgpuClient,
a: &Tensor<WgpuRuntime>,
dim: isize,
) -> Result<Tensor<WgpuRuntime>> {
let shape = a.shape();
let ndim = shape.len();
let dim = if dim < 0 {
(ndim as isize + dim) as usize
} else {
dim as usize
};
if dim >= ndim {
return Err(Error::InvalidDimension {
dim: dim as isize,
ndim,
});
}
if dim != ndim - 1 {
let mut perm: Vec<usize> = (0..ndim).collect();
perm.remove(dim);
perm.push(dim);
let permuted = a.permute(&perm)?;
let permuted_contig = permuted.contiguous();
let result = native_softmax_last_dim(client, &permuted_contig)?;
let mut inv_perm = vec![0; ndim];
for (i, &p) in perm.iter().enumerate() {
inv_perm[p] = i;
}
return result.permute(&inv_perm);
}
native_softmax_last_dim(client, a)
}
fn native_softmax_last_dim(
client: &WgpuClient,
a: &Tensor<WgpuRuntime>,
) -> Result<Tensor<WgpuRuntime>> {
let dtype = a.dtype();
let shape = a.shape();
let ndim = shape.len();
let a_contig = ensure_contiguous(a);
let dim = ndim - 1;
let batch_size: usize = shape[..dim].iter().product();
let dim_size = shape[dim];
let out = alloc_output(client, shape, dtype);
let a_buf = get_tensor_buffer(&a_contig)?;
let out_buf = get_tensor_buffer(&out)?;
let params = SoftmaxParams {
batch_size: batch_size.max(1) as u32,
dim_size: dim_size as u32,
};
let params_buf = create_params_buffer(client, ¶ms);
reduce::launch_softmax_op(
client.pipeline_cache(),
client.wgpu_queue(),
&a_buf,
&out_buf,
¶ms_buf,
batch_size.max(1),
dtype,
)?;
Ok(out)
}
pub(crate) fn native_softmax_bwd(
client: &WgpuClient,
grad: &Tensor<WgpuRuntime>,
output: &Tensor<WgpuRuntime>,
dim: isize,
) -> Result<Tensor<WgpuRuntime>> {
let shape = grad.shape();
let ndim = shape.len();
let dim = if dim < 0 {
(ndim as isize + dim) as usize
} else {
dim as usize
};
if dim >= ndim {
return Err(Error::InvalidDimension {
dim: dim as isize,
ndim,
});
}
if dim != ndim - 1 {
let mut perm: Vec<usize> = (0..ndim).collect();
perm.remove(dim);
perm.push(dim);
let grad_p = grad.permute(&perm)?.contiguous();
let output_p = output.permute(&perm)?.contiguous();
let result = native_softmax_bwd_last_dim(client, &grad_p, &output_p)?;
let mut inv_perm = vec![0; ndim];
for (i, &p) in perm.iter().enumerate() {
inv_perm[p] = i;
}
return result.permute(&inv_perm);
}
native_softmax_bwd_last_dim(client, grad, output)
}
fn native_softmax_bwd_last_dim(
client: &WgpuClient,
grad: &Tensor<WgpuRuntime>,
output: &Tensor<WgpuRuntime>,
) -> Result<Tensor<WgpuRuntime>> {
let shape = grad.shape();
let ndim = shape.len();
let dtype = grad.dtype();
let grad_contig = ensure_contiguous(grad);
let output_contig = ensure_contiguous(output);
let dim = ndim - 1;
let batch_size: usize = shape[..dim].iter().product();
let dim_size = shape[dim];
let d_input = alloc_output(client, shape, dtype);
let grad_buf = get_tensor_buffer(&grad_contig)?;
let output_buf = get_tensor_buffer(&output_contig)?;
let d_input_buf = get_tensor_buffer(&d_input)?;
let params = SoftmaxParams {
batch_size: batch_size.max(1) as u32,
dim_size: dim_size as u32,
};
let params_buf = create_params_buffer(client, ¶ms);
reduce::launch_softmax_bwd_op(
client.pipeline_cache(),
client.wgpu_queue(),
&grad_buf,
&output_buf,
&d_input_buf,
¶ms_buf,
batch_size.max(1),
dtype,
)?;
Ok(d_input)
}
pub(crate) fn native_argreduce_op(
client: &WgpuClient,
op: &'static str,
a: &Tensor<WgpuRuntime>,
dim: usize,
keepdim: bool,
) -> Result<Tensor<WgpuRuntime>> {
let dtype = a.dtype();
let shape = a.shape();
let ndim = shape.len();
if dim >= ndim {
return Err(Error::InvalidDimension {
dim: dim as isize,
ndim,
});
}
let a_contig = ensure_contiguous(a);
let reduce_size = shape[dim];
let outer_size: usize = shape[..dim].iter().product();
let inner_size: usize = shape[dim + 1..].iter().product();
let numel_out = outer_size * inner_size;
let out_shape: Vec<usize> = if keepdim {
let mut s = shape.to_vec();
s[dim] = 1;
s
} else {
let mut s: Vec<usize> = shape[..dim].to_vec();
s.extend_from_slice(&shape[dim + 1..]);
if s.is_empty() {
s.push(1);
}
s
};
let out = alloc_output(client, &out_shape, DType::I32);
let a_buf = get_tensor_buffer(&a_contig)?;
let out_buf = get_tensor_buffer(&out)?;
let params = ArgReduceParams {
reduce_size: reduce_size as u32,
outer_size: outer_size.max(1) as u32,
inner_size: inner_size.max(1) as u32,
numel_out: numel_out.max(1) as u32,
};
let params_buf = create_params_buffer(client, ¶ms);
reduce::launch_argreduce_op(
client.pipeline_cache(),
client.wgpu_queue(),
op,
&a_buf,
&out_buf,
¶ms_buf,
numel_out.max(1),
dtype,
)?;
Ok(out)
}