Skip to main content

rlx_runtime/
lm.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Generic language-model runner trait and shared builder.
17//!
18//! Until now every `rlx-<family>` model crate carried its own
19//! `*RunnerBuilder` (Qwen3RunnerBuilder, Llama32RunnerBuilder, …)
20//! with the same fields, the same `*ConfigSource { Embedded |
21//! JsonFile | Explicit(T) }` enum, and the same auto-packed-GGUF
22//! heuristic. This module hoists those shapes upstream so that:
23//!
24//!   1. `LmRunner` can live in `rlx-runtime` (today's home in
25//!      `rlx-cli` forces every model crate to take a dependency on
26//!      the CLI helper crate).
27//!   2. Per-family runners can `Deref` to / wrap [`LmRunnerBuilder`]
28//!      instead of redefining the same fields.
29//!   3. Downstream tools (`skill`, web apps) can talk to runners
30//!      through one trait without compiling in every model crate.
31//!
32//! The trait surface mirrors the existing `rlx_cli::LmRunner`. The
33//! CLI re-export is kept for backwards compat.
34
35use std::path::{Path, PathBuf};
36
37use crate::Device;
38
39/// Minimal per-family runner interface used by `auto_dispatch` and
40/// the `rlx-text` / `skill` integration.
41///
42/// Implementations must be `Send` so the boxed trait can move across
43/// threads (e.g. when a server runs inference on a worker pool).
44/// `Sync` is intentionally not required — runners hold mutable
45/// per-call compile / cache state.
46pub trait LmRunner: Send {
47    /// Short family identifier (`"qwen3"`, `"llama32"`, `"gemma"`).
48    fn family(&self) -> &'static str;
49
50    /// LM head vocabulary size.
51    fn vocab_size(&self) -> usize;
52
53    /// Run prefill on `prompt_ids` and return last-token logits.
54    fn predict_logits(&mut self, prompt_ids: &[u32]) -> anyhow::Result<Vec<f32>>;
55
56    /// Generate up to `n_new` tokens after `prompt_ids` using greedy
57    /// (argmax) sampling. The default impl re-prefills on the full
58    /// context each step — per-family runners should override with
59    /// their cached decode fast path.
60    ///
61    /// `on_token` returns `true` to continue, `false` to stop.
62    fn generate(
63        &mut self,
64        prompt_ids: &[u32],
65        n_new: usize,
66        on_token: &mut dyn FnMut(u32) -> bool,
67    ) -> anyhow::Result<Vec<u32>> {
68        let mut context: Vec<u32> = prompt_ids.to_vec();
69        let mut produced: Vec<u32> = Vec::with_capacity(n_new);
70        for _ in 0..n_new {
71            let logits = self.predict_logits(&context)?;
72            let next = argmax_u32(&logits);
73            produced.push(next);
74            let cont = on_token(next);
75            context.push(next);
76            if !cont {
77                break;
78            }
79        }
80        Ok(produced)
81    }
82
83    /// Whether this runner supports multimodal (image+text) generation.
84    fn supports_multimodal(&self) -> bool {
85        false
86    }
87
88    /// Multimodal generation: prefill with text where image markers are
89    /// spliced with vision embeddings derived from `rgb`.
90    fn generate_multimodal(
91        &mut self,
92        _prompt: &str,
93        _rgb: &[u8],
94        _img_w: usize,
95        _img_h: usize,
96        _tokenizer: Option<&Path>,
97        _n_new: usize,
98        _on_token: &mut dyn FnMut(u32) -> bool,
99    ) -> anyhow::Result<Vec<u32>> {
100        Err(anyhow::anyhow!(
101            "this LmRunner does not support multimodal generation"
102        ))
103    }
104}
105
106fn argmax_u32(logits: &[f32]) -> u32 {
107    let mut best = 0usize;
108    let mut best_v = f32::NEG_INFINITY;
109    for (i, &v) in logits.iter().enumerate() {
110        if v > best_v {
111            best_v = v;
112            best = i;
113        }
114    }
115    best as u32
116}
117
118// ─────────────────────────────────────────────────────────────────
119// Weight format + config source
120// ─────────────────────────────────────────────────────────────────
121
122/// Weight file format. Detected from the file extension by default;
123/// the CLI accepts `--format` to override.
124#[derive(Debug, Clone, Copy, PartialEq, Eq)]
125pub enum WeightFormat {
126    Safetensors,
127    Gguf,
128}
129
130impl WeightFormat {
131    /// Infer format from a path extension.
132    pub fn from_path(path: &Path) -> anyhow::Result<Self> {
133        match path.extension().and_then(|s| s.to_str()) {
134            Some("safetensors") => Ok(Self::Safetensors),
135            Some("gguf") => Ok(Self::Gguf),
136            other => Err(anyhow::anyhow!(
137                "cannot autodetect weight format from extension {:?} on {:?}",
138                other,
139                path
140            )),
141        }
142    }
143
144    /// Parse CLI `--format` values (`safetensors` | `gguf`).
145    pub fn parse(s: &str) -> anyhow::Result<Self> {
146        match s {
147            "safetensors" => Ok(Self::Safetensors),
148            "gguf" => Ok(Self::Gguf),
149            other => Err(anyhow::anyhow!("expected safetensors|gguf, got {other}")),
150        }
151    }
152}
153
154/// Where to read a model config from.
155///
156/// Replaces the per-family `Qwen3ConfigSource`, `Llama32ConfigSource`,
157/// `GemmaConfigSource`, `Qwen35ConfigSource` enums.
158#[derive(Debug, Clone, Default)]
159pub enum ConfigSource<T> {
160    /// Read from GGUF metadata.
161    #[default]
162    Embedded,
163    /// Read from a HuggingFace `config.json` at this path.
164    JsonFile(PathBuf),
165    /// Use the supplied config object directly.
166    Explicit(T),
167}
168
169// ─────────────────────────────────────────────────────────────────
170// Sampling
171// ─────────────────────────────────────────────────────────────────
172
173/// Mirostat variant selection. See `crate::samplers::{MirostatV1, MirostatV2}`.
174#[derive(Debug, Default, Clone, Copy, PartialEq)]
175pub enum MirostatMode {
176    #[default]
177    Off,
178    V1,
179    V2,
180}
181
182/// Sampling parameters. Greedy when `temperature == 0` and no advanced
183/// sampler is enabled. All "advanced" knobs default to off / no-op so
184/// legacy callers see classic top-k/top-p/temperature behaviour.
185///
186/// `into_chain()` turns these flat fields into a `SamplerChain` that
187/// downstream backends can execute. Ordering follows llama.cpp's
188/// canonical chain (penalties → temperature → top-k → typical → top-p
189/// → top-n-σ → xtc → mirostat).
190#[derive(Debug, Clone)]
191pub struct SampleOpts {
192    pub temperature: f32,
193    pub top_p: f32,
194    pub top_k: Option<u32>,
195    pub repetition_penalty: f32,
196
197    // ── advanced samplers ────────────────────────────────────────
198    /// Dynamic temperature [min, max] gated by softmax entropy.
199    /// `None` ⇒ flat temperature only.
200    pub dynamic_temp: Option<(f32, f32)>,
201    /// Exponent used by [`crate::samplers::DynamicTemperature`].
202    pub dynamic_temp_exponent: f32,
203    /// Locally-typical sampling (Meister et al. 2022). 1.0 ⇒ off.
204    pub typical_p: f32,
205    /// Top-n-σ cutoff (Hewitt et al. 2024). 0 ⇒ off.
206    pub top_n_sigma: f32,
207    /// XTC: probability of dropping high-confidence top tokens.
208    pub xtc_threshold: f32,
209    pub xtc_prob: f32,
210    /// DRY repetition penalty knobs.
211    pub dry_multiplier: f32,
212    pub dry_base: f32,
213    pub dry_allowed_length: usize,
214    pub dry_max_ngram: usize,
215    pub dry_sequence_breakers: Vec<u32>,
216    /// Mirostat mode + parameters.
217    pub mirostat: MirostatMode,
218    pub mirostat_tau: f32,
219    pub mirostat_eta: f32,
220    pub mirostat_m: usize,
221    /// Frequency / presence penalties (OpenAI-style).
222    pub frequency_penalty: f32,
223    pub presence_penalty: f32,
224    pub repetition_window: usize,
225    /// Minimum tokens kept by top-p / typical (avoid one-token nucleus).
226    pub min_keep: usize,
227}
228
229impl Default for SampleOpts {
230    fn default() -> Self {
231        Self::greedy()
232    }
233}
234
235impl SampleOpts {
236    pub fn greedy() -> Self {
237        Self {
238            temperature: 0.0,
239            top_p: 1.0,
240            top_k: None,
241            repetition_penalty: 1.0,
242            dynamic_temp: None,
243            dynamic_temp_exponent: 1.0,
244            typical_p: 1.0,
245            top_n_sigma: 0.0,
246            xtc_threshold: 0.0,
247            xtc_prob: 0.0,
248            dry_multiplier: 0.0,
249            dry_base: 1.75,
250            dry_allowed_length: 2,
251            dry_max_ngram: 32,
252            dry_sequence_breakers: Vec::new(),
253            mirostat: MirostatMode::Off,
254            mirostat_tau: 5.0,
255            mirostat_eta: 0.1,
256            mirostat_m: 100,
257            frequency_penalty: 0.0,
258            presence_penalty: 0.0,
259            repetition_window: 64,
260            min_keep: 1,
261        }
262    }
263
264    pub fn nucleus(temperature: f32, top_p: f32) -> Self {
265        Self {
266            temperature,
267            top_p,
268            ..Self::greedy()
269        }
270    }
271
272    pub fn is_greedy(&self) -> bool {
273        self.temperature <= 0.0 && self.mirostat == MirostatMode::Off
274    }
275
276    /// True when only classic top-k/top-p/temperature are configured;
277    /// backends can take a cheap fast path in this case (e.g. the
278    /// existing `sample_row` CPU kernel) instead of building a chain.
279    pub fn is_classic(&self) -> bool {
280        self.dynamic_temp.is_none()
281            && self.typical_p >= 1.0
282            && self.top_n_sigma <= 0.0
283            && self.xtc_prob <= 0.0
284            && self.dry_multiplier <= 0.0
285            && self.mirostat == MirostatMode::Off
286            && self.frequency_penalty == 0.0
287            && self.presence_penalty == 0.0
288            && (self.repetition_penalty - 1.0).abs() < f32::EPSILON
289    }
290
291    /// Build the `SamplerChain` corresponding to these options. The
292    /// returned chain is ready to drive `SamplerChain::sample` against
293    /// a logits row + history. Greedy decoding produces a chain with
294    /// one `Temperature{t:1e-6}` step (which collapses to argmax after
295    /// softmax) — callers that want true greedy can short-circuit via
296    /// `is_greedy()` before building the chain.
297    pub fn into_chain(&self) -> crate::samplers::SamplerChain {
298        use crate::samplers::*;
299        let mut b = SamplerChain::builder();
300
301        // 1. Penalties operate on raw logits, before any temperature
302        //    scaling — matches llama.cpp's order.
303        if (self.repetition_penalty - 1.0).abs() > f32::EPSILON
304            || self.frequency_penalty != 0.0
305            || self.presence_penalty != 0.0
306        {
307            b = b.push(RepetitionPenalty {
308                penalty: self.repetition_penalty,
309                frequency: self.frequency_penalty,
310                presence: self.presence_penalty,
311                last_n: self.repetition_window,
312            });
313        }
314        if self.dry_multiplier > 0.0 {
315            b = b.push(Dry {
316                multiplier: self.dry_multiplier,
317                base: self.dry_base,
318                allowed_length: self.dry_allowed_length,
319                max_ngram: self.dry_max_ngram,
320                sequence_breakers: self.dry_sequence_breakers.clone(),
321            });
322        }
323
324        // 2. Temperature (dynamic or static). Mirostat replaces both.
325        if self.mirostat == MirostatMode::Off {
326            if let Some((mn, mx)) = self.dynamic_temp {
327                b = b.push(DynamicTemperature {
328                    min: mn,
329                    max: mx,
330                    exponent: self.dynamic_temp_exponent,
331                });
332            } else if self.temperature > 0.0 && (self.temperature - 1.0).abs() > f32::EPSILON {
333                b = b.push(Temperature {
334                    t: self.temperature,
335                });
336            } else if self.temperature <= 0.0 {
337                b = b.push(Temperature { t: 1e-6 });
338            }
339        }
340
341        // 3. Filters: top-k → typical → top-p → top-n-sigma → xtc.
342        if let Some(k) = self.top_k {
343            if k > 0 {
344                b = b.push(TopK { k: k as usize });
345            }
346        }
347        if self.typical_p < 1.0 && self.typical_p > 0.0 {
348            b = b.push(TypicalP {
349                p: self.typical_p,
350                min_keep: self.min_keep,
351            });
352        }
353        if self.top_p < 1.0 && self.top_p > 0.0 {
354            b = b.push(TopP {
355                p: self.top_p,
356                min_keep: self.min_keep,
357            });
358        }
359        if self.top_n_sigma > 0.0 {
360            b = b.push(TopNSigma {
361                n: self.top_n_sigma,
362            });
363        }
364        if self.xtc_prob > 0.0 && self.xtc_threshold > 0.0 {
365            b = b.push(Xtc {
366                threshold: self.xtc_threshold,
367                prob: self.xtc_prob,
368                min_keep: self.min_keep,
369            });
370        }
371
372        // 4. Mirostat (replaces softmax+sample at the end of the chain).
373        match self.mirostat {
374            MirostatMode::Off => {}
375            MirostatMode::V1 => {
376                b = b.push(MirostatV1 {
377                    tau: self.mirostat_tau,
378                    eta: self.mirostat_eta,
379                    m: self.mirostat_m,
380                });
381            }
382            MirostatMode::V2 => {
383                b = b.push(MirostatV2 {
384                    tau: self.mirostat_tau,
385                    eta: self.mirostat_eta,
386                });
387            }
388        }
389        b.build()
390    }
391}
392
393// ─────────────────────────────────────────────────────────────────
394// Shared builder
395// ─────────────────────────────────────────────────────────────────
396
397/// Auto-packed threshold: prefer K-quant packed loading for GGUF
398/// files >= this size. Cuts host memory ~6× on Q4_K_M models.
399pub const PACKED_GGUF_AUTO_THRESHOLD_BYTES: u64 = 256 * 1024 * 1024;
400
401/// Builder fields common to every per-family runner.
402///
403/// Per-family runner builders should wrap this and forward the
404/// methods (or use `#[rlx_runner]` from `rlx-macros`).
405#[derive(Debug, Clone)]
406pub struct LmRunnerBuilder<Cfg> {
407    pub weights: Option<PathBuf>,
408    pub config: ConfigSource<Cfg>,
409    pub device: Device,
410    pub max_seq: usize,
411    pub max_memory_gb: Option<f32>,
412    pub stream: bool,
413    pub sample: SampleOpts,
414    pub format: Option<WeightFormat>,
415    /// `None` = auto-detect (packed when GGUF ≥ 256 MB).
416    pub packed_weights: Option<bool>,
417    /// Substring for picking one GGUF in a directory (default `Q4_K_M`).
418    pub prefer_gguf: Option<String>,
419}
420
421impl<Cfg> Default for LmRunnerBuilder<Cfg> {
422    fn default() -> Self {
423        Self {
424            weights: None,
425            config: ConfigSource::Embedded,
426            device: Device::Cpu,
427            max_seq: 128,
428            max_memory_gb: None,
429            stream: true,
430            sample: SampleOpts::greedy(),
431            format: None,
432            packed_weights: None,
433            prefer_gguf: None,
434        }
435    }
436}
437
438impl<Cfg> LmRunnerBuilder<Cfg> {
439    pub fn new() -> Self {
440        Self::default()
441    }
442
443    pub fn weights<P: Into<PathBuf>>(mut self, p: P) -> Self {
444        self.weights = Some(p.into());
445        self
446    }
447
448    pub fn config(mut self, src: ConfigSource<Cfg>) -> Self {
449        self.config = src;
450        self
451    }
452
453    pub fn config_value(self, cfg: Cfg) -> Self {
454        self.config(ConfigSource::Explicit(cfg))
455    }
456
457    pub fn device(mut self, d: Device) -> Self {
458        self.device = d;
459        self
460    }
461
462    pub fn max_seq(mut self, n: usize) -> Self {
463        self.max_seq = n;
464        self
465    }
466
467    pub fn max_memory_gb(mut self, gb: f32) -> Self {
468        self.max_memory_gb = Some(gb);
469        self
470    }
471
472    pub fn stream(mut self, on: bool) -> Self {
473        self.stream = on;
474        self
475    }
476
477    pub fn sample(mut self, s: SampleOpts) -> Self {
478        self.sample = s;
479        self
480    }
481
482    pub fn format(mut self, fmt: WeightFormat) -> Self {
483        self.format = Some(fmt);
484        self
485    }
486
487    pub fn packed_weights(mut self, on: bool) -> Self {
488        self.packed_weights = Some(on);
489        self
490    }
491
492    pub fn prefer_gguf<S: Into<String>>(mut self, q: S) -> Self {
493        self.prefer_gguf = Some(q.into());
494        self
495    }
496
497    /// Resolve the format using the explicit override or the file extension.
498    pub fn resolved_format(&self) -> anyhow::Result<WeightFormat> {
499        match self.format {
500            Some(f) => Ok(f),
501            None => {
502                let p = self
503                    .weights
504                    .as_deref()
505                    .ok_or_else(|| anyhow::anyhow!("weights path required"))?;
506                WeightFormat::from_path(p)
507            }
508        }
509    }
510
511    /// Determine whether packed GGUF loading should be used. Honors an
512    /// explicit override; otherwise auto-enables for GGUF files at or
513    /// above [`PACKED_GGUF_AUTO_THRESHOLD_BYTES`].
514    pub fn resolved_packed(&self, fmt: WeightFormat) -> bool {
515        match self.packed_weights {
516            Some(b) => b,
517            None => {
518                if !matches!(fmt, WeightFormat::Gguf) {
519                    return false;
520                }
521                self.weights
522                    .as_deref()
523                    .and_then(|p| std::fs::metadata(p).ok())
524                    .map(|m| m.len() >= PACKED_GGUF_AUTO_THRESHOLD_BYTES)
525                    .unwrap_or(false)
526            }
527        }
528    }
529}
530
531// ─────────────────────────────────────────────────────────────────
532// Model registry (auto-dispatch by path)
533// ─────────────────────────────────────────────────────────────────
534
535/// Family-routing entry: a short name + a probe closure that returns
536/// `true` for files this family should handle.
537///
538/// Registered at process start by `register_model` (or by a
539/// `#[rlx_runner]`-generated `inventory` entry). [`auto_runner_name`]
540/// walks the registry and returns the first matching family.
541pub struct ModelRegistration {
542    pub family: &'static str,
543    pub description: &'static str,
544    /// `(arch_str_lower_case, path) -> bool`. `arch_str_lower_case` is
545    /// the GGUF `general.architecture` (`""` for safetensors); `path`
546    /// is the concrete weights file. Implementations should return
547    /// `true` if the family owns this file.
548    pub matches: fn(arch: &str, path: &Path) -> bool,
549}
550
551inventory::collect!(ModelRegistration);
552
553/// Re-export of `inventory` so the `register_lm_runner!` proc-macro
554/// can call `::rlx_runtime::lm::inventory::submit!` without forcing
555/// every caller to add `inventory` to their Cargo.toml.
556pub extern crate inventory;
557
558/// Iterate over every registered family.
559pub fn registered_models() -> impl Iterator<Item = &'static ModelRegistration> {
560    inventory::iter::<ModelRegistration>.into_iter()
561}
562
563/// Find the family that claims `(arch, path)`.
564pub fn auto_runner_name(arch: &str, path: &Path) -> Option<&'static str> {
565    let arch_lc = arch.to_ascii_lowercase();
566    registered_models()
567        .find(|m| (m.matches)(&arch_lc, path))
568        .map(|m| m.family)
569}
570
571#[cfg(test)]
572mod tests {
573    use super::*;
574
575    #[test]
576    fn config_source_default_is_embedded() {
577        let s: ConfigSource<()> = ConfigSource::default();
578        assert!(matches!(s, ConfigSource::Embedded));
579    }
580
581    #[test]
582    fn builder_defaults_match_legacy_runners() {
583        let b: LmRunnerBuilder<()> = LmRunnerBuilder::new();
584        assert_eq!(b.device, Device::Cpu);
585        assert_eq!(b.max_seq, 128);
586        assert!(b.stream);
587        assert!(b.sample.is_greedy());
588        assert!(b.packed_weights.is_none());
589    }
590
591    #[test]
592    fn packed_auto_size_threshold() {
593        let mut b: LmRunnerBuilder<()> = LmRunnerBuilder::new();
594        b.weights = Some("/nonexistent/file.gguf".into());
595        // Missing file → auto returns false (no metadata).
596        assert!(!b.resolved_packed(WeightFormat::Gguf));
597        // Explicit override wins.
598        b.packed_weights = Some(true);
599        assert!(b.resolved_packed(WeightFormat::Gguf));
600        // Non-GGUF never auto-packs.
601        b.packed_weights = None;
602        assert!(!b.resolved_packed(WeightFormat::Safetensors));
603    }
604}