import argparse
import io
import numpy as np
import os
import pickle
import struct
import sys
import zipfile
class StorageRef:
def __init__(self, key, storage_type, size):
self.key = key
self.storage_type = storage_type
self.size = size
class FakeTensor:
def __init__(self, storage, storage_offset, size, stride, dtype=np.float32):
self.storage = storage
self.storage_offset = storage_offset
self.size = size self.stride = stride self.dtype = dtype
def to_numpy(self, raw_bytes_map):
raw = raw_bytes_map[self.storage.key]
item = np.dtype(self.dtype).itemsize
arr = np.frombuffer(raw, dtype=self.dtype)
total = int(np.prod(self.size)) if self.size else 1
start = self.storage_offset
return arr[start:start + total].reshape(self.size).copy()
STORAGE_TO_DTYPE = {
"FloatStorage": np.float32,
"DoubleStorage": np.float64,
"HalfStorage": np.float16,
"BFloat16Storage":np.float32, "IntStorage": np.int32,
"LongStorage": np.int64,
"ShortStorage": np.int16,
"ByteStorage": np.uint8,
"BoolStorage": np.bool_,
}
BF16_TYPES = {"BFloat16Storage"}
class TorchUnpickler(pickle.Unpickler):
def __init__(self, file, zip_file, prefix):
super().__init__(file)
self._zip = zip_file
self._prefix = prefix self._storages = {}
def find_class(self, module, name):
if module.startswith("torch") or module.startswith("_codecs"):
if name in STORAGE_TO_DTYPE:
dtype = STORAGE_TO_DTYPE[name]
is_bf16 = name in BF16_TYPES
def make_storage(dtype=dtype, is_bf16=is_bf16):
return (dtype, is_bf16)
make_storage.__name__ = name
return make_storage
if name == "_rebuild_tensor_v2":
return rebuild_tensor_v2
if name in ("_rebuild_parameter", "_rebuild_parameter_with_state"):
return lambda data, *a, **kw: data
if name == "_rebuild_from_type_v2":
return lambda fn, tp, args, kw: fn(*args)
return lambda *a, **kw: None
if module == "collections" and name == "OrderedDict":
from collections import OrderedDict
return OrderedDict
if module == "_codecs":
return lambda *a, **kw: None
return super().find_class(module, name)
def persistent_load(self, pid):
if isinstance(pid, tuple) and pid[0] in ("storage", b"storage"):
_, storage_callable, key, _location, n_elements = pid
if callable(storage_callable):
info = storage_callable() if info is None:
info = (np.float32, False)
dtype, is_bf16 = info
else:
dtype, is_bf16 = np.float32, False
ref = StorageRef(key, (dtype, is_bf16), n_elements)
self._storages[key] = ref
return ref
return pid
def rebuild_tensor_v2(storage, storage_offset, size, stride, *args, **kwargs):
dtype, is_bf16 = storage.storage_type if isinstance(storage.storage_type, tuple) else (np.float32, False)
return FakeTensor(storage, storage_offset, tuple(size), tuple(stride), dtype=dtype)
def load_pytorch_bin(path):
with zipfile.ZipFile(path, "r") as zf:
names = zf.namelist()
prefix = names[0].split("/")[0]
pkl_name = f"{prefix}/data.pkl"
if pkl_name not in names:
pkl_candidates = [n for n in names if n.endswith("data.pkl")]
if not pkl_candidates:
raise RuntimeError(f"No data.pkl found in {path}. Names: {names[:5]}")
pkl_name = pkl_candidates[0]
prefix = pkl_name.split("/")[0]
print(f" Archive prefix: '{prefix}', reading {pkl_name} ...")
data_prefix = f"{prefix}/data/"
data_names = {n[len(data_prefix):]: n for n in names if n.startswith(data_prefix)}
pkl_bytes = zf.read(pkl_name)
unpickler = TorchUnpickler(io.BytesIO(pkl_bytes), zf, prefix)
import pickle as _pickle
original_dispatch = unpickler.dispatch.copy() if hasattr(unpickler, 'dispatch') else {}
unpickler.dispatch_table = {
"__builtin__.object": object,
}
import sys as _sys
fake_torch_utils = type(_sys)("torch._utils")
fake_torch_utils._rebuild_tensor_v2 = rebuild_tensor_v2
fake_torch_utils._rebuild_parameter = lambda data, *a: data
fake_torch_utils._rebuild_from_type_v2 = lambda fn, tp, args, kw: fn(*args)
_sys.modules.setdefault("torch._utils", fake_torch_utils)
fake_collections = type(_sys)("collections")
class OD(dict): pass
fake_collections.OrderedDict = OD
_sys.modules.setdefault("collections", __import__("collections"))
state_dict_raw = unpickler.load()
raw_bytes_map = {}
for key, ref in unpickler._storages.items():
dtype, is_bf16 = ref.storage_type
fname = data_names.get(str(key))
if fname is None:
print(f" WARNING: data file for storage key {key!r} not found")
continue
raw = zf.read(fname)
if is_bf16:
u16 = np.frombuffer(raw, dtype=np.uint16)
f32 = (u16.astype(np.uint32) << 16).view(np.float32)
raw_bytes_map[key] = f32.tobytes()
ref.storage_type = (np.float32, False)
else:
raw_bytes_map[key] = raw
def convert(obj):
if isinstance(obj, FakeTensor):
return obj.to_numpy(raw_bytes_map).astype(np.float32)
if isinstance(obj, dict):
return {k: convert(v) for k, v in obj.items()}
if isinstance(obj, (list, tuple)):
return type(obj)(convert(v) for v in obj)
return obj
state_dict = convert(state_dict_raw)
return state_dict
def main():
parser = argparse.ArgumentParser(description=__doc__.strip())
parser.add_argument("--out", default="models/neucodec_decoder.safetensors")
parser.add_argument("--repo", default="neuphonic/neucodec")
parser.add_argument("--n-heads", type=int, default=16)
args = parser.parse_args()
try:
from huggingface_hub import hf_hub_download
except ImportError:
print("ERROR: pip install huggingface_hub safetensors numpy")
sys.exit(1)
print(f"Downloading pytorch_model.bin from {args.repo} ...")
bin_path = hf_hub_download(repo_id=args.repo, filename="pytorch_model.bin")
print(f" cached at: {bin_path}")
print("Loading checkpoint (pure-Python, no PyTorch required) ...")
state = load_pytorch_bin(bin_path)
if not isinstance(state, dict):
print(f"ERROR: Expected a dict, got {type(state)}. Keys (if any): {list(state.keys())[:5] if hasattr(state, 'keys') else 'N/A'}")
sys.exit(1)
print(f" total tensors: {len(state)}")
decoder_prefixes = ("generator.", "fc_post_a.")
decoder_state = {}
skipped = []
for k, v in state.items():
if not isinstance(v, np.ndarray):
skipped.append(k)
continue
if any(k.startswith(p) for p in decoder_prefixes):
decoder_state[k] = np.ascontiguousarray(v.astype(np.float32))
if skipped:
print(f" skipped {len(skipped)} non-array entries")
print(f" extracted {len(decoder_state)} decoder tensors")
if not decoder_state:
print("ERROR: No decoder tensors found.")
print(" Available keys:", list(state.keys())[:10])
sys.exit(1)
embed_w = decoder_state.get("generator.backbone.embed.weight")
head_w = decoder_state.get("generator.head.out.weight")
if embed_w is None or head_w is None:
print("ERROR: Missing embed/head weights.")
sys.exit(1)
hidden_dim = embed_w.shape[0]
out_dim = head_w.shape[0]
hop_length = (out_dim - 2) // 4
n_fft = hop_length * 4
depth = sum(1 for k in decoder_state
if k.startswith("generator.backbone.transformers.")
and k.endswith(".att_norm.weight"))
print()
print("Detected configuration:")
print(f" hidden_dim = {hidden_dim}")
print(f" depth = {depth} transformer blocks")
print(f" n_heads = {args.n_heads}")
print(f" hop_length = {hop_length} ({24_000 // hop_length} tokens/s at 24 kHz)")
print(f" n_fft = {n_fft}")
print(f" out_dim = {out_dim}")
try:
from safetensors.numpy import save_file
except ImportError:
print("ERROR: pip install safetensors")
sys.exit(1)
os.makedirs(os.path.dirname(os.path.abspath(args.out)), exist_ok=True)
metadata = {
"hidden_dim": str(hidden_dim),
"depth": str(depth),
"n_heads": str(args.n_heads),
"hop_length": str(hop_length),
"source": args.repo,
}
print(f"\nSaving {args.out} ...")
save_file(decoder_state, args.out, metadata=metadata)
size_mb = os.path.getsize(args.out) / 1024 / 1024
print(f" saved {len(decoder_state)} tensors, {size_mb:.0f} MB")
print()
print("Done! Run: cargo build && cargo run --example test_pipeline")
if __name__ == "__main__":
main()