use anyhow::{Context, Result};
use rlx_ir::{Graph, hir::HirModule};
use rlx_runtime::compile_cache::{
BucketedCompileCache, CacheRunInput, CompileCache, pad_rows, pad_rows_into, slice_rows,
};
use rlx_runtime::kv_cache::LayerKvCache;
use rlx_runtime::{CompileOptions, CompiledGraph, Device};
use std::collections::HashMap;
pub type DecodeLogitsKv = (Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>);
type PackedUploadMap<'a> =
Option<&'a HashMap<String, (Vec<u8>, rlx_ir::quant::QuantScheme, Vec<usize>)>>;
pub use rlx_runtime::kv_cache::LayerKvCache as KvCacheState;
pub fn prefill_cache_key(batch: usize, seq: usize) -> u64 {
((batch as u64) << 32) | (seq as u64)
}
pub fn packed_prefill_active_extent_enabled(device: Device) -> bool {
if rlx_ir::env::var("RLX_DISABLE_ACTIVE_EXTENT").as_deref() == Some("1") {
return false;
}
matches!(device, Device::Cpu | Device::Metal | Device::Mlx)
}
pub fn run_packed_prefill(
compiled: &mut CompiledGraph,
device: Device,
actual_seq: usize,
upper_seq: usize,
inputs: &[(&str, &[f32])],
) -> Vec<Vec<f32>> {
let use_active =
packed_prefill_active_extent_enabled(device) && actual_seq > 0 && actual_seq < upper_seq;
if use_active {
compiled.set_active_extent(Some((actual_seq, upper_seq)));
}
let out = compiled.run(inputs);
if use_active {
compiled.set_active_extent(None);
}
out
}
pub fn infer_prefill_kv_seq(
outputs: &[Vec<f32>],
batch: usize,
kv_dims: &[usize],
actual: usize,
upper: usize,
) -> usize {
let Some(k) = outputs.get(1) else {
return actual;
};
if kv_dims.is_empty() {
return actual;
}
let kd = kv_dims[0];
let actual_len = batch * actual * kd;
let upper_len = batch * upper * kd;
if k.len() == actual_len {
actual
} else if k.len() == upper_len {
upper
} else {
actual
}
}
fn output_inners_for_kv(kv_dims: &[usize]) -> Vec<usize> {
let mut output_inners = vec![0usize];
for &kd in kv_dims {
output_inners.push(kd);
output_inners.push(kd);
}
output_inners
}
fn run_bucketed_decode_on_compiled_metal_rows(
compiled: &mut CompiledGraph,
upper: usize,
past_seq: usize,
specs: &[CacheRunInput<'_>],
kv: &LayerKvCache,
kv_dims: &[usize],
) -> Result<DecodeLogitsKv> {
let num_layers = kv_dims.len();
let pairs: Vec<(&str, &[f32])> = specs.iter().map(|inp| (inp.name, inp.data)).collect();
let outs = compiled.run_read_outputs(&pairs, Some(&[0]));
let logits = outs
.into_iter()
.next()
.context("bucketed decode logits missing")?;
let mut new_k = Vec::with_capacity(num_layers);
let mut new_v = Vec::with_capacity(num_layers);
for layer in 0..num_layers {
let kd = kv_dims[layer];
let row_k = compiled
.read_output_row(1 + 2 * layer, upper, kd)
.with_context(|| format!("Metal K row read layer {layer}"))?;
let row_v = compiled
.read_output_row(2 + 2 * layer, upper, kd)
.with_context(|| format!("Metal V row read layer {layer}"))?;
let need = (past_seq + 1) * kd;
let mut k_out = Vec::with_capacity(need);
k_out.extend_from_slice(&kv.layers_k[layer]);
k_out.extend_from_slice(&row_k);
let mut v_out = Vec::with_capacity(need);
v_out.extend_from_slice(&kv.layers_v[layer]);
v_out.extend_from_slice(&row_v);
new_k.push(k_out);
new_v.push(v_out);
}
Ok((logits, new_k, new_v))
}
fn metal_full_kv_readback() -> bool {
matches!(
rlx_ir::env::var("RLX_GEMMA_METAL_FULL_KV_READBACK").as_deref(),
Some("1") | Some("true") | Some("yes")
)
}
fn finish_bucketed_decode(
compiled: &mut CompiledGraph,
upper: usize,
past_seq: usize,
specs: &[CacheRunInput<'_>],
output_inners: &[usize],
kv: &LayerKvCache,
kv_dims: &[usize],
) -> Result<DecodeLogitsKv> {
if compiled.device() == Device::Metal && !metal_full_kv_readback() {
return run_bucketed_decode_on_compiled_metal_rows(
compiled, upper, past_seq, specs, kv, kv_dims,
);
}
let raw = run_bucketed_decode_on_compiled(compiled, upper, past_seq, specs, output_inners);
split_bucketed_decode_kv_per_layer(raw, past_seq, kv_dims, kv_dims.len(), 1)
}
fn run_bucketed_decode_on_compiled(
compiled: &mut CompiledGraph,
upper: usize,
_past_seq: usize,
specs: &[CacheRunInput<'_>],
output_inners: &[usize],
) -> Vec<Vec<f32>> {
let kv_cap = upper + 1;
let mut owned: Vec<(String, Vec<f32>)> = Vec::new();
let mut use_owned = vec![false; specs.len()];
for (i, inp) in specs.iter().enumerate() {
if let Some(inner) = inp.row_inner {
if inp.data.len() != upper * inner {
owned.push((
inp.name.to_string(),
pad_rows(inp.data, inner, upper as u64),
));
use_owned[i] = true;
}
}
}
let mut owned_idx = 0usize;
let mut pairs: Vec<(&str, &[f32])> = Vec::with_capacity(specs.len());
for (i, inp) in specs.iter().enumerate() {
if use_owned[i] {
pairs.push((owned[owned_idx].0.as_str(), owned[owned_idx].1.as_slice()));
owned_idx += 1;
} else {
pairs.push((inp.name, inp.data));
}
}
let use_active = compiled.device() != Device::Metal
&& rlx_ir::env::var("RLX_DISABLE_ACTIVE_EXTENT").as_deref() != Some("1");
if use_active {
compiled.set_active_extent(Some((kv_cap, kv_cap)));
}
let raw_outputs = compiled.run(&pairs);
if use_active {
compiled.set_active_extent(None);
}
raw_outputs
.into_iter()
.enumerate()
.map(|(i, out)| match output_inners.get(i).copied() {
Some(0) | None => out,
Some(inner) => slice_rows(&out, inner, kv_cap),
})
.collect()
}
pub fn past_kv_input_names(num_layers: usize) -> Vec<String> {
(0..num_layers)
.flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
.collect()
}
pub fn split_decode_logits_kv(outputs: Vec<Vec<f32>>, num_layers: usize) -> Result<DecodeLogitsKv> {
if outputs.len() != 1 + 2 * num_layers {
anyhow::bail!(
"decode graph produced {} outputs, expected {}",
outputs.len(),
1 + 2 * num_layers
);
}
let mut iter = outputs.into_iter();
let logits = iter.next().context("decode logits missing")?;
let mut layers_k = Vec::with_capacity(num_layers);
let mut layers_v = Vec::with_capacity(num_layers);
for _ in 0..num_layers {
layers_k.push(iter.next().context("decode k missing")?);
layers_v.push(iter.next().context("decode v missing")?);
}
Ok((logits, layers_k, layers_v))
}
pub fn kv_from_prefill_outputs(
outputs: Vec<Vec<f32>>,
batch: usize,
seq: usize,
kv_dim: usize,
num_layers: usize,
) -> Result<(Vec<f32>, LayerKvCache)> {
let dims: Vec<usize> = vec![kv_dim; num_layers];
kv_from_prefill_outputs_per_layer(outputs, batch, seq, &dims, num_layers)
}
pub fn kv_from_prefill_outputs_per_layer(
outputs: Vec<Vec<f32>>,
batch: usize,
seq: usize,
kv_dims: &[usize],
num_layers: usize,
) -> Result<(Vec<f32>, LayerKvCache)> {
if outputs.len() != 1 + 2 * num_layers {
anyhow::bail!(
"prefill produced {} outputs, expected {}",
outputs.len(),
1 + 2 * num_layers
);
}
if kv_dims.len() != num_layers {
anyhow::bail!(
"kv_dims has {} entries, expected {num_layers}",
kv_dims.len()
);
}
let mut iter = outputs.into_iter();
let logits = iter.next().context("prefill logits missing")?;
let mut layers_k = Vec::with_capacity(num_layers);
let mut layers_v = Vec::with_capacity(num_layers);
for layer in 0..num_layers {
let k = iter.next().context("prefill k missing")?;
let v = iter.next().context("prefill v missing")?;
let expected_kv_len = batch * seq * kv_dims[layer];
if k.len() != expected_kv_len || v.len() != expected_kv_len {
anyhow::bail!(
"layer {layer}: k.len={} v.len={} expected {expected_kv_len} (kv_dim={})",
k.len(),
v.len(),
kv_dims[layer],
);
}
layers_k.push(k);
layers_v.push(v);
}
Ok((
logits,
LayerKvCache {
past_len: seq,
layers_k,
layers_v,
},
))
}
pub fn run_bucketed_kv_decode<F>(
cache: &mut BucketedCompileCache,
past_seq: usize,
kv: &LayerKvCache,
kv_dim: usize,
num_layers: usize,
fixed_inputs: &[CacheRunInput<'_>],
build: F,
options: &CompileOptions,
) -> Result<DecodeLogitsKv>
where
F: FnOnce(u64) -> (Graph, HashMap<String, Vec<f32>>),
{
run_bucketed_kv_decode_keyed(
cache,
past_seq as u64,
past_seq,
kv,
kv_dim,
num_layers,
fixed_inputs,
build,
options,
)
}
pub fn run_bucketed_kv_decode_keyed<F>(
cache: &mut BucketedCompileCache,
cache_key: u64,
past_seq: usize,
kv: &LayerKvCache,
kv_dim: usize,
num_layers: usize,
fixed_inputs: &[CacheRunInput<'_>],
build: F,
options: &CompileOptions,
) -> Result<DecodeLogitsKv>
where
F: FnOnce(u64) -> (Graph, HashMap<String, Vec<f32>>),
{
let (upper_u64, compiled) = cache
.ensure_graph_with_params(cache_key, build, options)
.ok_or_else(|| anyhow::anyhow!("cache_key {cache_key} outside decode buckets"))?;
let upper = upper_u64 as usize;
let (padded_k, padded_v) = kv.pad_layers_to_upper(upper_u64, kv_dim);
let key_names = past_kv_input_names(num_layers);
let mut specs: Vec<CacheRunInput<'_>> = Vec::with_capacity(fixed_inputs.len() + 2 * num_layers);
for inp in fixed_inputs {
specs.push(CacheRunInput {
name: inp.name,
data: inp.data,
row_inner: inp.row_inner,
});
}
for i in 0..num_layers {
specs.push(CacheRunInput {
name: key_names[2 * i].as_str(),
data: padded_k[i].as_slice(),
row_inner: Some(kv_dim),
});
specs.push(CacheRunInput {
name: key_names[2 * i + 1].as_str(),
data: padded_v[i].as_slice(),
row_inner: Some(kv_dim),
});
}
let kv_dims = vec![kv_dim; num_layers];
let output_inners = output_inners_for_kv(&kv_dims);
finish_bucketed_decode(
compiled,
upper,
past_seq,
&specs,
&output_inners,
kv,
&kv_dims,
)
}
pub fn run_bucketed_kv_decode_hir<F>(
cache: &mut BucketedCompileCache,
past_seq: usize,
kv: &LayerKvCache,
kv_dim: usize,
num_layers: usize,
fixed_inputs: &[CacheRunInput<'_>],
build: F,
options: &CompileOptions,
) -> Result<DecodeLogitsKv>
where
F: FnOnce(u64) -> (HirModule, HashMap<String, Vec<f32>>),
{
let key = past_seq as u64;
let (upper_u64, compiled) = cache
.ensure_hir_with_params(key, build, options)
.ok_or_else(|| anyhow::anyhow!("past_seq {past_seq} outside decode buckets"))?;
let upper = upper_u64 as usize;
let (padded_k, padded_v) = kv.pad_layers_to_upper(upper_u64, kv_dim);
let key_names = past_kv_input_names(num_layers);
let mut specs: Vec<CacheRunInput<'_>> = Vec::with_capacity(fixed_inputs.len() + 2 * num_layers);
for inp in fixed_inputs {
specs.push(CacheRunInput {
name: inp.name,
data: inp.data,
row_inner: inp.row_inner,
});
}
for i in 0..num_layers {
specs.push(CacheRunInput {
name: key_names[2 * i].as_str(),
data: padded_k[i].as_slice(),
row_inner: Some(kv_dim),
});
specs.push(CacheRunInput {
name: key_names[2 * i + 1].as_str(),
data: padded_v[i].as_slice(),
row_inner: Some(kv_dim),
});
}
let kv_dims = vec![kv_dim; num_layers];
let output_inners = output_inners_for_kv(&kv_dims);
finish_bucketed_decode(
compiled,
upper,
past_seq,
&specs,
&output_inners,
kv,
&kv_dims,
)
}
pub fn run_bucketed_kv_decode_hir_layers<F>(
cache: &mut BucketedCompileCache,
past_seq: usize,
kv: &LayerKvCache,
kv_dims: &[usize],
num_layers: usize,
fixed_inputs: &[CacheRunInput<'_>],
build: F,
options: &CompileOptions,
) -> Result<DecodeLogitsKv>
where
F: FnOnce(u64) -> (HirModule, HashMap<String, Vec<f32>>),
{
let key = past_seq as u64;
let (upper_u64, compiled) = cache
.ensure_hir_with_params(key, build, options)
.ok_or_else(|| anyhow::anyhow!("past_seq {past_seq} outside decode buckets"))?;
let upper = upper_u64 as usize;
if kv_dims.len() != num_layers {
anyhow::bail!(
"run_bucketed_kv_decode_hir: kv_dims len {} != num_layers {num_layers}",
kv_dims.len()
);
}
let (padded_k, padded_v) = kv.pad_layers_to_upper_per_layer(upper_u64, kv_dims);
let key_names = past_kv_input_names(num_layers);
let mut specs: Vec<CacheRunInput<'_>> = Vec::with_capacity(fixed_inputs.len() + 2 * num_layers);
for inp in fixed_inputs {
specs.push(CacheRunInput {
name: inp.name,
data: inp.data,
row_inner: inp.row_inner,
});
}
for i in 0..num_layers {
specs.push(CacheRunInput {
name: key_names[2 * i].as_str(),
data: padded_k[i].as_slice(),
row_inner: Some(kv_dims[i]),
});
specs.push(CacheRunInput {
name: key_names[2 * i + 1].as_str(),
data: padded_v[i].as_slice(),
row_inner: Some(kv_dims[i]),
});
}
let output_inners = output_inners_for_kv(kv_dims);
finish_bucketed_decode(
compiled,
upper,
past_seq,
&specs,
&output_inners,
kv,
kv_dims,
)
}
pub fn run_bucketed_kv_decode_hir_scratch<F>(
cache: &mut BucketedCompileCache,
past_seq: usize,
kv: &LayerKvCache,
kv_dims: &[usize],
num_layers: usize,
padded_k: &mut [Vec<f32>],
padded_v: &mut [Vec<f32>],
fixed_inputs: &[CacheRunInput<'_>],
build: F,
options: &CompileOptions,
) -> Result<DecodeLogitsKv>
where
F: FnOnce(u64) -> (HirModule, HashMap<String, Vec<f32>>),
{
let key = past_seq as u64;
let (upper_u64, compiled) = cache
.ensure_hir_with_params(key, build, options)
.ok_or_else(|| anyhow::anyhow!("past_seq {past_seq} outside decode buckets"))?;
let upper = upper_u64 as usize;
if kv_dims.len() != num_layers || padded_k.len() != num_layers || padded_v.len() != num_layers {
anyhow::bail!("run_bucketed_kv_decode_hir_scratch: layer count mismatch");
}
for i in 0..num_layers {
let kd = kv_dims[i];
let need = upper * kd;
if padded_k[i].len() != need {
padded_k[i].resize(need, 0.0);
padded_v[i].resize(need, 0.0);
}
pad_rows_into(&mut padded_k[i], kv.layers_k[i].as_slice(), kd);
pad_rows_into(&mut padded_v[i], kv.layers_v[i].as_slice(), kd);
}
let key_names = past_kv_input_names(num_layers);
let mut specs: Vec<CacheRunInput<'_>> = Vec::with_capacity(fixed_inputs.len() + 2 * num_layers);
for inp in fixed_inputs {
specs.push(CacheRunInput {
name: inp.name,
data: inp.data,
row_inner: inp.row_inner,
});
}
for i in 0..num_layers {
specs.push(CacheRunInput {
name: key_names[2 * i].as_str(),
data: padded_k[i].as_slice(),
row_inner: Some(kv_dims[i]),
});
specs.push(CacheRunInput {
name: key_names[2 * i + 1].as_str(),
data: padded_v[i].as_slice(),
row_inner: Some(kv_dims[i]),
});
}
let output_inners = output_inners_for_kv(kv_dims);
finish_bucketed_decode(
compiled,
upper,
past_seq,
&specs,
&output_inners,
kv,
kv_dims,
)
}
pub fn run_bucketed_kv_decode_graph_layers_scratch<F>(
cache: &mut BucketedCompileCache,
past_seq: usize,
kv: &LayerKvCache,
kv_dims: &[usize],
num_layers: usize,
padded_k: &mut [Vec<f32>],
padded_v: &mut [Vec<f32>],
fixed_inputs: &[CacheRunInput<'_>],
build: F,
packed_upload: PackedUploadMap<'_>,
packed_loaded: &mut std::collections::HashSet<u64>,
options: &CompileOptions,
) -> Result<DecodeLogitsKv>
where
F: FnOnce(u64) -> (Graph, HashMap<String, Vec<f32>>),
{
let key = past_seq as u64;
let needs_build = cache.compiled_for_key_mut(key).is_none();
let (upper_u64, compiled) = cache
.ensure_graph_with_params(key, build, options)
.ok_or_else(|| anyhow::anyhow!("past_seq {past_seq} outside decode buckets"))?;
if needs_build {
if let Some(packed) = packed_upload {
if packed_loaded.insert(upper_u64) {
for (name, (bytes, _scheme, _shape)) in packed {
compiled.set_param_typed(name, bytes, rlx_ir::DType::U8);
}
}
}
}
let upper = upper_u64 as usize;
if kv_dims.len() != num_layers || padded_k.len() != num_layers || padded_v.len() != num_layers {
anyhow::bail!("run_bucketed_kv_decode_graph_layers_scratch: layer count mismatch");
}
for i in 0..num_layers {
let kd = kv_dims[i];
let need = upper * kd;
if padded_k[i].len() != need {
padded_k[i].resize(need, 0.0);
padded_v[i].resize(need, 0.0);
}
pad_rows_into(&mut padded_k[i], kv.layers_k[i].as_slice(), kd);
pad_rows_into(&mut padded_v[i], kv.layers_v[i].as_slice(), kd);
}
let key_names = past_kv_input_names(num_layers);
let mut specs: Vec<CacheRunInput<'_>> = Vec::with_capacity(fixed_inputs.len() + 2 * num_layers);
for inp in fixed_inputs {
specs.push(CacheRunInput {
name: inp.name,
data: inp.data,
row_inner: inp.row_inner,
});
}
for i in 0..num_layers {
specs.push(CacheRunInput {
name: key_names[2 * i].as_str(),
data: padded_k[i].as_slice(),
row_inner: Some(kv_dims[i]),
});
specs.push(CacheRunInput {
name: key_names[2 * i + 1].as_str(),
data: padded_v[i].as_slice(),
row_inner: Some(kv_dims[i]),
});
}
let output_inners = output_inners_for_kv(kv_dims);
finish_bucketed_decode(
compiled,
upper,
past_seq,
&specs,
&output_inners,
kv,
kv_dims,
)
}
pub fn split_bucketed_decode_kv_per_layer(
outputs: Vec<Vec<f32>>,
past_seq: usize,
kv_dims: &[usize],
num_layers: usize,
batch: usize,
) -> Result<DecodeLogitsKv> {
if outputs.len() != 1 + 2 * num_layers {
anyhow::bail!(
"bucketed decode produced {} outputs, expected {}",
outputs.len(),
1 + 2 * num_layers
);
}
let mut iter = outputs.into_iter();
let logits = iter.next().context("bucketed decode logits missing")?;
let mut new_k = Vec::with_capacity(num_layers);
let mut new_v = Vec::with_capacity(num_layers);
for layer in 0..num_layers {
let kv_dim = kv_dims[layer];
let row_size = batch * kv_dim;
let k = iter.next().context("bucketed k missing")?;
let v = iter.next().context("bucketed v missing")?;
if k.len() < (past_seq + 1) * row_size {
anyhow::bail!(
"bucketed K layer {layer}: got {} f32, need at least {} (past_seq={past_seq}, kv_dim={kv_dim})",
k.len(),
(past_seq + 1) * row_size
);
}
new_k.push(compact_bucketed_kv_buffer(&k, past_seq + 1, kv_dim, batch));
new_v.push(compact_bucketed_kv_buffer(&v, past_seq + 1, kv_dim, batch));
}
Ok((logits, new_k, new_v))
}
pub fn run_bucketed_kv_decode_hir_uniform<F>(
cache: &mut BucketedCompileCache,
past_seq: usize,
kv: &LayerKvCache,
kv_dim: usize,
num_layers: usize,
fixed_inputs: &[CacheRunInput<'_>],
build: F,
options: &CompileOptions,
) -> Result<DecodeLogitsKv>
where
F: FnOnce(u64) -> (HirModule, HashMap<String, Vec<f32>>),
{
run_bucketed_kv_decode_hir(
cache,
past_seq,
kv,
kv_dim,
num_layers,
fixed_inputs,
build,
options,
)
}
pub fn split_bucketed_decode_kv(
outputs: Vec<Vec<f32>>,
past_seq: usize,
kv_dim: usize,
num_layers: usize,
batch: usize,
) -> Result<DecodeLogitsKv> {
if outputs.len() != 1 + 2 * num_layers {
anyhow::bail!(
"bucketed decode produced {} outputs, expected {}",
outputs.len(),
1 + 2 * num_layers
);
}
let mut iter = outputs.into_iter();
let logits = iter.next().context("bucketed decode logits missing")?;
let row_size = batch * kv_dim;
let mut new_k = Vec::with_capacity(num_layers);
let mut new_v = Vec::with_capacity(num_layers);
for layer in 0..num_layers {
let k = iter.next().context("bucketed k missing")?;
let v = iter.next().context("bucketed v missing")?;
if k.len() < (past_seq + 1) * row_size {
anyhow::bail!(
"bucketed K layer {layer}: got {} f32, need at least {} (past_seq={past_seq}, kv_dim={kv_dim})",
k.len(),
(past_seq + 1) * row_size
);
}
new_k.push(compact_bucketed_kv_buffer(&k, past_seq + 1, kv_dim, batch));
new_v.push(compact_bucketed_kv_buffer(&v, past_seq + 1, kv_dim, batch));
}
Ok((logits, new_k, new_v))
}
pub fn compact_bucketed_kv_buffer(
buf: &[f32],
past_len: usize,
kv_dim: usize,
batch: usize,
) -> Vec<f32> {
if past_len == 0 {
return Vec::new();
}
let row_size = batch * kv_dim;
if buf.len() < row_size {
return buf.to_vec();
}
let total_rows = buf.len() / row_size;
if total_rows <= past_len {
let n = past_len * row_size;
return buf[..n.min(buf.len())].to_vec();
}
let past_seq = past_len - 1;
let past_bytes = past_seq * row_size;
let new_row = total_rows - 1;
let mut out = Vec::with_capacity(past_len * row_size);
out.extend_from_slice(&buf[..past_bytes]);
out.extend_from_slice(&buf[new_row * row_size..(new_row + 1) * row_size]);
out
}
pub fn compile_cache_ensure_graph<'a>(
cache: &'a mut CompileCache,
key: u64,
graph: Graph,
params: HashMap<String, Vec<f32>>,
options: &CompileOptions,
) -> &'a mut CompiledGraph {
if !cache.contains(key) {
let compiled = cache.get_or_compile_with_options(key, || graph, options);
for (name, data) in ¶ms {
compiled.set_param(name, data);
}
}
cache.get_or_compile_with_options(
key,
|| panic!("compile_cache_ensure_graph: missing {key}"),
options,
)
}