use crate::{
array::Array,
dtype::Dtype,
error::{
ArithmeticOverflowPayload, DtypeMismatchPayload, Error, InvariantViolationPayload,
LengthMismatchPayload, OutOfRangePayload, RankMismatchPayload, Result,
ShapePairMismatchPayload,
},
lm::cache::{
BatchPositionedKvCache, KvCache, MaskMode,
util::{KV_NDIM, concat_seq, nbytes, seq_len, slice_seq},
},
ops,
};
use smol_str::format_smolstr;
pub(crate) fn batch_head_dim(name: &str, a: &Array) -> Result<usize> {
let shape = a.shape();
if shape.len() != KV_NDIM {
let context: &'static str = match name {
"keys" => "batch_head_dim: batched KV cache expects 4-D keys [B, n_kv_heads, S, head_dim]",
"values" => {
"batch_head_dim: batched KV cache expects 4-D values [B, n_kv_heads, S, head_dim]"
}
_ => "batch_head_dim: batched KV cache expects 4-D [B, n_kv_heads, S, head_dim]",
};
return Err(Error::RankMismatch(RankMismatchPayload::new(
context,
shape.len() as u32,
shape.to_vec(),
)));
}
Ok(shape[KV_NDIM - 1])
}
pub(crate) fn validate_kv_compat(keys: &Array, values: &Array) -> Result<()> {
let ks = keys.shape();
let vs = values.shape();
if ks.len() != KV_NDIM {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"batched KV cache expects 4-D keys [B, n_kv_heads, S, head_dim]",
ks.len() as u32,
ks.to_vec(),
)));
}
if vs.len() != KV_NDIM {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"batched KV cache expects 4-D values [B, n_kv_heads, S, head_dim]",
vs.len() as u32,
vs.to_vec(),
)));
}
if ks[0] != vs[0] || ks[1] != vs[1] || ks[2] != vs[2] {
return Err(Error::ShapePairMismatch(ShapePairMismatchPayload::new(
"batched KV cache: values shape must match keys on [B, n_kv_heads, S] (head_dim free; mlx-lm raises at `self.values[..., prev:_idx, :] = values`)",
vec![ks[0], ks[1], ks[2]],
vec![vs[0], vs[1], vs[2]],
)));
}
Ok(())
}
pub(crate) fn ivec(values: &[i32]) -> Result<Array> {
Array::from_slice::<i32>(values, &(values.len(),))
}
fn neg_ivec(values: &[i32]) -> Result<Array> {
let negated: Vec<i32> = values.iter().map(|&l| -l).collect();
ivec(&negated)
}
pub fn dynamic_roll(x: &Array, shifts: &Array, axis: i32) -> Result<Array> {
let xshape = x.shape();
if xshape.len() != KV_NDIM {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"dynamic_roll: x must be 4-D [B, n_kv_heads, S, head_dim]",
xshape.len() as u32,
xshape.to_vec(),
)));
}
if axis != (KV_NDIM as i32) - 2 {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"dynamic_roll: axis (must be the sequence axis)",
"must equal KV_NDIM - 2 (the sequence axis = 2)",
format_smolstr!("{axis}"),
)));
}
let sshape = shifts.shape();
if sshape.len() != 2 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"dynamic_roll: shifts must be rank 2 ([B, 1] or scalar broadcast [1, 1])",
sshape.len() as u32,
sshape.to_vec(),
)));
}
let valid_b = sshape[0] == xshape[0] || sshape[0] == 1;
if !valid_b || sshape[1] != 1 {
return Err(Error::ShapePairMismatch(ShapePairMismatchPayload::new(
"dynamic_roll: shifts must be [B, 1] or [1, 1] (scalar broadcast)",
vec![xshape[0], 1usize],
sshape.to_vec(),
)));
}
let n = xshape[KV_NDIM - 2];
if n == 0 {
return x.try_clone();
}
const F32_EXACT_INT_MAX: usize = 1usize << 24;
if n > F32_EXACT_INT_MAX {
return Err(Error::OutOfRange(OutOfRangePayload::new(
"dynamic_roll: sequence axis n (arange/cast through f32 would silently alias indices and produce wrong rolls)",
"must be <= 2^24 (f32 exact-integer limit)",
format_smolstr!("{n}"),
)));
}
let ar = ops::misc::astype(&Array::arange::<f32>(0.0, n as f32, 1.0)?, Dtype::I32)?;
let ar = ops::shape::expand_dims_axes(&ar, &[1])?; let sh = ops::shape::expand_dims_axes(shifts, &[2, 3])?; let diff = ops::arithmetic::subtract(&ar, &sh)?;
let nscalar = ops::misc::astype(&Array::full::<f32>(&(1usize,), n as f32)?, Dtype::I32)?;
let idx = ops::arithmetic::remainder(&diff, &nscalar)?; ops::indexing::take_along_axis(x, &idx, axis)
}
pub struct BatchKvCache {
keys: Option<Array>,
values: Option<Array>,
left_padding: Array,
pad_lengths: Vec<i32>,
offset: Array,
idx: usize,
right_padding: Option<Array>,
right_padding_host: Option<Vec<i32>>,
}
impl BatchKvCache {
pub fn new(left_padding: &[i32]) -> Self {
let lp = ivec(left_padding).unwrap_or_else(|_| empty_ivec());
let offset = neg_ivec(left_padding).unwrap_or_else(|_| empty_ivec());
Self {
keys: None,
values: None,
left_padding: lp,
pad_lengths: left_padding.to_vec(),
offset,
idx: 0,
right_padding: None,
right_padding_host: None,
}
}
pub fn pad_lengths(&self) -> &[i32] {
&self.pad_lengths
}
pub fn prepare_right_padding(&mut self, right_padding: &[i32]) -> Result<()> {
if right_padding.iter().copied().max().unwrap_or(0) > 0 {
let rp = ivec(right_padding)?;
self.right_padding = Some(rp);
self.right_padding_host = Some(right_padding.to_vec());
}
Ok(())
}
pub fn finalize(&mut self) -> Result<()> {
if let Some(padding) = &self.right_padding {
let new_pad_lengths = match self.right_padding_host.as_ref() {
None => self.pad_lengths.clone(),
Some(rp_host) if rp_host.len() == self.pad_lengths.len() => self
.pad_lengths
.iter()
.zip(rp_host)
.map(|(&a, &b)| a.wrapping_add(b))
.collect::<Vec<i32>>(),
Some(rp_host) if rp_host.len() == 1 => {
let b = rp_host[0];
self
.pad_lengths
.iter()
.map(|&a| a.wrapping_add(b))
.collect::<Vec<i32>>()
}
Some(rp_host) => {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"BatchKvCache::finalize: right_padding length (must equal pad_lengths length or be a length-1 scalar broadcast — refusing to commit a desynchronized pad_lengths host mirror)",
self.pad_lengths.len(),
rp_host.len(),
)));
}
};
let pad_col = ops::shape::expand_dims_axes(padding, &[1])?;
let rolled = match (&self.keys, &self.values) {
(Some(k), Some(v)) => Some((dynamic_roll(k, &pad_col, 2)?, dynamic_roll(v, &pad_col, 2)?)),
_ => None,
};
let new_offset = ops::arithmetic::subtract(&self.offset, padding)?;
let new_left_padding = ops::arithmetic::add(&self.left_padding, padding)?;
if let Some((nk, nv)) = rolled {
self.keys = Some(nk);
self.values = Some(nv);
}
self.offset = new_offset;
self.left_padding = new_left_padding;
self.pad_lengths = new_pad_lengths;
self.right_padding = None;
self.right_padding_host = None;
}
Ok(())
}
pub fn left_padding_arr(&self) -> Result<Array> {
self.left_padding.try_clone()
}
pub fn state_kv(&self) -> Result<(Array, Array)> {
match (&self.keys, &self.values) {
(Some(k), Some(v)) => Ok((slice_seq(k, 0, self.idx)?, slice_seq(v, 0, self.idx)?)),
_ => Err(Error::InvariantViolation(InvariantViolationPayload::new(
"BatchKvCache::state_kv",
"must be called on a non-empty cache (keys/values both Some)",
))),
}
}
}
fn empty_ivec() -> Array {
Array::from_slice::<i32>(&[], &(0usize,)).unwrap_or_else(|_| {
Array(unsafe { mlxrs_sys::mlx_array_new() })
})
}
impl KvCache for BatchKvCache {
fn offset(&self) -> usize {
self.idx
}
fn update(&mut self, keys: &Array, values: &Array) -> Result<(Array, Array)> {
let s = seq_len("keys", keys)?;
validate_kv_compat(keys, values)?;
let (k, v) = match (&self.keys, &self.values) {
(Some(pk), Some(pv)) => (concat_seq(pk, keys)?, concat_seq(pv, values)?),
_ => (keys.try_clone()?, values.try_clone()?),
};
let new_idx = self.idx.checked_add(s).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"BatchKvCache::update: _idx + S",
"usize",
[("_idx", self.idx as u64), ("S", s as u64)],
))
})?;
let s_scalar = ops::misc::astype(&Array::full::<f32>(&(1usize,), s as f32)?, Dtype::I32)?;
let new_offset = ops::arithmetic::add(&self.offset, &s_scalar)?;
let (rk, rv) = (k.try_clone()?, v.try_clone()?);
self.offset = new_offset;
self.idx = new_idx;
self.keys = Some(k);
self.values = Some(v);
Ok((rk, rv))
}
fn state(&self) -> Result<Vec<Array>> {
match (&self.keys, &self.values) {
(Some(k), Some(v)) => Ok(vec![
slice_seq(k, 0, self.idx)?,
slice_seq(v, 0, self.idx)?,
self.offset.try_clone()?,
self.left_padding.try_clone()?,
]),
_ => Ok(Vec::new()),
}
}
fn materialize(&mut self) -> Result<()> {
if let Some(k) = self.keys.as_mut() {
k.eval()?;
}
if let Some(v) = self.values.as_mut() {
v.eval()?;
}
self.offset.eval()?;
self.left_padding.eval()?;
if let Some(rp) = self.right_padding.as_mut() {
rp.eval()?;
}
Ok(())
}
fn set_state(&mut self, mut state: Vec<Array>) -> Result<()> {
match state.len() {
0 => {
let new_offset = ops::arithmetic::negative(&self.left_padding)?;
self.keys = None;
self.values = None;
self.idx = 0;
self.offset = new_offset;
self.right_padding = None;
self.right_padding_host = None;
Ok(())
}
4 => {
let left_padding = state.pop().unwrap();
let offset = state.pop().unwrap();
let values = state.pop().unwrap();
let keys = state.pop().unwrap();
let sk = seq_len("keys", &keys)?;
batch_head_dim("values", &values)?;
let lp_shape = left_padding.shape();
let kb = keys.shape()[0];
if lp_shape.len() != 1 {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"BatchKvCache::set_state: restored left_padding must be 1-D [B]",
lp_shape.len() as u32,
lp_shape.to_vec(),
)));
}
if lp_shape[0] != kb {
return Err(Error::LengthMismatch(LengthMismatchPayload::new(
"BatchKvCache::set_state: restored left_padding length vs keys batch dim",
kb,
lp_shape[0],
)));
}
let lp_dtype = left_padding.dtype()?;
if lp_dtype != Dtype::I32 {
return Err(Error::DtypeMismatch(DtypeMismatchPayload::new(
Dtype::I32,
lp_dtype,
)));
}
let mut lp_clone = left_padding.try_clone()?;
let new_pad_lengths = lp_clone.to_vec::<i32>()?;
self.keys = Some(keys);
self.values = Some(values);
self.offset = offset;
self.left_padding = left_padding;
self.pad_lengths = new_pad_lengths;
self.idx = sk;
self.right_padding = None;
self.right_padding_host = None;
Ok(())
}
n => Err(Error::OutOfRange(OutOfRangePayload::new(
"BatchKvCache::set_state: state array count",
"must be 0 or 4",
format_smolstr!("{n}"),
))),
}
}
fn is_trimmable(&self) -> bool {
true
}
fn trim(&mut self, n: usize) -> Result<usize> {
let trimmed = n.min(self.idx);
if trimmed == 0 {
return Ok(0);
}
let new_idx = self.idx - trimmed;
let nscalar = ops::misc::astype(&Array::full::<f32>(&(1usize,), trimmed as f32)?, Dtype::I32)?;
let new_offset = ops::arithmetic::subtract(&self.offset, &nscalar)?;
let sliced = match (&self.keys, &self.values) {
(Some(k), Some(v)) => Some((slice_seq(k, 0, new_idx)?, slice_seq(v, 0, new_idx)?)),
_ => None,
};
self.idx = new_idx;
self.offset = new_offset;
if let Some((nk, nv)) = sliced {
self.keys = Some(nk);
self.values = Some(nv);
}
Ok(trimmed)
}
fn make_mask(
&self,
n: usize,
window_size: Option<usize>,
_return_array: bool,
) -> Result<MaskMode> {
Ok(MaskMode::Array(create_causal_mask_batched(
n,
self.idx,
window_size,
None,
Some(&self.left_padding),
)?))
}
fn nbytes(&self) -> usize {
let mut total = 0;
if let Some(k) = &self.keys {
total += nbytes(k).unwrap_or(0);
}
if let Some(v) = &self.values {
total += nbytes(v).unwrap_or(0);
}
total
}
fn is_empty(&self) -> bool {
self.keys.is_none()
}
fn copy(&self) -> Result<Box<dyn KvCache>> {
Ok(Box::new(Self {
keys: match &self.keys {
Some(a) => Some(a.try_clone()?),
None => None,
},
values: match &self.values {
Some(a) => Some(a.try_clone()?),
None => None,
},
left_padding: self.left_padding.try_clone()?,
pad_lengths: self.pad_lengths.clone(),
offset: self.offset.try_clone()?,
idx: self.idx,
right_padding: match &self.right_padding {
Some(a) => Some(a.try_clone()?),
None => None,
},
right_padding_host: self.right_padding_host.clone(),
}))
}
fn as_batch_positioned(&self) -> Option<&dyn BatchPositionedKvCache> {
Some(self)
}
fn reference_class_name(&self) -> &'static str {
"BatchKVCache"
}
fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
self
}
fn from_serialized(&mut self, state: Vec<Array>, meta: &[String]) -> Result<()> {
let mut staged = BatchKvCache::new(&[]);
staged.set_state(state)?;
staged.set_meta_state(meta)?;
*self = staged;
Ok(())
}
}
impl BatchPositionedKvCache for BatchKvCache {
fn batch_offset(&self) -> Result<Array> {
self.offset.try_clone()
}
}
pub(crate) fn create_causal_mask_batched(
n: usize,
offset: usize,
window_size: Option<usize>,
right_padding: Option<&Array>,
left_padding: Option<&Array>,
) -> Result<Array> {
use crate::lm::cache::mask::{iarange, scalar_i32};
let total = offset.checked_add(n).ok_or_else(|| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"create_causal_mask_batched: offset + N",
"usize",
[("offset", offset as u64), ("N", n as u64)],
))
})?;
let rinds = iarange(0, total)?;
let linds = if offset != 0 {
iarange(offset, total)?
} else {
rinds.try_clone()?
};
let linds = ops::shape::expand_dims_axes(&linds, &[1])?;
let rinds = ops::shape::expand_dims_axes(&rinds, &[0])?;
let mut mask = ops::comparison::greater_equal(&linds, &rinds)?;
if let Some(w) = window_size
&& w < total
{
let w_i32 = i32::try_from(w).map_err(|_| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"create_causal_mask_batched: window_size exceeds i32::MAX (cannot fit into a scalar mask offset)",
"i32",
[("window_size", w as u64)],
))
})?;
let bound = ops::arithmetic::add(&rinds, &scalar_i32(w_i32)?)?;
let windowed = ops::comparison::less(&linds, &bound)?;
mask = ops::logical::logical_and(&mask, &windowed)?;
}
if let Some(rp) = right_padding {
let total_i32 = i32::try_from(total).map_err(|_| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"create_causal_mask_batched: total exceeds i32::MAX (cannot fit into a scalar mask offset)",
"i32",
[("total", total as u64)],
))
})?;
let total_s = scalar_i32(total_i32)?;
let bound = ops::arithmetic::subtract(&total_s, rp)?; let bound = ops::shape::expand_dims_axes(&bound, &[1, 2, 3])?; let term = ops::comparison::less(&rinds, &bound)?;
mask = ops::logical::logical_and(&mask, &term)?;
}
if let Some(lp) = left_padding {
let lp = ops::shape::expand_dims_axes(lp, &[1, 2, 3])?; let term = ops::comparison::greater_equal(&rinds, &lp)?; mask = ops::logical::logical_and(&mask, &term)?;
}
Ok(mask)
}
#[cfg(test)]
mod tests;