use crate::errors::NeuroxResult;
use crate::{activations, ops, tensor::Tensor};
pub struct Dense {
pub w: Tensor,
pub b: Tensor,
pub activation: Activation,
input_cache: Option<Tensor>,
preact_cache: Option<Tensor>,
pub grad_w: Option<Tensor>,
pub grad_b: Option<Tensor>,
}
#[derive(Clone, Copy, Debug)]
pub enum Activation {
ReLU,
Sigmoid,
Tanh,
None,
}
impl Dense {
pub fn new(in_features: usize, out_features: usize, activation: Activation) -> Self {
Dense {
w: Tensor::random(in_features, out_features),
b: Tensor::random(1, out_features),
input_cache: None,
preact_cache: None,
grad_w: None,
grad_b: None,
activation,
}
}
pub fn forward(&mut self, input: &Tensor) -> NeuroxResult<Tensor> {
self.input_cache = Some(input.clone());
let z = ops::matmul(input, &self.w)?;
let z = z.add_row_broadcast(&self.b)?;
self.preact_cache = Some(z.clone());
let out = match self.activation {
Activation::ReLU => activations::relu(&z),
Activation::Sigmoid => activations::sigmoid(&z),
Activation::Tanh => activations::tanh(&z),
Activation::None => z,
};
Ok(out)
}
pub fn backward(&mut self, grad_out: &Tensor) -> NeuroxResult<Tensor> {
let pre = self
.preact_cache
.as_ref()
.expect("forward pass must be called before backward");
let dz = match self.activation {
Activation::ReLU => {
let g = activations::relu_grad(pre);
crate::ops::mul_elementwise(grad_out, &g)?
}
Activation::Sigmoid => {
let out = activations::sigmoid(pre);
let g = activations::sigmoid_grad_from_out(&out);
crate::ops::mul_elementwise(grad_out, &g)?
}
Activation::Tanh => {
let out = activations::tanh(pre);
let g = activations::tanh_grad_from_out(&out);
crate::ops::mul_elementwise(grad_out, &g)?
}
Activation::None => grad_out.clone(),
};
let input = self.input_cache.as_ref().expect("no input cache");
let gw = ops::matmul(&input.transpose(), &dz)?;
let mut gb = Tensor::zeros(1, dz.cols);
for j in 0..dz.cols {
let mut s = 0.0;
for i in 0..dz.rows {
s += dz.get(i, j);
}
gb.set(0, j, s);
}
let grad_input = ops::matmul(&dz, &self.w.transpose())?;
self.grad_w = Some(gw);
self.grad_b = Some(gb);
Ok(grad_input)
}
pub fn apply_gradients(&mut self, lr: f32) {
if let Some(gw) = &self.grad_w {
for idx in 0..self.w.data.len() {
self.w.data[idx] -= lr * gw.data[idx];
}
}
if let Some(gb) = &self.grad_b {
for idx in 0..self.b.data.len() {
self.b.data[idx] -= lr * gb.data[idx];
}
}
}
pub fn num_params(&self) -> usize {
self.w.data.len() + self.b.data.len()
}
}