use std::collections::HashMap;
use std::io;
use std::path::Path;
use crate::config::HFConfig;
use crate::ir::{Architecture, DType, ModelConfig};
use crate::safetensors::{SafeTensorsError, SafeTensorsFile};
use crate::weight_loader::ModelWeights;
#[derive(Debug, thiserror::Error)]
pub enum SafeTensorsLoadError {
#[error("I/O error: {0}")]
Io(#[from] io::Error),
#[error("SafeTensors parse error: {0}")]
Parse(#[from] SafeTensorsError),
#[error("could not infer model config from tensor shapes: {0}")]
ConfigInference(String),
}
pub fn load_safetensors(
path: impl AsRef<Path>,
) -> Result<(ModelConfig, ModelWeights), SafeTensorsLoadError> {
let path = path.as_ref();
let file = std::fs::File::open(path)?;
let mmap = unsafe { memmap2::Mmap::map(&file)? };
let data: &[u8] = &mmap;
let st_file = crate::safetensors::parse(std::io::Cursor::new(data))?;
let data_offset = st_file.data_offset as usize;
let mut tensors: HashMap<String, Vec<f32>> = HashMap::with_capacity(st_file.tensors.len());
for info in &st_file.tensors {
let abs_start = data_offset + info.data_start;
let abs_end = data_offset + info.data_end;
if abs_end > data.len() {
return Err(SafeTensorsLoadError::Io(io::Error::new(
io::ErrorKind::UnexpectedEof,
format!("tensor {} extends past end of file", info.name),
)));
}
let raw = &data[abs_start..abs_end];
let floats = bytes_to_f32(raw, info.dtype).map_err(|e| {
SafeTensorsLoadError::Io(io::Error::new(
io::ErrorKind::InvalidData,
format!("dtype conversion for {}: {e}", info.name),
))
})?;
tensors.insert(info.name.clone(), floats);
}
let config = if let Some(parent) = path.parent() {
let config_path = parent.join("config.json");
if config_path.exists() {
let json = std::fs::read(&config_path)?;
if let Ok(hf_cfg) = HFConfig::from_json(&json) {
hf_cfg.to_model_config()
} else {
None
}
} else {
None
}
} else {
None
};
let config = match config {
Some(c) => c,
None => infer_model_config(&st_file, &tensors)?,
};
Ok((config, ModelWeights { tensors }))
}
pub fn infer_model_config(
st_file: &SafeTensorsFile,
tensors: &HashMap<String, Vec<f32>>,
) -> Result<ModelConfig, SafeTensorsLoadError> {
let shape_map: HashMap<&str, &[usize]> = st_file
.tensors
.iter()
.map(|t| (t.name.as_str(), t.shape.as_slice()))
.collect();
let embed_shape = shape_map.get("model.embed_tokens.weight").ok_or_else(|| {
SafeTensorsLoadError::ConfigInference("missing model.embed_tokens.weight".to_string())
})?;
if embed_shape.len() < 2 {
return Err(SafeTensorsLoadError::ConfigInference(
"model.embed_tokens.weight must be 2-D".to_string(),
));
}
let vocab_size = embed_shape[0];
let hidden_size = embed_shape[1];
let num_layers = count_num_layers(&shape_map);
if num_layers == 0 {
return Err(SafeTensorsLoadError::ConfigInference(
"could not find any model.layers.N.* tensors".to_string(),
));
}
let q_shape = shape_map
.get("model.layers.0.self_attn.q_proj.weight")
.ok_or_else(|| {
SafeTensorsLoadError::ConfigInference(
"missing model.layers.0.self_attn.q_proj.weight".to_string(),
)
})?;
if q_shape.len() < 2 {
return Err(SafeTensorsLoadError::ConfigInference(
"q_proj.weight must be at least 2-D".to_string(),
));
}
let q_out_dim = q_shape[0];
let k_out_dim = shape_map
.get("model.layers.0.self_attn.k_proj.weight")
.and_then(|s| s.first().copied());
let head_dim = infer_head_dim(q_out_dim, k_out_dim);
let num_attention_heads = if head_dim > 0 {
q_out_dim / head_dim
} else {
1
};
let num_kv_heads = k_out_dim
.map(|k| {
if head_dim > 0 {
k / head_dim
} else {
num_attention_heads
}
})
.unwrap_or(num_attention_heads);
let intermediate_size = shape_map
.get("model.layers.0.mlp.gate_proj.weight")
.or_else(|| shape_map.get("model.layers.0.mlp.up_proj.weight"))
.and_then(|s| s.first().copied())
.unwrap_or(hidden_size * 4);
let has_gate_proj = shape_map.contains_key("model.layers.0.mlp.gate_proj.weight");
let has_q_bias = tensors.contains_key("model.layers.0.self_attn.q_proj.bias");
let dtype = st_file
.tensors
.first()
.map(|t| match t.dtype {
DType::F16 => DType::F16,
DType::BF16 => DType::BF16,
_ => DType::F32,
})
.unwrap_or(DType::F32);
let qkv_bias = has_q_bias;
let architecture = if has_q_bias {
Architecture::Qwen2
} else {
let _ = has_gate_proj;
Architecture::Llama
};
Ok(ModelConfig {
architecture,
hidden_size,
intermediate_size,
num_layers,
num_attention_heads,
num_kv_heads,
head_dim,
vocab_size,
max_seq_len: 4096,
rms_norm_eps: 1e-5,
rope_theta: 10000.0,
dtype,
sliding_window_size: None,
qkv_bias,
hidden_activation: crate::ir::HiddenActivation::SiLU,
})
}
fn count_num_layers(shape_map: &HashMap<&str, &[usize]>) -> usize {
let mut max_idx: Option<usize> = None;
for name in shape_map.keys() {
if let Some(rest) = name.strip_prefix("model.layers.") {
if let Some(dot) = rest.find('.') {
if let Ok(idx) = rest[..dot].parse::<usize>() {
max_idx = Some(max_idx.map_or(idx, |m| m.max(idx)));
}
}
}
}
max_idx.map(|m| m + 1).unwrap_or(0)
}
fn infer_head_dim(q_out_dim: usize, k_out_dim: Option<usize>) -> usize {
const CANDIDATES: &[usize] = &[256, 192, 160, 128, 96, 80, 64, 32];
for &d in CANDIDATES {
let divides_q = q_out_dim.is_multiple_of(d);
let divides_k = k_out_dim.is_none_or(|k| k.is_multiple_of(d));
if divides_q && divides_k {
return d;
}
}
let fallback_heads = [32, 16, 8, 4, 1];
for &h in &fallback_heads {
if q_out_dim.is_multiple_of(h) {
return q_out_dim / h;
}
}
q_out_dim
}
fn bytes_to_f32(data: &[u8], dtype: DType) -> Result<Vec<f32>, String> {
match dtype {
DType::F32 => {
if !data.len().is_multiple_of(4) {
return Err(format!("F32 data length {} not divisible by 4", data.len()));
}
Ok(data
.chunks_exact(4)
.map(|b| f32::from_le_bytes(b.try_into().unwrap()))
.collect())
}
DType::F16 => {
if !data.len().is_multiple_of(2) {
return Err(format!("F16 data length {} not divisible by 2", data.len()));
}
Ok(data
.chunks_exact(2)
.map(|b| f16_to_f32(u16::from_le_bytes(b.try_into().unwrap())))
.collect())
}
DType::BF16 => {
if !data.len().is_multiple_of(2) {
return Err(format!(
"BF16 data length {} not divisible by 2",
data.len()
));
}
Ok(data
.chunks_exact(2)
.map(|b| bf16_to_f32(u16::from_le_bytes(b.try_into().unwrap())))
.collect())
}
DType::I32 => {
if !data.len().is_multiple_of(4) {
return Err(format!("I32 data length {} not divisible by 4", data.len()));
}
Ok(data
.chunks_exact(4)
.map(|b| i32::from_le_bytes(b.try_into().unwrap()) as f32)
.collect())
}
DType::I64 => {
if !data.len().is_multiple_of(8) {
return Err(format!("I64 data length {} not divisible by 8", data.len()));
}
Ok(data
.chunks_exact(8)
.map(|b| i64::from_le_bytes(b.try_into().unwrap()) as f32)
.collect())
}
other => Err(format!(
"unsupported dtype for SafeTensors loading: {other}"
)),
}
}
#[inline]
fn f16_to_f32(bits: u16) -> f32 {
let sign = (bits >> 15) as u32;
let exp = ((bits >> 10) & 0x1f) as u32;
let mant = (bits & 0x3ff) as u32;
let f32_bits = if exp == 0 {
if mant == 0 {
sign << 31
} else {
let mut e = 127 - 14;
let mut m = mant;
while m & 0x400 == 0 {
m <<= 1;
e -= 1;
}
(sign << 31) | (e << 23) | ((m & 0x3ff) << 13)
}
} else if exp == 0x1f {
(sign << 31) | (0xff << 23) | (mant << 13)
} else {
(sign << 31) | ((exp + 127 - 15) << 23) | (mant << 13)
};
f32::from_bits(f32_bits)
}
#[inline]
fn bf16_to_f32(bits: u16) -> f32 {
f32::from_bits((bits as u32) << 16)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::safetensors::{SafeTensorInfo, SafeTensorsFile};
fn build_safetensors_bytes(header_json: &str, tensor_data: &[u8]) -> Vec<u8> {
let hdr = header_json.as_bytes();
let mut buf = Vec::new();
buf.extend_from_slice(&(hdr.len() as u64).to_le_bytes());
buf.extend_from_slice(hdr);
buf.extend_from_slice(tensor_data);
buf
}
fn make_st_file(tensor_infos: Vec<SafeTensorInfo>) -> SafeTensorsFile {
SafeTensorsFile {
tensors: tensor_infos,
data_offset: 0,
metadata: HashMap::new(),
}
}
fn make_info(name: &str, shape: Vec<usize>, dtype: DType) -> SafeTensorInfo {
let numel: usize = shape.iter().product();
let bytes = numel * 4; SafeTensorInfo {
name: name.to_string(),
dtype,
shape,
data_start: 0,
data_end: bytes,
}
}
#[test]
fn infer_config_basic_llama() {
let hidden = 2560usize;
let head_dim = 160usize; let num_heads = hidden / head_dim; let num_kv_heads = 4usize;
let intermediate = 6912usize;
let vocab = 32000usize;
let infos = vec![
make_info("model.embed_tokens.weight", vec![vocab, hidden], DType::F32),
make_info(
"model.layers.0.self_attn.q_proj.weight",
vec![num_heads * head_dim, hidden],
DType::F32,
),
make_info(
"model.layers.0.self_attn.k_proj.weight",
vec![num_kv_heads * head_dim, hidden],
DType::F32,
),
make_info(
"model.layers.0.mlp.gate_proj.weight",
vec![intermediate, hidden],
DType::F32,
),
make_info(
"model.layers.1.self_attn.q_proj.weight",
vec![num_heads * head_dim, hidden],
DType::F32,
),
];
let st_file = make_st_file(infos);
let tensors: HashMap<String, Vec<f32>> = HashMap::new();
let config = infer_model_config(&st_file, &tensors).unwrap();
assert_eq!(config.vocab_size, vocab);
assert_eq!(config.hidden_size, hidden);
assert_eq!(config.num_layers, 2);
assert_eq!(config.num_attention_heads, num_heads);
assert_eq!(config.num_kv_heads, num_kv_heads);
assert_eq!(config.head_dim, head_dim);
assert_eq!(config.intermediate_size, intermediate);
assert_eq!(config.architecture, Architecture::Llama);
}
#[test]
fn infer_config_missing_embed_tokens_errors() {
let st_file = make_st_file(vec![]);
let tensors = HashMap::new();
let result = infer_model_config(&st_file, &tensors);
assert!(matches!(
result,
Err(SafeTensorsLoadError::ConfigInference(_))
));
}
#[test]
fn infer_config_missing_q_proj_errors() {
let infos = vec![
make_info("model.embed_tokens.weight", vec![32000, 2048], DType::F32),
make_info(
"model.layers.0.self_attn.k_proj.weight",
vec![512, 2048],
DType::F32,
),
];
let st_file = make_st_file(infos);
let tensors = HashMap::new();
let result = infer_model_config(&st_file, &tensors);
assert!(matches!(
result,
Err(SafeTensorsLoadError::ConfigInference(_))
));
}
#[test]
fn f16_to_f32_one() {
assert_eq!(f16_to_f32(0x3C00), 1.0f32);
}
#[test]
fn f16_to_f32_zero() {
assert_eq!(f16_to_f32(0x0000), 0.0f32);
}
#[test]
fn bf16_to_f32_one() {
assert_eq!(bf16_to_f32(0x3F80), 1.0f32);
}
#[test]
fn bytes_to_f32_f32_dtype() {
let floats: [f32; 4] = [1.0, 2.0, 3.0, 4.0];
let raw: Vec<u8> = floats.iter().flat_map(|f| f.to_le_bytes()).collect();
let out = bytes_to_f32(&raw, DType::F32).unwrap();
assert_eq!(out, vec![1.0f32, 2.0, 3.0, 4.0]);
}
#[test]
fn bytes_to_f32_f16_dtype() {
let raw: Vec<u8> = vec![0x00, 0x3C, 0x00, 0x3C]; let out = bytes_to_f32(&raw, DType::F16).unwrap();
assert_eq!(out.len(), 2);
assert!((out[0] - 1.0f32).abs() < 1e-6);
assert!((out[1] - 1.0f32).abs() < 1e-6);
}
#[test]
fn bytes_to_f32_bf16_dtype() {
let raw: Vec<u8> = vec![0x80, 0x3F, 0x80, 0x3F];
let out = bytes_to_f32(&raw, DType::BF16).unwrap();
assert_eq!(out.len(), 2);
assert!((out[0] - 1.0f32).abs() < 1e-6);
}
#[test]
fn bytes_to_f32_unaligned_returns_error() {
let raw = vec![0u8; 3]; assert!(bytes_to_f32(&raw, DType::F32).is_err());
}
#[test]
fn count_layers_from_shape_map() {
let mut map: HashMap<&str, &[usize]> = HashMap::new();
let shape = vec![1024usize, 1024];
let s: &[usize] = &shape;
map.insert("model.layers.0.self_attn.q_proj.weight", s);
map.insert("model.layers.3.self_attn.q_proj.weight", s);
map.insert("model.embed_tokens.weight", s);
assert_eq!(count_num_layers(&map), 4); }
#[test]
fn load_minimal_safetensors_bytes() {
let floats: [f32; 4] = [1.0, 2.0, 3.0, 4.0];
let raw: Vec<u8> = floats.iter().flat_map(|f| f.to_le_bytes()).collect();
let header = r#"{"w": {"dtype": "F32", "shape": [2, 2], "data_offsets": [0, 16]}}"#;
let buf = build_safetensors_bytes(header, &raw);
let st_file = crate::safetensors::parse(std::io::Cursor::new(&buf)).expect("parse failed");
assert_eq!(st_file.tensors.len(), 1);
assert_eq!(st_file.tensors[0].name, "w");
assert_eq!(st_file.tensors[0].shape, vec![2usize, 2]);
let data_start = st_file.data_offset as usize + st_file.tensors[0].data_start;
let data_end = st_file.data_offset as usize + st_file.tensors[0].data_end;
let loaded = bytes_to_f32(&buf[data_start..data_end], DType::F32).unwrap();
assert_eq!(loaded, vec![1.0f32, 2.0, 3.0, 4.0]);
}
}