Skip to main content

rlx_models_core/
weight_registry.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Extensible weight-format registry — register custom loaders for new extensions.
17
18use anyhow::{Context, Result, anyhow};
19use std::path::{Path, PathBuf};
20use std::sync::{Mutex, OnceLock};
21
22use crate::gguf_support::{ResolveWeightsOptions, resolve_weights_file_with_options};
23use crate::weight_loader::{GgufLoader, WeightLoader};
24use crate::weight_map::{WeightDrainPolicy, WeightMap};
25
26/// Opens a file path into a [`WeightLoader`].
27pub type WeightLoaderFactory = fn(&Path) -> Result<Box<dyn WeightLoader>>;
28
29/// Describes one on-disk weight format.
30#[derive(Clone, Copy)]
31pub struct WeightFormatRegistration {
32    pub id: &'static str,
33    pub extensions: &'static [&'static str],
34    pub open: WeightLoaderFactory,
35}
36
37fn open_safetensors(path: &Path) -> Result<Box<dyn WeightLoader>> {
38    let path_str = path
39        .to_str()
40        .ok_or_else(|| anyhow!("non-utf8 path {:?}", path))?;
41    Ok(Box::new(WeightMap::from_file(path_str)?))
42}
43
44fn open_gguf(path: &Path) -> Result<Box<dyn WeightLoader>> {
45    let path_str = path
46        .to_str()
47        .ok_or_else(|| anyhow!("non-utf8 path {:?}", path))?;
48    Ok(Box::new(GgufLoader::from_file(path_str)?))
49}
50
51static BUILTIN_FORMATS: &[WeightFormatRegistration] = &[
52    WeightFormatRegistration {
53        id: "safetensors",
54        extensions: &["safetensors"],
55        open: open_safetensors,
56    },
57    WeightFormatRegistration {
58        id: "gguf",
59        extensions: &["gguf"],
60        open: open_gguf,
61    },
62];
63
64static CUSTOM_FORMATS: Mutex<Vec<WeightFormatRegistration>> = Mutex::new(Vec::new());
65static REGISTRY_INIT: OnceLock<()> = OnceLock::new();
66
67/// Register a custom weight format (call before the first load). Later entries override
68/// built-ins when the same extension is registered twice.
69pub fn register_weight_format(reg: WeightFormatRegistration) {
70    CUSTOM_FORMATS
71        .lock()
72        .expect("weight format registry lock")
73        .push(reg);
74}
75
76fn formats() -> Vec<WeightFormatRegistration> {
77    REGISTRY_INIT.get_or_init(|| ());
78    let mut out: Vec<WeightFormatRegistration> = BUILTIN_FORMATS.to_vec();
79    let custom = CUSTOM_FORMATS.lock().expect("weight format registry lock");
80    out.extend(custom.iter().copied());
81    out
82}
83
84/// One registered on-disk format (built-in or custom).
85#[derive(Debug, Clone, Copy)]
86pub struct RegisteredFormat {
87    pub id: &'static str,
88    pub extensions: &'static [&'static str],
89}
90
91/// All registered formats (built-ins first, then custom registrations).
92pub fn list_registered_formats() -> Vec<RegisteredFormat> {
93    formats()
94        .into_iter()
95        .map(|r| RegisteredFormat {
96            id: r.id,
97            extensions: r.extensions,
98        })
99        .collect()
100}
101
102/// Comma-separated extension list for error messages.
103pub fn registered_extensions_hint() -> String {
104    let mut exts: Vec<&str> = Vec::new();
105    for reg in list_registered_formats() {
106        for e in reg.extensions {
107            if !exts.contains(e) {
108                exts.push(e);
109            }
110        }
111    }
112    exts.join(", ")
113}
114
115/// Extension → format id (last registration wins).
116pub fn format_for_extension(ext: &str) -> Option<&'static str> {
117    let ext = ext.to_ascii_lowercase();
118    let mut found = None;
119    for reg in formats() {
120        if reg.extensions.iter().any(|e| *e == ext) {
121            found = Some(reg.id);
122        }
123    }
124    found
125}
126
127/// Open a single file via the format registry.
128pub fn open_weight_loader(path: &Path) -> Result<Box<dyn WeightLoader>> {
129    let ext = path.extension().and_then(|s| s.to_str()).unwrap_or("");
130    for reg in formats() {
131        if reg.extensions.contains(&ext) {
132            return (reg.open)(path)
133                .with_context(|| format!("opening {:?} as format {}", path, reg.id));
134        }
135    }
136    let known = registered_extensions_hint();
137    Err(anyhow!(
138        "unsupported weight extension `.{ext}` for {path:?}\n\
139         Registered extensions: .{known}\n\
140         Register a custom loader before the first open:\n\
141           use rlx_core::weights::WeightFormatRegistration;\n\
142           WeightFormatRegistration::new(\"myfmt\", &[\"mybin\"], my_open).register();\n\
143         Docs: rlx_core::weights module, README → GGUF → Custom formats"
144    ))
145}
146
147/// Options for [`load_weights_resolved`] — prefer [`crate::weights::LoadOpts`] presets at call sites.
148#[derive(Debug, Clone, Default)]
149pub struct LoadWeightsOptions<'a> {
150    pub resolve: ResolveWeightsOptions<'a>,
151    /// How to drain into a [`WeightMap`] when `into_map` is true.
152    pub drain: WeightDrainPolicy,
153    /// If true, return a drained [`WeightMap`]; if false, return the live loader.
154    pub into_map: bool,
155}
156
157impl<'a> LoadWeightsOptions<'a> {
158    pub fn map() -> Self {
159        Self {
160            into_map: true,
161            ..Default::default()
162        }
163    }
164
165    pub fn loader() -> Self {
166        Self {
167            into_map: false,
168            ..Default::default()
169        }
170    }
171
172    pub fn prefer_q4_k_m(self) -> Self {
173        self.prefer_substring("Q4_K_M")
174    }
175
176    pub fn prefer_substring(mut self, sub: &'a str) -> Self {
177        self.resolve.prefer_gguf_substring = Some(sub);
178        self
179    }
180
181    pub fn gguf_index(mut self, idx: usize) -> Self {
182        self.resolve.gguf_index = Some(idx);
183        self
184    }
185
186    pub fn drain(mut self, policy: WeightDrainPolicy) -> Self {
187        self.drain = policy;
188        self
189    }
190
191    pub fn warn_unused(self) -> Self {
192        self.drain(WeightDrainPolicy::AllF32WarnUnused)
193    }
194}
195
196/// Result of resolving and opening weights.
197pub enum LoadedWeights {
198    /// Live loader (supports packed `take`, MTP, mmap borrow).
199    Loader {
200        path: PathBuf,
201        format_id: &'static str,
202        loader: Box<dyn WeightLoader>,
203    },
204    /// Drained map (F32 tensors + optional packed sidecar).
205    Map {
206        path: PathBuf,
207        format_id: &'static str,
208        map: WeightMap,
209        packed: Vec<crate::weight_map::NamedPackedWeight>,
210    },
211}
212
213impl LoadedWeights {
214    /// Drained map, if this load used `into_map: true`.
215    pub fn as_map(&self) -> Option<&WeightMap> {
216        match self {
217            Self::Map { map, .. } => Some(map),
218            Self::Loader { .. } => None,
219        }
220    }
221
222    /// Live loader, if this load used `into_map: false`.
223    pub fn as_loader_mut(&mut self) -> Option<&mut dyn WeightLoader> {
224        match self {
225            Self::Loader { loader, .. } => Some(loader.as_mut()),
226            Self::Map { .. } => None,
227        }
228    }
229
230    /// Packed K-quant tensors when returned as [`LoadedWeights::Map`].
231    pub fn packed_tensors(&self) -> Option<&[crate::weight_map::NamedPackedWeight]> {
232        match self {
233            Self::Map { packed, .. } => Some(packed.as_slice()),
234            Self::Loader { .. } => None,
235        }
236    }
237
238    pub fn path(&self) -> &Path {
239        match self {
240            Self::Loader { path, .. } | Self::Map { path, .. } => path,
241        }
242    }
243
244    pub fn format_id(&self) -> &'static str {
245        match self {
246            Self::Loader { format_id, .. } | Self::Map { format_id, .. } => format_id,
247        }
248    }
249
250    pub fn into_map(self) -> Result<WeightMap> {
251        match self {
252            Self::Map { map, packed, .. } => {
253                if !packed.is_empty() {
254                    anyhow::bail!(
255                        "into_map: {} packed tensors were not merged (use Loader path for packed mode)",
256                        packed.len()
257                    );
258                }
259                Ok(map)
260            }
261            Self::Loader { mut loader, .. } => Ok(WeightMap::from_weight_loader(loader.as_mut())?),
262        }
263    }
264}
265
266/// Resolve a file or directory, enforce GGUF arch policy, open via registry, optionally drain.
267pub fn load_weights_resolved(path: &Path, opts: LoadWeightsOptions<'_>) -> Result<LoadedWeights> {
268    let file = resolve_weights_file_with_options(path, &opts.resolve)?;
269    // Split GGUF merge happens inside GgufLoader::from_file / load_gguf_file.
270    let ext = file.extension().and_then(|s| s.to_str()).unwrap_or("");
271    let format_id = format_for_extension(ext)
272        .ok_or_else(|| anyhow!("no registered loader for extension `.{ext}`"))?;
273    let mut loader = open_weight_loader(&file)?;
274    if opts.into_map {
275        let (map, packed) = WeightMap::drain_loader(loader.as_mut(), opts.drain)?;
276        Ok(LoadedWeights::Map {
277            path: file,
278            format_id,
279            map,
280            packed,
281        })
282    } else {
283        Ok(LoadedWeights::Loader {
284            path: file,
285            format_id,
286            loader,
287        })
288    }
289}
290
291/// Convenience: resolve + drain to F32 [`WeightMap`].
292pub fn load_weight_map_resolved(
293    path: &Path,
294    opts: LoadWeightsOptions<'_>,
295) -> Result<(PathBuf, WeightMap)> {
296    let mut o = opts;
297    o.into_map = true;
298    let loaded = load_weights_resolved(path, o)?;
299    Ok((loaded.path().to_path_buf(), loaded.into_map()?))
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305
306    #[test]
307    fn unknown_extension_errors() {
308        let path = std::env::temp_dir().join("rlx_weight_registry_test.noext");
309        match open_weight_loader(&path) {
310            Err(e) => assert!(
311                e.to_string().contains("unsupported weight extension"),
312                "{e}"
313            ),
314            Ok(_) => panic!("expected unsupported extension for {path:?}"),
315        }
316    }
317
318    #[test]
319    fn format_for_extension_builtin() {
320        assert_eq!(format_for_extension("gguf"), Some("gguf"));
321        assert_eq!(format_for_extension("safetensors"), Some("safetensors"));
322        assert_eq!(format_for_extension("bin"), None);
323    }
324
325    #[test]
326    fn list_formats_includes_builtins() {
327        let ids: Vec<_> = list_registered_formats().iter().map(|r| r.id).collect();
328        assert!(ids.contains(&"gguf"));
329        assert!(ids.contains(&"safetensors"));
330    }
331}