1use anyhow::{Result, anyhow, bail};
46use rlx_gguf::{GgmlType, quantize};
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum KvQuant {
55 F16,
58 Q8_0,
59 Q4_0,
60 Q5_0,
61}
62
63impl KvQuant {
64 pub const fn block_elements(self) -> usize {
66 match self {
67 Self::F16 => 1,
68 Self::Q8_0 | Self::Q4_0 | Self::Q5_0 => 32,
69 }
70 }
71
72 pub const fn block_bytes(self) -> usize {
74 match self {
75 Self::F16 => 2,
76 Self::Q8_0 => 2 + 32,
77 Self::Q4_0 => 2 + 32 / 2,
78 Self::Q5_0 => 2 + 4 + 32 / 2,
79 }
80 }
81
82 fn ggml_type(self) -> Option<GgmlType> {
83 match self {
84 Self::F16 => None, Self::Q8_0 => Some(GgmlType::Q8_0),
86 Self::Q4_0 => Some(GgmlType::Q4_0),
87 Self::Q5_0 => Some(GgmlType::Q5_0),
88 }
89 }
90
91 pub fn bytes_for(self, n_elements: usize) -> Result<usize> {
93 let blk = self.block_elements();
94 if !n_elements.is_multiple_of(blk) {
95 bail!("{self:?}: element count {n_elements} not aligned to block size {blk}");
96 }
97 Ok((n_elements / blk) * self.block_bytes())
98 }
99}
100
101#[derive(Debug, Clone)]
107pub struct QuantizedKvLayer {
108 pub k: Vec<u8>,
109 pub v: Vec<u8>,
110 pub past_len: usize,
111 pub kv_dim: usize,
112 pub scheme: KvQuant,
113}
114
115impl QuantizedKvLayer {
116 pub fn new(kv_dim: usize, scheme: KvQuant) -> Result<Self> {
117 let blk = scheme.block_elements();
118 if !kv_dim.is_multiple_of(blk) {
119 bail!("kv_dim ({kv_dim}) must be a multiple of {scheme:?} block size ({blk})");
120 }
121 Ok(Self {
122 k: Vec::new(),
123 v: Vec::new(),
124 past_len: 0,
125 kv_dim,
126 scheme,
127 })
128 }
129
130 pub fn append_rows(&mut self, k_rows: &[f32], v_rows: &[f32]) -> Result<()> {
134 if k_rows.len() != v_rows.len() {
135 bail!(
136 "append_rows: k len {} != v len {}",
137 k_rows.len(),
138 v_rows.len()
139 );
140 }
141 if !k_rows.len().is_multiple_of(self.kv_dim) {
142 bail!(
143 "append_rows: byte count {} not aligned to kv_dim {}",
144 k_rows.len(),
145 self.kv_dim
146 );
147 }
148 let n_rows = k_rows.len() / self.kv_dim;
149 let k_bytes = quant_rows(k_rows, self.scheme)?;
150 let v_bytes = quant_rows(v_rows, self.scheme)?;
151 self.k.extend_from_slice(&k_bytes);
152 self.v.extend_from_slice(&v_bytes);
153 self.past_len += n_rows;
154 Ok(())
155 }
156
157 pub fn read_all(&self) -> Result<(Vec<f32>, Vec<f32>)> {
159 let k = dequant_rows(&self.k, self.scheme, self.past_len * self.kv_dim)?;
160 let v = dequant_rows(&self.v, self.scheme, self.past_len * self.kv_dim)?;
161 Ok((k, v))
162 }
163
164 pub fn read_window(&self, window: usize) -> Result<(Vec<f32>, Vec<f32>)> {
166 if window >= self.past_len {
167 return self.read_all();
168 }
169 let blk = self.scheme.block_elements();
172 let blocks_per_row = self.kv_dim / blk;
173 let bytes_per_row = blocks_per_row * self.scheme.block_bytes();
174 let start_byte = (self.past_len - window) * bytes_per_row;
175 let n = window * self.kv_dim;
176 let k = dequant_rows(&self.k[start_byte..], self.scheme, n)?;
177 let v = dequant_rows(&self.v[start_byte..], self.scheme, n)?;
178 Ok((k, v))
179 }
180
181 pub fn drop_front(&mut self, n_rows: usize) -> Result<()> {
183 let n_rows = n_rows.min(self.past_len);
184 if n_rows == 0 {
185 return Ok(());
186 }
187 let blk = self.scheme.block_elements();
188 let blocks_per_row = self.kv_dim / blk;
189 let drop_bytes = n_rows * blocks_per_row * self.scheme.block_bytes();
190 self.k.drain(..drop_bytes);
191 self.v.drain(..drop_bytes);
192 self.past_len -= n_rows;
193 Ok(())
194 }
195
196 pub fn bytes(&self) -> usize {
198 self.k.len() + self.v.len()
199 }
200}
201
202#[derive(Debug, Clone)]
204pub struct QuantizedKvCache {
205 pub layers: Vec<QuantizedKvLayer>,
206}
207
208impl QuantizedKvCache {
209 pub fn new(n_layers: usize, kv_dim: usize, scheme: KvQuant) -> Result<Self> {
210 let layers = (0..n_layers)
211 .map(|_| QuantizedKvLayer::new(kv_dim, scheme))
212 .collect::<Result<Vec<_>>>()?;
213 Ok(Self { layers })
214 }
215
216 pub fn n_layers(&self) -> usize {
217 self.layers.len()
218 }
219
220 pub fn past_len(&self) -> usize {
221 self.layers.first().map(|l| l.past_len).unwrap_or(0)
222 }
223
224 pub fn bytes(&self) -> usize {
226 self.layers.iter().map(|l| l.bytes()).sum()
227 }
228}
229
230fn quant_rows(values: &[f32], scheme: KvQuant) -> Result<Vec<u8>> {
233 match scheme {
234 KvQuant::F16 => {
235 let mut out = Vec::with_capacity(values.len() * 2);
236 for &v in values {
237 let h = half::f16::from_f32(v);
238 out.extend_from_slice(&h.to_le_bytes());
239 }
240 Ok(out)
241 }
242 scheme => {
243 let ty = scheme
244 .ggml_type()
245 .ok_or_else(|| anyhow!("internal: missing ggml type for {scheme:?}"))?;
246 Ok(quantize(values, ty)?)
247 }
248 }
249}
250
251fn dequant_rows(bytes: &[u8], scheme: KvQuant, n: usize) -> Result<Vec<f32>> {
252 match scheme {
253 KvQuant::F16 => {
254 if bytes.len() < n * 2 {
255 bail!("F16 dequant: {} bytes < {} expected", bytes.len(), n * 2);
256 }
257 let mut out = Vec::with_capacity(n);
258 for chunk in bytes[..n * 2].chunks_exact(2) {
259 let h = half::f16::from_le_bytes([chunk[0], chunk[1]]);
260 out.push(h.to_f32());
261 }
262 Ok(out)
263 }
264 KvQuant::Q8_0 => {
265 let expected = scheme.bytes_for(n)?;
266 Ok(rlx_gguf::dequant_q8_0(&bytes[..expected], n)?)
267 }
268 KvQuant::Q4_0 => {
269 let expected = scheme.bytes_for(n)?;
270 Ok(rlx_gguf::dequant_q4_0(&bytes[..expected], n)?)
271 }
272 KvQuant::Q5_0 => {
273 decode_q5_0(bytes, n)
282 }
283 }
284}
285
286fn decode_q5_0(bytes: &[u8], n: usize) -> Result<Vec<f32>> {
287 const QK5_0: usize = 32;
288 let blk_bytes = 2 + 4 + QK5_0 / 2;
289 if !n.is_multiple_of(QK5_0) {
290 bail!("Q5_0: n={n} not divisible by {QK5_0}");
291 }
292 let nb = n / QK5_0;
293 if bytes.len() < nb * blk_bytes {
294 bail!(
295 "Q5_0: expected {} bytes, got {}",
296 nb * blk_bytes,
297 bytes.len()
298 );
299 }
300 let mut out = Vec::with_capacity(n);
301 for i in 0..nb {
302 let off = i * blk_bytes;
303 let d = half::f16::from_le_bytes([bytes[off], bytes[off + 1]]).to_f32();
304 let qh = u32::from_le_bytes([
305 bytes[off + 2],
306 bytes[off + 3],
307 bytes[off + 4],
308 bytes[off + 5],
309 ]);
310 let qs = &bytes[off + 6..off + 6 + QK5_0 / 2];
311 for j in 0..QK5_0 / 2 {
312 let xh0 = (((qh >> j) & 1) as u8) << 4;
313 let v0 = ((qs[j] & 0x0F) | xh0) as i32 - 16;
314 out.push(d * v0 as f32);
315 }
316 for j in 0..QK5_0 / 2 {
317 let xh1 = (((qh >> (j + 16)) & 1) as u8) << 4;
318 let v1 = ((qs[j] >> 4) | xh1) as i32 - 16;
319 out.push(d * v1 as f32);
320 }
321 }
322 Ok(out)
323}
324
325#[cfg(feature = "mmap-kv")]
346pub mod mmap {
347 use super::*;
348 use memmap2::{MmapMut, MmapOptions};
349 use std::fs::OpenOptions;
350 use std::path::{Path, PathBuf};
351
352 pub struct MmapKvLayer {
355 pub mmap: MmapMut,
356 pub past_len: usize,
357 pub capacity_rows: usize,
358 pub kv_dim: usize,
359 pub scheme: KvQuant,
360 pub bytes_per_row: usize,
361 pub k_offset: usize,
362 pub v_offset: usize,
363 pub path: Option<PathBuf>,
364 }
365
366 impl MmapKvLayer {
367 pub fn open<P: AsRef<Path>>(
370 path: P,
371 kv_dim: usize,
372 scheme: KvQuant,
373 capacity_rows: usize,
374 ) -> Result<Self> {
375 let blk = scheme.block_elements();
376 if !kv_dim.is_multiple_of(blk) {
377 bail!("kv_dim ({kv_dim}) must be a multiple of {scheme:?} block size ({blk})");
378 }
379 let bytes_per_row = (kv_dim / blk) * scheme.block_bytes();
380 let total = 2 * capacity_rows * bytes_per_row;
381 let file = OpenOptions::new()
382 .read(true)
383 .write(true)
384 .create(true)
385 .truncate(true)
386 .open(&path)?;
387 file.set_len(total as u64)?;
388 let mmap = unsafe { MmapOptions::new().len(total).map_mut(&file)? };
389 Ok(Self {
390 mmap,
391 past_len: 0,
392 capacity_rows,
393 kv_dim,
394 scheme,
395 bytes_per_row,
396 k_offset: 0,
397 v_offset: capacity_rows * bytes_per_row,
398 path: Some(path.as_ref().to_path_buf()),
399 })
400 }
401
402 pub fn anonymous(kv_dim: usize, scheme: KvQuant, capacity_rows: usize) -> Result<Self> {
406 let blk = scheme.block_elements();
407 if !kv_dim.is_multiple_of(blk) {
408 bail!("kv_dim ({kv_dim}) must be a multiple of {scheme:?} block size ({blk})");
409 }
410 let bytes_per_row = (kv_dim / blk) * scheme.block_bytes();
411 let total = 2 * capacity_rows * bytes_per_row;
412 let mmap = MmapOptions::new().len(total).map_anon()?;
413 Ok(Self {
414 mmap,
415 past_len: 0,
416 capacity_rows,
417 kv_dim,
418 scheme,
419 bytes_per_row,
420 k_offset: 0,
421 v_offset: capacity_rows * bytes_per_row,
422 path: None,
423 })
424 }
425
426 pub fn append_rows(&mut self, k_rows: &[f32], v_rows: &[f32]) -> Result<()> {
429 if k_rows.len() != v_rows.len() {
430 bail!("append_rows: k/v length mismatch");
431 }
432 if !k_rows.len().is_multiple_of(self.kv_dim) {
433 bail!("append_rows: byte count not aligned to kv_dim");
434 }
435 let n_rows = k_rows.len() / self.kv_dim;
436 if self.past_len + n_rows > self.capacity_rows {
437 bail!(
438 "append_rows: would exceed capacity ({} + {} > {})",
439 self.past_len,
440 n_rows,
441 self.capacity_rows
442 );
443 }
444 let kb = quant_rows(k_rows, self.scheme)?;
445 let vb = quant_rows(v_rows, self.scheme)?;
446 let k_start = self.k_offset + self.past_len * self.bytes_per_row;
447 let v_start = self.v_offset + self.past_len * self.bytes_per_row;
448 self.mmap[k_start..k_start + kb.len()].copy_from_slice(&kb);
449 self.mmap[v_start..v_start + vb.len()].copy_from_slice(&vb);
450 self.past_len += n_rows;
451 Ok(())
452 }
453
454 pub fn read_all(&self) -> Result<(Vec<f32>, Vec<f32>)> {
457 let n = self.past_len * self.kv_dim;
458 let k_end = self.k_offset + self.past_len * self.bytes_per_row;
459 let v_end = self.v_offset + self.past_len * self.bytes_per_row;
460 let k = dequant_rows(&self.mmap[self.k_offset..k_end], self.scheme, n)?;
461 let v = dequant_rows(&self.mmap[self.v_offset..v_end], self.scheme, n)?;
462 Ok((k, v))
463 }
464
465 pub fn read_window(&self, window: usize) -> Result<(Vec<f32>, Vec<f32>)> {
467 let window = window.min(self.past_len);
468 let start_row = self.past_len - window;
469 let n = window * self.kv_dim;
470 let k_start = self.k_offset + start_row * self.bytes_per_row;
471 let v_start = self.v_offset + start_row * self.bytes_per_row;
472 let k_end = k_start + window * self.bytes_per_row;
473 let v_end = v_start + window * self.bytes_per_row;
474 let k = dequant_rows(&self.mmap[k_start..k_end], self.scheme, n)?;
475 let v = dequant_rows(&self.mmap[v_start..v_end], self.scheme, n)?;
476 Ok((k, v))
477 }
478
479 pub fn prefetch_window(&self, window: usize) {
484 let window = window.min(self.past_len);
485 if window == 0 {
486 return;
487 }
488 let start_row = self.past_len - window;
489 let k_start = self.k_offset + start_row * self.bytes_per_row;
490 let v_start = self.v_offset + start_row * self.bytes_per_row;
491 let _ = self.mmap.advise_range(
492 memmap2::Advice::WillNeed,
493 k_start,
494 window * self.bytes_per_row,
495 );
496 let _ = self.mmap.advise_range(
497 memmap2::Advice::WillNeed,
498 v_start,
499 window * self.bytes_per_row,
500 );
501 }
502
503 pub fn flush(&self) -> Result<()> {
506 self.mmap.flush()?;
507 Ok(())
508 }
509
510 pub fn bytes(&self) -> usize {
511 2 * self.past_len * self.bytes_per_row
512 }
513 }
514
515 pub struct MmapKvCache {
517 pub layers: Vec<MmapKvLayer>,
518 }
519
520 impl MmapKvCache {
521 pub fn open_dir<P: AsRef<Path>>(
523 dir: P,
524 n_layers: usize,
525 kv_dim: usize,
526 scheme: KvQuant,
527 capacity_rows: usize,
528 ) -> Result<Self> {
529 let dir = dir.as_ref();
530 std::fs::create_dir_all(dir)?;
531 let layers = (0..n_layers)
532 .map(|i| {
533 MmapKvLayer::open(
534 dir.join(format!("kv_{i}.bin")),
535 kv_dim,
536 scheme,
537 capacity_rows,
538 )
539 })
540 .collect::<Result<Vec<_>>>()?;
541 Ok(Self { layers })
542 }
543
544 pub fn anonymous(
545 n_layers: usize,
546 kv_dim: usize,
547 scheme: KvQuant,
548 capacity_rows: usize,
549 ) -> Result<Self> {
550 let layers = (0..n_layers)
551 .map(|_| MmapKvLayer::anonymous(kv_dim, scheme, capacity_rows))
552 .collect::<Result<Vec<_>>>()?;
553 Ok(Self { layers })
554 }
555
556 pub fn n_layers(&self) -> usize {
557 self.layers.len()
558 }
559
560 pub fn past_len(&self) -> usize {
561 self.layers.first().map(|l| l.past_len).unwrap_or(0)
562 }
563
564 pub fn bytes(&self) -> usize {
566 self.layers.iter().map(|l| l.bytes()).sum()
567 }
568 }
569
570 #[cfg(test)]
571 mod tests {
572 use super::*;
573
574 #[test]
575 fn anonymous_q8_0_roundtrip() {
576 let kv_dim = 64;
577 let mut layer = MmapKvLayer::anonymous(kv_dim, KvQuant::Q8_0, 4).unwrap();
578 let data: Vec<f32> = (0..kv_dim).map(|i| (i as f32).sin()).collect();
579 layer.append_rows(&data, &data).unwrap();
580 let (k, v) = layer.read_all().unwrap();
581 assert_eq!(k.len(), kv_dim);
582 assert_eq!(v.len(), kv_dim);
583 for (a, b) in k.iter().zip(data.iter()) {
585 assert!((a - b).abs() < 0.02);
586 }
587 }
588
589 #[test]
590 fn file_backed_persists_and_reopens() {
591 let dir = tempfile::tempdir().unwrap();
592 let kv_dim = 32;
593 let path = dir.path().join("layer.bin");
594 {
595 let mut layer = MmapKvLayer::open(&path, kv_dim, KvQuant::F16, 8).unwrap();
596 let data: Vec<f32> = (0..kv_dim).map(|i| i as f32 * 0.5).collect();
597 layer.append_rows(&data, &data).unwrap();
598 layer.flush().unwrap();
599 }
600 let bytes = std::fs::read(&path).unwrap();
602 assert!(!bytes.is_empty());
603 assert!(bytes.iter().any(|&b| b != 0));
604 }
605
606 #[test]
607 fn append_past_capacity_errors() {
608 let mut l = MmapKvLayer::anonymous(32, KvQuant::Q8_0, 2).unwrap();
609 let row = vec![0.5f32; 32];
610 l.append_rows(&row, &row).unwrap();
611 l.append_rows(&row, &row).unwrap();
612 assert!(l.append_rows(&row, &row).is_err());
613 }
614 }
615}
616
617#[cfg(test)]
618mod tests {
619 use super::*;
620
621 fn cosine(a: &[f32], b: &[f32]) -> f32 {
622 let mut dot = 0.0f32;
623 let mut na = 0.0f32;
624 let mut nb = 0.0f32;
625 for (x, y) in a.iter().zip(b.iter()) {
626 dot += x * y;
627 na += x * x;
628 nb += y * y;
629 }
630 dot / (na.sqrt() * nb.sqrt() + 1e-12)
631 }
632
633 #[test]
634 fn block_size_invariants() {
635 assert_eq!(KvQuant::F16.block_bytes(), 2);
636 assert_eq!(KvQuant::Q8_0.block_bytes(), 34);
637 assert_eq!(KvQuant::Q4_0.block_bytes(), 18);
638 assert_eq!(KvQuant::Q5_0.block_bytes(), 22);
639 }
640
641 #[test]
642 fn f16_roundtrip_exact() {
643 let kv_dim = 64;
644 let mut layer = QuantizedKvLayer::new(kv_dim, KvQuant::F16).unwrap();
645 let k_row: Vec<f32> = (0..kv_dim).map(|i| (i as f32) * 0.1).collect();
646 let v_row: Vec<f32> = (0..kv_dim).map(|i| (i as f32) * 0.2).collect();
647 layer.append_rows(&k_row, &v_row).unwrap();
648 let (k, v) = layer.read_all().unwrap();
649 for i in 0..kv_dim {
650 assert!((k[i] - k_row[i]).abs() < 0.01);
652 assert!((v[i] - v_row[i]).abs() < 0.01);
653 }
654 }
655
656 #[test]
657 fn q8_0_roundtrip_high_fidelity() {
658 let kv_dim = 64;
659 let n_rows = 4;
660 let mut layer = QuantizedKvLayer::new(kv_dim, KvQuant::Q8_0).unwrap();
661 let total = n_rows * kv_dim;
662 let k_data: Vec<f32> = (0..total).map(|i| (i as f32).sin()).collect();
663 let v_data: Vec<f32> = (0..total).map(|i| (i as f32).cos()).collect();
664 layer.append_rows(&k_data, &v_data).unwrap();
665 assert_eq!(layer.past_len, n_rows);
666 let (k, v) = layer.read_all().unwrap();
667 assert!(cosine(&k, &k_data) > 0.999, "Q8_0 K cosine too low");
668 assert!(cosine(&v, &v_data) > 0.999, "Q8_0 V cosine too low");
669 }
670
671 #[test]
672 fn q4_0_roundtrip_lossy_but_close() {
673 let kv_dim = 64;
674 let mut layer = QuantizedKvLayer::new(kv_dim, KvQuant::Q4_0).unwrap();
675 let k: Vec<f32> = (0..kv_dim).map(|i| (i as f32 * 0.05).tanh()).collect();
676 let v: Vec<f32> = (0..kv_dim).map(|i| (i as f32 * 0.07).tanh()).collect();
677 layer.append_rows(&k, &v).unwrap();
678 let (kr, vr) = layer.read_all().unwrap();
679 assert!(cosine(&kr, &k) > 0.99);
680 assert!(cosine(&vr, &v) > 0.99);
681 }
682
683 #[test]
684 fn q5_0_roundtrip_better_than_q4() {
685 let kv_dim = 64;
686 let mut q4 = QuantizedKvLayer::new(kv_dim, KvQuant::Q4_0).unwrap();
687 let mut q5 = QuantizedKvLayer::new(kv_dim, KvQuant::Q5_0).unwrap();
688 let k: Vec<f32> = (0..kv_dim).map(|i| (i as f32 * 0.1).sin() * 3.0).collect();
689 let v: Vec<f32> = (0..kv_dim).map(|i| (i as f32 * 0.13).cos() * 3.0).collect();
690 q4.append_rows(&k, &v).unwrap();
691 q5.append_rows(&k, &v).unwrap();
692 let (k4, _) = q4.read_all().unwrap();
693 let (k5, _) = q5.read_all().unwrap();
694 let cos4 = cosine(&k4, &k);
695 let cos5 = cosine(&k5, &k);
696 assert!(cos5 >= cos4 - 1e-3, "Q5_0 should not be worse than Q4_0");
697 }
698
699 #[test]
700 fn sliding_window_drops_oldest() {
701 let kv_dim = 32;
702 let mut layer = QuantizedKvLayer::new(kv_dim, KvQuant::Q8_0).unwrap();
703 for r in 0..5 {
704 let v: Vec<f32> = (0..kv_dim).map(|i| (i + r * 100) as f32).collect();
705 layer.append_rows(&v, &v).unwrap();
706 }
707 assert_eq!(layer.past_len, 5);
708 layer.drop_front(2).unwrap();
709 assert_eq!(layer.past_len, 3);
710 let (k, _v) = layer.read_window(3).unwrap();
711 assert!((k[0] - 200.0).abs() < 1.0);
713 }
714
715 #[test]
716 fn kv_dim_must_align_to_block_size() {
717 assert!(QuantizedKvLayer::new(24, KvQuant::Q8_0).is_err());
719 assert!(QuantizedKvLayer::new(24, KvQuant::Q4_0).is_err());
720 assert!(QuantizedKvLayer::new(24, KvQuant::F16).is_ok());
722 }
723
724 #[test]
725 fn cache_memory_decreases_with_quantization() {
726 let kv_dim = 128;
727 let n_layers = 4;
728 let n_rows = 16;
729 let data: Vec<f32> = (0..kv_dim).map(|i| (i as f32) * 0.01).collect();
730 let mut f16 = QuantizedKvCache::new(n_layers, kv_dim, KvQuant::F16).unwrap();
731 let mut q8 = QuantizedKvCache::new(n_layers, kv_dim, KvQuant::Q8_0).unwrap();
732 let mut q4 = QuantizedKvCache::new(n_layers, kv_dim, KvQuant::Q4_0).unwrap();
733 for _ in 0..n_rows {
734 for l in 0..n_layers {
735 f16.layers[l].append_rows(&data, &data).unwrap();
736 q8.layers[l].append_rows(&data, &data).unwrap();
737 q4.layers[l].append_rows(&data, &data).unwrap();
738 }
739 }
740 assert!(q8.bytes() < f16.bytes());
741 assert!(q4.bytes() < q8.bytes());
742 }
743}