use crate::tensor::Tensor;
use num_traits::Float;
pub trait Function<T: Float + Send + Sync + 'static>: Send + Sync {
fn forward(&self, inputs: &[&Tensor<T>]) -> Tensor<T>;
fn backward(&self, grad_output: &Tensor<T>, inputs: &[&Tensor<T>]) -> Vec<Option<Tensor<T>>>;
}
#[derive(Debug)]
pub struct AddFunction;
impl<T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive>
Function<T> for AddFunction
{
fn forward(&self, inputs: &[&Tensor<T>]) -> Tensor<T> {
inputs[0] + inputs[1]
}
fn backward(&self, grad_output: &Tensor<T>, _inputs: &[&Tensor<T>]) -> Vec<Option<Tensor<T>>> {
vec![Some(grad_output.clone()), Some(grad_output.clone())]
}
}
#[derive(Debug)]
pub struct SubFunction;
impl<T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive>
Function<T> for SubFunction
{
fn forward(&self, inputs: &[&Tensor<T>]) -> Tensor<T> {
inputs[0] - inputs[1]
}
fn backward(&self, grad_output: &Tensor<T>, _inputs: &[&Tensor<T>]) -> Vec<Option<Tensor<T>>> {
vec![Some(grad_output.clone()), Some(-grad_output)]
}
}
#[derive(Debug)]
pub struct MulFunction;
impl<T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive>
Function<T> for MulFunction
{
fn forward(&self, inputs: &[&Tensor<T>]) -> Tensor<T> {
inputs[0] * inputs[1]
}
fn backward(&self, grad_output: &Tensor<T>, inputs: &[&Tensor<T>]) -> Vec<Option<Tensor<T>>> {
let grad_input0 = grad_output * inputs[1];
let grad_input1 = grad_output * inputs[0];
vec![Some(grad_input0), Some(grad_input1)]
}
}
#[derive(Debug)]
pub struct MatMulFunction;
impl<T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive>
Function<T> for MatMulFunction
{
fn forward(&self, inputs: &[&Tensor<T>]) -> Tensor<T> {
inputs[0]
.matmul(inputs[1])
.expect("Matrix multiplication failed")
}
fn backward(&self, grad_output: &Tensor<T>, inputs: &[&Tensor<T>]) -> Vec<Option<Tensor<T>>> {
let grad_input0 = grad_output
.matmul(&inputs[1].transpose().expect("Transpose failed"))
.expect("MatMul failed");
let grad_input1 = inputs[0]
.transpose()
.expect("Transpose failed")
.matmul(grad_output)
.expect("MatMul failed");
vec![Some(grad_input0), Some(grad_input1)]
}
}
#[derive(Debug)]
pub struct SumFunction;
impl<T: Float + Send + Sync + 'static + ndarray::ScalarOperand + num_traits::FromPrimitive>
Function<T> for SumFunction
{
fn forward(&self, inputs: &[&Tensor<T>]) -> Tensor<T> {
let sum_value = inputs[0].sum();
Tensor::from_vec(vec![sum_value], vec![1])
}
fn backward(&self, grad_output: &Tensor<T>, inputs: &[&Tensor<T>]) -> Vec<Option<Tensor<T>>> {
let grad_expanded = Tensor::ones(inputs[0].shape());
let grad_value = grad_output
.as_array()
.iter()
.next()
.copied()
.unwrap_or(T::zero());
let mut result = grad_expanded;
result.as_array_mut().mapv_inplace(|_| grad_value);
vec![Some(result)]
}
}