use crate::autograd::{GradFn, Var};
use crate::dtype::DType;
use crate::error::Result;
use crate::ops::{ReduceOps, ShapeOps};
use crate::runtime::{Runtime, RuntimeClient};
use crate::tensor::{Tensor, TensorId};
use std::sync::Arc;
pub struct ReshapeBackward<R: Runtime> {
input_id: TensorId,
input_shape: Vec<usize>,
input_grad_fn: Option<Arc<dyn GradFn<R>>>,
}
impl<R: Runtime> ReshapeBackward<R> {
pub fn new(
input_id: TensorId,
input_shape: Vec<usize>,
input_grad_fn: Option<Arc<dyn GradFn<R>>>,
) -> Self {
Self {
input_id,
input_shape,
input_grad_fn,
}
}
}
impl<R: Runtime> GradFn<R> for ReshapeBackward<R> {
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let grad = grad_output.reshape(&self.input_shape)?;
Ok(vec![Some(grad)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>> {
let reshaped = grad_output.tensor().reshape(&self.input_shape)?;
let grad_fn = ReshapeBackward::<R>::new(
grad_output.id(),
grad_output.shape().to_vec(),
grad_output.grad_fn().cloned(),
);
if grad_output.requires_grad() {
Ok(vec![Some(Var::from_op(reshaped, Arc::new(grad_fn)))])
} else {
Ok(vec![Some(Var::new(reshaped, false))])
}
}
fn inputs(&self) -> &[TensorId] {
std::slice::from_ref(&self.input_id)
}
fn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>> {
vec![self.input_grad_fn.clone()]
}
fn name(&self) -> &'static str {
"ReshapeBackward"
}
}
pub struct TransposeBackward<R: Runtime> {
input_id: TensorId,
input_grad_fn: Option<Arc<dyn GradFn<R>>>,
}
impl<R: Runtime> TransposeBackward<R> {
pub fn new(input_id: TensorId, input_grad_fn: Option<Arc<dyn GradFn<R>>>) -> Self {
Self {
input_id,
input_grad_fn,
}
}
}
impl<R: Runtime> GradFn<R> for TransposeBackward<R> {
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let grad = grad_output.t()?;
Ok(vec![Some(grad)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>> {
let transposed = grad_output.tensor().t()?;
let grad_fn = TransposeBackward::<R>::new(grad_output.id(), grad_output.grad_fn().cloned());
if grad_output.requires_grad() {
Ok(vec![Some(Var::from_op(transposed, Arc::new(grad_fn)))])
} else {
Ok(vec![Some(Var::new(transposed, false))])
}
}
fn inputs(&self) -> &[TensorId] {
std::slice::from_ref(&self.input_id)
}
fn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>> {
vec![self.input_grad_fn.clone()]
}
fn name(&self) -> &'static str {
"TransposeBackward"
}
}
pub struct PermuteBackward<R: Runtime> {
input_id: TensorId,
inverse_dims: Vec<usize>,
input_grad_fn: Option<Arc<dyn GradFn<R>>>,
}
impl<R: Runtime> PermuteBackward<R> {
pub fn new(
input_id: TensorId,
dims: &[usize],
input_grad_fn: Option<Arc<dyn GradFn<R>>>,
) -> Self {
let mut inverse_dims = vec![0; dims.len()];
for (i, &d) in dims.iter().enumerate() {
inverse_dims[d] = i;
}
Self {
input_id,
inverse_dims,
input_grad_fn,
}
}
fn from_inverse(
input_id: TensorId,
inverse_dims: Vec<usize>,
input_grad_fn: Option<Arc<dyn GradFn<R>>>,
) -> Self {
Self {
input_id,
inverse_dims,
input_grad_fn,
}
}
}
impl<R: Runtime> GradFn<R> for PermuteBackward<R> {
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let grad = grad_output.permute(&self.inverse_dims)?;
Ok(vec![Some(grad)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>> {
let permuted = grad_output.tensor().permute(&self.inverse_dims)?;
let grad_fn = PermuteBackward::<R>::from_inverse(
grad_output.id(),
self.inverse_dims.clone(),
grad_output.grad_fn().cloned(),
);
if grad_output.requires_grad() {
Ok(vec![Some(Var::from_op(permuted, Arc::new(grad_fn)))])
} else {
Ok(vec![Some(Var::new(permuted, false))])
}
}
fn inputs(&self) -> &[TensorId] {
std::slice::from_ref(&self.input_id)
}
fn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>> {
vec![self.input_grad_fn.clone()]
}
fn name(&self) -> &'static str {
"PermuteBackward"
}
}
pub struct ExpandBackward<R: Runtime> {
input_id: TensorId,
input_shape: Vec<usize>,
input_grad_fn: Option<Arc<dyn GradFn<R>>>,
}
impl<R: Runtime> ExpandBackward<R> {
pub fn new(
input_id: TensorId,
input_shape: Vec<usize>,
input_grad_fn: Option<Arc<dyn GradFn<R>>>,
) -> Self {
Self {
input_id,
input_shape,
input_grad_fn,
}
}
}
impl<R: Runtime> GradFn<R> for ExpandBackward<R>
where
R::Client: RuntimeClient<R> + crate::ops::TensorOps<R> + ReduceOps<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let client = R::default_client(grad_output.device());
let output_shape = grad_output.shape();
let input_ndim = self.input_shape.len();
let output_ndim = output_shape.len();
let mut result = grad_output.clone();
if output_ndim > input_ndim {
let extra_dims: Vec<usize> = (0..(output_ndim - input_ndim)).collect();
result = client.sum(&result, &extra_dims, false)?;
}
let offset = output_ndim.saturating_sub(input_ndim);
let mut reduce_dims = Vec::new();
for (i, &input_dim) in self.input_shape.iter().enumerate() {
let output_idx = offset + i;
if input_dim == 1 && output_shape[output_idx] > 1 {
reduce_dims.push(i);
}
}
if !reduce_dims.is_empty() {
result = client.sum(&result, &reduce_dims, true)?;
}
if result.shape() != self.input_shape.as_slice() {
result = result.reshape(&self.input_shape)?;
}
Ok(vec![Some(result)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>> {
use crate::autograd::var_sum;
let client = R::default_client(grad_output.tensor().device());
let output_shape = grad_output.shape();
let input_ndim = self.input_shape.len();
let output_ndim = output_shape.len();
let mut result = grad_output.clone();
if output_ndim > input_ndim {
let extra_dims: Vec<usize> = (0..(output_ndim - input_ndim)).collect();
result = var_sum(&result, &extra_dims, false, &client)?;
}
let offset = output_ndim.saturating_sub(input_ndim);
let mut reduce_dims = Vec::new();
for (i, &input_dim) in self.input_shape.iter().enumerate() {
let output_idx = offset + i;
if output_idx < output_shape.len() && input_dim == 1 && output_shape[output_idx] > 1 {
reduce_dims.push(i);
}
}
if !reduce_dims.is_empty() {
result = var_sum(&result, &reduce_dims, true, &client)?;
}
if result.shape() != self.input_shape.as_slice() {
result = var_reshape(&result, &self.input_shape)?;
}
Ok(vec![Some(result)])
}
fn inputs(&self) -> &[TensorId] {
std::slice::from_ref(&self.input_id)
}
fn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>> {
vec![self.input_grad_fn.clone()]
}
fn name(&self) -> &'static str {
"ExpandBackward"
}
}
pub fn var_reshape<R: Runtime>(a: &Var<R>, shape: &[usize]) -> Result<Var<R>> {
let output = a.tensor().reshape(shape)?;
if a.requires_grad() {
let grad_fn = ReshapeBackward::<R>::new(a.id(), a.shape().to_vec(), a.grad_fn().cloned());
Ok(Var::from_op(output, Arc::new(grad_fn)))
} else {
Ok(Var::new(output, false))
}
}
pub fn var_transpose<R: Runtime>(a: &Var<R>) -> Result<Var<R>> {
let output = a.tensor().t()?;
if a.requires_grad() {
let grad_fn = TransposeBackward::<R>::new(a.id(), a.grad_fn().cloned());
Ok(Var::from_op(output, Arc::new(grad_fn)))
} else {
Ok(Var::new(output, false))
}
}
pub fn var_permute<R: Runtime>(a: &Var<R>, dims: &[usize]) -> Result<Var<R>> {
let output = a.tensor().permute(dims)?;
if a.requires_grad() {
let grad_fn = PermuteBackward::<R>::new(a.id(), dims, a.grad_fn().cloned());
Ok(Var::from_op(output, Arc::new(grad_fn)))
} else {
Ok(Var::new(output, false))
}
}
pub fn var_broadcast_to<R: Runtime>(a: &Var<R>, shape: &[usize]) -> Result<Var<R>>
where
R::Client: RuntimeClient<R> + crate::ops::TensorOps<R> + ReduceOps<R>,
{
let output = a.tensor().broadcast_to(shape)?;
if a.requires_grad() {
let grad_fn = ExpandBackward::<R>::new(a.id(), a.shape().to_vec(), a.grad_fn().cloned());
Ok(Var::from_op(output, Arc::new(grad_fn)))
} else {
Ok(Var::new(output, false))
}
}
pub struct NarrowBackward<R: Runtime> {
input_id: TensorId,
input_shape: Vec<usize>,
dim: usize,
start: usize,
input_grad_fn: Option<Arc<dyn GradFn<R>>>,
}
impl<R: Runtime> NarrowBackward<R> {
pub fn new(
input_id: TensorId,
input_shape: Vec<usize>,
dim: usize,
start: usize,
input_grad_fn: Option<Arc<dyn GradFn<R>>>,
) -> Self {
Self {
input_id,
input_shape,
dim,
start,
input_grad_fn,
}
}
}
impl<R: Runtime<DType = DType>> GradFn<R> for NarrowBackward<R>
where
R::Client: RuntimeClient<R> + crate::ops::TensorOps<R> + ShapeOps<R>,
{
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let client = R::default_client(grad_output.device());
let length = grad_output.shape()[self.dim];
let orig_dim_size = self.input_shape[self.dim];
let end = self.start + length;
let mut parts: Vec<Tensor<R>> = Vec::new();
if self.start > 0 {
let mut pad_shape = self.input_shape.clone();
pad_shape[self.dim] = self.start;
parts.push(Tensor::<R>::zeros(
&pad_shape,
grad_output.dtype(),
grad_output.device(),
));
}
parts.push(grad_output.contiguous());
if end < orig_dim_size {
let mut pad_shape = self.input_shape.clone();
pad_shape[self.dim] = orig_dim_size - end;
parts.push(Tensor::<R>::zeros(
&pad_shape,
grad_output.dtype(),
grad_output.device(),
));
}
let refs: Vec<&Tensor<R>> = parts.iter().collect();
let grad_input = client.cat(&refs, self.dim as isize)?;
Ok(vec![Some(grad_input)])
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>> {
let client = R::default_client(grad_output.tensor().device());
let length = grad_output.shape()[self.dim];
let orig_dim_size = self.input_shape[self.dim];
let end = self.start + length;
let mut parts: Vec<Tensor<R>> = Vec::new();
if self.start > 0 {
let mut pad_shape = self.input_shape.clone();
pad_shape[self.dim] = self.start;
parts.push(Tensor::<R>::zeros(
&pad_shape,
grad_output.tensor().dtype(),
grad_output.tensor().device(),
));
}
parts.push(grad_output.tensor().contiguous());
if end < orig_dim_size {
let mut pad_shape = self.input_shape.clone();
pad_shape[self.dim] = orig_dim_size - end;
parts.push(Tensor::<R>::zeros(
&pad_shape,
grad_output.tensor().dtype(),
grad_output.tensor().device(),
));
}
let refs: Vec<&Tensor<R>> = parts.iter().collect();
let grad_input = client.cat(&refs, self.dim as isize)?;
Ok(vec![Some(Var::new(grad_input, false))])
}
fn inputs(&self) -> &[TensorId] {
std::slice::from_ref(&self.input_id)
}
fn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>> {
vec![self.input_grad_fn.clone()]
}
fn name(&self) -> &'static str {
"NarrowBackward"
}
}
pub struct CatBackward<R: Runtime> {
input_ids: Vec<TensorId>,
split_sizes: Vec<usize>,
dim: usize,
input_grad_fns: Vec<Option<Arc<dyn GradFn<R>>>>,
}
impl<R: Runtime> CatBackward<R> {
pub fn new(
input_ids: Vec<TensorId>,
split_sizes: Vec<usize>,
dim: usize,
input_grad_fns: Vec<Option<Arc<dyn GradFn<R>>>>,
) -> Self {
Self {
input_ids,
split_sizes,
dim,
input_grad_fns,
}
}
}
impl<R: Runtime> GradFn<R> for CatBackward<R> {
fn backward(&self, grad_output: &Tensor<R>) -> Result<Vec<Option<Tensor<R>>>> {
let mut grads = Vec::with_capacity(self.split_sizes.len());
let mut offset = 0;
for &size in &self.split_sizes {
let grad_slice = grad_output.narrow(self.dim as isize, offset, size)?;
grads.push(Some(grad_slice.contiguous()));
offset += size;
}
Ok(grads)
}
fn backward_var(&self, grad_output: &Var<R>) -> Result<Vec<Option<Var<R>>>> {
let mut grads = Vec::with_capacity(self.split_sizes.len());
let mut offset = 0;
for &size in &self.split_sizes {
let grad_slice = grad_output
.tensor()
.narrow(self.dim as isize, offset, size)?
.contiguous();
grads.push(Some(Var::new(grad_slice, false)));
offset += size;
}
Ok(grads)
}
fn inputs(&self) -> &[TensorId] {
&self.input_ids
}
fn input_grad_fns(&self) -> Vec<Option<Arc<dyn GradFn<R>>>> {
self.input_grad_fns.clone()
}
fn name(&self) -> &'static str {
"CatBackward"
}
}
pub fn var_narrow<R: Runtime<DType = DType>>(
a: &Var<R>,
dim: isize,
start: usize,
length: usize,
) -> Result<Var<R>>
where
R::Client: RuntimeClient<R> + crate::ops::TensorOps<R> + ShapeOps<R>,
{
let dim_idx =
a.tensor()
.layout()
.normalize_dim(dim)
.ok_or(crate::error::Error::InvalidDimension {
dim,
ndim: a.ndim(),
})?;
let output = a.tensor().narrow(dim, start, length)?;
if a.requires_grad() {
let grad_fn = NarrowBackward::<R>::new(
a.id(),
a.shape().to_vec(),
dim_idx,
start,
a.grad_fn().cloned(),
);
Ok(Var::from_op(output, Arc::new(grad_fn)))
} else {
Ok(Var::new(output, false))
}
}
pub fn var_cat<R, C>(vars: &[&Var<R>], dim: isize, client: &C) -> Result<Var<R>>
where
R: Runtime,
C: RuntimeClient<R> + crate::ops::ShapeOps<R>,
{
if vars.is_empty() {
return Err(crate::error::Error::InvalidArgument {
arg: "vars",
reason: "var_cat requires at least one input".into(),
});
}
let tensors: Vec<&Tensor<R>> = vars.iter().map(|v| v.tensor()).collect();
let output = client.cat(&tensors, dim)?;
let any_requires_grad = vars.iter().any(|v| v.requires_grad());
if any_requires_grad {
let dim_idx = vars[0].tensor().layout().normalize_dim(dim).ok_or(
crate::error::Error::InvalidDimension {
dim,
ndim: vars[0].ndim(),
},
)?;
let input_ids: Vec<TensorId> = vars.iter().map(|v| v.id()).collect();
let split_sizes: Vec<usize> = vars.iter().map(|v| v.shape()[dim_idx]).collect();
let input_grad_fns: Vec<Option<Arc<dyn GradFn<R>>>> =
vars.iter().map(|v| v.grad_fn().cloned()).collect();
let grad_fn = CatBackward::<R>::new(input_ids, split_sizes, dim_idx, input_grad_fns);
Ok(Var::from_op(output, Arc::new(grad_fn)))
} else {
Ok(Var::new(output, false))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dtype::DType;
use crate::runtime::cpu::{CpuDevice, CpuRuntime};
#[test]
fn test_reshape_backward() {
let device = CpuDevice::new();
let input =
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], &device);
let grad_out = Tensor::<CpuRuntime>::ones(&[3, 2], DType::F32, &device);
let backward = ReshapeBackward::<CpuRuntime>::new(input.id(), vec![2, 3], None);
let grads = backward.backward(&grad_out).unwrap();
let grad = grads[0].as_ref().unwrap();
assert_eq!(grad.shape(), &[2, 3]);
}
#[test]
fn test_transpose_backward() {
let device = CpuDevice::new();
let input =
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], &device);
let grad_out = Tensor::<CpuRuntime>::ones(&[3, 2], DType::F32, &device);
let backward = TransposeBackward::<CpuRuntime>::new(input.id(), None);
let grads = backward.backward(&grad_out).unwrap();
let grad = grads[0].as_ref().unwrap();
assert_eq!(grad.shape(), &[2, 3]);
}
#[test]
fn test_var_reshape() {
let device = CpuDevice::new();
let tensor =
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], &device);
let x = Var::new(tensor, true);
let y = var_reshape(&x, &[3, 2]).unwrap();
assert_eq!(y.shape(), &[3, 2]);
assert!(y.requires_grad());
assert!(y.grad_fn().is_some());
assert_eq!(y.grad_fn().unwrap().name(), "ReshapeBackward");
}
#[test]
fn test_var_transpose() {
let device = CpuDevice::new();
let tensor =
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], &device);
let x = Var::new(tensor, true);
let y = var_transpose(&x).unwrap();
assert_eq!(y.shape(), &[3, 2]);
assert!(y.requires_grad());
assert!(y.grad_fn().is_some());
assert_eq!(y.grad_fn().unwrap().name(), "TransposeBackward");
}
#[test]
fn test_permute_backward() {
let device = CpuDevice::new();
let input = Tensor::<CpuRuntime>::ones(&[2, 3, 4], DType::F32, &device);
let grad_out = Tensor::<CpuRuntime>::ones(&[3, 4, 2], DType::F32, &device);
let backward = PermuteBackward::<CpuRuntime>::new(input.id(), &[1, 2, 0], None);
let grads = backward.backward(&grad_out).unwrap();
let grad = grads[0].as_ref().unwrap();
assert_eq!(grad.shape(), &[2, 3, 4]);
}
#[test]
fn test_var_permute() {
let device = CpuDevice::new();
let tensor = Tensor::<CpuRuntime>::ones(&[2, 3, 4], DType::F32, &device);
let x = Var::new(tensor, true);
let y = var_permute(&x, &[2, 0, 1]).unwrap();
assert_eq!(y.shape(), &[4, 2, 3]);
assert!(y.requires_grad());
assert!(y.grad_fn().is_some());
assert_eq!(y.grad_fn().unwrap().name(), "PermuteBackward");
}
#[test]
fn test_expand_backward() {
let device = CpuDevice::new();
let input = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[1, 3], &device);
let grad_out = Tensor::<CpuRuntime>::ones(&[2, 3], DType::F32, &device);
let backward = ExpandBackward::<CpuRuntime>::new(input.id(), vec![1, 3], None);
let grads = backward.backward(&grad_out).unwrap();
let grad = grads[0].as_ref().unwrap();
assert_eq!(grad.shape(), &[1, 3]);
let grad_data: Vec<f32> = grad.to_vec();
assert_eq!(grad_data, vec![2.0, 2.0, 2.0]);
}
#[test]
fn test_var_broadcast_to() {
let device = CpuDevice::new();
let tensor = Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &device);
let x = Var::new(tensor, true);
let y = var_broadcast_to(&x, &[2, 3]).unwrap();
assert_eq!(y.shape(), &[2, 3]);
assert!(y.requires_grad());
assert!(y.grad_fn().is_some());
assert_eq!(y.grad_fn().unwrap().name(), "ExpandBackward");
let y_contiguous = y.tensor().contiguous();
let y_data: Vec<f32> = y_contiguous.to_vec();
assert_eq!(y_data, vec![1.0, 2.0, 3.0, 1.0, 2.0, 3.0]);
}
#[test]
fn test_reshape_backward_scalar() {
let device = CpuDevice::new();
let input = Tensor::<CpuRuntime>::from_slice(&[5.0f32], &[], &device);
let grad_out = Tensor::<CpuRuntime>::ones(&[1], DType::F32, &device);
let backward = ReshapeBackward::<CpuRuntime>::new(input.id(), vec![], None);
let grads = backward.backward(&grad_out).unwrap();
let grad = grads[0].as_ref().unwrap();
assert_eq!(grad.shape(), &[] as &[usize]);
}
#[test]
fn test_transpose_backward_3d() {
let device = CpuDevice::new();
let input = Tensor::<CpuRuntime>::ones(&[2, 3, 4], DType::F32, &device);
let grad_out = Tensor::<CpuRuntime>::ones(&[2, 4, 3], DType::F32, &device);
let backward = TransposeBackward::<CpuRuntime>::new(input.id(), None);
let grads = backward.backward(&grad_out).unwrap();
let grad = grads[0].as_ref().unwrap();
assert_eq!(grad.shape(), &[2, 3, 4]);
}
#[test]
fn test_expand_backward_multiple_dims() {
let device = CpuDevice::new();
let input = Tensor::<CpuRuntime>::from_slice(&[2.0f32], &[1, 1], &device);
let grad_out = Tensor::<CpuRuntime>::ones(&[3, 4], DType::F32, &device);
let backward = ExpandBackward::<CpuRuntime>::new(input.id(), vec![1, 1], None);
let grads = backward.backward(&grad_out).unwrap();
let grad = grads[0].as_ref().unwrap();
assert_eq!(grad.shape(), &[1, 1]);
let grad_data: Vec<f32> = grad.to_vec();
assert_eq!(grad_data, vec![12.0]);
}
#[test]
fn test_permute_backward_identity() {
let device = CpuDevice::new();
let input = Tensor::<CpuRuntime>::ones(&[2, 3], DType::F32, &device);
let grad_out =
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[2, 3], &device);
let backward = PermuteBackward::<CpuRuntime>::new(input.id(), &[0, 1], None);
let grads = backward.backward(&grad_out).unwrap();
let grad = grads[0].as_ref().unwrap();
assert_eq!(grad.shape(), &[2, 3]);
let grad_data: Vec<f32> = grad.to_vec();
assert_eq!(grad_data, vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
}
#[test]
fn test_var_narrow() {
let device = CpuDevice::new();
let tensor =
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], &[6], &device);
let x = Var::new(tensor, true);
let y = var_narrow(&x, 0, 1, 3).unwrap();
assert_eq!(y.shape(), &[3]);
assert!(y.requires_grad());
assert_eq!(y.grad_fn().unwrap().name(), "NarrowBackward");
let y_data: Vec<f32> = y.tensor().to_vec();
assert_eq!(y_data, vec![2.0, 3.0, 4.0]);
}
#[test]
fn test_narrow_backward() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let x = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0], &[5], &device),
true,
);
let y = var_narrow(&x, 0, 1, 3).unwrap();
let loss = crate::autograd::var_sum(&y, &[0], false, &client).unwrap();
let grads = crate::autograd::backward(&loss, &client).unwrap();
let grad_x: Vec<f32> = grads.get(x.id()).unwrap().to_vec();
assert_eq!(grad_x, vec![0.0, 1.0, 1.0, 1.0, 0.0]);
}
#[test]
fn test_var_cat() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let a = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[2], &device),
true,
);
let b = Var::new(
Tensor::<CpuRuntime>::from_slice(&[3.0f32, 4.0, 5.0], &[3], &device),
true,
);
let c = var_cat(&[&a, &b], 0, &client).unwrap();
assert_eq!(c.shape(), &[5]);
assert!(c.requires_grad());
assert_eq!(c.grad_fn().unwrap().name(), "CatBackward");
let c_data: Vec<f32> = c.tensor().to_vec();
assert_eq!(c_data, vec![1.0, 2.0, 3.0, 4.0, 5.0]);
}
#[test]
fn test_cat_backward() {
let device = CpuDevice::new();
let client = CpuRuntime::default_client(&device);
let a = Var::new(
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[2], &device),
true,
);
let b = Var::new(
Tensor::<CpuRuntime>::from_slice(&[3.0f32, 4.0, 5.0], &[3], &device),
true,
);
let c = var_cat(&[&a, &b], 0, &client).unwrap();
let loss = crate::autograd::var_sum(&c, &[0], false, &client).unwrap();
let grads = crate::autograd::backward(&loss, &client).unwrap();
let grad_a: Vec<f32> = grads.get(a.id()).unwrap().to_vec();
let grad_b: Vec<f32> = grads.get(b.id()).unwrap().to_vec();
assert_eq!(grad_a, vec![1.0, 1.0]);
assert_eq!(grad_b, vec![1.0, 1.0, 1.0]);
}
}