use std::collections::HashMap;
use crate::{Module, Parameters};
use rand_distr::{Distribution, StandardNormal};
use zenu_autograd::{
creator::{rand::normal, zeros::zeros},
functions::{matmul::matmul, transpose::transpose},
Variable,
};
use zenu_matrix::{device::Device, num::Num};
pub struct Linear<T: Num, D: Device> {
in_features: usize,
out_features: usize,
pub weight: Variable<T, D>,
pub bias: Option<Variable<T, D>>,
}
impl<T: Num, D: Device> Module<T, D> for Linear<T, D> {
type Input = Variable<T, D>;
type Output = Variable<T, D>;
fn call(&self, input: Variable<T, D>) -> Variable<T, D> {
let weight_t = transpose(self.weight.clone());
let output = matmul(input, weight_t);
if let Some(bias) = &self.bias {
output.set_name("linear.intermediate_output");
output + bias.clone()
} else {
output
}
}
}
impl<T: Num, D: Device> Parameters<T, D> for Linear<T, D> {
fn weights(&self) -> HashMap<String, Variable<T, D>> {
let mut weights = HashMap::new();
weights.insert("linear.weight".to_string(), self.weight.clone());
weights
}
fn biases(&self) -> HashMap<String, Variable<T, D>> {
let mut biases = HashMap::new();
if let Some(bias) = &self.bias {
biases.insert("linear.bias".to_string(), bias.clone());
}
biases
}
}
impl<T: Num, D: Device> Linear<T, D> {
#[must_use]
pub fn new(in_features: usize, out_features: usize, use_bias: bool) -> Self
where
StandardNormal: Distribution<T>,
{
let weight = normal(T::zero(), T::one(), None, [out_features, in_features]);
weight
.get_data_mut()
.to_ref_mut()
.div_scalar_assign(T::from_usize(in_features).sqrt());
let bias = if use_bias {
let bias = zeros([out_features]);
bias.set_name("linear.bias");
bias.set_is_train(true);
Some(bias)
} else {
None
};
weight.set_is_train(true);
weight.set_name("linear.weight");
Self {
in_features,
out_features,
weight,
bias,
}
}
#[must_use]
pub fn to<Dout: Device>(self) -> Linear<T, Dout> {
Linear {
in_features: self.in_features,
out_features: self.out_features,
weight: self.weight.to(),
bias: self.bias.map(|b| b.to()),
}
}
}