rai_nn/
linear.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
use rai_core::{AsDevice, Shape, Tensor, Type};
use rai_derive::Module;

use crate::init::{self, Init, DEFAULT_KAIMING_NORMAL};

#[derive(Clone, Debug, Module)]
#[module(crate = rai_core)]
pub struct Linear {
    weight: Tensor,
    bias: Option<Tensor>,
}

impl Linear {
    #[inline]
    pub fn new(
        in_features: usize,
        out_features: usize,
        has_bias: bool,
        dtype: impl Type,
        device: impl AsDevice,
    ) -> Self {
        let bound = 1. / (in_features as f64).sqrt();
        let bias_init = match has_bias {
            true => Some(init::Uniform::new(-bound, bound)),
            false => None,
        };
        Self::new_with_init(
            in_features,
            out_features,
            dtype,
            device,
            DEFAULT_KAIMING_NORMAL,
            bias_init,
        )
    }

    pub fn new_with_init(
        in_features: usize,
        out_features: usize,
        dtype: impl Type,
        device: impl AsDevice,
        weight_init: impl Init,
        bias_init: Option<impl Init>,
    ) -> Self {
        let device = device.device();
        let weight = weight_init.new_tensor([out_features, in_features], dtype, device);
        let bias = bias_init.map(|init| init.new_tensor([out_features], dtype, device));
        Self { weight, bias }
    }

    pub fn new_with_params(weight: Tensor, bias: Option<Tensor>) -> Self {
        Self { weight, bias }
    }

    pub fn weight(&self) -> &Tensor {
        &self.weight
    }

    pub fn fwd(&self, x: &Tensor) -> Tensor {
        // todo: move the broadcast checking to matmul?
        let w = &match x.shape() {
            [b1, b2, _, _] => self.weight.broadcast_left([*b1, *b2]).t(),
            [b, _, _] => self.weight.broadcast_left([*b]).t(),
            _ => self.weight.t(),
        };
        match &self.bias {
            Some(bias) => x.matmul(w) + bias,
            None => x.matmul(w),
        }
    }
}