use crate::error::{ModelError, ModelResult};
use crate::ModelType;
use safetensors::tensor::SafeTensors;
use scirs2_core::ndarray::{Array1, Array2, ArrayD};
use std::collections::HashMap;
use std::fs::File;
use std::io::Read;
use std::path::Path;
pub struct ModelLoader {
tensors: SafeTensors<'static>,
_data: Vec<u8>,
}
impl ModelLoader {
pub fn new<P: AsRef<Path>>(path: P) -> ModelResult<Self> {
let mut file = File::open(path.as_ref())
.map_err(|e| ModelError::simple_load_error(format!("Failed to open file: {}", e)))?;
let mut data = Vec::new();
file.read_to_end(&mut data)
.map_err(|e| ModelError::simple_load_error(format!("Failed to read file: {}", e)))?;
let data_static = Box::leak(data.clone().into_boxed_slice());
let tensors = SafeTensors::deserialize(data_static).map_err(|e| {
ModelError::simple_load_error(format!("Failed to parse safetensors: {}", e))
})?;
Ok(Self {
tensors,
_data: data,
})
}
pub fn from_bytes(data: Vec<u8>) -> ModelResult<Self> {
let data_static = Box::leak(data.clone().into_boxed_slice());
let tensors = SafeTensors::deserialize(data_static).map_err(|e| {
ModelError::simple_load_error(format!("Failed to parse safetensors: {}", e))
})?;
Ok(Self {
tensors,
_data: data,
})
}
pub fn list_tensors(&self) -> Vec<String> {
self.tensors.names().iter().map(|s| s.to_string()).collect()
}
pub fn tensor_info(&self, name: &str) -> Option<TensorInfo> {
self.tensors.tensor(name).ok().map(|view| TensorInfo {
name: name.to_string(),
shape: view.shape().to_vec(),
dtype: format!("{:?}", view.dtype()),
})
}
pub fn load_array1(&self, name: &str) -> ModelResult<Array1<f32>> {
let view = self.tensors.tensor(name).map_err(|e| {
ModelError::simple_load_error(format!("Tensor '{}' not found: {}", name, e))
})?;
let shape = view.shape();
if shape.len() != 1 {
return Err(ModelError::simple_load_error(format!(
"Expected 1D tensor for '{}', got shape {:?}",
name, shape
)));
}
let data = view.data();
let float_data = match view.dtype() {
safetensors::Dtype::F32 => {
data.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect::<Vec<_>>()
}
safetensors::Dtype::F64 => {
data.chunks_exact(8)
.map(|chunk| {
let bytes = [
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
chunk[7],
];
f64::from_le_bytes(bytes) as f32
})
.collect::<Vec<_>>()
}
dtype => {
return Err(ModelError::simple_load_error(format!(
"Unsupported dtype for '{}': {:?}",
name, dtype
)));
}
};
Ok(Array1::from_vec(float_data))
}
pub fn load_array2(&self, name: &str) -> ModelResult<Array2<f32>> {
let view = self.tensors.tensor(name).map_err(|e| {
ModelError::simple_load_error(format!("Tensor '{}' not found: {}", name, e))
})?;
let shape = view.shape();
if shape.len() != 2 {
return Err(ModelError::simple_load_error(format!(
"Expected 2D tensor for '{}', got shape {:?}",
name, shape
)));
}
let data = view.data();
let float_data = match view.dtype() {
safetensors::Dtype::F32 => data
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect::<Vec<_>>(),
safetensors::Dtype::F64 => data
.chunks_exact(8)
.map(|chunk| {
let bytes = [
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
chunk[7],
];
f64::from_le_bytes(bytes) as f32
})
.collect::<Vec<_>>(),
dtype => {
return Err(ModelError::simple_load_error(format!(
"Unsupported dtype for '{}': {:?}",
name, dtype
)));
}
};
Array2::from_shape_vec((shape[0], shape[1]), float_data)
.map_err(|e| ModelError::simple_load_error(format!("Failed to create Array2: {}", e)))
}
pub fn load_array(&self, name: &str) -> ModelResult<ArrayD<f32>> {
let view = self.tensors.tensor(name).map_err(|e| {
ModelError::simple_load_error(format!("Tensor '{}' not found: {}", name, e))
})?;
let shape = view.shape();
let data = view.data();
let float_data = match view.dtype() {
safetensors::Dtype::F32 => data
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect::<Vec<_>>(),
safetensors::Dtype::F64 => data
.chunks_exact(8)
.map(|chunk| {
let bytes = [
chunk[0], chunk[1], chunk[2], chunk[3], chunk[4], chunk[5], chunk[6],
chunk[7],
];
f64::from_le_bytes(bytes) as f32
})
.collect::<Vec<_>>(),
safetensors::Dtype::F16 => {
data.chunks_exact(2)
.map(|chunk| {
let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
half::f16::from_bits(bits).to_f32()
})
.collect::<Vec<_>>()
}
dtype => {
return Err(ModelError::simple_load_error(format!(
"Unsupported dtype for '{}': {:?}",
name, dtype
)));
}
};
ArrayD::from_shape_vec(shape, float_data)
.map_err(|e| ModelError::simple_load_error(format!("Failed to create ArrayD: {}", e)))
}
pub fn load_array3(&self, name: &str) -> ModelResult<Vec<Vec<Vec<f32>>>> {
let array_d = self.load_array(name)?;
if array_d.ndim() != 3 {
return Err(ModelError::simple_load_error(format!(
"Expected 3D tensor for '{}', got {}D tensor",
name,
array_d.ndim()
)));
}
let shape = array_d.shape();
let dim0 = shape[0];
let dim1 = shape[1];
let dim2 = shape[2];
let mut result = Vec::with_capacity(dim0);
for i in 0..dim0 {
let mut dim1_vec = Vec::with_capacity(dim1);
for j in 0..dim1 {
let mut dim2_vec = Vec::with_capacity(dim2);
for k in 0..dim2 {
dim2_vec.push(array_d[[i, j, k]]);
}
dim1_vec.push(dim2_vec);
}
result.push(dim1_vec);
}
Ok(result)
}
pub fn has_tensor(&self, name: &str) -> bool {
self.tensors.tensor(name).is_ok()
}
pub fn load_all(&self) -> ModelResult<HashMap<String, ArrayD<f32>>> {
let mut result = HashMap::new();
for name in self.list_tensors() {
let array = self.load_array(&name)?;
result.insert(name, array);
}
Ok(result)
}
pub fn print_summary(&self) {
println!("SafeTensors Weight Summary");
println!("==========================");
println!("Total tensors: {}", self.list_tensors().len());
println!();
let mut prefixes: HashMap<String, Vec<String>> = HashMap::new();
for name in self.list_tensors() {
let parts: Vec<&str> = name.split('.').collect();
let prefix = if parts.len() > 1 {
parts[0..parts.len() - 1].join(".")
} else {
"root".to_string()
};
prefixes.entry(prefix).or_default().push(name);
}
for (prefix, tensors) in prefixes.iter() {
println!("\n[{}]", prefix);
for name in tensors {
if let Some(info) = self.tensor_info(name) {
println!(
" {} - shape: {:?}, dtype: {}",
name, info.shape, info.dtype
);
}
}
}
}
pub fn get_size_stats(&self) -> HashMap<String, usize> {
let mut stats = HashMap::new();
let mut total_params = 0usize;
for name in self.list_tensors() {
if let Some(info) = self.tensor_info(&name) {
let size: usize = info.shape.iter().product();
stats.insert(name.clone(), size);
total_params += size;
}
}
stats.insert("__total_parameters".to_string(), total_params);
stats
}
pub fn search_tensors(&self, pattern: &str) -> Vec<String> {
self.list_tensors()
.into_iter()
.filter(|name| name.contains(pattern))
.collect()
}
}
#[derive(Debug, Clone)]
pub struct TensorInfo {
pub name: String,
pub shape: Vec<usize>,
pub dtype: String,
}
pub struct WeightLoader {
loader: ModelLoader,
model_type: Option<ModelType>,
strict: bool,
name_mapping: Option<HashMap<String, String>>,
}
impl WeightLoader {
pub fn new(loader: ModelLoader) -> Self {
Self {
loader,
model_type: None,
strict: true,
name_mapping: None,
}
}
pub fn model_type(mut self, model_type: ModelType) -> Self {
self.model_type = Some(model_type);
self
}
pub fn strict(mut self, strict: bool) -> Self {
self.strict = strict;
self
}
pub fn validate_weights(&self, required: &[&str]) -> ModelResult<()> {
if !self.strict {
return Ok(());
}
let missing: Vec<_> = required
.iter()
.filter(|&&name| !self.loader.has_tensor(name))
.copied()
.collect();
if !missing.is_empty() {
return Err(ModelError::simple_load_error(format!(
"Missing required weights: {:?}",
missing
)));
}
Ok(())
}
pub fn loader(&self) -> &ModelLoader {
&self.loader
}
pub fn with_name_mapping(mut self, mapping: HashMap<String, String>) -> Self {
self.name_mapping = Some(mapping);
self
}
pub fn remap_name<'a>(&'a self, name: &'a str) -> &'a str {
if let Some(mapping) = &self.name_mapping {
if let Some(mapped) = mapping.get(name) {
return mapped.as_str();
}
}
name
}
pub fn print_weights(&self) {
self.loader.print_summary();
}
pub fn suggest_huggingface_mapping(&self) -> Vec<(String, String)> {
let mut mappings = Vec::new();
let tensors = self.loader.list_tensors();
if tensors.iter().any(|t| t.contains("backbone.layers")) {
for tensor in &tensors {
if let Some(kizzasi_name) = self.hf_to_kizzasi_name(tensor) {
mappings.push((tensor.clone(), kizzasi_name));
}
}
}
mappings
}
fn hf_to_kizzasi_name(&self, hf_name: &str) -> Option<String> {
let name = hf_name
.replace("backbone.", "")
.replace(".mixer.", ".")
.replace("conv1d", "conv")
.replace("A_log", "ssm.log_a")
.replace(".D", ".ssm.d_skip");
if name.is_empty() {
None
} else {
Some(name)
}
}
}
pub struct WeightSourceLoader<S: crate::incremental_loader::WeightSource> {
source: S,
}
impl<S: crate::incremental_loader::WeightSource> WeightSourceLoader<S> {
pub fn new(source: S) -> Self {
Self { source }
}
pub fn list_tensors(&self) -> Vec<String> {
self.source.tensor_names()
}
pub fn has_tensor(&self, name: &str) -> bool {
self.source.contains(name)
}
pub fn load_flat(&mut self, name: &str) -> ModelResult<Vec<f32>> {
self.source.load_tensor(name)
}
pub fn into_source(self) -> S {
self.source
}
}
impl WeightLoader {
pub fn from_weight_source<S: crate::incremental_loader::WeightSource>(
mut source: S,
model_type: Option<crate::ModelType>,
strict: bool,
) -> ModelResult<Self> {
let names = source.tensor_names();
let mut all_data: Vec<u8> = Vec::new();
let mut tensor_metas: Vec<(String, usize, usize, usize)> = Vec::new();
for name in &names {
let floats = source.load_tensor(name)?;
let begin = all_data.len();
for v in &floats {
all_data.extend_from_slice(&v.to_le_bytes());
}
let end = all_data.len();
tensor_metas.push((name.clone(), begin, end, floats.len()));
}
let mut header_map = serde_json::Map::new();
for (name, begin, end, n) in &tensor_metas {
let entry = serde_json::json!({
"dtype": "F32",
"shape": [n],
"data_offsets": [begin, end]
});
header_map.insert(name.clone(), entry);
}
let header_json = serde_json::Value::Object(header_map).to_string();
let header_bytes = header_json.as_bytes();
let header_len = header_bytes.len() as u64;
let mut file_bytes: Vec<u8> = Vec::new();
file_bytes.extend_from_slice(&header_len.to_le_bytes());
file_bytes.extend_from_slice(header_bytes);
file_bytes.extend_from_slice(&all_data);
let model_loader = ModelLoader::from_bytes(file_bytes)?;
let mut wl = WeightLoader::new(model_loader);
if let Some(mt) = model_type {
wl = wl.model_type(mt);
}
wl = wl.strict(strict);
Ok(wl)
}
}
#[derive(Debug, Clone, Default)]
pub struct NameRemapper;
impl NameRemapper {
pub fn new() -> Self {
Self
}
pub fn remap(&self, key: &str) -> String {
let key = if let Some(rest) = key.strip_prefix("backbone.") {
match rest {
"embeddings.weight" => return "input_proj".to_string(),
"norm_f.weight" => return "final_norm.weight".to_string(),
_ => rest,
}
} else {
key
};
if key == "embedding.weight" {
return "input_proj".to_string();
}
if key == "lm_head.weight" {
return "output_proj".to_string();
}
if let Some(layer_idx) = Self::extract_layer_index(key) {
let after_layer = Self::strip_layer_prefix(key, layer_idx);
if let Some(mapped_suffix) = Self::remap_layer_suffix(after_layer) {
return format!("layers.{}.{}", layer_idx, mapped_suffix);
}
}
key.to_string()
}
fn extract_layer_index(key: &str) -> Option<usize> {
let mut parts = key.splitn(3, '.');
match (parts.next(), parts.next()) {
(Some("layers"), Some(idx)) => idx.parse::<usize>().ok(),
_ => None,
}
}
fn strip_layer_prefix(key: &str, layer_idx: usize) -> &str {
let prefix_len = 7 + layer_idx.to_string().len() + 1; if key.len() > prefix_len {
&key[prefix_len..]
} else {
""
}
}
fn remap_layer_suffix(suffix: &str) -> Option<&'static str> {
match suffix {
"mixer.in_proj.weight" => Some("input_proj"),
"mixer.out_proj.weight" => Some("output_proj"),
"attn.q_proj.weight" => Some("attention.q"),
"attn.k_proj.weight" => Some("attention.k"),
"attn.v_proj.weight" => Some("attention.v"),
"attn.o_proj.weight" => Some("attention.out"),
"mlp.gate_proj.weight" => Some("ff.gate"),
"mlp.up_proj.weight" => Some("ff.up"),
"mlp.down_proj.weight" => Some("ff.down"),
_ => None,
}
}
pub fn remap_map(&self, weights: HashMap<String, Vec<f32>>) -> HashMap<String, Vec<f32>> {
weights
.into_iter()
.map(|(k, v)| (self.remap(&k), v))
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tensor_info() {
let info = TensorInfo {
name: "test".to_string(),
shape: vec![2, 3],
dtype: "F32".to_string(),
};
assert_eq!(info.name, "test");
assert_eq!(info.shape, vec![2, 3]);
}
#[test]
fn test_name_remapper_layers() {
let remapper = NameRemapper::new();
assert_eq!(
remapper.remap("layers.0.mixer.in_proj.weight"),
"layers.0.input_proj"
);
assert_eq!(
remapper.remap("layers.3.mixer.out_proj.weight"),
"layers.3.output_proj"
);
assert_eq!(
remapper.remap("layers.7.attn.q_proj.weight"),
"layers.7.attention.q"
);
assert_eq!(
remapper.remap("layers.7.attn.k_proj.weight"),
"layers.7.attention.k"
);
assert_eq!(
remapper.remap("layers.7.attn.v_proj.weight"),
"layers.7.attention.v"
);
assert_eq!(
remapper.remap("layers.7.attn.o_proj.weight"),
"layers.7.attention.out"
);
assert_eq!(
remapper.remap("layers.2.mlp.gate_proj.weight"),
"layers.2.ff.gate"
);
assert_eq!(
remapper.remap("layers.2.mlp.up_proj.weight"),
"layers.2.ff.up"
);
assert_eq!(
remapper.remap("layers.2.mlp.down_proj.weight"),
"layers.2.ff.down"
);
}
#[test]
fn test_name_remapper_embedding() {
let remapper = NameRemapper::new();
assert_eq!(remapper.remap("embedding.weight"), "input_proj");
assert_eq!(remapper.remap("lm_head.weight"), "output_proj");
}
#[test]
fn test_name_remapper_backbone() {
let remapper = NameRemapper::new();
assert_eq!(
remapper.remap("backbone.embeddings.weight"),
"input_proj",
"HuggingFace backbone.embeddings.weight should remap to input_proj"
);
assert_eq!(
remapper.remap("backbone.norm_f.weight"),
"final_norm.weight",
"HuggingFace backbone.norm_f.weight should remap to final_norm.weight"
);
assert_eq!(
remapper.remap("backbone.layers.0.mixer.in_proj.weight"),
"layers.0.input_proj",
"backbone-prefixed layer key should remap via the normal layer-suffix rules"
);
let raw_unknown = "backbone.something.unknown";
assert_eq!(
remapper.remap(raw_unknown),
"something.unknown",
"unknown backbone sub-key should pass through with backbone. prefix stripped"
);
}
#[test]
fn test_name_remapper_passthrough() {
let remapper = NameRemapper::new();
let unknown = "some.random.unknown.key";
assert_eq!(remapper.remap(unknown), unknown);
let another = "custom_layer_bias";
assert_eq!(remapper.remap(another), another);
}
#[test]
fn test_name_remapper_remap_map() {
let remapper = NameRemapper::new();
let mut weights = HashMap::new();
weights.insert("embedding.weight".to_string(), vec![1.0f32, 2.0]);
weights.insert("lm_head.weight".to_string(), vec![3.0f32, 4.0]);
weights.insert("layers.0.attn.q_proj.weight".to_string(), vec![5.0f32]);
let remapped = remapper.remap_map(weights);
assert!(remapped.contains_key("input_proj"));
assert!(remapped.contains_key("output_proj"));
assert!(remapped.contains_key("layers.0.attention.q"));
}
#[test]
fn test_weight_loader_remap_name() {
let mut mapping = HashMap::new();
mapping.insert("old_name".to_string(), "new_name".to_string());
let remapper = NameRemapper::new();
assert_eq!(
remapper.remap("layers.1.mlp.gate_proj.weight"),
"layers.1.ff.gate"
);
}
}