use anyhow::{Context, Result, anyhow};
use std::path::{Path, PathBuf};
use std::sync::{Mutex, OnceLock};
use crate::gguf_support::{ResolveWeightsOptions, resolve_weights_file_with_options};
use crate::weight_loader::{GgufLoader, WeightLoader};
use crate::weight_map::{WeightDrainPolicy, WeightMap};
pub type WeightLoaderFactory = fn(&Path) -> Result<Box<dyn WeightLoader>>;
#[derive(Clone, Copy)]
pub struct WeightFormatRegistration {
pub id: &'static str,
pub extensions: &'static [&'static str],
pub open: WeightLoaderFactory,
}
fn open_safetensors(path: &Path) -> Result<Box<dyn WeightLoader>> {
let path_str = path
.to_str()
.ok_or_else(|| anyhow!("non-utf8 path {:?}", path))?;
Ok(Box::new(WeightMap::from_file(path_str)?))
}
fn open_gguf(path: &Path) -> Result<Box<dyn WeightLoader>> {
let path_str = path
.to_str()
.ok_or_else(|| anyhow!("non-utf8 path {:?}", path))?;
Ok(Box::new(GgufLoader::from_file(path_str)?))
}
static BUILTIN_FORMATS: &[WeightFormatRegistration] = &[
WeightFormatRegistration {
id: "safetensors",
extensions: &["safetensors"],
open: open_safetensors,
},
WeightFormatRegistration {
id: "gguf",
extensions: &["gguf"],
open: open_gguf,
},
];
static CUSTOM_FORMATS: Mutex<Vec<WeightFormatRegistration>> = Mutex::new(Vec::new());
static REGISTRY_INIT: OnceLock<()> = OnceLock::new();
pub fn register_weight_format(reg: WeightFormatRegistration) {
CUSTOM_FORMATS
.lock()
.expect("weight format registry lock")
.push(reg);
}
fn formats() -> Vec<WeightFormatRegistration> {
REGISTRY_INIT.get_or_init(|| ());
let mut out: Vec<WeightFormatRegistration> = BUILTIN_FORMATS.to_vec();
let custom = CUSTOM_FORMATS.lock().expect("weight format registry lock");
out.extend(custom.iter().copied());
out
}
#[derive(Debug, Clone, Copy)]
pub struct RegisteredFormat {
pub id: &'static str,
pub extensions: &'static [&'static str],
}
pub fn list_registered_formats() -> Vec<RegisteredFormat> {
formats()
.into_iter()
.map(|r| RegisteredFormat {
id: r.id,
extensions: r.extensions,
})
.collect()
}
pub fn registered_extensions_hint() -> String {
let mut exts: Vec<&str> = Vec::new();
for reg in list_registered_formats() {
for e in reg.extensions {
if !exts.contains(e) {
exts.push(e);
}
}
}
exts.join(", ")
}
pub fn format_for_extension(ext: &str) -> Option<&'static str> {
let ext = ext.to_ascii_lowercase();
let mut found = None;
for reg in formats() {
if reg.extensions.iter().any(|e| *e == ext) {
found = Some(reg.id);
}
}
found
}
pub fn open_weight_loader(path: &Path) -> Result<Box<dyn WeightLoader>> {
let ext = path.extension().and_then(|s| s.to_str()).unwrap_or("");
for reg in formats() {
if reg.extensions.contains(&ext) {
return (reg.open)(path)
.with_context(|| format!("opening {:?} as format {}", path, reg.id));
}
}
let known = registered_extensions_hint();
Err(anyhow!(
"unsupported weight extension `.{ext}` for {path:?}\n\
Registered extensions: .{known}\n\
Register a custom loader before the first open:\n\
use rlx_core::weights::WeightFormatRegistration;\n\
WeightFormatRegistration::new(\"myfmt\", &[\"mybin\"], my_open).register();\n\
Docs: rlx_core::weights module, README → GGUF → Custom formats"
))
}
#[derive(Debug, Clone, Default)]
pub struct LoadWeightsOptions<'a> {
pub resolve: ResolveWeightsOptions<'a>,
pub drain: WeightDrainPolicy,
pub into_map: bool,
}
impl<'a> LoadWeightsOptions<'a> {
pub fn map() -> Self {
Self {
into_map: true,
..Default::default()
}
}
pub fn loader() -> Self {
Self {
into_map: false,
..Default::default()
}
}
pub fn prefer_q4_k_m(self) -> Self {
self.prefer_substring("Q4_K_M")
}
pub fn prefer_substring(mut self, sub: &'a str) -> Self {
self.resolve.prefer_gguf_substring = Some(sub);
self
}
pub fn gguf_index(mut self, idx: usize) -> Self {
self.resolve.gguf_index = Some(idx);
self
}
pub fn drain(mut self, policy: WeightDrainPolicy) -> Self {
self.drain = policy;
self
}
pub fn warn_unused(self) -> Self {
self.drain(WeightDrainPolicy::AllF32WarnUnused)
}
}
pub enum LoadedWeights {
Loader {
path: PathBuf,
format_id: &'static str,
loader: Box<dyn WeightLoader>,
},
Map {
path: PathBuf,
format_id: &'static str,
map: WeightMap,
packed: Vec<crate::weight_map::NamedPackedWeight>,
},
}
impl LoadedWeights {
pub fn as_map(&self) -> Option<&WeightMap> {
match self {
Self::Map { map, .. } => Some(map),
Self::Loader { .. } => None,
}
}
pub fn as_loader_mut(&mut self) -> Option<&mut dyn WeightLoader> {
match self {
Self::Loader { loader, .. } => Some(loader.as_mut()),
Self::Map { .. } => None,
}
}
pub fn packed_tensors(&self) -> Option<&[crate::weight_map::NamedPackedWeight]> {
match self {
Self::Map { packed, .. } => Some(packed.as_slice()),
Self::Loader { .. } => None,
}
}
pub fn path(&self) -> &Path {
match self {
Self::Loader { path, .. } | Self::Map { path, .. } => path,
}
}
pub fn format_id(&self) -> &'static str {
match self {
Self::Loader { format_id, .. } | Self::Map { format_id, .. } => format_id,
}
}
pub fn into_map(self) -> Result<WeightMap> {
match self {
Self::Map { map, packed, .. } => {
if !packed.is_empty() {
anyhow::bail!(
"into_map: {} packed tensors were not merged (use Loader path for packed mode)",
packed.len()
);
}
Ok(map)
}
Self::Loader { mut loader, .. } => Ok(WeightMap::from_weight_loader(loader.as_mut())?),
}
}
}
pub fn load_weights_resolved(path: &Path, opts: LoadWeightsOptions<'_>) -> Result<LoadedWeights> {
let file = resolve_weights_file_with_options(path, &opts.resolve)?;
let ext = file.extension().and_then(|s| s.to_str()).unwrap_or("");
let format_id = format_for_extension(ext)
.ok_or_else(|| anyhow!("no registered loader for extension `.{ext}`"))?;
let mut loader = open_weight_loader(&file)?;
if opts.into_map {
let (map, packed) = WeightMap::drain_loader(loader.as_mut(), opts.drain)?;
Ok(LoadedWeights::Map {
path: file,
format_id,
map,
packed,
})
} else {
Ok(LoadedWeights::Loader {
path: file,
format_id,
loader,
})
}
}
pub fn load_weight_map_resolved(
path: &Path,
opts: LoadWeightsOptions<'_>,
) -> Result<(PathBuf, WeightMap)> {
let mut o = opts;
o.into_map = true;
let loaded = load_weights_resolved(path, o)?;
Ok((loaded.path().to_path_buf(), loaded.into_map()?))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn unknown_extension_errors() {
let path = std::env::temp_dir().join("rlx_weight_registry_test.noext");
match open_weight_loader(&path) {
Err(e) => assert!(
e.to_string().contains("unsupported weight extension"),
"{e}"
),
Ok(_) => panic!("expected unsupported extension for {path:?}"),
}
}
#[test]
fn format_for_extension_builtin() {
assert_eq!(format_for_extension("gguf"), Some("gguf"));
assert_eq!(format_for_extension("safetensors"), Some("safetensors"));
assert_eq!(format_for_extension("bin"), None);
}
#[test]
fn list_formats_includes_builtins() {
let ids: Vec<_> = list_registered_formats().iter().map(|r| r.id).collect();
assert!(ids.contains(&"gguf"));
assert!(ids.contains(&"safetensors"));
}
}