1use rlx_ir::quant::QuantScheme;
23use std::collections::HashMap;
24use std::sync::{Arc, OnceLock, RwLock};
25
26#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
27struct DequantKey {
28 k: u32,
29 n: u32,
30 scheme: u8,
31 bytes_hash: u64,
33}
34
35fn weight_bytes_hash(w_bytes: &[u8]) -> u64 {
36 use std::hash::{Hash, Hasher};
37 let mut hasher = std::collections::hash_map::DefaultHasher::new();
38 w_bytes.hash(&mut hasher);
39 hasher.finish()
40}
41
42fn scheme_tag(scheme: QuantScheme) -> u8 {
43 match scheme {
44 QuantScheme::GgufQ4K => 1,
45 QuantScheme::GgufQ5K => 2,
46 QuantScheme::GgufQ6K => 3,
47 QuantScheme::GgufQ8K => 4,
48 QuantScheme::GgufQ4_0 => 5,
49 QuantScheme::GgufQ8_0 => 6,
50 QuantScheme::GgufQ2K => 7,
51 QuantScheme::GgufQ3K => 8,
52 QuantScheme::GgufIQ4NL => 9,
53 QuantScheme::GgufIQ4XS => 10,
54 QuantScheme::GgufIQ2XXS => 11,
55 QuantScheme::GgufIQ2XS => 12,
56 QuantScheme::GgufIQ2S => 13,
57 QuantScheme::GgufIQ3XXS => 14,
58 QuantScheme::GgufIQ3S => 15,
59 QuantScheme::GgufIQ1S => 16,
60 QuantScheme::GgufIQ1M => 17,
61 QuantScheme::GgufTQ1_0 => 18,
62 QuantScheme::GgufTQ2_0 => 19,
63 QuantScheme::GgufMXFP4 => 20,
64 QuantScheme::GgufNVFP4 => 21,
65 _ => 255,
66 }
67}
68
69fn dequant_gguf(w_bytes: &[u8], k: usize, n: usize, scheme: QuantScheme) -> Vec<f32> {
70 let n_elems = k * n;
71 match scheme {
72 QuantScheme::GgufQ4K => rlx_gguf::dequant_q4_k(w_bytes, n_elems),
73 QuantScheme::GgufQ5K => rlx_gguf::dequant_q5_k(w_bytes, n_elems),
74 QuantScheme::GgufQ6K => rlx_gguf::dequant_q6_k(w_bytes, n_elems),
75 QuantScheme::GgufQ8K => rlx_gguf::dequant_q8_k(w_bytes, n_elems),
76 QuantScheme::GgufQ2K => rlx_gguf::dequant_q2_k(w_bytes, n_elems),
77 QuantScheme::GgufQ3K => rlx_gguf::dequant_q3_k(w_bytes, n_elems),
78 QuantScheme::GgufQ4_0 => rlx_gguf::dequant_q4_0(w_bytes, n_elems),
79 QuantScheme::GgufQ8_0 => rlx_gguf::dequant_q8_0(w_bytes, n_elems),
80 QuantScheme::GgufIQ4NL => rlx_gguf::iq_dequant::dequant_iq4_nl(w_bytes, n_elems),
81 QuantScheme::GgufIQ4XS => rlx_gguf::iq_dequant::dequant_iq4_xs(w_bytes, n_elems),
82 QuantScheme::GgufIQ2XXS => rlx_gguf::iq_dequant::dequant_iq2_xxs(w_bytes, n_elems),
83 QuantScheme::GgufIQ2XS => rlx_gguf::iq_dequant::dequant_iq2_xs(w_bytes, n_elems),
84 QuantScheme::GgufIQ2S => rlx_gguf::iq_dequant::dequant_iq2_s(w_bytes, n_elems),
85 QuantScheme::GgufIQ3XXS => rlx_gguf::iq_dequant::dequant_iq3_xxs(w_bytes, n_elems),
86 QuantScheme::GgufIQ3S => rlx_gguf::iq_dequant::dequant_iq3_s(w_bytes, n_elems),
87 QuantScheme::GgufIQ1S => rlx_gguf::iq_dequant::dequant_iq1_s(w_bytes, n_elems),
88 QuantScheme::GgufIQ1M => rlx_gguf::iq_dequant::dequant_iq1_m(w_bytes, n_elems),
89 QuantScheme::GgufTQ1_0 => rlx_gguf::tq_dequant::dequant_tq1_0(w_bytes, n_elems),
90 QuantScheme::GgufTQ2_0 => rlx_gguf::tq_dequant::dequant_tq2_0(w_bytes, n_elems),
91 QuantScheme::GgufMXFP4 => rlx_gguf::mx_dequant::dequant_mxfp4(w_bytes, n_elems),
92 QuantScheme::GgufNVFP4 => rlx_gguf::mx_dequant::dequant_nvfp4(w_bytes, n_elems),
93 other => panic!("dequant_cache: unsupported GGUF scheme {other:?}"),
94 }
95 .expect("GGUF dequant failed")
96}
97
98static CACHE: OnceLock<RwLock<HashMap<DequantKey, Arc<[f32]>>>> = OnceLock::new();
99
100fn cache_enabled() -> bool {
101 !matches!(
102 rlx_ir::env::var("RLX_DEQUANT_CACHE").as_deref(),
103 Some("0") | Some("false") | Some("off")
104 )
105}
106
107pub fn gguf_weight_f32(
109 _w_off: usize,
110 w_bytes: &[u8],
111 k: usize,
112 n: usize,
113 scheme: QuantScheme,
114) -> Arc<[f32]> {
115 if !cache_enabled() {
116 return Arc::from(dequant_gguf(w_bytes, k, n, scheme).into_boxed_slice());
117 }
118 let key = DequantKey {
119 k: k as u32,
120 n: n as u32,
121 scheme: scheme_tag(scheme),
122 bytes_hash: weight_bytes_hash(w_bytes),
123 };
124 let cache = CACHE.get_or_init(|| RwLock::new(HashMap::new()));
125 if let Some(hit) = cache.read().expect("dequant cache poisoned").get(&key) {
126 return Arc::clone(hit);
127 }
128 let dense = dequant_gguf(w_bytes, k, n, scheme);
129 let arc: Arc<[f32]> = Arc::from(dense.into_boxed_slice());
130 cache
131 .write()
132 .expect("dequant cache poisoned")
133 .insert(key, Arc::clone(&arc));
134 arc
135}
136
137pub fn clear_dequant_cache() {
139 if let Some(c) = CACHE.get() {
140 c.write().expect("dequant cache poisoned").clear();
141 }
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147
148 #[test]
149 fn gguf_dequant_cache_hits_on_second_lookup() {
150 clear_dequant_cache();
151 const QK_K: usize = 256;
152 let mut packed = Vec::new();
153 packed.extend_from_slice(&half::f16::from_f32(1.0).to_le_bytes());
154 packed.extend_from_slice(&half::f16::from_f32(1.0).to_le_bytes());
155 let mut scales = [0u8; 12];
156 for s in &mut scales[0..4] {
157 *s = 0x01;
158 }
159 packed.extend_from_slice(&scales);
160 packed.extend(std::iter::repeat_n(0x77u8, QK_K / 2));
161 let k = 256;
162 let n = 1;
163 let w_off = 4096;
164 let hash = weight_bytes_hash(&packed);
165 let a = gguf_weight_f32(w_off, &packed, k, n, QuantScheme::GgufQ4K);
166 let b = gguf_weight_f32(w_off + 999, &packed, k, n, QuantScheme::GgufQ4K);
167 assert!(
168 Arc::ptr_eq(&a, &b),
169 "same bytes at different offsets should hit"
170 );
171 let mut other = packed.clone();
172 other[0] ^= 0x01;
173 let c = gguf_weight_f32(w_off, &other, k, n, QuantScheme::GgufQ4K);
174 assert!(!Arc::ptr_eq(&a, &c), "different bytes should miss: {hash}");
175 }
176}