rlx-cli 0.2.0

Shared CLI helpers and multiplexer registry for RLX model binaries
Documentation
// RLX — versatile ML compiler + runtime.
// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, version 3.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see <https://www.gnu.org/licenses/>.

//! Boxed-trait dispatch for LM runners (PLAN.md M3 + M8).
//!
//! `LmRunner` is the minimal abstraction `rlx_models::run::auto_runner`
//! returns from a path. M3 shipped the single-shot `predict_logits`
//! method; M8 adds streaming `generate(prompt_ids, n_new, on_token)`
//! with a sampler-agnostic greedy default that delegates to
//! `predict_logits` on each step.
//!
//! Per-family runners with a cached decode path
//! (`Qwen3Runner::generate`, `Qwen35Runner::generate_with_opts`,
//! `GemmaRunner::generate`, `Llama32Runner::generate`) should
//! override this default with their fast path — the default exists
//! so `auto_runner(path)?.generate(...)` always works, not as a
//! recommended hot path.
//!
//! Each per-family crate provides `impl LmRunner for FooRunner` so
//! that `rlx-models` can hand back a `Box<dyn LmRunner>` from a
//! single GGUF path without the caller knowing the family upfront.

use anyhow::Result;

/// Minimal per-family runner interface.
///
/// Implementations must be `Send` so the boxed trait can move across
/// threads (e.g. when `skill` runs inference on a worker pool).
/// `Sync` is intentionally not required — most runners hold mutable
/// per-call compile / cache state.
pub trait LmRunner: Send {
    /// Short family identifier matching `rlx-cli::arch_runner_name`
    /// (e.g. `"qwen3"`, `"qwen35"`, `"gemma"`, `"llama32"`). Useful
    /// for logging / metrics / per-family branches in the caller.
    fn family(&self) -> &'static str;

    /// LM head vocab size — useful for callers that need to size a
    /// logit buffer or validate token ids before calling
    /// [`Self::predict_logits`]. PLAN.md M9.
    fn vocab_size(&self) -> usize;

    /// Run prefill on `prompt_ids` and return the last-token logits
    /// over the full vocab. Mirrors the existing `predict_logits`
    /// method on every per-family runner.
    fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>>;

    /// Generate up to `n_new` tokens after `prompt_ids` using greedy
    /// (argmax) sampling. `on_token` fires once per generated token
    /// and **returns `true` to continue, `false` to stop**. Returns
    /// the generated id sequence (excluding the prompt).
    ///
    /// **Stop-signal honoring varies by family** (PLAN.md M9):
    ///  * default impl + `Qwen35Runner` — honor the return value.
    ///  * `Qwen3Runner` / `GemmaRunner` / `Llama32Runner` — call the
    ///    callback but ignore its return (their inherent `generate`
    ///    doesn't take a bool callback). Pass an EOS-aware sampler
    ///    in the caller, or check `produced.last()` after the call.
    ///
    /// **Default impl is naive**: re-prefill on the full context
    /// each step. Per-family runners override with their cached
    /// decode fast path.
    fn generate(
        &mut self,
        prompt_ids: &[u32],
        n_new: usize,
        on_token: &mut dyn FnMut(u32) -> bool,
    ) -> Result<Vec<u32>> {
        let mut context: Vec<u32> = prompt_ids.to_vec();
        let mut produced: Vec<u32> = Vec::with_capacity(n_new);
        for _ in 0..n_new {
            let logits = self.predict_logits(&context)?;
            let next = argmax_u32(&logits);
            produced.push(next);
            let cont = on_token(next);
            context.push(next);
            if !cont {
                break;
            }
        }
        Ok(produced)
    }

    /// Whether this runner supports multimodal (image+text) generation
    /// via [`Self::generate_multimodal`]. Default `false`. Per-family
    /// runners that wire a vision encoder (e.g. `Qwen35Runner` with an
    /// mmproj path) override to `true`.
    fn supports_multimodal(&self) -> bool {
        false
    }

    /// Multimodal text generation: prefill the trunk with `prompt` text
    /// where image markers are spliced with vision embeddings derived
    /// from `rgb` (raw RGB bytes, row-major `[h, w, 3]`). Streams one
    /// token per `on_token` call; returns the full produced sequence.
    ///
    /// Default impl returns an error — only family runners that wire
    /// a vision encoder override this. Match parity with llama-cpp's
    /// MtmdContext-based multimodal eval path.
    fn generate_multimodal(
        &mut self,
        _prompt: &str,
        _rgb: &[u8],
        _img_w: usize,
        _img_h: usize,
        _tokenizer: Option<&std::path::Path>,
        _n_new: usize,
        _on_token: &mut dyn FnMut(u32) -> bool,
    ) -> Result<Vec<u32>> {
        Err(anyhow::anyhow!(
            "this LmRunner does not support multimodal generation"
        ))
    }
}

fn argmax_u32(logits: &[f32]) -> u32 {
    let mut best = 0usize;
    let mut best_v = f32::NEG_INFINITY;
    for (i, &v) in logits.iter().enumerate() {
        if v > best_v {
            best_v = v;
            best = i;
        }
    }
    best as u32
}