use crate::dtype::Float;
use crate::dtype_dispatch::{is_f32, is_f64};
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::storage::TensorStorage;
use crate::tensor::Tensor;
pub fn searchsorted<T: Float>(
boundaries: &Tensor<T>,
values: &Tensor<T>,
right: bool,
) -> FerrotorchResult<Vec<usize>> {
if boundaries.ndim() != 1 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"searchsorted: boundaries must be 1-D, got shape {:?}",
boundaries.shape()
),
});
}
if boundaries.is_cuda() && values.is_cuda() && (is_f32::<T>() || is_f64::<T>()) {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let values = values.contiguous()?;
let boundaries = boundaries.contiguous()?;
let idx_handle =
backend.searchsorted_1d(values.gpu_handle()?, boundaries.gpu_handle()?, right)?;
let bytes = backend.gpu_to_cpu(&idx_handle)?;
let n = values.numel();
if bytes.len() < n * 8 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"searchsorted: GPU returned {} bytes, expected >= {} (8 per index)",
bytes.len(),
n * 8
),
});
}
let result: Vec<usize> = bytes
.chunks_exact(8)
.take(n)
.map(|c| {
let mut buf = [0u8; 8];
buf.copy_from_slice(c);
i64::from_le_bytes(buf) as usize
})
.collect();
return Ok(result);
}
if boundaries.is_cuda() || values.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "searchsorted" });
}
let bounds = boundaries.data()?;
let vals = values.data_vec()?;
#[allow(
clippy::neg_cmp_op_on_partial_ord,
reason = "matches pytorch Bucketization.cu:33,51 NaN advance semantics; \
`!(b >= v)` differs from `b < v` for NaN val (advances to len)"
)]
let result: Vec<usize> = vals
.iter()
.map(|v| {
if right {
bounds.partition_point(|b| !(*b > *v))
} else {
bounds.partition_point(|b| !(*b >= *v))
}
})
.collect();
Ok(result)
}
pub fn bucketize<T: Float>(
input: &Tensor<T>,
boundaries: &Tensor<T>,
right: bool,
) -> FerrotorchResult<Vec<usize>> {
searchsorted(boundaries, input, right)
}
pub fn unique<T: Float>(
input: &Tensor<T>,
) -> FerrotorchResult<(Tensor<T>, Vec<usize>, Vec<usize>)> {
if input.is_cuda() && (is_f32::<T>() || is_f64::<T>()) {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let input = input.contiguous()?;
let n = input.numel();
let (values_handle, inverse, counts) = backend.unique_1d(input.gpu_handle()?, n)?;
let out_len = values_handle.len();
let values_tensor =
Tensor::from_storage(TensorStorage::gpu(values_handle), vec![out_len], false)?;
return Ok((values_tensor, inverse, counts));
}
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "unique" });
}
let data = input.data_vec()?;
let n = data.len();
if n == 0 {
return Ok((
Tensor::from_storage(TensorStorage::cpu(vec![]), vec![0], false)?,
vec![],
vec![],
));
}
let mut indices: Vec<usize> = (0..n).collect();
indices.sort_by(|&a, &b| nan_is_max_cmp(data[a], data[b]));
let mut unique_vals: Vec<T> = Vec::new();
let mut inverse = vec![0usize; n];
let mut counts: Vec<usize> = Vec::new();
let mut current_unique_idx = 0;
unique_vals.push(data[indices[0]]);
counts.push(1);
inverse[indices[0]] = 0;
for &orig_idx in &indices[1..] {
let val = data[orig_idx];
if val != *unique_vals.last().unwrap() {
unique_vals.push(val);
counts.push(0);
current_unique_idx += 1;
}
inverse[orig_idx] = current_unique_idx;
counts[current_unique_idx] += 1;
}
let unique_len = unique_vals.len();
let unique_tensor =
Tensor::from_storage(TensorStorage::cpu(unique_vals), vec![unique_len], false)?;
Ok((unique_tensor, inverse, counts))
}
pub fn unique_consecutive<T: Float>(
input: &Tensor<T>,
) -> FerrotorchResult<(Tensor<T>, Vec<usize>, Vec<usize>)> {
if input.is_cuda() && (is_f32::<T>() || is_f64::<T>()) {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let input = input.contiguous()?;
let n = input.numel();
let (values_handle, inverse, counts) =
backend.unique_consecutive_1d(input.gpu_handle()?, n)?;
let out_len = values_handle.len();
let output_tensor =
Tensor::from_storage(TensorStorage::gpu(values_handle), vec![out_len], false)?;
return Ok((output_tensor, inverse, counts));
}
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "unique_consecutive",
});
}
let data = input.data_vec()?;
let n = data.len();
if n == 0 {
return Ok((
Tensor::from_storage(TensorStorage::cpu(vec![]), vec![0], false)?,
vec![],
vec![],
));
}
let mut output: Vec<T> = vec![data[0]];
let mut inverse = vec![0usize; n];
let mut counts: Vec<usize> = vec![1];
for i in 1..n {
if data[i] == data[i - 1] {
*counts.last_mut().unwrap() += 1;
} else {
output.push(data[i]);
counts.push(1);
}
inverse[i] = output.len() - 1;
}
let out_len = output.len();
let output_tensor = Tensor::from_storage(TensorStorage::cpu(output), vec![out_len], false)?;
Ok((output_tensor, inverse, counts))
}
#[allow(
clippy::float_cmp,
reason = "exact `==` mirrors upstream's `min == max` / `minvalue == maxvalue` \
degenerate-range checks (aten/src/ATen/native/cuda/SummaryOps.cu:328,333); \
the bit-exact comparison IS the upstream contract for range inference"
)]
pub fn histc<T: Float>(
input: &Tensor<T>,
bins: usize,
min_val: f64,
max_val: f64,
) -> FerrotorchResult<Tensor<T>> {
if bins == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "histc: bins must be > 0".into(),
});
}
let (min_val, max_val) = if min_val == max_val {
let numel = input.numel();
if numel > 0 {
let data = input.data_vec()?;
let mut mn = f64::INFINITY;
let mut mx = f64::NEG_INFINITY;
for &v in &data {
if let Some(f) = num_traits::ToPrimitive::to_f64(&v) {
if f < mn {
mn = f;
}
if f > mx {
mx = f;
}
}
}
if !mn.is_finite() || !mx.is_finite() {
mn = min_val;
mx = max_val;
}
if mn == mx {
(mn - 1.0, mx + 1.0)
} else {
(mn, mx)
}
} else {
(min_val - 1.0, max_val + 1.0)
}
} else if min_val > max_val {
return Err(FerrotorchError::InvalidArgument {
message: format!("histc: min ({min_val}) must be <= max ({max_val})"),
});
} else {
(min_val, max_val)
};
if input.is_cuda() && (is_f32::<T>() || is_f64::<T>()) {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let input = input.contiguous()?;
let counts_handle = backend.histc_1d(input.gpu_handle()?, bins, min_val, max_val)?;
return Tensor::from_storage(TensorStorage::gpu(counts_handle), vec![bins], false);
}
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "histc" });
}
let data = input.data_vec()?;
let mut counts = vec![<T as num_traits::Zero>::zero(); bins];
let range = max_val - min_val;
let bin_width = range / bins as f64;
for &v in &data {
let f = match num_traits::ToPrimitive::to_f64(&v) {
Some(f) => f,
None => continue,
};
if !(f >= min_val && f <= max_val) {
continue;
}
let idx = ((f - min_val) / bin_width) as usize;
let idx = idx.min(bins - 1);
counts[idx] += <T as num_traits::One>::one();
}
Tensor::from_storage(TensorStorage::cpu(counts), vec![bins], false)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MeshIndexing {
Ij,
Xy,
}
pub fn meshgrid<T: Float>(tensors: &[Tensor<T>]) -> FerrotorchResult<Vec<Tensor<T>>> {
meshgrid_indexing(tensors, MeshIndexing::Ij)
}
pub fn meshgrid_indexing<T: Float>(
tensors: &[Tensor<T>],
indexing: MeshIndexing,
) -> FerrotorchResult<Vec<Tensor<T>>> {
if tensors.is_empty() {
return Ok(vec![]);
}
if indexing == MeshIndexing::Xy && tensors.len() >= 2 {
let mut swapped: Vec<Tensor<T>> = Vec::with_capacity(tensors.len());
swapped.push(tensors[1].clone());
swapped.push(tensors[0].clone());
swapped.extend(tensors[2..].iter().cloned());
let mut grids = meshgrid_indexing(&swapped, MeshIndexing::Ij)?;
grids.swap(0, 1);
return Ok(grids);
}
let all_cuda = tensors.iter().all(|t| t.is_cuda());
for t in tensors {
if t.ndim() != 1 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"meshgrid: all inputs must be 1-D, got shape {:?}",
t.shape()
),
});
}
if t.is_cuda() != all_cuda {
return Err(FerrotorchError::InvalidArgument {
message: "meshgrid: all inputs must be on the same device".into(),
});
}
}
let shapes: Vec<usize> = tensors.iter().map(|t| t.shape()[0]).collect();
let ndim = shapes.len();
let total: usize = shapes.iter().product();
if all_cuda && (is_f32::<T>() || is_f64::<T>()) {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let mut result = Vec::with_capacity(ndim);
for (dim, t) in tensors.iter().enumerate() {
let inner: usize = shapes[dim + 1..].iter().product();
let t = t.contiguous()?;
let grid_handle = backend.meshgrid_grid(t.gpu_handle()?, total, inner, shapes[dim])?;
result.push(Tensor::from_storage(
TensorStorage::gpu(grid_handle),
shapes.clone(),
false,
)?);
}
return Ok(result);
}
if all_cuda {
return Err(FerrotorchError::NotImplementedOnCuda { op: "meshgrid" });
}
let mut result = Vec::with_capacity(ndim);
for (dim, t) in tensors.iter().enumerate() {
let data = t.data()?;
let mut grid = Vec::with_capacity(total);
let inner: usize = shapes[dim + 1..].iter().product();
let outer_stride = shapes[dim] * inner;
for flat in 0..total {
let coord = (flat / inner) % shapes[dim];
grid.push(data[coord]);
}
let _ = outer_stride;
result.push(Tensor::from_storage(
TensorStorage::cpu(grid),
shapes.clone(),
false,
)?);
}
Ok(result)
}
fn nan_is_max_cmp<T: Float>(lhs: T, rhs: T) -> std::cmp::Ordering {
use std::cmp::Ordering;
match (lhs.is_nan(), rhs.is_nan()) {
(true, true) => Ordering::Equal,
(true, false) => Ordering::Greater, (false, true) => Ordering::Less,
(false, false) => lhs.partial_cmp(&rhs).unwrap_or(Ordering::Equal),
}
}
pub fn topk<T: Float>(
input: &Tensor<T>,
k: usize,
largest: bool,
) -> FerrotorchResult<(Tensor<T>, Vec<usize>)> {
if input.ndim() == 0 {
return Err(FerrotorchError::InvalidArgument {
message: "topk: input must have at least 1 dimension".into(),
});
}
let shape = input.shape();
let last_dim = *shape.last().unwrap();
if k > last_dim {
return Err(FerrotorchError::InvalidArgument {
message: format!("topk: k ({k}) > last dimension size ({last_dim})"),
});
}
if input.is_cuda() && (is_f32::<T>() || is_f64::<T>()) {
let backend =
crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let input = input.contiguous()?;
let outer = input.numel() / last_dim;
let (val_handle, idx_handle) =
backend.topk_1d(input.gpu_handle()?, outer, last_dim, k, largest)?;
let bytes = backend.gpu_to_cpu(&idx_handle)?;
let n = outer * k;
if bytes.len() < n * 8 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"topk: GPU returned {} bytes for indices, expected >= {} (8 per index)",
bytes.len(),
n * 8
),
});
}
let out_indices: Vec<usize> = bytes
.chunks_exact(8)
.take(n)
.map(|c| {
let mut buf = [0u8; 8];
buf.copy_from_slice(c);
i64::from_le_bytes(buf) as usize
})
.collect();
let mut out_shape = shape.to_vec();
*out_shape.last_mut().unwrap() = k;
let values = Tensor::from_storage(
crate::storage::TensorStorage::gpu(val_handle),
out_shape,
false,
)?;
return Ok((values, out_indices));
}
if input.is_cuda() {
return Err(FerrotorchError::NotImplementedOnCuda { op: "topk" });
}
let data = input.data_vec()?;
let outer: usize = data.len() / last_dim;
let mut out_values = Vec::with_capacity(outer * k);
let mut out_indices = Vec::with_capacity(outer * k);
for o in 0..outer {
let slice = &data[o * last_dim..(o + 1) * last_dim];
let mut idx: Vec<usize> = (0..last_dim).collect();
if largest {
idx.sort_by(|&a, &b| nan_is_max_cmp(slice[b], slice[a]));
} else {
idx.sort_by(|&a, &b| nan_is_max_cmp(slice[a], slice[b]));
}
for &i in &idx[..k] {
out_values.push(slice[i]);
out_indices.push(i);
}
}
let mut out_shape = shape.to_vec();
*out_shape.last_mut().unwrap() = k;
let values = Tensor::from_storage(TensorStorage::cpu(out_values), out_shape, false)?;
Ok((values, out_indices))
}
#[cfg(test)]
#[allow(
clippy::excessive_precision,
clippy::float_cmp,
reason = "oracle expected values from live torch 2.11; full precision intentional (rounds to dtype at compile time); float comparisons are deliberately exact byte-for-byte parity checks"
)]
mod tests {
use super::*;
fn tensor_1d(data: &[f32]) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), vec![data.len()], false).unwrap()
}
#[test]
fn test_searchsorted_right() {
let bounds = tensor_1d(&[1.0, 3.0, 5.0, 7.0]);
let values = tensor_1d(&[0.0, 2.0, 3.0, 6.0, 8.0]);
let result = searchsorted(&bounds, &values, true).unwrap();
assert_eq!(result, vec![0, 1, 2, 3, 4]);
}
#[test]
fn test_searchsorted_left() {
let bounds = tensor_1d(&[1.0, 3.0, 5.0, 7.0]);
let values = tensor_1d(&[1.0, 3.0, 5.0, 7.0]);
let result = searchsorted(&bounds, &values, false).unwrap();
assert_eq!(result, vec![0, 1, 2, 3]);
}
#[test]
fn test_searchsorted_empty_bounds() {
let bounds = tensor_1d(&[]);
let values = tensor_1d(&[1.0, 2.0]);
let result = searchsorted(&bounds, &values, true).unwrap();
assert_eq!(result, vec![0, 0]);
}
#[test]
fn test_bucketize() {
let bounds = tensor_1d(&[0.0, 1.0, 2.0, 3.0]);
let input = tensor_1d(&[-0.5, 0.5, 1.5, 2.5, 3.5]);
let result = bucketize(&input, &bounds, false).unwrap();
assert_eq!(result, vec![0, 1, 2, 3, 4]);
}
#[test]
#[allow(clippy::float_cmp)]
fn test_unique_sorted() {
let input = tensor_1d(&[3.0, 1.0, 2.0, 1.0, 3.0, 2.0]);
let (unique, inverse, counts) = unique(&input).unwrap();
let unique_data = unique.data().unwrap();
assert_eq!(unique_data, &[1.0, 2.0, 3.0]);
assert_eq!(counts, vec![2, 2, 2]);
let input_data = input.data().unwrap();
for i in 0..6 {
assert_eq!(unique_data[inverse[i]], input_data[i]);
}
}
#[test]
fn test_unique_empty() {
let input = tensor_1d(&[]);
let (unique, inverse, counts) = unique(&input).unwrap();
assert_eq!(unique.numel(), 0);
assert!(inverse.is_empty());
assert!(counts.is_empty());
}
#[test]
fn test_unique_all_same() {
let input = tensor_1d(&[5.0, 5.0, 5.0]);
let (unique, _inverse, counts) = unique(&input).unwrap();
assert_eq!(unique.data().unwrap(), &[5.0]);
assert_eq!(counts, vec![3]);
}
#[test]
fn test_unique_consecutive_basic() {
let input = tensor_1d(&[1.0, 1.0, 2.0, 2.0, 2.0, 3.0, 1.0, 1.0]);
let (output, inverse, counts) = unique_consecutive(&input).unwrap();
let out_data = output.data().unwrap();
assert_eq!(out_data, &[1.0, 2.0, 3.0, 1.0]);
assert_eq!(counts, vec![2, 3, 1, 2]);
assert_eq!(inverse, vec![0, 0, 1, 1, 1, 2, 3, 3]);
}
#[test]
fn test_unique_consecutive_no_duplicates() {
let input = tensor_1d(&[1.0, 2.0, 3.0]);
let (output, _inverse, counts) = unique_consecutive(&input).unwrap();
assert_eq!(output.data().unwrap(), &[1.0, 2.0, 3.0]);
assert_eq!(counts, vec![1, 1, 1]);
}
#[test]
fn test_unique_consecutive_empty() {
let input = tensor_1d(&[]);
let (output, inverse, counts) = unique_consecutive(&input).unwrap();
assert_eq!(output.numel(), 0);
assert!(inverse.is_empty());
assert!(counts.is_empty());
}
#[test]
fn test_histc_basic() {
let input = tensor_1d(&[0.5, 1.5, 2.5, 3.5, 1.5]);
let hist = histc(&input, 4, 0.0, 4.0).unwrap();
let data = hist.data().unwrap();
assert_eq!(data, &[1.0, 2.0, 1.0, 1.0]);
}
#[test]
fn test_histc_skips_out_of_range() {
let input = tensor_1d(&[-1.0, 5.0, 0.5]);
let hist = histc(&input, 2, 0.0, 2.0).unwrap();
let data = hist.data().unwrap();
assert_eq!(data, &[1.0, 0.0]);
}
#[test]
fn test_histc_skips_nan() {
let input = tensor_1d(&[0.5, f32::NAN, 1.5]);
let hist = histc(&input, 2, 0.0, 2.0).unwrap();
assert_eq!(hist.data().unwrap(), &[1.0, 1.0]);
}
#[test]
fn test_histc_default_minmax_infers_range() {
let input = tensor_1d(&[1.0, 2.0, 3.0, 4.0, 5.0]);
let hist = histc(&input, 4, 0.0, 0.0).unwrap();
assert_eq!(hist.data().unwrap(), &[1.0, 1.0, 1.0, 2.0]);
}
#[test]
fn test_histc_default_minmax_all_equal_widens() {
let input = tensor_1d(&[3.0, 3.0, 3.0]);
let hist = histc(&input, 4, 0.0, 0.0).unwrap();
assert_eq!(hist.data().unwrap(), &[0.0, 0.0, 3.0, 0.0]);
}
#[test]
fn test_meshgrid_xy() {
let x = tensor_1d(&[1.0, 2.0, 3.0]);
let y = tensor_1d(&[4.0, 5.0]);
let grids = meshgrid_indexing(&[x, y], MeshIndexing::Xy).unwrap();
assert_eq!(grids.len(), 2);
assert_eq!(grids[0].shape(), &[2, 3]);
assert_eq!(grids[1].shape(), &[2, 3]);
assert_eq!(grids[0].data().unwrap(), &[1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
assert_eq!(grids[1].data().unwrap(), &[4.0, 4.0, 4.0, 5.0, 5.0, 5.0]);
}
#[test]
fn test_meshgrid_ij_default_unchanged() {
let x = tensor_1d(&[1.0, 2.0, 3.0]);
let y = tensor_1d(&[4.0, 5.0]);
let a = meshgrid(&[x.clone(), y.clone()]).unwrap();
let b = meshgrid_indexing(&[x, y], MeshIndexing::Ij).unwrap();
assert_eq!(a[0].data().unwrap(), b[0].data().unwrap());
assert_eq!(a[1].data().unwrap(), b[1].data().unwrap());
assert_eq!(a[0].shape(), &[3, 2]);
}
#[test]
fn test_meshgrid_2d() {
let x = tensor_1d(&[1.0, 2.0, 3.0]);
let y = tensor_1d(&[4.0, 5.0]);
let grids = meshgrid(&[x, y]).unwrap();
assert_eq!(grids.len(), 2);
assert_eq!(grids[0].shape(), &[3, 2]);
assert_eq!(grids[1].shape(), &[3, 2]);
let gx = grids[0].data().unwrap();
assert_eq!(gx, &[1.0, 1.0, 2.0, 2.0, 3.0, 3.0]);
let gy = grids[1].data().unwrap();
assert_eq!(gy, &[4.0, 5.0, 4.0, 5.0, 4.0, 5.0]);
}
#[test]
fn test_topk_largest() {
let input = tensor_1d(&[3.0, 1.0, 4.0, 1.0, 5.0, 9.0]);
let (values, indices) = topk(&input, 3, true).unwrap();
let vdata = values.data().unwrap();
assert_eq!(vdata, &[9.0, 5.0, 4.0]);
assert_eq!(indices, vec![5, 4, 2]);
}
#[test]
fn test_topk_smallest() {
let input = tensor_1d(&[3.0, 1.0, 4.0, 1.0, 5.0]);
let (values, indices) = topk(&input, 2, false).unwrap();
let vdata = values.data().unwrap();
assert_eq!(vdata, &[1.0, 1.0]);
assert_eq!(indices, vec![1, 3]);
}
#[test]
fn test_topk_k_exceeds_dim() {
let input = tensor_1d(&[1.0, 2.0]);
let result = topk(&input, 5, true);
assert!(result.is_err());
}
#[test]
fn test_topk_largest_nan_is_top() {
let input = tensor_1d(&[3.0, f32::NAN, 1.0, 5.0, f32::NAN, 2.0]);
let (values, indices) = topk(&input, 4, true).unwrap();
let vdata = values.data().unwrap();
assert!(
vdata[0].is_nan() && vdata[1].is_nan(),
"NaNs first: {vdata:?}"
);
assert_eq!(vdata[2], 5.0);
assert_eq!(vdata[3], 3.0);
assert_eq!(indices, vec![1, 4, 3, 0]);
}
#[test]
fn test_topk_smallest_nan_is_last() {
let input = tensor_1d(&[3.0, f32::NAN, 1.0, 5.0, f32::NAN, 2.0]);
let (v4, i4) = topk(&input, 4, false).unwrap();
let v4d = v4.data().unwrap();
assert!(v4d.iter().all(|v| !v.is_nan()), "no NaN at k=4: {v4d:?}");
assert_eq!(v4d, &[1.0, 2.0, 3.0, 5.0]);
assert_eq!(i4, vec![2, 5, 0, 3]);
let (v6, i6) = topk(&input, 6, false).unwrap();
let v6d = v6.data().unwrap();
assert_eq!(&v6d[..4], &[1.0, 2.0, 3.0, 5.0]);
assert!(v6d[4].is_nan() && v6d[5].is_nan(), "NaNs last: {v6d:?}");
assert_eq!(i6, vec![2, 5, 0, 3, 1, 4]);
}
}