pt-loader 0.1.2

Safe parser-based PyTorch checkpoint converter to safetensors
Documentation
use std::collections::BTreeMap;

use crate::types::{ConvertError, Result, TensorRef, Value};

pub(crate) fn extract_state_dict_tensors(root: &Value) -> Result<BTreeMap<String, TensorRef>> {
  let mut out = BTreeMap::new();

  if let Some(model) = find_named_child(root, "model") {
    collect_module_state_tensors(model, "", &mut out);
  }
  if out.is_empty() {
    collect_module_state_tensors(root, "", &mut out);
  }
  if out.is_empty() {
    let mut best = BTreeMap::new();
    collect_largest_tensor_map(root, &mut best);
    out = best;
  }
  if out.is_empty() {
    return Err(ConvertError::InvalidStructure(
      "could not find a tensor state_dict in checkpoint pickle".to_string(),
    ));
  }
  Ok(out)
}

fn collect_largest_tensor_map(value: &Value, best: &mut BTreeMap<String, TensorRef>) {
  let candidate = tensor_map_from_value(value);
  if candidate.len() > best.len() {
    *best = candidate;
  }

  match value {
    Value::Dict(entries) => {
      for (_, child) in entries {
        collect_largest_tensor_map(child, best);
      }
    }
    Value::OrderedDict(entries) => {
      for (_, child) in entries {
        collect_largest_tensor_map(child, best);
      }
    }
    Value::List(items) | Value::Tuple(items) => {
      for child in items {
        collect_largest_tensor_map(child, best);
      }
    }
    Value::Object { args, state, .. } => {
      if let Some(args) = args {
        collect_largest_tensor_map(args, best);
      }
      if let Some(state) = state {
        collect_largest_tensor_map(state, best);
      }
    }
    _ => {}
  }
}

fn collect_module_state_tensors(value: &Value, prefix: &str, out: &mut BTreeMap<String, TensorRef>) {
  match value {
    Value::Object { state: Some(state), .. } => {
      collect_from_module_state(state, prefix, out);
    }
    Value::Dict(entries) => {
      for (key, child) in entries {
        let Some(name) = key_as_string(key) else {
          continue;
        };
        if name == "state_dict" {
          collect_prefixed_tensor_map(child, "", out);
          continue;
        }
        let next_prefix = join_name(prefix, &name);
        match child {
          Value::Object { .. } => collect_module_state_tensors(child, &next_prefix, out),
          Value::Dict(_) | Value::OrderedDict(_) | Value::List(_) | Value::Tuple(_) => {
            collect_module_state_tensors(child, &next_prefix, out)
          }
          _ => {}
        }
      }
    }
    Value::OrderedDict(entries) => {
      for (name, child) in entries {
        let next_prefix = join_name(prefix, name);
        match child {
          Value::Object { .. } => collect_module_state_tensors(child, &next_prefix, out),
          Value::Dict(_) | Value::OrderedDict(_) | Value::List(_) | Value::Tuple(_) => {
            collect_module_state_tensors(child, &next_prefix, out)
          }
          _ => {}
        }
      }
    }
    Value::List(items) | Value::Tuple(items) => {
      for (idx, child) in items.iter().enumerate() {
        let next_prefix = join_name(prefix, &idx.to_string());
        collect_module_state_tensors(child, &next_prefix, out);
      }
    }
    _ => {}
  }
}

fn collect_from_module_state(state: &Value, prefix: &str, out: &mut BTreeMap<String, TensorRef>) {
  let Some(entries) = mapping_entries(state) else {
    return;
  };

  if let Some(parameters) = mapping_get(entries.as_slice(), "_parameters") {
    collect_named_tensors_from_mapping(parameters, prefix, out);
  }
  if let Some(buffers) = mapping_get(entries.as_slice(), "_buffers") {
    collect_named_tensors_from_mapping(buffers, prefix, out);
  }
  if let Some(modules) = mapping_get(entries.as_slice(), "_modules") {
    if let Some(module_entries) = mapping_entries(modules) {
      for (name, child) in module_entries {
        if matches!(child, Value::None) {
          continue;
        }
        let next_prefix = join_name(prefix, &name);
        collect_module_state_tensors(child, &next_prefix, out);
      }
    }
  }
}

