use std::any::TypeId;
use std::sync::Arc;
use crate::autograd::no_grad::is_grad_enabled;
use crate::dtype::Float;
use crate::error::{FerrotorchError, FerrotorchResult};
use crate::gpu_dispatch::gpu_backend;
use crate::ops::linalg::{self, mm, transpose};
use crate::storage::TensorStorage;
use crate::tensor::{GradFn, Tensor};
#[inline]
fn is_f32<T: Float>() -> bool {
TypeId::of::<T>() == TypeId::of::<f32>()
}
fn mm_backward_gpu<T: Float>(
grad_output: &Tensor<T>,
a: &Tensor<T>,
b: &Tensor<T>,
) -> FerrotorchResult<(Option<Tensor<T>>, Option<Tensor<T>>)> {
let backend = gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let go_h = grad_output.gpu_handle()?;
let m = grad_output.shape()[0];
let n = grad_output.shape()[1];
let grad_a = if a.requires_grad() {
let k = b.shape()[0];
let b_h = b.gpu_handle()?;
let bt_h = backend.transpose_2d_f32(b_h, k, n)?;
let result_h = backend.matmul_f32(go_h, &bt_h, m, n, k)?;
Some(Tensor::from_storage(
TensorStorage::gpu(result_h),
vec![m, k],
false,
)?)
} else {
None
};
let grad_b = if b.requires_grad() {
let k = a.shape()[1];
let a_h = a.gpu_handle()?;
let at_h = backend.transpose_2d_f32(a_h, m, k)?;
let result_h = backend.matmul_f32(&at_h, go_h, k, m, n)?;
Some(Tensor::from_storage(
TensorStorage::gpu(result_h),
vec![k, n],
false,
)?)
} else {
None
};
Ok((grad_a, grad_b))
}
#[derive(Debug)]
pub struct MmBackward<T: Float> {
a: Tensor<T>,
b: Tensor<T>,
}
impl<T: Float> MmBackward<T> {
pub fn new(a: Tensor<T>, b: Tensor<T>) -> Self {
Self { a, b }
}
}
impl<T: Float> GradFn<T> for MmBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
if grad_output.is_cuda() && is_f32::<T>() {
let (ga, gb) = mm_backward_gpu(grad_output, &self.a, &self.b)?;
return Ok(vec![ga, gb]);
}
let device = grad_output.device();
let cpu_go = if grad_output.is_cuda() { grad_output.cpu()? } else { grad_output.clone() };
let cpu_a = if self.a.is_cuda() { self.a.cpu()? } else { self.a.clone() };
let cpu_b = if self.b.is_cuda() { self.b.cpu()? } else { self.b.clone() };
let grad_a = if self.a.requires_grad() {
let gc_data = cpu_go.data()?;
let b_data = cpu_b.data()?;
let m = grad_output.shape()[0];
let n = grad_output.shape()[1];
let k = self.b.shape()[0];
let result = crate::ops::linalg::mm_raw_bt(&gc_data, &b_data, m, n, k);
let t = Tensor::from_storage(TensorStorage::cpu(result), vec![m, k], false)?;
Some(if device.is_cuda() { t.to(device)? } else { t })
} else {
None
};
let grad_b = if self.b.requires_grad() {
let a_data = cpu_a.data()?;
let gc_data = cpu_go.data()?;
let m = self.a.shape()[0];
let k = self.a.shape()[1];
let n = grad_output.shape()[1];
let result = crate::ops::linalg::mm_raw_at(&a_data, &gc_data, k, m, n);
let t = Tensor::from_storage(TensorStorage::cpu(result), vec![k, n], false)?;
Some(if device.is_cuda() { t.to(device)? } else { t })
} else {
None
};
Ok(vec![grad_a, grad_b])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.a, &self.b]
}
fn name(&self) -> &'static str {
"MmBackward"
}
}
#[derive(Debug)]
pub struct MvBackward<T: Float> {
a: Tensor<T>,
x: Tensor<T>,
}
impl<T: Float> MvBackward<T> {
pub fn new(a: Tensor<T>, x: Tensor<T>) -> Self {
Self { a, x }
}
}
impl<T: Float> GradFn<T> for MvBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let device = grad_output.device();
let cpu_go = if grad_output.is_cuda() { grad_output.cpu()? } else { grad_output.clone() };
let cpu_a = if self.a.is_cuda() { self.a.cpu()? } else { self.a.clone() };
let cpu_x = if self.x.is_cuda() { self.x.cpu()? } else { self.x.clone() };
let grad_a = if self.a.requires_grad() {
let grad_data = cpu_go.data()?;
let x_data = cpu_x.data()?;
let m = grad_data.len();
let k = x_data.len();
let mut outer = vec![<T as num_traits::Zero>::zero(); m * k];
for i in 0..m {
for j in 0..k {
outer[i * k + j] = grad_data[i] * x_data[j];
}
}
let t = Tensor::from_storage(
TensorStorage::cpu(outer),
vec![m, k],
false,
)?;
Some(if device.is_cuda() { t.to(device)? } else { t })
} else {
None
};
let grad_x = if self.x.requires_grad() {
let at = transpose(&cpu_a)?;
let t = linalg::mv(&at, &cpu_go)?;
Some(if device.is_cuda() { t.to(device)? } else { t })
} else {
None
};
Ok(vec![grad_a, grad_x])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.a, &self.x]
}
fn name(&self) -> &'static str {
"MvBackward"
}
}
#[derive(Debug)]
pub struct DotBackward<T: Float> {
a: Tensor<T>,
b: Tensor<T>,
}
impl<T: Float> DotBackward<T> {
pub fn new(a: Tensor<T>, b: Tensor<T>) -> Self {
Self { a, b }
}
}
impl<T: Float> GradFn<T> for DotBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let device = grad_output.device();
let cpu_go = if grad_output.is_cuda() { grad_output.cpu()? } else { grad_output.clone() };
let s = cpu_go.item()?;
let grad_a = if self.a.requires_grad() {
let cpu_b = if self.b.is_cuda() { self.b.cpu()? } else { self.b.clone() };
let b_data = cpu_b.data()?;
let result: Vec<T> = b_data.iter().map(|&v| s * v).collect();
let t = Tensor::from_storage(
TensorStorage::cpu(result),
self.a.shape().to_vec(),
false,
)?;
Some(if device.is_cuda() { t.to(device)? } else { t })
} else {
None
};
let grad_b = if self.b.requires_grad() {
let cpu_a = if self.a.is_cuda() { self.a.cpu()? } else { self.a.clone() };
let a_data = cpu_a.data()?;
let result: Vec<T> = a_data.iter().map(|&v| s * v).collect();
let t = Tensor::from_storage(
TensorStorage::cpu(result),
self.b.shape().to_vec(),
false,
)?;
Some(if device.is_cuda() { t.to(device)? } else { t })
} else {
None
};
Ok(vec![grad_a, grad_b])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.a, &self.b]
}
fn name(&self) -> &'static str {
"DotBackward"
}
}
#[derive(Debug)]
pub struct BmmBackward<T: Float> {
a: Tensor<T>,
b: Tensor<T>,
}
impl<T: Float> BmmBackward<T> {
pub fn new(a: Tensor<T>, b: Tensor<T>) -> Self {
Self { a, b }
}
}
impl<T: Float> GradFn<T> for BmmBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let device = grad_output.device();
let cpu_go = if grad_output.is_cuda() { grad_output.cpu()? } else { grad_output.clone() };
let cpu_a = if self.a.is_cuda() { self.a.cpu()? } else { self.a.clone() };
let cpu_b = if self.b.is_cuda() { self.b.cpu()? } else { self.b.clone() };
let batch = self.a.shape()[0];
let m = self.a.shape()[1];
let k = self.a.shape()[2];
let n = self.b.shape()[2];
let grad_a = if self.a.requires_grad() {
let grad_data = cpu_go.data()?;
let b_data = cpu_b.data()?;
let mut result = vec![<T as num_traits::Zero>::zero(); batch * m * k];
for bi in 0..batch {
let g_off = bi * m * n;
let b_off = bi * k * n;
let r_off = bi * m * k;
for i in 0..m {
for j in 0..k {
let mut acc = <T as num_traits::Zero>::zero();
for p in 0..n {
acc = acc + grad_data[g_off + i * n + p] * b_data[b_off + j * n + p];
}
result[r_off + i * k + j] = acc;
}
}
}
let t = Tensor::from_storage(
TensorStorage::cpu(result),
vec![batch, m, k],
false,
)?;
Some(if device.is_cuda() { t.to(device)? } else { t })
} else {
None
};
let grad_b = if self.b.requires_grad() {
let a_data = cpu_a.data()?;
let grad_data = cpu_go.data()?;
let mut result = vec![<T as num_traits::Zero>::zero(); batch * k * n];
for bi in 0..batch {
let a_off = bi * m * k;
let g_off = bi * m * n;
let r_off = bi * k * n;
for i in 0..k {
for j in 0..n {
let mut acc = <T as num_traits::Zero>::zero();
for p in 0..m {
acc = acc + a_data[a_off + p * k + i] * grad_data[g_off + p * n + j];
}
result[r_off + i * n + j] = acc;
}
}
}
let t = Tensor::from_storage(
TensorStorage::cpu(result),
vec![batch, k, n],
false,
)?;
Some(if device.is_cuda() { t.to(device)? } else { t })
} else {
None
};
Ok(vec![grad_a, grad_b])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.a, &self.b]
}
fn name(&self) -> &'static str {
"BmmBackward"
}
}
#[derive(Debug)]
pub struct MatmulBackward<T: Float> {
a: Tensor<T>,
b: Tensor<T>,
}
impl<T: Float> MatmulBackward<T> {
pub fn new(a: Tensor<T>, b: Tensor<T>) -> Self {
Self { a, b }
}
}
impl<T: Float> GradFn<T> for MatmulBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
match (self.a.ndim(), self.b.ndim()) {
(2, 2) => {
let inner = MmBackward::new(self.a.clone(), self.b.clone());
inner.backward(grad_output)
}
(2, 1) => {
let inner = MvBackward::new(self.a.clone(), self.b.clone());
inner.backward(grad_output)
}
(1, 1) => {
let inner = DotBackward::new(self.a.clone(), self.b.clone());
inner.backward(grad_output)
}
(1, 2) => {
let device = grad_output.device();
let cpu_go = if grad_output.is_cuda() { grad_output.cpu()? } else { grad_output.clone() };
let cpu_a = if self.a.is_cuda() { self.a.cpu()? } else { self.a.clone() };
let cpu_b = if self.b.is_cuda() { self.b.cpu()? } else { self.b.clone() };
let grad_a = if self.a.requires_grad() {
let t = linalg::mv(&cpu_b, &cpu_go)?;
Some(if device.is_cuda() { t.to(device)? } else { t })
} else {
None
};
let grad_b = if self.b.requires_grad() {
let a_data = cpu_a.data()?;
let grad_data = cpu_go.data()?;
let k = a_data.len();
let n = grad_data.len();
let mut outer = vec![<T as num_traits::Zero>::zero(); k * n];
for ki in 0..k {
for ni in 0..n {
outer[ki * n + ni] = a_data[ki] * grad_data[ni];
}
}
let t = Tensor::from_storage(
TensorStorage::cpu(outer),
vec![k, n],
false,
)?;
Some(if device.is_cuda() { t.to(device)? } else { t })
} else {
None
};
Ok(vec![grad_a, grad_b])
}
_ => {
broadcast_matmul_backward(&self.a, &self.b, grad_output)
}
}
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.a, &self.b]
}
fn name(&self) -> &'static str {
"MatmulBackward"
}
}
fn broadcast_matmul_backward<T: Float>(
a: &Tensor<T>,
b: &Tensor<T>,
grad_output: &Tensor<T>,
) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let swap_last_two = |t: &Tensor<T>| -> FerrotorchResult<Tensor<T>> {
let shape = t.shape();
let nd = shape.len();
if nd < 2 {
return Err(FerrotorchError::InvalidArgument {
message: "Cannot transpose last two dims of tensor with ndim < 2".into(),
});
}
let data = t.data()?;
let rows = shape[nd - 2];
let cols = shape[nd - 1];
let mat_size = rows * cols;
let n_mats: usize = shape[..nd - 2].iter().product::<usize>().max(1);
let mut out = vec![<T as num_traits::Zero>::zero(); data.len()];
for m in 0..n_mats {
let off = m * mat_size;
for i in 0..rows {
for j in 0..cols {
out[off + j * rows + i] = data[off + i * cols + j];
}
}
}
let mut out_shape = shape.to_vec();
out_shape[nd - 2] = cols;
out_shape[nd - 1] = rows;
Tensor::from_storage(TensorStorage::cpu(out), out_shape, false)
};
let reduce_to_shape = |grad: Tensor<T>, target: &[usize]| -> FerrotorchResult<Tensor<T>> {
let grad_shape = grad.shape().to_vec();
if grad_shape == target {
return Ok(grad);
}
let grad_nd = grad_shape.len();
let target_nd = target.len();
let offset = grad_nd - target_nd;
let grad_data = grad.data()?;
let target_size: usize = target.iter().product::<usize>().max(1);
let mut result = vec![<T as num_traits::Zero>::zero(); target_size];
let grad_total: usize = grad_shape.iter().product::<usize>().max(1);
let mut grad_strides = vec![1usize; grad_nd];
for i in (0..grad_nd.saturating_sub(1)).rev() {
grad_strides[i] = grad_strides[i + 1] * grad_shape[i + 1];
}
let mut target_strides = vec![1usize; target_nd];
if target_nd > 0 {
for i in (0..target_nd.saturating_sub(1)).rev() {
target_strides[i] = target_strides[i + 1] * target[i + 1];
}
}
for flat in 0..grad_total {
let mut remaining = flat;
let mut target_flat = 0usize;
for d in (0..grad_nd).rev() {
let coord = remaining % grad_shape[d];
remaining /= grad_shape[d];
if d >= offset {
let td = d - offset;
let target_coord = if target[td] == 1 { 0 } else { coord };
target_flat += target_coord * target_strides[td];
}
}
result[target_flat] = result[target_flat] + grad_data[flat];
}
Tensor::from_storage(TensorStorage::cpu(result), target.to_vec(), false)
};
let grad_a = if a.requires_grad() {
let bt = swap_last_two(b)?;
let full_grad = linalg::matmul(grad_output, &bt)?;
Some(reduce_to_shape(full_grad, a.shape())?)
} else {
None
};
let grad_b = if b.requires_grad() {
let at = swap_last_two(a)?;
let full_grad = linalg::matmul(&at, grad_output)?;
Some(reduce_to_shape(full_grad, b.shape())?)
} else {
None
};
Ok(vec![grad_a, grad_b])
}
pub fn mm_differentiable<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if a.device() != b.device() {
return Err(FerrotorchError::DeviceMismatch { expected: a.device(), got: b.device() });
}
if a.is_cuda() {
let backend = crate::gpu_dispatch::gpu_backend()
.ok_or(FerrotorchError::DeviceUnavailable)?;
let m = a.shape()[0];
let k = a.shape()[1];
let n = b.shape()[1];
let handle = backend.matmul_f32(a.gpu_handle()?, b.gpu_handle()?, m, k, n)?;
let storage = TensorStorage::gpu(handle);
let shape = vec![m, n];
if is_grad_enabled() && (a.requires_grad() || b.requires_grad()) {
let grad_fn = Arc::new(MmBackward::new(a.clone(), b.clone()));
Tensor::from_operation(storage, shape, grad_fn)
} else {
Tensor::from_storage(storage, shape, false)
}
} else {
let m = a.shape()[0];
let k = a.shape()[1];
let n = b.shape()[1];
if k != b.shape()[0] {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"mm: inner dimensions mismatch: ({},{}) @ ({},{})",
m, k, b.shape()[0], n
),
});
}
let a_data = a.data()?;
let b_data = b.data()?;
let result_vec = linalg::mm_raw(a_data, b_data, m, k, n);
let storage = TensorStorage::cpu(result_vec);
let shape = vec![m, n];
if is_grad_enabled() && (a.requires_grad() || b.requires_grad()) {
let grad_fn = Arc::new(MmBackward::new(a.clone(), b.clone()));
Tensor::from_operation(storage, shape, grad_fn)
} else {
Tensor::from_storage(storage, shape, false)
}
}
}
#[derive(Debug)]
struct MmBtBackward<T: Float> {
a: Tensor<T>, b: Tensor<T>, }
impl<T: Float> MmBtBackward<T> {
fn new(a: Tensor<T>, b: Tensor<T>) -> Self {
Self { a, b }
}
}
impl<T: Float> GradFn<T> for MmBtBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let device = grad_output.device();
if grad_output.is_cuda() && is_f32::<T>() {
let backend = gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let go_h = grad_output.gpu_handle()?;
let m = grad_output.shape()[0];
let n = grad_output.shape()[1];
let grad_a = if self.a.requires_grad() {
let k = self.b.shape()[1];
let b_h = self.b.gpu_handle()?;
let result_h = backend.matmul_f32(go_h, b_h, m, n, k)?;
Some(Tensor::from_storage(
TensorStorage::gpu(result_h), vec![m, k], false,
)?)
} else { None };
let grad_b = if self.b.requires_grad() {
let k = self.a.shape()[1];
let got_h = backend.transpose_2d_f32(go_h, m, n)?;
let a_h = self.a.gpu_handle()?;
let result_h = backend.matmul_f32(&got_h, a_h, n, m, k)?;
Some(Tensor::from_storage(
TensorStorage::gpu(result_h), vec![n, k], false,
)?)
} else { None };
return Ok(vec![grad_a, grad_b]);
}
let grad_a = if self.a.requires_grad() {
Some(mm(grad_output, &self.b)?)
} else {
None
};
let grad_b = if self.b.requires_grad() {
let cpu_go = if grad_output.is_cuda() { grad_output.cpu()? } else { grad_output.clone() };
let cpu_a = if self.a.is_cuda() { self.a.cpu()? } else { self.a.clone() };
let gc_data = cpu_go.data()?;
let a_data = cpu_a.data()?;
let m = grad_output.shape()[0];
let n = grad_output.shape()[1];
let k = self.a.shape()[1];
let result = crate::ops::linalg::mm_raw_at(&gc_data, &a_data, n, m, k);
let t = Tensor::from_storage(TensorStorage::cpu(result), vec![n, k], false)?;
Some(if device.is_cuda() { t.to(device)? } else { t })
} else {
None
};
Ok(vec![grad_a, grad_b])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.a, &self.b]
}
fn name(&self) -> &'static str {
"MmBtBackward"
}
}
pub fn mm_bt_differentiable<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let m = a.shape()[0];
let k = a.shape()[1];
let n = b.shape()[0];
if b.ndim() != 2 || b.shape()[1] != k {
return Err(FerrotorchError::ShapeMismatch {
message: format!(
"mm_bt: A is ({},{}) but B is {:?} (expected ({},{}))",
m, k, b.shape(), n, k
),
});
}
if a.is_cuda() {
let backend = crate::gpu_dispatch::gpu_backend()
.ok_or(FerrotorchError::DeviceUnavailable)?;
let bt_handle = backend.transpose_2d_f32(b.gpu_handle()?, n, k)?;
let handle = backend.matmul_f32(a.gpu_handle()?, &bt_handle, m, k, n)?;
let storage = TensorStorage::gpu(handle);
let shape = vec![m, n];
return if is_grad_enabled() && (a.requires_grad() || b.requires_grad()) {
let grad_fn = Arc::new(MmBtBackward::new(a.clone(), b.clone()));
Tensor::from_operation(storage, shape, grad_fn)
} else {
Tensor::from_storage(storage, shape, false)
};
}
let a_data = a.data()?;
let b_data = b.data()?;
let result_vec = linalg::mm_raw_bt(a_data, b_data, m, k, n);
let storage = TensorStorage::cpu(result_vec);
let shape = vec![m, n];
if is_grad_enabled() && (a.requires_grad() || b.requires_grad()) {
let grad_fn = Arc::new(MmBtBackward::new(a.clone(), b.clone()));
Tensor::from_operation(storage, shape, grad_fn)
} else {
Tensor::from_storage(storage, shape, false)
}
}
#[derive(Debug)]
struct LinearFusedBackward<T: Float> {
input: Tensor<T>, weight: Tensor<T>, has_bias: bool,
bias: Option<Tensor<T>>, }
impl<T: Float> GradFn<T> for LinearFusedBackward<T> {
fn backward(&self, grad_output: &Tensor<T>) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let device = grad_output.device();
let m = grad_output.shape()[0];
let n = grad_output.shape()[1];
if grad_output.is_cuda() && is_f32::<T>() {
let backend = gpu_backend().ok_or(FerrotorchError::DeviceUnavailable)?;
let go_h = grad_output.gpu_handle()?;
let grad_input = if self.input.requires_grad() {
let k = self.weight.shape()[1];
let w_h = self.weight.gpu_handle()?;
let result_h = backend.matmul_f32(go_h, w_h, m, n, k)?;
Some(Tensor::from_storage(
TensorStorage::gpu(result_h), vec![m, k], false,
)?)
} else { None };
let grad_weight = if self.weight.requires_grad() {
let k = self.input.shape()[1];
let got_h = backend.transpose_2d_f32(go_h, m, n)?;
let inp_h = self.input.gpu_handle()?;
let result_h = backend.matmul_f32(&got_h, inp_h, n, m, k)?;
Some(Tensor::from_storage(
TensorStorage::gpu(result_h), vec![n, k], false,
)?)
} else { None };
let grad_bias = if self.has_bias {
if let Some(ref b) = self.bias {
if b.requires_grad() {
let cpu_go = grad_output.cpu()?;
let gc_data = cpu_go.data()?;
let zero = <T as num_traits::Zero>::zero();
let mut gb = vec![zero; n];
for i in 0..m {
let row = i * n;
for j in 0..n {
gb[j] = gb[j] + gc_data[row + j];
}
}
let t = Tensor::from_storage(TensorStorage::cpu(gb), vec![n], false)?;
Some(t.to(device)?)
} else { None }
} else { None }
} else { None };
return Ok(vec![grad_input, grad_weight, grad_bias]);
}
let cpu_go = if grad_output.is_cuda() { grad_output.cpu()? } else { grad_output.clone() };
let gc_data = cpu_go.data()?;
let grad_input = if self.input.requires_grad() {
let cpu_w = if self.weight.is_cuda() { self.weight.cpu()? } else { self.weight.clone() };
let w_data = cpu_w.data()?;
let k = self.weight.shape()[1];
let result = crate::ops::linalg::mm_raw(&gc_data, &w_data, m, n, k);
let t = Tensor::from_storage(TensorStorage::cpu(result), vec![m, k], false)?;
Some(if device.is_cuda() { t.to(device)? } else { t })
} else { None };
let grad_weight = if self.weight.requires_grad() {
let cpu_inp = if self.input.is_cuda() { self.input.cpu()? } else { self.input.clone() };
let a_data = cpu_inp.data()?;
let k = self.input.shape()[1];
let result = crate::ops::linalg::mm_raw_at(&gc_data, &a_data, n, m, k);
let t = Tensor::from_storage(TensorStorage::cpu(result), vec![n, k], false)?;
Some(if device.is_cuda() { t.to(device)? } else { t })
} else { None };
let grad_bias = if self.has_bias {
if let Some(ref b) = self.bias {
if b.requires_grad() {
let zero = <T as num_traits::Zero>::zero();
let mut gb = vec![zero; n];
for i in 0..m {
let row = i * n;
for j in 0..n {
gb[j] = gb[j] + gc_data[row + j];
}
}
let t = Tensor::from_storage(TensorStorage::cpu(gb), vec![n], false)?;
Some(if device.is_cuda() { t.to(device)? } else { t })
} else { None }
} else { None }
} else { None };
Ok(vec![grad_input, grad_weight, grad_bias])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
let mut v = vec![&self.input, &self.weight];
if let Some(ref b) = self.bias {
v.push(b);
}
v
}
fn name(&self) -> &'static str {
"LinearFusedBackward"
}
}
pub fn linear_fused<T: Float>(
input: &Tensor<T>,
weight: &Tensor<T>,
bias: Option<&Tensor<T>>,
) -> FerrotorchResult<Tensor<T>> {
let m = input.shape()[0];
let k = input.shape()[1];
let n = weight.shape()[0];
if input.is_cuda() {
let backend = crate::gpu_dispatch::gpu_backend()
.ok_or(FerrotorchError::DeviceUnavailable)?;
let wt_handle = backend.transpose_2d_f32(weight.gpu_handle()?, n, k)?;
let mut result_handle = backend.matmul_f32(input.gpu_handle()?, &wt_handle, m, k, n)?;
if let Some(b) = bias {
let out_shape = vec![m, n];
let b_shape = vec![n];
result_handle = backend.broadcast_add_f32(
&result_handle, b.gpu_handle()?,
&out_shape, &b_shape, &out_shape,
)?;
}
let storage = TensorStorage::gpu(result_handle);
let shape = vec![m, n];
let needs_grad = is_grad_enabled() && (input.requires_grad() || weight.requires_grad()
|| bias.map_or(false, |b| b.requires_grad()));
return if needs_grad {
let grad_fn = Arc::new(LinearFusedBackward {
input: input.clone(),
weight: weight.clone(),
has_bias: bias.is_some(),
bias: bias.map(|b| b.clone()),
});
Tensor::from_operation(storage, shape, grad_fn)
} else {
Tensor::from_storage(storage, shape, false)
};
}
let a_data = input.data()?;
let w_data = weight.data()?;
let mut result_vec = linalg::mm_raw_bt(&a_data, &w_data, m, k, n);
if let Some(b) = bias {
let b_data = b.data()?;
for i in 0..m {
let row = i * n;
for j in 0..n {
result_vec[row + j] = result_vec[row + j] + b_data[j];
}
}
}
let storage = TensorStorage::cpu(result_vec);
let shape = vec![m, n];
let needs_grad = is_grad_enabled() && (input.requires_grad() || weight.requires_grad()
|| bias.map_or(false, |b| b.requires_grad()));
if needs_grad {
let grad_fn = Arc::new(LinearFusedBackward {
input: input.clone(),
weight: weight.clone(),
has_bias: bias.is_some(),
bias: bias.map(|b| b.clone()),
});
Tensor::from_operation(storage, shape, grad_fn)
} else {
Tensor::from_storage(storage, shape, false)
}
}
pub fn mv_differentiable<T: Float>(a: &Tensor<T>, x: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let needs_grad = is_grad_enabled() && (a.requires_grad() || x.requires_grad());
let a_data = a.data()?;
let x_data = x.data()?;
let m = a.shape()[0];
let k = a.shape()[1];
let zero = <T as num_traits::Zero>::zero();
let mut result_vec = vec![zero; m];
for i in 0..m {
let mut acc = zero;
let row = i * k;
for p in 0..k {
acc = acc + a_data[row + p] * x_data[p];
}
result_vec[i] = acc;
}
let storage = TensorStorage::cpu(result_vec);
let shape = vec![m];
if needs_grad {
let grad_fn = Arc::new(MvBackward::new(a.clone(), x.clone()));
Tensor::from_operation(storage, shape, grad_fn)
} else {
Tensor::from_storage(storage, shape, false)
}
}
pub fn dot_differentiable<T: Float>(
a: &Tensor<T>,
b: &Tensor<T>,
) -> FerrotorchResult<Tensor<T>> {
let needs_grad = is_grad_enabled() && (a.requires_grad() || b.requires_grad());
let a_data = a.data()?;
let b_data = b.data()?;
let result_val = a_data
.iter()
.zip(b_data.iter())
.fold(<T as num_traits::Zero>::zero(), |acc, (&x, &y)| acc + x * y);
let storage = TensorStorage::cpu(vec![result_val]);
let shape = vec![];
if needs_grad {
let grad_fn = Arc::new(DotBackward::new(a.clone(), b.clone()));
Tensor::from_operation(storage, shape, grad_fn)
} else {
Tensor::from_storage(storage, shape, false)
}
}
pub fn bmm_differentiable<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
let needs_grad = is_grad_enabled() && (a.requires_grad() || b.requires_grad());
let a_data = a.data()?;
let b_data = b.data()?;
let batch = a.shape()[0];
let m = a.shape()[1];
let k = a.shape()[2];
let n = b.shape()[2];
let zero = <T as num_traits::Zero>::zero();
let slice_a = m * k;
let slice_b = k * n;
let slice_c = m * n;
let mut result_vec = vec![zero; batch * slice_c];
for bi in 0..batch {
let a_off = bi * slice_a;
let b_off = bi * slice_b;
let c_off = bi * slice_c;
let batch_result = linalg::mm_raw(&a_data[a_off..a_off + slice_a], &b_data[b_off..b_off + slice_b], m, k, n);
result_vec[c_off..c_off + slice_c].copy_from_slice(&batch_result);
}
let storage = TensorStorage::cpu(result_vec);
let shape = vec![batch, m, n];
if needs_grad {
let grad_fn = Arc::new(BmmBackward::new(a.clone(), b.clone()));
Tensor::from_operation(storage, shape, grad_fn)
} else {
Tensor::from_storage(storage, shape, false)
}
}
pub fn matmul_differentiable<T: Float>(
a: &Tensor<T>,
b: &Tensor<T>,
) -> FerrotorchResult<Tensor<T>> {
if a.device() != b.device() {
return Err(FerrotorchError::DeviceMismatch { expected: a.device(), got: b.device() });
}
if a.is_cuda() && a.ndim() == 2 && b.ndim() == 2 {
let backend = crate::gpu_dispatch::gpu_backend()
.ok_or(FerrotorchError::DeviceUnavailable)?;
let m = a.shape()[0];
let k = a.shape()[1];
let n = b.shape()[1];
let handle = backend.matmul_f32(a.gpu_handle()?, b.gpu_handle()?, m, k, n)?;
let storage = TensorStorage::gpu(handle);
let shape = vec![m, n];
if is_grad_enabled() && (a.requires_grad() || b.requires_grad()) {
let grad_fn = Arc::new(MatmulBackward::new(a.clone(), b.clone()));
Tensor::from_operation(storage, shape, grad_fn)
} else {
Tensor::from_storage(storage, shape, false)
}
} else {
match (a.ndim(), b.ndim()) {
(1, 1) => return dot_differentiable(a, b),
(2, 1) => return mv_differentiable(a, b),
(2, 2) => return mm_differentiable(a, b),
(3, 3) if a.shape()[0] == b.shape()[0] => return bmm_differentiable(a, b),
_ => {}
}
let result = linalg::matmul(a, b)?;
if is_grad_enabled() && (a.requires_grad() || b.requires_grad()) {
let grad_fn = Arc::new(MatmulBackward::new(a.clone(), b.clone()));
let (storage, shape) = result.into_storage_and_shape()?;
Tensor::from_operation(storage, shape, grad_fn)
} else {
Ok(result)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::TensorStorage;
fn leaf(data: &[f32], shape: &[usize]) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), true).unwrap()
}
fn no_grad_leaf(data: &[f32], shape: &[usize]) -> Tensor<f32> {
Tensor::from_storage(TensorStorage::cpu(data.to_vec()), shape.to_vec(), false).unwrap()
}
fn assert_close(actual: &[f32], expected: &[f32], tol: f32) {
assert_eq!(
actual.len(),
expected.len(),
"length mismatch: {} vs {}",
actual.len(),
expected.len()
);
for (i, (&a, &e)) in actual.iter().zip(expected.iter()).enumerate() {
assert!(
(a - e).abs() < tol,
"index {i}: {a} vs {e} (diff {})",
(a - e).abs()
);
}
}
#[test]
fn test_mm_backward_both_grads() {
let a = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = leaf(&[5.0, 6.0, 7.0, 8.0], &[2, 2]);
let c = mm_differentiable(&a, &b).unwrap();
assert_eq!(c.shape(), &[2, 2]);
let c_data = c.data().unwrap();
let loss_val: f32 = c_data.iter().sum();
#[derive(Debug)]
struct SumBackward<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for SumBackward<T> {
fn backward(
&self,
_grad_output: &Tensor<T>,
) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let ones = vec![<T as num_traits::One>::one(); self.input.numel()];
let g = Tensor::from_storage(
TensorStorage::cpu(ones),
self.input.shape().to_vec(),
false,
)?;
Ok(vec![Some(g)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"SumBackward"
}
}
let loss = Tensor::from_operation(
TensorStorage::cpu(vec![loss_val]),
vec![],
Arc::new(SumBackward { input: c }),
)
.unwrap();
loss.backward().unwrap();
let a_grad = a.grad().unwrap().expect("a should have grad");
let b_grad = b.grad().unwrap().expect("b should have grad");
assert_eq!(a_grad.shape(), &[2, 2]);
assert_eq!(b_grad.shape(), &[2, 2]);
assert_close(a_grad.data().unwrap(), &[11.0, 15.0, 11.0, 15.0], 1e-5);
assert_close(b_grad.data().unwrap(), &[4.0, 4.0, 6.0, 6.0], 1e-5);
}
#[test]
fn test_mm_backward_one_requires_grad() {
let a = leaf(&[1.0, 0.0, 0.0, 1.0], &[2, 2]); let b = no_grad_leaf(&[2.0, 3.0, 4.0, 5.0], &[2, 2]);
let c = mm_differentiable(&a, &b).unwrap();
assert!(c.grad_fn().is_some());
let grad_out = no_grad_leaf(&[1.0, 1.0, 1.0, 1.0], &[2, 2]);
let grads = c.grad_fn().unwrap().backward(&grad_out).unwrap();
assert!(grads[0].is_some());
assert!(grads[1].is_none());
let ga = grads[0].as_ref().unwrap();
assert_close(ga.data().unwrap(), &[5.0, 9.0, 5.0, 9.0], 1e-5);
}
#[test]
fn test_dot_backward() {
let a = leaf(&[1.0, 2.0, 3.0], &[3]);
let b = leaf(&[4.0, 5.0, 6.0], &[3]);
let s = dot_differentiable(&a, &b).unwrap();
assert!(s.is_scalar());
assert!((s.item().unwrap() - 32.0).abs() < 1e-5);
s.backward().unwrap();
let a_grad = a.grad().unwrap().expect("a should have grad");
let b_grad = b.grad().unwrap().expect("b should have grad");
assert_eq!(a_grad.shape(), &[3]);
assert_eq!(b_grad.shape(), &[3]);
assert_close(a_grad.data().unwrap(), &[4.0, 5.0, 6.0], 1e-5);
assert_close(b_grad.data().unwrap(), &[1.0, 2.0, 3.0], 1e-5);
}
#[test]
fn test_dot_backward_one_requires_grad() {
let a = leaf(&[2.0, 3.0], &[2]);
let b = no_grad_leaf(&[4.0, 5.0], &[2]);
let s = dot_differentiable(&a, &b).unwrap();
let grad_out = no_grad_leaf(&[1.0], &[]);
let grads = s.grad_fn().unwrap().backward(&grad_out).unwrap();
assert!(grads[0].is_some());
assert!(grads[1].is_none());
assert_close(
grads[0].as_ref().unwrap().data().unwrap(),
&[4.0, 5.0],
1e-5,
);
}
#[test]
fn test_mv_backward() {
let a = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let x = leaf(&[5.0, 6.0], &[2]);
let y = mv_differentiable(&a, &x).unwrap();
assert_eq!(y.shape(), &[2]);
let y_data = y.data().unwrap();
let loss_val: f32 = y_data.iter().sum();
#[derive(Debug)]
struct SumBackward1D<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for SumBackward1D<T> {
fn backward(
&self,
_grad_output: &Tensor<T>,
) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let ones = vec![<T as num_traits::One>::one(); self.input.numel()];
let g = Tensor::from_storage(
TensorStorage::cpu(ones),
self.input.shape().to_vec(),
false,
)?;
Ok(vec![Some(g)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"SumBackward"
}
}
let loss = Tensor::from_operation(
TensorStorage::cpu(vec![loss_val]),
vec![],
Arc::new(SumBackward1D { input: y }),
)
.unwrap();
loss.backward().unwrap();
let a_grad = a.grad().unwrap().expect("a should have grad");
let x_grad = x.grad().unwrap().expect("x should have grad");
assert_eq!(a_grad.shape(), &[2, 2]);
assert_eq!(x_grad.shape(), &[2]);
assert_close(a_grad.data().unwrap(), &[5.0, 6.0, 5.0, 6.0], 1e-5);
assert_close(x_grad.data().unwrap(), &[4.0, 6.0], 1e-5);
}
#[test]
fn test_matmul_backward_dispatches_to_dot() {
let a = leaf(&[1.0, 2.0], &[2]);
let b = leaf(&[3.0, 4.0], &[2]);
let s = matmul_differentiable(&a, &b).unwrap();
assert!(s.is_scalar());
assert!((s.item().unwrap() - 11.0).abs() < 1e-5);
s.backward().unwrap();
let a_grad = a.grad().unwrap().unwrap();
let b_grad = b.grad().unwrap().unwrap();
assert_close(a_grad.data().unwrap(), &[3.0, 4.0], 1e-5);
assert_close(b_grad.data().unwrap(), &[1.0, 2.0], 1e-5);
}
#[test]
fn test_matmul_backward_dispatches_to_mm() {
let a = leaf(&[1.0, 2.0, 3.0, 4.0], &[2, 2]);
let b = leaf(&[1.0, 0.0, 0.0, 1.0], &[2, 2]);
let c = matmul_differentiable(&a, &b).unwrap();
let grad_out = no_grad_leaf(&[1.0, 1.0, 1.0, 1.0], &[2, 2]);
let grads = c.grad_fn().unwrap().backward(&grad_out).unwrap();
assert_close(
grads[0].as_ref().unwrap().data().unwrap(),
&[1.0, 1.0, 1.0, 1.0],
1e-5,
);
assert_close(
grads[1].as_ref().unwrap().data().unwrap(),
&[4.0, 4.0, 6.0, 6.0],
1e-5,
);
}
#[test]
fn test_matmul_backward_vm() {
let a = leaf(&[1.0, 2.0], &[2]);
let b = leaf(&[3.0, 4.0, 5.0, 6.0, 7.0, 8.0], &[2, 3]);
let y = matmul_differentiable(&a, &b).unwrap();
assert_eq!(y.shape(), &[3]);
let y_data = y.data().unwrap();
let loss_val: f32 = y_data.iter().sum();
#[derive(Debug)]
struct SumBackwardVec<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for SumBackwardVec<T> {
fn backward(
&self,
_grad_output: &Tensor<T>,
) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let ones = vec![<T as num_traits::One>::one(); self.input.numel()];
let g = Tensor::from_storage(
TensorStorage::cpu(ones),
self.input.shape().to_vec(),
false,
)?;
Ok(vec![Some(g)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"SumBackward"
}
}
let loss = Tensor::from_operation(
TensorStorage::cpu(vec![loss_val]),
vec![],
Arc::new(SumBackwardVec { input: y }),
)
.unwrap();
loss.backward().unwrap();
let a_grad = a.grad().unwrap().expect("a should have grad");
let b_grad = b.grad().unwrap().expect("b should have grad");
assert_eq!(a_grad.shape(), &[2]);
assert_eq!(b_grad.shape(), &[2, 3]);
assert_close(a_grad.data().unwrap(), &[12.0, 21.0], 1e-5);
assert_close(
b_grad.data().unwrap(),
&[1.0, 1.0, 1.0, 2.0, 2.0, 2.0],
1e-5,
);
}
#[test]
fn test_bmm_backward_both_grads() {
#[rustfmt::skip]
let a = leaf(&[
1.0, 2.0, 3.0, 4.0, 1.0, 0.0, 0.0, 1.0, ], &[2, 2, 2]);
#[rustfmt::skip]
let b = leaf(&[
5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, ], &[2, 2, 2]);
let c = bmm_differentiable(&a, &b).unwrap();
assert_eq!(c.shape(), &[2, 2, 2]);
let c_data = c.data().unwrap();
let loss_val: f32 = c_data.iter().sum();
#[derive(Debug)]
struct SumBackward3D<T: Float> {
input: Tensor<T>,
}
impl<T: Float> GradFn<T> for SumBackward3D<T> {
fn backward(
&self,
_grad_output: &Tensor<T>,
) -> FerrotorchResult<Vec<Option<Tensor<T>>>> {
let ones = vec![<T as num_traits::One>::one(); self.input.numel()];
let g = Tensor::from_storage(
TensorStorage::cpu(ones),
self.input.shape().to_vec(),
false,
)?;
Ok(vec![Some(g)])
}
fn inputs(&self) -> Vec<&Tensor<T>> {
vec![&self.input]
}
fn name(&self) -> &'static str {
"SumBackward"
}
}
let loss = Tensor::from_operation(
TensorStorage::cpu(vec![loss_val]),
vec![],
Arc::new(SumBackward3D { input: c }),
)
.unwrap();
loss.backward().unwrap();
let a_grad = a.grad().unwrap().expect("a should have grad");
let b_grad = b.grad().unwrap().expect("b should have grad");
assert_eq!(a_grad.shape(), &[2, 2, 2]);
assert_eq!(b_grad.shape(), &[2, 2, 2]);
#[rustfmt::skip]
let expected_da: &[f32] = &[
11.0, 15.0, 11.0, 15.0, 19.0, 23.0, 19.0, 23.0, ];
#[rustfmt::skip]
let expected_db: &[f32] = &[
4.0, 4.0, 6.0, 6.0, 1.0, 1.0, 1.0, 1.0, ];
assert_close(a_grad.data().unwrap(), expected_da, 1e-5);
assert_close(b_grad.data().unwrap(), expected_db, 1e-5);
}
#[test]
fn test_bmm_backward_batch_size_1() {
let a = leaf(&[1.0, 2.0, 3.0, 4.0], &[1, 2, 2]);
let b = leaf(&[5.0, 6.0, 7.0, 8.0], &[1, 2, 2]);
let c = bmm_differentiable(&a, &b).unwrap();
let grad_out = no_grad_leaf(&[1.0, 1.0, 1.0, 1.0], &[1, 2, 2]);
let grads = c.grad_fn().unwrap().backward(&grad_out).unwrap();
assert!(grads[0].is_some());
assert!(grads[1].is_some());
let ga = grads[0].as_ref().unwrap();
let gb = grads[1].as_ref().unwrap();
assert_eq!(ga.shape(), &[1, 2, 2]);
assert_eq!(gb.shape(), &[1, 2, 2]);
assert_close(ga.data().unwrap(), &[11.0, 15.0, 11.0, 15.0], 1e-5);
assert_close(gb.data().unwrap(), &[4.0, 4.0, 6.0, 6.0], 1e-5);
}
#[test]
fn test_bmm_backward_one_requires_grad() {
let a = leaf(&[1.0, 0.0, 0.0, 1.0], &[1, 2, 2]);
let b = no_grad_leaf(&[2.0, 3.0, 4.0, 5.0], &[1, 2, 2]);
let c = bmm_differentiable(&a, &b).unwrap();
assert!(c.grad_fn().is_some());
let grad_out = no_grad_leaf(&[1.0, 1.0, 1.0, 1.0], &[1, 2, 2]);
let grads = c.grad_fn().unwrap().backward(&grad_out).unwrap();
assert!(grads[0].is_some());
assert!(grads[1].is_none());
let ga = grads[0].as_ref().unwrap();
assert_close(ga.data().unwrap(), &[5.0, 9.0, 5.0, 9.0], 1e-5);
}
#[test]
fn test_no_grad_skips_backward() {
let a = leaf(&[1.0, 2.0, 3.0], &[3]);
let b = leaf(&[4.0, 5.0, 6.0], &[3]);
let s = crate::autograd::no_grad::no_grad(|| dot_differentiable(&a, &b).unwrap());
assert!(s.grad_fn().is_none());
}
#[test]
fn test_matmul_backward_3d_3d_numerical() {
let eps = 1e-3f32;
let a_data: Vec<f32> = (0..12).map(|i| (i as f32) * 0.1 + 0.1).collect();
let b_data: Vec<f32> = (0..12).map(|i| (i as f32) * 0.1 + 0.5).collect();
let a = leaf(&a_data, &[2, 2, 3]);
let b = leaf(&b_data, &[2, 3, 2]);
let c = matmul_differentiable(&a, &b).unwrap();
let loss = crate::grad_fns::reduction::sum(&c).unwrap();
loss.backward().unwrap();
let analytic_a = a.grad().unwrap().unwrap().data().unwrap().to_vec();
let analytic_b = b.grad().unwrap().unwrap().data().unwrap().to_vec();
for idx in 0..a_data.len() {
let mut a_plus = a_data.clone();
a_plus[idx] += eps;
let mut a_minus = a_data.clone();
a_minus[idx] -= eps;
let loss_plus = crate::autograd::no_grad::no_grad(|| {
let ap = no_grad_leaf(&a_plus, &[2, 2, 3]);
let bp = no_grad_leaf(&b_data, &[2, 3, 2]);
let c = linalg::matmul(&ap, &bp).unwrap();
crate::grad_fns::reduction::sum(&c).unwrap().item().unwrap()
});
let loss_minus = crate::autograd::no_grad::no_grad(|| {
let am = no_grad_leaf(&a_minus, &[2, 2, 3]);
let bm = no_grad_leaf(&b_data, &[2, 3, 2]);
let c = linalg::matmul(&am, &bm).unwrap();
crate::grad_fns::reduction::sum(&c).unwrap().item().unwrap()
});
let numerical = (loss_plus - loss_minus) / (2.0 * eps);
assert!(
(numerical - analytic_a[idx]).abs() < 5e-2,
"grad_a[{idx}]: numerical={numerical}, analytic={}, diff={}",
analytic_a[idx],
(numerical - analytic_a[idx]).abs()
);
}
for idx in 0..b_data.len() {
let mut b_plus = b_data.clone();
b_plus[idx] += eps;
let mut b_minus = b_data.clone();
b_minus[idx] -= eps;
let loss_plus = crate::autograd::no_grad::no_grad(|| {
let ap = no_grad_leaf(&a_data, &[2, 2, 3]);
let bp = no_grad_leaf(&b_plus, &[2, 3, 2]);
let c = linalg::matmul(&ap, &bp).unwrap();
crate::grad_fns::reduction::sum(&c).unwrap().item().unwrap()
});
let loss_minus = crate::autograd::no_grad::no_grad(|| {
let am = no_grad_leaf(&a_data, &[2, 2, 3]);
let bm = no_grad_leaf(&b_minus, &[2, 3, 2]);
let c = linalg::matmul(&am, &bm).unwrap();
crate::grad_fns::reduction::sum(&c).unwrap().item().unwrap()
});
let numerical = (loss_plus - loss_minus) / (2.0 * eps);
assert!(
(numerical - analytic_b[idx]).abs() < 5e-2,
"grad_b[{idx}]: numerical={numerical}, analytic={}, diff={}",
analytic_b[idx],
(numerical - analytic_b[idx]).abs()
);
}
}
#[test]
fn test_matmul_backward_3d_2d_broadcast_numerical() {
let eps = 1e-4f32;
let a_data: Vec<f32> = (0..24).map(|i| (i as f32) * 0.05 + 0.1).collect();
let b_data: Vec<f32> = (0..8).map(|i| (i as f32) * 0.1 + 0.2).collect();
let a = leaf(&a_data, &[2, 3, 4]);
let b = leaf(&b_data, &[4, 2]);
let c = matmul_differentiable(&a, &b).unwrap();
let loss = crate::grad_fns::reduction::sum(&c).unwrap();
loss.backward().unwrap();
let analytic_a = a.grad().unwrap().unwrap().data().unwrap().to_vec();
let analytic_b = b.grad().unwrap().unwrap().data().unwrap().to_vec();
assert_eq!(a.grad().unwrap().unwrap().shape(), &[2, 3, 4]);
assert_eq!(b.grad().unwrap().unwrap().shape(), &[4, 2]);
for idx in 0..b_data.len() {
let mut b_plus = b_data.clone();
b_plus[idx] += eps;
let mut b_minus = b_data.clone();
b_minus[idx] -= eps;
let loss_plus = crate::autograd::no_grad::no_grad(|| {
let ap = no_grad_leaf(&a_data, &[2, 3, 4]);
let bp = no_grad_leaf(&b_plus, &[4, 2]);
let c = linalg::matmul(&ap, &bp).unwrap();
crate::grad_fns::reduction::sum(&c).unwrap().item().unwrap()
});
let loss_minus = crate::autograd::no_grad::no_grad(|| {
let am = no_grad_leaf(&a_data, &[2, 3, 4]);
let bm = no_grad_leaf(&b_minus, &[4, 2]);
let c = linalg::matmul(&am, &bm).unwrap();
crate::grad_fns::reduction::sum(&c).unwrap().item().unwrap()
});
let numerical = (loss_plus - loss_minus) / (2.0 * eps);
assert!(
(numerical - analytic_b[idx]).abs() < 1e-2,
"grad_b[{idx}]: numerical={numerical}, analytic={}, diff={}",
analytic_b[idx],
(numerical - analytic_b[idx]).abs()
);
}
for idx in 0..4 {
let mut a_plus = a_data.clone();
a_plus[idx] += eps;
let mut a_minus = a_data.clone();
a_minus[idx] -= eps;
let loss_plus = crate::autograd::no_grad::no_grad(|| {
let ap = no_grad_leaf(&a_plus, &[2, 3, 4]);
let bp = no_grad_leaf(&b_data, &[4, 2]);
let c = linalg::matmul(&ap, &bp).unwrap();
crate::grad_fns::reduction::sum(&c).unwrap().item().unwrap()
});
let loss_minus = crate::autograd::no_grad::no_grad(|| {
let am = no_grad_leaf(&a_minus, &[2, 3, 4]);
let bm = no_grad_leaf(&b_data, &[4, 2]);
let c = linalg::matmul(&am, &bm).unwrap();
crate::grad_fns::reduction::sum(&c).unwrap().item().unwrap()
});
let numerical = (loss_plus - loss_minus) / (2.0 * eps);
assert!(
(numerical - analytic_a[idx]).abs() < 1e-2,
"grad_a[{idx}]: numerical={numerical}, analytic={}, diff={}",
analytic_a[idx],
(numerical - analytic_a[idx]).abs()
);
}
}
#[test]
fn test_matmul_backward_batch_broadcast_1_vs_n() {
let a_data: Vec<f32> = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let b_data: Vec<f32> = (0..12).map(|i| (i as f32) + 1.0).collect();
let a = leaf(&a_data, &[1, 2, 3]);
let b = leaf(&b_data, &[2, 3, 2]);
let c = matmul_differentiable(&a, &b).unwrap();
assert_eq!(c.shape(), &[2, 2, 2]);
let loss = crate::grad_fns::reduction::sum(&c).unwrap();
loss.backward().unwrap();
assert_eq!(a.grad().unwrap().unwrap().shape(), &[1, 2, 3]);
assert_eq!(b.grad().unwrap().unwrap().shape(), &[2, 3, 2]);
}
}
pub fn bmm<T: Float>(a: &Tensor<T>, b: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if a.ndim() != 3 || b.ndim() != 3 {
return Err(FerrotorchError::InvalidArgument {
message: format!("bmm requires 3-D tensors, got {:?} and {:?}", a.shape(), b.shape()),
});
}
if a.device() != b.device() {
return Err(FerrotorchError::DeviceMismatch { expected: a.device(), got: b.device() });
}
let batch = a.shape()[0];
let m = a.shape()[1];
let k = a.shape()[2];
let n = b.shape()[2];
if b.shape()[0] != batch || b.shape()[1] != k {
return Err(FerrotorchError::ShapeMismatch {
message: format!("bmm: a is [{batch},{m},{k}], b is {:?}", b.shape()),
});
}
let out_shape = vec![batch, m, n];
if a.is_cuda() {
if let Some(backend) = crate::gpu_dispatch::gpu_backend() {
let handle = backend.bmm_f32(
a.gpu_handle()?, b.gpu_handle()?,
batch, m, k, n,
)?;
return Tensor::from_storage(TensorStorage::gpu(handle), out_shape, false);
}
}
let a_data = a.data()?;
let b_data = b.data()?;
let mut out = Vec::with_capacity(batch * m * n);
for bi in 0..batch {
let a_off = bi * m * k;
let b_off = bi * k * n;
for i in 0..m {
for j in 0..n {
let mut sum = <T as num_traits::Zero>::zero();
for p in 0..k {
sum = sum + a_data[a_off + i * k + p] * b_data[b_off + p * n + j];
}
out.push(sum);
}
}
}
Tensor::from_storage(TensorStorage::cpu(out), out_shape, false)
}
pub fn permute_0213<T: Float>(input: &Tensor<T>) -> FerrotorchResult<Tensor<T>> {
if input.ndim() != 4 {
return Err(FerrotorchError::InvalidArgument {
message: format!("permute_0213 requires 4-D tensor, got {:?}", input.shape()),
});
}
let d0 = input.shape()[0];
let d1 = input.shape()[1];
let d2 = input.shape()[2];
let d3 = input.shape()[3];
let out_shape = vec![d0, d2, d1, d3];
if input.is_cuda() {
if let Some(backend) = crate::gpu_dispatch::gpu_backend() {
let handle = backend.permute_0213_f32(
input.gpu_handle()?, d0, d1, d2, d3,
)?;
return Tensor::from_storage(TensorStorage::gpu(handle), out_shape, false);
}
}
let data = input.data()?;
let total = d0 * d1 * d2 * d3;
let mut out = vec![<T as num_traits::Zero>::zero(); total];
for i0 in 0..d0 {
for i1 in 0..d1 {
for i2 in 0..d2 {
for i3 in 0..d3 {
let in_idx = ((i0 * d1 + i1) * d2 + i2) * d3 + i3;
let out_idx = ((i0 * d2 + i2) * d1 + i1) * d3 + i3;
out[out_idx] = data[in_idx];
}
}
}
}
Tensor::from_storage(TensorStorage::cpu(out), out_shape, false)
}