use crate::{
array::Array,
error::{
ArithmeticOverflowPayload, CapExceededPayload, Error, InvariantViolationPayload,
LayerKeyedPayload, LengthMismatchPayload, ParsePayload, Result,
},
lm::cache::{KvCache, MaskMode, RopeOffset},
};
use smol_str::format_smolstr;
pub struct CacheList {
caches: Vec<Box<dyn KvCache>>,
}
impl CacheList {
pub fn new(caches: Vec<Box<dyn KvCache>>) -> Self {
Self { caches }
}
#[allow(clippy::len_without_is_empty)]
pub fn len(&self) -> usize {
self.caches.len()
}
pub fn is_child_list_empty(&self) -> bool {
self.caches.is_empty()
}
pub fn get(&self, idx: usize) -> Option<&dyn KvCache> {
self.caches.get(idx).map(|b| b.as_ref())
}
pub fn get_mut(&mut self, idx: usize) -> Option<&mut (dyn KvCache + 'static)> {
self.caches.get_mut(idx).map(|b| b.as_mut())
}
}
impl KvCache for CacheList {
fn offset(&self) -> usize {
self.caches.iter().map(|c| c.offset()).max().unwrap_or(0)
}
fn rope_offset(&self) -> Result<RopeOffset> {
Ok(RopeOffset::Scalar(self.offset()))
}
fn update(&mut self, _keys: &Array, _values: &Array) -> Result<(Array, Array)> {
Err(Error::InvariantViolation(InvariantViolationPayload::new(
"CacheList::update",
"is invalid — index a child via CacheList::get_mut and update that child",
)))
}
fn state(&self) -> Result<Vec<Array>> {
let mut out = Vec::new();
for c in &self.caches {
c.state_into(&mut out)?;
}
Ok(out)
}
fn state_into(&self, buf: &mut Vec<Array>) -> Result<()> {
for c in &self.caches {
c.state_into(buf)?;
}
Ok(())
}
fn set_state(&mut self, state: Vec<Array>) -> Result<()> {
let mut lengths = Vec::with_capacity(self.caches.len());
for c in &self.caches {
lengths.push(c.state_count()?);
}
let total: usize = lengths.iter().sum();
if total != state.len() {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"CacheList::set_state: flattened state array count vs sum of children state_count",
total,
state.len(),
)));
}
let mut staged: Vec<Box<dyn KvCache>> = Vec::with_capacity(self.caches.len());
for c in &self.caches {
staged.push(c.copy()?);
}
let mut it = state.into_iter();
for (c, &len) in staged.iter_mut().zip(lengths.iter()) {
let chunk: Vec<Array> = it.by_ref().take(len).collect();
c.set_state(chunk)?;
}
self.caches = staged;
Ok(())
}
fn materialize(&mut self) -> Result<()> {
for c in &mut self.caches {
c.materialize()?;
}
Ok(())
}
fn meta_state(&self) -> Vec<String> {
let mut out = Vec::new();
self.meta_state_into(&mut out);
out
}
fn meta_state_into(&self, buf: &mut Vec<String>) {
buf.push(self.caches.len().to_string());
for c in &self.caches {
let class_name = c.reference_class_name();
let state_count = c
.state_count()
.or_else(|_| c.state().map(|s| s.len()))
.unwrap_or(0);
buf.push(class_name.to_string());
buf.push(state_count.to_string());
let count_slot = buf.len();
buf.push(String::new());
let before = buf.len();
c.meta_state_into(buf);
let appended = buf.len() - before;
buf[count_slot] = appended.to_string();
}
}
fn set_meta_state(&mut self, _m: &[String]) -> Result<()> {
Err(Error::InvariantViolation(InvariantViolationPayload::new(
"CacheList::set_meta_state (direct call invalid)",
"must reconstruct via from_state(\"CacheList\", state, meta) (Swift: CacheList.fromState)",
)))
}
fn is_trimmable(&self) -> bool {
self.caches.iter().all(|c| c.is_trimmable())
}
fn trim(&mut self, n: usize) -> Result<usize> {
if !self.is_trimmable() {
return Ok(0);
}
let mut last = 0;
for c in &mut self.caches {
last = c.trim(n)?;
}
Ok(last)
}
fn make_mask(
&self,
_n: usize,
_window_size: Option<usize>,
_return_array: bool,
) -> Result<MaskMode> {
Err(Error::InvariantViolation(InvariantViolationPayload::new(
"CacheList::make_mask (composite is never masked directly)",
"must mask per child via CacheList::get (mlx-lm CacheList/_BaseCache define no make_mask; masking is per child)",
)))
}
fn nbytes(&self) -> usize {
self.caches.iter().map(|c| c.nbytes()).sum()
}
fn is_empty(&self) -> bool {
match self.caches.first() {
Some(c) => c.is_empty(),
None => true,
}
}
fn copy(&self) -> Result<Box<dyn KvCache>> {
let mut copied = Vec::with_capacity(self.caches.len());
for c in &self.caches {
copied.push(c.copy()?);
}
Ok(Box::new(Self { caches: copied }))
}
fn as_cache_list(&self) -> Option<&CacheList> {
Some(self)
}
fn as_cache_list_mut(&mut self) -> Option<&mut CacheList> {
Some(self)
}
fn state_count(&self) -> Result<usize> {
let mut total = 0usize;
for c in &self.caches {
total = total
.checked_add(c.state_count()?)
.ok_or(Error::ArithmeticOverflow(ArithmeticOverflowPayload::new(
"CacheList::state_count",
"usize",
)))?;
}
Ok(total)
}
fn reference_class_name(&self) -> &'static str {
"CacheList"
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
#[allow(clippy::wrong_self_convention)] fn from_serialized(&mut self, state: Vec<Array>, meta: &[String]) -> Result<()> {
let children = build_cache_list_children(state, meta, CACHE_LIST_MAX_NESTING_DEPTH)?;
*self = CacheList::new(children);
Ok(())
}
}
pub(crate) fn cache_list_from_state(
state: Vec<Array>,
meta: &[String],
) -> Result<Box<dyn KvCache>> {
cache_list_from_state_bounded(state, meta, CACHE_LIST_MAX_NESTING_DEPTH)
}
const CACHE_LIST_MAX_NESTING_DEPTH: usize = 64;
fn cache_list_from_state_bounded(
state: Vec<Array>,
meta: &[String],
depth_budget: usize,
) -> Result<Box<dyn KvCache>> {
let children = build_cache_list_children(state, meta, depth_budget)?;
Ok(Box::new(CacheList::new(children)))
}
fn build_cache_list_children(
state: Vec<Array>,
meta: &[String],
depth_budget: usize,
) -> Result<Vec<Box<dyn KvCache>>> {
let Some(child_depth_budget) = depth_budget.checked_sub(1) else {
return Err(Error::CapExceeded(CapExceededPayload::new(
"CacheList::from_state: nesting depth (deeper chain rejected as a forged/corrupt prompt cache, not a stack-overflow abort)",
"CACHE_LIST_MAX_NESTING_DEPTH",
CACHE_LIST_MAX_NESTING_DEPTH as u64,
CACHE_LIST_MAX_NESTING_DEPTH as u64,
)));
};
let first = meta.first().ok_or_else(|| {
Error::InvariantViolation(InvariantViolationPayload::new(
"CacheList::from_state: meta_state",
"must be non-empty (first element is child count)",
))
})?;
let child_count: usize = first.parse().map_err(|e: std::num::ParseIntError| {
Error::Parse(ParsePayload::new(
"CacheList::from_state: child count",
"usize",
Box::new(e),
))
})?;
let max_children = meta.len().saturating_sub(1) / 3;
if child_count > max_children {
return Err(Error::CapExceeded(CapExceededPayload::new(
"CacheList::from_state: child count (3 framing fields per child)",
"max_children_for_meta",
max_children as u64,
child_count as u64,
)));
}
let mut children: Vec<Box<dyn KvCache>> = Vec::new();
let mut meta_idx = 1usize; let mut state_it = state.into_iter();
let mut state_remaining = state_it.len();
for child in 0..child_count {
let layer = |inner: Error| -> Error {
Error::LayerKeyed(LayerKeyedPayload::new(
format_smolstr!("child {child}"),
inner,
))
};
if meta_idx + 2 >= meta.len() {
return Err(layer(Error::InvariantViolation(
InvariantViolationPayload::new(
"CacheList::from_state: meta_state truncated at child frame (need class/state/meta counts)",
"must have at least 3 meta entries remaining for each child frame",
),
)));
}
let class_name: &str = &meta[meta_idx];
let state_count: usize = meta[meta_idx + 1]
.parse()
.map_err(|e: std::num::ParseIntError| {
layer(Error::Parse(ParsePayload::new(
"CacheList::from_state: child stateCount",
"usize",
Box::new(e),
)))
})?;
let meta_count: usize = meta[meta_idx + 2]
.parse()
.map_err(|e: std::num::ParseIntError| {
layer(Error::Parse(ParsePayload::new(
"CacheList::from_state: child metaCount",
"usize",
Box::new(e),
)))
})?;
meta_idx += 3;
let meta_end = meta_idx.checked_add(meta_count).ok_or_else(|| {
layer(Error::ArithmeticOverflow(
ArithmeticOverflowPayload::with_operands(
"CacheList::from_state: meta_idx + metaCount",
"usize",
[
("meta_idx", meta_idx as u64),
("metaCount", meta_count as u64),
],
),
))
})?;
if meta_end > meta.len() {
return Err(layer(Error::LengthMismatch(LengthMismatchPayload::new(
"CacheList::from_state: child metaCount exceeds remaining meta values",
meta.len().saturating_sub(meta_idx),
meta_count,
))));
}
let child_meta = &meta[meta_idx..meta_end];
meta_idx = meta_end;
if state_count > state_remaining {
return Err(layer(Error::LengthMismatch(LengthMismatchPayload::new(
"CacheList::from_state: child stateCount exceeds remaining state arrays",
state_remaining,
state_count,
))));
}
let child_state: Vec<Array> = state_it.by_ref().take(state_count).collect();
state_remaining -= state_count;
let child_cache = if class_name == "CacheList" {
cache_list_from_state_bounded(child_state, child_meta, child_depth_budget)?
} else {
super::from_state(class_name, child_state, child_meta)?
};
children.push(child_cache);
}
if state_remaining != 0 {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"CacheList::from_state: state array consumption after all children (framing/payload mismatch)",
0,
state_remaining,
)));
}
if meta_idx != meta.len() {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"CacheList::from_state: meta value consumption after all children (framing/payload mismatch)",
meta.len(),
meta_idx,
)));
}
Ok(children)
}