fn collect_named_tensors_from_mapping(value: &Value, prefix: &str, out: &mut BTreeMap<String, TensorRef>) {
  let Some(entries) = mapping_entries(value) else {
    return;
  };
  for (name, child) in entries {
    let Some(tensor) = extract_tensor_ref(child) else {
      continue;
    };
    let full_name = join_name(prefix, &name);
    out.entry(full_name).or_insert(tensor);
  }
}

fn collect_prefixed_tensor_map(value: &Value, prefix: &str, out: &mut BTreeMap<String, TensorRef>) {
  let mapped = tensor_map_from_value(value);
  for (name, tensor) in mapped {
    let full_name = join_name(prefix, &name);
    out.entry(full_name).or_insert(tensor);
  }
}

fn find_named_child<'a>(value: &'a Value, key: &str) -> Option<&'a Value> {
  let entries = mapping_entries(value)?;
  for (name, child) in entries {
    if name == key {
      return Some(child);
    }
  }
  None
}

pub(crate) fn mapping_entries(value: &Value) -> Option<Vec<(String, &Value)>> {
  match value {
    Value::Dict(entries) => Some(
      entries
        .iter()
        .filter_map(|(k, v)| key_as_string(k).map(|name| (name, v)))
        .collect(),
    ),
    Value::OrderedDict(entries) => Some(entries.iter().map(|(k, v)| (k.clone(), v)).collect()),
    _ => None,
  }
}

fn mapping_get<'a>(entries: &'a [(String, &'a Value)], key: &str) -> Option<&'a Value> {
  entries.iter().find_map(|(name, value)| (name == key).then_some(*value))
}

pub(crate) fn key_as_string(value: &Value) -> Option<String> {
  match value {
    Value::String(v) => Some(v.clone()),
    Value::Int(v) => Some(v.to_string()),
    _ => None,
  }
}

fn join_name(prefix: &str, name: &str) -> String {
  if prefix.is_empty() {
    name.to_string()
  } else if name.is_empty() {
    prefix.to_string()
  } else {
    format!("{prefix}.{name}")
  }
}

fn extract_tensor_ref(value: &Value) -> Option<TensorRef> {
  match value {
    Value::TensorRef(spec) => Some(spec.clone()),
    Value::Tuple(items) if !items.is_empty() => match &items[0] {
      Value::TensorRef(spec) => Some(spec.clone()),
      _ => None,
    },
    _ => None,
  }
}

fn tensor_map_from_value(value: &Value) -> BTreeMap<String, TensorRef> {
  let mut out = BTreeMap::new();
  let iter: Box<dyn Iterator<Item = (&str, &Value)> + '_> = match value {
    Value::Dict(entries) => {
      let mapped = entries.iter().filter_map(|(k, v)| match k {
        Value::String(key) => Some((key.as_str(), v)),
        _ => None,
      });
      Box::new(mapped)
    }
    Value::OrderedDict(entries) => Box::new(entries.iter().map(|(k, v)| (k.as_str(), v))),
    _ => return out,
  };

  for (key, value) in iter {
    let tensor = extract_tensor_ref(value);
    if let Some(spec) = tensor {
      out.insert(key.to_string(), spec);
    }
  }

  out
}

pub(crate) fn numel(shape: &[usize]) -> Result<usize> {
  if shape.is_empty() {
    return Ok(1);
  }
  shape.iter().try_fold(1usize, |acc, dim| {
    acc
      .checked_mul(*dim)
      .ok_or_else(|| ConvertError::InvalidStructure("tensor shape product overflow".to_string()))
  })
}

pub(crate) fn contiguous_stride(shape: &[usize]) -> Vec<usize> {
  if shape.is_empty() {
    return Vec::new();
  }
  let mut stride = vec![0usize; shape.len()];
  let mut running = 1usize;
  for idx in (0..shape.len()).rev() {
    stride[idx] = running;
    running = running.saturating_mul(shape[idx]);
  }
  stride
}