Skip to main content

ferrum_quantization/
dense.rs

1//! Dense linear projection — the baseline, uses `B::gemm` directly.
2//!
3//! Supports an optional learnable bias (Bert / Clip / many encoder models).
4//! When `bias` is set, `forward` lowers to `gemm + add_bias` (one extra
5//! dispatch on GPU backends, still part of the current command buffer).
6
7use ferrum_kernels::backend::Backend;
8
9use crate::traits::Linear;
10
11/// Dense linear projection.
12///
13/// Holds a single weight matrix laid out row-major as `[out_features, in_features]`.
14/// `forward` delegates to `B::gemm` plus (optional) `B::add_bias`.
15pub struct DenseLinear<B: Backend> {
16    weight: B::Buffer,
17    bias: Option<B::Buffer>,
18    in_features: usize,
19    out_features: usize,
20}
21
22impl<B: Backend> DenseLinear<B> {
23    /// Build a weight-only dense projection (no bias).
24    pub fn from_rows(weight_row_major: &[f32], out_features: usize, in_features: usize) -> Self {
25        debug_assert_eq!(
26            weight_row_major.len(),
27            out_features * in_features,
28            "DenseLinear weight length mismatch"
29        );
30        let weight = B::from_slice(weight_row_major);
31        Self {
32            weight,
33            bias: None,
34            in_features,
35            out_features,
36        }
37    }
38
39    /// Build a dense projection with a bias vector of length `out_features`.
40    pub fn from_rows_with_bias(
41        weight_row_major: &[f32],
42        bias: &[f32],
43        out_features: usize,
44        in_features: usize,
45    ) -> Self {
46        debug_assert_eq!(bias.len(), out_features, "DenseLinear bias length mismatch");
47        Self {
48            weight: B::from_slice(weight_row_major),
49            bias: Some(B::from_slice(bias)),
50            in_features,
51            out_features,
52        }
53    }
54
55    /// Construct by moving already-allocated `Backend` buffers.
56    pub fn from_buffer(weight: B::Buffer, out_features: usize, in_features: usize) -> Self {
57        Self {
58            weight,
59            bias: None,
60            in_features,
61            out_features,
62        }
63    }
64
65    pub fn with_bias(mut self, bias: B::Buffer) -> Self {
66        self.bias = Some(bias);
67        self
68    }
69
70    pub fn weight(&self) -> &B::Buffer {
71        &self.weight
72    }
73
74    pub fn bias(&self) -> Option<&B::Buffer> {
75        self.bias.as_ref()
76    }
77}
78
79impl<B: Backend> Linear<B> for DenseLinear<B> {
80    fn in_features(&self) -> usize {
81        self.in_features
82    }
83
84    fn out_features(&self) -> usize {
85        self.out_features
86    }
87
88    fn forward(&self, ctx: &mut B::Context, input: &B::Buffer, out: &mut B::Buffer, m: usize) {
89        B::gemm(
90            ctx,
91            input,
92            &self.weight,
93            out,
94            m,
95            self.out_features,
96            self.in_features,
97        );
98        if let Some(bias) = &self.bias {
99            B::add_bias(ctx, out, bias, m, self.out_features);
100        }
101    }
102}