Skip to main content

rlx_cli/
lm_runner.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Boxed-trait dispatch for LM runners (PLAN.md M3 + M8).
17//!
18//! `LmRunner` is the minimal abstraction `rlx_models::run::auto_runner`
19//! returns from a path. M3 shipped the single-shot `predict_logits`
20//! method; M8 adds streaming `generate(prompt_ids, n_new, on_token)`
21//! with a sampler-agnostic greedy default that delegates to
22//! `predict_logits` on each step.
23//!
24//! Per-family runners with a cached decode path
25//! (`Qwen3Runner::generate`, `Qwen35Runner::generate_with_opts`,
26//! `GemmaRunner::generate`, `Llama32Runner::generate`) should
27//! override this default with their fast path — the default exists
28//! so `auto_runner(path)?.generate(...)` always works, not as a
29//! recommended hot path.
30//!
31//! Each per-family crate provides `impl LmRunner for FooRunner` so
32//! that `rlx-models` can hand back a `Box<dyn LmRunner>` from a
33//! single GGUF path without the caller knowing the family upfront.
34
35use anyhow::Result;
36
37/// Minimal per-family runner interface.
38///
39/// Implementations must be `Send` so the boxed trait can move across
40/// threads (e.g. when `skill` runs inference on a worker pool).
41/// `Sync` is intentionally not required — most runners hold mutable
42/// per-call compile / cache state.
43pub trait LmRunner: Send {
44    /// Short family identifier matching `rlx-cli::arch_runner_name`
45    /// (e.g. `"qwen3"`, `"qwen35"`, `"gemma"`, `"llama32"`). Useful
46    /// for logging / metrics / per-family branches in the caller.
47    fn family(&self) -> &'static str;
48
49    /// LM head vocab size — useful for callers that need to size a
50    /// logit buffer or validate token ids before calling
51    /// [`Self::predict_logits`]. PLAN.md M9.
52    fn vocab_size(&self) -> usize;
53
54    /// Run prefill on `prompt_ids` and return the last-token logits
55    /// over the full vocab. Mirrors the existing `predict_logits`
56    /// method on every per-family runner.
57    fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>>;
58
59    /// Generate up to `n_new` tokens after `prompt_ids` using greedy
60    /// (argmax) sampling. `on_token` fires once per generated token
61    /// and **returns `true` to continue, `false` to stop**. Returns
62    /// the generated id sequence (excluding the prompt).
63    ///
64    /// **Stop-signal honoring varies by family** (PLAN.md M9):
65    ///  * default impl + `Qwen35Runner` — honor the return value.
66    ///  * `Qwen3Runner` / `GemmaRunner` / `Llama32Runner` — call the
67    ///    callback but ignore its return (their inherent `generate`
68    ///    doesn't take a bool callback). Pass an EOS-aware sampler
69    ///    in the caller, or check `produced.last()` after the call.
70    ///
71    /// **Default impl is naive**: re-prefill on the full context
72    /// each step. Per-family runners override with their cached
73    /// decode fast path.
74    fn generate(
75        &mut self,
76        prompt_ids: &[u32],
77        n_new: usize,
78        on_token: &mut dyn FnMut(u32) -> bool,
79    ) -> Result<Vec<u32>> {
80        let mut context: Vec<u32> = prompt_ids.to_vec();
81        let mut produced: Vec<u32> = Vec::with_capacity(n_new);
82        for _ in 0..n_new {
83            let logits = self.predict_logits(&context)?;
84            let next = argmax_u32(&logits);
85            produced.push(next);
86            let cont = on_token(next);
87            context.push(next);
88            if !cont {
89                break;
90            }
91        }
92        Ok(produced)
93    }
94
95    /// Whether this runner supports multimodal (image+text) generation
96    /// via [`Self::generate_multimodal`]. Default `false`. Per-family
97    /// runners that wire a vision encoder (e.g. `Qwen35Runner` with an
98    /// mmproj path) override to `true`.
99    fn supports_multimodal(&self) -> bool {
100        false
101    }
102
103    /// Multimodal text generation: prefill the trunk with `prompt` text
104    /// where image markers are spliced with vision embeddings derived
105    /// from `rgb` (raw RGB bytes, row-major `[h, w, 3]`). Streams one
106    /// token per `on_token` call; returns the full produced sequence.
107    ///
108    /// Default impl returns an error — only family runners that wire
109    /// a vision encoder override this. Match parity with llama-cpp's
110    /// MtmdContext-based multimodal eval path.
111    fn generate_multimodal(
112        &mut self,
113        _prompt: &str,
114        _rgb: &[u8],
115        _img_w: usize,
116        _img_h: usize,
117        _tokenizer: Option<&std::path::Path>,
118        _n_new: usize,
119        _on_token: &mut dyn FnMut(u32) -> bool,
120    ) -> Result<Vec<u32>> {
121        Err(anyhow::anyhow!(
122            "this LmRunner does not support multimodal generation"
123        ))
124    }
125}
126
127fn argmax_u32(logits: &[f32]) -> u32 {
128    let mut best = 0usize;
129    let mut best_v = f32::NEG_INFINITY;
130    for (i, &v) in logits.iter().enumerate() {
131        if v > best_v {
132            best_v = v;
133            best = i;
134        }
135    }
136    best as u32
137}