use std::path::Path;
use crate::model::{Architecture, ModelConfig, ModelError, ModelResult, ModelSource};
use crate::model::hf_config::HfConfig;
use crate::tensor::Tensor;
use super::{ShardedSafeTensors, TensorNameMapper};
use super::reader::{bf16_to_f32, f16_to_f32, SafeTensorsDtype};
pub struct SafeTensorsLoader {
files: ShardedSafeTensors,
name_mapper: TensorNameMapper,
config: ModelConfig,
architecture: Architecture,
}
impl SafeTensorsLoader {
pub fn load(path: &Path) -> ModelResult<Self> {
let dir = if path.is_dir() {
path.to_path_buf()
} else {
path.parent().unwrap_or(Path::new(".")).to_path_buf()
};
let files = ShardedSafeTensors::open(&dir)
.map_err(|e| ModelError::ConfigError(format!("SafeTensors: {e}")))?;
let config_path = dir.join("config.json");
let config_str = std::fs::read_to_string(&config_path)
.map_err(|e| ModelError::ConfigError(
format!("Failed to read {}: {e}", config_path.display())
))?;
let hf_config = HfConfig::from_json(&config_str)?;
let architecture = hf_config.architecture();
let architecture = if matches!(architecture, Architecture::Unknown) {
let model_type = hf_config.model_type.as_deref().unwrap_or("llama");
Architecture::from_gguf_str(model_type)
} else {
architecture
};
let config = hf_config.to_model_config()?;
let tensor_names = files.tensor_names();
let name_mapper = TensorNameMapper::from_tensor_names(&tensor_names, architecture);
tracing::info!(
"SafeTensors model: {} tensors, {} mapped, arch={:?}",
files.num_tensors(),
name_mapper.len(),
architecture,
);
Ok(Self { files, name_mapper, config, architecture })
}
fn needs_gemma_norm_offset(&self) -> bool {
false
}
fn is_norm_weight(name: &str) -> bool {
name.ends_with("_norm.weight") || name == "output_norm.weight"
}
fn apply_norm_offset(data: &mut [f32]) {
for v in data.iter_mut() {
*v += 1.0;
}
}
}
impl ModelSource for SafeTensorsLoader {
fn config(&self) -> &ModelConfig {
&self.config
}
fn config_mut(&mut self) -> &mut ModelConfig {
&mut self.config
}
fn architecture(&self) -> Architecture {
self.architecture
}
fn load_tensor(&self, name: &str) -> ModelResult<Tensor> {
let hf_name = self.name_mapper.to_hf(name)
.ok_or_else(|| ModelError::MissingTensor(format!(
"{name} (no HuggingFace mapping found)"
)))?;
let info = self.files.tensor_info(hf_name)
.ok_or_else(|| ModelError::MissingTensor(hf_name.to_string()))?;
let data = self.files.tensor_data(hf_name)
.ok_or_else(|| ModelError::MissingTensor(hf_name.to_string()))?;
let mut f32_data = match info.dtype {
SafeTensorsDtype::F32 => {
data.chunks_exact(4)
.map(|c| f32::from_le_bytes([c[0], c[1], c[2], c[3]]))
.collect()
}
SafeTensorsDtype::BF16 => bf16_to_f32(data),
SafeTensorsDtype::F16 => f16_to_f32(data),
other => return Err(ModelError::ConfigError(
format!("Unsupported SafeTensors dtype {:?} for tensor {name}", other)
)),
};
if self.needs_gemma_norm_offset() && Self::is_norm_weight(name) {
Self::apply_norm_offset(&mut f32_data);
}
let shape = if info.shape.len() == 2 {
vec![info.shape[1], info.shape[0]]
} else {
info.shape.clone()
};
let mut tensor = Tensor::from_f32(&f32_data, shape)?;
tensor.set_name(name);
Ok(tensor)
}
fn try_load_tensor(&self, name: &str) -> Option<Tensor> {
self.load_tensor(name).ok()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::TempDir;
fn write_safetensors_file(
path: &std::path::Path,
tensors: &[(&str, Vec<usize>, &[f32])],
) {
let mut data_buf = Vec::new();
let mut header = serde_json::Map::new();
for &(name, ref shape, values) in tensors {
let start = data_buf.len();
for &v in values {
data_buf.extend_from_slice(&v.to_le_bytes());
}
let end = data_buf.len();
let shape_json: Vec<serde_json::Value> = shape.iter()
.map(|&d| serde_json::Value::Number(serde_json::Number::from(d)))
.collect();
let mut obj = serde_json::Map::new();
obj.insert("dtype".into(), "F32".into());
obj.insert("shape".into(), serde_json::Value::Array(shape_json));
obj.insert("data_offsets".into(), serde_json::json!([start, end]));
header.insert(name.to_string(), serde_json::Value::Object(obj));
}
let header_str = serde_json::to_string(&serde_json::Value::Object(header)).unwrap();
let header_bytes = header_str.as_bytes();
let mut f = std::fs::File::create(path).unwrap();
f.write_all(&(header_bytes.len() as u64).to_le_bytes()).unwrap();
f.write_all(header_bytes).unwrap();
f.write_all(&data_buf).unwrap();
}
fn write_config_json(dir: &std::path::Path, model_type: &str) {
let config = serde_json::json!({
"model_type": model_type,
"hidden_size": 64,
"intermediate_size": 128,
"num_hidden_layers": 1,
"num_attention_heads": 2,
"num_key_value_heads": 2,
"max_position_embeddings": 128,
"rms_norm_eps": 1e-5,
"vocab_size": 8,
"head_dim": 32,
});
let path = dir.join("config.json");
std::fs::write(path, serde_json::to_string_pretty(&config).unwrap()).unwrap();
}
#[test]
fn test_load_safetensors_basic() {
let dir = TempDir::new().unwrap();
let dir_path = dir.path();
write_config_json(dir_path, "llama");
let embd_data: Vec<f32> = (0..512).map(|i| i as f32 * 0.01).collect(); let norm_data: Vec<f32> = vec![1.0; 64];
write_safetensors_file(
&dir_path.join("model.safetensors"),
&[
("model.embed_tokens.weight", vec![8, 64], &embd_data),
("model.norm.weight", vec![64], &norm_data),
],
);
let loader = SafeTensorsLoader::load(dir_path).unwrap();
assert_eq!(loader.architecture(), Architecture::Llama);
assert_eq!(loader.config().hidden_size, 64);
assert_eq!(loader.config().vocab_size, 8);
let embd = loader.load_tensor("token_embd.weight").unwrap();
assert_eq!(embd.shape(), &[64, 8]);
let norm = loader.load_tensor("output_norm.weight").unwrap();
assert_eq!(norm.shape(), &[64]);
}
#[test]
fn test_missing_tensor_returns_error() {
let dir = TempDir::new().unwrap();
let dir_path = dir.path();
write_config_json(dir_path, "llama");
write_safetensors_file(
&dir_path.join("model.safetensors"),
&[("model.embed_tokens.weight", vec![8, 64], &vec![0.0; 512])],
);
let loader = SafeTensorsLoader::load(dir_path).unwrap();
assert!(loader.load_tensor("nonexistent.weight").is_err());
}
#[test]
fn test_try_load_tensor_returns_none() {
let dir = TempDir::new().unwrap();
let dir_path = dir.path();
write_config_json(dir_path, "llama");
write_safetensors_file(
&dir_path.join("model.safetensors"),
&[("model.embed_tokens.weight", vec![8, 64], &vec![0.0; 512])],
);
let loader = SafeTensorsLoader::load(dir_path).unwrap();
assert!(loader.try_load_tensor("nonexistent.weight").is_none());
assert!(loader.try_load_tensor("token_embd.weight").is_some());
}
#[test]
fn test_gemma_norm_no_offset() {
let dir = TempDir::new().unwrap();
let dir_path = dir.path();
let config = serde_json::json!({
"model_type": "gemma4",
"hidden_size": 4,
"intermediate_size": 8,
"num_hidden_layers": 1,
"num_attention_heads": 1,
"num_key_value_heads": 1,
"max_position_embeddings": 16,
"rms_norm_eps": 1e-5,
"vocab_size": 4,
"head_dim": 4,
});
std::fs::write(dir_path.join("config.json"),
serde_json::to_string_pretty(&config).unwrap()
).unwrap();
let norm_data = vec![9.375f32; 4];
write_safetensors_file(
&dir_path.join("model.safetensors"),
&[
("model.norm.weight", vec![4], &norm_data),
("model.embed_tokens.weight", vec![4, 4], &vec![0.5; 16]),
],
);
let loader = SafeTensorsLoader::load(dir_path).unwrap();
let norm = loader.load_tensor("output_norm.weight").unwrap();
let norm_vals = norm.as_f32().unwrap();
for &v in norm_vals.iter() {
assert!((v - 9.375).abs() < 1e-6, "Expected 9.375, got {v}");
}
}
#[test]
fn test_non_gemma_no_norm_offset() {
let dir = TempDir::new().unwrap();
let dir_path = dir.path();
write_config_json(dir_path, "llama");
let norm_data = vec![0.5f32; 64];
write_safetensors_file(
&dir_path.join("model.safetensors"),
&[
("model.norm.weight", vec![64], &norm_data),
("model.embed_tokens.weight", vec![8, 64], &vec![0.0; 512]),
],
);
let loader = SafeTensorsLoader::load(dir_path).unwrap();
assert!(!loader.architecture().is_gemma());
let norm = loader.load_tensor("output_norm.weight").unwrap();
let norm_vals = norm.as_f32().unwrap();
for &v in norm_vals.iter() {
assert!((v - 0.5).abs() < 1e-6, "Expected 0.5, got {v}");
}
}
}