use crate::error::{RusTorchError, RusTorchResult};
use crate::tensor::Tensor;
use ndarray::ArrayD;
use num_traits::Float;
use std::cmp::Ordering;
type BoolMask = ArrayD<bool>;
type IndexArray = ArrayD<i64>;
type Shape = Vec<usize>;
mod broadcasting {
use super::*;
pub(super) fn can_broadcast(shape1: &[usize], shape2: &[usize]) -> bool {
let len1 = shape1.len();
let len2 = shape2.len();
let max_len = len1.max(len2);
for i in 0..max_len {
let dim1 = shape1
.get(len1.saturating_sub(max_len - i))
.copied()
.unwrap_or(1);
let dim2 = shape2
.get(len2.saturating_sub(max_len - i))
.copied()
.unwrap_or(1);
if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
return false;
}
}
true
}
pub(super) fn broadcast_shape(shape1: &[usize], shape2: &[usize]) -> RusTorchResult<Shape> {
let len1 = shape1.len();
let len2 = shape2.len();
let max_len = len1.max(len2);
let mut result = Vec::with_capacity(max_len);
for i in 0..max_len {
let dim1 = shape1
.get(len1.saturating_sub(max_len - i))
.copied()
.unwrap_or(1);
let dim2 = shape2
.get(len2.saturating_sub(max_len - i))
.copied()
.unwrap_or(1);
if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
return Err(RusTorchError::shape_mismatch(shape1, shape2));
}
result.push(dim1.max(dim2));
}
Ok(result)
}
pub(super) fn broadcast_index(
flat_idx: usize,
original_shape: &[usize],
target_shape: &[usize],
) -> usize {
if original_shape == target_shape {
return flat_idx;
}
let mut coords = Vec::with_capacity(target_shape.len());
let mut remaining = flat_idx;
for &dim_size in target_shape.iter().rev() {
coords.push(remaining % dim_size);
remaining /= dim_size;
}
coords.reverse();
let mut result_idx = 0;
let mut stride = 1;
for i in (0..original_shape.len()).rev() {
let coord_idx = coords.len().saturating_sub(original_shape.len() - i);
let coord = coords.get(coord_idx).copied().unwrap_or(0);
let mapped_coord = if original_shape[i] == 1 { 0 } else { coord };
result_idx += mapped_coord * stride;
stride *= original_shape[i];
}
result_idx
}
}
mod stride_calc {
pub(super) fn calculate_strides(shape: &[usize]) -> Vec<usize> {
let mut strides = vec![1; shape.len()];
for i in (0..shape.len().saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
strides
}
pub(super) fn flat_to_coords(flat_idx: usize, strides: &[usize]) -> Vec<usize> {
let mut coords = Vec::with_capacity(strides.len());
let mut remaining = flat_idx;
for &stride in strides.iter() {
coords.push(remaining / stride);
remaining %= stride;
}
coords
}
pub(super) fn coords_to_flat(coords: &[usize], strides: &[usize]) -> usize {
coords
.iter()
.zip(strides.iter())
.map(|(&coord, &stride)| coord * stride)
.sum()
}
}
pub mod conditional {
use super::broadcasting::{broadcast_index, broadcast_shape, can_broadcast};
use super::*;
pub fn where_<T: Float + 'static>(
condition: &BoolMask,
x: &Tensor<T>,
y: &Tensor<T>,
) -> RusTorchResult<Tensor<T>> {
let shapes = [condition.shape(), x.shape(), y.shape()];
for i in 0..shapes.len() {
for j in (i + 1)..shapes.len() {
if !can_broadcast(shapes[i], shapes[j]) {
return Err(RusTorchError::shape_mismatch(shapes[i], shapes[j]));
}
}
}
let output_shape =
broadcast_shape(&broadcast_shape(condition.shape(), x.shape())?, y.shape())?;
let total_elements: usize = output_shape.iter().product();
let condition_data = condition
.as_slice()
.ok_or_else(|| RusTorchError::tensor_op("Condition mask data not accessible"))?;
let x_data = x
.data
.as_slice()
.ok_or_else(|| RusTorchError::tensor_op("X tensor data not accessible"))?;
let y_data = y
.data
.as_slice()
.ok_or_else(|| RusTorchError::tensor_op("Y tensor data not accessible"))?;
let mut result_data = Vec::with_capacity(total_elements);
for i in 0..total_elements {
let cond_idx = broadcast_index(i, condition.shape(), &output_shape);
let x_idx = broadcast_index(i, x.shape(), &output_shape);
let y_idx = broadcast_index(i, y.shape(), &output_shape);
let value = if cond_idx < condition_data.len() && condition_data[cond_idx] {
if x_idx < x_data.len() {
x_data[x_idx]
} else {
T::zero()
}
} else {
if y_idx < y_data.len() {
y_data[y_idx]
} else {
T::zero()
}
};
result_data.push(value);
}
Ok(Tensor::from_vec(result_data, output_shape))
}
pub fn masked_select<T: Float + 'static>(
input: &Tensor<T>,
mask: &BoolMask,
) -> RusTorchResult<Tensor<T>> {
if input.shape() != mask.shape() {
return Err(RusTorchError::shape_mismatch(input.shape(), mask.shape()));
}
let input_data = input
.data
.as_slice()
.ok_or_else(|| RusTorchError::tensor_op("Input tensor data not accessible"))?;
let mask_data = mask
.as_slice()
.ok_or_else(|| RusTorchError::tensor_op("Mask array data not accessible"))?;
let true_count = mask_data.iter().filter(|&&val| val).count();
let mut selected = Vec::with_capacity(true_count);
for (value, mask_val) in input_data.iter().zip(mask_data.iter()) {
if *mask_val {
selected.push(*value);
}
}
let len = selected.len();
Ok(Tensor::from_vec(selected, vec![len]))
}
pub fn masked_fill_<T: Float + 'static>(
input: &mut Tensor<T>,
mask: &BoolMask,
value: T,
) -> RusTorchResult<()> {
if input.shape() != mask.shape() {
return Err(RusTorchError::shape_mismatch(input.shape(), mask.shape()));
}
let input_data = input.data.as_slice_mut().ok_or_else(|| {
RusTorchError::tensor_op("Input tensor data not accessible for mutation")
})?;
let mask_data = mask
.as_slice()
.ok_or_else(|| RusTorchError::tensor_op("Mask array data not accessible"))?;
input_data
.iter_mut()
.zip(mask_data.iter())
.filter(|(_, &mask_val)| mask_val)
.for_each(|(elem, _)| *elem = value);
Ok(())
}
pub fn masked_fill<T: Float + 'static>(
input: &Tensor<T>,
mask: &BoolMask,
value: T,
) -> RusTorchResult<Tensor<T>> {
if mask
.as_slice()
.map_or(false, |data| data.iter().all(|&x| !x))
{
return Ok(input.clone());
}
let mut result = input.clone();
masked_fill_(&mut result, mask, value)?;
Ok(result)
}
}
pub mod indexing {
use super::stride_calc::{calculate_strides, coords_to_flat, flat_to_coords};
use super::*;
pub fn gather<T: Float + 'static>(
input: &Tensor<T>,
dim: usize,
index: &IndexArray,
) -> RusTorchResult<Tensor<T>> {
let input_shape = input.shape();
if dim >= input_shape.len() {
return Err(RusTorchError::invalid_dimension(
dim,
input_shape.len().saturating_sub(1),
));
}
let input_data = input
.data
.as_slice()
.ok_or_else(|| RusTorchError::tensor_op("Input tensor data not accessible"))?;
let index_data = index
.as_slice()
.ok_or_else(|| RusTorchError::tensor_op("Index array data not accessible"))?;
let index_shape = index.shape();
let mut result_data = Vec::with_capacity(index_data.len());
let input_strides = calculate_strides(input_shape);
let index_strides = calculate_strides(index_shape);
let dim_size = input_shape[dim];
for &idx in index_data.iter() {
if idx < 0 || idx as usize >= dim_size {
return Err(RusTorchError::tensor_op(format!(
"Index {} out of bounds for dimension {} with size {}",
idx, dim, dim_size
)));
}
}
for flat_idx in 0..index_data.len() {
let index_coords = flat_to_coords(flat_idx, &index_strides);
let gather_idx = index_data[flat_idx] as usize;
let mut input_coords = index_coords;
input_coords.resize(input_shape.len(), 0);
if dim < input_coords.len() {
input_coords[dim] = gather_idx;
}
let input_flat_idx = coords_to_flat(&input_coords, &input_strides);
if input_flat_idx < input_data.len() {
result_data.push(input_data[input_flat_idx]);
} else {
return Err(RusTorchError::tensor_op(
"gather: Calculated index exceeds tensor bounds",
));
}
}
Ok(Tensor::from_vec(result_data, index_shape.to_vec()))
}
pub fn scatter_<T: Float + 'static>(
input: &mut Tensor<T>,
dim: usize,
index: &IndexArray,
src: &Tensor<T>,
) -> RusTorchResult<()> {
let input_shape = input.shape();
if dim >= input_shape.len() {
return Err(RusTorchError::invalid_dimension(
dim,
input_shape.len().saturating_sub(1),
));
}
if index.shape() != src.shape() {
return Err(RusTorchError::shape_mismatch(index.shape(), src.shape()));
}
let input_shape_owned = input_shape.to_vec();
let index_shape = index.shape().to_vec();
let dim_size = input_shape_owned[dim];
let index_data = index
.as_slice()
.ok_or_else(|| RusTorchError::tensor_op("Index array data not accessible"))?;
let src_data = src
.data
.as_slice()
.ok_or_else(|| RusTorchError::tensor_op("Source tensor data not accessible"))?;
for &idx in index_data.iter() {
if idx < 0 || idx as usize >= dim_size {
return Err(RusTorchError::tensor_op(format!(
"Index {} out of bounds for dimension {} with size {}",
idx, dim, dim_size
)));
}
}
let input_data = input.data.as_slice_mut().ok_or_else(|| {
RusTorchError::tensor_op("Input tensor data not accessible for mutation")
})?;
let input_strides = calculate_strides(&input_shape_owned);
let index_strides = calculate_strides(&index_shape);
for flat_idx in 0..index_data.len() {
let index_coords = flat_to_coords(flat_idx, &index_strides);
let scatter_idx = index_data[flat_idx] as usize;
let mut input_coords = index_coords;
input_coords.resize(input_shape_owned.len(), 0);
if dim < input_coords.len() {
input_coords[dim] = scatter_idx;
}
let input_flat_idx = coords_to_flat(&input_coords, &input_strides);
if input_flat_idx < input_data.len() {
input_data[input_flat_idx] = src_data[flat_idx];
} else {
return Err(RusTorchError::tensor_op(
"scatter: Calculated index exceeds tensor bounds",
));
}
}
Ok(())
}
pub fn scatter<T: Float + 'static>(
input: &Tensor<T>,
dim: usize,
index: &IndexArray,
src: &Tensor<T>,
) -> RusTorchResult<Tensor<T>> {
let mut result = input.clone();
scatter_(&mut result, dim, index, src)?;
Ok(result)
}
pub fn index_select<T: Float + 'static>(
input: &Tensor<T>,
dim: usize,
index: &IndexArray,
) -> RusTorchResult<Tensor<T>> {
let input_shape = input.shape();
if dim >= input_shape.len() {
return Err(RusTorchError::invalid_dimension(
dim,
input_shape.len().saturating_sub(1),
));
}
let input_data = input
.data
.as_slice()
.ok_or_else(|| RusTorchError::tensor_op("Input tensor data not accessible"))?;
let index_data = index
.as_slice()
.ok_or_else(|| RusTorchError::tensor_op("Index array data not accessible"))?;
let dim_size = input_shape[dim];
let index_len = index_data.len();
for &idx in index_data.iter() {
if idx < 0 || idx as usize >= dim_size {
return Err(RusTorchError::tensor_op(format!(
"Index {} out of bounds for dimension {} with size {}",
idx, dim, dim_size
)));
}
}
let mut output_shape = input_shape.to_vec();
output_shape[dim] = index_len;
let output_size: usize = output_shape.iter().product();
let mut result_data = Vec::with_capacity(output_size);
let outer_size: usize = input_shape[..dim].iter().product();
let inner_size: usize = input_shape[dim + 1..].iter().product();
let dim_stride = inner_size;
let outer_stride = input_shape[dim] * inner_size;
for outer_idx in 0..outer_size {
let base_outer = outer_idx * outer_stride;
for &selected_idx in index_data.iter() {
let selected_idx = selected_idx as usize;
let base_selected = base_outer + selected_idx * dim_stride;
for inner_idx in 0..inner_size {
let input_idx = base_selected + inner_idx;
if input_idx < input_data.len() {
result_data.push(input_data[input_idx]);
} else {
return Err(RusTorchError::tensor_op(
"index_select: Calculated index exceeds tensor bounds",
));
}
}
}
}
Ok(Tensor::from_vec(result_data, output_shape))
}
}
pub mod statistics {
use super::*;
pub fn topk_util<T: Float + 'static>(
input: &Tensor<T>,
k: usize,
dim: usize,
largest: bool,
sorted: bool,
) -> RusTorchResult<(Tensor<T>, IndexArray)> {
let input_shape = input.shape();
if dim >= input_shape.len() {
return Err(RusTorchError::invalid_dimension(
dim,
input_shape.len().saturating_sub(1),
));
}
if k > input_shape[dim] {
return Err(RusTorchError::tensor_op(format!(
"k ({}) cannot be larger than dimension size ({})",
k, input_shape[dim]
)));
}
if k == 0 {
let mut output_shape = input_shape.to_vec();
output_shape[dim] = 0;
return Ok((
Tensor::from_vec(Vec::new(), output_shape.clone()),
ArrayD::from_shape_vec(output_shape, Vec::new()).map_err(|_| {
RusTorchError::tensor_op("Failed to create empty indices array")
})?,
));
}
let input_data = input
.data
.as_slice()
.ok_or_else(|| RusTorchError::tensor_op("Input tensor data not accessible"))?;
let outer_size: usize = input_shape[..dim].iter().product();
let dim_size = input_shape[dim];
let inner_size: usize = input_shape[dim + 1..].iter().product();
let mut output_shape = input_shape.to_vec();
output_shape[dim] = k;
let output_size: usize = output_shape.iter().product();
let mut values = Vec::with_capacity(output_size);
let mut indices = Vec::with_capacity(output_size);
for outer_idx in 0..outer_size {
for inner_idx in 0..inner_size {
let mut slice_data = Vec::with_capacity(dim_size);
for dim_idx in 0..dim_size {
let flat_idx =
outer_idx * dim_size * inner_size + dim_idx * inner_size + inner_idx;
if flat_idx < input_data.len() {
slice_data.push((input_data[flat_idx], dim_idx));
}
}
if largest {
slice_data.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(Ordering::Equal));
} else {
slice_data.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
}
let take_count = k.min(slice_data.len());
for i in 0..take_count {
values.push(slice_data[i].0);
indices.push(slice_data[i].1 as i64);
}
if !sorted && largest && take_count > 0 {
let start_idx = values.len() - take_count;
values[start_idx..].reverse();
indices[start_idx..].reverse();
}
while values.len() % k != 0 && (outer_idx + 1) * (inner_idx + 1) * k <= output_size
{
values.push(T::zero());
indices.push(0);
}
}
}
Ok((
Tensor::from_vec(values, output_shape.clone()),
ArrayD::from_shape_vec(output_shape, indices)
.map_err(|_| RusTorchError::tensor_op("Failed to create indices array"))?,
))
}
pub fn kthvalue<T: Float + 'static>(
input: &Tensor<T>,
k: usize,
dim: usize,
keepdim: bool,
) -> RusTorchResult<(Tensor<T>, IndexArray)> {
let input_shape = input.shape();
if dim >= input_shape.len() {
return Err(RusTorchError::invalid_dimension(
dim,
input_shape.len().saturating_sub(1),
));
}
if k >= input_shape[dim] {
return Err(RusTorchError::tensor_op(format!(
"k ({}) must be less than dimension size ({})",
k, input_shape[dim]
)));
}
let input_data = input
.data
.as_slice()
.ok_or_else(|| RusTorchError::tensor_op("Input tensor data not accessible"))?;
let outer_size: usize = input_shape[..dim].iter().product();
let dim_size = input_shape[dim];
let inner_size: usize = input_shape[dim + 1..].iter().product();
let mut output_shape = input_shape.to_vec();
if keepdim {
output_shape[dim] = 1;
} else {
output_shape.remove(dim);
}
let output_size: usize = output_shape.iter().product();
let mut values = Vec::with_capacity(output_size);
let mut indices = Vec::with_capacity(output_size);
for outer_idx in 0..outer_size {
for inner_idx in 0..inner_size {
let mut slice_data = Vec::with_capacity(dim_size);
for dim_idx in 0..dim_size {
let flat_idx =
outer_idx * dim_size * inner_size + dim_idx * inner_size + inner_idx;
if flat_idx < input_data.len() {
slice_data.push((input_data[flat_idx], dim_idx));
}
}
if slice_data.is_empty() {
return Err(RusTorchError::tensor_op(
"kthvalue: No elements found in slice",
));
}
slice_data.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
if k < slice_data.len() {
values.push(slice_data[k].0);
indices.push(slice_data[k].1 as i64);
} else {
return Err(RusTorchError::tensor_op(
"kthvalue: k index exceeds available elements",
));
}
}
}
Ok((
Tensor::from_vec(values, output_shape.clone()),
ArrayD::from_shape_vec(output_shape, indices)
.map_err(|_| RusTorchError::tensor_op("Failed to create indices array"))?,
))
}
pub fn quantile_util<T: Float + 'static + std::fmt::Display>(
input: &Tensor<T>,
q: &Tensor<T>,
dim: Option<usize>,
keepdim: bool,
) -> RusTorchResult<Tensor<T>> {
let q_data = q
.data
.as_slice()
.ok_or_else(|| RusTorchError::tensor_op("Quantile tensor data not accessible"))?;
for &q_val in q_data.iter() {
if q_val < T::zero() || q_val > T::one() {
return Err(RusTorchError::tensor_op(format!(
"Quantile values must be in [0, 1], got {}",
q_val
)));
}
}
let input_data = input
.data
.as_slice()
.ok_or_else(|| RusTorchError::tensor_op("Input tensor data not accessible"))?;
match dim {
Some(dim_val) => {
if dim_val >= input.shape().len() {
return Err(RusTorchError::invalid_dimension(
dim_val,
input.shape().len().saturating_sub(1),
));
}
let input_shape = input.shape();
let outer_size: usize = input_shape[..dim_val].iter().product();
let dim_size = input_shape[dim_val];
let inner_size: usize = input_shape[dim_val + 1..].iter().product();
let mut output_shape = input_shape.to_vec();
if keepdim {
output_shape[dim_val] = q_data.len();
} else {
output_shape.remove(dim_val);
output_shape.insert(dim_val, q_data.len());
}
let mut result_data = Vec::new();
for outer_idx in 0..outer_size {
for inner_idx in 0..inner_size {
let mut slice_values = Vec::with_capacity(dim_size);
for dim_idx in 0..dim_size {
let flat_idx = outer_idx * dim_size * inner_size
+ dim_idx * inner_size
+ inner_idx;
if flat_idx < input_data.len() {
slice_values.push(input_data[flat_idx]);
}
}
slice_values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
for &q_val in q_data.iter() {
let quantile_val = compute_quantile(&slice_values, q_val);
result_data.push(quantile_val);
}
}
}
Ok(Tensor::from_vec(result_data, output_shape))
}
None => {
let mut sorted_data = input_data.to_vec();
sorted_data.sort_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal));
let mut result_data = Vec::with_capacity(q_data.len());
for &q_val in q_data.iter() {
let quantile_val = compute_quantile(&sorted_data, q_val);
result_data.push(quantile_val);
}
Ok(Tensor::from_vec(result_data, vec![q_data.len()]))
}
}
}
fn compute_quantile<T: Float>(sorted_data: &[T], q: T) -> T {
if sorted_data.is_empty() {
return T::zero();
}
if sorted_data.len() == 1 {
return sorted_data[0];
}
let n = sorted_data.len();
let index = q * T::from(n - 1).unwrap();
let lower_idx = index.floor().to_usize().unwrap_or(0).min(n - 1);
let upper_idx = index.ceil().to_usize().unwrap_or(0).min(n - 1);
if lower_idx == upper_idx {
sorted_data[lower_idx]
} else {
let fraction = index - T::from(lower_idx).unwrap();
let lower_val = sorted_data[lower_idx];
let upper_val = sorted_data[upper_idx];
lower_val + fraction * (upper_val - lower_val)
}
}
}
pub mod advanced {
use super::*;
pub fn unique<T: Float + 'static>(
input: &Tensor<T>,
sorted: bool,
return_inverse: bool,
return_counts: bool,
dim: Option<usize>,
) -> RusTorchResult<(Tensor<T>, Option<ArrayD<i64>>, Option<ArrayD<i64>>)> {
let input_data = input
.data
.as_slice()
.ok_or_else(|| RusTorchError::invalid_parameter("Input tensor data not accessible"))?;
match dim {
None => {
let mut indexed_values: Vec<(T, usize)> = input_data
.iter()
.enumerate()
.map(|(i, &val)| (val, i))
.collect();
if sorted {
indexed_values.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
}
let mut unique_values = Vec::new();
let mut inverse_indices = vec![0i64; input_data.len()];
let mut counts = Vec::new();
if !indexed_values.is_empty() {
let mut current_value = indexed_values[0].0;
let mut current_count = 1;
unique_values.push(current_value);
let mut unique_idx = 0;
inverse_indices[indexed_values[0].1] = unique_idx;
for i in 1..indexed_values.len() {
let (value, original_idx) = indexed_values[i];
if (value - current_value).abs() < T::from(1e-7).unwrap() {
current_count += 1;
inverse_indices[original_idx] = unique_idx;
} else {
counts.push(current_count as i64);
current_value = value;
current_count = 1;
unique_idx += 1;
unique_values.push(current_value);
inverse_indices[original_idx] = unique_idx;
}
}
counts.push(current_count as i64);
}
let inverse_tensor = if return_inverse {
Some(
ArrayD::from_shape_vec(input.shape().to_vec(), inverse_indices).map_err(
|_| {
RusTorchError::invalid_parameter(
"Invalid shape for inverse indices".to_string(),
)
},
)?,
)
} else {
None
};
let counts_tensor = if return_counts {
Some(
ArrayD::from_shape_vec(vec![unique_values.len()], counts).map_err(
|_| {
RusTorchError::invalid_parameter(
"Invalid shape for counts".to_string(),
)
},
)?,
)
} else {
None
};
let unique_len = unique_values.len();
Ok((
Tensor::from_vec(unique_values, vec![unique_len]),
inverse_tensor,
counts_tensor,
))
}
Some(_dim) => {
Err(RusTorchError::UnsupportedOperation(
"Unique along specific dimension not yet implemented",
))
}
}
}
pub fn histogram<T: Float + 'static>(
input: &Tensor<T>,
bins: usize,
range: Option<(T, T)>,
density: bool,
) -> RusTorchResult<(ArrayD<i64>, Tensor<T>)> {
let input_data = input
.data
.as_slice()
.ok_or_else(|| RusTorchError::invalid_parameter("Input tensor data not accessible"))?;
if input_data.is_empty() {
return Err(RusTorchError::invalid_parameter(
"Cannot compute histogram of empty tensor",
));
}
if bins == 0 {
return Err(RusTorchError::invalid_parameter(
"Number of bins must be positive",
));
}
let (min_val, max_val) = match range {
Some((min, max)) => {
if min >= max {
return Err(RusTorchError::invalid_parameter(
"Range min must be less than max",
));
}
(min, max)
}
None => {
let min_val = input_data
.iter()
.min_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal))
.copied()
.unwrap();
let max_val = input_data
.iter()
.max_by(|a, b| a.partial_cmp(b).unwrap_or(Ordering::Equal))
.copied()
.unwrap();
(min_val, max_val)
}
};
let bin_width = (max_val - min_val) / T::from(bins).unwrap();
let mut bin_edges = Vec::with_capacity(bins + 1);
let mut bin_counts = vec![0i64; bins];
for i in 0..=bins {
bin_edges.push(min_val + T::from(i).unwrap() * bin_width);
}
for &value in input_data.iter() {
if value >= min_val && value <= max_val {
let bin_idx = if value == max_val {
bins - 1 } else {
let idx = ((value - min_val) / bin_width)
.floor()
.to_usize()
.unwrap_or(0);
idx.min(bins - 1)
};
bin_counts[bin_idx] += 1;
}
}
if density {
return Err(RusTorchError::tensor_op(
"histogram: Density mode not supported with integer counts, use raw counts and normalize separately"
));
}
Ok((
ArrayD::from_shape_vec(vec![bins], bin_counts).map_err(|_| {
RusTorchError::invalid_parameter("Invalid shape for histogram counts".to_string())
})?,
Tensor::from_vec(bin_edges, vec![bins + 1]),
))
}
}