pt-loader 0.1.4

Safe parser-based PyTorch checkpoint converter to safetensors
Documentation
use std::collections::BTreeMap;

use crate::types::{ConvertError, LoadOptions, Result, TensorRef, Value};

pub(crate) fn extract_state_dict_tensors(
  root: &Value,
  opts: &LoadOptions,
) -> Result<BTreeMap<String, BTreeMap<String, TensorRef>>> {
  let mut groups = BTreeMap::new();

  if opts.state_dict_root_keys.is_empty() {
    let mut out = BTreeMap::new();
    collect_module_state_tensors(root, "", &mut out);
    if out.is_empty() {
      return Err(ConvertError::InvalidStructure(
        "could not find a tensor state_dict in checkpoint pickle".to_string(),
      ));
    }
    groups.insert("root".to_string(), out);
    return Ok(groups);
  }

  for key in &opts.state_dict_root_keys {
    let Some(target) = find_deep_child(root, key) else {
      if opts.state_dict_root_strict {
        return Err(ConvertError::InvalidStructure(format!(
          "could not find configured state_dict_root_key '{}' in checkpoint pickle",
          key
        )));
      }
      continue;
    };
    let mut out = BTreeMap::new();
    collect_module_state_tensors(target, "", &mut out);
    if out.is_empty() {
      if opts.state_dict_root_strict {
        return Err(ConvertError::InvalidStructure(format!(
          "could not find a tensor state_dict under configured state_dict_root_key '{}'",
          key
        )));
      }
      continue;
    }
    groups.insert(key.clone(), out);
  }

  if groups.is_empty() {
    return Err(ConvertError::InvalidStructure(
      "could not find a tensor state_dict in any configured state_dict_root_key".to_string(),
    ));
  }
  Ok(groups)
}

fn collect_module_state_tensors(value: &Value, prefix: &str, out: &mut BTreeMap<String, TensorRef>) {
  match value {
    Value::Object { state: Some(state), .. } => {
      collect_from_module_state(state, prefix, out);
    }
    Value::Dict(entries) => {
      collect_prefixed_tensor_map(value, prefix, out);
      for (key, child) in entries {
        let Some(name) = key_as_string(key) else {
          continue;
        };
        if name == "state_dict" {
          collect_prefixed_tensor_map(child, "", out);
          continue;
        }
        let next_prefix = join_name(prefix, &name);
        match child {
          Value::Object { .. } => collect_module_state_tensors(child, &next_prefix, out),
          Value::Dict(_) | Value::OrderedDict(_) | Value::List(_) | Value::Tuple(_) => {
            collect_module_state_tensors(child, &next_prefix, out)
          }
          _ => {}
        }
      }
    }
    Value::OrderedDict(entries) => {
      collect_prefixed_tensor_map(value, prefix, out);
      for (name, child) in entries {
        let next_prefix = join_name(prefix, name);
        match child {
          Value::Object { .. } => collect_module_state_tensors(child, &next_prefix, out),
          Value::Dict(_) | Value::OrderedDict(_) | Value::List(_) | Value::Tuple(_) => {
            collect_module_state_tensors(child, &next_prefix, out)
          }
          _ => {}
        }
      }
    }
    Value::List(items) | Value::Tuple(items) => {
      for (idx, child) in items.iter().enumerate() {
        let next_prefix = join_name(prefix, &idx.to_string());
        collect_module_state_tensors(child, &next_prefix, out);
      }
    }
    Value::Call { args, state, .. } => {
      for (idx, child) in args.iter().enumerate() {
        let next_prefix = join_name(prefix, &idx.to_string());
        collect_module_state_tensors(child, &next_prefix, out);
      }
      if let Some(state) = state {
        collect_module_state_tensors(state, prefix, out);
      }
    }
    _ => {}
  }
}

fn collect_from_module_state(state: &Value, prefix: &str, out: &mut BTreeMap<String, TensorRef>) {
  let Some(entries) = mapping_entries(state) else {
    return;
  };

  if let Some(parameters) = mapping_get(entries.as_slice(), "_parameters") {
    collect_named_tensors_from_mapping(parameters, prefix, out);
  }
  if let Some(buffers) = mapping_get(entries.as_slice(), "_buffers") {
    collect_named_tensors_from_mapping(buffers, prefix, out);
  }
  if let Some(modules) = mapping_get(entries.as_slice(), "_modules") {
    if let Some(module_entries) = mapping_entries(modules) {
      for (name, child) in module_entries {
        if matches!(child, Value::None) {
          continue;
        }
        let next_prefix = join_name(prefix, &name);
        collect_module_state_tensors(child, &next_prefix, out);
      }
    }
  }
}

