use crate::error::{RusTorchError, RusTorchResult};
use crate::tensor::Tensor;
use ndarray::{ArrayD, IxDyn};
use num_traits::Float;
use std::sync::Arc;
pub enum ShapeMode {
Owned,
ViewOrOwned,
ViewOnly,
}
struct TensorProcessor;
impl TensorProcessor {
fn process_with_transform<T, F>(
data: &ArrayD<T>,
source_shape: &[usize],
target_shape: &[usize],
transform_fn: F,
) -> RusTorchResult<Vec<T>>
where
T: Float + Clone,
F: Fn(&[usize]) -> Vec<usize> + Copy,
{
let mut output = Vec::with_capacity(target_shape.iter().product());
let mut indices = vec![0; target_shape.len()];
Self::recursive_process(
data,
&mut output,
source_shape,
target_shape,
transform_fn,
&mut indices,
0,
)?;
Ok(output)
}
fn recursive_process<T, F>(
data: &ArrayD<T>,
output: &mut Vec<T>,
source_shape: &[usize],
target_shape: &[usize],
transform_fn: F,
indices: &mut [usize],
dim: usize,
) -> RusTorchResult<()>
where
T: Float + Clone,
F: Fn(&[usize]) -> Vec<usize> + Copy,
{
if dim == target_shape.len() {
let source_indices = transform_fn(indices);
if let Some(&value) = data.get(source_indices.as_slice()) {
output.push(value);
} else {
return Err(RusTorchError::index_out_of_bounds(
&source_indices,
source_shape,
));
}
return Ok(());
}
for i in 0..target_shape[dim] {
indices[dim] = i;
Self::recursive_process(
data,
output,
source_shape,
target_shape,
transform_fn,
indices,
dim + 1,
)?;
}
Ok(())
}
}
struct ShapeValidator;
impl ShapeValidator {
fn validate_dimension<T: Float + 'static>(
tensor: &Tensor<T>,
dim: usize,
) -> RusTorchResult<()> {
if dim >= tensor.shape().len() {
return Err(RusTorchError::invalid_dimension(
dim,
tensor.shape().len() - 1,
));
}
Ok(())
}
fn validate_broadcast_compatibility(shape1: &[usize], shape2: &[usize]) -> RusTorchResult<()> {
let max_dims = shape1.len().max(shape2.len());
for i in 0..max_dims {
let dim1 = shape1
.get(shape1.len().saturating_sub(max_dims - i))
.copied()
.unwrap_or(1);
let dim2 = shape2
.get(shape2.len().saturating_sub(max_dims - i))
.copied()
.unwrap_or(1);
if dim1 != dim2 && dim1 != 1 && dim2 != 1 {
return Err(RusTorchError::shape_mismatch(shape1, shape2));
}
}
Ok(())
}
fn normalize_dimension(dim: isize, ndim: usize) -> RusTorchResult<usize> {
let normalized = if dim < 0 {
(ndim as isize + dim) as usize
} else {
dim as usize
};
if normalized >= ndim {
return Err(RusTorchError::invalid_dimension(normalized, ndim - 1));
}
Ok(normalized)
}
}
impl<T: Float + Clone + 'static> Tensor<T> {
pub fn squeeze(&self) -> Self {
self.squeeze_with_mode(ShapeMode::Owned)
.expect("Owned squeeze should never fail")
}
pub fn squeeze_view(&self) -> RusTorchResult<Self> {
self.squeeze_with_mode(ShapeMode::ViewOrOwned)
}
pub fn squeeze_inplace(&mut self) -> RusTorchResult<()> {
let current_shape = self.data.shape();
let new_shape: Vec<usize> = current_shape
.iter()
.filter(|&&dim| dim != 1)
.copied()
.collect();
let final_shape = if new_shape.is_empty() {
vec![1]
} else {
new_shape
};
match self.data.clone().into_shape_with_order(final_shape) {
Ok(reshaped) => {
self.data = reshaped;
Ok(())
}
Err(_) => Err(RusTorchError::InvalidOperation {
operation: "squeeze_inplace".to_string(),
message: "Cannot perform in-place squeeze due to layout constraints".to_string(),
}),
}
}
pub fn squeeze_dim(&self, dim: usize) -> RusTorchResult<Self> {
let current_shape = self.data.shape();
if dim >= current_shape.len() {
return Err(RusTorchError::InvalidDimension(format!(
"Invalid dimension {} (max: {})",
dim,
current_shape.len() - 1
)));
}
if current_shape[dim] != 1 {
return Err(RusTorchError::InvalidOperation {
operation: "squeeze_dim".to_string(),
message: format!(
"Cannot squeeze dimension {} with size {}",
dim, current_shape[dim]
),
});
}
let mut new_shape = current_shape.to_vec();
new_shape.remove(dim);
if new_shape.is_empty() {
new_shape.push(1);
}
let reshaped_data = self
.data
.clone()
.into_shape_with_order(new_shape)
.map_err(|_| RusTorchError::InvalidOperation {
operation: "squeeze_dim".to_string(),
message: "Failed to reshape tensor".to_string(),
})?;
Ok(Tensor::new(reshaped_data))
}
pub fn unsqueeze(&self, dim: usize) -> RusTorchResult<Self> {
self.unsqueeze_with_mode(dim, ShapeMode::Owned)
}
pub fn unsqueeze_view(&self, dim: usize) -> RusTorchResult<Self> {
self.unsqueeze_with_mode(dim, ShapeMode::ViewOrOwned)
}
pub fn unsqueeze_inplace(&mut self, dim: usize) -> RusTorchResult<()> {
let mut new_shape = self.data.shape().to_vec();
if dim > new_shape.len() {
return Err(RusTorchError::InvalidDimension(format!(
"Invalid dimension {} (max: {})",
dim,
new_shape.len()
)));
}
new_shape.insert(dim, 1);
match self.data.clone().into_shape_with_order(new_shape) {
Ok(reshaped) => {
self.data = reshaped;
Ok(())
}
Err(_) => Err(RusTorchError::InvalidOperation {
operation: "unsqueeze_inplace".to_string(),
message: "Cannot perform in-place unsqueeze due to layout constraints".to_string(),
}),
}
}
pub fn expand_owned(&self, target_shape: &[usize]) -> RusTorchResult<Self> {
self.expand_with_mode(target_shape, ShapeMode::Owned)
}
pub fn expand_shared(&self, target_shape: &[usize]) -> RusTorchResult<Arc<Self>> {
let expanded = self.expand_with_mode(target_shape, ShapeMode::ViewOrOwned)?;
Ok(Arc::new(expanded))
}
pub fn expand_lazy(&self, target_shape: &[usize]) -> RusTorchResult<LazyExpandedTensor<T>> {
self.validate_expansion(target_shape)?;
Ok(LazyExpandedTensor {
source: Arc::new(self.clone()),
target_shape: target_shape.to_vec(),
})
}
pub fn flatten_owned(&self) -> Self {
let total_elements = self.data.len();
let flattened_data = self
.data
.clone()
.into_shape_with_order(vec![total_elements])
.expect("Flatten should always succeed");
Tensor::new(flattened_data)
}
pub fn flatten_range(&self, start_dim: usize, end_dim: Option<usize>) -> RusTorchResult<Self> {
let shape = self.shape();
let end_dim = end_dim.unwrap_or(shape.len() - 1);
if start_dim >= shape.len() || end_dim >= shape.len() || start_dim > end_dim {
return Err(RusTorchError::InvalidDimension(format!(
"Invalid dimension range [{}, {}] for tensor with {} dimensions",
start_dim,
end_dim,
shape.len()
)));
}
let mut new_shape = Vec::new();
new_shape.extend_from_slice(&shape[..start_dim]);
let flattened_size: usize = shape[start_dim..=end_dim].iter().product();
new_shape.push(flattened_size);
new_shape.extend_from_slice(&shape[end_dim + 1..]);
let reshaped_data = self
.data
.clone()
.into_shape_with_order(new_shape)
.map_err(|_| RusTorchError::InvalidOperation {
operation: "flatten_range".to_string(),
message: "Failed to flatten dimension range".to_string(),
})?;
Ok(Tensor::new(reshaped_data))
}
pub fn flatten_inplace(&mut self) -> RusTorchResult<()> {
let total_elements = self.data.len();
match self
.data
.clone()
.into_shape_with_order(vec![total_elements])
{
Ok(flattened) => {
self.data = flattened;
Ok(())
}
Err(_) => Err(RusTorchError::InvalidOperation {
operation: "flatten_inplace".to_string(),
message: "Cannot perform in-place flatten due to layout constraints".to_string(),
}),
}
}
pub fn flatten_view(&self) -> RusTorchResult<Self> {
let total_elements = self.data.len();
if self.is_contiguous() {
let view_data = self
.data
.clone()
.into_shape_with_order(vec![total_elements])
.map_err(|_| RusTorchError::InvalidOperation {
operation: "flatten_view".to_string(),
message: "Cannot create view due to non-contiguous layout".to_string(),
})?;
Ok(Tensor::new(view_data))
} else {
Err(RusTorchError::InvalidOperation {
operation: "flatten_view".to_string(),
message: "Cannot create zero-copy view from non-contiguous tensor".to_string(),
})
}
}
fn squeeze_with_mode(&self, mode: ShapeMode) -> RusTorchResult<Self> {
let current_shape = self.data.shape();
let new_shape: Vec<usize> = current_shape
.iter()
.filter(|&&dim| dim != 1)
.copied()
.collect();
let final_shape = if new_shape.is_empty() {
vec![1]
} else {
new_shape
};
match mode {
ShapeMode::Owned => {
let reshaped = self
.data
.clone()
.into_shape_with_order(final_shape)
.map_err(|_| RusTorchError::InvalidOperation {
operation: "squeeze".to_string(),
message: "Failed to squeeze tensor".to_string(),
})?;
Ok(Tensor::new(reshaped))
}
ShapeMode::ViewOrOwned => {
if self.is_contiguous() {
self.squeeze_with_mode(ShapeMode::ViewOnly)
.or_else(|_| self.squeeze_with_mode(ShapeMode::Owned))
} else {
self.squeeze_with_mode(ShapeMode::Owned)
}
}
ShapeMode::ViewOnly => {
if !self.is_contiguous() {
return Err(RusTorchError::InvalidOperation {
operation: "squeeze_view".to_string(),
message: "Cannot create view from non-contiguous tensor".to_string(),
});
}
let reshaped = self
.data
.clone()
.into_shape_with_order(final_shape)
.map_err(|_| RusTorchError::InvalidOperation {
operation: "squeeze_view".to_string(),
message: "Failed to create view".to_string(),
})?;
Ok(Tensor::new(reshaped))
}
}
}
fn unsqueeze_with_mode(&self, dim: usize, mode: ShapeMode) -> RusTorchResult<Self> {
let mut new_shape = self.data.shape().to_vec();
if dim > new_shape.len() {
return Err(RusTorchError::InvalidDimension(format!(
"Invalid dimension {} (max: {})",
dim,
new_shape.len()
)));
}
new_shape.insert(dim, 1);
match mode {
ShapeMode::Owned => {
let reshaped =
self.data
.clone()
.into_shape_with_order(new_shape)
.map_err(|_| RusTorchError::InvalidOperation {
operation: "unsqueeze".to_string(),
message: "Failed to unsqueeze tensor".to_string(),
})?;
Ok(Tensor::new(reshaped))
}
ShapeMode::ViewOrOwned => {
if self.is_contiguous() {
self.unsqueeze_with_mode(dim, ShapeMode::ViewOnly)
.or_else(|_| self.unsqueeze_with_mode(dim, ShapeMode::Owned))
} else {
self.unsqueeze_with_mode(dim, ShapeMode::Owned)
}
}
ShapeMode::ViewOnly => {
if !self.is_contiguous() {
return Err(RusTorchError::InvalidOperation {
operation: "unsqueeze_view".to_string(),
message: "Cannot create view from non-contiguous tensor".to_string(),
});
}
let reshaped =
self.data
.clone()
.into_shape_with_order(new_shape)
.map_err(|_| RusTorchError::InvalidOperation {
operation: "unsqueeze_view".to_string(),
message: "Failed to create view".to_string(),
})?;
Ok(Tensor::new(reshaped))
}
}
}
fn expand_with_mode(&self, target_shape: &[usize], mode: ShapeMode) -> RusTorchResult<Self> {
self.validate_expansion(target_shape)?;
match mode {
ShapeMode::Owned => self.expand_impl(target_shape),
ShapeMode::ViewOrOwned => {
self.expand_impl(target_shape)
}
ShapeMode::ViewOnly => Err(RusTorchError::InvalidOperation {
operation: "expand_view".to_string(),
message: "Expand operation cannot be performed as zero-copy view".to_string(),
}),
}
}
fn expand_impl(&self, target_shape: &[usize]) -> RusTorchResult<Self> {
let mut expanded_data = Vec::new();
let total_elements: usize = target_shape.iter().product();
expanded_data.reserve(total_elements);
self.expand_recursive(
&mut expanded_data,
target_shape,
&vec![0; target_shape.len()],
0,
)?;
Ok(Tensor::from_vec(expanded_data, target_shape.to_vec()))
}
fn validate_expansion(&self, target_shape: &[usize]) -> RusTorchResult<()> {
let self_shape = self.shape();
if target_shape.len() < self_shape.len() {
return Err(RusTorchError::InvalidOperation {
operation: "expand".to_string(),
message: format!(
"Target shape must have at least {} dimensions, got {}",
self_shape.len(),
target_shape.len()
),
});
}
let ndim_diff = target_shape.len() - self_shape.len();
for (i, (&target_dim, &self_dim)) in target_shape
.iter()
.skip(ndim_diff)
.zip(self_shape.iter())
.enumerate()
{
if self_dim != 1 && self_dim != target_dim {
return Err(RusTorchError::InvalidOperation {
operation: "expand".to_string(),
message: format!(
"Cannot expand dimension {} from {} to {} (must be 1 or equal)",
i + ndim_diff,
self_dim,
target_dim
),
});
}
}
Ok(())
}
fn expand_recursive(
&self,
output: &mut Vec<T>,
target_shape: &[usize],
indices: &[usize],
dim: usize,
) -> RusTorchResult<()> {
if dim == target_shape.len() {
let self_indices = self.compute_source_indices(indices)?;
if let Some(&value) = self.data.get(self_indices.as_slice()) {
output.push(value);
} else {
return Err(RusTorchError::index_out_of_bounds(&[], &[]));
}
return Ok(());
}
let mut new_indices = indices.to_vec();
for i in 0..target_shape[dim] {
new_indices[dim] = i;
self.expand_recursive(output, target_shape, &new_indices, dim + 1)?;
}
Ok(())
}
fn compute_source_indices(&self, target_indices: &[usize]) -> RusTorchResult<Vec<usize>> {
let self_shape = self.shape();
let ndim_diff = target_indices.len() - self_shape.len();
let mut source_indices = Vec::new();
for (i, &target_idx) in target_indices.iter().skip(ndim_diff).enumerate() {
let self_dim = self_shape[i];
if self_dim == 1 {
source_indices.push(0);
} else {
source_indices.push(target_idx % self_dim);
}
}
Ok(source_indices)
}
fn repeat_recursive(
&self,
output: &mut Vec<T>,
output_shape: &[usize],
repeats: &[usize],
indices: &[usize],
dim: usize,
) -> RusTorchResult<()> {
if dim == output_shape.len() {
let source_indices =
self.compute_repeat_source_indices(indices, output_shape, repeats)?;
if let Some(&value) = self.data.get(source_indices.as_slice()) {
output.push(value);
} else {
return Err(RusTorchError::index_out_of_bounds(&[], &[]));
}
return Ok(());
}
let mut new_indices = indices.to_vec();
for i in 0..output_shape[dim] {
new_indices[dim] = i;
self.repeat_recursive(output, output_shape, repeats, &new_indices, dim + 1)?;
}
Ok(())
}
fn compute_repeat_source_indices(
&self,
output_indices: &[usize],
output_shape: &[usize],
repeats: &[usize],
) -> RusTorchResult<Vec<usize>> {
let self_shape = self.shape();
let mut source_indices = Vec::new();
let ndim_diff = if output_shape.len() > self_shape.len() {
output_shape.len() - self_shape.len()
} else {
0
};
for (i, &output_idx) in output_indices.iter().enumerate() {
if i < ndim_diff {
continue;
}
let self_dim_idx = i - ndim_diff;
if self_dim_idx < self_shape.len() {
let self_dim_size = self_shape[self_dim_idx];
let repeat_count = repeats[i];
source_indices.push(output_idx / repeat_count);
}
}
Ok(source_indices)
}
fn repeat_interleave_along_dim(&self, repeats: usize, dim: usize) -> RusTorchResult<Self> {
let shape = self.shape();
if dim >= shape.len() {
return Err(RusTorchError::InvalidDimension(format!(
"Invalid dimension {} (max: {})",
dim,
shape.len() - 1
)));
}
let mut output_shape = shape.to_vec();
output_shape[dim] *= repeats;
let mut output_data = Vec::new();
let total_elements: usize = output_shape.iter().product();
output_data.reserve(total_elements);
let mut indices = vec![0; output_shape.len()];
self.repeat_interleave_recursive(
&mut output_data,
&output_shape,
repeats,
dim,
&mut indices,
0,
)?;
Ok(Tensor::from_vec(output_data, output_shape))
}
fn repeat_interleave_recursive(
&self,
output: &mut Vec<T>,
output_shape: &[usize],
repeats: usize,
target_dim: usize,
indices: &mut [usize],
dim: usize,
) -> RusTorchResult<()> {
if dim == output_shape.len() {
let mut source_indices = indices.to_vec();
if target_dim < source_indices.len() {
source_indices[target_dim] = indices[target_dim] / repeats;
}
if let Some(&value) = self.data.get(source_indices.as_slice()) {
output.push(value);
} else {
return Err(RusTorchError::index_out_of_bounds(&[], &[]));
}
return Ok(());
}
for i in 0..output_shape[dim] {
indices[dim] = i;
self.repeat_interleave_recursive(
output,
output_shape,
repeats,
target_dim,
indices,
dim + 1,
)?;
}
Ok(())
}
fn roll_along_dimension(&self, shift: usize, dim: usize) -> RusTorchResult<Self> {
let shape = self.shape();
let dim_size = shape[dim];
if shift >= dim_size {
return Err(RusTorchError::InvalidOperation {
operation: "roll".to_string(),
message: "Shift amount exceeds dimension size".to_string(),
});
}
let mut output_data = Vec::with_capacity(self.data.len());
let mut indices = vec![0; shape.len()];
self.roll_recursive(&mut output_data, shape, shift, dim, &mut indices, 0)?;
Ok(Tensor::from_vec(output_data, shape.to_vec()))
}
fn roll_recursive(
&self,
output: &mut Vec<T>,
shape: &[usize],
shift: usize,
target_dim: usize,
indices: &mut [usize],
dim: usize,
) -> RusTorchResult<()> {
if dim == shape.len() {
let mut source_indices = indices.to_vec();
if target_dim < source_indices.len() {
let dim_size = shape[target_dim];
let rolled_idx = (indices[target_dim] + dim_size - shift) % dim_size;
source_indices[target_dim] = rolled_idx;
}
if let Some(&value) = self.data.get(source_indices.as_slice()) {
output.push(value);
} else {
return Err(RusTorchError::index_out_of_bounds(&[], &[]));
}
return Ok(());
}
for i in 0..shape[dim] {
indices[dim] = i;
self.roll_recursive(output, shape, shift, target_dim, indices, dim + 1)?;
}
Ok(())
}
fn rot90_once(&self, dim0: usize, dim1: usize) -> RusTorchResult<Self> {
let shape = self.shape();
let mut new_shape = shape.to_vec();
new_shape.swap(dim0, dim1);
let mut output_data = Vec::with_capacity(self.data.len());
let mut indices = vec![0; shape.len()];
self.rot90_recursive(
&mut output_data,
shape,
&new_shape,
dim0,
dim1,
&mut indices,
0,
)?;
Ok(Tensor::from_vec(output_data, new_shape))
}
fn rot90_recursive(
&self,
output: &mut Vec<T>,
original_shape: &[usize],
new_shape: &[usize],
dim0: usize,
dim1: usize,
indices: &mut [usize],
dim: usize,
) -> RusTorchResult<()> {
if dim == new_shape.len() {
let mut source_indices = indices.to_vec();
let old_i = indices[dim0];
let old_j = indices[dim1];
source_indices[dim0] = original_shape[dim1] - 1 - old_j;
source_indices[dim1] = old_i;
if let Some(&value) = self.data.get(source_indices.as_slice()) {
output.push(value);
} else {
return Err(RusTorchError::index_out_of_bounds(&[], &[]));
}
return Ok(());
}
for i in 0..new_shape[dim] {
indices[dim] = i;
self.rot90_recursive(
output,
original_shape,
new_shape,
dim0,
dim1,
indices,
dim + 1,
)?;
}
Ok(())
}
fn flip_single_dim(&self, dim: usize) -> RusTorchResult<Self> {
let shape = self.shape();
let dim_size = shape[dim];
let mut output_data = Vec::with_capacity(self.data.len());
let mut indices = vec![0; shape.len()];
self.flip_recursive(&mut output_data, shape, dim, &mut indices, 0)?;
Ok(Tensor::from_vec(output_data, shape.to_vec()))
}
fn flip_recursive(
&self,
output: &mut Vec<T>,
shape: &[usize],
flip_dim: usize,
indices: &mut [usize],
dim: usize,
) -> RusTorchResult<()> {
if dim == shape.len() {
let mut source_indices = indices.to_vec();
if flip_dim < source_indices.len() {
let dim_size = shape[flip_dim];
source_indices[flip_dim] = dim_size - 1 - indices[flip_dim];
}
if let Some(&value) = self.data.get(source_indices.as_slice()) {
output.push(value);
} else {
return Err(RusTorchError::index_out_of_bounds(&[], &[]));
}
return Ok(());
}
for i in 0..shape[dim] {
indices[dim] = i;
self.flip_recursive(output, shape, flip_dim, indices, dim + 1)?;
}
Ok(())
}
pub fn is_contiguous(&self) -> bool {
self.data.is_standard_layout()
}
pub fn can_broadcast_with(&self, other: &Self) -> bool {
let self_shape = self.shape();
let other_shape = other.shape();
let max_dims = self_shape.len().max(other_shape.len());
for i in 0..max_dims {
let self_dim = if i < self_shape.len() {
self_shape[self_shape.len() - 1 - i]
} else {
1
};
let other_dim = if i < other_shape.len() {
other_shape[other_shape.len() - 1 - i]
} else {
1
};
if self_dim != 1 && other_dim != 1 && self_dim != other_dim {
return false;
}
}
true
}
pub fn broadcast_with(&self, other: &Self) -> RusTorchResult<(Self, Self)> {
if !self.can_broadcast_with(other) {
return Err(RusTorchError::InvalidOperation {
operation: "broadcast".to_string(),
message: format!(
"Cannot broadcast shapes {:?} and {:?}",
self.shape(),
other.shape()
),
});
}
let self_shape = self.shape();
let other_shape = other.shape();
let max_dims = self_shape.len().max(other_shape.len());
let mut broadcast_shape = Vec::new();
for i in 0..max_dims {
let self_dim = if i < max_dims - self_shape.len() {
1
} else {
self_shape[i - (max_dims - self_shape.len())]
};
let other_dim = if i < max_dims - other_shape.len() {
1
} else {
other_shape[i - (max_dims - other_shape.len())]
};
broadcast_shape.push(self_dim.max(other_dim));
}
let broadcasted_self = if self.shape() == broadcast_shape.as_slice() {
self.clone()
} else {
self.expand_owned(&broadcast_shape)?
};
let broadcasted_other = if other.shape() == broadcast_shape.as_slice() {
other.clone()
} else {
other.expand_owned(&broadcast_shape)?
};
Ok((broadcasted_self, broadcasted_other))
}
pub fn expand_dims(&self, axis: usize) -> RusTorchResult<Self> {
self.unsqueeze(axis)
}
pub fn expand_as(&self, other: &Self) -> RusTorchResult<Self> {
self.expand_owned(other.shape())
}
pub fn unflatten(&self, dim: usize, sizes: &[usize]) -> RusTorchResult<Self> {
let current_shape = self.shape();
if dim >= current_shape.len() {
return Err(RusTorchError::InvalidDimension(format!(
"Invalid dimension {} (max: {})",
dim,
current_shape.len() - 1
)));
}
let sizes_product: usize = sizes.iter().product();
if sizes_product != current_shape[dim] {
return Err(RusTorchError::InvalidOperation {
operation: "unflatten".to_string(),
message: format!(
"Cannot unflatten dimension of size {} into sizes {:?} (product: {})",
current_shape[dim], sizes, sizes_product
),
});
}
let mut new_shape = Vec::new();
new_shape.extend_from_slice(¤t_shape[..dim]);
new_shape.extend_from_slice(sizes);
new_shape.extend_from_slice(¤t_shape[dim + 1..]);
let reshaped_data = self
.data
.clone()
.into_shape_with_order(new_shape)
.map_err(|_| RusTorchError::InvalidOperation {
operation: "unflatten".to_string(),
message: "Failed to unflatten tensor".to_string(),
})?;
Ok(Tensor::new(reshaped_data))
}
pub fn repeat(&self, repeats: &[usize]) -> RusTorchResult<Self> {
let current_shape = self.shape();
let (adjusted_shape, adjusted_repeats) = if repeats.len() > current_shape.len() {
let padding = repeats.len() - current_shape.len();
let mut padded_shape = vec![1; padding];
padded_shape.extend_from_slice(current_shape);
(padded_shape, repeats.to_vec())
} else if repeats.len() < current_shape.len() {
let padding = current_shape.len() - repeats.len();
let mut padded_repeats = vec![1; padding];
padded_repeats.extend_from_slice(repeats);
(current_shape.to_vec(), padded_repeats)
} else {
(current_shape.to_vec(), repeats.to_vec())
};
let output_shape: Vec<usize> = adjusted_shape
.iter()
.zip(adjusted_repeats.iter())
.map(|(&dim, &rep)| dim * rep)
.collect();
let mut output_data = Vec::new();
let total_elements: usize = output_shape.iter().product();
output_data.reserve(total_elements);
self.repeat_recursive(
&mut output_data,
&output_shape,
&adjusted_repeats,
&vec![0; output_shape.len()],
0,
)?;
Ok(Tensor::from_vec(output_data, output_shape))
}
pub fn repeat_interleave_scalar(
&self,
repeats: usize,
dim: Option<usize>,
) -> RusTorchResult<Self> {
match dim {
Some(d) => self.repeat_interleave_along_dim(repeats, d),
None => {
let flattened = self.flatten_owned();
let mut output_data = Vec::new();
for &value in flattened.data.iter() {
for _ in 0..repeats {
output_data.push(value);
}
}
let output_len = output_data.len();
Ok(Tensor::from_vec(output_data, vec![output_len]))
}
}
}
pub fn roll_1d(&self, shifts: isize, dim: Option<usize>) -> RusTorchResult<Self> {
let shape = self.shape();
match dim {
Some(d) => {
if d >= shape.len() {
return Err(RusTorchError::InvalidDimension(format!(
"Invalid dimension {} (max: {})",
d,
shape.len() - 1
)));
}
let dim_size = shape[d] as isize;
let effective_shift = ((shifts % dim_size) + dim_size) % dim_size;
if effective_shift == 0 {
return Ok(self.clone());
}
self.roll_along_dimension(effective_shift as usize, d)
}
None => {
let flattened = self.flatten_owned();
let data = flattened.data.as_slice().unwrap();
let len = data.len() as isize;
let effective_shift = ((shifts % len) + len) % len;
if effective_shift == 0 {
return Ok(self.clone());
}
let mut output_data = Vec::with_capacity(data.len());
let shift = effective_shift as usize;
output_data.extend_from_slice(&data[data.len() - shift..]);
output_data.extend_from_slice(&data[..data.len() - shift]);
let rolled_flat = Tensor::from_vec(output_data, vec![data.len()]);
rolled_flat.view_shape(shape)
}
}
}
pub fn rot90(&self, k: isize, dims: &[usize]) -> RusTorchResult<Self> {
if dims.len() != 2 {
return Err(RusTorchError::InvalidOperation {
operation: "rot90".to_string(),
message: "rot90 requires exactly 2 dimensions".to_string(),
});
}
let shape = self.shape();
let dim0 = dims[0];
let dim1 = dims[1];
if dim0 >= shape.len() || dim1 >= shape.len() {
return Err(RusTorchError::InvalidDimension(format!(
"Invalid dimensions [{}, {}] (max: {})",
dim0,
dim1,
shape.len() - 1
)));
}
if dim0 == dim1 {
return Err(RusTorchError::InvalidOperation {
operation: "rot90".to_string(),
message: "Rotation dimensions must be different".to_string(),
});
}
let k_norm = ((k % 4) + 4) % 4;
match k_norm {
0 => Ok(self.clone()),
1 => self.rot90_once(dim0, dim1),
2 => self.rot90_once(dim0, dim1)?.rot90_once(dim0, dim1),
3 => self
.rot90_once(dim0, dim1)?
.rot90_once(dim0, dim1)?
.rot90_once(dim0, dim1),
_ => unreachable!(),
}
}
pub fn flip(&self, dims: &[usize]) -> RusTorchResult<Self> {
let shape = self.shape();
for &dim in dims {
if dim >= shape.len() {
return Err(RusTorchError::InvalidDimension(format!(
"Invalid dimension {} (max: {})",
dim,
shape.len() - 1
)));
}
}
if dims.is_empty() {
return Ok(self.clone());
}
let mut result = self.clone();
for &dim in dims {
result = result.flip_single_dim(dim)?;
}
Ok(result)
}
pub fn fliplr(&self) -> RusTorchResult<Self> {
let shape = self.shape();
if shape.len() < 2 {
return Err(RusTorchError::InvalidOperation {
operation: "fliplr".to_string(),
message: "fliplr requires at least 2D tensor".to_string(),
});
}
self.flip(&[shape.len() - 1])
}
pub fn flipud(&self) -> RusTorchResult<Self> {
let shape = self.shape();
if shape.is_empty() {
return Err(RusTorchError::InvalidOperation {
operation: "flipud".to_string(),
message: "flipud requires at least 1D tensor".to_string(),
});
}
self.flip(&[0])
}
}
pub struct LazyExpandedTensor<T: Float> {
source: Arc<Tensor<T>>,
target_shape: Vec<usize>,
}
impl<T: Float + Clone + 'static> LazyExpandedTensor<T> {
pub fn materialize(&self) -> RusTorchResult<Tensor<T>> {
self.source.expand_owned(&self.target_shape)
}
pub fn shape(&self) -> &[usize] {
&self.target_shape
}
pub fn get(&self, indices: &[usize]) -> RusTorchResult<T> {
if indices.len() != self.target_shape.len() {
return Err(RusTorchError::index_out_of_bounds(
indices,
&self.target_shape,
));
}
let source_indices = self.source.compute_source_indices(indices)?;
self.source
.data
.get(source_indices.as_slice())
.copied()
.ok_or_else(|| RusTorchError::index_out_of_bounds(&[], &[]))
}
}
impl<T: Float + Clone + 'static> Tensor<T> {
pub fn view_shape(&self, shape: &[usize]) -> RusTorchResult<Self> {
let current_elements = self.data.len();
let new_elements: usize = shape.iter().product();
if current_elements != new_elements {
return Err(RusTorchError::InvalidOperation {
operation: "view_shape".to_string(),
message: format!(
"Shape {} is invalid for tensor with {} elements",
format!("{:?}", shape),
current_elements
),
});
}
let reshaped_data = self
.data
.clone()
.into_shape_with_order(shape.to_vec())
.map_err(|_| RusTorchError::InvalidOperation {
operation: "view_shape".to_string(),
message: "Failed to create view with specified shape".to_string(),
})?;
Ok(Tensor::new(reshaped_data))
}
pub fn shape_builder(self) -> ShapeBuilder<T> {
ShapeBuilder::new(self)
}
}
pub trait ZeroAllocShapeOps<T: Float> {
fn try_squeeze_view(&self) -> RusTorchResult<Tensor<T>>;
fn try_unsqueeze_view(&self, dim: usize) -> RusTorchResult<Tensor<T>>;
}
impl<T: Float + Clone + 'static> ZeroAllocShapeOps<T> for Tensor<T> {
fn try_squeeze_view(&self) -> RusTorchResult<Tensor<T>> {
self.squeeze_with_mode(ShapeMode::ViewOnly)
}
fn try_unsqueeze_view(&self, dim: usize) -> RusTorchResult<Tensor<T>> {
self.unsqueeze_with_mode(dim, ShapeMode::ViewOnly)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ownership_patterns_squeeze() {
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![1, 3, 1]);
let squeezed_owned = tensor.squeeze();
assert_eq!(squeezed_owned.shape(), &[3]);
let squeezed_view = tensor.squeeze_view().unwrap();
assert_eq!(squeezed_view.shape(), &[3]);
let mut tensor_mut = tensor.clone();
tensor_mut.squeeze_inplace().unwrap();
assert_eq!(tensor_mut.shape(), &[3]);
}
#[test]
fn test_ownership_patterns_unsqueeze() {
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]);
let unsqueezed = tensor.unsqueeze(0).unwrap();
assert_eq!(unsqueezed.shape(), &[1, 3]);
let unsqueezed_view = tensor.unsqueeze_view(1).unwrap();
assert_eq!(unsqueezed_view.shape(), &[3, 1]);
let mut tensor_mut = tensor.clone();
tensor_mut.unsqueeze_inplace(0).unwrap();
assert_eq!(tensor_mut.shape(), &[1, 3]);
}
#[test]
fn test_expand_shared_ownership() {
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![1, 3]);
let expanded_shared = tensor.expand_shared(&[4, 3]).unwrap();
assert_eq!(expanded_shared.shape(), &[4, 3]);
let ref1 = Arc::clone(&expanded_shared);
let ref2 = Arc::clone(&expanded_shared);
assert_eq!(ref1.shape(), ref2.shape());
}
#[test]
fn test_lazy_expand() {
let tensor = Tensor::from_vec(vec![1.0, 2.0], vec![1, 2]);
let lazy_expanded = tensor.expand_lazy(&[3, 2]).unwrap();
assert_eq!(lazy_expanded.shape(), &[3, 2]);
assert_eq!(lazy_expanded.get(&[0, 0]).unwrap(), 1.0);
assert_eq!(lazy_expanded.get(&[1, 1]).unwrap(), 2.0);
assert_eq!(lazy_expanded.get(&[2, 0]).unwrap(), 1.0);
let materialized = lazy_expanded.materialize().unwrap();
assert_eq!(materialized.shape(), &[3, 2]);
}
#[test]
fn test_shape_builder_chain() {
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![1, 2, 2, 1]);
let result = tensor
.shape_builder()
.squeeze()
.unwrap() .unsqueeze(0)
.unwrap() .expand(&[3, 2, 2])
.unwrap() .flatten()
.unwrap() .build();
assert_eq!(result.shape(), &[12]);
assert_eq!(result.numel(), 12);
}
#[test]
fn test_flatten_variants() {
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
let flattened = tensor.flatten_owned();
assert_eq!(flattened.shape(), &[4]);
let partial = tensor.flatten_range(0, Some(1)).unwrap();
assert_eq!(partial.shape(), &[4]);
if tensor.is_contiguous() {
let flattened_view = tensor.flatten_view().unwrap();
assert_eq!(flattened_view.shape(), &[4]);
}
let mut tensor_mut = tensor.clone();
tensor_mut.flatten_inplace().unwrap();
assert_eq!(tensor_mut.shape(), &[4]);
}
#[test]
fn test_zero_alloc_traits() {
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![1, 3]);
if tensor.is_contiguous() {
let squeezed_view = tensor.try_squeeze_view().unwrap();
assert_eq!(squeezed_view.shape(), &[3]);
let unsqueezed_view = tensor.try_unsqueeze_view(0).unwrap();
assert_eq!(unsqueezed_view.shape(), &[1, 1, 3]);
}
}
#[test]
fn test_expand_as() {
let tensor = Tensor::from_vec(vec![1.0, 2.0], vec![1, 2]);
let target = Tensor::from_vec(vec![0.0; 6], vec![3, 2]);
let expanded = tensor.expand_as(&target).unwrap();
assert_eq!(expanded.shape(), target.shape());
assert_eq!(expanded.shape(), &[3, 2]);
let expanded_data = expanded.data.as_slice().unwrap();
assert_eq!(expanded_data, &[1.0, 2.0, 1.0, 2.0, 1.0, 2.0]);
}
#[test]
fn test_unflatten() {
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![6]);
let unflattened = tensor.unflatten(0, &[2, 3]).unwrap();
assert_eq!(unflattened.shape(), &[2, 3]);
let unflattened2 = tensor.unflatten(0, &[3, 2]).unwrap();
assert_eq!(unflattened2.shape(), &[3, 2]);
let tensor2d = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
let unflattened3d = tensor2d.unflatten(1, &[1, 2]).unwrap();
assert_eq!(unflattened3d.shape(), &[2, 1, 2]);
let result = tensor.unflatten(0, &[2, 4]); assert!(result.is_err());
}
#[test]
fn test_repeat() {
let tensor = Tensor::from_vec(vec![1.0, 2.0], vec![2]);
let repeated = tensor.repeat(&[3]).unwrap();
assert_eq!(repeated.shape(), &[6]);
assert_eq!(
repeated.data.as_slice().unwrap(),
&[1.0, 1.0, 1.0, 2.0, 2.0, 2.0]
);
let tensor2d = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
let repeated2d = tensor2d.repeat(&[2, 3]).unwrap();
assert_eq!(repeated2d.shape(), &[4, 6]);
let repeated_padded = tensor.repeat(&[2, 3]).unwrap();
assert_eq!(repeated_padded.shape(), &[2, 6]);
}
#[test]
fn test_repeat_interleave() {
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0], vec![3]);
let repeated = tensor.repeat_interleave_scalar(2, Some(0)).unwrap();
assert_eq!(repeated.shape(), &[6]);
assert_eq!(
repeated.data.as_slice().unwrap(),
&[1.0, 1.0, 2.0, 2.0, 3.0, 3.0]
);
let repeated_flat = tensor.repeat_interleave_scalar(2, None).unwrap();
assert_eq!(repeated_flat.shape(), &[6]);
assert_eq!(
repeated_flat.data.as_slice().unwrap(),
&[1.0, 1.0, 2.0, 2.0, 3.0, 3.0]
);
}
#[test]
fn test_roll() {
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![4]);
let rolled = tensor.roll_1d(1, Some(0)).unwrap();
assert_eq!(rolled.shape(), &[4]);
assert_eq!(rolled.data.as_slice().unwrap(), &[4.0, 1.0, 2.0, 3.0]);
let rolled_neg = tensor.roll_1d(-1, Some(0)).unwrap();
assert_eq!(rolled_neg.data.as_slice().unwrap(), &[2.0, 3.0, 4.0, 1.0]);
let tensor2d = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
let rolled2d = tensor2d.roll_1d(1, Some(0)).unwrap();
assert_eq!(rolled2d.shape(), &[2, 2]);
let rolled_flat = tensor.roll_1d(1, None).unwrap();
assert_eq!(rolled_flat.shape(), &[4]);
assert_eq!(rolled_flat.data.as_slice().unwrap(), &[4.0, 1.0, 2.0, 3.0]);
}
#[test]
fn test_rot90() {
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
let rotated1 = tensor.rot90(1, &[0, 1]).unwrap();
assert_eq!(rotated1.shape(), &[2, 2]);
let rotated2 = tensor.rot90(2, &[0, 1]).unwrap();
assert_eq!(rotated2.shape(), &[2, 2]);
let rotated3 = tensor.rot90(3, &[0, 1]).unwrap();
assert_eq!(rotated3.shape(), &[2, 2]);
let rotated4 = tensor.rot90(4, &[0, 1]).unwrap();
assert_eq!(rotated4.shape(), &[2, 2]);
assert_eq!(
rotated4.data.as_slice().unwrap(),
tensor.data.as_slice().unwrap()
);
let rotated_neg = tensor.rot90(-1, &[0, 1]).unwrap();
assert_eq!(rotated_neg.shape(), &[2, 2]);
assert!(tensor.rot90(1, &[0]).is_err()); assert!(tensor.rot90(1, &[0, 0]).is_err()); assert!(tensor.rot90(1, &[0, 5]).is_err()); }
#[test]
fn test_flip() {
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
let flipped0 = tensor.flip(&[0]).unwrap();
assert_eq!(flipped0.shape(), &[2, 2]);
let flipped1 = tensor.flip(&[1]).unwrap();
assert_eq!(flipped1.shape(), &[2, 2]);
let flipped_both = tensor.flip(&[0, 1]).unwrap();
assert_eq!(flipped_both.shape(), &[2, 2]);
let no_flip = tensor.flip(&[]).unwrap();
assert_eq!(
no_flip.data.as_slice().unwrap(),
tensor.data.as_slice().unwrap()
);
assert!(tensor.flip(&[5]).is_err());
}
#[test]
fn test_fliplr_flipud() {
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2]);
let flipped_lr = tensor.fliplr().unwrap();
assert_eq!(flipped_lr.shape(), &[2, 2]);
let flipped_ud = tensor.flipud().unwrap();
assert_eq!(flipped_ud.shape(), &[2, 2]);
let tensor1d = Tensor::from_vec(vec![1.0, 2.0], vec![2]);
assert!(tensor1d.fliplr().is_err());
let tensor0d = Tensor::from_vec(vec![1.0], vec![]);
assert!(tensor0d.flipud().is_err()); }
#[test]
fn test_complex_shape_operations() {
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
let result = tensor
.clone()
.shape_builder()
.unsqueeze(0)
.unwrap() .expand(&[2, 2, 3])
.unwrap() .flatten()
.unwrap() .build();
assert_eq!(result.shape(), &[12]);
assert_eq!(result.numel(), 12);
let flattened = tensor.flatten_owned();
let restored = flattened.unflatten(0, &[2, 3]).unwrap();
let original_shape = vec![2, 3]; assert_eq!(restored.shape(), &original_shape);
let small = Tensor::from_vec(vec![1.0, 2.0], vec![1, 2]);
let target_shape = Tensor::from_vec(vec![0.0; 8], vec![4, 2]);
let expanded = small.expand_as(&target_shape).unwrap();
assert_eq!(expanded.shape(), &[4, 2]);
let data = expanded.data.as_slice().unwrap();
assert_eq!(data, &[1.0, 2.0, 1.0, 2.0, 1.0, 2.0, 1.0, 2.0]);
}
}
pub struct ShapeBuilder<T: Float> {
tensor: Tensor<T>,
}
impl<T: Float + Clone + 'static> ShapeBuilder<T> {
pub fn new(tensor: Tensor<T>) -> Self {
Self { tensor }
}
pub fn squeeze(mut self) -> RusTorchResult<Self> {
self.tensor = self.tensor.squeeze();
Ok(self)
}
pub fn squeeze_dim(mut self, dim: usize) -> RusTorchResult<Self> {
self.tensor = self.tensor.squeeze_dim(dim)?;
Ok(self)
}
pub fn unsqueeze(mut self, dim: usize) -> RusTorchResult<Self> {
self.tensor = self.tensor.unsqueeze(dim)?;
Ok(self)
}
pub fn expand_as(mut self, other: &Tensor<T>) -> RusTorchResult<Self> {
self.tensor = self.tensor.expand_as(other)?;
Ok(self)
}
pub fn expand(mut self, shape: &[usize]) -> RusTorchResult<Self> {
self.tensor = self.tensor.expand_owned(shape)?;
Ok(self)
}
pub fn flatten(mut self) -> RusTorchResult<Self> {
self.tensor = self.tensor.flatten_owned();
Ok(self)
}
pub fn flatten_range(
mut self,
start_dim: usize,
end_dim: Option<usize>,
) -> RusTorchResult<Self> {
self.tensor = self.tensor.flatten_range(start_dim, end_dim)?;
Ok(self)
}
pub fn unflatten(mut self, dim: usize, sizes: &[usize]) -> RusTorchResult<Self> {
self.tensor = self.tensor.unflatten(dim, sizes)?;
Ok(self)
}
pub fn repeat(mut self, repeats: &[usize]) -> RusTorchResult<Self> {
self.tensor = self.tensor.repeat(repeats)?;
Ok(self)
}
pub fn repeat_interleave(mut self, repeats: usize, dim: Option<usize>) -> RusTorchResult<Self> {
self.tensor = self.tensor.repeat_interleave_scalar(repeats, dim)?;
Ok(self)
}
pub fn roll(mut self, shifts: isize, dim: Option<usize>) -> RusTorchResult<Self> {
self.tensor = self.tensor.roll_1d(shifts, dim)?;
Ok(self)
}
pub fn rot90(mut self, k: isize, dims: &[usize]) -> RusTorchResult<Self> {
self.tensor = self.tensor.rot90(k, dims)?;
Ok(self)
}
pub fn flip(mut self, dims: &[usize]) -> RusTorchResult<Self> {
self.tensor = self.tensor.flip(dims)?;
Ok(self)
}
pub fn fliplr(mut self) -> RusTorchResult<Self> {
self.tensor = self.tensor.fliplr()?;
Ok(self)
}
pub fn flipud(mut self) -> RusTorchResult<Self> {
self.tensor = self.tensor.flipud()?;
Ok(self)
}
pub fn view_shape(mut self, shape: &[usize]) -> RusTorchResult<Self> {
self.tensor = self.tensor.view_shape(shape)?;
Ok(self)
}
pub fn build(self) -> Tensor<T> {
self.tensor
}
pub fn current_shape(&self) -> &[usize] {
self.tensor.shape()
}
pub fn peek(&self) -> &Tensor<T> {
&self.tensor
}
}
#[macro_export]
macro_rules! shape_ops {
($tensor:expr, $($op:ident$(($($arg:expr),*))?),+ $(,)?) => {{
let mut builder = $tensor.shape_builder();
$(
builder = builder.$op($($($arg),*)?)?;
)+
Ok::<_, $crate::error::RusTorchError>(builder.build())
}};
}
pub trait ShapeOps<T: Float> {
fn shapes(self) -> ShapeBuilder<T>;
}
impl<T: Float + Clone + 'static> ShapeOps<T> for Tensor<T> {
fn shapes(self) -> ShapeBuilder<T> {
ShapeBuilder::new(self)
}
}
#[cfg(test)]
mod builder_tests {
use super::*;
#[test]
fn test_builder_pattern_basic() {
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2, 1]);
let result = tensor
.shape_builder()
.squeeze()
.unwrap()
.unsqueeze(1)
.unwrap()
.flatten()
.unwrap()
.build();
assert_eq!(result.shape(), &[4]);
}
#[test]
fn test_fluent_interface() {
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![4, 1]);
let result = tensor
.shapes()
.squeeze()
.unwrap()
.unsqueeze(1)
.unwrap()
.build();
assert_eq!(result.shape(), &[4, 1]);
}
#[test]
fn test_shape_ops_macro() -> RusTorchResult<()> {
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2, 1]);
let result = shape_ops!(tensor, squeeze, flatten)?;
assert_eq!(result.shape(), &[4]);
Ok(())
}
#[test]
fn test_builder_peek() {
let tensor = Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], vec![2, 2, 1]);
let builder = tensor.shape_builder().squeeze().unwrap();
assert_eq!(builder.current_shape(), &[2, 2]);
let result = builder.flatten().unwrap().build();
assert_eq!(result.shape(), &[4]);
}
}