Skip to main content

rlx_runtime/
quantized_kv.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Block-quantized K/V cache — store decode-time history as `q8_0`,
17//! `q4_0`, or `q5_0` GGUF-encoded blocks instead of f32/f16. Memory
18//! saving vs f16 is roughly:
19//!
20//! | scheme | bits/elem | ratio vs f16 |
21//! |--------|-----------|--------------|
22//! | f16    | 16        | 1.0×         |
23//! | q8_0   | 8.5       | 0.53×        |
24//! | q5_0   | 5.5       | 0.34×        |
25//! | q4_0   | 4.5       | 0.28×        |
26//!
27//! Trade-off: quantization adds noise to attention scores. q8_0 is
28//! near-lossless for most decoder LMs; q4_0 typically costs ~0.3 ppl
29//! at 4× memory savings.
30//!
31//! Layout per layer
32//! ----------------
33//!
34//! Each layer's K and V buffer is a flat `Vec<u8>` of `past_len`
35//! quantized rows. Every "row" is `kv_dim` f32 elements when
36//! dequantized; rows are stored back-to-back. `kv_dim` must be a
37//! multiple of the scheme's block size (32 for all three schemes).
38//!
39//! On read, callers materialize a window of rows to f32 via
40//! [`dequant_rows`]. On write, freshly produced f32 K/V is quantized
41//! one row at a time via [`quant_rows`] before being appended. The
42//! quantization wrappers route to the `rlx_gguf::quantize` /
43//! `dequant_*` kernels for parity with on-disk GGUF blocks.
44
45use anyhow::{Result, anyhow, bail};
46use rlx_gguf::{GgmlType, quantize};
47
48/// Quantization scheme for cache rows. Restricted to the three
49/// q-formats whose blocks are 32 elements wide and stable across
50/// llama.cpp versions. The K-quants (Q4_K etc.) require 256-element
51/// blocks, which doesn't compose cleanly with typical kv_dim values
52/// (e.g. 128 head dim) so we don't expose them here.
53#[derive(Debug, Clone, Copy, PartialEq, Eq)]
54pub enum KvQuant {
55    /// `f16` — lossless storage of f32→f16 (no quantization). Kept as a
56    /// useful baseline; 2 bytes per element.
57    F16,
58    Q8_0,
59    Q4_0,
60    Q5_0,
61}
62
63impl KvQuant {
64    /// On-disk block size in elements.
65    pub const fn block_elements(self) -> usize {
66        match self {
67            Self::F16 => 1,
68            Self::Q8_0 | Self::Q4_0 | Self::Q5_0 => 32,
69        }
70    }
71
72    /// On-disk block size in bytes.
73    pub const fn block_bytes(self) -> usize {
74        match self {
75            Self::F16 => 2,
76            Self::Q8_0 => 2 + 32,
77            Self::Q4_0 => 2 + 32 / 2,
78            Self::Q5_0 => 2 + 4 + 32 / 2,
79        }
80    }
81
82    fn ggml_type(self) -> Option<GgmlType> {
83        match self {
84            Self::F16 => None, // direct f16 path
85            Self::Q8_0 => Some(GgmlType::Q8_0),
86            Self::Q4_0 => Some(GgmlType::Q4_0),
87            Self::Q5_0 => Some(GgmlType::Q5_0),
88        }
89    }
90
91    /// Bytes required to store `n_elements` quantized.
92    pub fn bytes_for(self, n_elements: usize) -> Result<usize> {
93        let blk = self.block_elements();
94        if !n_elements.is_multiple_of(blk) {
95            bail!("{self:?}: element count {n_elements} not aligned to block size {blk}");
96        }
97        Ok((n_elements / blk) * self.block_bytes())
98    }
99}
100
101/// One layer's quantized K/V buffers.
102///
103/// Rows are appended back-to-back. `past_len` is the number of *logical*
104/// rows currently stored; the byte buffers carry `past_len × kv_dim`
105/// elements' worth of quantized bytes.
106#[derive(Debug, Clone)]
107pub struct QuantizedKvLayer {
108    pub k: Vec<u8>,
109    pub v: Vec<u8>,
110    pub past_len: usize,
111    pub kv_dim: usize,
112    pub scheme: KvQuant,
113}
114
115impl QuantizedKvLayer {
116    pub fn new(kv_dim: usize, scheme: KvQuant) -> Result<Self> {
117        let blk = scheme.block_elements();
118        if !kv_dim.is_multiple_of(blk) {
119            bail!("kv_dim ({kv_dim}) must be a multiple of {scheme:?} block size ({blk})");
120        }
121        Ok(Self {
122            k: Vec::new(),
123            v: Vec::new(),
124            past_len: 0,
125            kv_dim,
126            scheme,
127        })
128    }
129
130    /// Append `rows` worth of K and V (f32, row-major) — quantizing
131    /// each row independently. Caller passes interleaved K then V
132    /// blocks via separate slices.
133    pub fn append_rows(&mut self, k_rows: &[f32], v_rows: &[f32]) -> Result<()> {
134        if k_rows.len() != v_rows.len() {
135            bail!(
136                "append_rows: k len {} != v len {}",
137                k_rows.len(),
138                v_rows.len()
139            );
140        }
141        if !k_rows.len().is_multiple_of(self.kv_dim) {
142            bail!(
143                "append_rows: byte count {} not aligned to kv_dim {}",
144                k_rows.len(),
145                self.kv_dim
146            );
147        }
148        let n_rows = k_rows.len() / self.kv_dim;
149        let k_bytes = quant_rows(k_rows, self.scheme)?;
150        let v_bytes = quant_rows(v_rows, self.scheme)?;
151        self.k.extend_from_slice(&k_bytes);
152        self.v.extend_from_slice(&v_bytes);
153        self.past_len += n_rows;
154        Ok(())
155    }
156
157    /// Dequantize all stored rows back to f32 (K, V).
158    pub fn read_all(&self) -> Result<(Vec<f32>, Vec<f32>)> {
159        let k = dequant_rows(&self.k, self.scheme, self.past_len * self.kv_dim)?;
160        let v = dequant_rows(&self.v, self.scheme, self.past_len * self.kv_dim)?;
161        Ok((k, v))
162    }
163
164    /// Dequantize the last `window` rows (or all rows if past_len ≤ window).
165    pub fn read_window(&self, window: usize) -> Result<(Vec<f32>, Vec<f32>)> {
166        if window >= self.past_len {
167            return self.read_all();
168        }
169        // Window slicing needs byte-precise offsets because rows are quantized.
170        // Each block holds N elements; rows are aligned to (kv_dim / block).
171        let blk = self.scheme.block_elements();
172        let blocks_per_row = self.kv_dim / blk;
173        let bytes_per_row = blocks_per_row * self.scheme.block_bytes();
174        let start_byte = (self.past_len - window) * bytes_per_row;
175        let n = window * self.kv_dim;
176        let k = dequant_rows(&self.k[start_byte..], self.scheme, n)?;
177        let v = dequant_rows(&self.v[start_byte..], self.scheme, n)?;
178        Ok((k, v))
179    }
180
181    /// Drop the oldest `n_rows` from this layer (sliding window).
182    pub fn drop_front(&mut self, n_rows: usize) -> Result<()> {
183        let n_rows = n_rows.min(self.past_len);
184        if n_rows == 0 {
185            return Ok(());
186        }
187        let blk = self.scheme.block_elements();
188        let blocks_per_row = self.kv_dim / blk;
189        let drop_bytes = n_rows * blocks_per_row * self.scheme.block_bytes();
190        self.k.drain(..drop_bytes);
191        self.v.drain(..drop_bytes);
192        self.past_len -= n_rows;
193        Ok(())
194    }
195
196    /// Memory used by both buffers (bytes).
197    pub fn bytes(&self) -> usize {
198        self.k.len() + self.v.len()
199    }
200}
201
202/// All layers of a quantized KV cache.
203#[derive(Debug, Clone)]
204pub struct QuantizedKvCache {
205    pub layers: Vec<QuantizedKvLayer>,
206}
207
208impl QuantizedKvCache {
209    pub fn new(n_layers: usize, kv_dim: usize, scheme: KvQuant) -> Result<Self> {
210        let layers = (0..n_layers)
211            .map(|_| QuantizedKvLayer::new(kv_dim, scheme))
212            .collect::<Result<Vec<_>>>()?;
213        Ok(Self { layers })
214    }
215
216    pub fn n_layers(&self) -> usize {
217        self.layers.len()
218    }
219
220    pub fn past_len(&self) -> usize {
221        self.layers.first().map(|l| l.past_len).unwrap_or(0)
222    }
223
224    /// Total bytes across all layers.
225    pub fn bytes(&self) -> usize {
226        self.layers.iter().map(|l| l.bytes()).sum()
227    }
228}
229
230// ─── quant / dequant entry points ────────────────────────────────────
231
232fn quant_rows(values: &[f32], scheme: KvQuant) -> Result<Vec<u8>> {
233    match scheme {
234        KvQuant::F16 => {
235            let mut out = Vec::with_capacity(values.len() * 2);
236            for &v in values {
237                let h = half::f16::from_f32(v);
238                out.extend_from_slice(&h.to_le_bytes());
239            }
240            Ok(out)
241        }
242        scheme => {
243            let ty = scheme
244                .ggml_type()
245                .ok_or_else(|| anyhow!("internal: missing ggml type for {scheme:?}"))?;
246            Ok(quantize(values, ty)?)
247        }
248    }
249}
250
251fn dequant_rows(bytes: &[u8], scheme: KvQuant, n: usize) -> Result<Vec<f32>> {
252    match scheme {
253        KvQuant::F16 => {
254            if bytes.len() < n * 2 {
255                bail!("F16 dequant: {} bytes < {} expected", bytes.len(), n * 2);
256            }
257            let mut out = Vec::with_capacity(n);
258            for chunk in bytes[..n * 2].chunks_exact(2) {
259                let h = half::f16::from_le_bytes([chunk[0], chunk[1]]);
260                out.push(h.to_f32());
261            }
262            Ok(out)
263        }
264        KvQuant::Q8_0 => {
265            let expected = scheme.bytes_for(n)?;
266            Ok(rlx_gguf::dequant_q8_0(&bytes[..expected], n)?)
267        }
268        KvQuant::Q4_0 => {
269            let expected = scheme.bytes_for(n)?;
270            Ok(rlx_gguf::dequant_q4_0(&bytes[..expected], n)?)
271        }
272        KvQuant::Q5_0 => {
273            // Q5_0 doesn't have a top-level `dequant_q5_0` export — call the
274            // private path via dequant_f32 by building a one-shot GgufFile?
275            // Cleaner: use the per-block helper exposed indirectly through
276            // bytes_for + a hand-written loop. For now we use the public
277            // `dequant_q8_0`-style path which is exposed; Q5_0 needs the
278            // same. Until the gguf crate exposes a public `dequant_q5_0`,
279            // route through the same block-by-block decoder lifted from
280            // ggml-quants.c.
281            decode_q5_0(bytes, n)
282        }
283    }
284}
285
286fn decode_q5_0(bytes: &[u8], n: usize) -> Result<Vec<f32>> {
287    const QK5_0: usize = 32;
288    let blk_bytes = 2 + 4 + QK5_0 / 2;
289    if !n.is_multiple_of(QK5_0) {
290        bail!("Q5_0: n={n} not divisible by {QK5_0}");
291    }
292    let nb = n / QK5_0;
293    if bytes.len() < nb * blk_bytes {
294        bail!(
295            "Q5_0: expected {} bytes, got {}",
296            nb * blk_bytes,
297            bytes.len()
298        );
299    }
300    let mut out = Vec::with_capacity(n);
301    for i in 0..nb {
302        let off = i * blk_bytes;
303        let d = half::f16::from_le_bytes([bytes[off], bytes[off + 1]]).to_f32();
304        let qh = u32::from_le_bytes([
305            bytes[off + 2],
306            bytes[off + 3],
307            bytes[off + 4],
308            bytes[off + 5],
309        ]);
310        let qs = &bytes[off + 6..off + 6 + QK5_0 / 2];
311        for j in 0..QK5_0 / 2 {
312            let xh0 = (((qh >> j) & 1) as u8) << 4;
313            let v0 = ((qs[j] & 0x0F) | xh0) as i32 - 16;
314            out.push(d * v0 as f32);
315        }
316        for j in 0..QK5_0 / 2 {
317            let xh1 = (((qh >> (j + 16)) & 1) as u8) << 4;
318            let v1 = ((qs[j] >> 4) | xh1) as i32 - 16;
319            out.push(d * v1 as f32);
320        }
321    }
322    Ok(out)
323}
324
325// ─── mmap-backed storage (feature = "mmap-kv") ───────────────────────
326//
327// Memory-mapped storage trades a Vec<u8> for a file-backed (or anonymous)
328// `MmapMut` so the OS pages quantized blocks in/out on demand. Two
329// use cases:
330//
331// 1. **Long contexts** — KV history for 100k-token decode runs can
332//    exceed RAM. With mmap, the kernel evicts cold pages to swap or
333//    the backing file; reactivating them is a page fault, not a
334//    user-space read().
335//
336// 2. **Zero-copy GPU upload** — paged memory is friendlier to
337//    `cudaHostRegister` / Metal `MTLBuffer::newBufferWithBytesNoCopy`,
338//    which can pin and DMA without an extra staging copy.
339//
340// Append semantics are emulated: we pre-allocate `capacity_rows`
341// worth of blocks and track a write head, then `set_len` on the file
342// when finalizing. For "grow as you go" the caller passes
343// `capacity_rows = max_seq_len`.
344
345#[cfg(feature = "mmap-kv")]
346pub mod mmap {
347    use super::*;
348    use memmap2::{MmapMut, MmapOptions};
349    use std::fs::OpenOptions;
350    use std::path::{Path, PathBuf};
351
352    /// File-backed quantized K/V buffer. K and V share one mapping —
353    /// V follows K in the file. Disk layout: `[K bytes][V bytes]`.
354    pub struct MmapKvLayer {
355        pub mmap: MmapMut,
356        pub past_len: usize,
357        pub capacity_rows: usize,
358        pub kv_dim: usize,
359        pub scheme: KvQuant,
360        pub bytes_per_row: usize,
361        pub k_offset: usize,
362        pub v_offset: usize,
363        pub path: Option<PathBuf>,
364    }
365
366    impl MmapKvLayer {
367        /// File-backed mapping. Creates / truncates `path` to
368        /// `2 × capacity_rows × bytes_per_row` and maps it RW.
369        pub fn open<P: AsRef<Path>>(
370            path: P,
371            kv_dim: usize,
372            scheme: KvQuant,
373            capacity_rows: usize,
374        ) -> Result<Self> {
375            let blk = scheme.block_elements();
376            if !kv_dim.is_multiple_of(blk) {
377                bail!("kv_dim ({kv_dim}) must be a multiple of {scheme:?} block size ({blk})");
378            }
379            let bytes_per_row = (kv_dim / blk) * scheme.block_bytes();
380            let total = 2 * capacity_rows * bytes_per_row;
381            let file = OpenOptions::new()
382                .read(true)
383                .write(true)
384                .create(true)
385                .truncate(true)
386                .open(&path)?;
387            file.set_len(total as u64)?;
388            let mmap = unsafe { MmapOptions::new().len(total).map_mut(&file)? };
389            Ok(Self {
390                mmap,
391                past_len: 0,
392                capacity_rows,
393                kv_dim,
394                scheme,
395                bytes_per_row,
396                k_offset: 0,
397                v_offset: capacity_rows * bytes_per_row,
398                path: Some(path.as_ref().to_path_buf()),
399            })
400        }
401
402        /// Anonymous (private) mapping. Lives in swap-backed pages;
403        /// not persisted. Use when you want OS-level paging without
404        /// keeping a file around.
405        pub fn anonymous(kv_dim: usize, scheme: KvQuant, capacity_rows: usize) -> Result<Self> {
406            let blk = scheme.block_elements();
407            if !kv_dim.is_multiple_of(blk) {
408                bail!("kv_dim ({kv_dim}) must be a multiple of {scheme:?} block size ({blk})");
409            }
410            let bytes_per_row = (kv_dim / blk) * scheme.block_bytes();
411            let total = 2 * capacity_rows * bytes_per_row;
412            let mmap = MmapOptions::new().len(total).map_anon()?;
413            Ok(Self {
414                mmap,
415                past_len: 0,
416                capacity_rows,
417                kv_dim,
418                scheme,
419                bytes_per_row,
420                k_offset: 0,
421                v_offset: capacity_rows * bytes_per_row,
422                path: None,
423            })
424        }
425
426        /// Append `n_rows` × `kv_dim` worth of K and V floats (one row
427        /// at a time, quantized in place).
428        pub fn append_rows(&mut self, k_rows: &[f32], v_rows: &[f32]) -> Result<()> {
429            if k_rows.len() != v_rows.len() {
430                bail!("append_rows: k/v length mismatch");
431            }
432            if !k_rows.len().is_multiple_of(self.kv_dim) {
433                bail!("append_rows: byte count not aligned to kv_dim");
434            }
435            let n_rows = k_rows.len() / self.kv_dim;
436            if self.past_len + n_rows > self.capacity_rows {
437                bail!(
438                    "append_rows: would exceed capacity ({} + {} > {})",
439                    self.past_len,
440                    n_rows,
441                    self.capacity_rows
442                );
443            }
444            let kb = quant_rows(k_rows, self.scheme)?;
445            let vb = quant_rows(v_rows, self.scheme)?;
446            let k_start = self.k_offset + self.past_len * self.bytes_per_row;
447            let v_start = self.v_offset + self.past_len * self.bytes_per_row;
448            self.mmap[k_start..k_start + kb.len()].copy_from_slice(&kb);
449            self.mmap[v_start..v_start + vb.len()].copy_from_slice(&vb);
450            self.past_len += n_rows;
451            Ok(())
452        }
453
454        /// Dequantize all stored rows. Reads zero-copy from the page
455        /// cache — only touched pages are faulted in.
456        pub fn read_all(&self) -> Result<(Vec<f32>, Vec<f32>)> {
457            let n = self.past_len * self.kv_dim;
458            let k_end = self.k_offset + self.past_len * self.bytes_per_row;
459            let v_end = self.v_offset + self.past_len * self.bytes_per_row;
460            let k = dequant_rows(&self.mmap[self.k_offset..k_end], self.scheme, n)?;
461            let v = dequant_rows(&self.mmap[self.v_offset..v_end], self.scheme, n)?;
462            Ok((k, v))
463        }
464
465        /// Read the last `window` rows (page-fault-lazy on dequant).
466        pub fn read_window(&self, window: usize) -> Result<(Vec<f32>, Vec<f32>)> {
467            let window = window.min(self.past_len);
468            let start_row = self.past_len - window;
469            let n = window * self.kv_dim;
470            let k_start = self.k_offset + start_row * self.bytes_per_row;
471            let v_start = self.v_offset + start_row * self.bytes_per_row;
472            let k_end = k_start + window * self.bytes_per_row;
473            let v_end = v_start + window * self.bytes_per_row;
474            let k = dequant_rows(&self.mmap[k_start..k_end], self.scheme, n)?;
475            let v = dequant_rows(&self.mmap[v_start..v_end], self.scheme, n)?;
476            Ok((k, v))
477        }
478
479        /// Hint the kernel that we're about to read `window` rows
480        /// linearly — prefetches pages into the page cache (madvise
481        /// WILLNEED on supported platforms). Best-effort: failures are
482        /// logged but don't propagate.
483        pub fn prefetch_window(&self, window: usize) {
484            let window = window.min(self.past_len);
485            if window == 0 {
486                return;
487            }
488            let start_row = self.past_len - window;
489            let k_start = self.k_offset + start_row * self.bytes_per_row;
490            let v_start = self.v_offset + start_row * self.bytes_per_row;
491            let _ = self.mmap.advise_range(
492                memmap2::Advice::WillNeed,
493                k_start,
494                window * self.bytes_per_row,
495            );
496            let _ = self.mmap.advise_range(
497                memmap2::Advice::WillNeed,
498                v_start,
499                window * self.bytes_per_row,
500            );
501        }
502
503        /// Persist any dirty pages to the backing file. No-op for
504        /// anonymous mappings.
505        pub fn flush(&self) -> Result<()> {
506            self.mmap.flush()?;
507            Ok(())
508        }
509
510        pub fn bytes(&self) -> usize {
511            2 * self.past_len * self.bytes_per_row
512        }
513    }
514
515    /// All-layer mmap-backed KV cache.
516    pub struct MmapKvCache {
517        pub layers: Vec<MmapKvLayer>,
518    }
519
520    impl MmapKvCache {
521        /// One file per layer under `dir`, named `kv_{i}.bin`.
522        pub fn open_dir<P: AsRef<Path>>(
523            dir: P,
524            n_layers: usize,
525            kv_dim: usize,
526            scheme: KvQuant,
527            capacity_rows: usize,
528        ) -> Result<Self> {
529            let dir = dir.as_ref();
530            std::fs::create_dir_all(dir)?;
531            let layers = (0..n_layers)
532                .map(|i| {
533                    MmapKvLayer::open(
534                        dir.join(format!("kv_{i}.bin")),
535                        kv_dim,
536                        scheme,
537                        capacity_rows,
538                    )
539                })
540                .collect::<Result<Vec<_>>>()?;
541            Ok(Self { layers })
542        }
543
544        pub fn anonymous(
545            n_layers: usize,
546            kv_dim: usize,
547            scheme: KvQuant,
548            capacity_rows: usize,
549        ) -> Result<Self> {
550            let layers = (0..n_layers)
551                .map(|_| MmapKvLayer::anonymous(kv_dim, scheme, capacity_rows))
552                .collect::<Result<Vec<_>>>()?;
553            Ok(Self { layers })
554        }
555
556        pub fn n_layers(&self) -> usize {
557            self.layers.len()
558        }
559
560        pub fn past_len(&self) -> usize {
561            self.layers.first().map(|l| l.past_len).unwrap_or(0)
562        }
563
564        /// Total bytes currently in use across all layers.
565        pub fn bytes(&self) -> usize {
566            self.layers.iter().map(|l| l.bytes()).sum()
567        }
568    }
569
570    #[cfg(test)]
571    mod tests {
572        use super::*;
573
574        #[test]
575        fn anonymous_q8_0_roundtrip() {
576            let kv_dim = 64;
577            let mut layer = MmapKvLayer::anonymous(kv_dim, KvQuant::Q8_0, 4).unwrap();
578            let data: Vec<f32> = (0..kv_dim).map(|i| (i as f32).sin()).collect();
579            layer.append_rows(&data, &data).unwrap();
580            let (k, v) = layer.read_all().unwrap();
581            assert_eq!(k.len(), kv_dim);
582            assert_eq!(v.len(), kv_dim);
583            // Q8_0 is high-fidelity.
584            for (a, b) in k.iter().zip(data.iter()) {
585                assert!((a - b).abs() < 0.02);
586            }
587        }
588
589        #[test]
590        fn file_backed_persists_and_reopens() {
591            let dir = tempfile::tempdir().unwrap();
592            let kv_dim = 32;
593            let path = dir.path().join("layer.bin");
594            {
595                let mut layer = MmapKvLayer::open(&path, kv_dim, KvQuant::F16, 8).unwrap();
596                let data: Vec<f32> = (0..kv_dim).map(|i| i as f32 * 0.5).collect();
597                layer.append_rows(&data, &data).unwrap();
598                layer.flush().unwrap();
599            }
600            // Re-open and verify K bytes are present (we know offset 0..len).
601            let bytes = std::fs::read(&path).unwrap();
602            assert!(!bytes.is_empty());
603            assert!(bytes.iter().any(|&b| b != 0));
604        }
605
606        #[test]
607        fn append_past_capacity_errors() {
608            let mut l = MmapKvLayer::anonymous(32, KvQuant::Q8_0, 2).unwrap();
609            let row = vec![0.5f32; 32];
610            l.append_rows(&row, &row).unwrap();
611            l.append_rows(&row, &row).unwrap();
612            assert!(l.append_rows(&row, &row).is_err());
613        }
614    }
615}
616
617#[cfg(test)]
618mod tests {
619    use super::*;
620
621    fn cosine(a: &[f32], b: &[f32]) -> f32 {
622        let mut dot = 0.0f32;
623        let mut na = 0.0f32;
624        let mut nb = 0.0f32;
625        for (x, y) in a.iter().zip(b.iter()) {
626            dot += x * y;
627            na += x * x;
628            nb += y * y;
629        }
630        dot / (na.sqrt() * nb.sqrt() + 1e-12)
631    }
632
633    #[test]
634    fn block_size_invariants() {
635        assert_eq!(KvQuant::F16.block_bytes(), 2);
636        assert_eq!(KvQuant::Q8_0.block_bytes(), 34);
637        assert_eq!(KvQuant::Q4_0.block_bytes(), 18);
638        assert_eq!(KvQuant::Q5_0.block_bytes(), 22);
639    }
640
641    #[test]
642    fn f16_roundtrip_exact() {
643        let kv_dim = 64;
644        let mut layer = QuantizedKvLayer::new(kv_dim, KvQuant::F16).unwrap();
645        let k_row: Vec<f32> = (0..kv_dim).map(|i| (i as f32) * 0.1).collect();
646        let v_row: Vec<f32> = (0..kv_dim).map(|i| (i as f32) * 0.2).collect();
647        layer.append_rows(&k_row, &v_row).unwrap();
648        let (k, v) = layer.read_all().unwrap();
649        for i in 0..kv_dim {
650            // f16 round-trip is bounded ~1e-3 relative for small magnitudes.
651            assert!((k[i] - k_row[i]).abs() < 0.01);
652            assert!((v[i] - v_row[i]).abs() < 0.01);
653        }
654    }
655
656    #[test]
657    fn q8_0_roundtrip_high_fidelity() {
658        let kv_dim = 64;
659        let n_rows = 4;
660        let mut layer = QuantizedKvLayer::new(kv_dim, KvQuant::Q8_0).unwrap();
661        let total = n_rows * kv_dim;
662        let k_data: Vec<f32> = (0..total).map(|i| (i as f32).sin()).collect();
663        let v_data: Vec<f32> = (0..total).map(|i| (i as f32).cos()).collect();
664        layer.append_rows(&k_data, &v_data).unwrap();
665        assert_eq!(layer.past_len, n_rows);
666        let (k, v) = layer.read_all().unwrap();
667        assert!(cosine(&k, &k_data) > 0.999, "Q8_0 K cosine too low");
668        assert!(cosine(&v, &v_data) > 0.999, "Q8_0 V cosine too low");
669    }
670
671    #[test]
672    fn q4_0_roundtrip_lossy_but_close() {
673        let kv_dim = 64;
674        let mut layer = QuantizedKvLayer::new(kv_dim, KvQuant::Q4_0).unwrap();
675        let k: Vec<f32> = (0..kv_dim).map(|i| (i as f32 * 0.05).tanh()).collect();
676        let v: Vec<f32> = (0..kv_dim).map(|i| (i as f32 * 0.07).tanh()).collect();
677        layer.append_rows(&k, &v).unwrap();
678        let (kr, vr) = layer.read_all().unwrap();
679        assert!(cosine(&kr, &k) > 0.99);
680        assert!(cosine(&vr, &v) > 0.99);
681    }
682
683    #[test]
684    fn q5_0_roundtrip_better_than_q4() {
685        let kv_dim = 64;
686        let mut q4 = QuantizedKvLayer::new(kv_dim, KvQuant::Q4_0).unwrap();
687        let mut q5 = QuantizedKvLayer::new(kv_dim, KvQuant::Q5_0).unwrap();
688        let k: Vec<f32> = (0..kv_dim).map(|i| (i as f32 * 0.1).sin() * 3.0).collect();
689        let v: Vec<f32> = (0..kv_dim).map(|i| (i as f32 * 0.13).cos() * 3.0).collect();
690        q4.append_rows(&k, &v).unwrap();
691        q5.append_rows(&k, &v).unwrap();
692        let (k4, _) = q4.read_all().unwrap();
693        let (k5, _) = q5.read_all().unwrap();
694        let cos4 = cosine(&k4, &k);
695        let cos5 = cosine(&k5, &k);
696        assert!(cos5 >= cos4 - 1e-3, "Q5_0 should not be worse than Q4_0");
697    }
698
699    #[test]
700    fn sliding_window_drops_oldest() {
701        let kv_dim = 32;
702        let mut layer = QuantizedKvLayer::new(kv_dim, KvQuant::Q8_0).unwrap();
703        for r in 0..5 {
704            let v: Vec<f32> = (0..kv_dim).map(|i| (i + r * 100) as f32).collect();
705            layer.append_rows(&v, &v).unwrap();
706        }
707        assert_eq!(layer.past_len, 5);
708        layer.drop_front(2).unwrap();
709        assert_eq!(layer.past_len, 3);
710        let (k, _v) = layer.read_window(3).unwrap();
711        // First kept row is original row 2 → starts with value 200.
712        assert!((k[0] - 200.0).abs() < 1.0);
713    }
714
715    #[test]
716    fn kv_dim_must_align_to_block_size() {
717        // kv_dim=24 < 32 → not aligned for Q8_0/Q4_0/Q5_0.
718        assert!(QuantizedKvLayer::new(24, KvQuant::Q8_0).is_err());
719        assert!(QuantizedKvLayer::new(24, KvQuant::Q4_0).is_err());
720        // f16 has block 1 so any dim works.
721        assert!(QuantizedKvLayer::new(24, KvQuant::F16).is_ok());
722    }
723
724    #[test]
725    fn cache_memory_decreases_with_quantization() {
726        let kv_dim = 128;
727        let n_layers = 4;
728        let n_rows = 16;
729        let data: Vec<f32> = (0..kv_dim).map(|i| (i as f32) * 0.01).collect();
730        let mut f16 = QuantizedKvCache::new(n_layers, kv_dim, KvQuant::F16).unwrap();
731        let mut q8 = QuantizedKvCache::new(n_layers, kv_dim, KvQuant::Q8_0).unwrap();
732        let mut q4 = QuantizedKvCache::new(n_layers, kv_dim, KvQuant::Q4_0).unwrap();
733        for _ in 0..n_rows {
734            for l in 0..n_layers {
735                f16.layers[l].append_rows(&data, &data).unwrap();
736                q8.layers[l].append_rows(&data, &data).unwrap();
737                q4.layers[l].append_rows(&data, &data).unwrap();
738            }
739        }
740        assert!(q8.bytes() < f16.bytes());
741        assert!(q4.bytes() < q8.bytes());
742    }
743}