use crate::dtype::DType;
use crate::error::{Error, Result};
use crate::ops::{ReduceOps, ScalarOps, TypeConversionOps, UnaryOps, UtilityOps};
use crate::runtime::RuntimeClient;
use crate::runtime::wgpu::{WgpuClient, WgpuRuntime};
use crate::tensor::Tensor;
use super::{create_bin_edges, tensor_to_f64};
pub fn histogram_impl(
client: &WgpuClient,
a: &Tensor<WgpuRuntime>,
bins: usize,
range: Option<(f64, f64)>,
) -> Result<(Tensor<WgpuRuntime>, Tensor<WgpuRuntime>)> {
if bins == 0 {
return Err(Error::InvalidArgument {
arg: "bins",
reason: "Number of bins must be positive".to_string(),
});
}
let dtype = a.dtype();
let numel = a.numel();
if numel == 0 {
let (min_val, max_val) = range.unwrap_or((0.0, 1.0));
let hist = Tensor::<WgpuRuntime>::zeros(&[bins], DType::I64, client.device());
let edges = create_bin_edges(client, min_val, max_val, bins, dtype)?;
return Ok((hist, edges));
}
let flat = a.reshape(&[numel])?;
let (min_val, max_val) = if let Some((min, max)) = range {
if min >= max {
return Err(Error::InvalidArgument {
arg: "range",
reason: format!("Range min ({}) must be less than max ({})", min, max),
});
}
(min, max)
} else {
let min_tensor = client.min(&flat, &[], false)?;
let max_tensor = client.max(&flat, &[], false)?;
let min_val = tensor_to_f64(client, &min_tensor)?;
let max_val = tensor_to_f64(client, &max_tensor)?;
if (min_val - max_val).abs() < f64::EPSILON {
(min_val - 0.5, max_val + 0.5)
} else {
(min_val, max_val)
}
};
let flat_f32 = if dtype != DType::F32 {
client.cast(&flat, DType::F32)?
} else {
flat.clone()
};
let bin_width = (max_val - min_val) / bins as f64;
let shifted = client.sub_scalar(&flat_f32, min_val)?;
let normalized = client.div_scalar(&shifted, bin_width)?;
let floored = client.floor(&normalized)?;
let bin_indices = client.clamp(&floored, 0.0, (bins - 1) as f64)?;
let bin_indices_i64 = client.cast(&bin_indices, DType::I64)?;
let one_hot_matrix = client.one_hot(&bin_indices_i64, bins)?;
let hist = client.sum(&one_hot_matrix, &[0], false)?;
let hist = client.cast(&hist, DType::I64)?;
let edges = create_bin_edges(client, min_val, max_val, bins, dtype)?;
Ok((hist, edges))
}