use torsh_core::Result as TorshResult;
use torsh_tensor::{creation::zeros, Tensor};
pub fn unravel_index(indices: &Tensor, shape: &[usize]) -> TorshResult<Vec<Tensor>> {
let indices_shape = indices.shape();
if indices_shape.ndim() != 1 {
return Err(torsh_core::TorshError::invalid_argument_with_context(
"Indices must be 1-dimensional",
"unravel_index",
));
}
if shape.is_empty() {
return Err(torsh_core::TorshError::invalid_argument_with_context(
"Shape cannot be empty",
"unravel_index",
));
}
let total_size: usize = shape.iter().product();
let indices_data = indices.to_vec()?;
for &idx in &indices_data {
if idx < 0.0 || idx as usize >= total_size {
return Err(torsh_core::TorshError::invalid_argument_with_context(
&format!("Index {} out of bounds for total size {}", idx, total_size),
"unravel_index",
));
}
}
let mut strides = vec![1; shape.len()];
for i in (0..shape.len().saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
let num_indices = indices_shape.dims()[0];
let mut results = Vec::with_capacity(shape.len());
for _ in 0..shape.len() {
results.push(zeros(&[num_indices])?);
}
for i in 0..num_indices {
let flat_idx = indices.get(&[i])? as usize;
let mut remaining = flat_idx;
for (j, &stride) in strides.iter().enumerate() {
let coord = remaining / stride;
results[j].set(&[i], coord as f32)?;
remaining %= stride;
}
}
Ok(results)
}
pub fn compute_strides(shape: &[usize]) -> Vec<usize> {
if shape.is_empty() {
return vec![];
}
let mut strides = vec![1; shape.len()];
for i in (0..shape.len().saturating_sub(1)).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
strides
}
pub fn ravel_multi_index(coords: &[Tensor], shape: &[usize]) -> TorshResult<Tensor> {
if coords.len() != shape.len() {
return Err(torsh_core::TorshError::invalid_argument_with_context(
"Number of coordinate tensors must match number of dimensions",
"ravel_multi_index",
));
}
if coords.is_empty() {
return Err(torsh_core::TorshError::invalid_argument_with_context(
"Coordinate tensors cannot be empty",
"ravel_multi_index",
));
}
let coord_shape = coords[0].shape();
for (i, coord) in coords.iter().enumerate() {
if coord.shape().dims() != coord_shape.dims() {
return Err(torsh_core::TorshError::invalid_argument_with_context(
&format!(
"All coordinate tensors must have the same shape, but coordinate {} differs",
i
),
"ravel_multi_index",
));
}
if coord.shape().ndim() != 1 {
return Err(torsh_core::TorshError::invalid_argument_with_context(
"All coordinate tensors must be 1-dimensional",
"ravel_multi_index",
));
}
}
let num_indices = coord_shape.dims()[0];
let strides = compute_strides(shape);
let result = zeros(&[num_indices])?;
for i in 0..num_indices {
let mut flat_idx = 0usize;
for (dim, coord_tensor) in coords.iter().enumerate() {
let coord = coord_tensor.get(&[i])? as usize;
if coord >= shape[dim] {
return Err(torsh_core::TorshError::invalid_argument_with_context(
&format!(
"Coordinate {} out of bounds for dimension {} with size {}",
coord, dim, shape[dim]
),
"ravel_multi_index",
));
}
flat_idx += coord * strides[dim];
}
result.set(&[i], flat_idx as f32)?;
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use torsh_core::DeviceType;
#[test]
fn test_unravel_index_2d() -> TorshResult<()> {
let indices = Tensor::from_data(vec![0.0f32, 1.0, 2.0, 3.0], vec![4], DeviceType::Cpu)?;
let shape = vec![2, 2];
let result = unravel_index(&indices, &shape)?;
assert_eq!(result.len(), 2); assert_eq!(result[0].shape().dims(), &[4]); assert_eq!(result[1].shape().dims(), &[4]);
assert_eq!(result[0].get(&[0])?, 0.0);
assert_eq!(result[1].get(&[0])?, 0.0);
assert_eq!(result[0].get(&[1])?, 0.0);
assert_eq!(result[1].get(&[1])?, 1.0);
assert_eq!(result[0].get(&[2])?, 1.0);
assert_eq!(result[1].get(&[2])?, 0.0);
assert_eq!(result[0].get(&[3])?, 1.0);
assert_eq!(result[1].get(&[3])?, 1.0);
Ok(())
}
#[test]
fn test_unravel_index_3d() -> TorshResult<()> {
let indices = Tensor::from_data(vec![0.0f32, 7.0, 15.0, 23.0], vec![4], DeviceType::Cpu)?;
let shape = vec![3, 4, 2]; let result = unravel_index(&indices, &shape)?;
assert_eq!(result.len(), 3);
assert_eq!(result[0].get(&[1])?, 0.0); assert_eq!(result[1].get(&[1])?, 3.0); assert_eq!(result[2].get(&[1])?, 1.0);
Ok(())
}
#[test]
fn test_compute_strides() {
let shape = vec![3, 4, 2];
let strides = compute_strides(&shape);
assert_eq!(strides, vec![8, 2, 1]);
let shape = vec![5];
let strides = compute_strides(&shape);
assert_eq!(strides, vec![1]);
let empty_shape: Vec<usize> = vec![];
let strides = compute_strides(&empty_shape);
assert_eq!(strides, Vec::<usize>::new());
}
#[test]
fn test_ravel_multi_index() -> TorshResult<()> {
let row_coords = Tensor::from_data(vec![0.0, 0.0, 1.0, 1.0], vec![4], DeviceType::Cpu)?;
let col_coords = Tensor::from_data(vec![0.0, 1.0, 0.0, 1.0], vec![4], DeviceType::Cpu)?;
let coords = vec![row_coords, col_coords];
let shape = vec![2, 2];
let flat_indices = ravel_multi_index(&coords, &shape)?;
assert_eq!(flat_indices.get(&[0])?, 0.0);
assert_eq!(flat_indices.get(&[1])?, 1.0);
assert_eq!(flat_indices.get(&[2])?, 2.0);
assert_eq!(flat_indices.get(&[3])?, 3.0);
Ok(())
}
#[test]
fn test_unravel_ravel_roundtrip() -> TorshResult<()> {
let original_indices =
Tensor::from_data(vec![0.0, 5.0, 10.0, 15.0], vec![4], DeviceType::Cpu)?;
let shape = vec![4, 4];
let coords = unravel_index(&original_indices, &shape)?;
let reconstructed = ravel_multi_index(&coords, &shape)?;
for i in 0..4 {
assert_eq!(original_indices.get(&[i])?, reconstructed.get(&[i])?);
}
Ok(())
}
#[test]
fn test_unravel_index_error_cases() {
let indices_2d =
Tensor::from_data(vec![0.0f32, 1.0, 2.0, 3.0], vec![2, 2], DeviceType::Cpu)
.expect("Tensor should succeed");
let shape = vec![2, 2];
assert!(unravel_index(&indices_2d, &shape).is_err());
let indices = Tensor::from_data(vec![4.0f32], vec![1], DeviceType::Cpu)
.expect("Tensor should succeed");
let shape = vec![2, 2]; assert!(unravel_index(&indices, &shape).is_err());
let indices = Tensor::from_data(vec![0.0f32], vec![1], DeviceType::Cpu)
.expect("Tensor should succeed");
let empty_shape: Vec<usize> = vec![];
assert!(unravel_index(&indices, &empty_shape).is_err());
}
#[test]
fn test_ravel_multi_index_error_cases() {
let coord = Tensor::from_data(vec![0.0f32], vec![1], DeviceType::Cpu)
.expect("Tensor should succeed");
let coords = vec![coord];
let shape = vec![2, 2]; assert!(ravel_multi_index(&coords, &shape).is_err());
let coord = Tensor::from_data(vec![2.0f32], vec![1], DeviceType::Cpu)
.expect("Tensor should succeed");
let coords = vec![coord];
let shape = vec![2]; assert!(ravel_multi_index(&coords, &shape).is_err());
}
#[test]
fn test_edge_case_1d_tensor() -> TorshResult<()> {
let indices = Tensor::from_data(vec![0.0, 1.0, 2.0], vec![3], DeviceType::Cpu)?;
let shape = vec![5];
let result = unravel_index(&indices, &shape)?;
assert_eq!(result.len(), 1);
assert_eq!(result[0].get(&[0])?, 0.0);
assert_eq!(result[0].get(&[1])?, 1.0);
assert_eq!(result[0].get(&[2])?, 2.0);
Ok(())
}
#[test]
fn test_large_tensor_indexing() -> TorshResult<()> {
let indices = Tensor::from_data(vec![0.0, 59.0, 35.0], vec![3], DeviceType::Cpu)?;
let shape = vec![5, 4, 3]; let result = unravel_index(&indices, &shape)?;
assert_eq!(result[0].get(&[1])?, 4.0); assert_eq!(result[1].get(&[1])?, 3.0); assert_eq!(result[2].get(&[1])?, 2.0);
assert_eq!(result[0].get(&[2])?, 2.0); assert_eq!(result[1].get(&[2])?, 3.0); assert_eq!(result[2].get(&[2])?, 2.0);
Ok(())
}
}