tch_plus/nn/
linear.rs

1//! A linear fully-connected layer.
2use crate::Tensor;
3use std::borrow::Borrow;
4
5/// Configuration for a linear layer.
6#[derive(Debug, Clone, Copy)]
7pub struct LinearConfig {
8    pub ws_init: super::Init,
9    pub bs_init: Option<super::Init>,
10    pub bias: bool,
11}
12
13impl Default for LinearConfig {
14    fn default() -> Self {
15        LinearConfig { ws_init: super::init::DEFAULT_KAIMING_UNIFORM, bs_init: None, bias: true }
16    }
17}
18
19/// A linear fully-connected layer.
20#[derive(Debug)]
21pub struct Linear {
22    pub ws: Tensor,
23    pub bs: Option<Tensor>,
24}
25
26/// Creates a new linear layer.
27pub fn linear<'a, T: Borrow<super::Path<'a>>>(
28    vs: T,
29    in_dim: i64,
30    out_dim: i64,
31    c: LinearConfig,
32) -> Linear {
33    let vs = vs.borrow();
34    let bs = if c.bias {
35        let bs_init = c.bs_init.unwrap_or_else(|| {
36            let bound = 1.0 / (in_dim as f64).sqrt();
37            super::Init::Uniform { lo: -bound, up: bound }
38        });
39        Some(vs.var("bias", &[out_dim], bs_init))
40    } else {
41        None
42    };
43
44    Linear { ws: vs.var("weight", &[out_dim, in_dim], c.ws_init), bs }
45}
46
47impl super::module::Module for Linear {
48    fn forward(&self, xs: &Tensor) -> Tensor {
49        xs.linear(&self.ws, self.bs.as_ref())
50    }
51}
52
53#[test]
54fn matches_pytorch() {
55    use crate::nn::Module;
56
57    let input = Tensor::read_npy("tests/linear/in.npy").unwrap();
58    let expected_output = Tensor::read_npy("tests/linear/out.npy").unwrap();
59    let ws = Tensor::read_npy("tests/linear/ws.npy").unwrap();
60    let bs = Some(Tensor::read_npy("tests/linear/bs.npy").unwrap());
61
62    let original_output =
63        if let Some(bias) = &bs { input.matmul(&ws.tr()) + bias } else { input.matmul(&ws.tr()) };
64
65    let linear = Linear { ws, bs };
66    let output = linear.forward(&input);
67
68    let delta_output: f32 = (&output - &expected_output).norm().try_into().unwrap();
69    let delta_original: f32 = (&original_output - &expected_output).norm().try_into().unwrap();
70
71    // The `matmul()` implementation is close, but `linear()` is at least as close or closer.
72    assert!(output.allclose(&expected_output, 1e-5, 1e-8, false));
73    assert!(delta_output <= delta_original);
74}