rlx_cli/lm_runner.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Boxed-trait dispatch for LM runners (PLAN.md M3 + M8).
17//!
18//! `LmRunner` is the minimal abstraction `rlx_models::run::auto_runner`
19//! returns from a path. M3 shipped the single-shot `predict_logits`
20//! method; M8 adds streaming `generate(prompt_ids, n_new, on_token)`
21//! with a sampler-agnostic greedy default that delegates to
22//! `predict_logits` on each step.
23//!
24//! Per-family runners with a cached decode path
25//! (`Qwen3Runner::generate`, `Qwen35Runner::generate_with_opts`,
26//! `GemmaRunner::generate`, `Llama32Runner::generate`) should
27//! override this default with their fast path — the default exists
28//! so `auto_runner(path)?.generate(...)` always works, not as a
29//! recommended hot path.
30//!
31//! Each per-family crate provides `impl LmRunner for FooRunner` so
32//! that `rlx-models` can hand back a `Box<dyn LmRunner>` from a
33//! single GGUF path without the caller knowing the family upfront.
34
35use anyhow::Result;
36
37/// Minimal per-family runner interface.
38///
39/// Implementations must be `Send` so the boxed trait can move across
40/// threads (e.g. when `skill` runs inference on a worker pool).
41/// `Sync` is intentionally not required — most runners hold mutable
42/// per-call compile / cache state.
43pub trait LmRunner: Send {
44 /// Short family identifier matching `rlx-cli::arch_runner_name`
45 /// (e.g. `"qwen3"`, `"qwen35"`, `"gemma"`, `"llama32"`). Useful
46 /// for logging / metrics / per-family branches in the caller.
47 fn family(&self) -> &'static str;
48
49 /// LM head vocab size — useful for callers that need to size a
50 /// logit buffer or validate token ids before calling
51 /// [`Self::predict_logits`]. PLAN.md M9.
52 fn vocab_size(&self) -> usize;
53
54 /// Run prefill on `prompt_ids` and return the last-token logits
55 /// over the full vocab. Mirrors the existing `predict_logits`
56 /// method on every per-family runner.
57 fn predict_logits(&mut self, prompt_ids: &[u32]) -> Result<Vec<f32>>;
58
59 /// Generate up to `n_new` tokens after `prompt_ids` using greedy
60 /// (argmax) sampling. `on_token` fires once per generated token
61 /// and **returns `true` to continue, `false` to stop**. Returns
62 /// the generated id sequence (excluding the prompt).
63 ///
64 /// **Stop-signal honoring varies by family** (PLAN.md M9):
65 /// * default impl + `Qwen35Runner` — honor the return value.
66 /// * `Qwen3Runner` / `GemmaRunner` / `Llama32Runner` — call the
67 /// callback but ignore its return (their inherent `generate`
68 /// doesn't take a bool callback). Pass an EOS-aware sampler
69 /// in the caller, or check `produced.last()` after the call.
70 ///
71 /// **Default impl is naive**: re-prefill on the full context
72 /// each step. Per-family runners override with their cached
73 /// decode fast path.
74 fn generate(
75 &mut self,
76 prompt_ids: &[u32],
77 n_new: usize,
78 on_token: &mut dyn FnMut(u32) -> bool,
79 ) -> Result<Vec<u32>> {
80 let mut context: Vec<u32> = prompt_ids.to_vec();
81 let mut produced: Vec<u32> = Vec::with_capacity(n_new);
82 for _ in 0..n_new {
83 let logits = self.predict_logits(&context)?;
84 let next = argmax_u32(&logits);
85 produced.push(next);
86 let cont = on_token(next);
87 context.push(next);
88 if !cont {
89 break;
90 }
91 }
92 Ok(produced)
93 }
94
95 /// Whether this runner supports multimodal (image+text) generation
96 /// via [`Self::generate_multimodal`]. Default `false`. Per-family
97 /// runners that wire a vision encoder (e.g. `Qwen35Runner` with an
98 /// mmproj path) override to `true`.
99 fn supports_multimodal(&self) -> bool {
100 false
101 }
102
103 /// Multimodal text generation: prefill the trunk with `prompt` text
104 /// where image markers are spliced with vision embeddings derived
105 /// from `rgb` (raw RGB bytes, row-major `[h, w, 3]`). Streams one
106 /// token per `on_token` call; returns the full produced sequence.
107 ///
108 /// Default impl returns an error — only family runners that wire
109 /// a vision encoder override this. Match parity with llama-cpp's
110 /// MtmdContext-based multimodal eval path.
111 fn generate_multimodal(
112 &mut self,
113 _prompt: &str,
114 _rgb: &[u8],
115 _img_w: usize,
116 _img_h: usize,
117 _tokenizer: Option<&std::path::Path>,
118 _n_new: usize,
119 _on_token: &mut dyn FnMut(u32) -> bool,
120 ) -> Result<Vec<u32>> {
121 Err(anyhow::anyhow!(
122 "this LmRunner does not support multimodal generation"
123 ))
124 }
125}
126
127fn argmax_u32(logits: &[f32]) -> u32 {
128 let mut best = 0usize;
129 let mut best_v = f32::NEG_INFINITY;
130 for (i, &v) in logits.iter().enumerate() {
131 if v > best_v {
132 best_v = v;
133 best = i;
134 }
135 }
136 best as u32
137}