rlx-models-core 0.2.1

Shared config, weight loading, and compile helpers for RLX model crates
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.

//! Shared helpers for autoregressive decode loops (KV cache + bucketed compile cache).

use anyhow::{Context, Result};
use rlx_ir::{Graph, hir::HirModule};
use rlx_runtime::compile_cache::{BucketedCompileCache, CacheRunInput, CompileCache};
use rlx_runtime::kv_cache::LayerKvCache;
use rlx_runtime::{CompileOptions, CompiledGraph};
use std::collections::HashMap;

/// Decode step outputs: logits plus per-layer K/V (row-major `[seq * kv_dim]`).
pub type DecodeLogitsKv = (Vec<f32>, Vec<Vec<f32>>, Vec<Vec<f32>>);

pub use rlx_runtime::kv_cache::LayerKvCache as KvCacheState;

/// LRU prefill cache key: `(batch << 32) | seq`.
pub fn prefill_cache_key(batch: usize, seq: usize) -> u64 {
    ((batch as u64) << 32) | (seq as u64)
}

/// `past_k_{i}` / `past_v_{i}` input names for `num_layers` decoder layers.
pub fn past_kv_input_names(num_layers: usize) -> Vec<String> {
    (0..num_layers)
        .flat_map(|i| [format!("past_k_{i}"), format!("past_v_{i}")])
        .collect()
}

/// Split decode graph outputs into logits + per-layer K/V (one-shot decode, no slicing).
pub fn split_decode_logits_kv(outputs: Vec<Vec<f32>>, num_layers: usize) -> Result<DecodeLogitsKv> {
    if outputs.len() != 1 + 2 * num_layers {
        anyhow::bail!(
            "decode graph produced {} outputs, expected {}",
            outputs.len(),
            1 + 2 * num_layers
        );
    }
    let mut iter = outputs.into_iter();
    let logits = iter.next().context("decode logits missing")?;
    let mut layers_k = Vec::with_capacity(num_layers);
    let mut layers_v = Vec::with_capacity(num_layers);
    for _ in 0..num_layers {
        layers_k.push(iter.next().context("decode k missing")?);
        layers_v.push(iter.next().context("decode v missing")?);
    }
    Ok((logits, layers_k, layers_v))
}

/// Build KV state from prefill-with-cache outputs (`logits` + `2 * num_layers` tensors).
pub fn kv_from_prefill_outputs(
    outputs: Vec<Vec<f32>>,
    batch: usize,
    seq: usize,
    kv_dim: usize,
    num_layers: usize,
) -> Result<(Vec<f32>, LayerKvCache)> {
    if outputs.len() != 1 + 2 * num_layers {
        anyhow::bail!(
            "prefill produced {} outputs, expected {}",
            outputs.len(),
            1 + 2 * num_layers
        );
    }
    let expected_kv_len = batch * seq * kv_dim;
    let mut iter = outputs.into_iter();
    let logits = iter.next().context("prefill logits missing")?;
    let mut layers_k = Vec::with_capacity(num_layers);
    let mut layers_v = Vec::with_capacity(num_layers);
    for layer in 0..num_layers {
        let k = iter.next().context("prefill k missing")?;
        let v = iter.next().context("prefill v missing")?;
        if k.len() != expected_kv_len || v.len() != expected_kv_len {
            anyhow::bail!(
                "layer {layer}: k.len={} v.len={} expected {expected_kv_len}",
                k.len(),
                v.len()
            );
        }
        layers_k.push(k);
        layers_v.push(v);
    }
    Ok((
        logits,
        LayerKvCache {
            past_len: seq,
            layers_k,
            layers_v,
        },
    ))
}

