1use rlx_ir::quant::QuantScheme;
23use std::collections::HashMap;
24use std::sync::{Arc, OnceLock, RwLock};
25
26#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
27struct DequantKey {
28 k: u32,
29 n: u32,
30 scheme: u8,
31 bytes_hash: u64,
33}
34
35fn weight_bytes_hash(w_bytes: &[u8]) -> u64 {
36 use std::hash::{Hash, Hasher};
37 let mut hasher = std::collections::hash_map::DefaultHasher::new();
38 w_bytes.hash(&mut hasher);
39 hasher.finish()
40}
41
42fn scheme_tag(scheme: QuantScheme) -> u8 {
43 match scheme {
44 QuantScheme::GgufQ4K => 1,
45 QuantScheme::GgufQ5K => 2,
46 QuantScheme::GgufQ6K => 3,
47 QuantScheme::GgufQ8K => 4,
48 QuantScheme::GgufQ4_0 => 5,
49 QuantScheme::GgufQ8_0 => 6,
50 _ => 255,
51 }
52}
53
54fn dequant_gguf(w_bytes: &[u8], k: usize, n: usize, scheme: QuantScheme) -> Vec<f32> {
55 match scheme {
56 QuantScheme::GgufQ4K => rlx_gguf::dequant_q4_k(w_bytes, k * n),
57 QuantScheme::GgufQ5K => rlx_gguf::dequant_q5_k(w_bytes, k * n),
58 QuantScheme::GgufQ6K => rlx_gguf::dequant_q6_k(w_bytes, k * n),
59 QuantScheme::GgufQ8K => rlx_gguf::dequant_q8_k(w_bytes, k * n),
60 QuantScheme::GgufQ4_0 => rlx_gguf::dequant_q4_0(w_bytes, k * n),
61 QuantScheme::GgufQ8_0 => rlx_gguf::dequant_q8_0(w_bytes, k * n),
62 other => panic!("dequant_cache: unsupported GGUF scheme {other:?}"),
63 }
64 .expect("GGUF dequant failed")
65}
66
67static CACHE: OnceLock<RwLock<HashMap<DequantKey, Arc<[f32]>>>> = OnceLock::new();
68
69fn cache_enabled() -> bool {
70 !matches!(
71 rlx_ir::env::var("RLX_DEQUANT_CACHE").as_deref(),
72 Some("0") | Some("false") | Some("off")
73 )
74}
75
76pub fn gguf_weight_f32(
78 _w_off: usize,
79 w_bytes: &[u8],
80 k: usize,
81 n: usize,
82 scheme: QuantScheme,
83) -> Arc<[f32]> {
84 if !cache_enabled() {
85 return Arc::from(dequant_gguf(w_bytes, k, n, scheme).into_boxed_slice());
86 }
87 let key = DequantKey {
88 k: k as u32,
89 n: n as u32,
90 scheme: scheme_tag(scheme),
91 bytes_hash: weight_bytes_hash(w_bytes),
92 };
93 let cache = CACHE.get_or_init(|| RwLock::new(HashMap::new()));
94 if let Some(hit) = cache.read().expect("dequant cache poisoned").get(&key) {
95 return Arc::clone(hit);
96 }
97 let dense = dequant_gguf(w_bytes, k, n, scheme);
98 let arc: Arc<[f32]> = Arc::from(dense.into_boxed_slice());
99 cache
100 .write()
101 .expect("dequant cache poisoned")
102 .insert(key, Arc::clone(&arc));
103 arc
104}
105
106pub fn clear_dequant_cache() {
108 if let Some(c) = CACHE.get() {
109 c.write().expect("dequant cache poisoned").clear();
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116
117 #[test]
118 fn gguf_dequant_cache_hits_on_second_lookup() {
119 clear_dequant_cache();
120 const QK_K: usize = 256;
121 let mut packed = Vec::new();
122 packed.extend_from_slice(&half::f16::from_f32(1.0).to_le_bytes());
123 packed.extend_from_slice(&half::f16::from_f32(1.0).to_le_bytes());
124 let mut scales = [0u8; 12];
125 for s in &mut scales[0..4] {
126 *s = 0x01;
127 }
128 packed.extend_from_slice(&scales);
129 packed.extend(std::iter::repeat_n(0x77u8, QK_K / 2));
130 let k = 256;
131 let n = 1;
132 let w_off = 4096;
133 let hash = weight_bytes_hash(&packed);
134 let a = gguf_weight_f32(w_off, &packed, k, n, QuantScheme::GgufQ4K);
135 let b = gguf_weight_f32(w_off + 999, &packed, k, n, QuantScheme::GgufQ4K);
136 assert!(
137 Arc::ptr_eq(&a, &b),
138 "same bytes at different offsets should hit"
139 );
140 let mut other = packed.clone();
141 other[0] ^= 0x01;
142 let c = gguf_weight_f32(w_off, &other, k, n, QuantScheme::GgufQ4K);
143 assert!(!Arc::ptr_eq(&a, &c), "different bytes should miss: {hash}");
144 }
145}