fn collect_named_tensors_from_mapping(value: &Value, prefix: &str, out: &mut BTreeMap<String, TensorRef>) {
  let Some(entries) = mapping_entries(value) else {
    return;
  };
  for (name, child) in entries {
    let Some(tensor) = extract_tensor_ref(child) else {
      continue;
    };
    let full_name = join_name(prefix, &name);
    out.entry(full_name).or_insert(tensor);
  }
}

fn collect_prefixed_tensor_map(value: &Value, prefix: &str, out: &mut BTreeMap<String, TensorRef>) {
  let mapped = tensor_map_from_value(value);
  for (name, tensor) in mapped {
    let full_name = join_name(prefix, &name);
    out.entry(full_name).or_insert(tensor);
  }
}

fn find_named_child<'a>(value: &'a Value, key: &str) -> Option<&'a Value> {
  let entries = mapping_entries(value)?;
  for (name, child) in entries {
    if name == key {
      return Some(child);
    }
  }
  None
}

fn find_deep_child<'a>(value: &'a Value, key_path: &str) -> Option<&'a Value> {
  let mut current = value;
  for segment in key_path.split('.') {
    current = find_named_child(current, segment)?;
  }
  Some(current)
}

pub(crate) fn mapping_entries(value: &Value) -> Option<Vec<(String, &Value)>> {
  match value {
    Value::Dict(entries) => Some(
      entries
        .iter()
        .filter_map(|(k, v)| key_as_string(k).map(|name| (name, v)))
        .collect(),
    ),
    Value::OrderedDict(entries) => Some(entries.iter().map(|(k, v)| (k.clone(), v)).collect()),
    _ => None,
  }
}

