pt-loader 0.1.2

Safe parser-based PyTorch checkpoint converter to safetensors
Documentation
use crate::extract::key_as_string;
use crate::types::Value;

fn bytes_to_hex(bytes: &[u8]) -> String {
  let mut out = String::with_capacity(bytes.len() * 2);
  for byte in bytes {
    out.push_str(&format!("{:02x}", byte));
  }
  out
}

pub(crate) fn project_root_metadata(root: &Value) -> serde_yaml::Value {
  project_value_for_metadata(root)
}

pub(crate) fn project_value_for_metadata(value: &Value) -> serde_yaml::Value {
  match value {
    Value::Marker => serde_yaml::Value::String("<marker>".to_string()),
    Value::None => serde_yaml::Value::Null,
    Value::Bool(v) => serde_yaml::Value::Bool(*v),
    Value::Int(v) => serde_yaml::to_value(*v).unwrap_or(serde_yaml::Value::String(v.to_string())),
    Value::Float(v) => serde_yaml::to_value(*v).unwrap_or(serde_yaml::Value::String(v.to_string())),
    Value::String(v) => serde_yaml::Value::String(v.clone()),
    Value::Bytes(v) => serde_yaml::Value::String(format!("0x{}", bytes_to_hex(v))),
    Value::List(items) | Value::Tuple(items) | Value::Set(items) => {
      serde_yaml::Value::Sequence(items.iter().map(project_value_for_metadata).collect())
    }
    Value::Dict(entries) => serde_yaml::Value::Mapping(
      entries
        .iter()
        .map(|(k, v)| {
          let key = key_as_string(k).unwrap_or_else(|| "<non-string-key>".to_string());
          (serde_yaml::Value::String(key), project_value_for_metadata(v))
        })
        .collect(),
    ),
    Value::OrderedDict(entries) => serde_yaml::Value::Mapping(
      entries
        .iter()
        .map(|(k, v)| (serde_yaml::Value::String(k.clone()), project_value_for_metadata(v)))
        .collect(),
    ),
    Value::Global { module, name } => serde_yaml::Value::String(format!("{module}.{name}")),
    Value::StorageRef(storage) => serde_yaml::Value::Mapping(
      [
        (
          serde_yaml::Value::String("$type".to_string()),
          serde_yaml::Value::String("storage_ref".to_string()),
        ),
        (
          serde_yaml::Value::String("dtype".to_string()),
          serde_yaml::Value::String(storage.dtype.as_safetensors().to_string()),
        ),
        (
          serde_yaml::Value::String("key".to_string()),
          serde_yaml::Value::String(storage.key.clone()),
        ),
        (
          serde_yaml::Value::String("size_elems".to_string()),
          serde_yaml::to_value(storage.size_elems).unwrap_or(serde_yaml::Value::Null),
        ),
      ]
      .into_iter()
      .collect(),
    ),
    Value::TensorRef(spec) => serde_yaml::Value::Mapping(
      [
        (
          serde_yaml::Value::String("$type".to_string()),
          serde_yaml::Value::String("tensor_ref".to_string()),
        ),
        (
          serde_yaml::Value::String("dtype".to_string()),
          serde_yaml::Value::String(spec.storage.dtype.as_safetensors().to_string()),
        ),
        (
          serde_yaml::Value::String("storage_key".to_string()),
          serde_yaml::Value::String(spec.storage.key.clone()),
        ),
        (
          serde_yaml::Value::String("offset_elems".to_string()),
          serde_yaml::to_value(spec.offset_elems).unwrap_or(serde_yaml::Value::Null),
        ),
        (
          serde_yaml::Value::String("shape".to_string()),
          serde_yaml::Value::Sequence(
            spec
              .shape
              .iter()
              .map(|v| serde_yaml::to_value(*v).unwrap_or(serde_yaml::Value::Null))
              .collect(),
          ),
        ),
        (
          serde_yaml::Value::String("stride".to_string()),
          serde_yaml::Value::Sequence(
            spec
              .stride
              .iter()
              .map(|v| serde_yaml::to_value(*v).unwrap_or(serde_yaml::Value::Null))
              .collect(),
          ),
        ),
      ]
      .into_iter()
      .collect(),
    ),
    Value::Object {
      module,
      name,
      args,
      state,
    } => {
      let mut map = serde_yaml::Mapping::new();
      map.insert(
        serde_yaml::Value::String("$type".to_string()),
        serde_yaml::Value::String("object".to_string()),
      );
      map.insert(
        serde_yaml::Value::String("$class".to_string()),
        serde_yaml::Value::String(format!("{module}.{name}")),
      );
      if let Some(args) = args {
        if !is_empty_args(args) {
          map.insert(
            serde_yaml::Value::String("$args".to_string()),
            project_value_for_metadata(args),
          );
        }
      }
      if let Some(state) = state {
        let projected_state = project_value_for_metadata(state);
        match projected_state {
          serde_yaml::Value::Mapping(state_map) => {
            for (key, value) in state_map {
              map.insert(key, value);
            }
          }
          serde_yaml::Value::Null => {}
          other => {
            map.insert(serde_yaml::Value::String("$state".to_string()), other);
          }
        }
      }
      serde_yaml::Value::Mapping(map)
    }
  }
}

fn is_empty_args(value: &Value) -> bool {
  match value {
    Value::None => true,
    Value::Tuple(items) | Value::List(items) | Value::Set(items) => items.is_empty(),
    Value::Dict(entries) => entries.is_empty(),
    Value::OrderedDict(entries) => entries.is_empty(),
    _ => false,
  }
}

pub(crate) fn collect_constructor_types(root: &Value) -> Vec<String> {
  let mut seen = std::collections::BTreeSet::new();
  let mut out = Vec::new();
  collect_constructor_types_inner(root, &mut seen, &mut out);
  out
}

fn collect_constructor_types_inner(
  value: &Value,
  seen: &mut std::collections::BTreeSet<String>,
  out: &mut Vec<String>,
) {
  match value {
    Value::Object {
      module,
      name,
      args,
      state,
    } => {
      let ty = format!("{module}.{name}");
      if seen.insert(ty.clone()) {
        out.push(ty);
      }
      if let Some(args) = args {
        collect_constructor_types_inner(args, seen, out);
      }
      if let Some(state) = state {
        collect_constructor_types_inner(state, seen, out);
      }
    }
    Value::Dict(entries) => {
      for (k, v) in entries {
        collect_constructor_types_inner(k, seen, out);
        collect_constructor_types_inner(v, seen, out);
      }
    }
    Value::OrderedDict(entries) => {
      for (_, v) in entries {
        collect_constructor_types_inner(v, seen, out);
      }
    }
    Value::List(items) | Value::Tuple(items) | Value::Set(items) => {
      for item in items {
        collect_constructor_types_inner(item, seen, out);
      }
    }
    _ => {}
  }
}