mod histogram;
mod mode;
mod moments;
mod quantile;
pub use histogram::histogram_impl;
pub use mode::mode_impl;
pub use moments::{kurtosis_impl, skew_impl};
pub use quantile::{median_impl, percentile_impl, quantile_impl};
use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::TypeConversionOps;
use crate::runtime::RuntimeClient;
use crate::runtime::common::statistics_common::compute_bin_edges_f64;
use crate::runtime::wgpu::{WgpuClient, WgpuRuntime};
use crate::tensor::Tensor;
pub(crate) fn create_bin_edges(
client: &WgpuClient,
min_val: f64,
max_val: f64,
bins: usize,
dtype: DType,
) -> Result<Tensor<WgpuRuntime>> {
let edges_data = compute_bin_edges_f64(min_val, max_val, bins);
let edges_f32: Vec<f32> = edges_data.iter().map(|&v| v as f32).collect();
let edges = Tensor::<WgpuRuntime>::from_slice(&edges_f32, &[bins + 1], client.device());
if dtype == DType::F32 {
Ok(edges)
} else {
client.cast(&edges, dtype)
}
}
pub(crate) fn tensor_to_f64(client: &WgpuClient, t: &Tensor<WgpuRuntime>) -> Result<f64> {
use crate::runtime::wgpu::client::get_buffer;
let dtype = t.dtype();
if t.numel() != 1 {
return Err(Error::InvalidArgument {
arg: "tensor",
reason: format!(
"tensor_to_f64 requires a scalar (1-element) tensor, got numel={}",
t.numel()
),
});
}
let src_buffer = get_buffer(t.ptr())
.ok_or_else(|| Error::Internal("Failed to get tensor buffer".to_string()))?;
let staging = client.create_staging_buffer("scalar_staging", dtype.size_in_bytes() as u64);
let mut encoder = client
.wgpu_device
.create_command_encoder(&wgpu::CommandEncoderDescriptor {
label: Some("scalar_copy"),
});
encoder.copy_buffer_to_buffer(&src_buffer, 0, &staging, 0, dtype.size_in_bytes() as u64);
client.submit_and_wait(encoder);
let val = match dtype {
DType::F32 => {
let mut data = [0f32; 1];
client.read_buffer(&staging, &mut data)?;
data[0] as f64
}
DType::I32 => {
let mut data = [0i32; 1];
client.read_buffer(&staging, &mut data)?;
data[0] as f64
}
DType::U32 => {
let mut data = [0u32; 1];
client.read_buffer(&staging, &mut data)?;
data[0] as f64
}
_ => {
return Err(Error::UnsupportedDType {
dtype,
op: "tensor_to_f64",
});
}
};
Ok(val)
}