use std::collections::BTreeMap;
use std::fs;
use std::io::Write;
use std::path::Path;
use ndarray::{Array1, Array2, Array3, Array4};
use safetensors::SafeTensors;
use safetensors::tensor::{Dtype, TensorView, serialize_to_file};
use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum StateDictError {
#[error("missing key: {0}")]
MissingKey(String),
#[error("shape mismatch for {key}: expected {expected:?}, got {actual:?}")]
ShapeMismatch {
key: String,
expected: Vec<usize>,
actual: Vec<usize>,
},
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("json error: {0}")]
Json(#[from] serde_json::Error),
#[error("safetensors error: {0}")]
Safetensors(String),
#[error("unsupported dtype {0:?} (only f32 is currently supported)")]
UnsupportedDtype(String),
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct TensorMeta {
pub name: String,
pub shape: Vec<usize>,
pub offset: u64,
#[serde(default = "default_dtype")]
pub dtype: String,
}
fn default_dtype() -> String {
"f32".into()
}
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Header {
pub tensors: Vec<TensorMeta>,
}
#[derive(Debug, Clone, Default)]
pub struct StateDict {
pub tensors: BTreeMap<String, (Vec<usize>, Vec<f32>)>,
}
impl StateDict {
pub fn new() -> Self {
Self::default()
}
pub fn insert(&mut self, name: impl Into<String>, shape: Vec<usize>, data: Vec<f32>) {
let n: usize = shape.iter().product();
assert_eq!(
n,
data.len(),
"data length {} ≠ product of shape {shape:?}",
data.len()
);
self.tensors.insert(name.into(), (shape, data));
}
pub fn get(&self, name: &str) -> Result<&(Vec<usize>, Vec<f32>), StateDictError> {
self.tensors
.get(name)
.ok_or_else(|| StateDictError::MissingKey(name.into()))
}
pub fn take_with_shape(
&self,
name: &str,
expected: &[usize],
) -> Result<&[f32], StateDictError> {
let (shape, data) = self.get(name)?;
if shape != expected {
return Err(StateDictError::ShapeMismatch {
key: name.into(),
expected: expected.to_vec(),
actual: shape.clone(),
});
}
Ok(data.as_slice())
}
pub fn take_vec(&self, name: &str, n: usize) -> Result<Vec<f32>, StateDictError> {
Ok(self.take_with_shape(name, &[n])?.to_vec())
}
pub fn take_array1(&self, name: &str, n: usize) -> Result<Array1<f32>, StateDictError> {
Ok(Array1::from_vec(self.take_vec(name, n)?))
}
pub fn take_array2(
&self,
name: &str,
rows: usize,
cols: usize,
) -> Result<Array2<f32>, StateDictError> {
let data = self.take_with_shape(name, &[rows, cols])?.to_vec();
Ok(Array2::from_shape_vec((rows, cols), data)
.expect("shape check passed but from_shape_vec failed"))
}
pub fn take_array3(
&self,
name: &str,
d0: usize,
d1: usize,
d2: usize,
) -> Result<Array3<f32>, StateDictError> {
let data = self.take_with_shape(name, &[d0, d1, d2])?.to_vec();
Ok(Array3::from_shape_vec((d0, d1, d2), data).unwrap())
}
pub fn take_array4(
&self,
name: &str,
d0: usize,
d1: usize,
d2: usize,
d3: usize,
) -> Result<Array4<f32>, StateDictError> {
let data = self.take_with_shape(name, &[d0, d1, d2, d3])?.to_vec();
Ok(Array4::from_shape_vec((d0, d1, d2, d3), data).unwrap())
}
pub fn keys_with_prefix<'a>(&'a self, prefix: &'a str) -> impl Iterator<Item = &'a str> + 'a {
self.tensors
.keys()
.filter(move |k| k.starts_with(prefix))
.map(String::as_str)
}
}
pub fn from_named_tensors(triples: Vec<(String, Vec<usize>, Vec<f32>)>) -> StateDict {
let mut sd = StateDict::new();
for (n, s, d) in triples {
sd.insert(n, s, d);
}
sd
}
pub fn load(path: impl AsRef<Path>) -> Result<StateDict, StateDictError> {
let path = path.as_ref();
match path.extension().and_then(|s| s.to_str()) {
Some("safetensors") => load_safetensors(path),
Some("json") => load_legacy_json_bin(path),
_ => {
let st = path.with_extension("safetensors");
if st.is_file() {
load_safetensors(&st)
} else {
load_legacy_json_bin(path)
}
}
}
}
fn load_safetensors(path: &Path) -> Result<StateDict, StateDictError> {
let bytes = fs::read(path)?;
let st =
SafeTensors::deserialize(&bytes).map_err(|e| StateDictError::Safetensors(e.to_string()))?;
let mut sd = StateDict::new();
for name in st.names() {
let tensor = st
.tensor(name)
.map_err(|e| StateDictError::Safetensors(e.to_string()))?;
if tensor.dtype() != Dtype::F32 {
return Err(StateDictError::UnsupportedDtype(format!(
"{:?}",
tensor.dtype()
)));
}
let shape: Vec<usize> = tensor.shape().to_vec();
let nelem: usize = shape.iter().product();
let data_bytes = tensor.data();
let expected = nelem * 4;
if data_bytes.len() != expected {
return Err(StateDictError::Safetensors(format!(
"tensor {name}: expected {expected} bytes, got {}",
data_bytes.len()
)));
}
let mut data = Vec::with_capacity(nelem);
for chunk in data_bytes.chunks_exact(4) {
data.push(f32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]));
}
sd.insert(name.to_string(), shape, data);
}
Ok(sd)
}
fn load_legacy_json_bin(json_path: &Path) -> Result<StateDict, StateDictError> {
let header_bytes = fs::read(json_path)?;
let header: Header = serde_json::from_slice(&header_bytes)?;
let bin_path = json_path.with_extension("bin");
let mut bin = fs::File::open(&bin_path)?;
let mut all = Vec::new();
std::io::Read::read_to_end(&mut bin, &mut all)?;
let mut sd = StateDict::new();
for meta in header.tensors {
if meta.dtype != "f32" {
return Err(StateDictError::UnsupportedDtype(meta.dtype));
}
let nelem: usize = meta.shape.iter().product();
let nbytes = nelem * 4;
let off = meta.offset as usize;
if off + nbytes > all.len() {
return Err(StateDictError::Io(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
format!("tensor {} truncated at offset {off}", meta.name),
)));
}
let mut data = Vec::with_capacity(nelem);
for i in 0..nelem {
let b = &all[off + i * 4..off + (i + 1) * 4];
data.push(f32::from_le_bytes([b[0], b[1], b[2], b[3]]));
}
sd.insert(meta.name, meta.shape, data);
}
Ok(sd)
}
pub fn save(sd: &StateDict, path: impl AsRef<Path>) -> Result<(), StateDictError> {
let path = path.as_ref();
if path.extension().and_then(|s| s.to_str()) == Some("json") {
save_legacy_json_bin(sd, path)
} else {
let out = if path.extension().is_some() {
path.to_path_buf()
} else {
path.with_extension("safetensors")
};
save_safetensors(sd, &out)
}
}
fn save_safetensors(sd: &StateDict, path: &Path) -> Result<(), StateDictError> {
let mut packed: Vec<(String, Vec<usize>, Vec<u8>)> = Vec::with_capacity(sd.tensors.len());
for (name, (shape, data)) in &sd.tensors {
let mut bytes = Vec::with_capacity(data.len() * 4);
for v in data {
bytes.write_all(&v.to_le_bytes())?;
}
packed.push((name.clone(), shape.clone(), bytes));
}
let mut views: BTreeMap<String, TensorView<'_>> = BTreeMap::new();
for (name, shape, bytes) in &packed {
let view = TensorView::new(Dtype::F32, shape.clone(), bytes.as_slice())
.map_err(|e| StateDictError::Safetensors(e.to_string()))?;
views.insert(name.clone(), view);
}
serialize_to_file(&views, None, path)
.map_err(|e| StateDictError::Safetensors(e.to_string()))?;
Ok(())
}
fn save_legacy_json_bin(sd: &StateDict, json_path: &Path) -> Result<(), StateDictError> {
let bin_path = json_path.with_extension("bin");
let mut header = Header {
tensors: Vec::new(),
};
let mut blob: Vec<u8> = Vec::new();
for (name, (shape, data)) in &sd.tensors {
let offset = blob.len() as u64;
for v in data {
blob.write_all(&v.to_le_bytes())?;
}
header.tensors.push(TensorMeta {
name: name.clone(),
shape: shape.clone(),
offset,
dtype: "f32".into(),
});
}
fs::write(json_path, serde_json::to_vec_pretty(&header)?)?;
fs::write(&bin_path, &blob)?;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile_like::TempDir;
mod tempfile_like {
use std::path::PathBuf;
pub struct TempDir(PathBuf);
impl TempDir {
pub fn new() -> std::io::Result<Self> {
let mut p = std::env::temp_dir();
let pid = std::process::id();
let nonce: u64 = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos() as u64)
.unwrap_or(0);
p.push(format!("tabicl-test-{pid}-{nonce}"));
std::fs::create_dir_all(&p)?;
Ok(Self(p))
}
pub fn path(&self) -> &std::path::Path {
&self.0
}
}
impl Drop for TempDir {
fn drop(&mut self) {
let _ = std::fs::remove_dir_all(&self.0);
}
}
}
#[test]
fn round_trip_save_load_safetensors() {
let mut sd = StateDict::new();
sd.insert("alpha", vec![2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
sd.insert("beta", vec![4], vec![10.0, 20.0, 30.0, 40.0]);
let tmp = TempDir::new().unwrap();
let path = tmp.path().join("state.safetensors");
save(&sd, &path).unwrap();
let loaded = load(&path).unwrap();
assert_eq!(loaded.tensors.len(), 2);
let (sh, d) = loaded.get("alpha").unwrap();
assert_eq!(sh, &vec![2, 3]);
assert_eq!(d, &vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let (sh, d) = loaded.get("beta").unwrap();
assert_eq!(sh, &vec![4]);
assert_eq!(d, &vec![10.0, 20.0, 30.0, 40.0]);
}
#[test]
fn round_trip_save_load_legacy_json() {
let mut sd = StateDict::new();
sd.insert("alpha", vec![2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let tmp = TempDir::new().unwrap();
let path = tmp.path().join("state.json");
save(&sd, &path).unwrap();
let loaded = load(&path).unwrap();
assert_eq!(loaded.tensors.len(), 1);
}
#[test]
fn missing_key_errors() {
let sd = StateDict::new();
assert!(matches!(sd.get("nope"), Err(StateDictError::MissingKey(_))));
}
#[test]
fn shape_mismatch_errors() {
let mut sd = StateDict::new();
sd.insert("w", vec![2, 3], vec![0.0; 6]);
let err = sd.take_with_shape("w", &[3, 2]).unwrap_err();
assert!(matches!(err, StateDictError::ShapeMismatch { .. }));
}
#[test]
fn take_array2_round_trip() {
let mut sd = StateDict::new();
sd.insert("w", vec![2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]);
let a = sd.take_array2("w", 2, 3).unwrap();
assert_eq!(a.shape(), &[2, 3]);
assert_eq!(a[(0, 0)], 1.0);
assert_eq!(a[(1, 2)], 6.0);
}
#[test]
fn keys_with_prefix_filters() {
let mut sd = StateDict::new();
sd.insert("a.b.c", vec![1], vec![0.0]);
sd.insert("a.b.d", vec![1], vec![0.0]);
sd.insert("a.e.f", vec![1], vec![0.0]);
let keys: Vec<_> = sd.keys_with_prefix("a.b.").collect();
assert_eq!(keys, vec!["a.b.c", "a.b.d"]);
}
#[test]
fn unsupported_dtype_errors_on_legacy_load() {
let tmp = TempDir::new().unwrap();
let path = tmp.path().join("ckpt.json");
let header = serde_json::json!({
"tensors": [
{ "name": "x", "shape": [2], "offset": 0, "dtype": "f64" }
]
});
std::fs::write(&path, serde_json::to_vec(&header).unwrap()).unwrap();
std::fs::write(path.with_extension("bin"), vec![0_u8; 16]).unwrap();
assert!(matches!(
load(&path),
Err(StateDictError::UnsupportedDtype(_))
));
}
}