import argparse
import inspect
import sys
import types
import warnings
from pathlib import Path
import numpy as np
import torch
import torch.nn as nn
ENCODER_SAMPLE_RATE = 16_000
ENCODER_SAMPLES_PER_TOKEN = 320
class EncoderWrapper(nn.Module):
def __init__(self, codec):
super().__init__()
self.codec = codec
def forward(self, audio: torch.Tensor) -> torch.Tensor:
codes = self.codec.encode_code(audio_or_path=audio)
if codes.dim() == 3:
codes = codes.squeeze(1)
return codes
def patch_alias_free_torch(wrapper: nn.Module, probe_input: torch.Tensor) -> int:
patched = 0
for class_name in ("UpSample1d", "DownSample1d"):
try:
mod_obj = __import__(
"neucodec.alias_free_torch.resample", fromlist=[class_name]
)
cls = getattr(mod_obj, class_name)
except (ImportError, AttributeError):
continue
channel_map: dict[str, int] = {}
hooks = []
for path, m in wrapper.named_modules():
if type(m) is cls:
def make_hook(p: str):
def h(mod, inp, _out):
channel_map.setdefault(p, int(inp[0].shape[1]))
return h
hooks.append(m.register_forward_hook(make_hook(path)))
with torch.no_grad():
wrapper(probe_input)
for h in hooks:
h.remove()
if not channel_map:
continue
src = inspect.getsource(cls.forward)
conv_fn = (
torch.nn.functional.conv_transpose1d
if "conv_transpose1d" in src
else torch.nn.functional.conv1d
)
pad_mode = "replicate" if ("replicate" in src or "edge" in src or "replication" in src) else "constant"
for path, m in wrapper.named_modules():
if type(m) is cls and path in channel_map:
C = channel_map[path]
pad_left = int(m.pad_left)
pad_right = int(m.pad_right)
stride = (
int(m.stride) if hasattr(m, "stride")
else int(m.ratio) if hasattr(m, "ratio")
else 1
)
fixed_w = m.filter.view(1, 1, -1).repeat(C, 1, 1).detach().clone()
m.register_buffer("_onnx_weight", fixed_w)
def make_fwd(pl: int, pr: int, s: int, c: int, fn, mode: str):
def fwd(self, x: torch.Tensor) -> torch.Tensor:
x = torch.nn.functional.pad(x, (pl, pr), mode=mode)
return fn(x, self._onnx_weight, stride=s, groups=c)
return fwd
m.forward = types.MethodType(
make_fwd(pad_left, pad_right, stride, C, conv_fn, pad_mode), m
)
patched += 1
return patched
def export_onnx(
wrapper: nn.Module,
n_samples: int,
out_path: Path,
opset: int,
) -> int:
dummy = torch.zeros(1, 1, n_samples, dtype=torch.float32)
print(" Computing reference codes (pre-patch) …")
with torch.no_grad():
ref_codes = wrapper(dummy).cpu().numpy()
print(" Patching alias_free_torch sampler modules …")
n_patched = patch_alias_free_torch(wrapper, dummy)
print(f" Patched {n_patched} module(s).")
with torch.no_grad():
patched_codes = wrapper(dummy).cpu().numpy()
n_diff = (ref_codes.astype(np.int32) != patched_codes.astype(np.int32)).sum()
if n_diff == 0:
print(" ✓ Patched model produces identical codes to the original.")
else:
pct = 100.0 * n_diff / ref_codes.size
print(
f" ⚠ {n_diff}/{ref_codes.size} ({pct:.1f}%) tokens differ after patch.\n"
" This reflects boundary changes at the few samples touched by the\n"
" padding mode. For speaker-identity encoding this is negligible."
)
n_tokens = int(patched_codes.shape[-1])
print(f"\n Tracing: input={n_samples} samples ({n_samples/ENCODER_SAMPLE_RATE:.1f} s)"
f" → {n_tokens} tokens")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
warnings.filterwarnings("ignore", category=UserWarning)
warnings.filterwarnings("ignore", category=DeprecationWarning)
torch.onnx.export(
wrapper,
(dummy,),
str(out_path),
input_names = ["audio"],
output_names = ["codes"],
opset_version = opset,
do_constant_folding= True,
dynamo = False,
)
print(f" Written: {out_path} ({out_path.stat().st_size / 1e6:.1f} MB)")
return n_tokens
def verify_with_ort(
out_path: Path,
n_samples: int,
patched_codes: np.ndarray,
) -> None:
try:
import onnxruntime as ort
except ImportError:
print(" (onnxruntime not installed — skipping ORT verification)")
return
print("\nVerifying with OnnxRuntime …")
sess = ort.InferenceSession(str(out_path))
inp_info = sess.get_inputs()[0]
print(f" ORT model input : name={inp_info.name!r} shape={inp_info.shape}")
print(f" ORT model output: name={sess.get_outputs()[0].name!r} "
f"shape={sess.get_outputs()[0].shape}")
dummy_np = np.zeros((1, 1, n_samples), dtype=np.float32)
ort_codes = sess.run(None, {"audio": dummy_np})[0].astype(np.int32)
print(f" ORT output: shape={ort_codes.shape} dtype={ort_codes.dtype}")
print(f" Code range: [{ort_codes.min()}, {ort_codes.max()}]")
ref = patched_codes.astype(np.int32)
n_diff = (ort_codes != ref).sum()
if n_diff == 0:
print(" ✓ ORT codes match patched PyTorch codes exactly.")
else:
pct = 100.0 * n_diff / ref.size
print(
f" ⚠ {n_diff}/{ref.size} ({pct:.1f}%) ORT tokens differ from PyTorch.\n"
" Usually caused by non-deterministic quantisation rounding."
)
def main() -> None:
parser = argparse.ArgumentParser(
description="Export NeuCodec PyTorch encoder → ONNX",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("--repo", default="neuphonic/neucodec",
help="HuggingFace repo for the PyTorch NeuCodec model")
parser.add_argument("--out", default="neucodec_encoder.onnx",
help="Output ONNX file path")
parser.add_argument("--opset", type=int, default=17,
help="ONNX opset version")
parser.add_argument("--max-duration-s", type=float, default=30.0,
help=(
"Fixed input duration the model is exported for (seconds). "
"Rust zero-pads shorter clips and trims longer ones."
))
args = parser.parse_args()
for pkg, install in [
("neucodec", "neucodec"),
("onnx", "onnx"),
("onnxscript", "onnxscript"),
]:
try:
__import__(pkg)
except ImportError:
print(f"ERROR: '{pkg}' not installed. Run: pip install {install}",
file=sys.stderr)
sys.exit(1)
from neucodec import NeuCodec
print(f"Loading {args.repo} …")
codec = NeuCodec.from_pretrained(args.repo)
codec.eval()
wrapper = EncoderWrapper(codec)
n_samples = (
int(args.max_duration_s * ENCODER_SAMPLE_RATE)
// ENCODER_SAMPLES_PER_TOKEN
* ENCODER_SAMPLES_PER_TOKEN
)
n_tokens_expected = n_samples // ENCODER_SAMPLES_PER_TOKEN
print(f"\nFixed input : {n_samples} samples = {n_samples/ENCODER_SAMPLE_RATE:.1f} s "
f"→ {n_tokens_expected} tokens")
out_path = Path(args.out)
print(f"\nExporting to {out_path} (opset {args.opset}) …")
n_tokens = export_onnx(wrapper, n_samples, out_path, args.opset)
dummy = torch.zeros(1, 1, n_samples, dtype=torch.float32)
with torch.no_grad():
patched_codes = wrapper(dummy).cpu().numpy()
verify_with_ort(out_path, n_samples, patched_codes)
print(f"""
Done → {out_path}
Fixed input : {n_samples} samples ({n_samples/ENCODER_SAMPLE_RATE:.1f} s @ 16 kHz)
Fixed output : {n_tokens} tokens ({n_tokens*ENCODER_SAMPLES_PER_TOKEN/ENCODER_SAMPLE_RATE:.1f} s)
Rust encoder behaviour:
• clips shorter than {n_samples/ENCODER_SAMPLE_RATE:.0f} s → zero-padded
• clips longer than {n_samples/ENCODER_SAMPLE_RATE:.0f} s → truncated
• output tokens always trimmed to floor(clip_len / 320)
If your reference audio is longer than {args.max_duration_s:.0f} s, re-export:
python scripts/export_encoder.py --max-duration-s 60
Encode a reference clip:
cargo run --example encode_reference -- --audio reference.wav
Voice clone (caches encoding automatically):
cargo run --example clone_voice --features espeak -- \\
--ref-audio reference.wav --text 'Hello!'
""")
if __name__ == "__main__":
main()