use metal::{Buffer, Device};
use crate::riir::moe::deferred;
use crate::riir::attn::linear_attn_forward::{
full_attn_layer_idx_for, linear_layer_idx_for, LayerForwardBuffers,
};
use crate::riir::snapshot::state::{LayerState, MlaKvCacheGpu};
use crate::riir::variants::{
AttnKind, LayerKind, Variant, GPU_KV_SEQ, MAX_SEQ_LEN, VARIANT,
};
pub const SNAPSHOT_MAGIC: u32 = 0x4D464C58;
pub const SNAPSHOT_VERSION: u32 = 2;
pub const SNAPSHOT_HEADER_U32: usize = 10;
pub const SNAPSHOT_HEADER_V1_U32: usize = 8;
#[derive(Debug, thiserror::Error)]
pub enum StateSnapshotError {
#[error("buffer too small: need {need} bytes, got {got}")]
BufferTooSmall { need: usize, got: usize },
#[error("snapshot magic mismatch (got 0x{got:08X}, want 0x{want:08X})")]
BadMagic { got: u32, want: u32 },
#[error("snapshot version mismatch (got {got}, want {want})")]
BadVersion { got: u32, want: u32 },
#[error("snapshot shape mismatch on field '{field}' (got {got}, want {want})")]
ShapeMismatch {
field: &'static str,
got: u32,
want: u32,
},
#[error("snapshot layer {layer} has negative KV length {len}")]
NegativeLen { layer: usize, len: i32 },
#[error(
"snapshot layer {layer} KV length {len} exceeds MAX_SEQ_LEN={max}"
)]
LenOverflow { layer: usize, len: i32, max: usize },
#[error("snapshot truncated at layer {layer}: need {need} more bytes, got {got}")]
Truncated {
layer: usize,
need: usize,
got: usize,
},
#[error("linear_buffers not initialized — call eval_prompt or memory_clear first")]
BuffersNotReady,
#[error("snapshot v{SNAPSHOT_VERSION} doesn't support MLA layers (layer {layer})")]
MlaUnsupported { layer: usize },
}
#[inline]
fn full_attn_stride_bytes(v: &Variant) -> usize {
v.num_kv_heads * v.head_dim * std::mem::size_of::<f32>()
}
#[inline]
fn linear_conv_bytes(v: &Variant) -> usize {
(Variant::CONV_KERNEL_SIZE - 1)
* v.linear_conv_dim()
* std::mem::size_of::<f32>()
}
#[inline]
fn linear_ssm_bytes(v: &Variant) -> usize {
v.linear_num_v_heads
* Variant::LINEAR_VALUE_DIM
* Variant::LINEAR_KEY_DIM
* std::mem::size_of::<f32>()
}
#[inline]
fn mla_latent_bytes(v: &Variant, len: usize) -> usize {
len * v.kv_lora_rank * std::mem::size_of::<f32>()
}
#[inline]
fn mla_rope_k_bytes(v: &Variant, len: usize) -> usize {
len * v.qk_rope_head_dim * std::mem::size_of::<f32>()
}
fn read_buffer_bytes_n_f32(buf: &Buffer, dst: &mut [u8], n_f32: usize) {
let bytes = n_f32 * std::mem::size_of::<f32>();
debug_assert_eq!(dst.len(), bytes);
unsafe {
std::ptr::copy_nonoverlapping(
buf.contents() as *const u8,
dst.as_mut_ptr(),
bytes,
);
}
}
fn write_buffer_bytes_n_f32(buf: &Buffer, src: &[u8], n_f32: usize) {
let bytes = n_f32 * std::mem::size_of::<f32>();
debug_assert_eq!(src.len(), bytes);
unsafe {
std::ptr::copy_nonoverlapping(
src.as_ptr(),
buf.contents() as *mut u8,
bytes,
);
}
}
fn ensure_mla_buffers(cache: &mut MlaKvCacheGpu, device: &Device) {
cache.ensure_buffers(device);
}
pub fn state_size(layer_states: &[LayerState]) -> usize {
let v = VARIANT;
let mut n = SNAPSHOT_HEADER_U32 * std::mem::size_of::<u32>();
let fa_stride = full_attn_stride_bytes(&v);
let la_conv = linear_conv_bytes(&v);
let la_ssm = linear_ssm_bytes(&v);
for (i, layer) in layer_states.iter().enumerate().take(v.num_layers) {
match v.layer_kind(i) {
LayerKind::FullAttn => {
n += std::mem::size_of::<i32>();
let len = match layer {
LayerState::FullAttn(kv) => kv.len.max(0) as usize,
LayerState::Mla(c) => c.len.max(0) as usize,
LayerState::LinearAttn(_) => 0,
};
match v.attn_kind {
AttnKind::Gqa => {
n += 2 * len * fa_stride;
}
AttnKind::Mla => {
n += mla_latent_bytes(&v, len)
+ mla_rope_k_bytes(&v, len);
}
}
}
LayerKind::LinearAttn => {
n += la_conv + la_ssm;
}
}
}
n
}
pub fn state_save(
buf: &mut [u8],
layer_states: &[LayerState],
linear_buffers: Option<&LayerForwardBuffers>,
pool: Option<&crate::riir::backend::MetalBufferPool>,
) -> Result<usize, StateSnapshotError> {
use crate::riir::backend::BufferPool as _;
let v = VARIANT;
let need = state_size(layer_states);
if buf.len() < need {
return Err(StateSnapshotError::BufferTooSmall {
need,
got: buf.len(),
});
}
let fa_stride = full_attn_stride_bytes(&v);
let la_conv = linear_conv_bytes(&v);
let la_ssm = linear_ssm_bytes(&v);
let mut off = 0usize;
let header: [u32; SNAPSHOT_HEADER_U32] = [
SNAPSHOT_MAGIC,
SNAPSHOT_VERSION,
v.num_layers as u32,
v.full_attn_interval as u32,
v.num_kv_heads as u32,
v.head_dim as u32,
la_conv as u32,
la_ssm as u32,
v.kv_lora_rank as u32,
v.qk_rope_head_dim as u32,
];
for &word in header.iter() {
buf[off..off + 4].copy_from_slice(&word.to_le_bytes());
off += 4;
}
for (i, layer) in layer_states.iter().enumerate().take(v.num_layers) {
match v.layer_kind(i) {
LayerKind::FullAttn => match (v.attn_kind, layer) {
(AttnKind::Gqa, LayerState::FullAttn(kv)) => {
let len = kv.len.max(0);
buf[off..off + 4].copy_from_slice(&len.to_le_bytes());
off += 4;
if len > 0 {
let bytes = (len as usize) * fa_stride;
let n = bytes / std::mem::size_of::<f32>();
let p = pool.ok_or(
StateSnapshotError::BuffersNotReady,
)?;
let k_buf = p.handle(kv.k_id.ok_or(
StateSnapshotError::BuffersNotReady,
)?);
read_buffer_bytes_n_f32(
k_buf,
&mut buf[off..off + bytes],
n,
);
off += bytes;
let v_buf = p.handle(kv.v_id.ok_or(
StateSnapshotError::BuffersNotReady,
)?);
read_buffer_bytes_n_f32(
v_buf,
&mut buf[off..off + bytes],
n,
);
off += bytes;
}
}
(AttnKind::Mla, LayerState::Mla(cache)) => {
let len = cache.len.max(0);
buf[off..off + 4].copy_from_slice(&len.to_le_bytes());
off += 4;
if len > 0 {
let lat_bytes = mla_latent_bytes(&v, len as usize);
let rope_bytes =
mla_rope_k_bytes(&v, len as usize);
let lat_buf = cache.latent_cache.as_ref().ok_or(
StateSnapshotError::BuffersNotReady,
)?;
let rope_buf = cache.rope_k_cache.as_ref().ok_or(
StateSnapshotError::BuffersNotReady,
)?;
let n_lat = (len as usize) * v.kv_lora_rank;
let n_rope = (len as usize) * v.qk_rope_head_dim;
read_buffer_bytes_n_f32(
lat_buf,
&mut buf[off..off + lat_bytes],
n_lat,
);
off += lat_bytes;
read_buffer_bytes_n_f32(
rope_buf,
&mut buf[off..off + rope_bytes],
n_rope,
);
off += rope_bytes;
}
}
_ => {
return Err(StateSnapshotError::ShapeMismatch {
field: "layer_state_kind",
got: 0,
want: 1,
});
}
},
LayerKind::LinearAttn => {
let lb = linear_buffers
.ok_or(StateSnapshotError::BuffersNotReady)?;
let p = pool
.ok_or(StateSnapshotError::BuffersNotReady)?;
let linear_idx = linear_layer_idx_for(i)
.expect("layer_kind says LinearAttn");
read_buffer_bytes(
p.handle(lb.conv_state[linear_idx]),
&mut buf[off..off + la_conv],
);
off += la_conv;
read_buffer_bytes(
p.handle(lb.delta_state[linear_idx]),
&mut buf[off..off + la_ssm],
);
off += la_ssm;
}
}
}
debug_assert_eq!(off, need, "state_save wrote {off} bytes, expected {need}");
Ok(off)
}
pub fn state_load(
buf: &[u8],
layer_states: &mut [LayerState],
mut linear_buffers: Option<&mut LayerForwardBuffers>,
pool: Option<&crate::riir::backend::MetalBufferPool>,
device: &Device,
) -> Result<(), StateSnapshotError> {
use crate::riir::backend::BufferPool as _;
let v = VARIANT;
if buf.len() < 8 {
return Err(StateSnapshotError::Truncated {
layer: 0,
need: 8,
got: buf.len(),
});
}
let read_u32 = |off: usize| -> u32 {
u32::from_le_bytes(buf[off..off + 4].try_into().unwrap())
};
let magic = read_u32(0);
if magic != SNAPSHOT_MAGIC {
return Err(StateSnapshotError::BadMagic {
got: magic,
want: SNAPSHOT_MAGIC,
});
}
let version = read_u32(4);
let header_words = match version {
1 => SNAPSHOT_HEADER_V1_U32,
2 => SNAPSHOT_HEADER_U32,
_ => {
return Err(StateSnapshotError::BadVersion {
got: version,
want: SNAPSHOT_VERSION,
});
}
};
let header_bytes = header_words * std::mem::size_of::<u32>();
if buf.len() < header_bytes {
return Err(StateSnapshotError::Truncated {
layer: 0,
need: header_bytes,
got: buf.len(),
});
}
if version == 1 && v.attn_kind == AttnKind::Mla {
return Err(StateSnapshotError::BadVersion {
got: version,
want: SNAPSHOT_VERSION,
});
}
let check = |off: usize, field: &'static str, want: u32| -> Result<(), StateSnapshotError> {
let got = read_u32(off);
if got != want {
return Err(StateSnapshotError::ShapeMismatch { field, got, want });
}
Ok(())
};
check(8, "num_layers", v.num_layers as u32)?;
check(12, "full_attn_interval", v.full_attn_interval as u32)?;
check(16, "num_kv_heads", v.num_kv_heads as u32)?;
check(20, "head_dim", v.head_dim as u32)?;
let la_conv = linear_conv_bytes(&v);
let la_ssm = linear_ssm_bytes(&v);
check(24, "linear_conv_bytes", la_conv as u32)?;
check(28, "linear_ssm_bytes", la_ssm as u32)?;
if version == 2 {
check(32, "kv_lora_rank", v.kv_lora_rank as u32)?;
check(36, "qk_rope_head_dim", v.qk_rope_head_dim as u32)?;
}
let fa_stride = full_attn_stride_bytes(&v);
{
let mut q = header_bytes;
for i in 0..v.num_layers {
match v.layer_kind(i) {
LayerKind::FullAttn => {
if buf.len() - q < 4 {
return Err(StateSnapshotError::Truncated {
layer: i,
need: 4,
got: buf.len() - q,
});
}
let len = i32::from_le_bytes(
buf[q..q + 4].try_into().unwrap(),
);
q += 4;
if len < 0 {
return Err(StateSnapshotError::NegativeLen {
layer: i,
len,
});
}
if (len as usize) > MAX_SEQ_LEN {
return Err(StateSnapshotError::LenOverflow {
layer: i,
len,
max: MAX_SEQ_LEN,
});
}
let bytes = match v.attn_kind {
AttnKind::Gqa => 2 * (len as usize) * fa_stride,
AttnKind::Mla => {
mla_latent_bytes(&v, len as usize)
+ mla_rope_k_bytes(&v, len as usize)
}
};
if buf.len() - q < bytes {
return Err(StateSnapshotError::Truncated {
layer: i,
need: bytes,
got: buf.len() - q,
});
}
q += bytes;
}
LayerKind::LinearAttn => {
let bytes = la_conv + la_ssm;
if buf.len() - q < bytes {
return Err(StateSnapshotError::Truncated {
layer: i,
need: bytes,
got: buf.len() - q,
});
}
q += bytes;
}
}
}
}
let mut off = header_bytes;
for i in 0..v.num_layers {
match v.layer_kind(i) {
LayerKind::FullAttn => {
let len = i32::from_le_bytes(
buf[off..off + 4].try_into().unwrap(),
);
off += 4;
if v.attn_kind == AttnKind::Mla {
let cache = match &mut layer_states[i] {
LayerState::Mla(c) => c,
_ => {
return Err(
StateSnapshotError::ShapeMismatch {
field: "layer_state_kind",
got: 0,
want: 1,
},
);
}
};
ensure_mla_buffers(cache, device);
if len > 0 {
let lat_bytes = mla_latent_bytes(&v, len as usize);
let rope_bytes =
mla_rope_k_bytes(&v, len as usize);
let n_lat = (len as usize) * v.kv_lora_rank;
let n_rope = (len as usize) * v.qk_rope_head_dim;
let lat_buf = cache
.latent_cache
.as_ref()
.expect("ensure_mla_buffers just ran");
let rope_buf = cache
.rope_k_cache
.as_ref()
.expect("ensure_mla_buffers just ran");
write_buffer_bytes_n_f32(
lat_buf,
&buf[off..off + lat_bytes],
n_lat,
);
off += lat_bytes;
write_buffer_bytes_n_f32(
rope_buf,
&buf[off..off + rope_bytes],
n_rope,
);
off += rope_bytes;
}
cache.len = len;
continue;
}
let kv = match &mut layer_states[i] {
LayerState::FullAttn(kv) => kv,
LayerState::Mla(_) => {
unreachable!("attn_kind branch handled above")
}
LayerState::LinearAttn(_) => {
return Err(StateSnapshotError::ShapeMismatch {
field: "layer_state_kind",
got: 0,
want: 1,
});
}
};
if len > 0 {
let p = pool
.ok_or(StateSnapshotError::BuffersNotReady)?;
let bytes = (len as usize) * fa_stride;
let n = bytes / std::mem::size_of::<f32>();
let k_buf = p.handle(kv.k_id.ok_or(
StateSnapshotError::BuffersNotReady,
)?);
write_buffer_bytes_n_f32(
k_buf,
&buf[off..off + bytes],
n,
);
off += bytes;
let v_buf = p.handle(kv.v_id.ok_or(
StateSnapshotError::BuffersNotReady,
)?);
write_buffer_bytes_n_f32(
v_buf,
&buf[off..off + bytes],
n,
);
off += bytes;
}
let stride = v.num_kv_heads * v.head_dim;
kv.len = len;
if let Some(fa_idx) = full_attn_layer_idx_for(i) {
let mirror_len = (len as usize).min(GPU_KV_SEQ);
if mirror_len > 0 {
let lb = linear_buffers
.as_deref_mut()
.ok_or(StateSnapshotError::BuffersNotReady)?;
let n = mirror_len * stride;
let p = pool
.ok_or(StateSnapshotError::BuffersNotReady)?;
unsafe {
let k_dst = p
.handle(lb.gpu_kv_k[fa_idx])
.contents()
as *mut f32;
let v_dst = p
.handle(lb.gpu_kv_v[fa_idx])
.contents()
as *mut f32;
std::ptr::copy_nonoverlapping(
kv.k_slice(p, mirror_len).as_ptr(),
k_dst,
n,
);
std::ptr::copy_nonoverlapping(
kv.v_slice(p, mirror_len).as_ptr(),
v_dst,
n,
);
}
}
}
}
LayerKind::LinearAttn => {
let lb = linear_buffers
.as_deref_mut()
.ok_or(StateSnapshotError::BuffersNotReady)?;
let p = pool
.ok_or(StateSnapshotError::BuffersNotReady)?;
let linear_idx = linear_layer_idx_for(i)
.expect("layer_kind says LinearAttn");
write_buffer_bytes(
p.handle(lb.conv_state[linear_idx]),
&buf[off..off + la_conv],
);
off += la_conv;
write_buffer_bytes(
p.handle(lb.delta_state[linear_idx]),
&buf[off..off + la_ssm],
);
off += la_ssm;
}
}
}
Ok(())
}
pub(in crate::riir) fn drain_deferred(deferred: &mut deferred::DeferredRing) {
deferred::discard_deferred_experts_in(deferred);
}
fn read_buffer_bytes(buf: &Buffer, dst: &mut [u8]) {
let n = dst.len();
debug_assert!(buf.length() as usize >= n);
unsafe {
std::ptr::copy_nonoverlapping(buf.contents() as *const u8, dst.as_mut_ptr(), n);
}
}
fn write_buffer_bytes(buf: &Buffer, src: &[u8]) {
let n = src.len();
debug_assert!(buf.length() as usize >= n);
unsafe {
std::ptr::copy_nonoverlapping(src.as_ptr(), buf.contents() as *mut u8, n);
}
}