use std::str::FromStr;
use crate::{
array::Array,
error::{
ArithmeticOverflowPayload, CapExceededPayload, Error, InvariantViolationPayload,
OutOfRangePayload, ParsePayload, RankMismatchPayload, Result,
},
lm::cache::{KvCache, MaskMode, mask},
ops,
};
use smol_str::format_smolstr;
pub const MAX_SLOT_COUNT: usize = 1 << 20;
fn parse_csv<T>(
s: &str,
what_context: &'static str,
what_input_kind: &'static str,
max_elems: usize,
) -> Result<Vec<T>>
where
T: FromStr,
T::Err: std::error::Error + Send + Sync + 'static,
{
if s.is_empty() {
return Ok(Vec::new());
}
let comma_count = s.bytes().filter(|&b| b == b',').count();
let upper_bound = comma_count.saturating_add(1);
if upper_bound > max_elems {
return Err(Error::CapExceeded(CapExceededPayload::new(
what_context,
"max_elems",
max_elems as u64,
upper_bound as u64,
)));
}
let mut out: Vec<T> = Vec::new();
out
.try_reserve_exact(upper_bound)
.map_err(|_| Error::OutOfMemory)?;
for p in s.split(',') {
let v = p.parse::<T>().map_err(|e| {
Error::Parse(ParsePayload::new(
what_context,
what_input_kind,
Box::new(e),
))
})?;
out.push(v);
}
Ok(out)
}
pub struct ArraysCache {
cache: Vec<Option<Array>>,
left_padding: Option<Vec<i32>>,
lengths: Option<Vec<i32>>,
is_mamba: bool,
}
impl ArraysCache {
pub fn new(size: usize) -> Self {
Self {
cache: (0..size).map(|_| None).collect(),
left_padding: None,
lengths: None,
is_mamba: false,
}
}
pub fn mamba() -> Self {
Self {
cache: vec![None, None],
left_padding: None,
lengths: None,
is_mamba: true,
}
}
pub fn with_left_padding(size: usize, left_padding: &[i32]) -> Self {
let mut c = Self::new(size);
if !left_padding.is_empty() {
c.left_padding = Some(left_padding.to_vec());
}
c
}
pub fn is_mamba(&self) -> bool {
self.is_mamba
}
fn build_from_serialized(state: Vec<Array>, meta: &[String], is_mamba: bool) -> Result<Self> {
let mut c = Self::new(0);
c.set_state(state)?;
c.set_meta_state(meta)?;
c.is_mamba = is_mamba;
Ok(c)
}
pub fn get(&self, idx: usize) -> Option<&Array> {
self.cache.get(idx).and_then(|s| s.as_ref())
}
pub fn set(&mut self, idx: usize, value: Array) -> Result<()> {
match self.cache.get_mut(idx) {
Some(slot) => {
*slot = Some(value);
Ok(())
}
None => Err(Error::OutOfRange(OutOfRangePayload::new(
"ArraysCache::set: slot index (must be < cache size)",
"must be < cache size",
format_smolstr!("idx={idx}, size={}", self.cache.len()),
))),
}
}
pub fn batch_size(&self) -> Result<usize> {
if let Some(slot) = self.cache.iter().flatten().next() {
let shape = slot.shape();
return match shape.first() {
Some(&b) => Ok(b),
None => Err(Error::RankMismatch(RankMismatchPayload::new(
"ArraysCache::batch_size: slot must have rank >= 1 (a leading axis to read as batch size)",
0,
shape.to_vec(),
))),
};
}
if let Some(lp) = &self.left_padding {
return Ok(lp.len());
}
if let Some(l) = &self.lengths {
return Ok(l.len());
}
Ok(1)
}
pub fn prepare(&mut self, lengths: &[i32]) {
self.lengths = Some(lengths.to_vec());
}
pub fn finalize(&mut self) {
self.lengths = None;
self.left_padding = None;
}
pub fn advance(&mut self, n: usize) -> Result<()> {
let n = i32::try_from(n).map_err(|_| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"ArraysCache::advance: N exceeds i32::MAX",
"i32",
[("N", n as u64)],
))
})?;
if let Some(l) = &mut self.lengths {
for v in l.iter_mut() {
*v = v.wrapping_sub(n);
}
}
if let Some(lp) = &mut self.left_padding {
for v in lp.iter_mut() {
*v = v.wrapping_sub(n);
}
}
Ok(())
}
pub fn left_padding(&self) -> Option<&[i32]> {
self.left_padding.as_deref()
}
pub fn lengths(&self) -> Option<&[i32]> {
self.lengths.as_deref()
}
}
impl KvCache for ArraysCache {
fn offset(&self) -> usize {
0
}
fn update(&mut self, _keys: &Array, _values: &Array) -> Result<(Array, Array)> {
Err(Error::InvariantViolation(InvariantViolationPayload::new(
"ArraysCache::update (generic slot cache, not K/V)",
"must use get/set/state instead; update_and_fetch is unsupported",
)))
}
fn state(&self) -> Result<Vec<Array>> {
self.cache.iter().flatten().map(|a| a.try_clone()).collect()
}
fn materialize(&mut self) -> Result<()> {
for slot in self.cache.iter_mut().flatten() {
slot.eval()?;
}
Ok(())
}
fn set_state(&mut self, state: Vec<Array>) -> Result<()> {
self.cache = state.into_iter().map(Some).collect();
Ok(())
}
fn meta_state(&self) -> Vec<String> {
let present: Vec<String> = self
.cache
.iter()
.enumerate()
.filter_map(|(i, s)| s.as_ref().map(|_| i.to_string()))
.collect();
let mut out = vec![self.cache.len().to_string(), present.join(",")];
if let Some(lp) = &self.left_padding {
out.push(lp.iter().map(i32::to_string).collect::<Vec<_>>().join(","));
}
out
}
fn set_meta_state(&mut self, m: &[String]) -> Result<()> {
if m.is_empty() || (m.len() == 1 && m[0].is_empty()) {
return Ok(());
}
let slot_count: usize = m[0].parse().map_err(|e: std::num::ParseIntError| {
Error::Parse(ParsePayload::new(
"ArraysCache meta_state slotCount",
"usize",
Box::new(e),
))
})?;
if m.len() < 2 {
return Err(Error::InvariantViolation(InvariantViolationPayload::new(
"ArraysCache::set_meta_state: slot-aware meta_state",
"must be [slotCount, presentSlots, leftPadding?] (length >= 2)",
)));
}
if slot_count > MAX_SLOT_COUNT {
return Err(Error::CapExceeded(CapExceededPayload::new(
"ArraysCache::set_meta_state: slot_count (realistic SSM/Mamba caches use <= 64 slots; the cap fails fast on forged/corrupt meta_state)",
"MAX_SLOT_COUNT",
MAX_SLOT_COUNT as u64,
slot_count as u64,
)));
}
let elem_size = std::mem::size_of::<Option<Array>>().max(1);
if slot_count > (isize::MAX as usize) / elem_size {
return Err(Error::ArithmeticOverflow(
ArithmeticOverflowPayload::with_operands(
"ArraysCache::set_meta_state: slot_count * sizeof::<Option<Array>>() (capacity overflow exceeds isize::MAX)",
"isize",
[
("slot_count", slot_count as u64),
("elem_size", elem_size as u64),
],
),
));
}
let present = parse_csv::<usize>(
&m[1],
"ArraysCache meta_state presentSlots",
"CSV<usize>",
slot_count,
)?;
let left_padding = match m.get(2) {
Some(s) => Some(parse_csv::<i32>(
s,
"ArraysCache meta_state leftPadding",
"CSV<i32>",
slot_count,
)?),
None => None,
};
let mut rebuilt: Vec<Option<Array>> = Vec::new();
rebuilt
.try_reserve_exact(slot_count)
.map_err(|_| Error::OutOfMemory)?;
rebuilt.resize_with(slot_count, || None);
let mut arrays: Vec<Option<Array>> = std::mem::take(&mut self.cache);
for (array_idx, &slot_idx) in present.iter().enumerate() {
if slot_idx < slot_count
&& let Some(a) = arrays.get_mut(array_idx).and_then(Option::take)
{
rebuilt[slot_idx] = Some(a);
}
}
self.cache = rebuilt;
self.left_padding = left_padding;
Ok(())
}
fn make_mask(
&self,
n: usize,
_window_size: Option<usize>,
_return_array: bool,
) -> Result<MaskMode> {
let col = |v: &[i32]| -> Result<Array> { Array::from_slice::<i32>(v, &(v.len(), 1usize)) };
if let Some(lp) = &self.left_padding {
let pos = mask::iarange(0, n)?;
return Ok(MaskMode::Array(ops::comparison::greater_equal(
&pos,
&col(lp)?,
)?));
}
if let Some(l) = &self.lengths {
let pos = mask::iarange(0, n)?;
return Ok(MaskMode::Array(ops::comparison::less(&pos, &col(l)?)?));
}
Ok(MaskMode::None)
}
fn nbytes(&self) -> usize {
self
.cache
.iter()
.flatten()
.map(|a| super::util::nbytes(a).unwrap_or(0))
.sum()
}
fn is_empty(&self) -> bool {
match self.cache.first() {
Some(slot) => slot.is_none(),
None => true,
}
}
fn copy(&self) -> Result<Box<dyn KvCache>> {
let mut cache = Vec::with_capacity(self.cache.len());
for slot in &self.cache {
cache.push(match slot {
Some(a) => Some(a.try_clone()?),
None => None,
});
}
Ok(Box::new(Self {
cache,
left_padding: self.left_padding.clone(),
lengths: self.lengths.clone(),
is_mamba: self.is_mamba,
}))
}
fn reference_class_name(&self) -> &'static str {
if self.is_mamba {
"MambaCache"
} else {
"ArraysCache"
}
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
#[allow(clippy::wrong_self_convention)] fn from_serialized(&mut self, state: Vec<Array>, meta: &[String]) -> Result<()> {
*self = ArraysCache::build_from_serialized(state, meta, self.is_mamba)?;
Ok(())
}
}
pub(super) fn from_state_arrays(
state: Vec<Array>,
meta: &[String],
is_mamba: bool,
) -> Result<Box<dyn KvCache>> {
Ok(Box::new(ArraysCache::build_from_serialized(
state, meta, is_mamba,
)?))
}