use crate::{
array::Array,
dtype::Dtype,
error::{
ArithmeticOverflowPayload, EmptyInputPayload, Error, InvariantViolationPayload,
RankMismatchPayload, Result, ShapePairMismatchPayload,
},
ops,
};
pub(crate) const KV_NDIM: usize = 4;
pub(crate) const SEQ_AXIS: i32 = -2;
pub(crate) fn seq_len(name: &str, a: &Array) -> Result<usize> {
let shape = a.shape();
if shape.len() != KV_NDIM {
let context: &'static str = match name {
"keys" => "seq_len: KV cache expects 4-D keys [B, n_kv_heads, S, head_dim]",
"values" => "seq_len: KV cache expects 4-D values [B, n_kv_heads, S, head_dim]",
_ => "seq_len: KV cache expects 4-D [B, n_kv_heads, S, head_dim]",
};
return Err(Error::RankMismatch(RankMismatchPayload::new(
context,
shape.len() as u32,
shape.to_vec(),
)));
}
Ok(shape[KV_NDIM - 2])
}
pub(crate) fn head_dim(name: &str, a: &Array) -> Result<usize> {
let shape = a.shape();
if shape.len() != KV_NDIM {
let context: &'static str = match name {
"keys" => "head_dim: KV cache expects 4-D keys [B, n_kv_heads, S, head_dim]",
"values" => "head_dim: KV cache expects 4-D values [B, n_kv_heads, S, head_dim]",
_ => "head_dim: KV cache expects 4-D [B, n_kv_heads, S, head_dim]",
};
return Err(Error::RankMismatch(RankMismatchPayload::new(
context,
shape.len() as u32,
shape.to_vec(),
)));
}
Ok(shape[KV_NDIM - 1])
}
pub(crate) fn broadcast_write_rhs(
name: &str,
buf: &Array,
a: usize,
end: usize,
new: &Array,
) -> Result<Array> {
let bs = buf.shape();
let ns = new.shape();
if bs.len() != KV_NDIM {
let context: &'static str = match name {
"keys" => "broadcast_write_rhs: KV cache expects 4-D keys [B, n_kv_heads, S, head_dim]",
"values" => "broadcast_write_rhs: KV cache expects 4-D values [B, n_kv_heads, S, head_dim]",
_ => "broadcast_write_rhs: KV cache expects 4-D [B, n_kv_heads, S, head_dim]",
};
return Err(Error::RankMismatch(RankMismatchPayload::new(
context,
bs.len() as u32,
bs.to_vec(),
)));
}
if ns.len() != KV_NDIM {
let context: &'static str = match name {
"keys" => {
"broadcast_write_rhs: KV cache expects 4-D keys write RHS [B, n_kv_heads, S, head_dim]"
}
"values" => {
"broadcast_write_rhs: KV cache expects 4-D values write RHS [B, n_kv_heads, S, head_dim]"
}
_ => "broadcast_write_rhs: KV cache expects 4-D write RHS [B, n_kv_heads, S, head_dim]",
};
return Err(Error::RankMismatch(RankMismatchPayload::new(
context,
ns.len() as u32,
ns.to_vec(),
)));
}
let win = end.checked_sub(a).ok_or_else(|| {
let context: &'static str = match name {
"keys" => "set_seq: keys write end < start",
"values" => "set_seq: values write end < start",
_ => "set_seq: write end < start",
};
Error::InvariantViolation(InvariantViolationPayload::new(
context,
"must satisfy end >= start",
))
})?;
for axis in 0..KV_NDIM {
let target = if axis == KV_NDIM - 2 { win } else { bs[axis] };
let got = ns[axis];
if got != target && got != 1 {
let expected: Vec<usize> = (0..KV_NDIM)
.map(|i| if i == KV_NDIM - 2 { win } else { bs[i] })
.collect();
let context: &'static str = match name {
"keys" => {
"broadcast_write_rhs: keys write RHS non-broadcastable (mlx-lm slice-assignment raises on non-broadcastable non-seq axes; seq-axis target is the slice window length)"
}
"values" => {
"broadcast_write_rhs: values write RHS non-broadcastable (mlx-lm slice-assignment raises on non-broadcastable non-seq axes; seq-axis target is the slice window length)"
}
_ => {
"broadcast_write_rhs: write RHS non-broadcastable (mlx-lm slice-assignment raises on non-broadcastable non-seq axes; seq-axis target is the slice window length)"
}
};
return Err(Error::ShapePairMismatch(ShapePairMismatchPayload::new(
context,
expected,
ns.to_vec(),
)));
}
}
let target_shape: Vec<usize> = (0..KV_NDIM)
.map(|axis| if axis == KV_NDIM - 2 { win } else { bs[axis] })
.collect();
ops::shape::broadcast_to(new, &target_shape.as_slice())
}
pub(crate) fn slice_seq(a: &Array, start: usize, end: usize) -> Result<Array> {
let shape = a.shape();
if shape.len() != KV_NDIM {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"slice_seq: expects 4-D array [B, n_kv_heads, S, head_dim]",
shape.len() as u32,
shape.to_vec(),
)));
}
let mut starts = vec![0i32; KV_NDIM];
let mut stops: Vec<i32> = shape
.iter()
.map(|&d| {
i32::try_from(d).map_err(|_| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"slice_seq: shape dim exceeds i32::MAX",
"i32",
[("dim", d as u64)],
))
})
})
.collect::<Result<Vec<i32>>>()?;
let strides = vec![1i32; KV_NDIM];
starts[KV_NDIM - 2] = i32::try_from(start).map_err(|_| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"slice_seq: start offset exceeds i32::MAX",
"i32",
[("start", start as u64)],
))
})?;
stops[KV_NDIM - 2] = i32::try_from(end).map_err(|_| {
Error::ArithmeticOverflow(ArithmeticOverflowPayload::with_operands(
"slice_seq: end offset exceeds i32::MAX",
"i32",
[("end", end as u64)],
))
})?;
ops::indexing::slice(a, &starts, &stops, &strides)
}
pub(crate) fn concat_seq(a: &Array, b: &Array) -> Result<Array> {
ops::shape::concatenate(&[a, b], SEQ_AXIS)
}
pub(crate) fn seq_slice(a: &Array, start: usize, end: usize) -> Result<Array> {
let l = a.shape()[KV_NDIM - 2];
let end = end.min(l);
let start = start.min(end);
slice_seq(a, start, end)
}
fn dtype_size(d: Dtype) -> usize {
match d {
Dtype::Bool | Dtype::U8 | Dtype::I8 => 1,
Dtype::U16 | Dtype::I16 | Dtype::F16 | Dtype::BF16 => 2,
Dtype::U32 | Dtype::I32 | Dtype::F32 => 4,
Dtype::U64 | Dtype::I64 | Dtype::F64 | Dtype::Complex64 => 8,
}
}
pub(crate) fn nbytes(a: &Array) -> Result<usize> {
Ok(a.size() * dtype_size(a.dtype()?))
}
pub(crate) fn concat_parts(parts: &[&Array]) -> Result<Array> {
let non_empty: Vec<&Array> = parts
.iter()
.copied()
.filter(|a| {
let shape = a.shape();
shape.len() != KV_NDIM || shape[KV_NDIM - 2] > 0
})
.collect();
let rank_checked = |a: &Array| -> Result<Array> {
let shape = a.shape();
if shape.len() != KV_NDIM {
return Err(Error::RankMismatch(RankMismatchPayload::new(
"concat_parts: KV cache concat expects 4-D [B, n_kv_heads, S, head_dim] parts",
shape.len() as u32,
shape.to_vec(),
)));
}
a.try_clone()
};
match non_empty.as_slice() {
[] => match parts.first() {
Some(first) => rank_checked(first),
None => Err(Error::EmptyInput(EmptyInputPayload::new(
"concat_parts: parts",
))),
},
[one] => rank_checked(one),
many => ops::shape::concatenate(many, SEQ_AXIS),
}
}
#[cfg(test)]
mod tests;