Skip to main content

oxibonsai_runtime/
kv_cache_policy.rs

1//! KV cache compression policy controller.
2//!
3//! Adapts the KV cache precision based on cache pressure: as more sequences
4//! accumulate, the cache transitions FP16 → INT8 (Q8) → INT4 (Q4) so the
5//! same memory budget can accommodate longer contexts and more in-flight
6//! requests.
7//!
8//! ## Design
9//!
10//! [`KvCachePolicy`] tracks an exponentially-weighted moving average of cache
11//! occupancy. Crossing one of the configured thresholds upgrades the level;
12//! falling below the threshold *minus* a hysteresis margin downgrades it,
13//! preventing oscillation around boundaries.
14//!
15//! ## Levels
16//!
17//! | Level | Memory factor | Quality |
18//! |-------|---------------|---------|
19//! | `Fp16` | 1.0× | exact |
20//! | `Q8`   | 0.5× | ~0.1% RMSE vs FP16 |
21//! | `Q4`   | 0.25× | ~1% RMSE vs FP16 |
22//!
23//! ## Usage
24//!
25//! ```
26//! use oxibonsai_runtime::kv_cache_policy::{KvCachePolicy, KvCacheLevel};
27//!
28//! let mut policy = KvCachePolicy::default();
29//! // 60 % pressure → still FP16 by default
30//! assert_eq!(policy.observe(0.60), KvCacheLevel::Fp16);
31//! // Sustained 90 % pressure → upgrades to Q8
32//! for _ in 0..20 {
33//!     policy.observe(0.92);
34//! }
35//! assert_eq!(policy.current_level(), KvCacheLevel::Q8);
36//! ```
37
38use std::sync::atomic::{AtomicU64, AtomicU8, Ordering};
39
40// ─── Levels ────────────────────────────────────────────────────────────────
41
42/// KV cache precision tier.
43///
44/// Lower variants are higher precision, larger memory footprint; higher
45/// variants are lower precision, smaller memory footprint.
46///
47/// Compactness ordering (ordinal): `Fp16 (0) < Q8 (1) < Fp8 (2) < Q4 (3)`.
48#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
49#[non_exhaustive]
50pub enum KvCacheLevel {
51    /// FP16 — full quality, baseline memory.
52    Fp16,
53    /// INT8 quantized — half the memory of FP16.
54    Q8,
55    /// FP8 quantized — half of FP32, same byte width as INT8 but floating-point
56    /// distribution preserves more dynamic range for attention activations.
57    Fp8,
58    /// INT4 quantized — quarter the memory of FP16.
59    Q4,
60}
61
62impl KvCacheLevel {
63    /// Memory factor relative to FP16.
64    ///
65    /// | Level | Factor |
66    /// |-------|--------|
67    /// | Fp16  | 1.0    |
68    /// | Q8    | 0.5    |
69    /// | Fp8   | 0.5    |
70    /// | Q4    | 0.25   |
71    pub const fn memory_factor(self) -> f32 {
72        match self {
73            Self::Fp16 => 1.0,
74            Self::Q8 => 0.5,
75            Self::Fp8 => 0.5,
76            Self::Q4 => 0.25,
77        }
78    }
79
80    /// Compactness order: higher = more compact (more aggressive).
81    ///
82    /// Ordering: `Fp16=0 < Q8=1 < Fp8=2 < Q4=3`.
83    /// `Fp8` sits between `Q8` and `Q4` because both use 1 byte per value but
84    /// FP8's floating-point distribution makes it preferable to INT8 for KV
85    /// cache activations while still being intermediate before INT4.
86    pub const fn ordinal(self) -> u8 {
87        match self {
88            Self::Fp16 => 0,
89            Self::Q8 => 1,
90            Self::Fp8 => 2,
91            Self::Q4 => 3,
92        }
93    }
94
95    /// Human-readable tag.
96    pub const fn tag(self) -> &'static str {
97        match self {
98            Self::Fp16 => "fp16",
99            Self::Q8 => "q8",
100            Self::Fp8 => "fp8",
101            Self::Q4 => "q4",
102        }
103    }
104
105    fn from_ordinal(o: u8) -> Self {
106        match o {
107            0 => Self::Fp16,
108            1 => Self::Q8,
109            2 => Self::Fp8,
110            _ => Self::Q4,
111        }
112    }
113}
114
115// ─── Configuration ─────────────────────────────────────────────────────────
116
117/// Configuration for [`KvCachePolicy`].
118///
119/// Default values: upgrade to Q8 above 80 % cache occupancy, upgrade to Q4
120/// above 95 %; hysteresis margin 5 %; EWMA factor 0.20.
121#[derive(Debug, Clone)]
122pub struct KvCachePolicyConfig {
123    /// Cache occupancy threshold (0.0..=1.0) above which we upgrade FP16 → Q8.
124    pub q8_threshold: f32,
125    /// Cache occupancy threshold above which we upgrade Q8 → Q4.
126    pub q4_threshold: f32,
127    /// Symmetric hysteresis margin: a downgrade fires only after pressure
128    /// drops below `threshold - hysteresis`.
129    pub hysteresis: f32,
130    /// EWMA smoothing factor (`alpha` in `s_t = alpha * x_t + (1-alpha)*s_{t-1}`).
131    /// Higher = more reactive, lower = more stable.
132    pub ewma_alpha: f32,
133    /// Initial / minimum tier — set to `Fp16` to allow downgrade.
134    pub min_level: KvCacheLevel,
135    /// Maximum tier — set to `Q4` to allow full compression range.
136    pub max_level: KvCacheLevel,
137}
138
139impl Default for KvCachePolicyConfig {
140    fn default() -> Self {
141        Self {
142            q8_threshold: 0.80,
143            q4_threshold: 0.95,
144            hysteresis: 0.05,
145            ewma_alpha: 0.20,
146            min_level: KvCacheLevel::Fp16,
147            max_level: KvCacheLevel::Q4,
148        }
149    }
150}
151
152impl KvCachePolicyConfig {
153    /// Conservative profile — never upgrades from FP16.
154    pub fn fp16_only() -> Self {
155        Self {
156            min_level: KvCacheLevel::Fp16,
157            max_level: KvCacheLevel::Fp16,
158            ..Self::default()
159        }
160    }
161
162    /// Aggressive profile — starts at Q8 and reaches Q4 sooner.
163    pub fn aggressive() -> Self {
164        Self {
165            q8_threshold: 0.50,
166            q4_threshold: 0.80,
167            hysteresis: 0.05,
168            ewma_alpha: 0.30,
169            min_level: KvCacheLevel::Q8,
170            max_level: KvCacheLevel::Q4,
171        }
172    }
173
174    fn validate(&self) -> Result<(), KvCachePolicyError> {
175        if !(0.0..=1.0).contains(&self.q8_threshold) {
176            return Err(KvCachePolicyError::InvalidConfig(
177                "q8_threshold must be in [0.0, 1.0]",
178            ));
179        }
180        if !(0.0..=1.0).contains(&self.q4_threshold) {
181            return Err(KvCachePolicyError::InvalidConfig(
182                "q4_threshold must be in [0.0, 1.0]",
183            ));
184        }
185        if self.q4_threshold < self.q8_threshold {
186            return Err(KvCachePolicyError::InvalidConfig(
187                "q4_threshold must be >= q8_threshold",
188            ));
189        }
190        if !(0.0..=1.0).contains(&self.hysteresis) {
191            return Err(KvCachePolicyError::InvalidConfig(
192                "hysteresis must be in [0.0, 1.0]",
193            ));
194        }
195        if !(0.0..=1.0).contains(&self.ewma_alpha) {
196            return Err(KvCachePolicyError::InvalidConfig(
197                "ewma_alpha must be in [0.0, 1.0]",
198            ));
199        }
200        if self.min_level.ordinal() > self.max_level.ordinal() {
201            return Err(KvCachePolicyError::InvalidConfig(
202                "min_level must be <= max_level (less compact)",
203            ));
204        }
205        Ok(())
206    }
207}
208
209/// Errors raised by [`KvCachePolicy`].
210#[derive(Debug, thiserror::Error)]
211#[non_exhaustive]
212pub enum KvCachePolicyError {
213    #[error("invalid kv-cache policy configuration: {0}")]
214    InvalidConfig(&'static str),
215}
216
217// ─── Policy controller ─────────────────────────────────────────────────────
218
219/// Stateful KV-cache compression policy.
220///
221/// Thread-safe: the current level is stored in an [`AtomicU8`] so concurrent
222/// observers can read without locking. The pressure EWMA is also stored
223/// atomically (as `u64`-encoded `f64` bits).
224#[derive(Debug)]
225pub struct KvCachePolicy {
226    config: KvCachePolicyConfig,
227    /// Current level encoded as `u8` for atomic load/store.
228    level: AtomicU8,
229    /// EWMA of observed pressure, `f64` bits stored as `u64`.
230    pressure_ewma: AtomicU64,
231    /// Number of observations since construction (also acts as warmup gate).
232    samples: AtomicU64,
233    /// Total upgrades fired (for telemetry).
234    upgrades: AtomicU64,
235    /// Total downgrades fired (for telemetry).
236    downgrades: AtomicU64,
237}
238
239impl Default for KvCachePolicy {
240    fn default() -> Self {
241        Self::new(KvCachePolicyConfig::default()).expect("default config is valid")
242    }
243}
244
245impl KvCachePolicy {
246    /// Construct a new policy.
247    ///
248    /// Returns an error if the config is invalid (out-of-range thresholds,
249    /// inverted hysteresis, or `min_level > max_level`).
250    pub fn new(config: KvCachePolicyConfig) -> Result<Self, KvCachePolicyError> {
251        config.validate()?;
252        Ok(Self {
253            level: AtomicU8::new(config.min_level.ordinal()),
254            pressure_ewma: AtomicU64::new(0u64),
255            samples: AtomicU64::new(0),
256            upgrades: AtomicU64::new(0),
257            downgrades: AtomicU64::new(0),
258            config,
259        })
260    }
261
262    /// Read the current level.
263    pub fn current_level(&self) -> KvCacheLevel {
264        KvCacheLevel::from_ordinal(self.level.load(Ordering::Relaxed))
265    }
266
267    /// Read the smoothed pressure (EWMA).
268    pub fn pressure(&self) -> f64 {
269        f64::from_bits(self.pressure_ewma.load(Ordering::Relaxed))
270    }
271
272    /// Number of observations recorded so far.
273    pub fn samples(&self) -> u64 {
274        self.samples.load(Ordering::Relaxed)
275    }
276
277    /// Number of upgrades fired since construction.
278    pub fn upgrades(&self) -> u64 {
279        self.upgrades.load(Ordering::Relaxed)
280    }
281
282    /// Number of downgrades fired since construction.
283    pub fn downgrades(&self) -> u64 {
284        self.downgrades.load(Ordering::Relaxed)
285    }
286
287    /// Record a new pressure observation and return the (possibly updated)
288    /// active level.
289    ///
290    /// `pressure` is expected in `[0.0, 1.0]`; values are clamped to that
291    /// range before being fed into the EWMA.
292    pub fn observe(&self, pressure: f64) -> KvCacheLevel {
293        let p = pressure.clamp(0.0, 1.0);
294
295        // Update EWMA (CAS loop on the f64-as-u64 bits).
296        let alpha = self.config.ewma_alpha as f64;
297        let one_minus_alpha = 1.0 - alpha;
298        loop {
299            let current_bits = self.pressure_ewma.load(Ordering::Relaxed);
300            let current = f64::from_bits(current_bits);
301            let n = self.samples.load(Ordering::Relaxed);
302            let new_val = if n == 0 {
303                p
304            } else {
305                alpha * p + one_minus_alpha * current
306            };
307            if self
308                .pressure_ewma
309                .compare_exchange_weak(
310                    current_bits,
311                    new_val.to_bits(),
312                    Ordering::Relaxed,
313                    Ordering::Relaxed,
314                )
315                .is_ok()
316            {
317                break;
318            }
319        }
320        self.samples.fetch_add(1, Ordering::Relaxed);
321
322        // Decide tier from smoothed pressure.
323        let smoothed = self.pressure();
324        let current = self.current_level();
325        let target = self.target_level(smoothed, current);
326
327        if target != current {
328            self.level.store(target.ordinal(), Ordering::Relaxed);
329            if target.ordinal() > current.ordinal() {
330                self.upgrades.fetch_add(1, Ordering::Relaxed);
331            } else {
332                self.downgrades.fetch_add(1, Ordering::Relaxed);
333            }
334        }
335        target
336    }
337
338    /// Decide the target tier given smoothed pressure and the current tier.
339    ///
340    /// Pure function — no side effects, useful for tests.
341    fn target_level(&self, smoothed: f64, current: KvCacheLevel) -> KvCacheLevel {
342        let q8 = self.config.q8_threshold as f64;
343        let q4 = self.config.q4_threshold as f64;
344        let h = self.config.hysteresis as f64;
345
346        let raw = if smoothed >= q4 {
347            KvCacheLevel::Q4
348        } else if smoothed >= q8 {
349            KvCacheLevel::Q8
350        } else {
351            KvCacheLevel::Fp16
352        };
353
354        // Apply hysteresis: only allow downgrade if pressure has dropped
355        // below the *previous* tier's threshold by at least `h`.
356        let target = match (current, raw) {
357            (KvCacheLevel::Q4, KvCacheLevel::Q8) | (KvCacheLevel::Q4, KvCacheLevel::Fp16) => {
358                if smoothed < q4 - h {
359                    raw
360                } else {
361                    KvCacheLevel::Q4
362                }
363            }
364            (KvCacheLevel::Q8, KvCacheLevel::Fp16) => {
365                if smoothed < q8 - h {
366                    KvCacheLevel::Fp16
367                } else {
368                    KvCacheLevel::Q8
369                }
370            }
371            _ => raw,
372        };
373
374        // Clamp to [min, max].
375        let min_o = self.config.min_level.ordinal();
376        let max_o = self.config.max_level.ordinal();
377        let clamped = target.ordinal().clamp(min_o, max_o);
378        KvCacheLevel::from_ordinal(clamped)
379    }
380
381    /// Reset the EWMA, sample counter, and tier to the configured minimum.
382    /// Counters for upgrades/downgrades are also reset.
383    pub fn reset(&self) {
384        self.pressure_ewma.store(0u64, Ordering::Relaxed);
385        self.samples.store(0, Ordering::Relaxed);
386        self.upgrades.store(0, Ordering::Relaxed);
387        self.downgrades.store(0, Ordering::Relaxed);
388        self.level
389            .store(self.config.min_level.ordinal(), Ordering::Relaxed);
390    }
391
392    /// Return the configuration this policy was built with.
393    pub fn config(&self) -> &KvCachePolicyConfig {
394        &self.config
395    }
396}
397
398// ─── Tests ─────────────────────────────────────────────────────────────────
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403
404    #[test]
405    fn level_memory_factor() {
406        assert!((KvCacheLevel::Fp16.memory_factor() - 1.0).abs() < f32::EPSILON);
407        assert!((KvCacheLevel::Q8.memory_factor() - 0.5).abs() < f32::EPSILON);
408        assert!((KvCacheLevel::Q4.memory_factor() - 0.25).abs() < f32::EPSILON);
409    }
410
411    #[test]
412    fn level_ordinal_monotonic() {
413        assert!(KvCacheLevel::Fp16.ordinal() < KvCacheLevel::Q8.ordinal());
414        assert!(KvCacheLevel::Q8.ordinal() < KvCacheLevel::Q4.ordinal());
415    }
416
417    #[test]
418    fn default_policy_starts_at_fp16() {
419        let p = KvCachePolicy::default();
420        assert_eq!(p.current_level(), KvCacheLevel::Fp16);
421        assert_eq!(p.samples(), 0);
422        assert_eq!(p.upgrades(), 0);
423        assert_eq!(p.downgrades(), 0);
424        assert!(p.pressure() < f64::EPSILON);
425    }
426
427    #[test]
428    fn validate_rejects_inverted_thresholds() {
429        let cfg = KvCachePolicyConfig {
430            q8_threshold: 0.9,
431            q4_threshold: 0.5,
432            ..Default::default()
433        };
434        let err = KvCachePolicy::new(cfg).unwrap_err();
435        assert!(matches!(err, KvCachePolicyError::InvalidConfig(_)));
436    }
437
438    #[test]
439    fn validate_rejects_min_greater_than_max() {
440        let cfg = KvCachePolicyConfig {
441            min_level: KvCacheLevel::Q4,
442            max_level: KvCacheLevel::Fp16,
443            ..Default::default()
444        };
445        assert!(KvCachePolicy::new(cfg).is_err());
446    }
447
448    #[test]
449    fn validate_rejects_out_of_range() {
450        let cfg = KvCachePolicyConfig {
451            q8_threshold: 1.5,
452            ..Default::default()
453        };
454        assert!(KvCachePolicy::new(cfg).is_err());
455    }
456
457    #[test]
458    fn low_pressure_stays_fp16() {
459        let p = KvCachePolicy::default();
460        for _ in 0..50 {
461            assert_eq!(p.observe(0.10), KvCacheLevel::Fp16);
462        }
463    }
464
465    #[test]
466    fn sustained_high_pressure_upgrades_to_q8_then_q4() {
467        let p = KvCachePolicy::default();
468        // Sustain ~85 % pressure — should reach Q8 but not Q4.
469        for _ in 0..40 {
470            p.observe(0.85);
471        }
472        assert_eq!(p.current_level(), KvCacheLevel::Q8);
473
474        // Push to ~98 % — should reach Q4.
475        for _ in 0..40 {
476            p.observe(0.98);
477        }
478        assert_eq!(p.current_level(), KvCacheLevel::Q4);
479        assert!(p.upgrades() >= 2);
480    }
481
482    #[test]
483    fn pressure_drop_downgrades_after_hysteresis() {
484        let p = KvCachePolicy::default();
485        for _ in 0..40 {
486            p.observe(0.98);
487        }
488        assert_eq!(p.current_level(), KvCacheLevel::Q4);
489
490        // Drop to 0.93 (below 0.95 but within hysteresis margin).
491        // Default hysteresis = 0.05, so pressure must drop below 0.90 to downgrade.
492        // 0.93 is *above* 0.90, so we should still be at Q4.
493        for _ in 0..40 {
494            p.observe(0.93);
495        }
496        // 0.93 sustained should pull EWMA below q4 threshold (0.95) but not below
497        // q4_threshold - hysteresis = 0.90, so we hold at Q4.
498        // (depending on exact EWMA dynamics — accept Q4 or Q8 here)
499        let after_partial = p.current_level();
500        assert!(matches!(after_partial, KvCacheLevel::Q4 | KvCacheLevel::Q8));
501
502        // Now drop hard to 0.10 — should reach Fp16.
503        for _ in 0..200 {
504            p.observe(0.05);
505        }
506        assert_eq!(p.current_level(), KvCacheLevel::Fp16);
507        assert!(p.downgrades() >= 1);
508    }
509
510    #[test]
511    fn hysteresis_prevents_thrashing() {
512        let p = KvCachePolicy::default();
513        // Push above q8 threshold to trigger upgrade.
514        for _ in 0..40 {
515            p.observe(0.85);
516        }
517        let before = p.upgrades();
518        assert!(before >= 1);
519        // Now oscillate just above and below the threshold.
520        for i in 0..40 {
521            // Stays around 0.78 .. 0.82 — within hysteresis band of q8 = 0.80.
522            let v = if i % 2 == 0 { 0.78 } else { 0.82 };
523            p.observe(v);
524        }
525        // We allow at most a small number of additional level changes.
526        // Without hysteresis we'd see ~20 transitions.
527        let total_changes = p.upgrades() + p.downgrades();
528        assert!(
529            total_changes < 10,
530            "hysteresis should suppress oscillation; saw {total_changes} transitions"
531        );
532    }
533
534    #[test]
535    fn reset_clears_state() {
536        let p = KvCachePolicy::default();
537        for _ in 0..50 {
538            p.observe(0.99);
539        }
540        assert_eq!(p.current_level(), KvCacheLevel::Q4);
541        p.reset();
542        assert_eq!(p.current_level(), KvCacheLevel::Fp16);
543        assert_eq!(p.samples(), 0);
544        assert!(p.pressure() < f64::EPSILON);
545    }
546
547    #[test]
548    fn fp16_only_profile_never_upgrades() {
549        let p = KvCachePolicy::new(KvCachePolicyConfig::fp16_only()).expect("valid config");
550        for _ in 0..200 {
551            assert_eq!(p.observe(1.0), KvCacheLevel::Fp16);
552        }
553        assert_eq!(p.upgrades(), 0);
554    }
555
556    #[test]
557    fn aggressive_profile_starts_at_q8() {
558        let p = KvCachePolicy::new(KvCachePolicyConfig::aggressive()).expect("valid config");
559        assert_eq!(p.current_level(), KvCacheLevel::Q8);
560        for _ in 0..30 {
561            p.observe(0.95);
562        }
563        assert_eq!(p.current_level(), KvCacheLevel::Q4);
564    }
565
566    #[test]
567    fn observed_pressure_is_clamped() {
568        let p = KvCachePolicy::default();
569        // Out-of-range values must not break the EWMA.
570        p.observe(-1.0);
571        assert!(p.pressure() >= 0.0);
572        p.observe(2.0);
573        assert!(p.pressure() <= 1.0 + 1e-6);
574    }
575
576    #[test]
577    fn level_tag_strings() {
578        assert_eq!(KvCacheLevel::Fp16.tag(), "fp16");
579        assert_eq!(KvCacheLevel::Q8.tag(), "q8");
580        assert_eq!(KvCacheLevel::Q4.tag(), "q4");
581    }
582
583    #[test]
584    fn concurrent_observe_is_safe() {
585        use std::sync::Arc;
586        use std::thread;
587
588        let p = Arc::new(KvCachePolicy::default());
589        let mut handles = Vec::new();
590        for tid in 0..8 {
591            let p = Arc::clone(&p);
592            handles.push(thread::spawn(move || {
593                for i in 0..100 {
594                    let v = ((tid + i) % 100) as f64 / 100.0;
595                    p.observe(v);
596                }
597            }));
598        }
599        for h in handles {
600            h.join().expect("worker thread panicked");
601        }
602        assert_eq!(p.samples(), 8 * 100);
603    }
604}