use crate::nn::init::Initializer;
use crate::nn::module::Module;
use crate::tensor::{GraphContext, Tensor};
use std::cell::RefCell;
use std::rc::Rc;
pub struct Linear {
pub weights: Tensor,
pub bias: Option<Tensor>,
pub in_features: usize,
pub out_features: usize,
}
impl Linear {
pub fn new(
context: &Rc<RefCell<GraphContext>>,
name: &str,
in_features: usize,
out_features: usize,
) -> Self {
Self::with_initializers(
context,
name,
in_features,
out_features,
Initializer::XavierUniform,
Some(Initializer::Zeros),
)
}
pub fn without_bias(
context: &Rc<RefCell<GraphContext>>,
name: &str,
in_features: usize,
out_features: usize,
) -> Self {
Self::with_initializers(
context,
name,
in_features,
out_features,
Initializer::XavierUniform,
None,
)
}
pub fn with_initializers(
context: &Rc<RefCell<GraphContext>>,
name: &str,
in_features: usize,
out_features: usize,
weight_init: Initializer,
bias_init: Option<Initializer>,
) -> Self {
let weights_name = format!("{}.weights", name);
let weights = Tensor::new_parameter_with_shape(
context,
&weights_name,
vec![in_features, out_features],
weight_init,
);
let bias = bias_init.map(|init| {
let bias_name = format!("{}.bias", name);
Tensor::new_parameter_with_shape(context, &bias_name, vec![1, out_features], init)
});
Self {
weights,
bias,
in_features,
out_features,
}
}
}
impl Module for Linear {
fn forward(&self, inputs: &Tensor) -> Tensor {
let dot = inputs.dot(&self.weights);
match &self.bias {
Some(b) => &dot + b,
None => dot,
}
}
fn parameters(&self) -> Vec<Tensor> {
let mut params = vec![self.weights.clone()];
if let Some(b) = &self.bias {
params.push(b.clone());
}
params
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn linear_registers_shapes() {
let ctx = Rc::new(RefCell::new(GraphContext::new()));
let _fc = Linear::new(&ctx, "fc1", 784, 128);
let borrowed = ctx.borrow();
let w_meta = borrowed.parameter_meta("fc1.weights").unwrap();
assert_eq!(w_meta.shape, vec![784, 128]);
let b_meta = borrowed.parameter_meta("fc1.bias").unwrap();
assert_eq!(b_meta.shape, vec![1, 128]);
}
#[test]
fn linear_without_bias_has_no_bias_param() {
let ctx = Rc::new(RefCell::new(GraphContext::new()));
let fc = Linear::without_bias(&ctx, "fc2", 32, 16);
assert!(fc.bias.is_none());
assert_eq!(fc.parameters().len(), 1);
assert!(ctx.borrow().parameter_meta("fc2.bias").is_none());
}
}