use std::collections::HashMap;
use axonml_autograd::Variable;
use axonml_tensor::Tensor;
use crate::init::{kaiming_uniform, zeros};
use crate::module::Module;
use crate::parameter::Parameter;
pub struct Linear {
pub weight: Parameter,
pub bias: Option<Parameter>,
in_features: usize,
out_features: usize,
}
impl Linear {
pub fn new(in_features: usize, out_features: usize) -> Self {
Self::with_bias(in_features, out_features, true)
}
pub fn with_bias(in_features: usize, out_features: usize, bias: bool) -> Self {
let weight_data = kaiming_uniform(out_features, in_features);
let weight = Parameter::named("weight", weight_data, true);
let bias_param = if bias {
let bias_data = zeros(&[out_features]);
Some(Parameter::named("bias", bias_data, true))
} else {
None
};
Self {
weight,
bias: bias_param,
in_features,
out_features,
}
}
pub fn from_weights(weight: Tensor<f32>, bias: Option<Tensor<f32>>) -> Self {
let out_features = weight.shape()[0];
let in_features = weight.shape()[1];
Self {
weight: Parameter::named("weight", weight, true),
bias: bias.map(|b| Parameter::named("bias", b, true)),
in_features,
out_features,
}
}
pub fn in_features(&self) -> usize {
self.in_features
}
pub fn out_features(&self) -> usize {
self.out_features
}
}
impl Module for Linear {
fn forward(&self, input: &Variable) -> Variable {
let input_shape = input.shape();
let batch_dims: Vec<usize> = input_shape[..input_shape.len() - 1].to_vec();
let total_batch: usize = batch_dims.iter().product();
let input_2d = if input_shape.len() > 2 {
input.reshape(&[total_batch, self.in_features])
} else {
input.clone()
};
let weight_var = self.weight.variable();
let weight_t = weight_var.transpose(0, 1);
let mut output = input_2d.matmul(&weight_t);
if let Some(ref bias) = self.bias {
let bias_var = bias.variable();
output = output.add_var(&bias_var);
}
if batch_dims.len() > 1 || (batch_dims.len() == 1 && input_shape.len() > 2) {
let mut output_shape: Vec<usize> = batch_dims.clone();
output_shape.push(self.out_features);
output.reshape(&output_shape)
} else {
output
}
}
fn parameters(&self) -> Vec<Parameter> {
let mut params = vec![self.weight.clone()];
if let Some(ref bias) = self.bias {
params.push(bias.clone());
}
params
}
fn named_parameters(&self) -> HashMap<String, Parameter> {
let mut params = HashMap::new();
params.insert("weight".to_string(), self.weight.clone());
if let Some(ref bias) = self.bias {
params.insert("bias".to_string(), bias.clone());
}
params
}
fn name(&self) -> &'static str {
"Linear"
}
}
impl std::fmt::Debug for Linear {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("Linear")
.field("in_features", &self.in_features)
.field("out_features", &self.out_features)
.field("bias", &self.bias.is_some())
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_linear_creation() {
let linear = Linear::new(10, 5);
assert_eq!(linear.in_features(), 10);
assert_eq!(linear.out_features(), 5);
assert!(linear.bias.is_some());
}
#[test]
fn test_linear_no_bias() {
let linear = Linear::with_bias(10, 5, false);
assert!(linear.bias.is_none());
}
#[test]
fn test_linear_forward() {
let linear = Linear::new(3, 2);
let input = Variable::new(
Tensor::from_vec(vec![1.0, 2.0, 3.0], &[1, 3]).unwrap(),
false,
);
let output = linear.forward(&input);
assert_eq!(output.shape(), vec![1, 2]);
}
#[test]
fn test_linear_batch_forward() {
let linear = Linear::new(4, 2);
let input = Variable::new(Tensor::from_vec(vec![1.0; 12], &[3, 4]).unwrap(), false);
let output = linear.forward(&input);
assert_eq!(output.shape(), vec![3, 2]);
}
#[test]
fn test_linear_parameters() {
let linear = Linear::new(10, 5);
let params = linear.parameters();
assert_eq!(params.len(), 2);
let linear_no_bias = Linear::with_bias(10, 5, false);
let params_no_bias = linear_no_bias.parameters();
assert_eq!(params_no_bias.len(), 1); }
#[test]
fn test_linear_num_parameters() {
let linear = Linear::new(10, 5);
assert_eq!(linear.num_parameters(), 55);
}
}