candle_nn/
linear.rs

1//! Linear layer
2//!
3//! This layer applies a linear transformation to the incoming data, `y = x@w.t() + b`.
4//! The bias is optional. The `forward` method can be used to apply the layer, it supports input
5//! with a batch dimension (so of shape `(b_sz, in_c)`) or without (of shape `(in_c,)`), the
6//! output has shape `(b_sz, out_c)` and `(out_c,)` respectively.
7//!
8//! ```rust
9//! use candle::{Tensor, Device::Cpu};
10//! use candle_nn::{Linear, Module};
11//! # fn main() -> candle::Result<()> {
12//!
13//! let w = Tensor::new(&[[1f32, 2.], [3., 4.], [5., 6.]], &Cpu)?;
14//! let layer = Linear::new(w, None); // Use no bias.
15//! let xs = Tensor::new(&[[10f32, 100.]], &Cpu)?;
16//! let ys = layer.forward(&xs)?;
17//! assert_eq!(ys.to_vec2::<f32>()?, &[[210.0, 430.0, 650.0]]);
18//! # Ok(()) }
19//! ```
20use candle::{Result, Tensor};
21
22#[derive(Clone, Debug)]
23pub struct Linear {
24    weight: Tensor,
25    bias: Option<Tensor>,
26}
27
28impl Linear {
29    pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
30        Self { weight, bias }
31    }
32
33    pub fn weight(&self) -> &Tensor {
34        &self.weight
35    }
36
37    pub fn bias(&self) -> Option<&Tensor> {
38        self.bias.as_ref()
39    }
40}
41
42impl super::Module for Linear {
43    fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
44        // When possible, we avoid using a broadcasted matmul as it is much slower
45        // than the standard matmul for the cuda and cpu backends.
46        let x = match *x.dims() {
47            [b1, b2, m, k] => {
48                if x.is_contiguous() {
49                    let w = self.weight.t()?;
50                    x.reshape((b1 * b2 * m, k))?
51                        .matmul(&w)?
52                        .reshape((b1, b2, m, ()))?
53                } else {
54                    let w = self.weight.broadcast_left((b1, b2))?.t()?;
55                    x.matmul(&w)?
56                }
57            }
58            [bsize, m, k] => {
59                if x.is_contiguous() {
60                    let w = self.weight.t()?;
61                    x.reshape((bsize * m, k))?
62                        .matmul(&w)?
63                        .reshape((bsize, m, ()))?
64                } else {
65                    let w = self.weight.broadcast_left(bsize)?.t()?;
66                    x.matmul(&w)?
67                }
68            }
69            _ => {
70                let w = self.weight.t()?;
71                x.matmul(&w)?
72            }
73        };
74        match &self.bias {
75            None => Ok(x),
76            Some(bias) => x.broadcast_add(bias),
77        }
78    }
79}
80
81/// Create or initialize a new linear layer.
82///
83/// This uses some default names for weights and biases, namely `"weight"` and `"bias"`.
84pub fn linear(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result<Linear> {
85    let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
86    let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
87    let bound = 1. / (in_dim as f64).sqrt();
88    let init_bs = crate::Init::Uniform {
89        lo: -bound,
90        up: bound,
91    };
92    let bs = vb.get_with_hints(out_dim, "bias", init_bs)?;
93    Ok(Linear::new(ws, Some(bs)))
94}
95
96/// Create or initialize a new linear layer without biases.
97pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result<Linear> {
98    let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
99    let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
100    Ok(Linear::new(ws, None))
101}
102
103pub fn linear_b(
104    in_dim: usize,
105    out_dim: usize,
106    bias: bool,
107    vb: crate::VarBuilder,
108) -> Result<Linear> {
109    if bias {
110        linear(in_dim, out_dim, vb)
111    } else {
112        linear_no_bias(in_dim, out_dim, vb)
113    }
114}