use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::{SortingOps, TypeConversionOps, compute_reduce_strides, reduce_dim_output_shape};
use crate::runtime::wgpu::client::get_buffer;
use crate::runtime::wgpu::shaders::launch_mode_dim;
use crate::runtime::wgpu::{WgpuClient, WgpuRuntime};
use crate::runtime::{RuntimeClient, ensure_contiguous, normalize_dim};
use crate::tensor::Tensor;
use wgpu::util::DeviceExt;
pub fn mode_impl(
client: &WgpuClient,
a: &Tensor<WgpuRuntime>,
dim: Option<isize>,
keepdim: bool,
) -> Result<(Tensor<WgpuRuntime>, Tensor<WgpuRuntime>)> {
let dtype = a.dtype();
let native_supported = matches!(dtype, DType::F32 | DType::I32 | DType::U32);
if !native_supported {
let a_f32 = client.cast(a, DType::F32)?;
let (values_f32, counts) = mode_impl(client, &a_f32, dim, keepdim)?;
let values = client.cast(&values_f32, dtype)?;
return Ok((values, counts));
}
if dim.is_none() {
let numel = a.numel();
if numel == 0 {
let out_shape = if keepdim { vec![1; a.ndim()] } else { vec![] };
let values = Tensor::<WgpuRuntime>::empty(&out_shape, dtype, client.device());
let counts = Tensor::<WgpuRuntime>::empty(&out_shape, DType::I32, client.device());
return Ok((values, counts));
}
let flat = a.reshape(&[numel])?;
return mode_impl(client, &flat, Some(0), keepdim);
}
let dim_val = dim.unwrap();
let shape = a.shape();
let ndim = shape.len();
if ndim == 0 {
let counts = Tensor::<WgpuRuntime>::full_scalar(&[], DType::I32, 1.0, client.device());
return Ok((a.clone(), counts));
}
let dim_idx = normalize_dim(dim_val, ndim)?;
let dim_size = shape[dim_idx];
if dim_size == 0 {
let out_shape = reduce_dim_output_shape(shape, dim_idx, keepdim);
let values = Tensor::<WgpuRuntime>::empty(&out_shape, dtype, client.device());
let counts = Tensor::<WgpuRuntime>::empty(&out_shape, DType::I32, client.device());
return Ok((values, counts));
}
let sorted = client.sort(a, dim_val, false)?;
let out_shape = reduce_dim_output_shape(shape, dim_idx, keepdim);
let (outer_size, reduce_size, inner_size) = compute_reduce_strides(shape, dim_idx);
let num_outputs = outer_size * inner_size;
let sorted_contig = ensure_contiguous(&sorted);
let mode_values = Tensor::<WgpuRuntime>::empty(&out_shape, dtype, client.device());
let mode_counts = Tensor::<WgpuRuntime>::empty(&out_shape, DType::I32, client.device());
let sorted_buf = get_buffer(sorted_contig.ptr())
.ok_or_else(|| Error::Internal("Failed to get sorted buffer".to_string()))?;
let values_buf = get_buffer(mode_values.ptr())
.ok_or_else(|| Error::Internal("Failed to get mode_values buffer".to_string()))?;
let counts_buf = get_buffer(mode_counts.ptr())
.ok_or_else(|| Error::Internal("Failed to get mode_counts buffer".to_string()))?;
#[repr(C)]
#[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
struct ModeParams {
outer_size: u32,
reduce_size: u32,
inner_size: u32,
_pad: u32,
}
let params = ModeParams {
outer_size: outer_size as u32,
reduce_size: reduce_size as u32,
inner_size: inner_size as u32,
_pad: 0,
};
let params_buf = client
.wgpu_device()
.create_buffer_init(&wgpu::util::BufferInitDescriptor {
label: Some("mode_params"),
contents: bytemuck::bytes_of(¶ms),
usage: wgpu::BufferUsages::UNIFORM,
});
launch_mode_dim(
client.pipeline_cache(),
client.wgpu_queue(),
&*sorted_buf,
&*values_buf,
&*counts_buf,
¶ms_buf,
num_outputs,
dtype,
)?;
let mode_counts_i64 = client.cast(&mode_counts, DType::I64)?;
Ok((mode_values, mode_counts_i64))
}