use anyhow::{Context, Result};
use rlx_ir::{Graph, hir::HirModule};
use rlx_runtime::compile_cache::{BucketedCompileCache, CacheRunInput, CompileCache};
use rlx_runtime::kv_cache::LayerKvCache;
use rlx_runtime::{CompileOptions, CompiledGraph};
use std::collections::HashMap;
pub type DecodeLogitsKv = (Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>);
pub use rlx_runtime::kv_cache::LayerKvCache as KvCacheState;
pub fn prefill_cache_key(batch: usize, seq: usize) -> u64 {
((batch as u64) << 32) | (seq as u64)
}
pub fn past_kv_input_names(num_layers: usize) -> Vec<String> {
(0..num_layers)
.flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
.collect()
}
pub fn split_decode_logits_kv(outputs: Vec<Vec<f32>>, num_layers: usize) -> Result<DecodeLogitsKv> {
if outputs.len() != 1 + 2 * num_layers {
anyhow::bail!(
"decode graph produced {} outputs, expected {}",
outputs.len(),
1 + 2 * num_layers
);
}
let mut iter = outputs.into_iter();
let logits = iter.next().context("decode logits missing")?;
let mut layers_k = Vec::with_capacity(num_layers);
let mut layers_v = Vec::with_capacity(num_layers);
for _ in 0..num_layers {
layers_k.push(iter.next().context("decode k missing")?);
layers_v.push(iter.next().context("decode v missing")?);
}
Ok((logits, layers_k, layers_v))
}
pub fn kv_from_prefill_outputs(
outputs: Vec<Vec<f32>>,
batch: usize,
seq: usize,
kv_dim: usize,
num_layers: usize,
) -> Result<(Vec<f32>, LayerKvCache)> {
if outputs.len() != 1 + 2 * num_layers {
anyhow::bail!(
"prefill produced {} outputs, expected {}",
outputs.len(),
1 + 2 * num_layers
);
}
let expected_kv_len = batch * seq * kv_dim;
let mut iter = outputs.into_iter();
let logits = iter.next().context("prefill logits missing")?;
let mut layers_k = Vec::with_capacity(num_layers);
let mut layers_v = Vec::with_capacity(num_layers);
for layer in 0..num_layers {
let k = iter.next().context("prefill k missing")?;
let v = iter.next().context("prefill v missing")?;
if k.len() != expected_kv_len || v.len() != expected_kv_len {
anyhow::bail!(
"layer {layer}: k.len={} v.len={} expected {expected_kv_len}",
k.len(),
v.len()
);
}
layers_k.push(k);
layers_v.push(v);
}
Ok((
logits,
LayerKvCache {
past_len: seq,
layers_k,
layers_v,
},
))
}
pub fn run_bucketed_kv_decode<F>(
cache: &mut BucketedCompileCache,
past_seq: usize,
kv: &LayerKvCache,
kv_dim: usize,
num_layers: usize,
fixed_inputs: &[CacheRunInput<'_>],
build: F,
options: &CompileOptions,
) -> Result<DecodeLogitsKv>
where
F: FnOnce(u64) -> (Graph, HashMap<String, Vec<f32>>),
{
let key = past_seq as u64;
let upper = cache
.ensure_graph_with_params(key, build, options)
.ok_or_else(|| anyhow::anyhow!("past_seq {past_seq} outside decode buckets"))?
.0 as usize;
let (padded_k, padded_v) = kv.pad_layers_to_upper(upper as u64, kv_dim);
let key_names = past_kv_input_names(num_layers);
let mut specs: Vec<CacheRunInput<'_>> = Vec::with_capacity(fixed_inputs.len() + 2 * num_layers);
for inp in fixed_inputs {
specs.push(CacheRunInput {
name: inp.name,
data: inp.data,
row_inner: inp.row_inner,
});
}
for i in 0..num_layers {
specs.push(CacheRunInput {
name: key_names[2 * i].as_str(),
data: padded_k[i].as_slice(),
row_inner: Some(kv_dim),
});
specs.push(CacheRunInput {
name: key_names[2 * i + 1].as_str(),
data: padded_v[i].as_slice(),
row_inner: Some(kv_dim),
});
}
let mut output_inners = vec![0usize];
output_inners.extend(std::iter::repeat_n(kv_dim, 2 * num_layers));
let (_upper, raw) = cache
.run_padded_mixed(
key,
past_seq + 1,
|_| panic!("bucket must be compiled via ensure_graph_with_params"),
&specs,
&output_inners,
)
.ok_or_else(|| anyhow::anyhow!("run_padded_mixed failed for past_seq {past_seq}"))?;
split_bucketed_decode_kv(raw, past_seq, kv_dim, num_layers)
}
pub fn run_bucketed_kv_decode_hir<F>(
cache: &mut BucketedCompileCache,
past_seq: usize,
kv: &LayerKvCache,
kv_dim: usize,
num_layers: usize,
fixed_inputs: &[CacheRunInput<'_>],
build: F,
options: &CompileOptions,
) -> Result<DecodeLogitsKv>
where
F: FnOnce(u64) -> (HirModule, HashMap<String, Vec<f32>>),
{
let key = past_seq as u64;
let upper = cache
.ensure_hir_with_params(key, build, options)
.ok_or_else(|| anyhow::anyhow!("past_seq {past_seq} outside decode buckets"))?
.0 as usize;
let (padded_k, padded_v) = kv.pad_layers_to_upper(upper as u64, kv_dim);
let key_names = past_kv_input_names(num_layers);
let mut specs: Vec<CacheRunInput<'_>> = Vec::with_capacity(fixed_inputs.len() + 2 * num_layers);
for inp in fixed_inputs {
specs.push(CacheRunInput {
name: inp.name,
data: inp.data,
row_inner: inp.row_inner,
});
}
for i in 0..num_layers {
specs.push(CacheRunInput {
name: key_names[2 * i].as_str(),
data: padded_k[i].as_slice(),
row_inner: Some(kv_dim),
});
specs.push(CacheRunInput {
name: key_names[2 * i + 1].as_str(),
data: padded_v[i].as_slice(),
row_inner: Some(kv_dim),
});
}
let mut output_inners = vec![0usize];
output_inners.extend(std::iter::repeat_n(kv_dim, 2 * num_layers));
let (_upper, raw) = cache
.run_padded_mixed(
key,
past_seq + 1,
|_| panic!("bucket must be compiled via ensure_hir_with_params"),
&specs,
&output_inners,
)
.ok_or_else(|| anyhow::anyhow!("run_padded_mixed failed for past_seq {past_seq}"))?;
split_bucketed_decode_kv(raw, past_seq, kv_dim, num_layers)
}
pub fn split_bucketed_decode_kv(
outputs: Vec<Vec<f32>>,
past_seq: usize,
kv_dim: usize,
num_layers: usize,
) -> Result<DecodeLogitsKv> {
if outputs.len() != 1 + 2 * num_layers {
anyhow::bail!(
"bucketed decode produced {} outputs, expected {}",
outputs.len(),
1 + 2 * num_layers
);
}
let mut iter = outputs.into_iter();
let logits = iter.next().context("bucketed decode logits missing")?;
let real_len = (past_seq + 1) * kv_dim;
let mut new_k = Vec::with_capacity(num_layers);
let mut new_v = Vec::with_capacity(num_layers);
for _ in 0..num_layers {
let k = iter.next().context("bucketed k missing")?;
let v = iter.next().context("bucketed v missing")?;
new_k.push(k[..real_len.min(k.len())].to_vec());
new_v.push(v[..real_len.min(v.len())].to_vec());
}
Ok((logits, new_k, new_v))
}
pub fn compile_cache_ensure_graph<'a>(
cache: &'a mut CompileCache,
key: u64,
graph: Graph,
params: HashMap<String, Vec<f32>>,
options: &CompileOptions,
) -> &'a mut CompiledGraph {
if !cache.contains(key) {
let compiled = cache.get_or_compile_with_options(key, || graph, options);
for (name, data) in ¶ms {
compiled.set_param(name, data);
}
}
cache.get_or_compile_with_options(
key,
|| panic!("compile_cache_ensure_graph: missing {key}"),
options,
)
}