ferrum_quantization/quant_linear.rs
1//! `QuantLinear<B>` — keeps Q4_K_M (or future k-quant) weights quantised
2//! in backend memory and dequants on-demand per `forward` call.
3//!
4//! Contrast with `GgufLinear<B>` which **eagerly** dequants Q4_K_M to
5//! fp32/fp16 at load time. That eager path inflates an 8B model from
6//! ~5 GB on disk to 16-32 GB in RAM — fine for safetensors-fp16 sources
7//! but wasteful for GGUF Q4_K_M and a non-starter for 30B-A3B on a
8//! 32 GB Mac.
9//!
10//! The Q4 → fp16 conversion happens inside `Backend::gemm_q4_k`, into a
11//! transient buffer that's freed after the matmul. Memory footprint is
12//! the on-disk Q4 size + a per-call transient ~= one weight matrix's
13//! worth of fp16.
14//!
15//! Phase 1D scope: direct (un-fused) Q4_K_M projections only —
16//! `o_proj`, `down_proj`, `lm_head`, `embed_tokens`, etc. Fused
17//! projections (`qkv_proj`, `gate_up_proj`) keep falling through to
18//! `GgufLinear`'s eager-dequant path; the loader's split-fusion logic
19//! already concatenates the dequanted parts into one dense weight.
20
21use ferrum_kernels::backend::{Backend, GgufQuantType};
22use ferrum_kernels::Linear;
23use ferrum_types::Result;
24
25/// Linear projection backed by a GGUF k-quant weight kept quantised in
26/// backend memory.
27///
28/// `forward` calls into `Backend::gemm_quant`, which dequants the
29/// weight into a transient fp16 buffer (Metal) or pre-dequanted fp32
30/// weights (CPU) and then runs the matmul. See `B::QuantStore` per
31/// backend for the storage format details.
32///
33/// Future k-quant flavours (Q5_K, Q6_K, Q8_0) plug in via the
34/// [`GgufQuantType`] discriminator passed to the constructor — no new
35/// `QuantLinear` type required.
36pub struct QuantLinear<B: Backend> {
37 store: B::QuantStore,
38 in_features: usize,
39 out_features: usize,
40}
41
42impl<B: Backend> QuantLinear<B> {
43 /// Build from raw GGUF block bytes.
44 ///
45 /// `kind`: which k-quant flavour the bytes encode (Q4_K, Q5_K, …).
46 /// `bytes`: the on-disk payload, sized by the kind's block layout.
47 pub fn from_gguf_bytes(
48 kind: GgufQuantType,
49 bytes: &[u8],
50 out_features: usize,
51 in_features: usize,
52 ) -> Result<Self> {
53 let store = B::load_quant(kind, bytes, out_features, in_features)?;
54 Ok(Self {
55 store,
56 in_features,
57 out_features,
58 })
59 }
60
61 /// Build a fused projection from multiple `(kind, bytes, rows)`
62 /// parts that share `in_features`. Each part stays in its own
63 /// QuantStore (no byte-concat); forward dispatches one matvec per
64 /// part. Used for Qwen3 `qkv_proj` when q+k are Q4_K and v is Q6_K
65 /// — the homogeneous fused-Q4 fast path would have to fall back
66 /// to eager-fp32, blowing 100 MB per layer.
67 pub fn from_gguf_fused(
68 parts: &[(GgufQuantType, &[u8], usize)],
69 in_features: usize,
70 ) -> Result<Self> {
71 let store = B::load_quant_fused(parts, in_features)?;
72 let out_features = parts.iter().map(|(_, _, n)| *n).sum();
73 Ok(Self {
74 store,
75 in_features,
76 out_features,
77 })
78 }
79
80 /// For tests / advanced callers that have already constructed a
81 /// `B::QuantStore` (e.g. through the Backend's own ingestion path).
82 pub fn from_store(store: B::QuantStore, out_features: usize, in_features: usize) -> Self {
83 Self {
84 store,
85 in_features,
86 out_features,
87 }
88 }
89
90 pub fn store(&self) -> &B::QuantStore {
91 &self.store
92 }
93}
94
95impl<B: Backend> Linear<B> for QuantLinear<B> {
96 fn in_features(&self) -> usize {
97 self.in_features
98 }
99
100 fn out_features(&self) -> usize {
101 self.out_features
102 }
103
104 fn forward(&self, ctx: &mut B::Context, input: &B::Buffer, out: &mut B::Buffer, m: usize) {
105 // Trait-level dispatch — the kernel choice (Metal compute kernel vs
106 // CPU dequant+gemm, and which k-quant flavour) is encapsulated in
107 // the backend's `QuantStore` enum.
108 B::gemm_quant(ctx, input, &self.store, out, m).expect("gemm_quant");
109 }
110}