Skip to main content

rlx_models/
run.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! High-level runner API — re-exported from per-model crates.
17//!
18//! Prefer depending on a specific model crate (`rlx-qwen3`, …) and
19//! its `rlx-<family>` binary when you only need one family.
20
21pub use crate::sam_runner::{SamArch, SamPredictionAny, SamRunner, SamRunnerBuilder};
22pub use rlx_cli::{
23    AssembledTurn, ChatMessage, ChatTemplate, ChatTemplateSource, CompatSource,
24    CompatibilityReport, CompatibilityStatus, GgufRequiredFields, LmRunner, MediaSource,
25    ModelRunner, MtmdContext, MtmdTurn, SniffedFrom, SniffedRunner, UnimplementedArch,
26    WeightFormat, arch_runner_name, auto_chat_template, auto_dispatch, auto_runner_name,
27    auto_sniff, check_hf_repo, check_path, debug_resolve_name, dispatch, dispatch_help,
28    known_unimplemented_arch, known_unimplemented_keys, list_mtp_keys, looks_like_hf_repo,
29    model_type_runner_name, open_gguf_loader, open_loader, open_loader_resolved,
30    open_loader_with_format, register_cli, register_runner, registered_runners, run_auto,
31    run_check, run_inspect, run_registered,
32};
33pub use rlx_dinov2::{DinoV2Output, DinoV2Runner, DinoV2RunnerBuilder, DinoV2Variant};
34pub use rlx_flux2::{Flux2Output, Flux2Runner, Flux2RunnerBuilder};
35pub use rlx_gemma::{GemmaConfigSource, GemmaRunner, GemmaRunnerBuilder};
36pub use rlx_llama32::{Llama32ConfigSource, Llama32Runner, Llama32RunnerBuilder};
37pub use rlx_qwen3::{Precision, Qwen3ConfigSource, Qwen3Runner, Qwen3RunnerBuilder};
38pub use rlx_qwen35::{Qwen35ConfigSource, Qwen35Runner, Qwen35RunnerBuilder};
39pub use rlx_vjepa2::{
40    Vjepa2Output, Vjepa2PoolOutput, Vjepa2PredictOutput, Vjepa2Runner, Vjepa2RunnerBuilder,
41};
42pub use rlx_wav2vec2_bert::{Wav2Vec2BertRunner, Wav2Vec2BertRunnerBuilder};
43
44/// Back-compat alias.
45pub type ConfigSource = Qwen3ConfigSource;
46
47use anyhow::{Result, bail};
48use std::path::Path;
49
50/// Sniff `path` for its GGUF / safetensors arch and return a boxed
51/// runner that implements [`LmRunner`]. The factory uses the existing
52/// [`auto_sniff`] arch-dispatch and constructs the per-family runner
53/// via its default builder.
54///
55/// Today this covers the four `text` LM families with a stable
56/// `predict_logits` API: `qwen3`, `qwen35`, `gemma`, `llama32`. Other
57/// families (vision-language, diffusion, embed) don't fit the
58/// `LmRunner` shape and return an error here. They keep their
59/// per-family builders.
60///
61/// PLAN.md M3. The `LmRunner` trait gained a default `generate(..)`
62/// in M8, so a boxed runner from this function can stream tokens too.
63pub fn auto_runner(path: &Path) -> Result<Box<dyn LmRunner>> {
64    auto_runner_with_mmproj(path, None)
65}
66
67/// Same as [`auto_runner`] but also attaches an mmproj vision encoder
68/// when the model family supports multimodal prefill (today: `qwen35`
69/// non-MTP path). For other families `mmproj` is silently ignored —
70/// matches llama-cpp's behaviour where mmproj on a text-only model is
71/// a no-op. The returned runner's [`LmRunner::supports_multimodal`]
72/// will report `true` only when both the family is multimodal-capable
73/// and `mmproj` was attached.
74pub fn auto_runner_with_mmproj(path: &Path, mmproj: Option<&Path>) -> Result<Box<dyn LmRunner>> {
75    let sniff = auto_sniff(path)?;
76    let weights = sniff.path.as_path();
77    // Packed-K-quant auto-detection is now inside each runner's
78    // `.build()` (matches llama.cpp's behaviour — K-quant tensors stay
79    // packed in memory, never materialise to a dense F32 matrix).
80    let runner: Box<dyn LmRunner> = match sniff.runner_name {
81        "qwen3" => Box::new(Qwen3Runner::builder().weights(weights).build()?),
82        "qwen35" => {
83            // PLAN.md M6 — auto-route MTP-equipped GGUFs through
84            // `Qwen35SpecRunner` for speculative decode. The
85            // `Qwen35MtpHead` HIR op now dispatches `DequantMatMul`
86            // per-weight (via `weight_schemes` plumbed through
87            // `lower_qwen35_mtp_head`), so packed K-quant GGUFs can
88            // run MTP without falling back to F32-only.
89            if gguf_has_mtp_heads(weights).unwrap_or(false) {
90                Box::new(
91                    rlx_qwen35::Qwen35SpecRunner::builder()
92                        .weights(weights)
93                        .build()?,
94                )
95            } else {
96                let mut b = Qwen35Runner::builder().weights(weights);
97                if let Some(mp) = mmproj {
98                    b = b.mmproj(mp);
99                }
100                Box::new(b.build()?)
101            }
102        }
103        "gemma" => Box::new(GemmaRunner::builder().weights(weights).build()?),
104        "llama32" => Box::new(Llama32Runner::builder().weights(weights).build()?),
105        "lfm" => Box::new(rlx_lfm::LfmRunner::builder().weights(weights).build()?),
106        other => bail!(
107            "auto_runner: runner `{other}` (sniffed from {:?}) has no `LmRunner` impl yet — \
108             use its per-family builder directly",
109            sniff.from
110        ),
111    };
112    Ok(runner)
113}
114
115/// Peek at a GGUF's `<arch>.nextn_predict_layers` metadata key without
116/// fully loading weights. Returns `Ok(true)` when the file declares ≥1
117/// MTP head. Non-GGUF or missing-key → `Ok(false)`.
118fn gguf_has_mtp_heads(path: &Path) -> Result<bool> {
119    use rlx_gguf::{GgufFile, MetaValue};
120    let is_gguf = path
121        .extension()
122        .and_then(|s| s.to_str())
123        .map(|s| s.eq_ignore_ascii_case("gguf"))
124        .unwrap_or(false);
125    if !is_gguf {
126        return Ok(false);
127    }
128    let raw = GgufFile::from_path(path)?;
129    let arch = raw
130        .metadata
131        .get("general.architecture")
132        .and_then(MetaValue::as_str)
133        .unwrap_or("");
134    // Try `<arch>.nextn_predict_layers` first; fall back to `qwen35.*` for
135    // converters that reuse the qwen35 prefix on qwen36 files.
136    for k in [
137        format!("{arch}.nextn_predict_layers"),
138        "qwen35.nextn_predict_layers".to_string(),
139        "qwen36.nextn_predict_layers".to_string(),
140    ] {
141        if let Some(MetaValue::U32(n)) = raw.metadata.get(&k) {
142            return Ok(*n > 0);
143        }
144    }
145    Ok(false)
146}
147
148/// Encode `text` to LM token ids using a HuggingFace `tokenizer.json`
149/// resolved next to the GGUF / safetensors at `weights_path`. Pass
150/// `explicit_tokenizer` to override the auto-discovery (sibling
151/// `<weights>.tokenizer.json` or `tokenizer.json` in the weights dir).
152///
153/// PLAN.md M8 — closes the loop between [`auto_chat_template`] (which
154/// returns a rendered string) and [`LmRunner::predict_logits`] /
155/// [`LmRunner::generate`] (which take raw token ids).
156///
157/// **Fallback (PLAN.md M8):** when no `tokenizer.json` is available
158/// and the weights are a GGUF, `encode_prompt_auto` automatically
159/// reconstructs a byte-level BPE tokenizer from
160/// `tokenizer.ggml.{tokens, merges}`. Works for the GPT-2/Qwen/Llama
161/// family (`tokenizer.ggml.model = "gpt2"`); SentencePiece tokenizers
162/// (`tokenizer.ggml.model = "llama"` legacy) still require a sibling
163/// `tokenizer.json`.
164pub fn auto_tokenize(
165    weights_path: &Path,
166    text: &str,
167    explicit_tokenizer: Option<&Path>,
168) -> Result<Vec<u32>> {
169    use anyhow::Context;
170    match rlx_qwen35::encode_prompt_auto(weights_path, explicit_tokenizer, text) {
171        Ok(ids) => Ok(ids),
172        Err(e) => {
173            // Augment with the GGUF-vocab fallback hint when applicable.
174            let is_gguf = weights_path
175                .extension()
176                .and_then(|s| s.to_str())
177                .map(|s| s.eq_ignore_ascii_case("gguf"))
178                .unwrap_or(false);
179            if !is_gguf {
180                return Err(e);
181            }
182            Err(e).with_context(|| {
183                format!(
184                    "auto_tokenize: no `tokenizer.json` resolved for {weights_path:?}. \
185                     The GGUF ships a vocab at `tokenizer.ggml.tokens` but \
186                     reconstructing a BPE encoder from GGUF-only metadata is \
187                     per-family work (PLAN.md M8 follow-up). Options: \
188                     (1) place `tokenizer.json` next to the GGUF; \
189                     (2) pass an explicit path via the `explicit_tokenizer` arg; \
190                     (3) download the matching `tokenizer.json` from the model's \
191                     HF repo and point at it"
192                )
193            })
194        }
195    }
196}
197
198/// Inverse of [`auto_tokenize`] — turn `ids` back into text, using the
199/// same tokenizer resolution chain (sibling `tokenizer.json` →
200/// `explicit_tokenizer` → GGUF-embedded byte-level BPE vocab).
201///
202/// `skip_special_tokens=true` removes EOS / chat-template control
203/// tokens (`<|im_end|>`, `<|endoftext|>`, …) — what you want for
204/// streaming user-facing chat output. Set `false` to keep them
205/// (useful for debugging or stop-string matching).
206pub fn auto_detokenize(
207    weights_path: &Path,
208    ids: &[u32],
209    explicit_tokenizer: Option<&Path>,
210    skip_special_tokens: bool,
211) -> Result<String> {
212    rlx_qwen35::decode_ids_auto(weights_path, explicit_tokenizer, ids, skip_special_tokens)
213}