Skip to main content

rlx_runtime/
lm.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Generic language-model runner trait and shared builder.
17//!
18//! Until now every `rlx-<family>` model crate carried its own
19//! `*RunnerBuilder` (Qwen3RunnerBuilder, Llama32RunnerBuilder, …)
20//! with the same fields, the same `*ConfigSource { Embedded |
21//! JsonFile | Explicit(T) }` enum, and the same auto-packed-GGUF
22//! heuristic. This module hoists those shapes upstream so that:
23//!
24//!   1. `LmRunner` can live in `rlx-runtime` (today's home in
25//!      `rlx-cli` forces every model crate to take a dependency on
26//!      the CLI helper crate).
27//!   2. Per-family runners can `Deref` to / wrap [`LmRunnerBuilder`]
28//!      instead of redefining the same fields.
29//!   3. Downstream tools (`skill`, web apps) can talk to runners
30//!      through one trait without compiling in every model crate.
31//!
32//! The trait surface mirrors the existing `rlx_cli::LmRunner`. The
33//! CLI re-export is kept for backwards compat.
34
35use std::path::{Path, PathBuf};
36
37use crate::Device;
38
39/// Minimal per-family runner interface used by `auto_dispatch` and
40/// the `rlx-text` / `skill` integration.
41///
42/// Implementations must be `Send` so the boxed trait can move across
43/// threads (e.g. when a server runs inference on a worker pool).
44/// `Sync` is intentionally not required — runners hold mutable
45/// per-call compile / cache state.
46pub trait LmRunner: Send {
47    /// Short family identifier (`"qwen3"`, `"llama32"`, `"gemma"`).
48    fn family(&self) -> &'static str;
49
50    /// LM head vocabulary size.
51    fn vocab_size(&self) -> usize;
52
53    /// Run prefill on `prompt_ids` and return last-token logits.
54    fn predict_logits(&mut self, prompt_ids: &[u32]) -> anyhow::Result<Vec<f32>>;
55
56    /// Generate up to `n_new` tokens after `prompt_ids` using greedy
57    /// (argmax) sampling. The default impl re-prefills on the full
58    /// context each step — per-family runners should override with
59    /// their cached decode fast path.
60    ///
61    /// `on_token` returns `true` to continue, `false` to stop.
62    fn generate(
63        &mut self,
64        prompt_ids: &[u32],
65        n_new: usize,
66        on_token: &mut dyn FnMut(u32) -> bool,
67    ) -> anyhow::Result<Vec<u32>> {
68        let mut context: Vec<u32> = prompt_ids.to_vec();
69        let mut produced: Vec<u32> = Vec::with_capacity(n_new);
70        for _ in 0..n_new {
71            let logits = self.predict_logits(&context)?;
72            let next = argmax_u32(&logits);
73            produced.push(next);
74            let cont = on_token(next);
75            context.push(next);
76            if !cont {
77                break;
78            }
79        }
80        Ok(produced)
81    }
82
83    /// Whether this runner supports multimodal (image+text) generation.
84    fn supports_multimodal(&self) -> bool {
85        false
86    }
87
88    /// Multimodal generation: prefill with text where image markers are
89    /// spliced with vision embeddings derived from `rgb`.
90    fn generate_multimodal(
91        &mut self,
92        _prompt: &str,
93        _rgb: &[u8],
94        _img_w: usize,
95        _img_h: usize,
96        _tokenizer: Option<&Path>,
97        _n_new: usize,
98        _on_token: &mut dyn FnMut(u32) -> bool,
99    ) -> anyhow::Result<Vec<u32>> {
100        Err(anyhow::anyhow!(
101            "this LmRunner does not support multimodal generation"
102        ))
103    }
104}
105
106fn argmax_u32(logits: &[f32]) -> u32 {
107    let mut best = 0usize;
108    let mut best_v = f32::NEG_INFINITY;
109    for (i, &v) in logits.iter().enumerate() {
110        if v > best_v {
111            best_v = v;
112            best = i;
113        }
114    }
115    best as u32
116}
117
118// ─────────────────────────────────────────────────────────────────
119// Weight format + config source
120// ─────────────────────────────────────────────────────────────────
121
122/// Weight file format. Detected from the file extension by default;
123/// the CLI accepts `--format` to override.
124#[derive(Debug, Clone, Copy, PartialEq, Eq)]
125pub enum WeightFormat {
126    Safetensors,
127    Gguf,
128}
129
130impl WeightFormat {
131    /// Infer format from a path extension.
132    pub fn from_path(path: &Path) -> anyhow::Result<Self> {
133        match path.extension().and_then(|s| s.to_str()) {
134            Some("safetensors") => Ok(Self::Safetensors),
135            Some("gguf") => Ok(Self::Gguf),
136            other => Err(anyhow::anyhow!(
137                "cannot autodetect weight format from extension {:?} on {:?}",
138                other,
139                path
140            )),
141        }
142    }
143
144    /// Parse CLI `--format` values (`safetensors` | `gguf`).
145    pub fn parse(s: &str) -> anyhow::Result<Self> {
146        match s {
147            "safetensors" => Ok(Self::Safetensors),
148            "gguf" => Ok(Self::Gguf),
149            other => Err(anyhow::anyhow!("expected safetensors|gguf, got {other}")),
150        }
151    }
152}
153
154/// Where to read a model config from.
155///
156/// Replaces the per-family `Qwen3ConfigSource`, `Llama32ConfigSource`,
157/// `GemmaConfigSource`, `Qwen35ConfigSource` enums.
158#[derive(Debug, Clone, Default)]
159pub enum ConfigSource<T> {
160    /// Read from GGUF metadata.
161    #[default]
162    Embedded,
163    /// Read from a HuggingFace `config.json` at this path.
164    JsonFile(PathBuf),
165    /// Use the supplied config object directly.
166    Explicit(T),
167}
168
169// ─────────────────────────────────────────────────────────────────
170// Sampling
171// ─────────────────────────────────────────────────────────────────
172
173/// Sampling parameters. Greedy when `temperature == 0`.
174#[derive(Debug, Clone, Copy)]
175pub struct SampleOpts {
176    pub temperature: f32,
177    pub top_p: f32,
178    pub top_k: Option<u32>,
179    pub repetition_penalty: f32,
180}
181
182impl Default for SampleOpts {
183    fn default() -> Self {
184        Self::greedy()
185    }
186}
187
188impl SampleOpts {
189    pub fn greedy() -> Self {
190        Self {
191            temperature: 0.0,
192            top_p: 1.0,
193            top_k: None,
194            repetition_penalty: 1.0,
195        }
196    }
197
198    pub fn nucleus(temperature: f32, top_p: f32) -> Self {
199        Self {
200            temperature,
201            top_p,
202            top_k: None,
203            repetition_penalty: 1.0,
204        }
205    }
206
207    pub fn is_greedy(&self) -> bool {
208        self.temperature <= 0.0
209    }
210}
211
212// ─────────────────────────────────────────────────────────────────
213// Shared builder
214// ─────────────────────────────────────────────────────────────────
215
216/// Auto-packed threshold: prefer K-quant packed loading for GGUF
217/// files >= this size. Cuts host memory ~6× on Q4_K_M models.
218pub const PACKED_GGUF_AUTO_THRESHOLD_BYTES: u64 = 256 * 1024 * 1024;
219
220/// Builder fields common to every per-family runner.
221///
222/// Per-family runner builders should wrap this and forward the
223/// methods (or use `#[rlx_runner]` from `rlx-macros`).
224#[derive(Debug, Clone)]
225pub struct LmRunnerBuilder<Cfg> {
226    pub weights: Option<PathBuf>,
227    pub config: ConfigSource<Cfg>,
228    pub device: Device,
229    pub max_seq: usize,
230    pub max_memory_gb: Option<f32>,
231    pub stream: bool,
232    pub sample: SampleOpts,
233    pub format: Option<WeightFormat>,
234    /// `None` = auto-detect (packed when GGUF ≥ 256 MB).
235    pub packed_weights: Option<bool>,
236    /// Substring for picking one GGUF in a directory (default `Q4_K_M`).
237    pub prefer_gguf: Option<String>,
238}
239
240impl<Cfg> Default for LmRunnerBuilder<Cfg> {
241    fn default() -> Self {
242        Self {
243            weights: None,
244            config: ConfigSource::Embedded,
245            device: Device::Cpu,
246            max_seq: 128,
247            max_memory_gb: None,
248            stream: true,
249            sample: SampleOpts::greedy(),
250            format: None,
251            packed_weights: None,
252            prefer_gguf: None,
253        }
254    }
255}
256
257impl<Cfg> LmRunnerBuilder<Cfg> {
258    pub fn new() -> Self {
259        Self::default()
260    }
261
262    pub fn weights<P: Into<PathBuf>>(mut self, p: P) -> Self {
263        self.weights = Some(p.into());
264        self
265    }
266
267    pub fn config(mut self, src: ConfigSource<Cfg>) -> Self {
268        self.config = src;
269        self
270    }
271
272    pub fn config_value(self, cfg: Cfg) -> Self {
273        self.config(ConfigSource::Explicit(cfg))
274    }
275
276    pub fn device(mut self, d: Device) -> Self {
277        self.device = d;
278        self
279    }
280
281    pub fn max_seq(mut self, n: usize) -> Self {
282        self.max_seq = n;
283        self
284    }
285
286    pub fn max_memory_gb(mut self, gb: f32) -> Self {
287        self.max_memory_gb = Some(gb);
288        self
289    }
290
291    pub fn stream(mut self, on: bool) -> Self {
292        self.stream = on;
293        self
294    }
295
296    pub fn sample(mut self, s: SampleOpts) -> Self {
297        self.sample = s;
298        self
299    }
300
301    pub fn format(mut self, fmt: WeightFormat) -> Self {
302        self.format = Some(fmt);
303        self
304    }
305
306    pub fn packed_weights(mut self, on: bool) -> Self {
307        self.packed_weights = Some(on);
308        self
309    }
310
311    pub fn prefer_gguf<S: Into<String>>(mut self, q: S) -> Self {
312        self.prefer_gguf = Some(q.into());
313        self
314    }
315
316    /// Resolve the format using the explicit override or the file extension.
317    pub fn resolved_format(&self) -> anyhow::Result<WeightFormat> {
318        match self.format {
319            Some(f) => Ok(f),
320            None => {
321                let p = self
322                    .weights
323                    .as_deref()
324                    .ok_or_else(|| anyhow::anyhow!("weights path required"))?;
325                WeightFormat::from_path(p)
326            }
327        }
328    }
329
330    /// Determine whether packed GGUF loading should be used. Honors an
331    /// explicit override; otherwise auto-enables for GGUF files at or
332    /// above [`PACKED_GGUF_AUTO_THRESHOLD_BYTES`].
333    pub fn resolved_packed(&self, fmt: WeightFormat) -> bool {
334        match self.packed_weights {
335            Some(b) => b,
336            None => {
337                if !matches!(fmt, WeightFormat::Gguf) {
338                    return false;
339                }
340                self.weights
341                    .as_deref()
342                    .and_then(|p| std::fs::metadata(p).ok())
343                    .map(|m| m.len() >= PACKED_GGUF_AUTO_THRESHOLD_BYTES)
344                    .unwrap_or(false)
345            }
346        }
347    }
348}
349
350// ─────────────────────────────────────────────────────────────────
351// Model registry (auto-dispatch by path)
352// ─────────────────────────────────────────────────────────────────
353
354/// Family-routing entry: a short name + a probe closure that returns
355/// `true` for files this family should handle.
356///
357/// Registered at process start by `register_model` (or by a
358/// `#[rlx_runner]`-generated `inventory` entry). [`auto_runner_name`]
359/// walks the registry and returns the first matching family.
360pub struct ModelRegistration {
361    pub family: &'static str,
362    pub description: &'static str,
363    /// `(arch_str_lower_case, path) -> bool`. `arch_str_lower_case` is
364    /// the GGUF `general.architecture` (`""` for safetensors); `path`
365    /// is the concrete weights file. Implementations should return
366    /// `true` if the family owns this file.
367    pub matches: fn(arch: &str, path: &Path) -> bool,
368}
369
370inventory::collect!(ModelRegistration);
371
372/// Re-export of `inventory` so the `register_lm_runner!` proc-macro
373/// can call `::rlx_runtime::lm::inventory::submit!` without forcing
374/// every caller to add `inventory` to their Cargo.toml.
375pub extern crate inventory;
376
377/// Iterate over every registered family.
378pub fn registered_models() -> impl Iterator<Item = &'static ModelRegistration> {
379    inventory::iter::<ModelRegistration>.into_iter()
380}
381
382/// Find the family that claims `(arch, path)`.
383pub fn auto_runner_name(arch: &str, path: &Path) -> Option<&'static str> {
384    let arch_lc = arch.to_ascii_lowercase();
385    registered_models()
386        .find(|m| (m.matches)(&arch_lc, path))
387        .map(|m| m.family)
388}
389
390#[cfg(test)]
391mod tests {
392    use super::*;
393
394    #[test]
395    fn config_source_default_is_embedded() {
396        let s: ConfigSource<()> = ConfigSource::default();
397        assert!(matches!(s, ConfigSource::Embedded));
398    }
399
400    #[test]
401    fn builder_defaults_match_legacy_runners() {
402        let b: LmRunnerBuilder<()> = LmRunnerBuilder::new();
403        assert_eq!(b.device, Device::Cpu);
404        assert_eq!(b.max_seq, 128);
405        assert!(b.stream);
406        assert!(b.sample.is_greedy());
407        assert!(b.packed_weights.is_none());
408    }
409
410    #[test]
411    fn packed_auto_size_threshold() {
412        let mut b: LmRunnerBuilder<()> = LmRunnerBuilder::new();
413        b.weights = Some("/nonexistent/file.gguf".into());
414        // Missing file → auto returns false (no metadata).
415        assert!(!b.resolved_packed(WeightFormat::Gguf));
416        // Explicit override wins.
417        b.packed_weights = Some(true);
418        assert!(b.resolved_packed(WeightFormat::Gguf));
419        // Non-GGUF never auto-packs.
420        b.packed_weights = None;
421        assert!(!b.resolved_packed(WeightFormat::Safetensors));
422    }
423}