ferrum_quantization/gguf/linear.rs
1//! `GgufLinear<B>`: a GGUF-sourced linear projection that integrates with
2//! ferrum's `Linear<B>` trait.
3//!
4//! Phase 1B uses an **eager-dequant-at-load** strategy: when constructed from
5//! a candle `QTensor`, the quantized payload is decoded to fp32 once on CPU,
6//! then handed to `DenseLinear<B>` so the runtime path goes through the
7//! standard `B::gemm` kernel. This is the simplest correct path that works
8//! uniformly across CPU / Metal / CUDA without per-backend bridging code.
9//!
10//! Trade-off: we lose GGUF's memory advantage (Q4_K_M @ 4.5 bits/weight
11//! becomes fp32 @ 32 bits/weight in RAM) and we don't get fused
12//! dequant-matmul perf. Phase 1D will replace this with a real
13//! quantization-aware Linear that holds the QTensor and dispatches to
14//! Metal / CUDA Q4_K_M kernels.
15//!
16//! Why a dedicated `GgufLinear<B>` type instead of just returning
17//! `DenseLinear<B>`? So Phase 1D can swap the internals (eager dequant →
18//! lazy QMatMul) without churning the public API of any `WeightLoader`
19//! that already returns `Box<dyn Linear<B>>`.
20
21use candle_core::quantized::QTensor;
22use candle_core::{Device, Result as CandleResult};
23use ferrum_kernels::backend::Backend;
24
25use crate::dense::DenseLinear;
26use crate::traits::Linear;
27
28/// Linear projection backed by a GGUF-sourced quantized tensor.
29///
30/// Internally a `DenseLinear<B>` (Phase 1B), so the runtime path is the same
31/// as a plain dense weight. The distinct type lets later phases evolve the
32/// representation without changing call sites.
33pub struct GgufLinear<B: Backend> {
34 inner: DenseLinear<B>,
35}
36
37impl<B: Backend> GgufLinear<B> {
38 /// Build from a candle `QTensor` previously read out of a GGUF file.
39 ///
40 /// Expects a 2-D weight whose shape is `[out_features, in_features]`
41 /// (the GGUF convention for linear projections — rows are output
42 /// neurons). Errors if the rank is wrong or the dequant step fails.
43 pub fn from_qtensor(qt: &QTensor) -> CandleResult<Self> {
44 let dims = qt.shape().dims();
45 if dims.len() != 2 {
46 return Err(candle_core::Error::Msg(format!(
47 "GgufLinear: expected 2-D weight tensor, got rank {} (shape {:?})",
48 dims.len(),
49 dims
50 )));
51 }
52 let out_features = dims[0];
53 let in_features = dims[1];
54 let weights = dequantize_to_vec(qt)?;
55 Ok(Self {
56 inner: DenseLinear::<B>::from_rows(&weights, out_features, in_features),
57 })
58 }
59
60 /// Build with a bias vector. `bias_qt` must be a 1-D `[out_features]`
61 /// tensor — typical for Qwen2.5 / Bert / any model with attention bias.
62 pub fn from_qtensor_with_bias(qt: &QTensor, bias_qt: &QTensor) -> CandleResult<Self> {
63 let weight_dims = qt.shape().dims();
64 if weight_dims.len() != 2 {
65 return Err(candle_core::Error::Msg(format!(
66 "GgufLinear: expected 2-D weight, got rank {}",
67 weight_dims.len()
68 )));
69 }
70 let out_features = weight_dims[0];
71 let in_features = weight_dims[1];
72
73 let bias_dims = bias_qt.shape().dims();
74 if bias_dims.len() != 1 || bias_dims[0] != out_features {
75 return Err(candle_core::Error::Msg(format!(
76 "GgufLinear: bias shape {:?} doesn't match weight out_features {}",
77 bias_dims, out_features
78 )));
79 }
80
81 let weights = dequantize_to_vec(qt)?;
82 let bias = dequantize_to_vec(bias_qt)?;
83 Ok(Self {
84 inner: DenseLinear::<B>::from_rows_with_bias(
85 &weights,
86 &bias,
87 out_features,
88 in_features,
89 ),
90 })
91 }
92
93 /// Build directly from already-dequantized fp32 weights. Useful when the
94 /// caller has already paid the dequant cost (e.g. cached weights, or
95 /// constructing from synthetic data in tests).
96 pub fn from_dense_rows(
97 weight_row_major: &[f32],
98 out_features: usize,
99 in_features: usize,
100 ) -> Self {
101 Self {
102 inner: DenseLinear::<B>::from_rows(weight_row_major, out_features, in_features),
103 }
104 }
105}
106
107impl<B: Backend> Linear<B> for GgufLinear<B> {
108 fn in_features(&self) -> usize {
109 self.inner.in_features()
110 }
111
112 fn out_features(&self) -> usize {
113 self.inner.out_features()
114 }
115
116 fn forward(&self, ctx: &mut B::Context, input: &B::Buffer, out: &mut B::Buffer, m: usize) {
117 self.inner.forward(ctx, input, out, m);
118 }
119}
120
121/// Convenience: build a boxed `Linear<B>` from a `QTensor`. Useful for
122/// `WeightLoader` impls that want a uniform `Box<dyn Linear<B>>` output.
123pub fn linear_from_qtensor<B: Backend>(qt: &QTensor) -> CandleResult<Box<dyn Linear<B>>> {
124 Ok(Box::new(GgufLinear::<B>::from_qtensor(qt)?))
125}
126
127/// Dequantize on CPU, flatten to a contiguous `Vec<f32>` in row-major order.
128/// Pulled out so weight + bias paths share the same conversion.
129fn dequantize_to_vec(qt: &QTensor) -> CandleResult<Vec<f32>> {
130 let dense = qt.dequantize(&Device::Cpu)?;
131 let flat = dense.flatten_all()?;
132 flat.to_vec1::<f32>()
133}