use serde::Deserialize;
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufReader, Read, Seek, SeekFrom};
use std::path::{Path, PathBuf};
use trustformers_core::{
errors::{invalid_format, runtime_error, Result, TrustformersError},
tensor::Tensor,
};
use super::config::{WeightDataType, WeightFormat, WeightLoadingConfig};
#[derive(Debug, Deserialize)]
pub struct HuggingFaceIndex {
pub metadata: HuggingFaceMetadata,
pub weight_map: HashMap<String, String>,
}
#[derive(Debug, Deserialize)]
pub struct HuggingFaceMetadata {
pub total_size: u64,
pub format: String,
}
#[derive(Debug)]
pub struct SafeTensorsHeader {
pub metadata: Option<HashMap<String, String>>,
pub tensors: HashMap<String, TensorInfo>,
}
impl<'de> serde::Deserialize<'de> for SafeTensorsHeader {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let mut map: HashMap<String, serde_json::Value> = HashMap::deserialize(deserializer)?;
let metadata = map.remove("__metadata__").and_then(|v| serde_json::from_value(v).ok());
let tensors: HashMap<String, TensorInfo> = map
.into_iter()
.filter_map(|(k, v)| serde_json::from_value(v).ok().map(|info| (k, info)))
.collect();
Ok(SafeTensorsHeader { metadata, tensors })
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct TensorInfo {
pub dtype: String,
pub shape: Vec<usize>,
pub data_offsets: [u64; 2],
}
#[derive(Debug)]
struct PyTorchTensorInfo {
pub shape: Vec<usize>,
pub dtype: WeightDataType,
pub data_offset: usize,
}
pub trait WeightLoader {
fn load_tensor(&mut self, name: &str) -> Result<Tensor>;
fn list_tensors(&self) -> Result<Vec<String>>;
fn tensor_info(&self, name: &str) -> Result<Option<TensorMetadata>>;
fn close(&mut self) -> Result<()>;
}
#[derive(Debug, Clone)]
pub struct TensorMetadata {
pub shape: Vec<usize>,
pub dtype: WeightDataType,
pub size_bytes: u64,
pub offset: u64,
}
pub struct LazyTensor {
name: String,
#[allow(dead_code)]
filename: String,
metadata: TensorMetadata,
model_dir: PathBuf,
config: WeightLoadingConfig,
}
pub struct HuggingFaceLoader {
config: WeightLoadingConfig,
index: HuggingFaceIndex,
file_handles: HashMap<String, BufReader<File>>,
model_dir: PathBuf,
tensor_cache: HashMap<String, Tensor>,
}
impl HuggingFaceLoader {
pub fn new(model_dir: impl AsRef<Path>, config: WeightLoadingConfig) -> Result<Self> {
let model_dir = model_dir.as_ref().to_path_buf();
let index_path = model_dir.join("pytorch_model.bin.index.json");
let index = if index_path.exists() {
Self::load_index(&index_path)?
} else {
Self::create_single_file_index(&model_dir)?
};
Ok(Self {
config,
index,
file_handles: HashMap::new(),
model_dir,
tensor_cache: HashMap::new(),
})
}
fn load_index(path: &Path) -> Result<HuggingFaceIndex> {
let file = File::open(path)?;
let reader = BufReader::new(file);
serde_json::from_reader(reader).map_err(|e| {
TrustformersError::weight_load_error(format!(
"Failed to parse HuggingFace index: {}",
e
))
})
}
fn create_single_file_index(model_dir: &Path) -> Result<HuggingFaceIndex> {
let bin_path = model_dir.join("pytorch_model.bin");
let safetensors_path = model_dir.join("model.safetensors");
let (weight_file, is_safetensors) = if safetensors_path.exists() {
("model.safetensors", true)
} else if bin_path.exists() {
("pytorch_model.bin", false)
} else {
return Err(TrustformersError::file_not_found(
"No weight files found in model directory".to_string(),
));
};
let mut weight_map = HashMap::new();
if is_safetensors {
match Self::read_safetensors_tensor_names(&model_dir.join(weight_file)) {
Ok(tensor_names) => {
for name in tensor_names {
weight_map.insert(name, weight_file.to_string());
}
},
Err(e) => {
eprintln!(
"Warning: Failed to read SafeTensors header: {}. Using fallback index.",
e
);
weight_map.insert("*".to_string(), weight_file.to_string());
},
}
} else {
weight_map.insert("*".to_string(), weight_file.to_string());
}
Ok(HuggingFaceIndex {
metadata: HuggingFaceMetadata {
total_size: 0,
format: if is_safetensors { "safetensors" } else { "pytorch" }.to_string(),
},
weight_map,
})
}
fn read_safetensors_tensor_names(path: &Path) -> Result<Vec<String>> {
use std::io::Read;
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let mut header_len_bytes = [0u8; 8];
reader.read_exact(&mut header_len_bytes)?;
let header_len = u64::from_le_bytes(header_len_bytes);
let mut header_bytes = vec![0u8; header_len as usize];
reader.read_exact(&mut header_bytes)?;
let header_str = String::from_utf8(header_bytes).map_err(|e| {
TrustformersError::weight_load_error(format!(
"Invalid UTF-8 in SafeTensors header: {}",
e
))
})?;
let header: serde_json::Value = serde_json::from_str(&header_str).map_err(|e| {
TrustformersError::weight_load_error(format!(
"Failed to parse SafeTensors header: {}",
e
))
})?;
let mut tensor_names = Vec::new();
if let Some(obj) = header.as_object() {
for (key, _value) in obj {
if key != "__metadata__" {
tensor_names.push(key.clone());
}
}
}
Ok(tensor_names)
}
fn get_file_handle(&mut self, filename: &str) -> Result<&mut BufReader<File>> {
if !self.file_handles.contains_key(filename) {
let file_path = self.model_dir.join(filename);
let file = File::open(&file_path)?;
let reader = BufReader::new(file);
self.file_handles.insert(filename.to_string(), reader);
}
self.file_handles.get_mut(filename).ok_or_else(|| {
TrustformersError::runtime_error(format!(
"File handle for {} not found after insertion",
filename
))
})
}
fn load_from_pytorch_bin(&mut self, name: &str, filename: &str) -> Result<Tensor> {
let reader = self.get_file_handle(filename)?;
let mut buffer = Vec::new();
reader.read_to_end(&mut buffer).map_err(|e| {
TrustformersError::weight_load_error(format!("Failed to read tensor file: {}", e))
})?;
match Self::parse_pytorch_pickle_static(&buffer, name) {
Ok(tensor) => Ok(tensor),
Err(e) => {
eprintln!(
"Warning: Pickle parsing failed for {}: {}. Attempting raw tensor parsing.",
name, e
);
Self::parse_raw_tensor_data_static(&buffer, name)
},
}
}
#[allow(dead_code)]
fn parse_pytorch_tensor(&mut self, reader: &mut BufReader<File>, name: &str) -> Result<Tensor> {
let mut buffer = Vec::new();
reader.read_to_end(&mut buffer).map_err(|e| {
TrustformersError::weight_load_error(format!("Failed to read tensor file: {}", e))
})?;
match Self::parse_pytorch_pickle_static(&buffer, name) {
Ok(tensor) => Ok(tensor),
Err(e) => {
eprintln!(
"Warning: Pickle parsing failed for {}: {}. Attempting raw tensor parsing.",
name, e
);
Self::parse_raw_tensor_data_static(&buffer, name)
},
}
}
#[allow(dead_code)]
fn parse_pytorch_pickle(&self, data: &[u8], name: &str) -> Result<Tensor> {
Self::parse_pytorch_pickle_static(data, name)
}
fn parse_pytorch_pickle_static(data: &[u8], name: &str) -> Result<Tensor> {
if data.len() < 8 {
return Err(TrustformersError::weight_load_error(
"File too small to contain tensor data".to_string(),
));
}
if let Some(tensor_info) = Self::extract_pytorch_tensor_info_static(data, name) {
let offset = tensor_info.data_offset;
let shape = tensor_info.shape;
let dtype = tensor_info.dtype;
let total_elements: usize = shape.iter().product();
match dtype {
WeightDataType::Float32 => {
let data_size = total_elements * 4;
if offset + data_size <= data.len() {
let tensor_data = &data[offset..offset + data_size];
let float_data: Vec<f32> = tensor_data
.chunks_exact(4)
.map(|chunk| {
f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]])
})
.collect();
Tensor::from_vec(float_data, &shape).map_err(|e| {
TrustformersError::weight_load_error(format!(
"Failed to create tensor: {}",
e
))
})
} else {
Err(TrustformersError::weight_load_error(
"Insufficient data for tensor".to_string(),
))
}
},
WeightDataType::Float16 => {
let data_size = total_elements * 2;
if offset + data_size <= data.len() {
let tensor_data = &data[offset..offset + data_size];
let float_data: Vec<f32> = tensor_data
.chunks_exact(2)
.map(|chunk| {
let half_val = half::f16::from_le_bytes([chunk[0], chunk[1]]);
half_val.to_f32()
})
.collect();
Tensor::from_vec(float_data, &shape).map_err(|e| {
TrustformersError::weight_load_error(format!(
"Failed to create tensor: {}",
e
))
})
} else {
Err(TrustformersError::weight_load_error(
"Insufficient data for tensor".to_string(),
))
}
},
_ => Err(TrustformersError::weight_load_error(format!(
"Unsupported tensor dtype: {:?}",
dtype
))),
}
} else {
Err(TrustformersError::weight_load_error(
"Could not extract tensor information from pickle data".to_string(),
))
}
}
#[allow(dead_code)]
fn extract_pytorch_tensor_info(&self, data: &[u8], name: &str) -> Option<PyTorchTensorInfo> {
Self::extract_pytorch_tensor_info_static(data, name)
}
fn extract_pytorch_tensor_info_static(data: &[u8], name: &str) -> Option<PyTorchTensorInfo> {
let shape = Self::infer_tensor_shape_static(name);
let dtype = WeightDataType::Float32;
let mut data_offset = 0;
for i in 0..data.len().saturating_sub(16) {
if Self::looks_like_tensor_data_static(&data[i..i.min(i + 16)]) {
data_offset = i;
break;
}
}
if data_offset == 0 && data.len() > 1024 {
data_offset = 1024; }
Some(PyTorchTensorInfo {
shape,
dtype,
data_offset,
})
}
#[allow(dead_code)]
fn infer_tensor_shape(&self, name: &str) -> Vec<usize> {
Self::infer_tensor_shape_static(name)
}
fn infer_tensor_shape_static(name: &str) -> Vec<usize> {
if name.contains("embeddings.word_embeddings.weight") {
vec![30522, 768] } else if name.contains("embeddings.position_embeddings.weight") {
vec![512, 768] } else if name.contains("attention.self.query.weight")
|| name.contains("attention.self.key.weight")
|| name.contains("attention.self.value.weight")
{
vec![768, 768] } else if name.contains("attention.output.dense.weight") {
vec![768, 768] } else if name.contains("intermediate.dense.weight") {
vec![768, 3072] } else if name.contains("output.dense.weight") {
vec![3072, 768] } else if name.contains("LayerNorm.weight") || name.contains("LayerNorm.bias") {
vec![768] } else if name.contains("bias") {
vec![768] } else {
vec![768, 768]
}
}
#[allow(dead_code)]
fn looks_like_tensor_data(&self, chunk: &[u8]) -> bool {
Self::looks_like_tensor_data_static(chunk)
}
fn looks_like_tensor_data_static(chunk: &[u8]) -> bool {
if chunk.len() < 4 {
return false;
}
let float_val = f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]);
float_val.is_finite() && float_val.abs() < 100.0
}
#[allow(dead_code)]
fn parse_raw_tensor_data(&self, data: &[u8], name: &str) -> Result<Tensor> {
Self::parse_raw_tensor_data_static(data, name)
}
fn parse_raw_tensor_data_static(data: &[u8], name: &str) -> Result<Tensor> {
let shape = Self::infer_tensor_shape_static(name);
let total_elements: usize = shape.iter().product();
let expected_size = total_elements * 4;
if data.len() >= expected_size {
for offset in (0..1024.min(data.len())).step_by(4) {
if offset + expected_size <= data.len() {
let tensor_data = &data[offset..offset + expected_size];
let float_data: Vec<f32> = tensor_data
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
if float_data.iter().any(|&x| x.is_finite() && x.abs() < 100.0) {
if let Ok(tensor) = Tensor::from_vec(float_data, &shape) {
return Ok(tensor);
}
}
}
}
}
Err(TrustformersError::weight_load_error(format!(
"Could not parse tensor data for {}",
name
)))
}
#[allow(dead_code)]
fn load_lazy(&mut self, name: &str) -> Result<LazyTensor> {
let filename = self.find_tensor_file(name)?;
let metadata = self.get_tensor_metadata(name, &filename)?;
Ok(LazyTensor {
name: name.to_string(),
filename,
metadata,
model_dir: self.model_dir.clone(),
config: self.config.clone(),
})
}
fn find_tensor_file(&self, name: &str) -> Result<String> {
if let Some(filename) = self.index.weight_map.get(name) {
Ok(filename.clone())
} else if let Some(filename) = self.index.weight_map.get("*") {
Ok(filename.clone())
} else {
Err(runtime_error(format!("Tensor not found: {}", name)))
}
}
fn get_tensor_metadata(&self, _name: &str, _filename: &str) -> Result<TensorMetadata> {
Ok(TensorMetadata {
shape: vec![1024, 768],
dtype: WeightDataType::Float32,
size_bytes: 1024 * 768 * 4,
offset: 0,
})
}
fn detect_format(&self, filename: &str) -> Result<WeightFormat> {
if filename.ends_with(".bin") {
Ok(WeightFormat::HuggingFaceBin)
} else if filename.ends_with(".safetensors") {
Ok(WeightFormat::SafeTensors)
} else {
Err(invalid_format(
"file format",
format!("Unknown format for file: {}", filename),
))
}
}
fn load_from_safetensors(&mut self, name: &str, filename: &str) -> Result<Tensor> {
self.load_safetensors_tensor_complete(name, filename)
}
fn load_safetensors_tensor_complete(&mut self, name: &str, filename: &str) -> Result<Tensor> {
let file_path = self.model_dir.join(filename);
eprintln!(
"[SAFETENSORS DEBUG] Loading tensor '{}' from file: {:?}",
name, file_path
);
let file = File::open(&file_path)?;
let mut reader = BufReader::new(file);
let mut header_len_bytes = [0u8; 8];
reader.read_exact(&mut header_len_bytes)?;
let header_len = u64::from_le_bytes(header_len_bytes);
eprintln!("[SAFETENSORS DEBUG] Header length: {} bytes", header_len);
let mut header_bytes = vec![0u8; header_len as usize];
reader.read_exact(&mut header_bytes)?;
let header_str = std::str::from_utf8(&header_bytes).map_err(|e| {
TrustformersError::weight_load_error(format!(
"Invalid UTF-8 in SafeTensors header: {}",
e
))
})?;
eprintln!(
"[SAFETENSORS DEBUG] Header preview (first 500 chars): {}",
&header_str[..header_str.len().min(500)]
);
let header: SafeTensorsHeader = serde_json::from_str(header_str).map_err(|e| {
eprintln!("[SAFETENSORS DEBUG] Failed to parse header, printing full header:");
eprintln!("{}", header_str);
TrustformersError::serialization_error(format!(
"Failed to parse SafeTensors header: {}",
e
))
})?;
if let Some(tensor_info) = header.tensors.get(name) {
let tensor_data_start = 8 + header_len;
reader.seek(SeekFrom::Start(
tensor_data_start + tensor_info.data_offsets[0],
))?;
let data_len = (tensor_info.data_offsets[1] - tensor_info.data_offsets[0]) as usize;
let mut data = vec![0u8; data_len];
reader.read_exact(&mut data)?;
self.bytes_to_tensor(data, &tensor_info.dtype, &tensor_info.shape)
} else {
Err(runtime_error(format!("Tensor not found: {}", name)))
}
}
#[allow(dead_code)]
fn parse_safetensors_header(
&mut self,
reader: &mut BufReader<File>,
) -> Result<SafeTensorsHeader> {
let mut header_len_bytes = [0u8; 8];
reader.read_exact(&mut header_len_bytes)?;
let header_len = u64::from_le_bytes(header_len_bytes);
let mut header_bytes = vec![0u8; header_len as usize];
reader.read_exact(&mut header_bytes)?;
let header_str = std::str::from_utf8(&header_bytes).map_err(|e| {
TrustformersError::weight_load_error(format!(
"Invalid UTF-8 in SafeTensors header: {}",
e
))
})?;
serde_json::from_str(header_str).map_err(|e| {
TrustformersError::serialization_error(format!(
"Failed to parse SafeTensors header: {}",
e
))
})
}
#[allow(dead_code)]
fn load_safetensors_tensor(
&mut self,
reader: &mut BufReader<File>,
info: &TensorInfo,
) -> Result<Tensor> {
reader.seek(SeekFrom::Start(info.data_offsets[0]))?;
let data_len = (info.data_offsets[1] - info.data_offsets[0]) as usize;
let mut data = vec![0u8; data_len];
reader.read_exact(&mut data)?;
self.bytes_to_tensor(data, &info.dtype, &info.shape)
}
fn bytes_to_tensor(&self, data: Vec<u8>, dtype: &str, shape: &[usize]) -> Result<Tensor> {
match dtype {
"F32" => {
let floats: Vec<f32> = data
.chunks_exact(4)
.map(|chunk| f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]))
.collect();
Tensor::from_vec(floats, shape)
},
"F16" => {
let floats: Vec<f32> = data
.chunks_exact(2)
.map(|chunk| {
let bits = u16::from_le_bytes([chunk[0], chunk[1]]);
half::f16::from_bits(bits).to_f32()
})
.collect();
Tensor::from_vec(floats, shape)
},
"I8" => {
let ints: Vec<i8> = data.into_iter().map(|b| b as i8).collect();
let floats: Vec<f32> = ints.into_iter().map(|i| i as f32).collect();
Tensor::from_vec(floats, shape)
},
_ => Err(invalid_format(
"dtype",
format!("Unsupported dtype: {}", dtype),
)),
}
}
}
impl WeightLoader for HuggingFaceLoader {
fn load_tensor(&mut self, name: &str) -> Result<Tensor> {
if let Some(tensor) = self.tensor_cache.get(name) {
return Ok(tensor.clone());
}
let filename = self.find_tensor_file(name)?;
let tensor = match self.detect_format(&filename)? {
WeightFormat::HuggingFaceBin => self.load_from_pytorch_bin(name, &filename)?,
WeightFormat::SafeTensors => self.load_from_safetensors(name, &filename)?,
_ => {
return Err(invalid_format("weight format", "Unsupported weight format"));
},
};
if !self.config.lazy_loading {
self.tensor_cache.insert(name.to_string(), tensor.clone());
}
Ok(tensor)
}
fn list_tensors(&self) -> Result<Vec<String>> {
Ok(self.index.weight_map.keys().cloned().collect())
}
fn tensor_info(&self, name: &str) -> Result<Option<TensorMetadata>> {
let filename = self.find_tensor_file(name)?;
Ok(Some(self.get_tensor_metadata(name, &filename)?))
}
fn close(&mut self) -> Result<()> {
self.file_handles.clear();
self.tensor_cache.clear();
Ok(())
}
}
impl LazyTensor {
pub fn load(&self) -> Result<Tensor> {
let mut temp_loader = HuggingFaceLoader::new(&self.model_dir, self.config.clone())?;
temp_loader.load_tensor(&self.name)
}
pub fn metadata(&self) -> &TensorMetadata {
&self.metadata
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::weight_loading::config::{WeightDataType, WeightFormat, WeightLoadingConfig};
#[test]
fn test_weight_format_equality() {
let f1 = WeightFormat::SafeTensors;
let f2 = WeightFormat::SafeTensors;
assert_eq!(f1, f2);
}
#[test]
fn test_weight_format_custom() {
let f = WeightFormat::Custom("msgpack".to_string());
if let WeightFormat::Custom(name) = &f {
assert_eq!(name, "msgpack");
} else {
panic!("Expected Custom variant");
}
}
#[test]
fn test_weight_format_hf_bin_ne_safetensors() {
assert_ne!(WeightFormat::HuggingFaceBin, WeightFormat::SafeTensors);
}
#[test]
fn test_weight_loading_config_default() {
let cfg = WeightLoadingConfig::default();
assert!(!cfg.lazy_loading);
assert!(!cfg.memory_mapped);
assert!(!cfg.streaming);
assert_eq!(cfg.device, "cpu");
assert!(cfg.verify_checksums);
}
#[test]
fn test_weight_loading_config_lazy_loading_flag() {
let cfg = WeightLoadingConfig {
lazy_loading: true,
..WeightLoadingConfig::default()
};
assert!(cfg.lazy_loading);
}
#[test]
fn test_weight_loading_config_with_format() {
let cfg = WeightLoadingConfig {
format: Some(WeightFormat::SafeTensors),
..WeightLoadingConfig::default()
};
if let Some(WeightFormat::SafeTensors) = &cfg.format {
} else {
panic!("Expected SafeTensors format");
}
}
#[test]
fn test_tensor_metadata_construction() {
let meta = TensorMetadata {
shape: vec![128, 256],
dtype: WeightDataType::Float32,
size_bytes: 128 * 256 * 4,
offset: 0,
};
assert_eq!(meta.shape, vec![128, 256]);
assert_eq!(meta.size_bytes, 131072);
}
#[test]
fn test_tensor_metadata_clone() {
let meta = TensorMetadata {
shape: vec![32, 64],
dtype: WeightDataType::Float16,
size_bytes: 32 * 64 * 2,
offset: 1024,
};
let cloned = meta.clone();
assert_eq!(cloned.shape, meta.shape);
assert_eq!(cloned.offset, meta.offset);
}
#[test]
fn test_tensor_info_clone() {
let info = TensorInfo {
dtype: "F32".to_string(),
shape: vec![64, 128],
data_offsets: [0, 32768],
};
let cloned = info.clone();
assert_eq!(cloned.dtype, "F32");
assert_eq!(cloned.shape, vec![64, 128]);
assert_eq!(cloned.data_offsets, [0, 32768]);
}
#[test]
fn test_huggingface_loader_nonexistent_dir() {
let cfg = WeightLoadingConfig::default();
let result = HuggingFaceLoader::new("/nonexistent/model/dir", cfg);
assert!(
result.is_err(),
"Expected error for nonexistent model directory"
);
}
#[test]
fn test_huggingface_loader_empty_dir() {
let dir = std::env::temp_dir().join("trustformers_hf_test_empty_dir");
let _ = std::fs::create_dir_all(&dir);
let cfg = WeightLoadingConfig::default();
let result = HuggingFaceLoader::new(&dir, cfg);
assert!(
result.is_err(),
"Expected error for directory with no weight files"
);
let _ = std::fs::remove_dir_all(&dir);
}
#[test]
fn test_safetensors_header_empty() {
let json = r#"{"__metadata__": {}}"#;
let result: std::result::Result<SafeTensorsHeader, _> = serde_json::from_str(json);
assert!(
result.is_ok(),
"Empty SafeTensors header parse failed: {:?}",
result.err()
);
let header = result.expect("expected Ok");
assert!(header.tensors.is_empty());
}
#[test]
fn test_safetensors_header_with_tensor() {
let json = r#"{
"__metadata__": {"format": "pt"},
"model.weight": {"dtype": "F32", "shape": [64, 128], "data_offsets": [0, 32768]}
}"#;
let result: std::result::Result<SafeTensorsHeader, _> = serde_json::from_str(json);
assert!(
result.is_ok(),
"SafeTensors header parse failed: {:?}",
result.err()
);
let header = result.expect("expected Ok");
assert!(header.tensors.contains_key("model.weight"));
assert!(header.metadata.is_some());
}
#[test]
fn test_safetensors_header_multiple_tensors() {
let json = r#"{
"layer.0.weight": {"dtype": "F16", "shape": [256, 512], "data_offsets": [0, 262144]},
"layer.0.bias": {"dtype": "F32", "shape": [256], "data_offsets": [262144, 263168]}
}"#;
let result: std::result::Result<SafeTensorsHeader, _> = serde_json::from_str(json);
assert!(result.is_ok());
let header = result.expect("expected Ok");
assert_eq!(header.tensors.len(), 2);
assert!(header.metadata.is_none());
}
}