#include <metal_stdlib>
using namespace metal;
/// Copy new K or V data from a source GPU buffer into the KV cache buffer
/// at the correct write position, with optional modulo wrapping for sliding
/// window (ring buffer) caches.
///
/// Grid: 1D, total_elements = n_new * row_size
/// Each thread copies one element.
kernel void kv_cache_copy(
device const half* src [[buffer(0)]], // source K or V [n_new, row_size]
device half* cache [[buffer(1)]], // destination cache buffer
constant uint& write_pos [[buffer(2)]], // starting write position in cache
constant uint& row_size [[buffer(3)]], // n_kv_heads * head_dim
constant uint& n_new [[buffer(4)]], // number of new tokens
constant uint& cache_cap [[buffer(5)]], // window size (sliding) or max_seq_len (global)
constant uint& is_sliding [[buffer(6)]], // 1 = use modulo wrapping, 0 = linear
uint tid [[thread_position_in_grid]]
) {
uint total_elements = n_new * row_size;
if (tid >= total_elements) return;
uint token_idx = tid / row_size;
uint elem_idx = tid % row_size;
uint dst_pos = is_sliding
? ((write_pos + token_idx) % cache_cap)
: (write_pos + token_idx);
cache[dst_pos * row_size + elem_idx] = src[token_idx * row_size + elem_idx];
}
/// Float32 variant of kv_cache_copy for F32 KV caches.
///
/// Identical logic to the half variant but operates on float data.
/// Used when the activation pipeline is F32 throughout (no bf16 casting).
kernel void kv_cache_copy_f32(
device const float* src [[buffer(0)]], // source K or V [n_new, row_size]
device float* cache [[buffer(1)]], // destination cache buffer
constant uint& write_pos [[buffer(2)]], // starting write position in cache
constant uint& row_size [[buffer(3)]], // n_kv_heads * head_dim
constant uint& n_new [[buffer(4)]], // number of new tokens
constant uint& cache_cap [[buffer(5)]], // window size (sliding) or max_seq_len (global)
constant uint& is_sliding [[buffer(6)]], // 1 = use modulo wrapping, 0 = linear
uint tid [[thread_position_in_grid]]
) {
uint total_elements = n_new * row_size;
if (tid >= total_elements) return;
uint token_idx = tid / row_size;
uint elem_idx = tid % row_size;
uint dst_pos = is_sliding
? ((write_pos + token_idx) % cache_cap)
: (write_pos + token_idx);
cache[dst_pos * row_size + elem_idx] = src[token_idx * row_size + elem_idx];
}
/// Batched KV cache copy — copies ALL heads in one dispatch.
///
/// Source layout: [n_heads * head_dim] flat (one token, all heads).
/// Cache layout: [n_heads, capacity, head_dim] head-major.
///
/// Grid: 2D — x=element within head (head_dim), y=head index (n_heads).
/// Replaces n_heads separate kv_cache_copy_f32 dispatches with 1.
kernel void kv_cache_copy_batch_f32(
device const float* src [[buffer(0)]], // [n_heads * head_dim] flat
device float* cache [[buffer(1)]], // [n_heads, capacity, head_dim]
constant uint& n_heads [[buffer(2)]], // number of KV heads
constant uint& head_dim [[buffer(3)]], // elements per head
constant uint& capacity [[buffer(4)]], // cache capacity (ring buffer size)
constant uint& seq_pos [[buffer(5)]], // write position (already wrapped)
uint2 tid [[thread_position_in_grid]] // x=elem, y=head
) {
uint elem = tid.x;
uint head = tid.y;
if (head >= n_heads || elem >= head_dim) return;
uint src_idx = head * head_dim + elem;
uint dst_idx = head * capacity * head_dim + seq_pos * head_dim + elem;
cache[dst_idx] = src[src_idx];
}
/// Batched KV cache copy with F32→F16 cast — copies ALL heads in one dispatch.
///
/// Source layout: [n_heads * head_dim] flat F32 (one token, all heads).
/// Cache layout: [n_heads, capacity, head_dim] head-major F16.
///
/// Casts float → half on write, halving cache memory bandwidth for SDPA reads.
/// Reference: llama.cpp stores KV cache in F16 for bandwidth-bound decode SDPA.
///
/// Grid: 2D — x=element within head (head_dim), y=head index (n_heads).
kernel void kv_cache_copy_batch_f32_to_f16(
device const float* src [[buffer(0)]], // [n_heads * head_dim] flat F32
device half* cache [[buffer(1)]], // [n_heads, capacity, head_dim] F16
constant uint& n_heads [[buffer(2)]], // number of KV heads
constant uint& head_dim [[buffer(3)]], // elements per head
constant uint& capacity [[buffer(4)]], // cache capacity (ring buffer size)
constant uint& seq_pos [[buffer(5)]], // write position (already wrapped)
uint2 tid [[thread_position_in_grid]] // x=elem, y=head
) {
uint elem = tid.x;
uint head = tid.y;
if (head >= n_heads || elem >= head_dim) return;
uint src_idx = head * head_dim + elem;
uint dst_idx = head * capacity * head_dim + seq_pos * head_dim + elem;
cache[dst_idx] = half(src[src_idx]);
}
/// Multi-position, all-heads KV cache copy (F32 source → F32 cache).
///
/// Source layout: [n_tokens, n_heads, head_dim] — token-major, from head_norm+RoPE.
/// Cache layout: [n_heads, capacity, head_dim] — head-major dense_kvs.
/// Writes positions [seq_pos_start, seq_pos_start + n_tokens) linearly
/// (no ring-buffer wrap — only used in prefill where seq_pos_start=0 and
/// n_tokens <= capacity).
///
/// Grid: 3D — x=elem within head, y=head, z=token.
kernel void kv_cache_copy_seq_f32(
device const float* src [[buffer(0)]], // [n_tokens, n_heads, head_dim] F32
device float* cache [[buffer(1)]], // [n_heads, capacity, head_dim] F32
constant uint& n_heads [[buffer(2)]],
constant uint& head_dim [[buffer(3)]],
constant uint& capacity [[buffer(4)]],
constant uint& seq_pos_start [[buffer(5)]],
constant uint& n_tokens [[buffer(6)]],
uint3 tid [[thread_position_in_grid]]
) {
uint elem = tid.x;
uint head = tid.y;
uint tok = tid.z;
if (head >= n_heads || elem >= head_dim || tok >= n_tokens) return;
uint src_idx = tok * (n_heads * head_dim) + head * head_dim + elem;
uint dst_pos = seq_pos_start + tok;
uint dst_idx = head * capacity * head_dim + dst_pos * head_dim + elem;
cache[dst_idx] = src[src_idx];
}
/// Multi-position, all-heads KV cache copy (F32 source → F16 cache).
///
/// Same layout/semantics as kv_cache_copy_seq_f32 but casts to half on write.
kernel void kv_cache_copy_seq_f32_to_f16(
device const float* src [[buffer(0)]], // [n_tokens, n_heads, head_dim] F32
device half* cache [[buffer(1)]], // [n_heads, capacity, head_dim] F16
constant uint& n_heads [[buffer(2)]],
constant uint& head_dim [[buffer(3)]],
constant uint& capacity [[buffer(4)]],
constant uint& seq_pos_start [[buffer(5)]],
constant uint& n_tokens [[buffer(6)]],
uint3 tid [[thread_position_in_grid]]
) {
uint elem = tid.x;
uint head = tid.y;
uint tok = tid.z;
if (head >= n_heads || elem >= head_dim || tok >= n_tokens) return;
uint src_idx = tok * (n_heads * head_dim) + head * head_dim + elem;
uint dst_pos = seq_pos_start + tok;
uint dst_idx = head * capacity * head_dim + dst_pos * head_dim + elem;
cache[dst_idx] = half(src[src_idx]);
}