1use anyhow::{Context, Result, anyhow};
19use std::path::{Path, PathBuf};
20use std::sync::{Mutex, OnceLock};
21
22use crate::gguf_support::{ResolveWeightsOptions, resolve_weights_file_with_options};
23use crate::weight_loader::{GgufLoader, WeightLoader};
24use crate::weight_map::{WeightDrainPolicy, WeightMap};
25
26pub type WeightLoaderFactory = fn(&Path) -> Result<Box<dyn WeightLoader>>;
28
29#[derive(Clone, Copy)]
31pub struct WeightFormatRegistration {
32 pub id: &'static str,
33 pub extensions: &'static [&'static str],
34 pub open: WeightLoaderFactory,
35}
36
37fn open_safetensors(path: &Path) -> Result<Box<dyn WeightLoader>> {
38 let path_str = path
39 .to_str()
40 .ok_or_else(|| anyhow!("non-utf8 path {:?}", path))?;
41 Ok(Box::new(WeightMap::from_file(path_str)?))
42}
43
44fn open_gguf(path: &Path) -> Result<Box<dyn WeightLoader>> {
45 let path_str = path
46 .to_str()
47 .ok_or_else(|| anyhow!("non-utf8 path {:?}", path))?;
48 Ok(Box::new(GgufLoader::from_file(path_str)?))
49}
50
51static BUILTIN_FORMATS: &[WeightFormatRegistration] = &[
52 WeightFormatRegistration {
53 id: "safetensors",
54 extensions: &["safetensors"],
55 open: open_safetensors,
56 },
57 WeightFormatRegistration {
58 id: "gguf",
59 extensions: &["gguf"],
60 open: open_gguf,
61 },
62];
63
64static CUSTOM_FORMATS: Mutex<Vec<WeightFormatRegistration>> = Mutex::new(Vec::new());
65static REGISTRY_INIT: OnceLock<()> = OnceLock::new();
66
67pub fn register_weight_format(reg: WeightFormatRegistration) {
70 CUSTOM_FORMATS
71 .lock()
72 .expect("weight format registry lock")
73 .push(reg);
74}
75
76fn formats() -> Vec<WeightFormatRegistration> {
77 REGISTRY_INIT.get_or_init(|| ());
78 let mut out: Vec<WeightFormatRegistration> = BUILTIN_FORMATS.to_vec();
79 let custom = CUSTOM_FORMATS.lock().expect("weight format registry lock");
80 out.extend(custom.iter().copied());
81 out
82}
83
84#[derive(Debug, Clone, Copy)]
86pub struct RegisteredFormat {
87 pub id: &'static str,
88 pub extensions: &'static [&'static str],
89}
90
91pub fn list_registered_formats() -> Vec<RegisteredFormat> {
93 formats()
94 .into_iter()
95 .map(|r| RegisteredFormat {
96 id: r.id,
97 extensions: r.extensions,
98 })
99 .collect()
100}
101
102pub fn registered_extensions_hint() -> String {
104 let mut exts: Vec<&str> = Vec::new();
105 for reg in list_registered_formats() {
106 for e in reg.extensions {
107 if !exts.contains(e) {
108 exts.push(e);
109 }
110 }
111 }
112 exts.join(", ")
113}
114
115pub fn format_for_extension(ext: &str) -> Option<&'static str> {
117 let ext = ext.to_ascii_lowercase();
118 let mut found = None;
119 for reg in formats() {
120 if reg.extensions.iter().any(|e| *e == ext) {
121 found = Some(reg.id);
122 }
123 }
124 found
125}
126
127pub fn open_weight_loader(path: &Path) -> Result<Box<dyn WeightLoader>> {
129 let ext = path.extension().and_then(|s| s.to_str()).unwrap_or("");
130 for reg in formats() {
131 if reg.extensions.contains(&ext) {
132 return (reg.open)(path)
133 .with_context(|| format!("opening {:?} as format {}", path, reg.id));
134 }
135 }
136 let known = registered_extensions_hint();
137 Err(anyhow!(
138 "unsupported weight extension `.{ext}` for {path:?}\n\
139 Registered extensions: .{known}\n\
140 Register a custom loader before the first open:\n\
141 use rlx_core::weights::WeightFormatRegistration;\n\
142 WeightFormatRegistration::new(\"myfmt\", &[\"mybin\"], my_open).register();\n\
143 Docs: rlx_core::weights module, README → GGUF → Custom formats"
144 ))
145}
146
147#[derive(Debug, Clone, Default)]
149pub struct LoadWeightsOptions<'a> {
150 pub resolve: ResolveWeightsOptions<'a>,
151 pub drain: WeightDrainPolicy,
153 pub into_map: bool,
155}
156
157impl<'a> LoadWeightsOptions<'a> {
158 pub fn map() -> Self {
159 Self {
160 into_map: true,
161 ..Default::default()
162 }
163 }
164
165 pub fn loader() -> Self {
166 Self {
167 into_map: false,
168 ..Default::default()
169 }
170 }
171
172 pub fn prefer_q4_k_m(self) -> Self {
173 self.prefer_substring("Q4_K_M")
174 }
175
176 pub fn prefer_substring(mut self, sub: &'a str) -> Self {
177 self.resolve.prefer_gguf_substring = Some(sub);
178 self
179 }
180
181 pub fn gguf_index(mut self, idx: usize) -> Self {
182 self.resolve.gguf_index = Some(idx);
183 self
184 }
185
186 pub fn drain(mut self, policy: WeightDrainPolicy) -> Self {
187 self.drain = policy;
188 self
189 }
190
191 pub fn warn_unused(self) -> Self {
192 self.drain(WeightDrainPolicy::AllF32WarnUnused)
193 }
194}
195
196pub enum LoadedWeights {
198 Loader {
200 path: PathBuf,
201 format_id: &'static str,
202 loader: Box<dyn WeightLoader>,
203 },
204 Map {
206 path: PathBuf,
207 format_id: &'static str,
208 map: WeightMap,
209 packed: Vec<crate::weight_map::NamedPackedWeight>,
210 },
211}
212
213impl LoadedWeights {
214 pub fn as_map(&self) -> Option<&WeightMap> {
216 match self {
217 Self::Map { map, .. } => Some(map),
218 Self::Loader { .. } => None,
219 }
220 }
221
222 pub fn as_loader_mut(&mut self) -> Option<&mut dyn WeightLoader> {
224 match self {
225 Self::Loader { loader, .. } => Some(loader.as_mut()),
226 Self::Map { .. } => None,
227 }
228 }
229
230 pub fn packed_tensors(&self) -> Option<&[crate::weight_map::NamedPackedWeight]> {
232 match self {
233 Self::Map { packed, .. } => Some(packed.as_slice()),
234 Self::Loader { .. } => None,
235 }
236 }
237
238 pub fn path(&self) -> &Path {
239 match self {
240 Self::Loader { path, .. } | Self::Map { path, .. } => path,
241 }
242 }
243
244 pub fn format_id(&self) -> &'static str {
245 match self {
246 Self::Loader { format_id, .. } | Self::Map { format_id, .. } => format_id,
247 }
248 }
249
250 pub fn into_map(self) -> Result<WeightMap> {
251 match self {
252 Self::Map { map, packed, .. } => {
253 if !packed.is_empty() {
254 anyhow::bail!(
255 "into_map: {} packed tensors were not merged (use Loader path for packed mode)",
256 packed.len()
257 );
258 }
259 Ok(map)
260 }
261 Self::Loader { mut loader, .. } => Ok(WeightMap::from_weight_loader(loader.as_mut())?),
262 }
263 }
264}
265
266pub fn load_weights_resolved(path: &Path, opts: LoadWeightsOptions<'_>) -> Result<LoadedWeights> {
268 let file = resolve_weights_file_with_options(path, &opts.resolve)?;
269 let ext = file.extension().and_then(|s| s.to_str()).unwrap_or("");
271 let format_id = format_for_extension(ext)
272 .ok_or_else(|| anyhow!("no registered loader for extension `.{ext}`"))?;
273 let mut loader = open_weight_loader(&file)?;
274 if opts.into_map {
275 let (map, packed) = WeightMap::drain_loader(loader.as_mut(), opts.drain)?;
276 Ok(LoadedWeights::Map {
277 path: file,
278 format_id,
279 map,
280 packed,
281 })
282 } else {
283 Ok(LoadedWeights::Loader {
284 path: file,
285 format_id,
286 loader,
287 })
288 }
289}
290
291pub fn load_weight_map_resolved(
293 path: &Path,
294 opts: LoadWeightsOptions<'_>,
295) -> Result<(PathBuf, WeightMap)> {
296 let mut o = opts;
297 o.into_map = true;
298 let loaded = load_weights_resolved(path, o)?;
299 Ok((loaded.path().to_path_buf(), loaded.into_map()?))
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305
306 #[test]
307 fn unknown_extension_errors() {
308 let path = std::env::temp_dir().join("rlx_weight_registry_test.noext");
309 match open_weight_loader(&path) {
310 Err(e) => assert!(
311 e.to_string().contains("unsupported weight extension"),
312 "{e}"
313 ),
314 Ok(_) => panic!("expected unsupported extension for {path:?}"),
315 }
316 }
317
318 #[test]
319 fn format_for_extension_builtin() {
320 assert_eq!(format_for_extension("gguf"), Some("gguf"));
321 assert_eq!(format_for_extension("safetensors"), Some("safetensors"));
322 assert_eq!(format_for_extension("bin"), None);
323 }
324
325 #[test]
326 fn list_formats_includes_builtins() {
327 let ids: Vec<_> = list_registered_formats().iter().map(|r| r.id).collect();
328 assert!(ids.contains(&"gguf"));
329 assert!(ids.contains(&"safetensors"));
330 }
331}