1use std::sync::atomic::{AtomicU64, AtomicU8, Ordering};
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
49#[non_exhaustive]
50pub enum KvCacheLevel {
51 Fp16,
53 Q8,
55 Fp8,
58 Q4,
60}
61
62impl KvCacheLevel {
63 pub const fn memory_factor(self) -> f32 {
72 match self {
73 Self::Fp16 => 1.0,
74 Self::Q8 => 0.5,
75 Self::Fp8 => 0.5,
76 Self::Q4 => 0.25,
77 }
78 }
79
80 pub const fn ordinal(self) -> u8 {
87 match self {
88 Self::Fp16 => 0,
89 Self::Q8 => 1,
90 Self::Fp8 => 2,
91 Self::Q4 => 3,
92 }
93 }
94
95 pub const fn tag(self) -> &'static str {
97 match self {
98 Self::Fp16 => "fp16",
99 Self::Q8 => "q8",
100 Self::Fp8 => "fp8",
101 Self::Q4 => "q4",
102 }
103 }
104
105 fn from_ordinal(o: u8) -> Self {
106 match o {
107 0 => Self::Fp16,
108 1 => Self::Q8,
109 2 => Self::Fp8,
110 _ => Self::Q4,
111 }
112 }
113}
114
115#[derive(Debug, Clone)]
122pub struct KvCachePolicyConfig {
123 pub q8_threshold: f32,
125 pub q4_threshold: f32,
127 pub hysteresis: f32,
130 pub ewma_alpha: f32,
133 pub min_level: KvCacheLevel,
135 pub max_level: KvCacheLevel,
137}
138
139impl Default for KvCachePolicyConfig {
140 fn default() -> Self {
141 Self {
142 q8_threshold: 0.80,
143 q4_threshold: 0.95,
144 hysteresis: 0.05,
145 ewma_alpha: 0.20,
146 min_level: KvCacheLevel::Fp16,
147 max_level: KvCacheLevel::Q4,
148 }
149 }
150}
151
152impl KvCachePolicyConfig {
153 pub fn fp16_only() -> Self {
155 Self {
156 min_level: KvCacheLevel::Fp16,
157 max_level: KvCacheLevel::Fp16,
158 ..Self::default()
159 }
160 }
161
162 pub fn aggressive() -> Self {
164 Self {
165 q8_threshold: 0.50,
166 q4_threshold: 0.80,
167 hysteresis: 0.05,
168 ewma_alpha: 0.30,
169 min_level: KvCacheLevel::Q8,
170 max_level: KvCacheLevel::Q4,
171 }
172 }
173
174 fn validate(&self) -> Result<(), KvCachePolicyError> {
175 if !(0.0..=1.0).contains(&self.q8_threshold) {
176 return Err(KvCachePolicyError::InvalidConfig(
177 "q8_threshold must be in [0.0, 1.0]",
178 ));
179 }
180 if !(0.0..=1.0).contains(&self.q4_threshold) {
181 return Err(KvCachePolicyError::InvalidConfig(
182 "q4_threshold must be in [0.0, 1.0]",
183 ));
184 }
185 if self.q4_threshold < self.q8_threshold {
186 return Err(KvCachePolicyError::InvalidConfig(
187 "q4_threshold must be >= q8_threshold",
188 ));
189 }
190 if !(0.0..=1.0).contains(&self.hysteresis) {
191 return Err(KvCachePolicyError::InvalidConfig(
192 "hysteresis must be in [0.0, 1.0]",
193 ));
194 }
195 if !(0.0..=1.0).contains(&self.ewma_alpha) {
196 return Err(KvCachePolicyError::InvalidConfig(
197 "ewma_alpha must be in [0.0, 1.0]",
198 ));
199 }
200 if self.min_level.ordinal() > self.max_level.ordinal() {
201 return Err(KvCachePolicyError::InvalidConfig(
202 "min_level must be <= max_level (less compact)",
203 ));
204 }
205 Ok(())
206 }
207}
208
209#[derive(Debug, thiserror::Error)]
211#[non_exhaustive]
212pub enum KvCachePolicyError {
213 #[error("invalid kv-cache policy configuration: {0}")]
214 InvalidConfig(&'static str),
215}
216
217#[derive(Debug)]
225pub struct KvCachePolicy {
226 config: KvCachePolicyConfig,
227 level: AtomicU8,
229 pressure_ewma: AtomicU64,
231 samples: AtomicU64,
233 upgrades: AtomicU64,
235 downgrades: AtomicU64,
237}
238
239impl Default for KvCachePolicy {
240 fn default() -> Self {
241 Self::new(KvCachePolicyConfig::default()).expect("default config is valid")
242 }
243}
244
245impl KvCachePolicy {
246 pub fn new(config: KvCachePolicyConfig) -> Result<Self, KvCachePolicyError> {
251 config.validate()?;
252 Ok(Self {
253 level: AtomicU8::new(config.min_level.ordinal()),
254 pressure_ewma: AtomicU64::new(0u64),
255 samples: AtomicU64::new(0),
256 upgrades: AtomicU64::new(0),
257 downgrades: AtomicU64::new(0),
258 config,
259 })
260 }
261
262 pub fn current_level(&self) -> KvCacheLevel {
264 KvCacheLevel::from_ordinal(self.level.load(Ordering::Relaxed))
265 }
266
267 pub fn pressure(&self) -> f64 {
269 f64::from_bits(self.pressure_ewma.load(Ordering::Relaxed))
270 }
271
272 pub fn samples(&self) -> u64 {
274 self.samples.load(Ordering::Relaxed)
275 }
276
277 pub fn upgrades(&self) -> u64 {
279 self.upgrades.load(Ordering::Relaxed)
280 }
281
282 pub fn downgrades(&self) -> u64 {
284 self.downgrades.load(Ordering::Relaxed)
285 }
286
287 pub fn observe(&self, pressure: f64) -> KvCacheLevel {
293 let p = pressure.clamp(0.0, 1.0);
294
295 let alpha = self.config.ewma_alpha as f64;
297 let one_minus_alpha = 1.0 - alpha;
298 loop {
299 let current_bits = self.pressure_ewma.load(Ordering::Relaxed);
300 let current = f64::from_bits(current_bits);
301 let n = self.samples.load(Ordering::Relaxed);
302 let new_val = if n == 0 {
303 p
304 } else {
305 alpha * p + one_minus_alpha * current
306 };
307 if self
308 .pressure_ewma
309 .compare_exchange_weak(
310 current_bits,
311 new_val.to_bits(),
312 Ordering::Relaxed,
313 Ordering::Relaxed,
314 )
315 .is_ok()
316 {
317 break;
318 }
319 }
320 self.samples.fetch_add(1, Ordering::Relaxed);
321
322 let smoothed = self.pressure();
324 let current = self.current_level();
325 let target = self.target_level(smoothed, current);
326
327 if target != current {
328 self.level.store(target.ordinal(), Ordering::Relaxed);
329 if target.ordinal() > current.ordinal() {
330 self.upgrades.fetch_add(1, Ordering::Relaxed);
331 } else {
332 self.downgrades.fetch_add(1, Ordering::Relaxed);
333 }
334 }
335 target
336 }
337
338 fn target_level(&self, smoothed: f64, current: KvCacheLevel) -> KvCacheLevel {
342 let q8 = self.config.q8_threshold as f64;
343 let q4 = self.config.q4_threshold as f64;
344 let h = self.config.hysteresis as f64;
345
346 let raw = if smoothed >= q4 {
347 KvCacheLevel::Q4
348 } else if smoothed >= q8 {
349 KvCacheLevel::Q8
350 } else {
351 KvCacheLevel::Fp16
352 };
353
354 let target = match (current, raw) {
357 (KvCacheLevel::Q4, KvCacheLevel::Q8) | (KvCacheLevel::Q4, KvCacheLevel::Fp16) => {
358 if smoothed < q4 - h {
359 raw
360 } else {
361 KvCacheLevel::Q4
362 }
363 }
364 (KvCacheLevel::Q8, KvCacheLevel::Fp16) => {
365 if smoothed < q8 - h {
366 KvCacheLevel::Fp16
367 } else {
368 KvCacheLevel::Q8
369 }
370 }
371 _ => raw,
372 };
373
374 let min_o = self.config.min_level.ordinal();
376 let max_o = self.config.max_level.ordinal();
377 let clamped = target.ordinal().clamp(min_o, max_o);
378 KvCacheLevel::from_ordinal(clamped)
379 }
380
381 pub fn reset(&self) {
384 self.pressure_ewma.store(0u64, Ordering::Relaxed);
385 self.samples.store(0, Ordering::Relaxed);
386 self.upgrades.store(0, Ordering::Relaxed);
387 self.downgrades.store(0, Ordering::Relaxed);
388 self.level
389 .store(self.config.min_level.ordinal(), Ordering::Relaxed);
390 }
391
392 pub fn config(&self) -> &KvCachePolicyConfig {
394 &self.config
395 }
396}
397
398#[cfg(test)]
401mod tests {
402 use super::*;
403
404 #[test]
405 fn level_memory_factor() {
406 assert!((KvCacheLevel::Fp16.memory_factor() - 1.0).abs() < f32::EPSILON);
407 assert!((KvCacheLevel::Q8.memory_factor() - 0.5).abs() < f32::EPSILON);
408 assert!((KvCacheLevel::Q4.memory_factor() - 0.25).abs() < f32::EPSILON);
409 }
410
411 #[test]
412 fn level_ordinal_monotonic() {
413 assert!(KvCacheLevel::Fp16.ordinal() < KvCacheLevel::Q8.ordinal());
414 assert!(KvCacheLevel::Q8.ordinal() < KvCacheLevel::Q4.ordinal());
415 }
416
417 #[test]
418 fn default_policy_starts_at_fp16() {
419 let p = KvCachePolicy::default();
420 assert_eq!(p.current_level(), KvCacheLevel::Fp16);
421 assert_eq!(p.samples(), 0);
422 assert_eq!(p.upgrades(), 0);
423 assert_eq!(p.downgrades(), 0);
424 assert!(p.pressure() < f64::EPSILON);
425 }
426
427 #[test]
428 fn validate_rejects_inverted_thresholds() {
429 let cfg = KvCachePolicyConfig {
430 q8_threshold: 0.9,
431 q4_threshold: 0.5,
432 ..Default::default()
433 };
434 let err = KvCachePolicy::new(cfg).unwrap_err();
435 assert!(matches!(err, KvCachePolicyError::InvalidConfig(_)));
436 }
437
438 #[test]
439 fn validate_rejects_min_greater_than_max() {
440 let cfg = KvCachePolicyConfig {
441 min_level: KvCacheLevel::Q4,
442 max_level: KvCacheLevel::Fp16,
443 ..Default::default()
444 };
445 assert!(KvCachePolicy::new(cfg).is_err());
446 }
447
448 #[test]
449 fn validate_rejects_out_of_range() {
450 let cfg = KvCachePolicyConfig {
451 q8_threshold: 1.5,
452 ..Default::default()
453 };
454 assert!(KvCachePolicy::new(cfg).is_err());
455 }
456
457 #[test]
458 fn low_pressure_stays_fp16() {
459 let p = KvCachePolicy::default();
460 for _ in 0..50 {
461 assert_eq!(p.observe(0.10), KvCacheLevel::Fp16);
462 }
463 }
464
465 #[test]
466 fn sustained_high_pressure_upgrades_to_q8_then_q4() {
467 let p = KvCachePolicy::default();
468 for _ in 0..40 {
470 p.observe(0.85);
471 }
472 assert_eq!(p.current_level(), KvCacheLevel::Q8);
473
474 for _ in 0..40 {
476 p.observe(0.98);
477 }
478 assert_eq!(p.current_level(), KvCacheLevel::Q4);
479 assert!(p.upgrades() >= 2);
480 }
481
482 #[test]
483 fn pressure_drop_downgrades_after_hysteresis() {
484 let p = KvCachePolicy::default();
485 for _ in 0..40 {
486 p.observe(0.98);
487 }
488 assert_eq!(p.current_level(), KvCacheLevel::Q4);
489
490 for _ in 0..40 {
494 p.observe(0.93);
495 }
496 let after_partial = p.current_level();
500 assert!(matches!(after_partial, KvCacheLevel::Q4 | KvCacheLevel::Q8));
501
502 for _ in 0..200 {
504 p.observe(0.05);
505 }
506 assert_eq!(p.current_level(), KvCacheLevel::Fp16);
507 assert!(p.downgrades() >= 1);
508 }
509
510 #[test]
511 fn hysteresis_prevents_thrashing() {
512 let p = KvCachePolicy::default();
513 for _ in 0..40 {
515 p.observe(0.85);
516 }
517 let before = p.upgrades();
518 assert!(before >= 1);
519 for i in 0..40 {
521 let v = if i % 2 == 0 { 0.78 } else { 0.82 };
523 p.observe(v);
524 }
525 let total_changes = p.upgrades() + p.downgrades();
528 assert!(
529 total_changes < 10,
530 "hysteresis should suppress oscillation; saw {total_changes} transitions"
531 );
532 }
533
534 #[test]
535 fn reset_clears_state() {
536 let p = KvCachePolicy::default();
537 for _ in 0..50 {
538 p.observe(0.99);
539 }
540 assert_eq!(p.current_level(), KvCacheLevel::Q4);
541 p.reset();
542 assert_eq!(p.current_level(), KvCacheLevel::Fp16);
543 assert_eq!(p.samples(), 0);
544 assert!(p.pressure() < f64::EPSILON);
545 }
546
547 #[test]
548 fn fp16_only_profile_never_upgrades() {
549 let p = KvCachePolicy::new(KvCachePolicyConfig::fp16_only()).expect("valid config");
550 for _ in 0..200 {
551 assert_eq!(p.observe(1.0), KvCacheLevel::Fp16);
552 }
553 assert_eq!(p.upgrades(), 0);
554 }
555
556 #[test]
557 fn aggressive_profile_starts_at_q8() {
558 let p = KvCachePolicy::new(KvCachePolicyConfig::aggressive()).expect("valid config");
559 assert_eq!(p.current_level(), KvCacheLevel::Q8);
560 for _ in 0..30 {
561 p.observe(0.95);
562 }
563 assert_eq!(p.current_level(), KvCacheLevel::Q4);
564 }
565
566 #[test]
567 fn observed_pressure_is_clamped() {
568 let p = KvCachePolicy::default();
569 p.observe(-1.0);
571 assert!(p.pressure() >= 0.0);
572 p.observe(2.0);
573 assert!(p.pressure() <= 1.0 + 1e-6);
574 }
575
576 #[test]
577 fn level_tag_strings() {
578 assert_eq!(KvCacheLevel::Fp16.tag(), "fp16");
579 assert_eq!(KvCacheLevel::Q8.tag(), "q8");
580 assert_eq!(KvCacheLevel::Q4.tag(), "q4");
581 }
582
583 #[test]
584 fn concurrent_observe_is_safe() {
585 use std::sync::Arc;
586 use std::thread;
587
588 let p = Arc::new(KvCachePolicy::default());
589 let mut handles = Vec::new();
590 for tid in 0..8 {
591 let p = Arc::clone(&p);
592 handles.push(thread::spawn(move || {
593 for i in 0..100 {
594 let v = ((tid + i) % 100) as f64 / 100.0;
595 p.observe(v);
596 }
597 }));
598 }
599 for h in handles {
600 h.join().expect("worker thread panicked");
601 }
602 assert_eq!(p.samples(), 8 * 100);
603 }
604}