use std::collections::BTreeMap;
use std::fs;
use std::path::{Path, PathBuf};
use safetensors::tensor::{Dtype, SafeTensors, TensorView};
#[derive(Debug)]
pub enum ShardError {
ParseSize(String),
Io(std::io::Error),
SafeTensors(safetensors::SafeTensorError),
Invalid(String),
}
impl std::fmt::Display for ShardError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ShardError::ParseSize(m) => write!(f, "invalid --max-shard-size: {m}"),
ShardError::Io(e) => write!(f, "i/o error: {e}"),
ShardError::SafeTensors(e) => write!(f, "safetensors error: {e}"),
ShardError::Invalid(m) => write!(f, "{m}"),
}
}
}
impl std::error::Error for ShardError {}
impl From<std::io::Error> for ShardError {
fn from(e: std::io::Error) -> Self {
ShardError::Io(e)
}
}
impl From<safetensors::SafeTensorError> for ShardError {
fn from(e: safetensors::SafeTensorError) -> Self {
ShardError::SafeTensors(e)
}
}
#[derive(Debug, Clone)]
pub struct ShardReport {
pub shard_files: Vec<PathBuf>,
pub index_path: PathBuf,
pub total_size: u64,
pub tensor_count: usize,
}
fn dtype_size(dt: Dtype) -> usize {
match dt {
Dtype::BOOL | Dtype::U8 | Dtype::I8 | Dtype::F8_E4M3 | Dtype::F8_E5M2 => 1,
Dtype::U16 | Dtype::I16 | Dtype::F16 | Dtype::BF16 => 2,
Dtype::U32 | Dtype::I32 | Dtype::F32 => 4,
Dtype::U64 | Dtype::I64 | Dtype::F64 => 8,
_ => 1, }
}
fn tensor_byte_size(view: &TensorView<'_>) -> u64 {
let elems: u64 = view.shape().iter().map(|d| *d as u64).product();
elems.saturating_mul(dtype_size(view.dtype()) as u64)
}
fn plan_shards<'a>(names: &'a [&'a str], sizes: &[u64], max_shard_size: u64) -> Vec<Vec<usize>> {
debug_assert_eq!(names.len(), sizes.len());
let mut shards: Vec<Vec<usize>> = Vec::new();
let mut current: Vec<usize> = Vec::new();
let mut current_size: u64 = 0;
for (i, &sz) in sizes.iter().enumerate() {
let would_overflow = current_size.saturating_add(sz) > max_shard_size;
if !current.is_empty() && would_overflow {
shards.push(std::mem::take(&mut current));
current_size = 0;
}
current.push(i);
current_size = current_size.saturating_add(sz);
}
if !current.is_empty() {
shards.push(current);
}
shards
}
fn shard_filename(index: usize, total: usize) -> String {
format!("model-{index:05}-of-{total:05}.safetensors")
}
fn build_index_json(weight_map: &BTreeMap<String, String>, total_size: u64) -> String {
let mut out = String::with_capacity(weight_map.len() * 80 + 64);
out.push_str("{\n");
out.push_str(" \"metadata\": {\n");
out.push_str(&format!(" \"total_size\": {total_size}\n"));
out.push_str(" },\n");
out.push_str(" \"weight_map\": {\n");
let mut first = true;
for (name, shard) in weight_map {
if !first {
out.push_str(",\n");
}
first = false;
out.push_str(" ");
out.push_str(&json_string(name));
out.push_str(": ");
out.push_str(&json_string(shard));
}
if !first {
out.push('\n');
}
out.push_str(" }\n");
out.push_str("}\n");
out
}
fn json_string(s: &str) -> String {
let mut out = String::with_capacity(s.len() + 2);
out.push('"');
for c in s.chars() {
match c {
'"' => out.push_str("\\\""),
'\\' => out.push_str("\\\\"),
'\n' => out.push_str("\\n"),
'\r' => out.push_str("\\r"),
'\t' => out.push_str("\\t"),
c if (c as u32) < 0x20 => out.push_str(&format!("\\u{:04x}", c as u32)),
c => out.push(c),
}
}
out.push('"');
out
}
pub fn shard_safetensors_file(
input: &Path,
max_shard_size: u64,
output_dir: &Path,
) -> Result<ShardReport, ShardError> {
if max_shard_size == 0 {
return Err(ShardError::Invalid(
"--max-shard-size must be positive".to_string(),
));
}
if !input.is_file() {
return Err(ShardError::Invalid(format!(
"input is not a file: {}",
input.display()
)));
}
let bytes = fs::read(input)?;
let st = SafeTensors::deserialize(&bytes)?;
let names: Vec<&str> = st.names().into_iter().map(String::as_str).collect();
if names.is_empty() {
return Err(ShardError::Invalid("input has no tensors".to_string()));
}
let views: Vec<TensorView<'_>> = names
.iter()
.map(|n| st.tensor(n))
.collect::<Result<Vec<_>, _>>()?;
let sizes: Vec<u64> = views.iter().map(tensor_byte_size).collect();
let total_size: u64 = sizes.iter().sum();
let plan = plan_shards(&names, &sizes, max_shard_size);
let total_shards = plan.len();
fs::create_dir_all(output_dir)?;
let mut weight_map = BTreeMap::new();
let mut shard_files = Vec::with_capacity(total_shards);
for (idx, group) in plan.iter().enumerate() {
let file_name = shard_filename(idx + 1, total_shards);
let shard_path = output_dir.join(&file_name);
let shard_tensors: Vec<(&str, TensorView<'_>)> = group
.iter()
.map(|&i| (names[i], views[i].clone()))
.collect();
let serialized =
safetensors::serialize(shard_tensors, &None).map_err(ShardError::SafeTensors)?;
fs::write(&shard_path, &serialized)?;
for &i in group {
weight_map.insert(names[i].to_string(), file_name.clone());
}
shard_files.push(shard_path);
}
let index_path = output_dir.join("model.safetensors.index.json");
let index_json = build_index_json(&weight_map, total_size);
fs::write(&index_path, index_json)?;
Ok(ShardReport {
shard_files,
index_path,
total_size,
tensor_count: names.len(),
})
}
#[cfg(test)]
mod plan_tests {
use super::{plan_shards, shard_filename};
#[test]
fn single_shard_when_under_limit() {
let names = vec!["a", "b", "c"];
let sizes = vec![10u64, 20, 30];
let plan = plan_shards(&names, &sizes, 1000);
assert_eq!(plan.len(), 1);
assert_eq!(plan[0], vec![0, 1, 2]);
}
#[test]
fn splits_when_over_limit() {
let names = vec!["a", "b", "c", "d"];
let sizes = vec![60u64, 60, 60, 60];
let plan = plan_shards(&names, &sizes, 100);
assert_eq!(plan.len(), 4);
}
#[test]
fn oversized_tensor_alone() {
let names = vec!["a", "big", "c"];
let sizes = vec![10u64, 5000, 10];
let plan = plan_shards(&names, &sizes, 100);
assert_eq!(plan.len(), 3);
assert_eq!(plan[1], vec![1]);
}
#[test]
fn preserves_insertion_order() {
let names = vec!["x", "y", "z"];
let sizes = vec![50u64, 50, 50];
let plan = plan_shards(&names, &sizes, 100);
let flat: Vec<usize> = plan.into_iter().flatten().collect();
assert_eq!(flat, vec![0, 1, 2]);
}
#[test]
fn shard_filename_format() {
assert_eq!(shard_filename(1, 3), "model-00001-of-00003.safetensors");
assert_eq!(shard_filename(42, 100), "model-00042-of-00100.safetensors");
}
}