use std::collections::HashMap;
use std::fs::File;
use std::path::Path;
use memmap2::Mmap;
use serde::Deserialize;
#[derive(Debug, thiserror::Error)]
pub enum WeightFileError {
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("manifest JSON parse error: {0}")]
Json(#[from] serde_json::Error),
#[error("manifest missing 'tensors' key")]
MissingTensors,
#[error(
"tensor '{name}' span [{offset}, {end}) extends past mmap'd file ({file_size} bytes)"
)]
OutOfBounds {
name: String,
offset: u64,
end: u64,
file_size: u64,
},
}
#[derive(Debug, Clone)]
pub struct TensorInfo {
pub offset: u64,
pub size: u64,
pub shape: Vec<usize>,
pub dtype: String,
pub bits: i32,
}
pub struct WeightFile {
mmap: Mmap,
tensors: HashMap<String, TensorInfo>,
}
impl WeightFile {
pub fn open(
bin_path: &Path,
manifest_path: &Path,
) -> Result<Self, WeightFileError> {
let file = File::open(bin_path)?;
let mmap = unsafe { Mmap::map(&file)? };
let tensors = parse_manifest(manifest_path)?;
let wf = WeightFile { mmap, tensors };
wf.bounds_check_all()?;
eprintln!(
"[weights] mmap'd {:.2} GB from {}",
wf.mmap.len() as f64 / 1e9,
bin_path.display()
);
eprintln!(
"[manifest] Loaded {} tensors from {}",
wf.tensors.len(),
manifest_path.display()
);
Ok(wf)
}
pub fn file_size(&self) -> usize {
self.mmap.len()
}
pub fn len(&self) -> usize {
self.tensors.len()
}
pub fn is_empty(&self) -> bool {
self.tensors.is_empty()
}
pub fn tensor_info(&self, name: &str) -> Option<&TensorInfo> {
self.tensors.get(name)
}
pub fn tensor_bytes(&self, name: &str) -> Option<&[u8]> {
let info = self.tensor_info(name)?;
let start = info.offset as usize;
let end = start + info.size as usize;
Some(&self.mmap[start..end])
}
pub fn bytes_at(&self, offset: u64, len: usize) -> Option<&[u8]> {
let start = offset as usize;
let end = start.checked_add(len)?;
if end > self.mmap.len() {
return None;
}
Some(&self.mmap[start..end])
}
pub fn iter(&self) -> impl Iterator<Item = (&str, &TensorInfo)> {
self.tensors.iter().map(|(k, v)| (k.as_str(), v))
}
fn bounds_check_all(&self) -> Result<(), WeightFileError> {
let file_size = self.mmap.len() as u64;
for (name, info) in &self.tensors {
let end = info.offset.saturating_add(info.size);
if end > file_size {
return Err(WeightFileError::OutOfBounds {
name: name.clone(),
offset: info.offset,
end,
file_size,
});
}
}
Ok(())
}
}
impl std::fmt::Debug for WeightFile {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("WeightFile")
.field("file_size", &self.mmap.len())
.field("num_tensors", &self.tensors.len())
.finish()
}
}
#[derive(Debug, Deserialize)]
struct RawManifest {
tensors: HashMap<String, RawTensor>,
}
#[derive(Debug, Deserialize)]
struct RawTensor {
offset: u64,
size: u64,
shape: Vec<usize>,
dtype: String,
#[serde(default)]
bits: Option<i32>,
}
fn parse_manifest(
path: &Path,
) -> Result<HashMap<String, TensorInfo>, WeightFileError> {
let bytes = std::fs::read(path)?;
let raw: RawManifest = serde_json::from_slice(&bytes)?;
let mut out = HashMap::with_capacity(raw.tensors.len());
for (name, t) in raw.tensors {
let bits = t.bits.unwrap_or_else(|| {
if t.dtype == "U32" {
4
} else {
0
}
});
out.insert(
name,
TensorInfo {
offset: t.offset,
size: t.size,
shape: t.shape,
dtype: t.dtype,
bits,
},
);
}
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn parse_manifest_with_and_without_bits() {
let json = br#"{
"tensors": {
"weights_a": {
"offset": 0,
"size": 1024,
"shape": [16, 16],
"dtype": "U32",
"bits": 4
},
"weights_b_old": {
"offset": 1024,
"size": 2048,
"shape": [32, 16],
"dtype": "U32"
},
"scales": {
"offset": 3072,
"size": 64,
"shape": [32],
"dtype": "BF16"
}
}
}"#;
let raw: RawManifest = serde_json::from_slice(json).unwrap();
assert_eq!(raw.tensors.len(), 3);
let a = &raw.tensors["weights_a"];
assert_eq!(a.bits, Some(4));
assert_eq!(a.dtype, "U32");
let b = &raw.tensors["weights_b_old"];
assert_eq!(b.bits, None);
assert_eq!(b.dtype, "U32");
}
}