/// One bucketed decode step: compile bucket if needed, pad K/V, run, slice updated K/V.
///
/// `fixed_inputs` must include `mask` (use [`bucket_decode_mask`]) when the graph uses
/// `MaskKind::Custom`. Rope / token inputs should use `row_inner: None`.
pub fn run_bucketed_kv_decode<F>(
    cache: &mut BucketedCompileCache,
    past_seq: usize,
    kv: &LayerKvCache,
    kv_dim: usize,
    num_layers: usize,
    fixed_inputs: &[CacheRunInput<'_>],
    build: F,
    options: &CompileOptions,
) -> Result<DecodeLogitsKv>
where
    F: FnOnce(u64) -> (Graph, HashMap<String, Vec<f32>>),
{
    let key = past_seq as u64;
    let upper = cache
        .ensure_graph_with_params(key, build, options)
        .ok_or_else(|| anyhow::anyhow!("past_seq {past_seq} outside decode buckets"))?
        .0 as usize;

    let (padded_k, padded_v) = kv.pad_layers_to_upper(upper as u64, kv_dim);
    let key_names = past_kv_input_names(num_layers);

    let mut specs: Vec<CacheRunInput<'_>> = Vec::with_capacity(fixed_inputs.len() + 2 * num_layers);
    for inp in fixed_inputs {
        specs.push(CacheRunInput {
            name: inp.name,
            data: inp.data,
            row_inner: inp.row_inner,
        });
    }
    for i in 0..num_layers {
        specs.push(CacheRunInput {
            name: key_names[2 * i].as_str(),
            data: padded_k[i].as_slice(),
            row_inner: Some(kv_dim),
        });
        specs.push(CacheRunInput {
            name: key_names[2 * i + 1].as_str(),
            data: padded_v[i].as_slice(),
            row_inner: Some(kv_dim),
        });
    }

    let mut output_inners = vec![0usize];
    output_inners.extend(std::iter::repeat_n(kv_dim, 2 * num_layers));

    let (_upper, raw) = cache
        .run_padded_mixed(
            key,
            past_seq + 1,
            |_| panic!("bucket must be compiled via ensure_graph_with_params"),
            &specs,
            &output_inners,
        )
        .ok_or_else(|| anyhow::anyhow!("run_padded_mixed failed for past_seq {past_seq}"))?;

    split_bucketed_decode_kv(raw, past_seq, kv_dim, num_layers)
}

/// Like [`run_bucketed_kv_decode`] but compiles decode graphs from HIR (Gemma / Llama / Qwen3.5).
pub fn run_bucketed_kv_decode_hir<F>(
    cache: &mut BucketedCompileCache,
    past_seq: usize,
    kv: &LayerKvCache,
    kv_dim: usize,
    num_layers: usize,
    fixed_inputs: &[CacheRunInput<'_>],
    build: F,
    options: &CompileOptions,
) -> Result<DecodeLogitsKv>
where
    F: FnOnce(u64) -> (HirModule, HashMap<String, Vec<f32>>),
{
    let key = past_seq as u64;
    let upper = cache
        .ensure_hir_with_params(key, build, options)
        .ok_or_else(|| anyhow::anyhow!("past_seq {past_seq} outside decode buckets"))?
        .0 as usize;

    let (padded_k, padded_v) = kv.pad_layers_to_upper(upper as u64, kv_dim);
    let key_names = past_kv_input_names(num_layers);

    let mut specs: Vec<CacheRunInput<'_>> = Vec::with_capacity(fixed_inputs.len() + 2 * num_layers);
    for inp in fixed_inputs {
        specs.push(CacheRunInput {
            name: inp.name,
            data: inp.data,
            row_inner: inp.row_inner,
        });
    }
    for i in 0..num_layers {
        specs.push(CacheRunInput {
            name: key_names[2 * i].as_str(),
            data: padded_k[i].as_slice(),
            row_inner: Some(kv_dim),
        });
        specs.push(CacheRunInput {
            name: key_names[2 * i + 1].as_str(),
            data: padded_v[i].as_slice(),
            row_inner: Some(kv_dim),
        });
    }

    let mut output_inners = vec![0usize];
    output_inners.extend(std::iter::repeat_n(kv_dim, 2 * num_layers));

    let (_upper, raw) = cache
        .run_padded_mixed(
            key,
            past_seq + 1,
            |_| panic!("bucket must be compiled via ensure_hir_with_params"),
            &specs,
            &output_inners,
        )
        .ok_or_else(|| anyhow::anyhow!("run_padded_mixed failed for past_seq {past_seq}"))?;

    split_bucketed_decode_kv(raw, past_seq, kv_dim, num_layers)
}

/// Slice bucketed decode K/V outputs back to `past_seq + 1` rows.
pub fn split_bucketed_decode_kv(
    outputs: Vec<Vec<f32>>,
    past_seq: usize,
    kv_dim: usize,
    num_layers: usize,
) -> Result<DecodeLogitsKv> {
    if outputs.len() != 1 + 2 * num_layers {
        anyhow::bail!(
            "bucketed decode produced {} outputs, expected {}",
            outputs.len(),
            1 + 2 * num_layers
        );
    }
    let mut iter = outputs.into_iter();
    let logits = iter.next().context("bucketed decode logits missing")?;
    let real_len = (past_seq + 1) * kv_dim;
    let mut new_k = Vec::with_capacity(num_layers);
    let mut new_v = Vec::with_capacity(num_layers);
    for _ in 0..num_layers {
        let k = iter.next().context("bucketed k missing")?;
        let v = iter.next().context("bucketed v missing")?;
        new_k.push(k[..real_len.min(k.len())].to_vec());
        new_v.push(v[..real_len.min(v.len())].to_vec());
    }
    Ok((logits, new_k, new_v))
}

/// Insert a pre-sized graph into an LRU [`CompileCache`].
pub fn compile_cache_ensure_graph<'a>(
    cache: &'a mut CompileCache,
    key: u64,
    graph: Graph,
    params: HashMap<String, Vec<f32>>,
    options: &CompileOptions,
) -> &'a mut CompiledGraph {
    if !cache.contains(key) {
        let compiled = cache.get_or_compile_with_options(key, || graph, options);
        for (name, data) in &params {
            compiled.set_param(name, data);
        }
    }
    cache.get_or_compile_with_options(
        key,
        || panic!("compile_cache_ensure_graph: missing {key}"),
        options,
    )
}