use ndarray::{array, Array1, ArrayView1};
type Gradient = Array1<f64>;
pub trait TensorTrait {
fn forward(&self, ctx: &mut Context, inputs: Vec<ArrayView1<f64>>) -> f64;
fn backward(&self, ctx: &mut Context, grad_output: ArrayView1<f64>);
fn get_value(&self) -> ArrayView1<f64>;
fn get_grad(&self) -> Option<Gradient>;
}
#[derive(Debug, Clone)]
pub struct Tensor {
pub value: Array1<f64>,
pub grad: Option<Gradient>,
}
impl Tensor {
pub fn new(value: Array1<f64>) -> Tensor {
Tensor { value, grad: None }
}
}
impl TensorTrait for Tensor {
fn forward(&self, _ctx: &mut Context, _inputs: Vec<ArrayView1<f64>>) -> f64 {
0.0
}
fn backward(&self, _ctx: &mut Context, _grad_output: ArrayView1<f64>) {
}
fn get_value(&self) -> ArrayView1<f64> {
self.value.view()
}
fn get_grad(&self) -> Option<Gradient> {
self.grad.clone()
}
}
pub trait ForwardBackward {
fn forward(&self, ctx: &mut Context, inputs: Vec<ArrayView1<f64>>) -> f64;
fn backward(&self, ctx: &mut Context, grad_output: ArrayView1<f64>);
}
struct Dot;
struct Sum;
impl ForwardBackward for Dot {
fn forward(&self, _ctx: &mut Context, inputs: Vec<ArrayView1<f64>>) -> f64 {
let input = &inputs[0];
let weight = &inputs[1];
input.dot(weight)
}
fn backward(&self, ctx: &mut Context, grad_output: ArrayView1<f64>) {
if ctx.saved_tensors.is_empty() {
println!("Warning: saved_tensors is empty. Unable to compute gradients.");
return;
}
let mut input = ctx.saved_tensors[0].clone();
let mut weight = ctx.saved_tensors[1].clone();
let grad_input = grad_output.dot(&input.get_value().t());
let grad_weight = input.get_value().t().dot(&grad_output);
input.grad = Some(array![grad_input]);
weight.grad = Some(array![grad_weight]);
ctx.save_for_backward(vec![Box::new(*input), Box::new(*weight)]);
}
}
impl ForwardBackward for Sum {
fn forward(&self, _ctx: &mut Context, inputs: Vec<ArrayView1<f64>>) -> f64 {
let input = &inputs[0];
input.sum()
}
fn backward(&self, ctx: &mut Context, grad_output: ArrayView1<f64>) {
let mut input = ctx.saved_tensors[0].clone();
input.grad = Some(Array1::from(grad_output.map(|x| x * 1.0)));
ctx.save_for_backward(vec![Box::new(*input)]);
}
}
pub struct Context {
pub saved_tensors: Vec<Box<Tensor>>,
}
impl Context {
pub fn new() -> Context {
Context {
saved_tensors: Vec::new(),
}
}
pub fn save_for_backward(&mut self, tensors: Vec<Box<Tensor>>) {
self.saved_tensors.extend(tensors);
}
}
impl Default for Context {
fn default() -> Self {
Self::new()
}
}