use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::{
BinaryOps, IndexingOps, ScalarOps, SortingOps, TypeConversionOps, reduce_dim_output_shape,
};
use crate::runtime::common::statistics_common::Interpolation;
use crate::runtime::wgpu::{WgpuClient, WgpuRuntime};
use crate::runtime::{RuntimeClient, normalize_dim};
use crate::tensor::Tensor;
pub fn quantile_impl(
client: &WgpuClient,
a: &Tensor<WgpuRuntime>,
q: f64,
dim: Option<isize>,
keepdim: bool,
interpolation: &str,
) -> Result<Tensor<WgpuRuntime>> {
if !(0.0..=1.0).contains(&q) {
return Err(Error::InvalidArgument {
arg: "q",
reason: format!("Quantile q must be in [0, 1], got {}", q),
});
}
let _interp = Interpolation::parse(interpolation)?;
let dtype = a.dtype();
if dtype == DType::F64 {
let a_f32 = client.cast(a, DType::F32)?;
let out_f32 = quantile_impl(client, &a_f32, q, dim, keepdim, interpolation)?;
return client.cast(&out_f32, DType::F64);
}
if dim.is_none() {
let numel = a.numel();
if numel == 0 {
let out_shape = if keepdim { vec![1; a.ndim()] } else { vec![] };
return Ok(Tensor::<WgpuRuntime>::empty(
&out_shape,
dtype,
client.device(),
));
}
let flat = a.reshape(&[numel])?;
return quantile_impl(client, &flat, q, Some(0), keepdim, interpolation);
}
let dim_val = dim.unwrap();
let shape = a.shape();
let ndim = shape.len();
if ndim == 0 {
return Ok(a.clone());
}
let dim_idx = normalize_dim(dim_val, ndim)?;
let dim_size = shape[dim_idx];
if dim_size == 0 {
let out_shape = reduce_dim_output_shape(shape, dim_idx, keepdim);
return Ok(Tensor::<WgpuRuntime>::empty(
&out_shape,
dtype,
client.device(),
));
}
let sorted = client.sort(a, dim_val, false)?;
let out_shape = reduce_dim_output_shape(shape, dim_idx, keepdim);
let (floor_idx, ceil_idx, frac) =
crate::runtime::common::statistics_common::compute_quantile_indices(q, dim_size);
let out_numel = out_shape.iter().product::<usize>();
if out_numel == 0 {
return Ok(Tensor::<WgpuRuntime>::empty(
&out_shape,
dtype,
client.device(),
));
}
let floor_idx_tensor =
Tensor::<WgpuRuntime>::from_slice(&[floor_idx as i32], &[1], client.device());
let ceil_idx_tensor =
Tensor::<WgpuRuntime>::from_slice(&[ceil_idx as i32], &[1], client.device());
let floor_vals = client.index_select(&sorted, dim_idx, &floor_idx_tensor)?;
let ceil_vals = client.index_select(&sorted, dim_idx, &ceil_idx_tensor)?;
let mut floor_shape = Vec::with_capacity(shape.len() - 1);
for (i, &s) in shape.iter().enumerate() {
if i != dim_idx {
floor_shape.push(s);
}
}
let ceil_shape = floor_shape.clone();
let floor_vals = floor_vals.reshape(&floor_shape)?;
let ceil_vals = ceil_vals.reshape(&ceil_shape)?;
let floor_f32 = if dtype != DType::F32 {
client.cast(&floor_vals, DType::F32)?
} else {
floor_vals
};
let ceil_f32 = if dtype != DType::F32 {
client.cast(&ceil_vals, DType::F32)?
} else {
ceil_vals
};
let diff = client.sub(&ceil_f32, &floor_f32)?;
let scaled_diff = client.mul_scalar(&diff, frac)?;
let result_f32 = client.add(&floor_f32, &scaled_diff)?;
let result = if dtype != DType::F32 {
client.cast(&result_f32, dtype)?
} else {
result_f32
};
if keepdim {
let mut final_shape = result.shape().to_vec();
final_shape.insert(dim_idx, 1);
result.reshape(&final_shape)
} else {
Ok(result)
}
}
pub fn percentile_impl(
client: &WgpuClient,
a: &Tensor<WgpuRuntime>,
p: f64,
dim: Option<isize>,
keepdim: bool,
) -> Result<Tensor<WgpuRuntime>> {
if !(0.0..=100.0).contains(&p) {
return Err(Error::InvalidArgument {
arg: "p",
reason: format!("Percentile p must be in [0, 100], got {}", p),
});
}
quantile_impl(client, a, p / 100.0, dim, keepdim, "linear")
}
pub fn median_impl(
client: &WgpuClient,
a: &Tensor<WgpuRuntime>,
dim: Option<isize>,
keepdim: bool,
) -> Result<Tensor<WgpuRuntime>> {
quantile_impl(client, a, 0.5, dim, keepdim, "linear")
}