use crate::{
array::Array,
error::{
ArithmeticOverflowPayload, Error, LengthMismatchPayload, OutOfRangePayload, ParsePayload,
Result,
},
lm::cache::{
KvCache, MaskMode, mask,
util::{KV_NDIM, concat_seq, nbytes, seq_len, seq_slice},
},
ops,
};
use smol_str::format_smolstr;
pub struct ChunkedKvCache {
keys: Option<Array>,
values: Option<Array>,
offset: usize,
chunk_size: Option<usize>,
start_position: usize,
}
const CHUNKED_STEP: usize = 256;
impl ChunkedKvCache {
pub fn new(chunk_size: Option<usize>) -> Self {
Self {
keys: None,
values: None,
offset: 0,
chunk_size,
start_position: 0,
}
}
pub fn maybe_trim_front(&mut self) -> Result<()> {
let chunk_size = match self.chunk_size {
Some(c) => c,
None => return Ok(()),
};
let (k, v) = match (&self.keys, &self.values) {
(Some(k), Some(v)) => (k, v),
_ => return Ok(()),
};
let buf_len = seq_len("keys", k)?;
let v_len = seq_len("values", v)?;
if buf_len >= chunk_size {
let added = buf_len - chunk_size;
let new_start = self.start_position.checked_add(added).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"ChunkedKvCache::maybe_trim_front: start_position + added",
"usize",
[
("start_position", self.start_position as u64),
("added", added as u64),
],
))
})?;
let k_start = if chunk_size == 0 {
0
} else {
buf_len - chunk_size
};
let v_start = if chunk_size == 0 {
0
} else {
v_len.saturating_sub(chunk_size)
};
let new_keys = seq_slice(k, k_start, buf_len)?;
let new_values = seq_slice(v, v_start, v_len)?;
self.start_position = new_start;
self.keys = Some(new_keys);
self.values = Some(new_values);
}
Ok(())
}
fn set_seq(name: &str, buf: &Array, a: usize, s: usize, new: &Array) -> Result<Array> {
let l = seq_len(name, buf)?;
let end = a.checked_add(s).ok_or_else(|| {
let context: &'static str = match name {
"keys" => "ChunkedKvCache::set_seq: keys write start + S",
"values" => "ChunkedKvCache::set_seq: values write start + S",
_ => "ChunkedKvCache::set_seq: write start + S",
};
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
context,
"usize",
[("start", a as u64), ("S", s as u64)],
))
})?;
if end > l {
let context: &'static str = match name {
"keys" => "ChunkedKvCache::set_seq: keys write window end (extends past buffer length)",
"values" => "ChunkedKvCache::set_seq: values write window end (extends past buffer length)",
_ => "ChunkedKvCache::set_seq: write window end (extends past buffer length)",
};
return Err(Error::OutOfRange(OutOfRangePayload::new(
context,
"must be <= buffer length L",
format_smolstr!("start={a}, end={end}, L={l}"),
)));
}
let new = super::util::broadcast_write_rhs(name, buf, a, end, new)?;
let head = seq_slice(buf, 0, a)?;
let tail = seq_slice(buf, end, l)?;
super::util::concat_parts(&[&head, &new, &tail])
}
}
impl KvCache for ChunkedKvCache {
fn offset(&self) -> usize {
self.offset
}
fn update(&mut self, keys: &Array, values: &Array) -> Result<(Array, Array)> {
let s = seq_len("keys", keys)?;
let _vs = seq_len("values", values)?;
let ks = keys.shape();
let vshape = values.shape();
let (b, n_kv_heads, k_head_dim) = (ks[0], ks[1], ks[KV_NDIM - 1]);
let v_head_dim = vshape[KV_NDIM - 1];
let prev = self
.offset
.checked_sub(self.start_position)
.ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"ChunkedKvCache::update: offset - start_position (start_position must not exceed offset)",
"usize",
[
("offset", self.offset as u64),
("start_position", self.start_position as u64),
],
))
})?;
let cur_buf = match &self.keys {
Some(k) => Some(seq_len("keys", k)?),
None => None,
};
let prev_plus_s = prev.checked_add(s).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"ChunkedKvCache::update: prev + S",
"usize",
[("prev", prev as u64), ("S", s as u64)],
))
})?;
let need_alloc = match cur_buf {
None => true,
Some(buf_len) => prev_plus_s > buf_len,
};
let (buf_k, buf_v) = if need_alloc {
let n_steps = (CHUNKED_STEP + s).saturating_sub(1) / CHUNKED_STEP;
let total = n_steps.checked_mul(CHUNKED_STEP).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"ChunkedKvCache::update: n_steps * step",
"usize",
[("n_steps", n_steps as u64), ("step", CHUNKED_STEP as u64)],
))
})?;
let new_k = ops::misc::astype(
&Array::zeros::<f32>(&(b, n_kv_heads, total, k_head_dim))?,
keys.dtype()?,
)?;
let new_v = ops::misc::astype(
&Array::zeros::<f32>(&(b, n_kv_heads, total, v_head_dim))?,
values.dtype()?,
)?;
match (&self.keys, &self.values) {
(Some(pk), Some(pv)) => {
let (bk_owned, bv_owned) = if prev % CHUNKED_STEP != 0 {
(Some(seq_slice(pk, 0, prev)?), Some(seq_slice(pv, 0, prev)?))
} else {
(None, None)
};
let bk_ref: &Array = bk_owned.as_ref().unwrap_or(pk);
let bv_ref: &Array = bv_owned.as_ref().unwrap_or(pv);
(concat_seq(bk_ref, &new_k)?, concat_seq(bv_ref, &new_v)?)
}
_ => (new_k, new_v),
}
} else {
let pk = self
.keys
.as_ref()
.expect("need_alloc=false implies self.keys is Some");
let pv = self
.values
.as_ref()
.expect("need_alloc=false implies self.values is Some");
(pk.try_clone()?, pv.try_clone()?)
};
let new_offset = self.offset.checked_add(s).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"ChunkedKvCache::update: offset + S",
"usize",
[("offset", self.offset as u64), ("S", s as u64)],
))
})?;
let end = new_offset.checked_sub(self.start_position).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"ChunkedKvCache::update: new_offset - start_position (start_position must not exceed new_offset)",
"usize",
[
("new_offset", new_offset as u64),
("start_position", self.start_position as u64),
],
))
})?;
let nk = Self::set_seq("keys", &buf_k, prev, s, keys)?;
let nv = Self::set_seq("values", &buf_v, prev, s, values)?;
let ret_k = seq_slice(&nk, 0, end)?;
let ret_v = seq_slice(&nv, 0, end)?;
self.offset = new_offset;
self.keys = Some(nk);
self.values = Some(nv);
Ok((ret_k, ret_v))
}
fn state(&self) -> Result<Vec<Array>> {
match (&self.keys, &self.values) {
(Some(k), Some(v)) => {
let buf_len = seq_len("keys", k)?;
seq_len("values", v)?;
if self.offset == buf_len {
Ok(vec![k.try_clone()?, v.try_clone()?])
} else {
Ok(vec![
seq_slice(k, 0, self.offset)?,
seq_slice(v, 0, self.offset)?,
])
}
}
_ => Ok(Vec::new()),
}
}
fn materialize(&mut self) -> Result<()> {
if let Some(k) = self.keys.as_mut() {
k.eval()?;
}
if let Some(v) = self.values.as_mut() {
v.eval()?;
}
Ok(())
}
fn set_state(&mut self, mut state: Vec<Array>) -> Result<()> {
match state.len() {
2 => {
let values = state.pop().unwrap();
let keys = state.pop().unwrap();
let sk = seq_len("keys", &keys)?;
seq_len("values", &values)?;
self.keys = Some(keys);
self.values = Some(values);
self.offset = sk;
Ok(())
}
n => Err(Error::LengthMismatch(LengthMismatchPayload::new(
"ChunkedKvCache::set_state: state arrays (its setter unpacks keys, values = v; empty/other is invalid)",
2,
n,
))),
}
}
fn meta_state(&self) -> Vec<String> {
let chunk = match self.chunk_size {
Some(c) => c.to_string(),
None => "None".to_string(),
};
vec![chunk, self.start_position.to_string()]
}
fn set_meta_state(&mut self, m: &[String]) -> Result<()> {
if m.len() != 2 {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"ChunkedKvCache::set_meta_state: meta_state values",
2,
m.len(),
)));
}
let chunk_size = if m[0] == "None" {
None
} else {
Some(
m[0]
.parse::<usize>()
.map_err(|e: std::num::ParseIntError| {
Error::Parse(ParsePayload::new(
"ChunkedKvCache::set_meta_state: chunk_size",
"usize",
Box::new(e),
))
})?,
)
};
let start_position = m[1]
.parse::<usize>()
.map_err(|e: std::num::ParseIntError| {
Error::Parse(ParsePayload::new(
"ChunkedKvCache::set_meta_state: start_position",
"usize",
Box::new(e),
))
})?;
self.chunk_size = chunk_size;
self.start_position = start_position;
Ok(())
}
fn is_trimmable(&self) -> bool {
true
}
fn trim(&mut self, n: usize) -> Result<usize> {
let span = self
.offset
.checked_sub(self.start_position)
.ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"ChunkedKvCache::trim: offset - start_position (start_position must not exceed offset)",
"usize",
[
("offset", self.offset as u64),
("start_position", self.start_position as u64),
],
))
})?;
let trimmed = span.min(n);
self.offset -= trimmed;
Ok(trimmed)
}
fn make_mask(
&self,
n: usize,
window_size: Option<usize>,
return_array: bool,
) -> Result<MaskMode> {
mask::create_attention_mask(n, self.offset(), return_array, window_size)
}
fn nbytes(&self) -> usize {
let mut total = 0;
if let Some(k) = &self.keys {
total += nbytes(k).unwrap_or(0);
}
if let Some(v) = &self.values {
total += nbytes(v).unwrap_or(0);
}
total
}
fn is_empty(&self) -> bool {
self.keys.is_none()
}
fn copy(&self) -> Result<Box<dyn KvCache>> {
Ok(Box::new(Self {
keys: match &self.keys {
Some(a) => Some(a.try_clone()?),
None => None,
},
values: match &self.values {
Some(a) => Some(a.try_clone()?),
None => None,
},
offset: self.offset,
chunk_size: self.chunk_size,
start_position: self.start_position,
}))
}
fn reference_class_name(&self) -> &'static str {
"ChunkedKVCache"
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn from_serialized(&mut self, state: Vec<Array>, meta: &[String]) -> Result<()> {
let mut staged = ChunkedKvCache::new(None);
staged.set_state(state)?;
staged.set_meta_state(meta)?;
*self = staged;
Ok(())
}
}
#[cfg(test)]
mod tests;