use std::collections::HashMap;
use torsh_core::{Result as TorshResult, TorshError};
use torsh_tensor::Tensor;
pub fn unique(
input: &Tensor,
sorted: bool,
return_inverse: bool,
return_counts: bool,
dim: Option<i32>,
) -> TorshResult<(Tensor, Option<Tensor>, Option<Tensor>)> {
if let Some(d) = dim {
return unique_dim(input, d, sorted, return_inverse, return_counts);
}
let flattened = input.view(&[-1])?;
let data = flattened.to_vec()?;
let len = flattened.shape().numel();
let values: Vec<f32> = data;
let mut unique_map: HashMap<i32, (usize, usize)> = HashMap::new(); let mut unique_values = Vec::new();
let mut inverse_indices_vec = vec![0usize; len];
for (idx, &val) in values.iter().enumerate() {
let key = val as i32;
match unique_map.get_mut(&key) {
Some((first_idx, count)) => {
*count += 1;
inverse_indices_vec[idx] = *first_idx;
}
None => {
let unique_idx = unique_values.len();
unique_map.insert(key, (unique_idx, 1));
unique_values.push(val);
inverse_indices_vec[idx] = unique_idx;
}
}
}
let mut sorted_indices: Vec<usize> = (0..unique_values.len()).collect();
if sorted {
sorted_indices.sort_by(|&a, &b| {
unique_values[a]
.partial_cmp(&unique_values[b])
.expect("numeric comparison should succeed")
});
let sorted_unique: Vec<f32> = sorted_indices.iter().map(|&i| unique_values[i]).collect();
unique_values = sorted_unique;
if return_inverse {
let mut index_map = vec![0; sorted_indices.len()];
for (new_idx, &old_idx) in sorted_indices.iter().enumerate() {
index_map[old_idx] = new_idx;
}
for idx in &mut inverse_indices_vec {
*idx = index_map[*idx];
}
}
}
let unique_len = unique_values.len();
let output = Tensor::from_vec(unique_values.clone(), &[unique_len])?;
let inverse_indices = if return_inverse {
let inverse_data: Vec<f32> = inverse_indices_vec.iter().map(|&i| i as f32).collect();
Some(Tensor::from_vec(inverse_data, &[len])?)
} else {
None
};
let counts = if return_counts {
let mut counts_vec = vec![0usize; unique_values.len()];
for (idx, count) in unique_map.values() {
if sorted {
let sorted_pos = sorted_indices
.iter()
.position(|&i| i == *idx)
.expect("idx must exist in sorted_indices as it came from the same unique_map");
counts_vec[sorted_pos] = *count;
} else {
counts_vec[*idx] = *count;
}
}
let counts_data: Vec<f32> = counts_vec.iter().map(|&c| c as f32).collect();
Some(Tensor::from_vec(counts_data, &[unique_values.len()])?)
} else {
None
};
Ok((output, inverse_indices, counts))
}
fn unique_dim(
input: &Tensor,
dim: i32,
sorted: bool,
return_inverse: bool,
return_counts: bool,
) -> TorshResult<(Tensor, Option<Tensor>, Option<Tensor>)> {
let shape = input.shape();
let ndim = shape.ndim() as i32;
let dim = if dim < 0 { ndim + dim } else { dim };
if dim < 0 || dim >= ndim {
return Err(TorshError::InvalidArgument(format!(
"Dimension {} out of range for tensor with {} dimensions",
dim, ndim
)));
}
let dim_usize = dim as usize;
let dims = shape.dims();
let dim_size = dims[dim_usize];
let mut perm: Vec<i32> = (0..ndim).collect();
perm.swap(0, dim_usize);
let permuted = input.permute(&perm)?;
let permuted_shape = permuted.shape();
let permuted_dims = permuted_shape.dims();
let mut slices: Vec<Vec<f32>> = Vec::new();
for i in 0..dim_size {
let slice = permuted.narrow(0, i as i64, 1)?;
let slice_data: Vec<f32> = slice.to_vec()?;
slices.push(slice_data);
}
let mut unique_slice_indices = Vec::new();
let mut inverse_indices_vec = vec![0usize; dim_size];
let mut counts_vec = Vec::new();
for (idx, slice_data) in slices.iter().enumerate() {
let mut found = false;
for (unique_idx, &unique_slice_idx) in unique_slice_indices.iter().enumerate() {
if slices_equal(&slices[unique_slice_idx], slice_data) {
inverse_indices_vec[idx] = unique_idx;
counts_vec[unique_idx] += 1;
found = true;
break;
}
}
if !found {
let unique_idx = unique_slice_indices.len();
unique_slice_indices.push(idx);
inverse_indices_vec[idx] = unique_idx;
counts_vec.push(1);
}
}
if sorted {
let mut sort_pairs: Vec<(usize, usize)> = unique_slice_indices
.iter()
.enumerate()
.map(|(unique_idx, &slice_idx)| (slice_idx, unique_idx))
.collect();
sort_pairs.sort_by_key(|(slice_idx, _)| *slice_idx);
let old_to_new: HashMap<usize, usize> = sort_pairs
.iter()
.enumerate()
.map(|(new_idx, &(_, old_idx))| (old_idx, new_idx))
.collect();
unique_slice_indices = sort_pairs.iter().map(|(slice_idx, _)| *slice_idx).collect();
for idx in &mut inverse_indices_vec {
*idx = old_to_new[idx];
}
let old_counts = counts_vec.clone();
for (new_idx, &(_, old_idx)) in sort_pairs.iter().enumerate() {
counts_vec[new_idx] = old_counts[old_idx];
}
}
let unique_count = unique_slice_indices.len();
let mut output_shape = permuted_dims.to_vec();
output_shape[0] = unique_count;
let mut output_data = Vec::new();
for &slice_idx in &unique_slice_indices {
output_data.extend(&slices[slice_idx]);
}
let output_permuted = Tensor::from_vec(output_data, &output_shape)?;
let mut inv_perm: Vec<i32> = vec![0; ndim as usize];
for (i, &p) in perm.iter().enumerate() {
inv_perm[p as usize] = i as i32;
}
let output = output_permuted.permute(&inv_perm)?;
let inverse_indices = if return_inverse {
let inverse_data: Vec<f32> = inverse_indices_vec.iter().map(|&i| i as f32).collect();
Some(Tensor::from_vec(inverse_data, &[dim_size])?)
} else {
None
};
let counts = if return_counts {
let counts_data: Vec<f32> = counts_vec.iter().map(|&c| c as f32).collect();
Some(Tensor::from_vec(counts_data, &[unique_count])?)
} else {
None
};
Ok((output, inverse_indices, counts))
}
fn slices_equal(a: &Vec<f32>, b: &Vec<f32>) -> bool {
if a.len() != b.len() {
return false;
}
a.iter().zip(b.iter()).all(|(x, y)| {
(x - y).abs() < 1e-6
})
}
pub fn bincount(
input: &Tensor,
weights: Option<&Tensor>,
minlength: Option<usize>,
) -> TorshResult<Tensor> {
if input.ndim() != 1 {
return Err(TorshError::dimension_error(
"input must be 1-dimensional",
"bincount",
));
}
if let Some(w) = weights {
if w.shape() != input.shape() {
return Err(TorshError::ShapeMismatch {
expected: input.shape().dims().to_vec(),
got: w.shape().dims().to_vec(),
});
}
}
let data = input.to_vec()?;
let values: Vec<f32> = data;
let mut max_val = 0i32;
for &val in &values {
if val < 0.0 || val.fract() != 0.0 {
return Err(TorshError::InvalidArgument(
"bincount: input must contain only non-negative integers".to_string(),
));
}
max_val = max_val.max(val as i32);
}
let output_size = if let Some(min_len) = minlength {
(max_val as usize + 1).max(min_len)
} else {
max_val as usize + 1
};
let mut counts = vec![0.0f32; output_size];
if let Some(weights_tensor) = weights {
let weights_data = weights_tensor.to_vec()?;
let weights_values: Vec<f32> = weights_data;
for (i, &val) in values.iter().enumerate() {
let idx = val as usize;
counts[idx] += weights_values[i];
}
} else {
for &val in &values {
let idx = val as usize;
counts[idx] += 1.0;
}
}
Tensor::from_vec(counts, &[output_size])
}
pub fn histogram(
input: &Tensor,
bins: usize,
min: Option<f32>,
max: Option<f32>,
density: bool,
) -> TorshResult<(Tensor, Tensor)> {
if bins == 0 {
return Err(TorshError::InvalidArgument(
"histogram: bins must be > 0".to_string(),
));
}
let flattened = input.view(&[-1])?;
let data = flattened.to_vec()?;
let len = flattened.shape().numel();
let values: Vec<f32> = data;
let min_val = min.unwrap_or_else(|| values.iter().cloned().fold(f32::INFINITY, f32::min));
let max_val = max.unwrap_or_else(|| values.iter().cloned().fold(f32::NEG_INFINITY, f32::max));
if min_val > max_val {
return Err(TorshError::InvalidArgument(
"histogram: min must be less than or equal to max".to_string(),
));
}
let range = if max_val == min_val {
1.0
} else {
max_val - min_val
};
let mut bin_edges = vec![0.0f32; bins + 1];
for i in 0..=bins {
bin_edges[i] = min_val + (i as f32 / bins as f32) * range;
}
let mut hist = vec![0.0f32; bins];
for &val in &values {
if val >= min_val && val <= max_val {
let mut bin_idx = ((val - min_val) / range * bins as f32) as usize;
if bin_idx >= bins {
bin_idx = bins - 1;
}
hist[bin_idx] += 1.0;
}
}
if density {
let bin_width = range / bins as f32;
let total_count = len as f32;
for h in &mut hist {
*h /= total_count * bin_width;
}
}
let hist_tensor = Tensor::from_vec(hist, &[bins])?;
let edges_tensor = Tensor::from_vec(bin_edges, &[bins + 1])?;
Ok((hist_tensor, edges_tensor))
}
pub fn histogram_with_edges(
input: &Tensor,
bin_edges: &Tensor,
density: bool,
) -> TorshResult<Tensor> {
if bin_edges.ndim() != 1 {
return Err(TorshError::dimension_error(
"bin_edges must be 1-dimensional",
"histogram",
));
}
let num_edges = bin_edges.shape().numel();
if num_edges < 2 {
return Err(TorshError::InvalidArgument(
"histogram: bin_edges must have at least 2 elements".to_string(),
));
}
let bins = num_edges - 1;
let edges_data = bin_edges.to_vec()?;
let edges: Vec<f32> = edges_data;
for i in 1..edges.len() {
if edges[i] < edges[i - 1] {
return Err(TorshError::InvalidArgument(
"histogram: bin_edges must be monotonically increasing".to_string(),
));
}
}
let flattened = input.view(&[-1])?;
let data = flattened.to_vec()?;
let len = flattened.shape().numel();
let values: Vec<f32> = data;
let mut hist = vec![0.0f32; bins];
for &val in &values {
if val >= edges[0] && val <= edges[bins] {
let mut left = 0;
let mut right = bins;
while left < right {
let mid = left + (right - left) / 2;
if val < edges[mid + 1] {
right = mid;
} else {
left = mid + 1;
}
}
if left >= bins {
left = bins - 1;
}
hist[left] += 1.0;
}
}
if density {
let total_count = len as f32;
for i in 0..bins {
let bin_width = edges[i + 1] - edges[i];
hist[i] /= total_count * bin_width;
}
}
Tensor::from_vec(hist, &[bins])
}
pub fn value_counts(input: &Tensor) -> TorshResult<Tensor> {
bincount(input, None, None)
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_tensor::tensor;
#[test]
fn test_unique_basic() {
let input = tensor![3.0, 1.0, 2.0, 1.0, 3.0, 2.0].unwrap();
let (output, inverse, counts) = unique(&input, true, true, true, None).unwrap();
let unique_data = output.data().expect("tensor should have data");
let unique_vals: Vec<f32> = unique_data.clone();
assert_eq!(unique_vals, vec![1.0, 2.0, 3.0]);
if let Some(counts_tensor) = counts {
let counts_data = counts_tensor.data().expect("tensor should have data");
let counts_vals: Vec<f32> = counts_data.clone();
assert_eq!(counts_vals, vec![2.0, 2.0, 2.0]);
}
if let Some(inv) = inverse {
let inv_data = inv.data().expect("tensor should have data");
let inv_vals: Vec<f32> = inv_data.clone();
let reconstructed: Vec<f32> = inv_vals
.iter()
.map(|&idx| unique_vals[idx as usize])
.collect();
assert_eq!(reconstructed, vec![3.0, 1.0, 2.0, 1.0, 3.0, 2.0]);
}
}
#[test]
fn test_bincount_basic() {
let input = tensor![0.0, 1.0, 1.0, 3.0, 2.0, 1.0, 3.0].unwrap();
let output = bincount(&input, None, None).unwrap();
let data = output.data().expect("tensor should have data");
let counts: Vec<f32> = data.clone();
assert_eq!(counts, vec![1.0, 3.0, 1.0, 2.0]);
}
#[test]
fn test_bincount_weighted() {
let input = tensor![0.0, 1.0, 1.0, 2.0, 2.0, 2.0].unwrap();
let weights = tensor![1.0, 2.0, 3.0, 4.0, 5.0, 6.0].unwrap();
let output = bincount(&input, Some(&weights), None).unwrap();
let data = output.data().expect("tensor should have data");
let weighted_counts: Vec<f32> = data.clone();
assert_eq!(weighted_counts, vec![1.0, 5.0, 15.0]);
}
#[test]
fn test_histogram_basic() {
let input = tensor![1.0, 2.0, 3.0, 4.0, 5.0].unwrap();
let (hist, edges) = histogram(&input, 5, Some(1.0), Some(5.0), false).unwrap();
let hist_data = hist.data().expect("tensor should have data");
let hist_vals: Vec<f32> = hist_data.clone();
assert_eq!(hist_vals, vec![1.0, 1.0, 1.0, 1.0, 1.0]);
let edges_data = edges.data().expect("tensor should have data");
let edges_vals: Vec<f32> = edges_data.clone();
assert_eq!(edges_vals.len(), 6); }
#[test]
fn test_histogram_density() {
let input = tensor![1.0, 2.0, 3.0, 4.0, 5.0].unwrap();
let (hist, _) = histogram(&input, 5, Some(1.0), Some(5.0), true).unwrap();
let hist_data = hist.data().expect("tensor should have data");
let hist_vals: Vec<f32> = hist_data.clone();
let bin_width = 4.0 / 5.0; let integral: f32 = hist_vals.iter().sum::<f32>() * bin_width;
assert!((integral - 1.0).abs() < 1e-6);
}
}