fn mapping_get<'a>(entries: &'a [(String, &'a Value)], key: &str) -> Option<&'a Value> {
  entries.iter().find_map(|(name, value)| (name == key).then_some(*value))
}

pub(crate) fn key_as_string(value: &Value) -> Option<String> {
  match value {
    Value::String(v) => Some(v.clone()),
    Value::Int(v) => Some(v.to_string()),
    _ => None,
  }
}

fn join_name(prefix: &str, name: &str) -> String {
  if prefix.is_empty() {
    name.to_string()
  } else if name.is_empty() {
    prefix.to_string()
  } else {
    format!("{prefix}.{name}")
  }
}

fn extract_tensor_ref(value: &Value) -> Option<TensorRef> {
  match value {
    Value::TensorRef(spec) => Some(spec.clone()),
    Value::Tuple(items) if !items.is_empty() => match &items[0] {
      Value::TensorRef(spec) => Some(spec.clone()),
      _ => None,
    },
    _ => None,
  }
}

fn tensor_map_from_value(value: &Value) -> BTreeMap<String, TensorRef> {
  let mut out = BTreeMap::new();
  let iter: Box<dyn Iterator<Item = (&str, &Value)> + '_> = match value {
    Value::Dict(entries) => {
      let mapped = entries.iter().filter_map(|(k, v)| match k {
        Value::String(key) => Some((key.as_str(), v)),
        _ => None,
      });
      Box::new(mapped)
    }
    Value::OrderedDict(entries) => Box::new(entries.iter().map(|(k, v)| (k.as_str(), v))),
    _ => return out,
  };

  for (key, value) in iter {
    let tensor = extract_tensor_ref(value);
    if let Some(spec) = tensor {
      out.insert(key.to_string(), spec);
    }
  }

  out
}

pub(crate) fn numel(shape: &[usize]) -> Result<usize> {
  if shape.is_empty() {
    return Ok(1);
  }
  shape.iter().try_fold(1usize, |acc, dim| {
    acc
      .checked_mul(*dim)
      .ok_or_else(|| ConvertError::InvalidStructure("tensor shape product overflow".to_string()))
  })
}

pub(crate) fn contiguous_stride(shape: &[usize]) -> Vec<usize> {
  if shape.is_empty() {
    return Vec::new();
  }
  let mut stride = vec![0usize; shape.len()];
  let mut running = 1usize;
  for idx in (0..shape.len()).rev() {
    stride[idx] = running;
    running = running.saturating_mul(shape[idx]);
  }
  stride
}

#[cfg(test)]
mod tests {
  use super::*;
  use crate::types::{DType, StorageRef};

  fn tensor_ref_with_key(key: &str) -> Value {
    Value::TensorRef(TensorRef {
      storage: StorageRef {
        key: key.to_string(),
        dtype: DType::F32,
        size_elems: 1,
      },
      offset_elems: 0,
      shape: vec![1],
      stride: vec![1],
    })
  }

  fn module_with_named_parameter(param_name: &str, storage_key: &str) -> Value {
    Value::Object {
      module: "torch.nn.modules.linear".to_string(),
      name: "Linear".to_string(),
      args: Vec::new(),
      state: Some(Box::new(Value::Dict(vec![
        (
          Value::String("_parameters".to_string()),
          Value::Dict(vec![(
            Value::String(param_name.to_string()),
            tensor_ref_with_key(storage_key),
          )]),
        ),
        (Value::String("_buffers".to_string()), Value::Dict(Vec::new())),
        (Value::String("_modules".to_string()), Value::Dict(Vec::new())),
      ]))),
    }
  }

  #[test]
  fn extracts_from_root_when_no_root_key_is_set() {
    let root = Value::Dict(vec![
      (
        Value::String("model".to_string()),
        module_with_named_parameter("w_model", "model_key"),
      ),
      (
        Value::String("module".to_string()),
        module_with_named_parameter("w_module", "module_key"),
      ),
    ]);

    let out = extract_state_dict_tensors(&root, &LoadOptions::default()).expect("extract tensors");
    let root_group = out.get("root").expect("root group");
    assert!(root_group.contains_key("model.w_model"));
    assert!(root_group.contains_key("module.w_module"));
  }

  #[test]
  fn supports_custom_state_dict_root_key_list() {
    let root = Value::Dict(vec![
      (
        Value::String("model".to_string()),
        module_with_named_parameter("w_model", "model_key"),
      ),
      (
        Value::String("module".to_string()),
        module_with_named_parameter("w_module", "module_key"),
      ),
    ]);

    let opts = LoadOptions {
      state_dict_root_keys: vec!["module".to_string()],
      ..LoadOptions::default()
    };
    let out = extract_state_dict_tensors(&root, &opts).expect("extract tensors");
    let module_group = out.get("module").expect("module group");
    assert!(module_group.contains_key("w_module"));
    assert!(!module_group.contains_key("w_model"));
  }

  #[test]
  fn fails_without_fallback_when_custom_root_key_is_missing_in_strict_mode() {
    let root = Value::Dict(vec![(
      Value::String("model".to_string()),
      module_with_named_parameter("w_model", "model_key"),
    )]);

    let opts = LoadOptions {
      state_dict_root_keys: vec!["ema".to_string()],
      ..LoadOptions::default()
    };
    let err = extract_state_dict_tensors(&root, &opts).expect_err("missing custom root key should fail");
    assert!(err
      .to_string()
      .contains("could not find configured state_dict_root_key"));
  }

  #[test]
  fn skips_missing_when_non_strict_mode_is_enabled() {
    let root = Value::Dict(vec![(
      Value::String("model".to_string()),
      module_with_named_parameter("w_model", "model_key"),
    )]);

    let opts = LoadOptions {
      state_dict_root_keys: vec!["ema".to_string(), "model".to_string()],
      state_dict_root_strict: false,
      ..LoadOptions::default()
    };
    let out = extract_state_dict_tensors(&root, &opts).expect("should extract model key and skip ema");
    assert!(!out.contains_key("ema"));
    assert!(out.contains_key("model"));
  }

  #[test]
  fn supports_deep_key_path() {
    let root = Value::Dict(vec![(
      Value::String("ema".to_string()),
      Value::Dict(vec![(
        Value::String("model".to_string()),
        module_with_named_parameter("w_ema_model", "ema_model_key"),
      )]),
    )]);

    let opts = LoadOptions {
      state_dict_root_keys: vec!["ema.model".to_string()],
      ..LoadOptions::default()
    };
    let out = extract_state_dict_tensors(&root, &opts).expect("extract deep key path");
    let group = out.get("ema.model").expect("ema.model group");
    assert!(group.contains_key("w_ema_model"));
  }
}