rlx_runtime/lm.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Generic language-model runner trait and shared builder.
17//!
18//! Until now every `rlx-<family>` model crate carried its own
19//! `*RunnerBuilder` (Qwen3RunnerBuilder, Llama32RunnerBuilder, …)
20//! with the same fields, the same `*ConfigSource { Embedded |
21//! JsonFile | Explicit(T) }` enum, and the same auto-packed-GGUF
22//! heuristic. This module hoists those shapes upstream so that:
23//!
24//! 1. `LmRunner` can live in `rlx-runtime` (today's home in
25//! `rlx-cli` forces every model crate to take a dependency on
26//! the CLI helper crate).
27//! 2. Per-family runners can `Deref` to / wrap [`LmRunnerBuilder`]
28//! instead of redefining the same fields.
29//! 3. Downstream tools (`skill`, web apps) can talk to runners
30//! through one trait without compiling in every model crate.
31//!
32//! The trait surface mirrors the existing `rlx_cli::LmRunner`. The
33//! CLI re-export is kept for backwards compat.
34
35use std::path::{Path, PathBuf};
36
37use crate::Device;
38
39/// Minimal per-family runner interface used by `auto_dispatch` and
40/// the `rlx-text` / `skill` integration.
41///
42/// Implementations must be `Send` so the boxed trait can move across
43/// threads (e.g. when a server runs inference on a worker pool).
44/// `Sync` is intentionally not required — runners hold mutable
45/// per-call compile / cache state.
46pub trait LmRunner: Send {
47 /// Short family identifier (`"qwen3"`, `"llama32"`, `"gemma"`).
48 fn family(&self) -> &'static str;
49
50 /// LM head vocabulary size.
51 fn vocab_size(&self) -> usize;
52
53 /// Run prefill on `prompt_ids` and return last-token logits.
54 fn predict_logits(&mut self, prompt_ids: &[u32]) -> anyhow::Result<Vec<f32>>;
55
56 /// Generate up to `n_new` tokens after `prompt_ids` using greedy
57 /// (argmax) sampling. The default impl re-prefills on the full
58 /// context each step — per-family runners should override with
59 /// their cached decode fast path.
60 ///
61 /// `on_token` returns `true` to continue, `false` to stop.
62 fn generate(
63 &mut self,
64 prompt_ids: &[u32],
65 n_new: usize,
66 on_token: &mut dyn FnMut(u32) -> bool,
67 ) -> anyhow::Result<Vec<u32>> {
68 let mut context: Vec<u32> = prompt_ids.to_vec();
69 let mut produced: Vec<u32> = Vec::with_capacity(n_new);
70 for _ in 0..n_new {
71 let logits = self.predict_logits(&context)?;
72 let next = argmax_u32(&logits);
73 produced.push(next);
74 let cont = on_token(next);
75 context.push(next);
76 if !cont {
77 break;
78 }
79 }
80 Ok(produced)
81 }
82
83 /// Whether this runner supports multimodal (image+text) generation.
84 fn supports_multimodal(&self) -> bool {
85 false
86 }
87
88 /// Multimodal generation: prefill with text where image markers are
89 /// spliced with vision embeddings derived from `rgb`.
90 fn generate_multimodal(
91 &mut self,
92 _prompt: &str,
93 _rgb: &[u8],
94 _img_w: usize,
95 _img_h: usize,
96 _tokenizer: Option<&Path>,
97 _n_new: usize,
98 _on_token: &mut dyn FnMut(u32) -> bool,
99 ) -> anyhow::Result<Vec<u32>> {
100 Err(anyhow::anyhow!(
101 "this LmRunner does not support multimodal generation"
102 ))
103 }
104}
105
106fn argmax_u32(logits: &[f32]) -> u32 {
107 let mut best = 0usize;
108 let mut best_v = f32::NEG_INFINITY;
109 for (i, &v) in logits.iter().enumerate() {
110 if v > best_v {
111 best_v = v;
112 best = i;
113 }
114 }
115 best as u32
116}
117
118// ─────────────────────────────────────────────────────────────────
119// Weight format + config source
120// ─────────────────────────────────────────────────────────────────
121
122/// Weight file format. Detected from the file extension by default;
123/// the CLI accepts `--format` to override.
124#[derive(Debug, Clone, Copy, PartialEq, Eq)]
125pub enum WeightFormat {
126 Safetensors,
127 Gguf,
128}
129
130impl WeightFormat {
131 /// Infer format from a path extension.
132 pub fn from_path(path: &Path) -> anyhow::Result<Self> {
133 match path.extension().and_then(|s| s.to_str()) {
134 Some("safetensors") => Ok(Self::Safetensors),
135 Some("gguf") => Ok(Self::Gguf),
136 other => Err(anyhow::anyhow!(
137 "cannot autodetect weight format from extension {:?} on {:?}",
138 other,
139 path
140 )),
141 }
142 }
143
144 /// Parse CLI `--format` values (`safetensors` | `gguf`).
145 pub fn parse(s: &str) -> anyhow::Result<Self> {
146 match s {
147 "safetensors" => Ok(Self::Safetensors),
148 "gguf" => Ok(Self::Gguf),
149 other => Err(anyhow::anyhow!("expected safetensors|gguf, got {other}")),
150 }
151 }
152}
153
154/// Where to read a model config from.
155///
156/// Replaces the per-family `Qwen3ConfigSource`, `Llama32ConfigSource`,
157/// `GemmaConfigSource`, `Qwen35ConfigSource` enums.
158#[derive(Debug, Clone, Default)]
159pub enum ConfigSource<T> {
160 /// Read from GGUF metadata.
161 #[default]
162 Embedded,
163 /// Read from a HuggingFace `config.json` at this path.
164 JsonFile(PathBuf),
165 /// Use the supplied config object directly.
166 Explicit(T),
167}
168
169// ─────────────────────────────────────────────────────────────────
170// Sampling
171// ─────────────────────────────────────────────────────────────────
172
173/// Sampling parameters. Greedy when `temperature == 0`.
174#[derive(Debug, Clone, Copy)]
175pub struct SampleOpts {
176 pub temperature: f32,
177 pub top_p: f32,
178 pub top_k: Option<u32>,
179 pub repetition_penalty: f32,
180}
181
182impl Default for SampleOpts {
183 fn default() -> Self {
184 Self::greedy()
185 }
186}
187
188impl SampleOpts {
189 pub fn greedy() -> Self {
190 Self {
191 temperature: 0.0,
192 top_p: 1.0,
193 top_k: None,
194 repetition_penalty: 1.0,
195 }
196 }
197
198 pub fn nucleus(temperature: f32, top_p: f32) -> Self {
199 Self {
200 temperature,
201 top_p,
202 top_k: None,
203 repetition_penalty: 1.0,
204 }
205 }
206
207 pub fn is_greedy(&self) -> bool {
208 self.temperature <= 0.0
209 }
210}
211
212// ─────────────────────────────────────────────────────────────────
213// Shared builder
214// ─────────────────────────────────────────────────────────────────
215
216/// Auto-packed threshold: prefer K-quant packed loading for GGUF
217/// files >= this size. Cuts host memory ~6× on Q4_K_M models.
218pub const PACKED_GGUF_AUTO_THRESHOLD_BYTES: u64 = 256 * 1024 * 1024;
219
220/// Builder fields common to every per-family runner.
221///
222/// Per-family runner builders should wrap this and forward the
223/// methods (or use `#[rlx_runner]` from `rlx-macros`).
224#[derive(Debug, Clone)]
225pub struct LmRunnerBuilder<Cfg> {
226 pub weights: Option<PathBuf>,
227 pub config: ConfigSource<Cfg>,
228 pub device: Device,
229 pub max_seq: usize,
230 pub max_memory_gb: Option<f32>,
231 pub stream: bool,
232 pub sample: SampleOpts,
233 pub format: Option<WeightFormat>,
234 /// `None` = auto-detect (packed when GGUF ≥ 256 MB).
235 pub packed_weights: Option<bool>,
236 /// Substring for picking one GGUF in a directory (default `Q4_K_M`).
237 pub prefer_gguf: Option<String>,
238}
239
240impl<Cfg> Default for LmRunnerBuilder<Cfg> {
241 fn default() -> Self {
242 Self {
243 weights: None,
244 config: ConfigSource::Embedded,
245 device: Device::Cpu,
246 max_seq: 128,
247 max_memory_gb: None,
248 stream: true,
249 sample: SampleOpts::greedy(),
250 format: None,
251 packed_weights: None,
252 prefer_gguf: None,
253 }
254 }
255}
256
257impl<Cfg> LmRunnerBuilder<Cfg> {
258 pub fn new() -> Self {
259 Self::default()
260 }
261
262 pub fn weights<P: Into<PathBuf>>(mut self, p: P) -> Self {
263 self.weights = Some(p.into());
264 self
265 }
266
267 pub fn config(mut self, src: ConfigSource<Cfg>) -> Self {
268 self.config = src;
269 self
270 }
271
272 pub fn config_value(self, cfg: Cfg) -> Self {
273 self.config(ConfigSource::Explicit(cfg))
274 }
275
276 pub fn device(mut self, d: Device) -> Self {
277 self.device = d;
278 self
279 }
280
281 pub fn max_seq(mut self, n: usize) -> Self {
282 self.max_seq = n;
283 self
284 }
285
286 pub fn max_memory_gb(mut self, gb: f32) -> Self {
287 self.max_memory_gb = Some(gb);
288 self
289 }
290
291 pub fn stream(mut self, on: bool) -> Self {
292 self.stream = on;
293 self
294 }
295
296 pub fn sample(mut self, s: SampleOpts) -> Self {
297 self.sample = s;
298 self
299 }
300
301 pub fn format(mut self, fmt: WeightFormat) -> Self {
302 self.format = Some(fmt);
303 self
304 }
305
306 pub fn packed_weights(mut self, on: bool) -> Self {
307 self.packed_weights = Some(on);
308 self
309 }
310
311 pub fn prefer_gguf<S: Into<String>>(mut self, q: S) -> Self {
312 self.prefer_gguf = Some(q.into());
313 self
314 }
315
316 /// Resolve the format using the explicit override or the file extension.
317 pub fn resolved_format(&self) -> anyhow::Result<WeightFormat> {
318 match self.format {
319 Some(f) => Ok(f),
320 None => {
321 let p = self
322 .weights
323 .as_deref()
324 .ok_or_else(|| anyhow::anyhow!("weights path required"))?;
325 WeightFormat::from_path(p)
326 }
327 }
328 }
329
330 /// Determine whether packed GGUF loading should be used. Honors an
331 /// explicit override; otherwise auto-enables for GGUF files at or
332 /// above [`PACKED_GGUF_AUTO_THRESHOLD_BYTES`].
333 pub fn resolved_packed(&self, fmt: WeightFormat) -> bool {
334 match self.packed_weights {
335 Some(b) => b,
336 None => {
337 if !matches!(fmt, WeightFormat::Gguf) {
338 return false;
339 }
340 self.weights
341 .as_deref()
342 .and_then(|p| std::fs::metadata(p).ok())
343 .map(|m| m.len() >= PACKED_GGUF_AUTO_THRESHOLD_BYTES)
344 .unwrap_or(false)
345 }
346 }
347 }
348}
349
350// ─────────────────────────────────────────────────────────────────
351// Model registry (auto-dispatch by path)
352// ─────────────────────────────────────────────────────────────────
353
354/// Family-routing entry: a short name + a probe closure that returns
355/// `true` for files this family should handle.
356///
357/// Registered at process start by `register_model` (or by a
358/// `#[rlx_runner]`-generated `inventory` entry). [`auto_runner_name`]
359/// walks the registry and returns the first matching family.
360pub struct ModelRegistration {
361 pub family: &'static str,
362 pub description: &'static str,
363 /// `(arch_str_lower_case, path) -> bool`. `arch_str_lower_case` is
364 /// the GGUF `general.architecture` (`""` for safetensors); `path`
365 /// is the concrete weights file. Implementations should return
366 /// `true` if the family owns this file.
367 pub matches: fn(arch: &str, path: &Path) -> bool,
368}
369
370inventory::collect!(ModelRegistration);
371
372/// Re-export of `inventory` so the `register_lm_runner!` proc-macro
373/// can call `::rlx_runtime::lm::inventory::submit!` without forcing
374/// every caller to add `inventory` to their Cargo.toml.
375pub extern crate inventory;
376
377/// Iterate over every registered family.
378pub fn registered_models() -> impl Iterator<Item = &'static ModelRegistration> {
379 inventory::iter::<ModelRegistration>.into_iter()
380}
381
382/// Find the family that claims `(arch, path)`.
383pub fn auto_runner_name(arch: &str, path: &Path) -> Option<&'static str> {
384 let arch_lc = arch.to_ascii_lowercase();
385 registered_models()
386 .find(|m| (m.matches)(&arch_lc, path))
387 .map(|m| m.family)
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393
394 #[test]
395 fn config_source_default_is_embedded() {
396 let s: ConfigSource<()> = ConfigSource::default();
397 assert!(matches!(s, ConfigSource::Embedded));
398 }
399
400 #[test]
401 fn builder_defaults_match_legacy_runners() {
402 let b: LmRunnerBuilder<()> = LmRunnerBuilder::new();
403 assert_eq!(b.device, Device::Cpu);
404 assert_eq!(b.max_seq, 128);
405 assert!(b.stream);
406 assert!(b.sample.is_greedy());
407 assert!(b.packed_weights.is_none());
408 }
409
410 #[test]
411 fn packed_auto_size_threshold() {
412 let mut b: LmRunnerBuilder<()> = LmRunnerBuilder::new();
413 b.weights = Some("/nonexistent/file.gguf".into());
414 // Missing file → auto returns false (no metadata).
415 assert!(!b.resolved_packed(WeightFormat::Gguf));
416 // Explicit override wins.
417 b.packed_weights = Some(true);
418 assert!(b.resolved_packed(WeightFormat::Gguf));
419 // Non-GGUF never auto-packs.
420 b.packed_weights = None;
421 assert!(!b.resolved_packed(WeightFormat::Safetensors));
422 }
423}