use std::sync::Arc;
use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::storage::TensorStorage;
use crate::tensor::{GradFn, Tensor};
pub fn as_strided<T: Float>(
input: &Tensor<T>,
size: &[usize],
stride: &[isize],
storage_offset: Option<usize>,
) -> FerrotorchResult<Tensor<T>> {
input.as_strided(size, stride, storage_offset)
}
pub fn as_strided_copy<T: Float>(
input: &Tensor<T>,
size: &[usize],
stride: &[isize],
storage_offset: Option<usize>,
) -> FerrotorchResult<Tensor<T>> {
input.as_strided_copy(size, stride, storage_offset)
}
pub fn as_strided_scatter<T: Float>(
input: &Tensor<T>,
src: &Tensor<T>,
size: &[usize],
stride: &[isize],
storage_offset: Option<usize>,
) -> FerrotorchResult<Tensor<T>> {
input.as_strided_scatter(src, size, stride, storage_offset)
}
fn stride_extent(shape: &[usize], stride: &[isize]) -> (i64, i64) {
if shape.contains(&0) {
return (0, 0);
}
let mut min_off: i64 = 0;
let mut max_off: i64 = 0;
for (&dim, &s) in shape.iter().zip(stride.iter()) {
if dim == 0 {
continue;
}
let last = (dim as i64 - 1) * s as i64;
if last >= 0 {
max_off += last;
} else {
min_off += last;
}
}
(min_off, max_off)
}
fn validate_bounds(
op: &'static str,
shape: &[usize],
stride: &[isize],
storage_offset: usize,
storage_len: usize,
) -> FerrotorchResult<()> {
if shape.len() != stride.len() {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"{op}: shape and stride must have the same length (got {} vs {})",
shape.len(),
stride.len()
),
});
}
if shape.contains(&0) {
if storage_offset > storage_len {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"{op}: storage_offset {storage_offset} > storage length {storage_len}"
),
});
}
return Ok(());
}
let (min_off, max_off) = stride_extent(shape, stride);
let lo = storage_offset as i64 + min_off;
let hi = storage_offset as i64 + max_off;
if lo < 0 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"{op}: storage_offset {storage_offset} with strides {stride:?} reaches negative \
offset {lo} (out of bounds)"
),
});
}
if hi >= storage_len as i64 {
return Err(FerrotorchError::InvalidArgument {
message: format!(
"{op}: storage_offset {storage_offset} with shape {shape:?} and strides \
{stride:?} reaches offset {hi}, beyond storage length {storage_len}"
),
});
}
Ok(())
}
impl<T: Float> Tensor<T> {
pub fn as_strided(
&self,
size: &[usize],
stride: &[isize],
storage_offset: Option<usize>,
) -> FerrotorchResult<Tensor<T>> {
let offset = storage_offset.unwrap_or_else(|| self.storage_offset());
let storage_len = self.storage_len();
validate_bounds("as_strided", size, stride, offset, storage_len)?;
if !crate::autograd::no_grad::is_grad_enabled() || !self.requires_grad() {
return Ok(self.stride_view(size.to_vec(), stride.to_vec(), offset));
}
let grad_fn = Arc::new(AsStridedBackward::new(
self.clone(),
size.to_vec(),
stride.to_vec(),
offset,
));
Ok(self.stride_view_operation(size.to_vec(), stride.to_vec(), offset, grad_fn))
}
pub fn as_strided_copy(
&self,
size: &[usize],
stride: &[isize],
storage_offset: Option<usize>,
) -> FerrotorchResult<Tensor<T>> {
let view = self.as_strided(size, stride, storage_offset)?;
if view.is_cuda() {
return materialize_strided_cuda(&view);
}
let data = view.data_vec()?;
Tensor::from_storage(TensorStorage::cpu(data), size.to_vec(), false)
}
pub fn as_strided_scatter(
&self,
src: &Tensor<T>,
size: &[usize],
stride: &[isize],
storage_offset: Option<usize>,
) -> FerrotorchResult<Tensor<T>> {
let offset = storage_offset.unwrap_or(0);
let storage_len = self.numel();
validate_bounds("as_strided_scatter", size, stride, offset, storage_len)?;
if size.len() != src.shape().len() || size.iter().zip(src.shape()).any(|(a, b)| a != b) {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"as_strided_scatter: src shape {:?} does not match requested view shape {size:?}",
src.shape()
),
});
}
if self.is_cuda() != src.is_cuda() {
return Err(FerrotorchError::DeviceMismatch {
expected: self.device(),
got: src.device(),
});
}
if self.is_cuda() {
return scatter_on_cuda(self, src, size, stride, offset);
}
let mut buf = self.data_vec()?;
let src_data = src.data_vec()?;
let ndim = size.len();
let numel: usize = size.iter().product();
if numel == 0 {
return Tensor::from_storage(TensorStorage::cpu(buf), self.shape().to_vec(), false);
}
let mut indices = vec![0usize; ndim];
#[allow(clippy::needless_range_loop)]
for src_i in 0..numel {
let mut flat = offset as i64;
for d in 0..ndim {
flat += indices[d] as i64 * stride[d] as i64;
}
buf[flat as usize] = src_data[src_i];
for d in (0..ndim).rev() {
indices[d] += 1;
if indices[d] < size[d] {
break;
}
indices[d] = 0;
}
}
Tensor::from_storage(TensorStorage::cpu(buf), self.shape().to_vec(), false)
}
}
#[derive(Debug)]
pub struct AsStridedBackward<T: Float> {
input: Tensor<T>,
size: Vec<usize>,
stride: Vec<isize>,
storage_offset: usize,
}
impl<T: Float> AsStridedBackward<T> {
pub fn new(
input: Tensor<T>,
size: Vec<usize>,
stride: Vec<isize>,
storage_offset: usize,
) -> Self {
Self {
input,
size,
stride,
storage_offset,
}
}
}
impl<T: Float> GradFn<T> for AsStridedBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if !self.input.requires_grad() {
return Ok(vec![None]);
}
let zeros = crate::creation::zeros::<T>(self.input.shape())?;
let grad_input = zeros.as_strided_scatter(
grad_output,
&self.size,
&self.stride,
Some(self.storage_offset),
)?;
Ok(vec![Some(grad_input)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"AsStridedBackward"
}
}
fn materialize_strided_cuda<T: Float>(view: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
use std::any::TypeId;
let backend = crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let storage = view.storage();
let buf = storage
.gpu_handle()
.ok_or(FerrotorchError::DeviceUnavailable)?;
let out_shape = view.shape().to_vec();
let stride = view.strides().to_vec();
let offset = view.storage_offset();
let new_handle = if TypeId::of::<T>() == TypeId::of::<f32>() {
backend.strided_copy_f32(buf, &out_shape, &stride, offset)?
} else if TypeId::of::<T>() == TypeId::of::<f64>() {
backend.strided_copy_f64(buf, &out_shape, &stride, offset)?
} else {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "as_strided_copy",
});
};
let new_storage = TensorStorage::gpu(new_handle);
Tensor::from_storage(new_storage, out_shape, false)
}
fn scatter_on_cuda<T: Float>(
base: &Tensor<T>,
src: &Tensor<T>,
size: &[usize],
stride: &[isize],
offset: usize,
) -> FerrotorchResult<Tensor<T>> {
use std::any::TypeId;
let backend = crate::gpu_dispatch::gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let base_buf = base
.storage()
.gpu_handle()
.ok_or(FerrotorchError::DeviceUnavailable)?;
let src_buf = src
.storage()
.gpu_handle()
.ok_or(FerrotorchError::DeviceUnavailable)?;
let out_shape = base.shape().to_vec();
let base_strides = base.strides().to_vec();
let base_offset = base.storage_offset();
let mut dst_handle = if TypeId::of::<T>() == TypeId::of::<f32>() {
backend.strided_copy_f32(base_buf, &out_shape, &base_strides, base_offset)?
} else if TypeId::of::<T>() == TypeId::of::<f64>() {
backend.strided_copy_f64(base_buf, &out_shape, &base_strides, base_offset)?
} else {
return Err(FerrotorchError::NotImplementedOnCuda {
op: "as_strided_scatter",
});
};
if TypeId::of::<T>() == TypeId::of::<f32>() {
backend.strided_scatter_f32(src_buf, &mut dst_handle, size, stride, offset)?;
} else {
backend.strided_scatter_f64(src_buf, &mut dst_handle, size, stride, offset)?;
}
Tensor::from_storage(TensorStorage::gpu(dst_handle), out_shape, false)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::creation::{tensor, zeros};
use crate::storage::TensorStorage;
fn t(data: &[f64], shape: &[usize]) -> Tensor<f64> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
}
#[test]
fn as_strided_reshape_to_2x3() {
let a = t(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6]);
let v = a.as_strided(&[2, 3], &[3, 1], None).unwrap();
assert_eq!(v.shape(), &[2, 3]);
assert_eq!(v.data_vec().unwrap(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn as_strided_overlapping_sliding_window() {
let a = t(&[1.0, 2.0, 3.0, 4.0, 5.0], &[5]);
let v = a.as_strided(&[3, 3], &[1, 1], None).unwrap();
assert_eq!(v.shape(), &[3, 3]);
assert_eq!(
v.data_vec().unwrap(),
vec![1.0, 2.0, 3.0, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0]
);
}
#[test]
fn as_strided_negative_stride_reverses() {
let a = t(&[1.0, 2.0, 3.0, 4.0], &[4]);
let v = a.as_strided(&[4], &[-1], Some(3)).unwrap();
assert_eq!(v.data_vec().unwrap(), vec![4.0, 3.0, 2.0, 1.0]);
}
#[test]
fn as_strided_zero_stride_broadcast() {
let a = t(&[7.0, 8.0, 9.0], &[3]);
let v = a.as_strided(&[5], &[0], Some(1)).unwrap();
assert_eq!(v.data_vec().unwrap(), vec![8.0, 8.0, 8.0, 8.0, 8.0]);
}
#[test]
fn as_strided_rejects_out_of_bounds() {
let a = t(&[1.0, 2.0, 3.0], &[3]);
let err = a.as_strided(&[4], &[1], Some(0)).unwrap_err();
assert!(
matches!(err, FerrotorchError::InvalidArgument { .. }),
"expected InvalidArgument, got {err:?}"
);
}
#[test]
fn as_strided_rejects_negative_reach() {
let a = t(&[1.0, 2.0, 3.0], &[3]);
let err = a.as_strided(&[3], &[-1], Some(1)).unwrap_err();
assert!(matches!(err, FerrotorchError::InvalidArgument { .. }));
}
#[test]
fn as_strided_rejects_size_stride_length_mismatch() {
let a = t(&[1.0, 2.0, 3.0, 4.0], &[4]);
let err = a.as_strided(&[2, 2], &[1], None).unwrap_err();
assert!(matches!(err, FerrotorchError::InvalidArgument { .. }));
}
#[test]
fn as_strided_zero_size_dim_is_valid() {
let a = t(&[1.0, 2.0, 3.0], &[3]);
let v = a.as_strided(&[0, 5], &[100, 100], Some(0)).unwrap();
assert_eq!(v.shape(), &[0, 5]);
assert_eq!(v.data_vec().unwrap(), Vec::<f64>::new());
}
#[test]
fn as_strided_shares_storage() {
let a = t(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6]);
let v = a.as_strided(&[3], &[2], Some(0)).unwrap();
assert_eq!(v.data_vec().unwrap(), vec![1.0, 3.0, 5.0]);
assert_eq!(v.storage().len(), a.storage().len());
}
#[test]
fn as_strided_copy_makes_contiguous_2x3() {
let a = t(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6]);
let copy = a.as_strided_copy(&[2, 3], &[3, 1], None).unwrap();
assert_eq!(copy.shape(), &[2, 3]);
assert!(copy.is_contiguous());
assert_eq!(copy.data_vec().unwrap(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn as_strided_copy_collects_overlapping_window() {
let a = t(&[1.0, 2.0, 3.0, 4.0, 5.0], &[5]);
let copy = a.as_strided_copy(&[3, 3], &[1, 1], None).unwrap();
assert!(copy.is_contiguous());
assert_eq!(
copy.data_vec().unwrap(),
vec![1.0, 2.0, 3.0, 2.0, 3.0, 4.0, 3.0, 4.0, 5.0]
);
}
#[test]
fn as_strided_scatter_writes_into_view_positions() {
let dst = t(&[0.0, 0.0, 0.0, 0.0, 0.0, 0.0], &[6]);
let src = t(&[10.0, 20.0, 30.0], &[3]);
let out = dst.as_strided_scatter(&src, &[3], &[2], Some(0)).unwrap();
assert_eq!(
out.data_vec().unwrap(),
vec![10.0, 0.0, 20.0, 0.0, 30.0, 0.0]
);
}
#[test]
fn as_strided_scatter_preserves_non_view_positions() {
let dst = t(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[6]);
let src = t(&[100.0, 200.0], &[2]);
let out = dst.as_strided_scatter(&src, &[2], &[2], Some(1)).unwrap();
assert_eq!(
out.data_vec().unwrap(),
vec![1.0, 100.0, 3.0, 200.0, 5.0, 6.0]
);
}
#[test]
fn as_strided_scatter_2d_view_into_1d_dst() {
let dst = zeros::<f64>(&[6]).unwrap();
let src = t(&[1.0, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3]);
let out = dst
.as_strided_scatter(&src, &[2, 3], &[3, 1], Some(0))
.unwrap();
assert_eq!(out.data_vec().unwrap(), vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn as_strided_scatter_rejects_shape_mismatch() {
let dst = zeros::<f64>(&[5]).unwrap();
let src = t(&[1.0, 2.0, 3.0], &[3]);
let err = dst
.as_strided_scatter(&src, &[2], &[1], Some(0))
.unwrap_err();
assert!(matches!(err, FerrotorchError::ShapeMismatch { .. }));
}
#[test]
fn as_strided_backward_scatters_into_input_shape() {
use crate::autograd::backward;
let input = tensor(&[1.0_f64, 2.0, 3.0, 4.0, 5.0, 6.0]).unwrap();
let input = input.requires_grad_(true);
let view = input.as_strided(&[2, 3], &[3, 1], None).unwrap();
let s = view.sum_all().unwrap();
backward(&s).unwrap();
let g = input.grad().unwrap().expect("input should have a gradient");
assert_eq!(g.data_vec().unwrap(), vec![1.0; 6]);
}
#[test]
fn as_strided_backward_overlapping_view_last_write_wins() {
use crate::autograd::backward;
let input = tensor(&[1.0_f64, 2.0, 3.0, 4.0, 5.0]).unwrap();
let input = input.requires_grad_(true);
let view = input.as_strided(&[3, 3], &[1, 1], None).unwrap();
let contig = view.contiguous().unwrap();
let s = contig.sum_all().unwrap();
backward(&s).unwrap();
let g = input.grad().unwrap().expect("input should have a gradient");
assert_eq!(g.data_vec().unwrap(), vec![1.0; 5]);
}
}