use super::core::VarMap;
use crate::error::{Error, Result};
use crate::format::gguf::Gguf;
use crate::format::safetensors::SafeTensors;
use numr::dtype::DType;
use numr::runtime::Runtime;
use numr::tensor::Tensor;
use std::collections::HashMap;
use std::path::Path;
impl<R: Runtime<DType = DType>> VarMap<R> {
pub fn from_safetensors<P: AsRef<Path>>(path: P, device: &R::Device) -> Result<Self> {
Self::from_safetensors_with_model_type(path, device, None)
}
pub fn from_safetensors_with_model_type<P: AsRef<Path>>(
path: P,
device: &R::Device,
model_type: Option<&str>,
) -> Result<Self> {
use crate::format::safetensors_name_map::normalize_hf_name;
let mut st = SafeTensors::open(path)?;
let all = st.load_all::<R>(device)?;
let mut map = Self::new();
for (name, tensor) in all {
let mapped = match model_type {
Some(mt) => normalize_hf_name(mt, &name),
None => name,
};
map.insert(mapped, tensor);
}
Ok(map)
}
pub fn from_safetensors_sharded<P: AsRef<Path>>(dir: P, device: &R::Device) -> Result<Self> {
Self::from_safetensors_sharded_with_model_type(dir, device, None)
}
pub fn from_safetensors_sharded_with_model_type<P: AsRef<Path>>(
dir: P,
device: &R::Device,
model_type: Option<&str>,
) -> Result<Self> {
let dir = dir.as_ref();
let index_path = dir.join("model.safetensors.index.json");
let index_str = std::fs::read_to_string(&index_path).map_err(|e| Error::ModelError {
reason: format!("failed to read index file: {e}"),
})?;
let index: serde_json::Value =
serde_json::from_str(&index_str).map_err(|e| Error::ModelError {
reason: format!("failed to parse index JSON: {e}"),
})?;
let weight_map = index
.get("weight_map")
.and_then(|v| v.as_object())
.ok_or_else(|| Error::ModelError {
reason: "index.json missing 'weight_map' object".into(),
})?;
let mut shard_to_names: HashMap<String, Vec<String>> = HashMap::new();
for (tensor_name, shard_val) in weight_map {
let shard_file = shard_val.as_str().ok_or_else(|| Error::ModelError {
reason: format!("weight_map value for '{tensor_name}' is not a string"),
})?;
shard_to_names
.entry(shard_file.to_string())
.or_default()
.push(tensor_name.clone());
}
let mut map = Self::new();
for (shard_file, names) in &shard_to_names {
let shard_path = dir.join(shard_file);
let mut st = SafeTensors::open(&shard_path)?;
for name in names {
let tensor = st.load_tensor::<R>(name, device)?;
let mapped = match model_type {
Some(mt) => crate::format::safetensors_name_map::normalize_hf_name(mt, name),
None => name.clone(),
};
map.insert(mapped, tensor);
}
}
Ok(map)
}
pub fn from_gguf<P: AsRef<Path>>(path: P, device: &R::Device) -> Result<Self>
where
R::Client: numr::ops::ShapeOps<R>,
{
use crate::format::gguf::gguf_to_hf_name;
let mut gguf = Gguf::open(path)?;
let names: Vec<String> = gguf.tensor_names().map(|s| s.to_string()).collect();
let mut map = Self::new();
for name in &names {
let hf_name = gguf_to_hf_name(name);
let info = gguf.tensor_info(name)?.clone();
if info.ggml_type.is_quantized() {
let qt = gguf.load_tensor_quantized::<R>(name, device)?;
map.insert_quant(hf_name, qt);
} else {
let t = gguf.load_tensor_f32::<R>(name, device)?;
map.insert(hf_name, t);
}
}
Self::stack_moe_experts(&mut map, device)?;
Ok(map)
}
fn stack_moe_experts(map: &mut Self, _device: &R::Device) -> Result<()>
where
R::Client: numr::ops::ShapeOps<R>,
{
use std::collections::BTreeMap;
let mut groups: HashMap<String, BTreeMap<usize, String>> = HashMap::new();
let all_names: Vec<String> = map.names().map(|s| s.to_string()).collect();
for name in &all_names {
if let Some(experts_pos) = name.find(".experts.") {
let after_experts = &name[experts_pos + ".experts.".len()..];
if let Some(dot_pos) = after_experts.find('.') {
let id_str = &after_experts[..dot_pos];
if let Ok(expert_id) = id_str.parse::<usize>() {
let prefix = &name[..experts_pos];
let suffix = &after_experts[dot_pos + 1..];
let group_key = format!("{prefix}.experts.{suffix}");
groups
.entry(group_key)
.or_default()
.insert(expert_id, name.clone());
}
}
}
}
for (stacked_name, expert_entries) in &groups {
if expert_entries.len() < 2 {
continue;
}
let mut tensors: Vec<Tensor<R>> = Vec::with_capacity(expert_entries.len());
let mut all_standard = true;
for name in expert_entries.values() {
match map.get(name) {
Ok(w) if !w.is_quantized() => {
if let Ok(t) = w.as_tensor() {
tensors.push(t.clone());
} else {
all_standard = false;
break;
}
}
_ => {
all_standard = false;
break;
}
}
}
if !all_standard || tensors.is_empty() {
continue;
}
let tensor_refs: Vec<&Tensor<R>> = tensors.iter().collect();
let stacked = Tensor::<R>::stack(&tensor_refs, 0).map_err(|e| {
crate::error::Error::ModelError {
reason: format!("Failed to stack expert tensors for {stacked_name}: {e}"),
}
})?;
for name in expert_entries.values() {
map.remove(name);
}
map.insert(stacked_name.clone(), stacked);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::format::safetensors::save_safetensors;
use crate::quant::QuantFormat;
use numr::runtime::cpu::{CpuDevice, CpuRuntime};
use numr::tensor::Tensor;
fn device() -> CpuDevice {
CpuDevice::new()
}
#[test]
fn test_varmap_from_safetensors_roundtrip() {
let d = device();
let tmp = tempfile::NamedTempFile::new().unwrap();
let mut tensors = HashMap::new();
tensors.insert(
"a".to_string(),
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0, 3.0], &[3], &d),
);
tensors.insert(
"b".to_string(),
Tensor::<CpuRuntime>::from_slice(&[4.0f32, 5.0], &[2], &d),
);
save_safetensors(tmp.path(), &tensors, None).unwrap();
let map = VarMap::<CpuRuntime>::from_safetensors(tmp.path(), &d).unwrap();
assert_eq!(map.len(), 2);
let a = map.get_tensor("a").unwrap();
assert_eq!(a.shape(), &[3]);
let data = a.to_vec::<f32>();
assert!((data[0] - 1.0).abs() < 1e-6);
let b = map.get_tensor("b").unwrap();
assert_eq!(b.shape(), &[2]);
}
#[test]
fn test_varmap_from_gguf() {
let d = device();
let tmp = create_test_gguf_file();
let map = VarMap::<CpuRuntime>::from_gguf(tmp.path(), &d).unwrap();
assert_eq!(map.len(), 2);
let f32_w = map.get("weight_f32").unwrap();
assert!(!f32_w.is_quantized());
let t = f32_w.as_tensor().unwrap();
assert_eq!(t.shape(), &[4]);
let q4_w = map.get("weight_q4").unwrap();
assert!(q4_w.is_quantized());
let qt = q4_w.as_quant_tensor().unwrap();
assert_eq!(qt.shape(), &[32]);
assert_eq!(qt.format(), QuantFormat::Q4_0);
}
#[test]
fn test_varmap_from_safetensors_sharded() {
let d = device();
let dir = tempfile::TempDir::new().unwrap();
let mut shard1 = HashMap::new();
shard1.insert(
"layers.0.weight".to_string(),
Tensor::<CpuRuntime>::from_slice(&[1.0f32, 2.0], &[2], &d),
);
save_safetensors(
dir.path().join("model-00001-of-00002.safetensors"),
&shard1,
None,
)
.unwrap();
let mut shard2 = HashMap::new();
shard2.insert(
"layers.1.weight".to_string(),
Tensor::<CpuRuntime>::from_slice(&[3.0f32, 4.0, 5.0], &[3], &d),
);
save_safetensors(
dir.path().join("model-00002-of-00002.safetensors"),
&shard2,
None,
)
.unwrap();
let index = serde_json::json!({
"metadata": {"total_size": 20},
"weight_map": {
"layers.0.weight": "model-00001-of-00002.safetensors",
"layers.1.weight": "model-00002-of-00002.safetensors"
}
});
std::fs::write(
dir.path().join("model.safetensors.index.json"),
serde_json::to_string_pretty(&index).unwrap(),
)
.unwrap();
let map = VarMap::<CpuRuntime>::from_safetensors_sharded(dir.path(), &d).unwrap();
assert_eq!(map.len(), 2);
let t0 = map.get_tensor("layers.0.weight").unwrap();
assert_eq!(t0.shape(), &[2]);
let data0: Vec<f32> = t0.to_vec();
assert!((data0[0] - 1.0).abs() < 1e-6);
let t1 = map.get_tensor("layers.1.weight").unwrap();
assert_eq!(t1.shape(), &[3]);
let data1: Vec<f32> = t1.to_vec();
assert!((data1[2] - 5.0).abs() < 1e-6);
}
fn create_test_gguf_file() -> tempfile::NamedTempFile {
use crate::format::gguf::types::{GgmlType, GgufValueType};
use std::io::Write;
let mut buf = Vec::new();
let gguf_magic: u32 = 0x46554747;
buf.extend_from_slice(&gguf_magic.to_le_bytes());
buf.extend_from_slice(&3u32.to_le_bytes());
buf.extend_from_slice(&2u64.to_le_bytes());
buf.extend_from_slice(&1u64.to_le_bytes());
write_gguf_str(&mut buf, "general.architecture");
buf.extend_from_slice(&(GgufValueType::String as u32).to_le_bytes());
write_gguf_str(&mut buf, "test");
write_gguf_str(&mut buf, "weight_f32");
buf.extend_from_slice(&1u32.to_le_bytes());
buf.extend_from_slice(&4u64.to_le_bytes());
buf.extend_from_slice(&(GgmlType::F32 as u32).to_le_bytes());
buf.extend_from_slice(&0u64.to_le_bytes());
write_gguf_str(&mut buf, "weight_q4");
buf.extend_from_slice(&1u32.to_le_bytes());
buf.extend_from_slice(&32u64.to_le_bytes());
buf.extend_from_slice(&(GgmlType::Q4_0 as u32).to_le_bytes());
buf.extend_from_slice(&16u64.to_le_bytes());
let aligned = buf.len().div_ceil(32) * 32;
buf.resize(aligned, 0);
for f in [1.0f32, 2.0, 3.0, 4.0] {
buf.extend_from_slice(&f.to_le_bytes());
}
let scale_bits = half::f16::from_f32(1.0).to_bits();
buf.push((scale_bits & 0xFF) as u8);
buf.push(((scale_bits >> 8) & 0xFF) as u8);
buf.extend(std::iter::repeat_n(0x88u8, 16));
let mut file = tempfile::NamedTempFile::new().unwrap();
file.write_all(&buf).unwrap();
file.flush().unwrap();
file
}
fn write_gguf_str(buf: &mut Vec<u8>, s: &str) {
buf.extend_from_slice(&(s.len() as u64).to_le_bytes());
buf.extend_from_slice(s.as_bytes());
}
}