rlx_runtime/lm.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Generic language-model runner trait and shared builder.
17//!
18//! Until now every `rlx-<family>` model crate carried its own
19//! `*RunnerBuilder` (Qwen3RunnerBuilder, Llama32RunnerBuilder, …)
20//! with the same fields, the same `*ConfigSource { Embedded |
21//! JsonFile | Explicit(T) }` enum, and the same auto-packed-GGUF
22//! heuristic. This module hoists those shapes upstream so that:
23//!
24//! 1. `LmRunner` can live in `rlx-runtime` (today's home in
25//! `rlx-cli` forces every model crate to take a dependency on
26//! the CLI helper crate).
27//! 2. Per-family runners can `Deref` to / wrap [`LmRunnerBuilder`]
28//! instead of redefining the same fields.
29//! 3. Downstream tools (`skill`, web apps) can talk to runners
30//! through one trait without compiling in every model crate.
31//!
32//! The trait surface mirrors the existing `rlx_cli::LmRunner`. The
33//! CLI re-export is kept for backwards compat.
34
35use std::path::{Path, PathBuf};
36
37use crate::Device;
38
39/// Minimal per-family runner interface used by `auto_dispatch` and
40/// the `rlx-text` / `skill` integration.
41///
42/// Implementations must be `Send` so the boxed trait can move across
43/// threads (e.g. when a server runs inference on a worker pool).
44/// `Sync` is intentionally not required — runners hold mutable
45/// per-call compile / cache state.
46pub trait LmRunner: Send {
47 /// Short family identifier (`"qwen3"`, `"llama32"`, `"gemma"`).
48 fn family(&self) -> &'static str;
49
50 /// LM head vocabulary size.
51 fn vocab_size(&self) -> usize;
52
53 /// Run prefill on `prompt_ids` and return last-token logits.
54 fn predict_logits(&mut self, prompt_ids: &[u32]) -> anyhow::Result<Vec<f32>>;
55
56 /// Generate up to `n_new` tokens after `prompt_ids` using greedy
57 /// (argmax) sampling. The default impl re-prefills on the full
58 /// context each step — per-family runners should override with
59 /// their cached decode fast path.
60 ///
61 /// `on_token` returns `true` to continue, `false` to stop.
62 fn generate(
63 &mut self,
64 prompt_ids: &[u32],
65 n_new: usize,
66 on_token: &mut dyn FnMut(u32) -> bool,
67 ) -> anyhow::Result<Vec<u32>> {
68 let mut context: Vec<u32> = prompt_ids.to_vec();
69 let mut produced: Vec<u32> = Vec::with_capacity(n_new);
70 for _ in 0..n_new {
71 let logits = self.predict_logits(&context)?;
72 let next = argmax_u32(&logits);
73 produced.push(next);
74 let cont = on_token(next);
75 context.push(next);
76 if !cont {
77 break;
78 }
79 }
80 Ok(produced)
81 }
82
83 /// Whether this runner supports multimodal (image+text) generation.
84 fn supports_multimodal(&self) -> bool {
85 false
86 }
87
88 /// Multimodal generation: prefill with text where image markers are
89 /// spliced with vision embeddings derived from `rgb`.
90 fn generate_multimodal(
91 &mut self,
92 _prompt: &str,
93 _rgb: &[u8],
94 _img_w: usize,
95 _img_h: usize,
96 _tokenizer: Option<&Path>,
97 _n_new: usize,
98 _on_token: &mut dyn FnMut(u32) -> bool,
99 ) -> anyhow::Result<Vec<u32>> {
100 Err(anyhow::anyhow!(
101 "this LmRunner does not support multimodal generation"
102 ))
103 }
104}
105
106fn argmax_u32(logits: &[f32]) -> u32 {
107 let mut best = 0usize;
108 let mut best_v = f32::NEG_INFINITY;
109 for (i, &v) in logits.iter().enumerate() {
110 if v > best_v {
111 best_v = v;
112 best = i;
113 }
114 }
115 best as u32
116}
117
118// ─────────────────────────────────────────────────────────────────
119// Weight format + config source
120// ─────────────────────────────────────────────────────────────────
121
122/// Weight file format. Detected from the file extension by default;
123/// the CLI accepts `--format` to override.
124#[derive(Debug, Clone, Copy, PartialEq, Eq)]
125pub enum WeightFormat {
126 Safetensors,
127 Gguf,
128}
129
130impl WeightFormat {
131 /// Infer format from a path extension.
132 pub fn from_path(path: &Path) -> anyhow::Result<Self> {
133 match path.extension().and_then(|s| s.to_str()) {
134 Some("safetensors") => Ok(Self::Safetensors),
135 Some("gguf") => Ok(Self::Gguf),
136 other => Err(anyhow::anyhow!(
137 "cannot autodetect weight format from extension {:?} on {:?}",
138 other,
139 path
140 )),
141 }
142 }
143
144 /// Parse CLI `--format` values (`safetensors` | `gguf`).
145 pub fn parse(s: &str) -> anyhow::Result<Self> {
146 match s {
147 "safetensors" => Ok(Self::Safetensors),
148 "gguf" => Ok(Self::Gguf),
149 other => Err(anyhow::anyhow!("expected safetensors|gguf, got {other}")),
150 }
151 }
152}
153
154/// Where to read a model config from.
155///
156/// Replaces the per-family `Qwen3ConfigSource`, `Llama32ConfigSource`,
157/// `GemmaConfigSource`, `Qwen35ConfigSource` enums.
158#[derive(Debug, Clone, Default)]
159pub enum ConfigSource<T> {
160 /// Read from GGUF metadata.
161 #[default]
162 Embedded,
163 /// Read from a HuggingFace `config.json` at this path.
164 JsonFile(PathBuf),
165 /// Use the supplied config object directly.
166 Explicit(T),
167}
168
169// ─────────────────────────────────────────────────────────────────
170// Sampling
171// ─────────────────────────────────────────────────────────────────
172
173/// Mirostat variant selection. See `crate::samplers::{MirostatV1, MirostatV2}`.
174#[derive(Debug, Default, Clone, Copy, PartialEq)]
175pub enum MirostatMode {
176 #[default]
177 Off,
178 V1,
179 V2,
180}
181
182/// Sampling parameters. Greedy when `temperature == 0` and no advanced
183/// sampler is enabled. All "advanced" knobs default to off / no-op so
184/// legacy callers see classic top-k/top-p/temperature behaviour.
185///
186/// `into_chain()` turns these flat fields into a `SamplerChain` that
187/// downstream backends can execute. Ordering follows llama.cpp's
188/// canonical chain (penalties → temperature → top-k → typical → top-p
189/// → top-n-σ → xtc → mirostat).
190#[derive(Debug, Clone)]
191pub struct SampleOpts {
192 pub temperature: f32,
193 pub top_p: f32,
194 pub top_k: Option<u32>,
195 pub repetition_penalty: f32,
196
197 // ── advanced samplers ────────────────────────────────────────
198 /// Dynamic temperature [min, max] gated by softmax entropy.
199 /// `None` ⇒ flat temperature only.
200 pub dynamic_temp: Option<(f32, f32)>,
201 /// Exponent used by [`crate::samplers::DynamicTemperature`].
202 pub dynamic_temp_exponent: f32,
203 /// Locally-typical sampling (Meister et al. 2022). 1.0 ⇒ off.
204 pub typical_p: f32,
205 /// Top-n-σ cutoff (Hewitt et al. 2024). 0 ⇒ off.
206 pub top_n_sigma: f32,
207 /// XTC: probability of dropping high-confidence top tokens.
208 pub xtc_threshold: f32,
209 pub xtc_prob: f32,
210 /// DRY repetition penalty knobs.
211 pub dry_multiplier: f32,
212 pub dry_base: f32,
213 pub dry_allowed_length: usize,
214 pub dry_max_ngram: usize,
215 pub dry_sequence_breakers: Vec<u32>,
216 /// Mirostat mode + parameters.
217 pub mirostat: MirostatMode,
218 pub mirostat_tau: f32,
219 pub mirostat_eta: f32,
220 pub mirostat_m: usize,
221 /// Frequency / presence penalties (OpenAI-style).
222 pub frequency_penalty: f32,
223 pub presence_penalty: f32,
224 pub repetition_window: usize,
225 /// Minimum tokens kept by top-p / typical (avoid one-token nucleus).
226 pub min_keep: usize,
227}
228
229impl Default for SampleOpts {
230 fn default() -> Self {
231 Self::greedy()
232 }
233}
234
235impl SampleOpts {
236 pub fn greedy() -> Self {
237 Self {
238 temperature: 0.0,
239 top_p: 1.0,
240 top_k: None,
241 repetition_penalty: 1.0,
242 dynamic_temp: None,
243 dynamic_temp_exponent: 1.0,
244 typical_p: 1.0,
245 top_n_sigma: 0.0,
246 xtc_threshold: 0.0,
247 xtc_prob: 0.0,
248 dry_multiplier: 0.0,
249 dry_base: 1.75,
250 dry_allowed_length: 2,
251 dry_max_ngram: 32,
252 dry_sequence_breakers: Vec::new(),
253 mirostat: MirostatMode::Off,
254 mirostat_tau: 5.0,
255 mirostat_eta: 0.1,
256 mirostat_m: 100,
257 frequency_penalty: 0.0,
258 presence_penalty: 0.0,
259 repetition_window: 64,
260 min_keep: 1,
261 }
262 }
263
264 pub fn nucleus(temperature: f32, top_p: f32) -> Self {
265 Self {
266 temperature,
267 top_p,
268 ..Self::greedy()
269 }
270 }
271
272 pub fn is_greedy(&self) -> bool {
273 self.temperature <= 0.0 && self.mirostat == MirostatMode::Off
274 }
275
276 /// True when only classic top-k/top-p/temperature are configured;
277 /// backends can take a cheap fast path in this case (e.g. the
278 /// existing `sample_row` CPU kernel) instead of building a chain.
279 pub fn is_classic(&self) -> bool {
280 self.dynamic_temp.is_none()
281 && self.typical_p >= 1.0
282 && self.top_n_sigma <= 0.0
283 && self.xtc_prob <= 0.0
284 && self.dry_multiplier <= 0.0
285 && self.mirostat == MirostatMode::Off
286 && self.frequency_penalty == 0.0
287 && self.presence_penalty == 0.0
288 && (self.repetition_penalty - 1.0).abs() < f32::EPSILON
289 }
290
291 /// Build the `SamplerChain` corresponding to these options. The
292 /// returned chain is ready to drive `SamplerChain::sample` against
293 /// a logits row + history. Greedy decoding produces a chain with
294 /// one `Temperature{t:1e-6}` step (which collapses to argmax after
295 /// softmax) — callers that want true greedy can short-circuit via
296 /// `is_greedy()` before building the chain.
297 pub fn into_chain(&self) -> crate::samplers::SamplerChain {
298 use crate::samplers::*;
299 let mut b = SamplerChain::builder();
300
301 // 1. Penalties operate on raw logits, before any temperature
302 // scaling — matches llama.cpp's order.
303 if (self.repetition_penalty - 1.0).abs() > f32::EPSILON
304 || self.frequency_penalty != 0.0
305 || self.presence_penalty != 0.0
306 {
307 b = b.push(RepetitionPenalty {
308 penalty: self.repetition_penalty,
309 frequency: self.frequency_penalty,
310 presence: self.presence_penalty,
311 last_n: self.repetition_window,
312 });
313 }
314 if self.dry_multiplier > 0.0 {
315 b = b.push(Dry {
316 multiplier: self.dry_multiplier,
317 base: self.dry_base,
318 allowed_length: self.dry_allowed_length,
319 max_ngram: self.dry_max_ngram,
320 sequence_breakers: self.dry_sequence_breakers.clone(),
321 });
322 }
323
324 // 2. Temperature (dynamic or static). Mirostat replaces both.
325 if self.mirostat == MirostatMode::Off {
326 if let Some((mn, mx)) = self.dynamic_temp {
327 b = b.push(DynamicTemperature {
328 min: mn,
329 max: mx,
330 exponent: self.dynamic_temp_exponent,
331 });
332 } else if self.temperature > 0.0 && (self.temperature - 1.0).abs() > f32::EPSILON {
333 b = b.push(Temperature {
334 t: self.temperature,
335 });
336 } else if self.temperature <= 0.0 {
337 b = b.push(Temperature { t: 1e-6 });
338 }
339 }
340
341 // 3. Filters: top-k → typical → top-p → top-n-sigma → xtc.
342 if let Some(k) = self.top_k {
343 if k > 0 {
344 b = b.push(TopK { k: k as usize });
345 }
346 }
347 if self.typical_p < 1.0 && self.typical_p > 0.0 {
348 b = b.push(TypicalP {
349 p: self.typical_p,
350 min_keep: self.min_keep,
351 });
352 }
353 if self.top_p < 1.0 && self.top_p > 0.0 {
354 b = b.push(TopP {
355 p: self.top_p,
356 min_keep: self.min_keep,
357 });
358 }
359 if self.top_n_sigma > 0.0 {
360 b = b.push(TopNSigma {
361 n: self.top_n_sigma,
362 });
363 }
364 if self.xtc_prob > 0.0 && self.xtc_threshold > 0.0 {
365 b = b.push(Xtc {
366 threshold: self.xtc_threshold,
367 prob: self.xtc_prob,
368 min_keep: self.min_keep,
369 });
370 }
371
372 // 4. Mirostat (replaces softmax+sample at the end of the chain).
373 match self.mirostat {
374 MirostatMode::Off => {}
375 MirostatMode::V1 => {
376 b = b.push(MirostatV1 {
377 tau: self.mirostat_tau,
378 eta: self.mirostat_eta,
379 m: self.mirostat_m,
380 });
381 }
382 MirostatMode::V2 => {
383 b = b.push(MirostatV2 {
384 tau: self.mirostat_tau,
385 eta: self.mirostat_eta,
386 });
387 }
388 }
389 b.build()
390 }
391}
392
393// ─────────────────────────────────────────────────────────────────
394// Shared builder
395// ─────────────────────────────────────────────────────────────────
396
397/// Auto-packed threshold: prefer K-quant packed loading for GGUF
398/// files >= this size. Cuts host memory ~6× on Q4_K_M models.
399pub const PACKED_GGUF_AUTO_THRESHOLD_BYTES: u64 = 256 * 1024 * 1024;
400
401/// Builder fields common to every per-family runner.
402///
403/// Per-family runner builders should wrap this and forward the
404/// methods (or use `#[rlx_runner]` from `rlx-macros`).
405#[derive(Debug, Clone)]
406pub struct LmRunnerBuilder<Cfg> {
407 pub weights: Option<PathBuf>,
408 pub config: ConfigSource<Cfg>,
409 pub device: Device,
410 pub max_seq: usize,
411 pub max_memory_gb: Option<f32>,
412 pub stream: bool,
413 pub sample: SampleOpts,
414 pub format: Option<WeightFormat>,
415 /// `None` = auto-detect (packed when GGUF ≥ 256 MB).
416 pub packed_weights: Option<bool>,
417 /// Substring for picking one GGUF in a directory (default `Q4_K_M`).
418 pub prefer_gguf: Option<String>,
419}
420
421impl<Cfg> Default for LmRunnerBuilder<Cfg> {
422 fn default() -> Self {
423 Self {
424 weights: None,
425 config: ConfigSource::Embedded,
426 device: Device::Cpu,
427 max_seq: 128,
428 max_memory_gb: None,
429 stream: true,
430 sample: SampleOpts::greedy(),
431 format: None,
432 packed_weights: None,
433 prefer_gguf: None,
434 }
435 }
436}
437
438impl<Cfg> LmRunnerBuilder<Cfg> {
439 pub fn new() -> Self {
440 Self::default()
441 }
442
443 pub fn weights<P: Into<PathBuf>>(mut self, p: P) -> Self {
444 self.weights = Some(p.into());
445 self
446 }
447
448 pub fn config(mut self, src: ConfigSource<Cfg>) -> Self {
449 self.config = src;
450 self
451 }
452
453 pub fn config_value(self, cfg: Cfg) -> Self {
454 self.config(ConfigSource::Explicit(cfg))
455 }
456
457 pub fn device(mut self, d: Device) -> Self {
458 self.device = d;
459 self
460 }
461
462 pub fn max_seq(mut self, n: usize) -> Self {
463 self.max_seq = n;
464 self
465 }
466
467 pub fn max_memory_gb(mut self, gb: f32) -> Self {
468 self.max_memory_gb = Some(gb);
469 self
470 }
471
472 pub fn stream(mut self, on: bool) -> Self {
473 self.stream = on;
474 self
475 }
476
477 pub fn sample(mut self, s: SampleOpts) -> Self {
478 self.sample = s;
479 self
480 }
481
482 pub fn format(mut self, fmt: WeightFormat) -> Self {
483 self.format = Some(fmt);
484 self
485 }
486
487 pub fn packed_weights(mut self, on: bool) -> Self {
488 self.packed_weights = Some(on);
489 self
490 }
491
492 pub fn prefer_gguf<S: Into<String>>(mut self, q: S) -> Self {
493 self.prefer_gguf = Some(q.into());
494 self
495 }
496
497 /// Resolve the format using the explicit override or the file extension.
498 pub fn resolved_format(&self) -> anyhow::Result<WeightFormat> {
499 match self.format {
500 Some(f) => Ok(f),
501 None => {
502 let p = self
503 .weights
504 .as_deref()
505 .ok_or_else(|| anyhow::anyhow!("weights path required"))?;
506 WeightFormat::from_path(p)
507 }
508 }
509 }
510
511 /// Determine whether packed GGUF loading should be used. Honors an
512 /// explicit override; otherwise auto-enables for GGUF files at or
513 /// above [`PACKED_GGUF_AUTO_THRESHOLD_BYTES`].
514 pub fn resolved_packed(&self, fmt: WeightFormat) -> bool {
515 match self.packed_weights {
516 Some(b) => b,
517 None => {
518 if !matches!(fmt, WeightFormat::Gguf) {
519 return false;
520 }
521 self.weights
522 .as_deref()
523 .and_then(|p| std::fs::metadata(p).ok())
524 .map(|m| m.len() >= PACKED_GGUF_AUTO_THRESHOLD_BYTES)
525 .unwrap_or(false)
526 }
527 }
528 }
529}
530
531// ─────────────────────────────────────────────────────────────────
532// Model registry (auto-dispatch by path)
533// ─────────────────────────────────────────────────────────────────
534
535/// Family-routing entry: a short name + a probe closure that returns
536/// `true` for files this family should handle.
537///
538/// Registered at process start by `register_model` (or by a
539/// `#[rlx_runner]`-generated `inventory` entry). [`auto_runner_name`]
540/// walks the registry and returns the first matching family.
541pub struct ModelRegistration {
542 pub family: &'static str,
543 pub description: &'static str,
544 /// `(arch_str_lower_case, path) -> bool`. `arch_str_lower_case` is
545 /// the GGUF `general.architecture` (`""` for safetensors); `path`
546 /// is the concrete weights file. Implementations should return
547 /// `true` if the family owns this file.
548 pub matches: fn(arch: &str, path: &Path) -> bool,
549}
550
551inventory::collect!(ModelRegistration);
552
553/// Re-export of `inventory` so the `register_lm_runner!` proc-macro
554/// can call `::rlx_runtime::lm::inventory::submit!` without forcing
555/// every caller to add `inventory` to their Cargo.toml.
556pub extern crate inventory;
557
558/// Iterate over every registered family.
559pub fn registered_models() -> impl Iterator<Item = &'static ModelRegistration> {
560 inventory::iter::<ModelRegistration>.into_iter()
561}
562
563/// Find the family that claims `(arch, path)`.
564pub fn auto_runner_name(arch: &str, path: &Path) -> Option<&'static str> {
565 let arch_lc = arch.to_ascii_lowercase();
566 registered_models()
567 .find(|m| (m.matches)(&arch_lc, path))
568 .map(|m| m.family)
569}
570
571#[cfg(test)]
572mod tests {
573 use super::*;
574
575 #[test]
576 fn config_source_default_is_embedded() {
577 let s: ConfigSource<()> = ConfigSource::default();
578 assert!(matches!(s, ConfigSource::Embedded));
579 }
580
581 #[test]
582 fn builder_defaults_match_legacy_runners() {
583 let b: LmRunnerBuilder<()> = LmRunnerBuilder::new();
584 assert_eq!(b.device, Device::Cpu);
585 assert_eq!(b.max_seq, 128);
586 assert!(b.stream);
587 assert!(b.sample.is_greedy());
588 assert!(b.packed_weights.is_none());
589 }
590
591 #[test]
592 fn packed_auto_size_threshold() {
593 let mut b: LmRunnerBuilder<()> = LmRunnerBuilder::new();
594 b.weights = Some("/nonexistent/file.gguf".into());
595 // Missing file → auto returns false (no metadata).
596 assert!(!b.resolved_packed(WeightFormat::Gguf));
597 // Explicit override wins.
598 b.packed_weights = Some(true);
599 assert!(b.resolved_packed(WeightFormat::Gguf));
600 // Non-GGUF never auto-packs.
601 b.packed_weights = None;
602 assert!(!b.resolved_packed(WeightFormat::Safetensors));
603 }
604}