use alloc::vec::Vec;
use super::ops::{matmul, transpose};
use super::tensor::Tensor;
use core::ops::{Add, Mul, Sub};
pub struct LinearGradients<T> {
pub weight_grad: Tensor<T>,
pub bias_grad: Tensor<T>,
pub input_grad: Tensor<T>,
}
pub fn linear_backward<T>(
output_grad: &Tensor<T>,
input: &Tensor<T>,
weights: &Tensor<T>,
zero: T,
) -> LinearGradients<T>
where
T: Add<Output = T> + Mul<Output = T> + Clone,
{
let input_t = transpose(input);
let weight_grad = matmul(output_grad, &input_t, zero.clone());
let bias_grad = output_grad.clone();
let weights_t = transpose(weights);
let input_grad = matmul(&weights_t, output_grad, zero);
LinearGradients {
weight_grad,
bias_grad,
input_grad,
}
}
pub fn relu_backward<T>(output_grad: &Tensor<T>, input: &Tensor<T>, zero: T) -> Tensor<T>
where
T: Mul<Output = T> + Clone + PartialOrd,
{
assert_eq!(
output_grad.shape, input.shape,
"Shape mismatch in ReLU backward"
);
let grad_data: Vec<T> = output_grad
.data
.iter()
.zip(input.data.iter())
.map(|(g, x)| {
if x > &zero {
g.clone() } else {
zero.clone() }
})
.collect();
Tensor::new(grad_data, output_grad.shape.clone())
}
pub fn mse_loss<T>(predicted: &Tensor<T>, target: &Tensor<T>) -> T
where
T: Sub<Output = T> + Mul<Output = T> + Add<Output = T> + Clone,
{
assert_eq!(predicted.shape, target.shape, "Shape mismatch in MSE loss");
let mut sum = target.data[0].clone() - target.data[0].clone();
for (p, t) in predicted.data.iter().zip(target.data.iter()) {
let diff = p.clone() - t.clone();
sum = sum + (diff.clone() * diff);
}
sum
}
pub fn mse_loss_grad<T>(predicted: &Tensor<T>, target: &Tensor<T>) -> Tensor<T>
where
T: Sub<Output = T> + Mul<Output = T> + Clone,
{
assert_eq!(predicted.shape, target.shape, "Shape mismatch in MSE grad");
let grad_data: Vec<T> = predicted
.data
.iter()
.zip(target.data.iter())
.map(|(p, t)| p.clone() - t.clone())
.collect();
Tensor::new(grad_data, predicted.shape.clone())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ScalarF4E4;
#[test]
fn test_linear_backward() {
let weights = Tensor::new(
vec![
ScalarF4E4::from(1.0),
ScalarF4E4::from(2.0),
ScalarF4E4::from(3.0),
ScalarF4E4::from(4.0),
],
vec![2, 2],
);
let input = Tensor::new(
vec![ScalarF4E4::from(1.0), ScalarF4E4::from(2.0)],
vec![2, 1],
);
let output_grad = Tensor::new(
vec![ScalarF4E4::from(1.0), ScalarF4E4::from(1.0)],
vec![2, 1],
);
let grads = linear_backward(&output_grad, &input, &weights, ScalarF4E4::ZERO);
assert_eq!(grads.weight_grad.shape, weights.shape);
assert_eq!(grads.input_grad.shape, input.shape);
}
#[test]
fn test_relu_backward() {
let input = Tensor::new(
vec![
ScalarF4E4::from(-1.0),
ScalarF4E4::from(2.0),
ScalarF4E4::from(-3.0),
ScalarF4E4::from(4.0),
],
vec![4],
);
let output_grad = Tensor::new(
vec![
ScalarF4E4::from(1.0),
ScalarF4E4::from(1.0),
ScalarF4E4::from(1.0),
ScalarF4E4::from(1.0),
],
vec![4],
);
let input_grad = relu_backward(&output_grad, &input, ScalarF4E4::ZERO);
assert_eq!(input_grad.data[0].to_f64(), 0.0);
assert_eq!(input_grad.data[1].to_f64(), 1.0);
assert_eq!(input_grad.data[2].to_f64(), 0.0);
assert_eq!(input_grad.data[3].to_f64(), 1.0);
}
#[test]
fn test_mse_loss_grad() {
let predicted = Tensor::new(vec![ScalarF4E4::from(2.0), ScalarF4E4::from(3.0)], vec![2]);
let target = Tensor::new(vec![ScalarF4E4::from(1.0), ScalarF4E4::from(2.0)], vec![2]);
let grad = mse_loss_grad(&predicted, &target);
assert_eq!(grad.data[0].to_f64(), 1.0);
assert_eq!(grad.data[1].to_f64(), 1.0);
}
}