use crate::{
array::Array,
error::{
Error, InvariantViolationPayload, LayerKeyedPayload, LengthMismatchPayload, OutOfRangePayload,
Result, UnknownEnumValuePayload,
},
};
use smol_str::format_smolstr;
pub mod arrays;
pub mod batch;
pub mod batch_rotating;
mod cache_list;
mod chunked;
mod mask;
pub mod persist;
pub mod prompt;
mod quantized;
mod rotating;
mod standard;
mod util;
pub use arrays::*;
pub use batch::*;
pub use batch_rotating::*;
pub use cache_list::CacheList;
pub use chunked::*;
pub use mask::{create_attention_mask, create_causal_mask};
pub use persist::*;
pub use prompt::*;
pub use quantized::*;
pub use rotating::RotatingKvCache;
pub use standard::StandardKvCache;
pub const ROTATING_DEFAULT_KEEP: i32 = 4;
#[derive(derive_more::IsVariant, derive_more::Unwrap, derive_more::TryUnwrap)]
#[unwrap(ref, ref_mut)]
#[try_unwrap(ref, ref_mut)]
pub enum RopeOffset {
Scalar(usize),
Batch(Array),
}
#[derive(derive_more::IsVariant, derive_more::Unwrap, derive_more::TryUnwrap)]
#[unwrap(ref, ref_mut)]
#[try_unwrap(ref, ref_mut)]
pub enum MaskMode {
None,
Causal,
Array(Array),
}
pub type QTriple = (Array, Array, Option<Array>);
pub trait KvCache {
fn offset(&self) -> usize;
fn rope_offset(&self) -> Result<RopeOffset> {
match self.as_batch_positioned() {
Some(bp) => Ok(RopeOffset::Batch(bp.batch_offset()?)),
None => Ok(RopeOffset::Scalar(self.offset())),
}
}
fn max_size(&self) -> Option<usize> {
None
}
fn update(&mut self, keys: &Array, values: &Array) -> Result<(Array, Array)>;
fn state(&self) -> Result<Vec<Array>>;
fn state_into(&self, buf: &mut Vec<Array>) -> Result<()> {
buf.extend(self.state()?);
Ok(())
}
fn set_state(&mut self, state: Vec<Array>) -> Result<()>;
fn materialize(&mut self) -> Result<()>;
fn meta_state(&self) -> Vec<String> {
Vec::new()
}
fn meta_state_into(&self, buf: &mut Vec<String>) {
buf.extend(self.meta_state());
}
fn set_meta_state(&mut self, m: &[String]) -> Result<()> {
if !m.is_empty() {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"KvCache::set_meta_state: meta_state value count for a no-meta cache (mirrors mlx-lm `_BaseCache.meta_state` setter cache.py:142-145)",
0,
m.len(),
)));
}
Ok(())
}
#[allow(clippy::wrong_self_convention)] fn from_serialized(&mut self, state: Vec<Array>, meta: &[String]) -> Result<()> {
let snapshot_state = self.state()?;
let snapshot_meta = self.meta_state();
let rollback = |cache: &mut Self,
e: Error,
snap_state: Vec<Array>,
snap_meta: Vec<String>|
-> Error {
if let Err(rb_state_err) = cache.set_state(snap_state) {
return Error::LayerKeyed(LayerKeyedPayload::new(
format_smolstr!(
"KvCache::from_serialized: rollback failed (set_state on snapshot: {rb_state_err})"
),
e,
));
}
if let Err(rb_meta_err) = cache.set_meta_state(&snap_meta) {
return Error::LayerKeyed(LayerKeyedPayload::new(
format_smolstr!(
"KvCache::from_serialized: rollback failed (set_meta_state on snapshot: {rb_meta_err})"
),
e,
));
}
e
};
match self.set_state(state) {
Ok(()) => match self.set_meta_state(meta) {
Ok(()) => Ok(()),
Err(e) => Err(rollback(self, e, snapshot_state, snapshot_meta)),
},
Err(e) => Err(rollback(self, e, snapshot_state, snapshot_meta)),
}
}
fn is_trimmable(&self) -> bool {
false
}
fn trim(&mut self, _n: usize) -> Result<usize> {
Ok(0)
}
fn make_mask(&self, n: usize, window_size: Option<usize>, return_array: bool)
-> Result<MaskMode>;
fn nbytes(&self) -> usize;
fn is_empty(&self) -> bool;
fn copy(&self) -> Result<Box<dyn KvCache>>;
fn as_quantized(&self) -> Option<&dyn QuantizedKvCache> {
None
}
fn as_quantized_mut(&mut self) -> Option<&mut dyn QuantizedKvCache> {
None
}
fn as_batch_positioned(&self) -> Option<&dyn BatchPositionedKvCache> {
None
}
fn as_cache_list(&self) -> Option<&CacheList> {
None
}
fn as_cache_list_mut(&mut self) -> Option<&mut CacheList> {
None
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any;
fn state_count(&self) -> Result<usize> {
self.state().map(|s| s.len())
}
fn reference_class_name(&self) -> &'static str;
}
pub trait QuantizedKvCache: KvCache {
fn group_size(&self) -> i32;
fn bits(&self) -> i32;
fn update_quantized(&mut self, keys: &Array, values: &Array) -> Result<(QTriple, QTriple)>;
fn quantized_state(&self) -> Result<Option<(QTriple, QTriple)>>;
}
pub trait BatchPositionedKvCache: KvCache {
fn batch_offset(&self) -> Result<Array>;
}
#[derive(Debug, Clone, PartialEq, Eq, derive_more::Display, derive_more::IsVariant)]
#[display("{}", self.as_str())]
pub enum KvCacheKind {
KvCache,
RotatingKvCache,
ChunkedKvCache,
QuantizedKvCache,
CacheList,
BatchKvCache,
BatchRotatingKvCache,
ArraysCache,
MambaCache,
}
impl KvCacheKind {
pub const fn as_str(&self) -> &'static str {
match self {
Self::KvCache => "KVCache",
Self::RotatingKvCache => "RotatingKVCache",
Self::ChunkedKvCache => "ChunkedKVCache",
Self::QuantizedKvCache => "QuantizedKVCache",
Self::CacheList => "CacheList",
Self::BatchKvCache => "BatchKVCache",
Self::BatchRotatingKvCache => "BatchRotatingKVCache",
Self::ArraysCache => "ArraysCache",
Self::MambaCache => "MambaCache",
}
}
pub fn parse(kind: &str) -> Result<Self> {
match kind {
"KVCache" | "ConcatenateKVCache" | "KVCacheSimple" | "StandardKvCache" => Ok(Self::KvCache),
"RotatingKVCache" | "RotatingKvCache" => Ok(Self::RotatingKvCache),
"ChunkedKVCache" | "ChunkedKvCache" => Ok(Self::ChunkedKvCache),
"QuantizedKVCache" | "StandardQuantizedKvCache" => Ok(Self::QuantizedKvCache),
"CacheList" => Ok(Self::CacheList),
"BatchKVCache" | "BatchKvCache" => Ok(Self::BatchKvCache),
"BatchRotatingKVCache" | "BatchRotatingKvCache" => Ok(Self::BatchRotatingKvCache),
"ArraysCache" => Ok(Self::ArraysCache),
"MambaCache" => Ok(Self::MambaCache),
other => Err(Error::UnknownEnumValue(UnknownEnumValuePayload::new(
"KvCacheKind",
other,
&[
"KVCache",
"ConcatenateKVCache",
"KVCacheSimple",
"StandardKvCache",
"RotatingKVCache",
"RotatingKvCache",
"ChunkedKVCache",
"ChunkedKvCache",
"QuantizedKVCache",
"StandardQuantizedKvCache",
"CacheList",
"BatchKVCache",
"BatchKvCache",
"BatchRotatingKVCache",
"BatchRotatingKvCache",
"ArraysCache",
"MambaCache",
],
))),
}
}
}
pub fn from_state(kind: &str, state: Vec<Array>, meta: &[String]) -> Result<Box<dyn KvCache>> {
match KvCacheKind::parse(kind)? {
KvCacheKind::KvCache => {
let mut c = StandardKvCache::new();
c.set_state(state)?;
c.set_meta_state(meta)?;
Ok(Box::new(c))
}
KvCacheKind::RotatingKvCache => {
let mut c = RotatingKvCache::new(0, 0);
c.set_state(state)?;
c.set_meta_state(meta)?;
if c.is_empty() && (c.offset() != 0 || c.idx() != 0) {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"RotatingKvCache::from_state: empty state with non-zero offset/idx",
"must satisfy offset=0 AND idx=0 when buffer is empty",
)));
}
Ok(Box::new(c))
}
KvCacheKind::ChunkedKvCache => {
let mut c = ChunkedKvCache::new(None);
c.set_state(state)?;
c.set_meta_state(meta)?;
Ok(Box::new(c))
}
KvCacheKind::QuantizedKvCache => {
let mut c = StandardQuantizedKvCache::new_unchecked(0, 0);
c.set_state(state)?;
c.set_meta_state(meta)?;
if c.is_empty() && c.offset() != 0 {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"QuantizedKvCache::from_state: empty state with non-zero offset",
"must satisfy offset=0 when buffer is empty",
)));
}
c.enforce_offset_len_invariant()?;
Ok(Box::new(c))
}
KvCacheKind::CacheList => cache_list::cache_list_from_state(state, meta),
KvCacheKind::BatchKvCache => {
let mut c = BatchKvCache::new(&[]);
c.set_state(state)?;
c.set_meta_state(meta)?;
Ok(Box::new(c))
}
KvCacheKind::BatchRotatingKvCache => {
let mut c = BatchRotatingKvCache::new(0, &[]);
c.set_state(state)?;
c.set_meta_state(meta)?;
if c.is_empty() {
let offset = c.offset();
let idx = c.ring_idx();
let rotated = c.is_rotated();
if offset != 0 || idx != 0 || rotated {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"BatchRotatingKvCache::from_serialized: empty buffer (keys=None) requires fully-fresh meta",
"must satisfy offset=0 AND _idx=0 AND rotated=false",
format_smolstr!("offset={offset}, _idx={idx}, rotated={rotated}"),
)));
}
} else {
let l = c.buf_seq_len()?.unwrap_or(0);
let max_size = c.max_window();
let idx = c.ring_idx();
let offset = c.offset();
let rotated = c.is_rotated();
if max_size == 0 {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"BatchRotatingKvCache::from_serialized: max_size",
"must be >= 1 for a non-empty buffer (max_size=0 is only the pre-setter placeholder)",
)));
} else if idx > l {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"BatchRotatingKvCache::from_serialized: _idx (write cursor must not exceed physical buffer seq-len L)",
"must satisfy _idx <= L",
format_smolstr!("_idx={idx}, L={l}"),
)));
} else if rotated && l != max_size {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"BatchRotatingKvCache::from_serialized: rotated=true requires L == max_size",
max_size,
l,
)));
} else if l > offset {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"BatchRotatingKvCache::from_serialized: L (keys seq-len; mlx-lm getter emits keys[:_offset,:], so L <= _offset always)",
"must satisfy L <= _offset",
format_smolstr!("L={l}, _offset={offset}"),
)));
}
}
Ok(Box::new(c))
}
KvCacheKind::ArraysCache => arrays::from_state_arrays(state, meta, false),
KvCacheKind::MambaCache => arrays::from_state_arrays(state, meta, true),
}
}
pub struct CacheConfig {
pub num_hidden_layers: usize,
pub sliding_window: Option<i32>,
}
pub fn make_prompt_cache(cfg: &CacheConfig) -> Vec<Box<dyn KvCache>> {
(0..cfg.num_hidden_layers)
.map(|_| -> Box<dyn KvCache> {
match cfg.sliding_window {
Some(window) => Box::new(RotatingKvCache::new(
window.max(0) as usize,
ROTATING_DEFAULT_KEEP.max(0) as usize,
)),
None => Box::new(StandardKvCache::new()),
}
})
.collect()
}