use super::ops::{matmul, relu};
use super::tensor::Tensor;
use core::ops::{Add, Mul};
pub struct Linear<T> {
pub weights: Tensor<T>,
pub bias: Tensor<T>,
}
impl<T> Linear<T>
where
T: Add<Output = T> + Mul<Output = T> + Clone,
{
pub fn new(weights: Tensor<T>, bias: Tensor<T>) -> Self {
assert!(weights.is_matrix(), "Weights must be a matrix");
assert_eq!(bias.ndim(), 1, "Bias must be a vector");
assert_eq!(
weights.shape[0], bias.shape[0],
"Bias dimension must match output dimension"
);
Linear { weights, bias }
}
pub fn forward(&self, input: &Tensor<T>, zero: T) -> Tensor<T>
where
T: Add<Output = T> + Clone,
{
let mut output = matmul(&self.weights, input, zero);
for (i, out_val) in output.data.iter_mut().enumerate() {
let bias_idx = i % self.bias.shape[0];
*out_val = out_val.clone() + self.bias.data[bias_idx].clone();
}
output
}
}
pub struct SimpleNet<T> {
pub layer1: Linear<T>,
pub layer2: Linear<T>,
}
impl<T> SimpleNet<T>
where
T: Add<Output = T> + Mul<Output = T> + Clone + PartialOrd,
{
pub fn new(layer1: Linear<T>, layer2: Linear<T>) -> Self {
SimpleNet { layer1, layer2 }
}
pub fn forward(&self, input: &Tensor<T>, zero: T) -> Tensor<T> {
let h = self.layer1.forward(input, zero.clone());
let h = relu(&h, zero.clone());
self.layer2.forward(&h, zero)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::ScalarF4E4;
#[test]
fn test_linear_layer() {
let weights = Tensor::new(
vec![
ScalarF4E4::from(1.0),
ScalarF4E4::from(2.0),
ScalarF4E4::from(3.0),
ScalarF4E4::from(4.0),
ScalarF4E4::from(5.0),
ScalarF4E4::from(6.0),
],
vec![2, 3],
);
let bias = Tensor::new(vec![ScalarF4E4::from(0.5), ScalarF4E4::from(1.0)], vec![2]);
let layer = Linear::new(weights, bias);
let input = Tensor::new(
vec![
ScalarF4E4::from(1.0),
ScalarF4E4::from(2.0),
ScalarF4E4::from(3.0),
],
vec![3, 1],
);
let output = layer.forward(&input, ScalarF4E4::ZERO);
assert!((output.data[0].to_f64() - 14.5).abs() < 0.1);
assert!((output.data[1].to_f64() - 33.0).abs() < 0.1);
}
#[test]
fn test_simple_net() {
let w1 = Tensor::new(
vec![
ScalarF4E4::from(0.1),
ScalarF4E4::from(0.2),
ScalarF4E4::from(0.3),
ScalarF4E4::from(0.4),
ScalarF4E4::from(0.5),
ScalarF4E4::from(0.6),
],
vec![3, 2],
);
let b1 = Tensor::new(
vec![
ScalarF4E4::from(0.0),
ScalarF4E4::from(0.0),
ScalarF4E4::from(0.0),
],
vec![3],
);
let w2 = Tensor::new(
vec![
ScalarF4E4::from(0.1),
ScalarF4E4::from(0.2),
ScalarF4E4::from(0.3),
ScalarF4E4::from(0.4),
ScalarF4E4::from(0.5),
ScalarF4E4::from(0.6),
],
vec![2, 3],
);
let b2 = Tensor::new(vec![ScalarF4E4::from(0.0), ScalarF4E4::from(0.0)], vec![2]);
let layer1 = Linear::new(w1, b1);
let layer2 = Linear::new(w2, b2);
let net = SimpleNet::new(layer1, layer2);
let input = Tensor::new(
vec![ScalarF4E4::from(1.0), ScalarF4E4::from(2.0)],
vec![2, 1],
);
let output = net.forward(&input, ScalarF4E4::ZERO);
assert_eq!(output.shape[0], 2);
}
}