pt-loader 0.1.0

Safe parser-based PyTorch checkpoint converter to safetensors
Documentation
use pt_loader::{convert_pt_to_safetensors, inspect_pt, ConvertOptions};
use safetensors::SafeTensors;
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::BTreeMap;
use std::path::Path;

#[derive(Debug, Clone, Serialize, Deserialize)]
struct ModelYaml {
  format_version: usize,
  source_file: String,
  source_sha256: String,
  safetensors_file: String,
  created_at_unix: u64,
  tensor_count: usize,
  total_tensor_bytes: usize,
  #[serde(default)]
  metadata: serde_yaml::Value,
  #[serde(default)]
  objects: Vec<String>,
  tensors: Vec<ModelYamlTensor>,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
struct ModelYamlTensor {
  name: String,
  dtype: String,
  shape: Vec<usize>,
  sha256: String,
}

#[test]
fn extracts_sample_yolo26n_pt() {
  let crate_root = std::path::Path::new(env!("CARGO_MANIFEST_DIR"));
  let input = crate_root.join("samples").join("yolo26n.pt");
  let reference_safetensors = crate_root.join("samples").join("yolo26n.safetensors");
  let reference_yaml = crate_root.join("samples").join("yolo26n.yaml");
  assert!(
    input.exists(),
    "expected sample checkpoint at {}",
    input.display()
  );
  let needs_regen = reference_yaml_needs_regeneration(&reference_yaml);
  if needs_regen {
    assert!(
      reference_safetensors.exists(),
      "expected sample safetensors at {} to regenerate {}",
      reference_safetensors.display(),
      reference_yaml.display()
    );
    generate_yaml_from_safetensors(&reference_safetensors, &reference_yaml);
  }

  let out_dir = crate_root.join("out");
  if out_dir.exists() {
    std::fs::remove_dir_all(&out_dir).expect("remove old out dir");
  }
  std::fs::create_dir_all(&out_dir).expect("create out dir");

  let result = convert_pt_to_safetensors(&input, &out_dir, ConvertOptions::default())
    .expect("sample checkpoint should convert");

  assert!(result.safetensors_path.exists());
  assert!(result.model_yaml_path.exists());
  assert!(result.tensor_count > 0);
  assert!(result.total_tensor_bytes > 0);

  let report = inspect_pt(&input).expect("sample inspect should pass");
  assert!(report.tensor_count > 0);
  assert_eq!(report.tensor_count, result.tensor_count);

  let converted_shapes = read_safetensors_shapes(&result.safetensors_path);
  assert!(!converted_shapes.is_empty(), "converted safetensors is empty");
  if reference_safetensors.exists() {
    let reference_shapes = read_safetensors_shapes(&reference_safetensors);
    assert!(!reference_shapes.is_empty(), "reference safetensors is empty");
    assert_eq!(
      converted_shapes, reference_shapes,
      "converted out/model.safetensors shapes must match samples/yolo26n.safetensors"
    );
  }

  let out_yaml_tensors = read_model_yaml_tensors(&result.model_yaml_path);
  let reference_yaml_tensors = read_model_yaml_tensors(&reference_yaml);
  assert_eq!(
    out_yaml_tensors, reference_yaml_tensors,
    "converted out/model.yaml tensors must match samples/yolo26n.yaml"
  );

  let out_yaml = read_model_yaml(&result.model_yaml_path);
  let metadata_len = match &out_yaml.metadata {
    serde_yaml::Value::Mapping(map) => map.len(),
    _ => 0,
  };
  assert!(
    metadata_len > 10,
    "expected extracted metadata map to contain many top-level fields, got {}",
    metadata_len
  );
  assert!(
    out_yaml
      .objects
      .iter()
      .any(|item| item == "ultralytics.nn.tasks.DetectionModel"),
    "expected objects to include ultralytics.nn.tasks.DetectionModel"
  );
}

fn reference_yaml_needs_regeneration(yaml_path: &Path) -> bool {
  if !yaml_path.exists() {
    true
  } else {
    let parsed = read_model_yaml(yaml_path);
    let raw = std::fs::read_to_string(yaml_path).expect("read existing yaml");
    parsed.source_sha256.trim().is_empty()
      || parsed.tensors.iter().any(|item| item.sha256.trim().is_empty())
      || raw.contains("shape:\n")
  }
}

fn read_safetensors_shapes(path: &Path) -> BTreeMap<String, Vec<usize>> {
  let bytes = std::fs::read(path).expect("read safetensors bytes");
  let tensors = SafeTensors::deserialize(&bytes).expect("parse safetensors");
  let mut shapes = BTreeMap::new();
  for name in tensors.names() {
    let view = tensors.tensor(name).expect("tensor view should exist");
    shapes.insert(
      name.to_string(),
      view.shape().iter().map(|dim| *dim as usize).collect(),
    );
  }
  shapes
}

fn generate_yaml_from_safetensors(safetensors_path: &Path, yaml_path: &Path) {
  let bytes = std::fs::read(safetensors_path).expect("read safetensors bytes");
  let tensors = SafeTensors::deserialize(&bytes).expect("parse safetensors");

  let mut names = tensors
    .names()
    .iter()
    .map(|value| value.to_string())
    .collect::<Vec<_>>();
  names.sort();

  let mut total_tensor_bytes = 0usize;
  let mut yaml_tensors = Vec::with_capacity(names.len());
  for name in names {
    let view = tensors.tensor(&name).expect("tensor view should exist");
    let raw = view.data();
    total_tensor_bytes += raw.len();
    yaml_tensors.push(ModelYamlTensor {
      name,
      dtype: format!("{:?}", view.dtype()),
      shape: view.shape().iter().map(|dim| *dim as usize).collect(),
      sha256: sha256_bytes(raw),
    });
  }

  let payload = ModelYaml {
    format_version: 1,
    source_file: safetensors_path.display().to_string(),
    source_sha256: sha256_file(safetensors_path),
    safetensors_file: safetensors_path
      .file_name()
      .map(|value| value.to_string_lossy().into_owned())
      .unwrap_or_else(|| "yolo26n.safetensors".to_string()),
    created_at_unix: 0,
    tensor_count: yaml_tensors.len(),
    total_tensor_bytes,
    metadata: serde_yaml::Value::Null,
    objects: Vec::new(),
    tensors: yaml_tensors,
  };

  let encoded = serde_yaml::to_string(&payload).expect("encode yaml");
  let encoded = format_shape_inline_yaml(&encoded);
  std::fs::write(yaml_path, encoded).expect("write generated yaml");
}

fn read_model_yaml(path: &Path) -> ModelYaml {
  let raw = std::fs::read_to_string(path).expect("read yaml");
  serde_yaml::from_str::<ModelYaml>(&raw).expect("parse model yaml")
}

fn read_model_yaml_tensors(path: &Path) -> BTreeMap<String, (String, Vec<usize>, String)> {
  let parsed = read_model_yaml(path);
  let mut map = BTreeMap::new();
  for item in parsed.tensors {
    map.insert(item.name, (item.dtype, item.shape, item.sha256));
  }
  map
}

fn sha256_file(path: &Path) -> String {
  let bytes = std::fs::read(path).expect("read file for sha256");
  sha256_bytes(&bytes)
}

fn sha256_bytes(bytes: &[u8]) -> String {
  let mut hasher = Sha256::new();
  hasher.update(bytes);
  hasher
    .finalize()
    .as_slice()
    .iter()
    .map(|byte| format!("{byte:02x}"))
    .collect()
}

fn format_shape_inline_yaml(raw: &str) -> String {
  let lines = raw.lines().collect::<Vec<_>>();
  let mut out = Vec::new();
  let mut idx = 0usize;
  while idx < lines.len() {
    let line = lines[idx];
    if line.trim() == "shape:" {
      let indent = line.chars().take_while(|ch| ch.is_whitespace()).count();
      let mut values = Vec::new();
      let mut probe = idx + 1;
      while probe < lines.len() {
        let next = lines[probe];
        let next_indent = next.chars().take_while(|ch| ch.is_whitespace()).count();
        if next_indent < indent {
          break;
        }
        let trimmed = next.trim();
        if let Some(value) = trimmed.strip_prefix("- ") {
          values.push(value.trim().to_string());
          probe += 1;
          continue;
        }
        break;
      }

      if values.is_empty() {
        out.push(line.to_string());
        idx += 1;
        continue;
      }

      out.push(format!(
        "{}shape: [{}]",
        " ".repeat(indent),
        values.join(", ")
      ));
      idx = probe;
      continue;
    }

    out.push(line.to_string());
    idx += 1;
  }

  out.join("\n") + "\n"
}