ferrum_quantization/
dense.rs1use ferrum_kernels::backend::Backend;
8
9use crate::traits::Linear;
10
11pub struct DenseLinear<B: Backend> {
16 weight: B::Buffer,
17 bias: Option<B::Buffer>,
18 in_features: usize,
19 out_features: usize,
20}
21
22impl<B: Backend> DenseLinear<B> {
23 pub fn from_rows(weight_row_major: &[f32], out_features: usize, in_features: usize) -> Self {
25 debug_assert_eq!(
26 weight_row_major.len(),
27 out_features * in_features,
28 "DenseLinear weight length mismatch"
29 );
30 let weight = B::from_slice(weight_row_major);
31 Self {
32 weight,
33 bias: None,
34 in_features,
35 out_features,
36 }
37 }
38
39 pub fn from_rows_with_bias(
41 weight_row_major: &[f32],
42 bias: &[f32],
43 out_features: usize,
44 in_features: usize,
45 ) -> Self {
46 debug_assert_eq!(bias.len(), out_features, "DenseLinear bias length mismatch");
47 Self {
48 weight: B::from_slice(weight_row_major),
49 bias: Some(B::from_slice(bias)),
50 in_features,
51 out_features,
52 }
53 }
54
55 pub fn from_buffer(weight: B::Buffer, out_features: usize, in_features: usize) -> Self {
57 Self {
58 weight,
59 bias: None,
60 in_features,
61 out_features,
62 }
63 }
64
65 pub fn with_bias(mut self, bias: B::Buffer) -> Self {
66 self.bias = Some(bias);
67 self
68 }
69
70 pub fn weight(&self) -> &B::Buffer {
71 &self.weight
72 }
73
74 pub fn bias(&self) -> Option<&B::Buffer> {
75 self.bias.as_ref()
76 }
77}
78
79impl<B: Backend> Linear<B> for DenseLinear<B> {
80 fn in_features(&self) -> usize {
81 self.in_features
82 }
83
84 fn out_features(&self) -> usize {
85 self.out_features
86 }
87
88 fn forward(&self, ctx: &mut B::Context, input: &B::Buffer, out: &mut B::Buffer, m: usize) {
89 B::gemm(
90 ctx,
91 input,
92 &self.weight,
93 out,
94 m,
95 self.out_features,
96 self.in_features,
97 );
98 if let Some(bias) = &self.bias {
99 B::add_bias(ctx, out, bias, m, self.out_features);
100 }
101 }
102}