use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use safetensors::tensor::{SafeTensors, TensorView};
#[derive(Debug)]
pub enum UnshardError {
Io(std::io::Error),
SafeTensors(safetensors::SafeTensorError),
Invalid(String),
}
impl std::fmt::Display for UnshardError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
UnshardError::Io(e) => write!(f, "i/o error: {e}"),
UnshardError::SafeTensors(e) => write!(f, "safetensors error: {e}"),
UnshardError::Invalid(m) => write!(f, "{m}"),
}
}
}
impl std::error::Error for UnshardError {}
impl From<std::io::Error> for UnshardError {
fn from(e: std::io::Error) -> Self {
UnshardError::Io(e)
}
}
impl From<safetensors::SafeTensorError> for UnshardError {
fn from(e: safetensors::SafeTensorError) -> Self {
UnshardError::SafeTensors(e)
}
}
#[derive(Debug, Clone)]
pub struct UnshardReport {
pub output: PathBuf,
pub tensor_count: usize,
pub shard_count: usize,
pub total_size: u64,
}
struct Index {
weight_map: Vec<(String, String)>, total_size: Option<u64>,
}
fn parse_index_json(json: &str) -> Result<Index, UnshardError> {
let json = json.trim();
if !json.starts_with('{') || !json.ends_with('}') {
return Err(UnshardError::Invalid(
"index.json is not a JSON object".to_string(),
));
}
let total_size = json.find("\"total_size\"").and_then(|pos| {
let after = &json[pos + 12..];
let colon = after.find(':')?;
let after_colon = after[colon + 1..].trim_start();
let end = after_colon
.find(|c: char| !c.is_ascii_digit())
.unwrap_or(after_colon.len());
after_colon[..end].parse::<u64>().ok()
});
let wm_start = json.find("\"weight_map\"").ok_or_else(|| {
UnshardError::Invalid("missing 'weight_map' key in index.json".to_string())
})?;
let after_key = &json[wm_start + 12..];
let obj_start = after_key.find('{').ok_or_else(|| {
UnshardError::Invalid("malformed weight_map: missing opening brace".to_string())
})?;
let obj = &after_key[obj_start..];
let mut depth = 0i32;
let mut obj_end = 0usize;
for (i, c) in obj.char_indices() {
match c {
'{' => depth += 1,
'}' => {
depth -= 1;
if depth == 0 {
obj_end = i;
break;
}
}
_ => {}
}
}
if obj_end == 0 {
return Err(UnshardError::Invalid(
"malformed weight_map: missing closing brace".to_string(),
));
}
let inner = &obj[1..obj_end];
let mut entries = Vec::new();
for pair in inner.split(',') {
let pair = pair.trim();
if pair.is_empty() {
continue;
}
let parts: Vec<&str> = pair.splitn(2, ':').collect();
if parts.len() != 2 {
continue;
}
let key = unquote(parts[0].trim());
let val = unquote(parts[1].trim());
if !key.is_empty() && !val.is_empty() {
entries.push((key, val));
}
}
Ok(Index {
weight_map: entries,
total_size,
})
}
fn unquote(s: &str) -> String {
let s = s.trim();
if s.len() >= 2 && s.starts_with('"') && s.ends_with('"') {
s[1..s.len() - 1].to_string()
} else {
s.to_string()
}
}
fn validate_shard_path(name: &str) -> Result<(), UnshardError> {
let p = Path::new(name);
if p.is_absolute() {
return Err(UnshardError::Invalid(format!(
"shard filename must be relative: {name}"
)));
}
for comp in p.components() {
use std::path::Component;
match comp {
Component::Normal(_) => {}
Component::ParentDir => {
return Err(UnshardError::Invalid(format!(
"shard filename contains '..': {name}"
)));
}
Component::RootDir | Component::Prefix(_) => {
return Err(UnshardError::Invalid(format!(
"shard filename has root component: {name}"
)));
}
Component::CurDir => {}
}
}
Ok(())
}
pub fn unshard_safetensors_dir(
input_dir: &Path,
output: &Path,
) -> Result<UnshardReport, UnshardError> {
if !input_dir.is_dir() {
return Err(UnshardError::Invalid(format!(
"input is not a directory: {}",
input_dir.display()
)));
}
let index_path = input_dir.join("model.safetensors.index.json");
if !index_path.is_file() {
return Err(UnshardError::Invalid(format!(
"missing model.safetensors.index.json in {}",
input_dir.display()
)));
}
let index_text = fs::read_to_string(&index_path)?;
let index = parse_index_json(&index_text)?;
if index.weight_map.is_empty() {
return Err(UnshardError::Invalid(
"weight_map is empty in index.json".to_string(),
));
}
let mut shard_order: Vec<String> = Vec::new();
let mut seen: HashMap<String, ()> = HashMap::new();
for (_tensor, shard) in &index.weight_map {
validate_shard_path(shard)?;
if seen.insert(shard.clone(), ()).is_none() {
shard_order.push(shard.clone());
}
}
let mut shard_bytes: HashMap<String, Vec<u8>> = HashMap::new();
for shard in &shard_order {
let shard_path = input_dir.join(shard);
if !shard_path.is_file() {
return Err(UnshardError::Invalid(format!(
"shard file missing on disk: {}",
shard_path.display()
)));
}
shard_bytes.insert(shard.clone(), fs::read(&shard_path)?);
}
let mut all_views: Vec<(String, TensorView<'_>)> = Vec::new();
let mut total_bytes: u64 = 0;
let mut by_shard: HashMap<&str, SafeTensors<'_>> = HashMap::new();
for shard in &shard_order {
let bytes = shard_bytes.get(shard).unwrap();
let st = SafeTensors::deserialize(bytes)?;
by_shard.insert(shard.as_str(), st);
}
for (tensor_name, shard_name) in &index.weight_map {
let st = by_shard
.get(shard_name.as_str())
.ok_or_else(|| UnshardError::Invalid(format!("shard not loaded: {shard_name}")))?;
let view = st.tensor(tensor_name).map_err(|e| {
UnshardError::Invalid(format!(
"tensor '{tensor_name}' declared in weight_map but not present in shard {shard_name}: {e}"
))
})?;
total_bytes = total_bytes.saturating_add(view.data().len() as u64);
all_views.push((tensor_name.clone(), view));
}
if let Some(declared) = index.total_size {
if declared != total_bytes {
return Err(UnshardError::Invalid(format!(
"index.json total_size {declared} disagrees with shard contents {total_bytes}"
)));
}
}
let view_refs: Vec<(&str, TensorView<'_>)> = all_views
.iter()
.map(|(n, v)| (n.as_str(), v.clone()))
.collect();
let serialized = safetensors::serialize(view_refs, &None).map_err(UnshardError::SafeTensors)?;
if let Some(parent) = output.parent() {
if !parent.as_os_str().is_empty() {
fs::create_dir_all(parent)?;
}
}
fs::write(output, &serialized)?;
Ok(UnshardReport {
output: output.to_path_buf(),
tensor_count: index.weight_map.len(),
shard_count: shard_order.len(),
total_size: total_bytes,
})
}
#[cfg(test)]
mod parser_tests {
use super::{parse_index_json, validate_shard_path};
#[test]
fn parses_minimal_index() {
let json = r#"{
"metadata": {"total_size": 1024},
"weight_map": {
"a.weight": "model-00001-of-00002.safetensors",
"b.weight": "model-00002-of-00002.safetensors"
}
}"#;
let idx = parse_index_json(json).expect("parse");
assert_eq!(idx.total_size, Some(1024));
assert_eq!(idx.weight_map.len(), 2);
assert_eq!(idx.weight_map[0].0, "a.weight");
}
#[test]
fn rejects_traversal() {
assert!(validate_shard_path("../escape.safetensors").is_err());
}
#[test]
fn rejects_absolute() {
assert!(validate_shard_path("/etc/passwd").is_err());
}
#[test]
fn accepts_relative_filename() {
assert!(validate_shard_path("model-00001-of-00002.safetensors").is_ok());
}
}