rlx_ir/quant.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Quantization metadata as graph annotations (plan #57).
17//! lives as per-tensor metadata on the IR rather than spawning a
18//! parallel "quantized graph" type. Ops can read the scheme and
19//! dispatch to fused-dequant kernels (the eventual #5 win) when
20//! present, or fall through to the standard f32/f16 path when not.
21//!
22//! The metadata is held *outside* the [`crate::Node`] type itself, in a
23//! [`crate::Graph`]-level [`QuantMap`]. This keeps Node small (every node
24//! pays for the rare quantization annotation otherwise) and makes
25//! quant info easy to query / clear without rewriting nodes.
26
27use crate::NodeId;
28use std::collections::HashMap;
29
30/// How a tensor is quantized. Mirrors the schemes RLX needs for LLM
31/// inference on Apple Silicon: blockwise int8 (GPTQ-style),
32/// blockwise int4 (Q4_K), and per-tensor fp8 (e4m3 / e5m2).
33///
34/// Each variant carries the parameters the dequantizer needs to read
35/// at runtime — scale, zero-point, block size. Where these live in
36/// the actual weight tensor is up to the loader (#56).
37#[cfg_attr(feature = "serialize", derive(serde::Serialize, serde::Deserialize))]
38#[derive(Debug, Clone, Copy, PartialEq)]
39pub enum QuantScheme {
40 /// Symmetric int8 with one scale per `block_size` elements.
41 Int8Block { block_size: u32 },
42 /// Asymmetric int8 with scale + zero-point per `block_size` elements.
43 Int8BlockAsym { block_size: u32 },
44 /// Int4 packed two-per-byte, scale per `block_size` elements
45 /// (Q4_K-ish; matches GGUF block layout).
46 Int4Block { block_size: u32 },
47 /// FP8 e4m3 (no scale; same domain as half).
48 Fp8E4m3,
49 /// FP8 e5m2 (no scale; wider range than e4m3).
50 Fp8E5m2,
51 /// GGUF / llama.cpp Q4_K super-block (256 elements / 144 bytes).
52 /// Packs an f16 super-scale + f16 super-min + 8 sub-block 6-bit
53 /// scales + 8 sub-block 6-bit mins + 128 nibbles. Block layout is
54 /// fixed by the format — there's no `block_size` knob.
55 GgufQ4K,
56 /// GGUF Q5_K (256 / 176 bytes). Adds a 32-byte high-bit plane on
57 /// top of Q4_K.
58 GgufQ5K,
59 /// GGUF Q6_K (256 / 210 bytes). Per-sub-block signed scales,
60 /// no min term.
61 GgufQ6K,
62 /// GGUF Q8_K (256 / 276 bytes). Per-super-block f32 scale plus
63 /// i8 quants and a 32-byte sum-of-blocks table that's only used
64 /// by Q8_K × Q8_K matmul accumulation paths.
65 GgufQ8K,
66 /// GGUF Q2_K (256 / 84 bytes). 2-bit quants with per-sub-block scale/min.
67 GgufQ2K,
68 /// GGUF Q3_K (256 / 110 bytes). 3-bit quants with hmask high bit plane.
69 GgufQ3K,
70 /// GGUF Q4_0 (32 / 18 bytes). Legacy llama.cpp block: f16 scale + nibbles.
71 GgufQ4_0,
72 /// GGUF Q8_0 (32 / 34 bytes). Legacy block: f16 scale + 32×i8 quants.
73 GgufQ8_0,
74 /// NVIDIA FP4 (E2M1) block — fixed 16-element groups, FP8 E4M3 block
75 /// scales, optional f32 global scale on input 3 (legacy `zp` slot).
76 /// Used by FLUX.2 / MLX `nvfp4` checkpoints.
77 Nvfp4Block,
78}
79
80impl QuantScheme {
81 /// Bits per element after packing (×10 for K-quants since they
82 /// have fractional bit budgets — divide by 10 when comparing).
83 pub const fn bits_per_element_x10(self) -> u32 {
84 match self {
85 Self::Int8Block { .. } | Self::Int8BlockAsym { .. } => 80,
86 Self::Int4Block { .. } => 40,
87 Self::Fp8E4m3 | Self::Fp8E5m2 => 80,
88 // GGUF K-quants: header + per-element bits over a 256-element block.
89 Self::GgufQ4K => 45, // 144 bytes / 256 elems × 8 = 4.5 bpe
90 Self::GgufQ5K => 55, // 176 / 256 × 8 ≈ 5.5
91 Self::GgufQ6K => 66, // 210 / 256 × 8 ≈ 6.5625 → 66 (rounded)
92 Self::GgufQ8K => 91, // 292 / 256 × 8 ≈ 9.125 → 91
93 Self::GgufQ2K => 26, // 84 / 256 × 8 ≈ 2.625 → 26
94 Self::GgufQ3K => 34, // 110 / 256 × 8 ≈ 3.4375 → 34
95 Self::GgufQ4_0 => 45, // 18 / 32 × 8 = 4.5 bpe
96 Self::GgufQ8_0 => 85, // 34 / 32 × 8 = 8.5 bpe
97 Self::Nvfp4Block => 40,
98 }
99 }
100
101 /// Bits per element after packing (rounded down). Use
102 /// `bits_per_element_x10` for the K-quant fractional values.
103 pub const fn bits_per_element(self) -> u32 {
104 self.bits_per_element_x10() / 10
105 }
106
107 /// True if this scheme requires a per-block scale tensor on the side.
108 pub const fn has_scale(self) -> bool {
109 matches!(
110 self,
111 Self::Int8Block { .. }
112 | Self::Int8BlockAsym { .. }
113 | Self::Int4Block { .. }
114 | Self::Nvfp4Block
115 )
116 }
117
118 /// True for NVFP4 block scales stored as FP8 E4M3 bytes (not f32).
119 pub const fn scale_is_fp8(self) -> bool {
120 matches!(self, Self::Nvfp4Block)
121 }
122
123 /// Fixed NVFP4 group size along K (0 for other schemes).
124 pub const fn nvfp4_group_size(self) -> u32 {
125 match self {
126 Self::Nvfp4Block => crate::nvfp4::NVFP4_GROUP_SIZE as u32,
127 _ => 0,
128 }
129 }
130
131 /// True if this scheme requires a per-block zero-point.
132 pub const fn has_zero_point(self) -> bool {
133 matches!(self, Self::Int8BlockAsym { .. })
134 }
135
136 /// GGUF K-quant block size (256 elements) — meaningless for the
137 /// non-GGUF schemes (returns 0).
138 pub const fn gguf_block_size(self) -> u32 {
139 match self {
140 Self::GgufQ4K
141 | Self::GgufQ5K
142 | Self::GgufQ6K
143 | Self::GgufQ8K
144 | Self::GgufQ2K
145 | Self::GgufQ3K => 256,
146 Self::GgufQ4_0 | Self::GgufQ8_0 => 32,
147 _ => 0,
148 }
149 }
150
151 /// Bytes per GGUF super-block. 0 for non-GGUF schemes.
152 pub const fn gguf_block_bytes(self) -> u32 {
153 match self {
154 Self::GgufQ4K => 144, // f16 d + f16 dmin + 12 packed scales + 128 nibbles
155 Self::GgufQ5K => 176, // + 32-byte high-bit plane
156 Self::GgufQ6K => 210, // 128 ql + 64 qh + 16 i8 scales + f16 d
157 Self::GgufQ8K => 292, // f32 d + 256 i8 + 16 i16 bsums = 4 + 256 + 32
158 Self::GgufQ2K => 84, // f16 d + f16 dmin + 16 scales + 64 qs
159 Self::GgufQ3K => 110, // f16 d + 12 scales + 32 hmask + 64 qs
160 Self::GgufQ4_0 => 18, // f16 d + 16 packed nibbles
161 Self::GgufQ8_0 => 34, // f16 d + 32 i8 quants
162 _ => 0,
163 }
164 }
165
166 /// True for any GGUF-format block scheme. GGUF schemes carry
167 /// their scales / mins / sub-block metadata *inside* the packed
168 /// weight bytes — they don't need separate `scale` / `zp`
169 /// tensors fed alongside as the legacy `Int8Block` paths do.
170 pub const fn is_gguf(self) -> bool {
171 matches!(
172 self,
173 Self::GgufQ4K
174 | Self::GgufQ5K
175 | Self::GgufQ6K
176 | Self::GgufQ8K
177 | Self::GgufQ2K
178 | Self::GgufQ3K
179 | Self::GgufQ4_0
180 | Self::GgufQ8_0
181 )
182 }
183}
184
185impl std::fmt::Display for QuantScheme {
186 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
187 match self {
188 Self::Int8Block { block_size } => write!(f, "int8/{block_size}"),
189 Self::Int8BlockAsym { block_size } => write!(f, "int8a/{block_size}"),
190 Self::Int4Block { block_size } => write!(f, "int4/{block_size}"),
191 Self::Fp8E4m3 => write!(f, "fp8e4m3"),
192 Self::Fp8E5m2 => write!(f, "fp8e5m2"),
193 Self::GgufQ4K => write!(f, "gguf_q4k"),
194 Self::GgufQ5K => write!(f, "gguf_q5k"),
195 Self::GgufQ6K => write!(f, "gguf_q6k"),
196 Self::GgufQ8K => write!(f, "gguf_q8k"),
197 Self::GgufQ2K => write!(f, "gguf_q2k"),
198 Self::GgufQ3K => write!(f, "gguf_q3k"),
199 Self::GgufQ4_0 => write!(f, "gguf_q4_0"),
200 Self::GgufQ8_0 => write!(f, "gguf_q8_0"),
201 Self::Nvfp4Block => write!(f, "nvfp4/16"),
202 }
203 }
204}
205
206/// Per-graph map of quantized tensors. Lookup is O(1).
207#[derive(Debug, Clone, Default)]
208pub struct QuantMap {
209 map: HashMap<NodeId, QuantScheme>,
210}
211
212impl QuantMap {
213 pub fn new() -> Self {
214 Self::default()
215 }
216 pub fn get(&self, id: NodeId) -> Option<QuantScheme> {
217 self.map.get(&id).copied()
218 }
219 pub fn insert(&mut self, id: NodeId, scheme: QuantScheme) -> Option<QuantScheme> {
220 self.map.insert(id, scheme)
221 }
222 pub fn is_empty(&self) -> bool {
223 self.map.is_empty()
224 }
225 pub fn len(&self) -> usize {
226 self.map.len()
227 }
228 pub fn iter(&self) -> impl Iterator<Item = (&NodeId, &QuantScheme)> {
229 self.map.iter()
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[test]
238 fn scheme_traits() {
239 assert_eq!(
240 QuantScheme::Int4Block { block_size: 32 }.bits_per_element(),
241 4
242 );
243 assert!(QuantScheme::Int8BlockAsym { block_size: 64 }.has_zero_point());
244 assert!(!QuantScheme::Fp8E4m3.has_scale());
245 }
246
247 #[test]
248 fn quant_map_lookup() {
249 let mut q = QuantMap::new();
250 let id = NodeId(7);
251 q.insert(id, QuantScheme::Int8Block { block_size: 32 });
252 assert_eq!(q.get(id), Some(QuantScheme::Int8Block { block_size: 32 }));
253 assert_eq!(q.get(NodeId(99)), None);
254 }
255}