use crate::{Tensor, TensorElement};
use torsh_core::error::{Result, TorshError};
#[derive(Debug, Clone)]
pub enum TensorIndex {
Index(i64),
Range(Option<i64>, Option<i64>, Option<i64>), All,
List(Vec<i64>),
Mask(Tensor<bool>),
Ellipsis,
NewAxis,
}
impl TensorIndex {
pub fn range(start: Option<i64>, stop: Option<i64>) -> Self {
TensorIndex::Range(start, stop, None)
}
pub fn range_step(start: Option<i64>, stop: Option<i64>, step: i64) -> Self {
TensorIndex::Range(start, stop, Some(step))
}
}
impl<T: TensorElement> Tensor<T> {
pub fn index(&self, indices: &[TensorIndex]) -> Result<Self> {
let consuming_indices = indices
.iter()
.filter(|idx| !matches!(idx, TensorIndex::NewAxis | TensorIndex::Ellipsis))
.count();
if consuming_indices > self.ndim() {
return Err(TorshError::InvalidArgument(format!(
"Too many indices for tensor: tensor has {} dimensions but {} consuming indices were provided",
self.ndim(),
consuming_indices
)));
}
let expanded_indices = self.expand_ellipsis(indices)?;
let mut output_shape = Vec::new();
let mut slices = Vec::new();
let mut input_dim_idx = 0;
for index in expanded_indices.iter() {
if let TensorIndex::NewAxis = index {
output_shape.push(1);
slices.push((0, 1, 1));
continue;
}
let dim_size = if input_dim_idx < self.ndim() {
self.shape().dims()[input_dim_idx]
} else {
return Err(TorshError::InvalidArgument(format!(
"Index {} beyond tensor dimensions (tensor has {} dimensions)",
input_dim_idx,
self.ndim()
)));
};
match index {
TensorIndex::Index(idx) => {
let idx = if *idx < 0 {
(dim_size as i64 + idx) as usize
} else {
*idx as usize
};
if idx >= dim_size {
return Err(TorshError::IndexOutOfBounds {
index: idx,
size: dim_size,
});
}
slices.push((idx, idx + 1, 1));
input_dim_idx += 1;
}
TensorIndex::Range(start, stop, step) => {
let step = step.unwrap_or(1);
if step == 0 {
return Err(TorshError::InvalidArgument(
"Step cannot be zero".to_string(),
));
}
let start = start
.map(|s| {
if s < 0 {
(dim_size as i64 + s).max(0) as usize
} else {
s.min(dim_size as i64) as usize
}
})
.unwrap_or(0);
let stop = stop
.map(|s| {
if s < 0 {
(dim_size as i64 + s).max(0) as usize
} else {
s.min(dim_size as i64) as usize
}
})
.unwrap_or(dim_size);
let size = if step > 0 {
((stop as i64 - start as i64 + step - 1) / step).max(0) as usize
} else {
((stop as i64 - start as i64 + step + 1) / step).max(0) as usize
};
output_shape.push(size);
slices.push((start, stop, step as usize));
input_dim_idx += 1;
}
TensorIndex::All => {
output_shape.push(dim_size);
slices.push((0, dim_size, 1));
input_dim_idx += 1;
}
TensorIndex::List(indices_list) => {
for &idx in indices_list {
let normalized_idx = if idx < 0 {
(dim_size as i64 + idx) as usize
} else {
idx as usize
};
if normalized_idx >= dim_size {
return Err(TorshError::IndexOutOfBounds {
index: normalized_idx,
size: dim_size,
});
}
}
output_shape.push(indices_list.len());
slices.push((0, indices_list.len(), 0)); input_dim_idx += 1;
}
TensorIndex::Mask(mask) => {
if mask.ndim() != 1 {
return Err(TorshError::InvalidArgument(
"Boolean mask must be 1D for single dimension indexing".to_string(),
));
}
if mask.numel() != dim_size {
return Err(TorshError::ShapeMismatch {
expected: vec![dim_size],
got: mask.shape().dims().to_vec(),
});
}
let mask_data = mask.to_vec()?;
let true_count = mask_data.iter().filter(|&&x| x).count();
output_shape.push(true_count);
slices.push((0, true_count, 0)); input_dim_idx += 1;
}
TensorIndex::NewAxis => {
return Err(TorshError::InvalidArgument(
"NewAxis should be handled before this point".to_string(),
));
}
TensorIndex::Ellipsis => {
return Err(TorshError::InvalidArgument(
"Ellipsis should be expanded before processing".to_string(),
));
}
}
}
if output_shape.is_empty() {
output_shape.push(1);
}
if expanded_indices
.iter()
.any(|idx| matches!(idx, TensorIndex::List(_) | TensorIndex::Mask(_)))
{
self.extract_advanced_indexing(&expanded_indices, &output_shape)
} else {
self.extract_basic_indexing(&expanded_indices, &output_shape, &slices)
}
}
fn extract_basic_indexing(
&self,
indices: &[TensorIndex],
output_shape: &[usize],
slices: &[(usize, usize, usize)],
) -> Result<Self> {
let input_data = self.to_vec()?;
let output_size = output_shape.iter().product();
let mut output_data = Vec::with_capacity(output_size);
let input_strides = self.compute_strides();
let output_strides = compute_strides_from_shape(output_shape);
for out_idx in 0..output_size {
let mut out_indices = vec![0; output_shape.len()];
let mut remaining = out_idx;
for (i, &stride) in output_strides.iter().enumerate() {
out_indices[i] = remaining / stride;
remaining %= stride;
}
let mut input_flat_idx = 0;
let mut out_dim = 0;
let mut input_dim = 0;
for (slice_idx, &(start, _, step)) in slices.iter().enumerate() {
if slice_idx < indices.len() && matches!(indices[slice_idx], TensorIndex::NewAxis) {
out_dim += 1;
continue;
}
if input_dim >= input_strides.len() {
break;
}
let idx = if slice_idx < indices.len()
&& matches!(indices[slice_idx], TensorIndex::Index(_))
{
start
} else {
start + out_indices[out_dim] * step
};
input_flat_idx += idx * input_strides[input_dim];
if !(slice_idx < indices.len()
&& matches!(indices[slice_idx], TensorIndex::Index(_)))
{
out_dim += 1;
}
input_dim += 1;
}
output_data.push(input_data[input_flat_idx]);
}
Self::from_data(output_data, output_shape.to_vec(), self.device)
}
fn extract_advanced_indexing(
&self,
indices: &[TensorIndex],
output_shape: &[usize],
) -> Result<Self> {
let input_data = self.to_vec()?;
let output_size = output_shape.iter().product();
let mut output_data = Vec::with_capacity(output_size);
let input_strides = self.compute_strides();
let output_strides = compute_strides_from_shape(output_shape);
for out_idx in 0..output_size {
let mut out_indices = vec![0; output_shape.len()];
let mut remaining = out_idx;
for (i, &stride) in output_strides.iter().enumerate() {
out_indices[i] = remaining / stride;
remaining %= stride;
}
let mut input_flat_idx = 0;
let mut out_dim = 0;
for (dim_idx, index) in indices.iter().enumerate() {
if dim_idx >= self.ndim() {
break;
}
let input_idx = match index {
TensorIndex::Index(idx) => {
let dim_size = self.shape().dims()[dim_idx];
if *idx < 0 {
(dim_size as i64 + idx) as usize
} else {
*idx as usize
}
}
TensorIndex::Range(start, _stop, step) => {
let dim_size = self.shape().dims()[dim_idx];
let step = step.unwrap_or(1);
let start = start
.map(|s| {
if s < 0 {
(dim_size as i64 + s).max(0) as usize
} else {
s.min(dim_size as i64) as usize
}
})
.unwrap_or(0);
start + out_indices[out_dim] * (step as usize)
}
TensorIndex::All => out_indices[out_dim],
TensorIndex::List(indices_list) => {
let list_idx = out_indices[out_dim];
if list_idx >= indices_list.len() {
return Err(TorshError::IndexOutOfBounds {
index: list_idx,
size: indices_list.len(),
});
}
let actual_idx = indices_list[list_idx];
let dim_size = self.shape().dims()[dim_idx];
if actual_idx < 0 {
(dim_size as i64 + actual_idx) as usize
} else {
actual_idx as usize
}
}
TensorIndex::Mask(mask) => {
let mask_data = mask.to_vec()?;
let target_true_idx = out_indices[out_dim];
let mut true_count = 0;
let mut found_idx = None;
for (i, &mask_val) in mask_data.iter().enumerate() {
if mask_val {
if true_count == target_true_idx {
found_idx = Some(i);
break;
}
true_count += 1;
}
}
match found_idx {
Some(idx) => idx,
None => {
return Err(TorshError::IndexOutOfBounds {
index: target_true_idx,
size: true_count,
});
}
}
}
TensorIndex::NewAxis => {
continue;
}
TensorIndex::Ellipsis => {
out_indices[out_dim]
}
};
input_flat_idx += input_idx * input_strides[dim_idx];
if !matches!(index, TensorIndex::Index(_) | TensorIndex::NewAxis) {
out_dim += 1;
}
}
for stride in input_strides
.iter()
.skip(indices.len())
.take(self.ndim() - indices.len())
{
if out_dim < out_indices.len() {
input_flat_idx += out_indices[out_dim] * stride;
out_dim += 1;
}
}
if input_flat_idx >= input_data.len() {
return Err(TorshError::IndexOutOfBounds {
index: input_flat_idx,
size: input_data.len(),
});
}
output_data.push(input_data[input_flat_idx]);
}
Self::from_data(output_data, output_shape.to_vec(), self.device)
}
fn expand_ellipsis(&self, indices: &[TensorIndex]) -> Result<Vec<TensorIndex>> {
let mut expanded = Vec::new();
let mut found_ellipsis = false;
let non_expanding_indices = indices
.iter()
.filter(|idx| !matches!(idx, TensorIndex::Ellipsis | TensorIndex::NewAxis))
.count();
for index in indices {
match index {
TensorIndex::Ellipsis => {
if found_ellipsis {
return Err(TorshError::InvalidArgument(
"Only one ellipsis (...) is allowed per indexing operation".to_string(),
));
}
found_ellipsis = true;
let ellipsis_dims = if self.ndim() >= non_expanding_indices {
self.ndim() - non_expanding_indices
} else {
0
};
for _ in 0..ellipsis_dims {
expanded.push(TensorIndex::All);
}
}
_ => {
expanded.push(index.clone());
}
}
}
if !found_ellipsis {
let current_dims = expanded
.iter()
.filter(|idx| !matches!(idx, TensorIndex::NewAxis))
.count();
for _ in current_dims..self.ndim() {
expanded.push(TensorIndex::All);
}
}
Ok(expanded)
}
pub fn get_1d(&self, index: usize) -> Result<T> {
if self.ndim() != 1 {
return Err(TorshError::InvalidShape(
"get_1d() can only be used on 1D tensors".to_string(),
));
}
if index >= self.shape().dims()[0] {
return Err(TorshError::IndexOutOfBounds {
index,
size: self.shape().dims()[0],
});
}
let data = self.data()?;
Ok(data[index])
}
pub fn get_2d(&self, row: usize, col: usize) -> Result<T> {
if self.ndim() != 2 {
return Err(TorshError::InvalidShape(
"get_2d() can only be used on 2D tensors".to_string(),
));
}
let shape = self.shape();
if row >= shape.dims()[0] || col >= shape.dims()[1] {
return Err(TorshError::IndexOutOfBounds {
index: row * shape.dims()[1] + col,
size: shape.numel(),
});
}
let data = self.to_vec()?;
let index = row * shape.dims()[1] + col;
Ok(data[index])
}
pub fn get_3d(&self, x: usize, y: usize, z: usize) -> Result<T> {
if self.ndim() != 3 {
return Err(TorshError::InvalidShape(
"get_3d() can only be used on 3D tensors".to_string(),
));
}
let shape = self.shape();
if x >= shape.dims()[0] || y >= shape.dims()[1] || z >= shape.dims()[2] {
return Err(TorshError::IndexOutOfBounds {
index: x * shape.dims()[1] * shape.dims()[2] + y * shape.dims()[2] + z,
size: shape.numel(),
});
}
let data = self.to_vec()?;
let index = x * shape.dims()[1] * shape.dims()[2] + y * shape.dims()[2] + z;
Ok(data[index])
}
pub fn set_1d(&mut self, index: usize, value: T) -> Result<()> {
if self.ndim() != 1 {
return Err(TorshError::InvalidShape(
"set_1d() can only be used on 1D tensors".to_string(),
));
}
if index >= self.shape().dims()[0] {
return Err(TorshError::IndexOutOfBounds {
index,
size: self.shape().dims()[0],
});
}
let mut data = self.to_vec()?;
data[index] = value;
*self = Self::from_data(data, self.shape().dims().to_vec(), self.device())?;
Ok(())
}
pub fn set_2d(&mut self, row: usize, col: usize, value: T) -> Result<()> {
if self.ndim() != 2 {
return Err(TorshError::InvalidShape(
"set_2d() can only be used on 2D tensors".to_string(),
));
}
let shape = self.shape();
if row >= shape.dims()[0] || col >= shape.dims()[1] {
return Err(TorshError::IndexOutOfBounds {
index: row * shape.dims()[1] + col,
size: shape.numel(),
});
}
let mut data = self.to_vec()?;
let index = row * shape.dims()[1] + col;
data[index] = value;
*self = Self::from_data(data, self.shape().dims().to_vec(), self.device())?;
Ok(())
}
pub fn set_3d(&mut self, x: usize, y: usize, z: usize, value: T) -> Result<()> {
if self.ndim() != 3 {
return Err(TorshError::InvalidShape(
"set_3d() can only be used on 3D tensors".to_string(),
));
}
let shape = self.shape();
if x >= shape.dims()[0] || y >= shape.dims()[1] || z >= shape.dims()[2] {
return Err(TorshError::IndexOutOfBounds {
index: x * shape.dims()[1] * shape.dims()[2] + y * shape.dims()[2] + z,
size: shape.numel(),
});
}
let mut data = self.to_vec()?;
let index = x * shape.dims()[1] * shape.dims()[2] + y * shape.dims()[2] + z;
data[index] = value;
*self = Self::from_data(data, self.shape().dims().to_vec(), self.device())?;
Ok(())
}
pub fn select(&self, dim: i32, index: i64) -> Result<Self> {
let ndim = self.ndim() as i32;
let dim = if dim < 0 { ndim + dim } else { dim } as usize;
if dim >= self.ndim() {
return Err(TorshError::InvalidArgument(format!(
"Dimension {} out of range for tensor with {} dimensions",
dim,
self.ndim()
)));
}
let dim_size = self.shape().dims()[dim] as i64;
let index = if index < 0 { dim_size + index } else { index };
if index < 0 || index >= dim_size {
return Err(TorshError::IndexOutOfBounds {
index: index as usize,
size: dim_size as usize,
});
}
let mut indices = Vec::new();
for d in 0..self.ndim() {
if d == dim {
indices.push(TensorIndex::Index(index));
} else {
indices.push(TensorIndex::All);
}
}
self.index(&indices)
}
pub fn slice_with_step(
&self,
dim: i32,
start: Option<i64>,
end: Option<i64>,
step: Option<i64>,
) -> Result<Self> {
let ndim = self.ndim() as i32;
let dim = if dim < 0 { ndim + dim } else { dim } as usize;
if dim >= self.ndim() {
return Err(TorshError::InvalidArgument(format!(
"Dimension {} out of range for tensor with {} dimensions",
dim,
self.ndim()
)));
}
let mut indices = Vec::new();
for d in 0..self.ndim() {
if d == dim {
indices.push(TensorIndex::Range(start, end, step));
} else {
indices.push(TensorIndex::All);
}
}
self.index(&indices)
}
pub fn narrow(&self, dim: i32, start: i64, length: usize) -> Result<Self> {
let ndim = self.ndim() as i32;
let dim = if dim < 0 { ndim + dim } else { dim } as usize;
if dim >= self.ndim() {
return Err(TorshError::InvalidArgument(format!(
"Dimension {} out of range for tensor with {} dimensions",
dim,
self.ndim()
)));
}
let dim_size = self.shape().dims()[dim] as i64;
let start = if start < 0 { dim_size + start } else { start };
if start < 0 || start >= dim_size {
return Err(TorshError::InvalidArgument(format!(
"Start index {start} out of range for dimension {dim} with size {dim_size}"
)));
}
let end = start + length as i64;
if end > dim_size {
return Err(TorshError::InvalidArgument(format!(
"End index {end} out of range for dimension {dim} with size {dim_size}"
)));
}
let mut indices = Vec::new();
for d in 0..self.ndim() {
if d == dim {
indices.push(TensorIndex::Range(Some(start), Some(end), None));
} else {
indices.push(TensorIndex::All);
}
}
self.index(&indices)
}
pub fn masked_select(&self, mask: &Tensor<bool>) -> Result<Self> {
if self.shape() != mask.shape() {
return Err(TorshError::ShapeMismatch {
expected: self.shape().dims().to_vec(),
got: mask.shape().dims().to_vec(),
});
}
let self_data = self.data()?;
let mask_data = mask.data()?;
let mut selected_data = Vec::new();
for (i, &mask_val) in mask_data.iter().enumerate() {
if mask_val {
selected_data.push(self_data[i]);
}
}
Self::from_data(
selected_data.clone(),
vec![selected_data.len()],
self.device,
)
}
pub fn take(&self, indices: &Tensor<i64>) -> Result<Self> {
let self_data = self.data()?;
let indices_data = indices.data()?;
let self_size = self.shape().numel();
let output_shape = indices.shape().dims().to_vec();
let output_size = indices.shape().numel();
let mut output_data = Vec::with_capacity(output_size);
for &idx in indices_data.iter() {
let idx = if idx < 0 {
(self_size as i64 + idx) as usize
} else {
idx as usize
};
if idx >= self_size {
return Err(TorshError::IndexOutOfBounds {
index: idx,
size: self_size,
});
}
output_data.push(self_data[idx]);
}
Self::from_data(output_data, output_shape, self.device)
}
pub fn put(&self, indices: &Tensor<i64>, values: &Self) -> Result<Self> {
let self_data = self.data()?;
let indices_data = indices.data()?;
let values_data = values.data()?;
if indices.shape() != values.shape() {
return Err(TorshError::ShapeMismatch {
expected: indices.shape().dims().to_vec(),
got: values.shape().dims().to_vec(),
});
}
let self_size = self.shape().numel();
let mut output_data = self_data.clone();
for (i, &idx) in indices_data.iter().enumerate() {
let idx = if idx < 0 {
(self_size as i64 + idx) as usize
} else {
idx as usize
};
if idx >= self_size {
return Err(TorshError::IndexOutOfBounds {
index: idx,
size: self_size,
});
}
output_data[idx] = values_data[i];
}
Self::from_data(output_data, self.shape().dims().to_vec(), self.device)
}
pub fn index_select(&self, dim: i32, index: &Tensor<i64>) -> Result<Self> {
let ndim = self.ndim() as i32;
let dim = if dim < 0 { ndim + dim } else { dim } as usize;
if dim >= self.ndim() {
return Err(TorshError::InvalidArgument(format!(
"Dimension {} out of range for tensor with {} dimensions",
dim,
self.ndim()
)));
}
if index.ndim() != 1 {
return Err(TorshError::InvalidShape(
"index_select expects a 1D index tensor".to_string(),
));
}
let mut output_shape = self.shape().dims().to_vec();
output_shape[dim] = index.shape().dims()[0];
let output_size: usize = output_shape.iter().product();
let mut output_data = Vec::with_capacity(output_size);
let self_data = self.data()?;
let index_data = index.data()?;
let self_strides = self.compute_strides();
let _output_strides = Self::compute_strides_for_shape(&output_shape);
for out_idx in 0..output_size {
let mut indices = vec![0; self.ndim()];
let mut remaining = out_idx;
for i in (0..self.ndim()).rev() {
indices[i] = remaining % output_shape[i];
remaining /= output_shape[i];
}
let select_idx = indices[dim];
let selected_value = index_data[select_idx] as usize;
if selected_value >= self.shape().dims()[dim] {
return Err(TorshError::IndexOutOfBounds {
index: selected_value,
size: self.shape().dims()[dim],
});
}
indices[dim] = selected_value;
let src_flat_idx = indices
.iter()
.zip(&self_strides)
.map(|(idx, stride)| idx * stride)
.sum::<usize>();
output_data.push(self_data[src_flat_idx]);
}
Self::from_data(output_data, output_shape, self.device)
}
pub(crate) fn compute_strides(&self) -> Vec<usize> {
Self::compute_strides_for_shape(self.shape().dims())
}
pub(crate) fn compute_strides_for_shape(shape: &[usize]) -> Vec<usize> {
let mut strides = vec![1; shape.len()];
for i in (0..shape.len() - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
strides
}
}
fn compute_strides_from_shape(shape: &[usize]) -> Vec<usize> {
let mut strides = vec![1; shape.len()];
for i in (0..shape.len() - 1).rev() {
strides[i] = strides[i + 1] * shape[i + 1];
}
strides
}
#[macro_export]
macro_rules! idx {
($idx:expr) => {
vec![TensorIndex::Index($idx)]
};
($($idx:expr),+ $(,)?) => {
vec![$(TensorIndex::Index($idx)),+]
};
}
#[macro_export]
macro_rules! s {
(..) => {
TensorIndex::All
};
(.. $stop:expr) => {
TensorIndex::range(None, Some($stop))
};
($start:expr, $stop:expr) => {
TensorIndex::range(Some($start), Some($stop))
};
($start:expr, $stop:expr, $step:expr) => {
TensorIndex::range_step(Some($start), Some($stop), $step)
};
(ellipsis) => {
TensorIndex::Ellipsis
};
(None) => {
TensorIndex::NewAxis
};
}
#[macro_export]
macro_rules! fancy_idx {
[$($idx:expr),+ $(,)?] => {
TensorIndex::List(vec![$($idx),+])
};
}
#[macro_export]
macro_rules! mask_idx {
[$mask:expr] => {
TensorIndex::Mask($mask)
};
}
impl<T: TensorElement> Tensor<T> {
pub fn index_with_list(&self, dim: i32, indices: &[i64]) -> Result<Self> {
let ndim = self.ndim() as i32;
let dim = if dim < 0 { ndim + dim } else { dim } as usize;
if dim >= self.ndim() {
return Err(TorshError::InvalidArgument(format!(
"Dimension {} out of range for tensor with {} dimensions",
dim,
self.ndim()
)));
}
let mut index_spec = vec![TensorIndex::All; self.ndim()];
index_spec[dim] = TensorIndex::List(indices.to_vec());
self.index(&index_spec)
}
pub fn index_with_mask(&self, dim: i32, mask: &Tensor<bool>) -> Result<Self> {
let ndim = self.ndim() as i32;
let dim = if dim < 0 { ndim + dim } else { dim } as usize;
if dim >= self.ndim() {
return Err(TorshError::InvalidArgument(format!(
"Dimension {} out of range for tensor with {} dimensions",
dim,
self.ndim()
)));
}
let mut index_spec = vec![TensorIndex::All; self.ndim()];
index_spec[dim] = TensorIndex::Mask(mask.clone());
self.index(&index_spec)
}
pub fn mask_select(&self, mask: &Tensor<bool>) -> Result<Self> {
if self.shape() != mask.shape() {
return Err(TorshError::ShapeMismatch {
expected: self.shape().dims().to_vec(),
got: mask.shape().dims().to_vec(),
});
}
let self_data = self.data()?;
let mask_data = mask.data()?;
let mut selected_data = Vec::new();
for (i, &mask_val) in mask_data.iter().enumerate() {
if mask_val {
selected_data.push(self_data[i]);
}
}
Self::from_data(
selected_data.clone(),
vec![selected_data.len()],
self.device,
)
}
pub fn where_condition<F>(&self, condition: F) -> Result<Tensor<bool>>
where
F: Fn(&T) -> bool,
T: Clone,
{
let data = self.data()?;
let mask_data: Vec<bool> = data.iter().map(condition).collect();
Tensor::from_data(mask_data, self.shape().dims().to_vec(), self.device)
}
pub fn scatter_indexed(&self, dim: i32, index: &Tensor<i64>, src: &Self) -> Result<Self> {
let ndim = self.ndim() as i32;
let dim = if dim < 0 { ndim + dim } else { dim } as usize;
if dim >= self.ndim() {
return Err(TorshError::InvalidArgument(format!(
"Dimension {} out of range for tensor with {} dimensions",
dim,
self.ndim()
)));
}
let self_shape_binding = self.shape();
let self_shape = self_shape_binding.dims();
let index_shape_binding = index.shape();
let index_shape = index_shape_binding.dims();
let src_shape_binding = src.shape();
let src_shape = src_shape_binding.dims();
if index_shape != src_shape {
return Err(TorshError::ShapeMismatch {
expected: index_shape.to_vec(),
got: src_shape.to_vec(),
});
}
if index_shape.len() != self_shape.len() {
return Err(TorshError::InvalidArgument(
"Index tensor must have same number of dimensions as input tensor".to_string(),
));
}
let mut result_data = self.data()?.clone();
let index_data = index.data()?;
let src_data = src.data()?;
let self_strides = self.compute_strides();
let index_size = index_shape.iter().product();
for flat_idx in 0..index_size {
let mut coords = Vec::new();
let mut temp_idx = flat_idx;
for &dim_size in index_shape.iter().rev() {
coords.push(temp_idx % dim_size);
temp_idx /= dim_size;
}
coords.reverse();
let scatter_idx = index_data[flat_idx];
let dim_size = self_shape[dim] as i64;
let scatter_idx = if scatter_idx < 0 {
dim_size + scatter_idx
} else {
scatter_idx
};
if scatter_idx < 0 || scatter_idx >= dim_size {
return Err(TorshError::IndexOutOfBounds {
index: scatter_idx as usize,
size: dim_size as usize,
});
}
coords[dim] = scatter_idx as usize;
let mut dest_idx = 0;
for (coord, &stride) in coords.iter().zip(self_strides.iter()) {
dest_idx += coord * stride;
}
result_data[dest_idx] = src_data[flat_idx];
}
Self::from_data(result_data, self_shape.to_vec(), self.device)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::creation::{tensor_2d, zeros};
#[test]
fn test_index_macros() {
let indices = idx![5];
assert_eq!(indices.len(), 1);
let indices = idx![1, 2, 3];
assert_eq!(indices.len(), 3);
let _all = s![..];
let _range = s![1, 5];
let _range_step = s![1, 10, 2];
let _to = s![..7];
let _fancy = fancy_idx![0, 2, 1];
let _ellipsis = s![ellipsis];
let _newaxis = s![None];
}
#[test]
fn test_get_set() {
let tensor = tensor_2d(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0]])
.expect("tensor creation should succeed");
assert_eq!(
tensor.get(&[0, 0]).expect("data access should succeed"),
1.0
);
assert_eq!(
tensor.get(&[0, 1]).expect("data access should succeed"),
2.0
);
assert_eq!(
tensor.get(&[1, 2]).expect("data access should succeed"),
6.0
);
tensor
.set(&[1, 1], 10.0)
.expect("data access should succeed");
assert_eq!(
tensor.get(&[1, 1]).expect("data access should succeed"),
10.0
);
assert!(tensor.get(&[2, 0]).is_err());
assert!(tensor.set(&[0, 3], 0.0).is_err());
}
#[test]
fn test_gather() {
let tensor = tensor_2d(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0], &[7.0, 8.0, 9.0]])
.expect("tensor creation should succeed");
let indices = tensor_2d(&[&[0i64, 2, 1], &[1, 0, 2], &[2, 1, 0]])
.expect("tensor creation should succeed");
let result = tensor.gather(1, &indices).expect("gather should succeed");
assert_eq!(
result.get(&[0, 0]).expect("data access should succeed"),
1.0
);
assert_eq!(
result.get(&[0, 1]).expect("data access should succeed"),
3.0
);
assert_eq!(
result.get(&[0, 2]).expect("data access should succeed"),
2.0
);
assert_eq!(
result.get(&[1, 0]).expect("data access should succeed"),
5.0
);
assert_eq!(
result.get(&[1, 1]).expect("data access should succeed"),
4.0
);
assert_eq!(
result.get(&[1, 2]).expect("data access should succeed"),
6.0
);
assert_eq!(
result.get(&[2, 0]).expect("data access should succeed"),
9.0
);
assert_eq!(
result.get(&[2, 1]).expect("data access should succeed"),
8.0
);
assert_eq!(
result.get(&[2, 2]).expect("data access should succeed"),
7.0
);
}
#[test]
fn test_scatter() {
let tensor = zeros::<f32>(&[3, 3]).expect("tensor creation should succeed");
let indices = tensor_2d(&[&[0i64, 2, 1], &[1, 0, 2], &[2, 1, 0]])
.expect("tensor creation should succeed");
let src = tensor_2d(&[&[1.0, 2.0, 3.0], &[4.0, 5.0, 6.0], &[7.0, 8.0, 9.0]])
.expect("tensor creation should succeed");
let result = tensor
.scatter(1, &indices, &src)
.expect("scatter should succeed");
assert_eq!(
result.get(&[0, 0]).expect("data access should succeed"),
1.0
);
assert_eq!(
result.get(&[0, 1]).expect("data access should succeed"),
3.0
);
assert_eq!(
result.get(&[0, 2]).expect("data access should succeed"),
2.0
);
assert_eq!(
result.get(&[1, 0]).expect("data access should succeed"),
5.0
);
assert_eq!(
result.get(&[1, 1]).expect("data access should succeed"),
4.0
);
assert_eq!(
result.get(&[1, 2]).expect("data access should succeed"),
6.0
);
assert_eq!(
result.get(&[2, 0]).expect("data access should succeed"),
9.0
);
assert_eq!(
result.get(&[2, 1]).expect("data access should succeed"),
8.0
);
assert_eq!(
result.get(&[2, 2]).expect("data access should succeed"),
7.0
);
}
#[test]
fn test_index_select() {
let tensor = tensor_2d(&[
&[1.0, 2.0, 3.0, 4.0],
&[5.0, 6.0, 7.0, 8.0],
&[9.0, 10.0, 11.0, 12.0],
])
.expect("tensor creation should succeed");
let row_indices =
crate::creation::tensor_1d(&[0i64, 2]).expect("tensor creation should succeed");
let result = tensor
.index_select(0, &row_indices)
.expect("index_select should succeed");
assert_eq!(result.shape().dims(), &[2, 4]);
assert_eq!(
result.get(&[0, 0]).expect("data access should succeed"),
1.0
);
assert_eq!(
result.get(&[0, 3]).expect("data access should succeed"),
4.0
);
assert_eq!(
result.get(&[1, 0]).expect("data access should succeed"),
9.0
);
assert_eq!(
result.get(&[1, 3]).expect("data access should succeed"),
12.0
);
let col_indices =
crate::creation::tensor_1d(&[1i64, 3]).expect("tensor creation should succeed");
let result = tensor
.index_select(1, &col_indices)
.expect("index_select should succeed");
assert_eq!(result.shape().dims(), &[3, 2]);
assert_eq!(
result.get(&[0, 0]).expect("data access should succeed"),
2.0
);
assert_eq!(
result.get(&[0, 1]).expect("data access should succeed"),
4.0
);
assert_eq!(
result.get(&[2, 0]).expect("data access should succeed"),
10.0
);
assert_eq!(
result.get(&[2, 1]).expect("data access should succeed"),
12.0
);
}
#[test]
fn test_list_indexing() {
let tensor = tensor_2d(&[
&[1.0, 2.0, 3.0, 4.0],
&[5.0, 6.0, 7.0, 8.0],
&[9.0, 10.0, 11.0, 12.0],
])
.expect("tensor creation should succeed");
let indices = vec![TensorIndex::List(vec![0, 2]), TensorIndex::All];
let result = tensor.index(&indices).expect("indexing should succeed");
assert_eq!(result.shape().dims(), &[2, 4]);
assert_eq!(
result.get(&[0, 0]).expect("data access should succeed"),
1.0
);
assert_eq!(
result.get(&[0, 3]).expect("data access should succeed"),
4.0
);
assert_eq!(
result.get(&[1, 0]).expect("data access should succeed"),
9.0
);
assert_eq!(
result.get(&[1, 3]).expect("data access should succeed"),
12.0
);
let result2 = tensor
.index_with_list(0, &[0, 2])
.expect("index_with_list should succeed");
assert_eq!(result.shape(), result2.shape());
assert_eq!(
result.get(&[0, 0]).expect("data access should succeed"),
result2.get(&[0, 0]).expect("data access should succeed")
);
}
#[test]
fn test_boolean_mask_indexing() {
use crate::creation::tensor_1d;
let tensor =
tensor_1d(&[10.0, 20.0, 30.0, 40.0, 50.0]).expect("tensor creation should succeed");
let mask = Tensor::from_data(
vec![true, false, true, false, true],
vec![5],
crate::DeviceType::Cpu,
)
.expect("tensor creation should succeed");
let result = tensor
.mask_select(&mask)
.expect("mask_select should succeed");
assert_eq!(result.shape().dims(), &[3]);
assert_eq!(result.get(&[0]).expect("data access should succeed"), 10.0);
assert_eq!(result.get(&[1]).expect("data access should succeed"), 30.0);
assert_eq!(result.get(&[2]).expect("data access should succeed"), 50.0);
let result2 = tensor
.index_with_mask(0, &mask)
.expect("index_with_mask should succeed");
assert_eq!(result2.shape().dims(), &[3]);
assert_eq!(result2.get(&[0]).expect("data access should succeed"), 10.0);
assert_eq!(result2.get(&[1]).expect("data access should succeed"), 30.0);
assert_eq!(result2.get(&[2]).expect("data access should succeed"), 50.0);
}
#[test]
fn test_where_condition() {
use crate::creation::tensor_1d;
let tensor = tensor_1d(&[1.0, 2.0, 3.0, 4.0, 5.0]).expect("tensor creation should succeed");
let mask = tensor
.where_condition(|&x| x > 3.0)
.expect("where_condition should succeed");
{
let mask_data = mask.data().expect("data access should succeed");
assert!(!mask_data[0]); assert!(!mask_data[1]); assert!(!mask_data[2]); assert!(mask_data[3]); assert!(mask_data[4]); }
let selected = tensor
.mask_select(&mask)
.expect("mask_select should succeed");
assert_eq!(selected.shape().dims(), &[2]);
assert_eq!(selected.get(&[0]).expect("data access should succeed"), 4.0);
assert_eq!(selected.get(&[1]).expect("data access should succeed"), 5.0);
}
#[test]
fn test_newaxis_indexing() {
use crate::creation::tensor_1d;
let tensor = tensor_1d(&[1.0, 2.0, 3.0]).expect("tensor creation should succeed");
let indices = vec![TensorIndex::NewAxis, TensorIndex::All];
let result = tensor.index(&indices).expect("indexing should succeed");
assert_eq!(result.shape().dims(), &[1, 3]);
let indices = vec![TensorIndex::All, TensorIndex::NewAxis];
let result = tensor.index(&indices).expect("indexing should succeed");
assert_eq!(result.shape().dims(), &[3, 1]);
let indices = vec![
TensorIndex::NewAxis,
TensorIndex::All,
TensorIndex::NewAxis,
TensorIndex::NewAxis,
];
let result = tensor.index(&indices).expect("indexing should succeed");
assert_eq!(result.shape().dims(), &[1, 3, 1, 1]);
}
#[test]
fn test_ellipsis_indexing() {
let tensor =
crate::creation::zeros::<f32>(&[2, 3, 4]).expect("tensor creation should succeed");
let indices = vec![TensorIndex::Index(0), TensorIndex::Ellipsis];
let result = tensor.index(&indices).expect("indexing should succeed");
assert_eq!(result.shape().dims(), &[3, 4]);
let indices = vec![TensorIndex::Index(1), TensorIndex::Ellipsis];
let result = tensor.index(&indices).expect("indexing should succeed");
assert_eq!(result.shape().dims(), &[3, 4]);
}
#[test]
fn test_complex_indexing() {
let tensor = tensor_2d(&[
&[1.0, 2.0, 3.0, 4.0],
&[5.0, 6.0, 7.0, 8.0],
&[9.0, 10.0, 11.0, 12.0],
&[13.0, 14.0, 15.0, 16.0],
])
.expect("operation should succeed");
let indices = vec![
TensorIndex::List(vec![0, 2, 3]),
TensorIndex::Range(Some(1), Some(4), None),
];
let result = tensor.index(&indices).expect("indexing should succeed");
assert_eq!(result.shape().dims(), &[3, 3]);
assert_eq!(
result.get(&[0, 0]).expect("data access should succeed"),
2.0
); assert_eq!(
result.get(&[1, 0]).expect("data access should succeed"),
10.0
); assert_eq!(
result.get(&[2, 2]).expect("data access should succeed"),
16.0
); }
#[test]
fn test_negative_indexing() {
use crate::creation::tensor_1d;
let tensor = tensor_1d(&[1.0, 2.0, 3.0, 4.0, 5.0]).expect("tensor creation should succeed");
let indices = vec![TensorIndex::Index(-1)];
let result = tensor.index(&indices).expect("indexing should succeed");
assert_eq!(result.numel(), 1);
assert_eq!(result.item().expect("item extraction should succeed"), 5.0);
let indices = vec![TensorIndex::Range(Some(-3), Some(-1), None)];
let result = tensor.index(&indices).expect("indexing should succeed");
assert_eq!(result.shape().dims(), &[2]);
assert_eq!(result.get(&[0]).expect("data access should succeed"), 3.0);
assert_eq!(result.get(&[1]).expect("data access should succeed"), 4.0);
let indices = vec![TensorIndex::List(vec![-1, -2, 0])];
let result = tensor.index(&indices).expect("indexing should succeed");
assert_eq!(result.shape().dims(), &[3]);
assert_eq!(result.get(&[0]).expect("data access should succeed"), 5.0); assert_eq!(result.get(&[1]).expect("data access should succeed"), 4.0); assert_eq!(result.get(&[2]).expect("data access should succeed"), 1.0); }
}