import torch
from pathlib import Path
import json
def analyze_checkpoint(checkpoint_path):
print(f"Analyzing checkpoint: {checkpoint_path}")
checkpoint = torch.load(checkpoint_path, map_location="cpu")
if isinstance(checkpoint, dict) and "state_dict" in checkpoint:
print("Checkpoint contains a state_dict")
state_dict = checkpoint["state_dict"]
other_keys = [k for k in checkpoint.keys() if k != "state_dict"]
if other_keys:
print(f"Other keys in checkpoint: {other_keys}")
else:
print("Checkpoint is a direct state_dict")
state_dict = checkpoint
print(f"\nState dict contains {len(state_dict)} tensors")
prefixes = {}
for key in state_dict.keys():
prefix = key.split(".")[0]
if prefix not in prefixes:
prefixes[prefix] = []
prefixes[prefix].append(key)
print("\nKey prefixes:")
for prefix, keys in prefixes.items():
print(f" {prefix}: {len(keys)} keys")
print("\nSample keys (first 10):")
for i, key in enumerate(list(state_dict.keys())[:10]):
tensor = state_dict[key]
print(f" {key}: shape={list(tensor.shape)}, dtype={tensor.dtype}")
return state_dict
def create_key_mapping(state_dict, model_type):
mappings = {}
if model_type == "e2vid_unet":
for key in state_dict.keys():
if key.startswith("encoders.0."):
candle_key = key.replace("encoders.0.", "encoder1.")
mappings[key] = candle_key
elif key.startswith("encoders.1."):
candle_key = key.replace("encoders.1.", "encoder2.")
mappings[key] = candle_key
elif key.startswith("decoders."):
candle_key = key.replace("decoders.", "decoder")
mappings[key] = candle_key
else:
mappings[key] = key
elif model_type == "firenet":
for key in state_dict.keys():
if key.startswith("head."):
mappings[key] = key
elif key.startswith("res_blocks."):
mappings[key] = key
elif key.startswith("tail."):
mappings[key] = key
else:
mappings[key] = key
return mappings
def save_weight_info(checkpoint_path, output_path):
state_dict = analyze_checkpoint(checkpoint_path)
filename = Path(checkpoint_path).stem.lower()
if "e2vid" in filename and "lightweight" in filename:
model_type = "e2vid_unet"
elif "firenet" in filename:
model_type = "firenet"
else:
model_type = "unknown"
print(f"\nDetected model type: {model_type}")
mappings = create_key_mapping(state_dict, model_type)
info = {
"model_type": model_type,
"num_parameters": len(state_dict),
"key_mappings": mappings,
"tensor_info": {},
}
for key, tensor in state_dict.items():
info["tensor_info"][key] = {
"shape": list(tensor.shape),
"dtype": str(tensor.dtype),
"numel": tensor.numel(),
}
with open(output_path, "w") as f:
json.dump(info, f, indent=2)
print(f"\nSaved weight info to: {output_path}")
def test_tensor_conversion():
test_tensors = {
"conv.weight": torch.randn(64, 5, 3, 3),
"conv.bias": torch.randn(64),
"bn.weight": torch.ones(64),
"bn.bias": torch.zeros(64),
"bn.running_mean": torch.zeros(64),
"bn.running_var": torch.ones(64),
}
print("Testing tensor conversion:")
for name, tensor in test_tensors.items():
np_array = tensor.detach().cpu().numpy()
print(f" {name}:")
print(f" PyTorch shape: {list(tensor.shape)}")
print(f" NumPy shape: {np_array.shape}")
print(f" dtype: {np_array.dtype}")
print(f" bytes: {np_array.tobytes()[:20]}...")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Test PyTorch weight loading")
parser.add_argument("--checkpoint", type=str, help="Path to PyTorch checkpoint")
parser.add_argument(
"--output",
type=str,
default="weight_info.json",
help="Output path for weight info JSON",
)
parser.add_argument(
"--test-conversion", action="store_true", help="Test tensor conversion"
)
args = parser.parse_args()
if args.test_conversion:
test_tensor_conversion()
elif args.checkpoint:
save_weight_info(args.checkpoint, args.output)
else:
e2vid_path = Path("models/E2VID_lightweight.pth.tar")
if e2vid_path.exists():
save_weight_info(e2vid_path, "e2vid_weight_info.json")
else:
print(
"No checkpoint found. Please provide --checkpoint or download E2VID model first."
)