1use candle::{Result, Tensor};
21
22#[derive(Clone, Debug)]
23pub struct Linear {
24 weight: Tensor,
25 bias: Option<Tensor>,
26}
27
28impl Linear {
29 pub fn new(weight: Tensor, bias: Option<Tensor>) -> Self {
30 Self { weight, bias }
31 }
32
33 pub fn weight(&self) -> &Tensor {
34 &self.weight
35 }
36
37 pub fn bias(&self) -> Option<&Tensor> {
38 self.bias.as_ref()
39 }
40}
41
42impl super::Module for Linear {
43 fn forward(&self, x: &Tensor) -> candle::Result<Tensor> {
44 let x = match *x.dims() {
47 [b1, b2, m, k] => {
48 if x.is_contiguous() {
49 let w = self.weight.t()?;
50 x.reshape((b1 * b2 * m, k))?
51 .matmul(&w)?
52 .reshape((b1, b2, m, ()))?
53 } else {
54 let w = self.weight.broadcast_left((b1, b2))?.t()?;
55 x.matmul(&w)?
56 }
57 }
58 [bsize, m, k] => {
59 if x.is_contiguous() {
60 let w = self.weight.t()?;
61 x.reshape((bsize * m, k))?
62 .matmul(&w)?
63 .reshape((bsize, m, ()))?
64 } else {
65 let w = self.weight.broadcast_left(bsize)?.t()?;
66 x.matmul(&w)?
67 }
68 }
69 _ => {
70 let w = self.weight.t()?;
71 x.matmul(&w)?
72 }
73 };
74 match &self.bias {
75 None => Ok(x),
76 Some(bias) => x.broadcast_add(bias),
77 }
78 }
79}
80
81pub fn linear(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result<Linear> {
85 let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
86 let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
87 let bound = 1. / (in_dim as f64).sqrt();
88 let init_bs = crate::Init::Uniform {
89 lo: -bound,
90 up: bound,
91 };
92 let bs = vb.get_with_hints(out_dim, "bias", init_bs)?;
93 Ok(Linear::new(ws, Some(bs)))
94}
95
96pub fn linear_no_bias(in_dim: usize, out_dim: usize, vb: crate::VarBuilder) -> Result<Linear> {
98 let init_ws = crate::init::DEFAULT_KAIMING_NORMAL;
99 let ws = vb.get_with_hints((out_dim, in_dim), "weight", init_ws)?;
100 Ok(Linear::new(ws, None))
101}
102
103pub fn linear_b(
104 in_dim: usize,
105 out_dim: usize,
106 bias: bool,
107 vb: crate::VarBuilder,
108) -> Result<Linear> {
109 if bias {
110 linear(in_dim, out_dim, vb)
111 } else {
112 linear_no_bias(in_dim, out_dim, vb)
113 }
114}