voxcpm-rs
Pure-Rust inference for VoxCPM2 — a zero-shot text-to-speech model with voice cloning — built on top of the Burn ML framework.
Runs locally on your machine via Vulkan (AMD, NVIDIA, Intel) or a pure-CPU fallback. No Python, no CUDA, no ONNX runtime — just a cargo dependency.
let model: = from_local?;
let wav = model.generate?;
write_wav?;
Contents
- Why
- Quick start
- Backends & features
- API tour
- Architecture
- Examples
- Contributing
- Related projects
- License
Why
The upstream VoxCPM2 reference is Python + PyTorch + CUDA. That is a heavy dependency tree to ship inside a desktop app, a game, a CLI tool, or any other Rust project where you want offline, on-device TTS.
voxcpm-rs is a single cargo add away and runs on:
- Any Vulkan-capable GPU (AMD, NVIDIA, Intel, Apple via MoltenVK).
- Pure CPU with SIMD elementwise ops, optionally with vendored OpenBLAS for multi-core matmul — no system libraries required.
It aims to stay faithful to the official implementation (see vendor/VoxCPM) while
exposing a small, idiomatic Rust API.
Quick start
-
Grab a checkpoint. Download the VoxCPM2 weights from Hugging Face:
You should end up with a directory containing
config.json,tokenizer.json,model.safetensors, andaudiovae.pth. The crate consumes this layout as-shipped — no manual weight conversion step is required. See Model files below for the full accepted layout. -
Add the crate:
# Cargo.toml [] = { = "0.1", = false, = ["wgpu"] } -
Synthesize something:
use ; type B = Wgpu;VoxCPM::generatetakes&self, so one loaded model can serve any number of sequential synthesis calls without reloading. Note however thatVoxCPM<B>is notSync— burn'sParam<Tensor<...>>wraps astd::cell::OnceCellfor lazy device materialization, which transitively makes the whole model!Sync. To share a single loaded model across threads or async tasks, wrap it inArc<Mutex<VoxCPM<B>>>(orArc<parking_lot::Mutex<...>>) and serializegeneratecalls; for true parallel inference, load oneVoxCPM<B>per worker. -
Or just run the bundled example:
Model files
VoxCPM::from_local expects a directory with:
| File | Purpose | Format accepted |
|---|---|---|
config.json |
Model architecture config | JSON |
tokenizer.json |
HuggingFace tokenizer | JSON |
model.safetensors / model.pth |
LM + DiT backbone weights | SafeTensors preferred, .pth/.pt fallback |
audiovae.safetensors / audiovae.pth |
AudioVAE decoder weights | SafeTensors preferred, .pth fallback |
The upstream HF repo currently ships model.safetensors + audiovae.pth; both
work directly with no conversion. PyTorch state_dict./model./module.
top-level container prefixes are stripped automatically.
Weight loading takes ~20–25 s on first call (a 4.3 GB BF16 backbone is upcast
to F32 for the wgpu backend — WGSL has no BF16 type). The cost is paid
once per from_local; subsequent generate() calls are free of any I/O.
Load-phase progress is reported via the log
crate, so wiring up env_logger / tracing-log surfaces it.
Backends & features
Pick exactly one backend:
| Feature | Backend | Notes |
|---|---|---|
cpu (default) |
burn-ndarray + SIMD |
Works everywhere. Matmul is single-threaded. |
cpu-blas |
cpu + vendored OpenBLAS |
Multi-core matmul. Builds OpenBLAS from source (no system deps). |
wgpu |
Vulkan / Metal / DX12 | Recommended for GPUs. Fast cold start. |
wgpu-fast |
wgpu + fusion + autotune |
~5–7% faster steady-state; pays a one-time autotune cost (cached). |
vulkan |
Native Vulkan + bf16 weights | ~2.6× faster than wgpu on AMD RDNA4. Requires a patch — see below. |
# CPU + BLAS
# Vulkan, tuned
Tip: with
wgpu-fast, setCUBECL_AUTOTUNE_LEVEL=minimalto shrink the first-run autotune cost. Results are cached intarget/autotune/.
Bf16 Vulkan backend (opt-in, fastest path)
The vulkan feature uses Burn's native Vulkan backend and runs the model in
bf16 end-to-end (the upstream weight dtype — no f32 upcast, half the VRAM,
substantially faster on bf16-capable hardware). Verified ~2.6× speedup over
wgpu on an AMD RX 9070 XT (RDNA4).
It needs two small patches that aren't in the released burn-cubecl /
cubecl-spirv crates yet — one fixes a conv accumulator dtype, the other
promotes a handful of bf16 SPIR-V ops that mesa's NIR translator doesn't lower
correctly. Add this to your project's Cargo.toml (alongside
voxcpm-rs = { …, features = ["vulkan"] }):
[]
= { = "https://github.com/mii-nipah/voxcpm-rs", = "main" }
= { = "https://github.com/mii-nipah/voxcpm-rs", = "main" }
That's it — cargo clones the repo, finds the patched crates by name, and
rebuilds. No mesa rebuild, no environment variables, no extra steps. Pin to a
specific rev = "…" instead of branch = "main" for reproducible builds.
Why a patch and not a published crate? Cargo's
[patch.crates-io]only takes effect at the workspace root, so a library can't transparently pull in a patched dependency on its consumers' behalf — the patch block must live in the consumer's manifest either way. Agit = "…"reference is the lowest- friction form that doesn't require maintaining renamed forks on crates.io. Seepatches/README.mdfor the patch contents and rationale.
API tour
Zero-shot synthesis
let wav = model.generate?;
Voice cloning
Provide a short reference clip (ideally a few seconds of clean speech):
use Prompt;
let opts = builder
.prompt
.build;
let wav = model.generate?;
Or continue from an existing utterance (the model picks up after audio):
let opts = builder
.prompt
.build;
Audio from memory
Prompt audio doesn't have to live on disk. PromptAudio
accepts three sources — a path, already-encoded bytes, or raw PCM samples — so
you can plug the model into an in-memory pipeline (microphone capture, HTTP
upload, another TTS stage, …):
use ;
// 1. From a file path (the default — `Into<PromptAudio>` is implemented for
// `&str`, `&Path` and `PathBuf`):
let a = Reference ;
// 2. From encoded bytes in memory (any format Symphonia supports):
let bytes: = read?;
let b = Reference ;
// 3. From raw mono f32 PCM you already have:
let c = Reference ;
Symmetrically, audio::load_audio_bytes /
audio::load_audio_bytes_as let you decode encoded audio
buffers without touching the filesystem.
Streaming
For real-time playback, network streaming, or just to start hearing audio
before the whole utterance is ready, use
VoxCPM::generate_stream. It returns an iterator
of Result<Vec<f32>> chunks at model.sample_rate():
let opts = builder
.chunk_patches // ~400 ms / chunk at the default model config
.build;
for chunk in model.generate_stream?
Concatenating every chunk yields exactly the same waveform generate()
would have returned — chunk boundaries are seamless because the AudioVAE
decoder is causal. chunk_patches trades latency for throughput: smaller
→ lower per-chunk latency, larger → fewer chunks. The default 5 is a
sensible balance for live playback.
See examples/tts_stream.rs for an end-to-end
run with per-chunk timing.
Implementation note. The autoregressive loop (LM + DiT) runs incrementally with KV-cache, so streaming adds no AR overhead compared to
generate(). The AudioVAE decoder, however, is currently stateless across chunks — each chunk re-decodes the cumulative latent and emits only the new tail samples, making total VAE workO(N²/chunk_patches)over an utterance instead ofO(N). AR cost dominates in practice, so the difference is rarely visible.
Throughput: batched & parallel-segment generation
Single-utterance inference at batch size 1 is launch-bound on most modern GPUs — the kernels are fast, but each one carries fixed dispatch overhead and each weight matrix is re-read from VRAM per call. Both costs amortize beautifully across a larger batch, so running multiple sequences through one forward pass gives close-to-linear speedup until you hit the actual compute or memory ceiling. This benefits every backend (Vulkan / wgpu / CPU, fp32 or bf16) — it is a property of the dispatch model, not of the numeric format.
voxcpm-rs exposes two complementary APIs that share the same right-pad
batched-prefill + per-element stop machinery underneath.
VoxCPM::batch() — independent utterances at once
When you have several unrelated requests (a server handling N clients, a
batch job rendering many lines), put them into one batch and get one PCM
buffer per item back, in order. Each item carries its own [Prompt], so
different items can use different reference voices in the same batch.
use ;
let outs: = model
.batch
.add
.add
.add
.run?;
for in outs.iter.enumerate
Measured on an AMD RX 9070 XT (Vulkan + bf16, 8 short utterances):
| Mode | Wall time | Audio | RTF | Speedup |
|---|---|---|---|---|
| serial (b=1) | 19.9 s | 30.1s | 0.66 | 1.00× |
batch b=2 |
13.1 s | 29.9s | 0.44 | 1.52× |
batch b=4 |
9.9 s | 29.8s | 0.33 | 2.00× |
batch b=8 |
8.6 s | 29.0s | 0.30 | 2.31× |
RTF below 1.0 means faster than realtime — at b=8 the GPU produces audio ~3.4× faster than playback speed.
parallel_segments — split one paragraph, share one voice
For a single long text (a book chapter, a long reply), set
[GenerateOptions::parallel_segments(n)]: voxcpm-rs splits the text on
sentence boundaries and feeds groups of n segments through the same
batched path. To keep the voice consistent across sentences when no
reference audio is supplied, the first segment is generated serially and
its audio is encoded as the reference for the rest ("self-seeding");
with [Prompt::Reference] the user-provided voice is used directly and
everything runs batched.
let opts = builder
.parallel_segments // batch size for the segment groups
.build;
let wav = model.generate?;
Same hardware, 10-sentence paragraph:
| Mode | Wall time | Audio | RTF | Speedup vs per-sentence serial |
|---|---|---|---|---|
| per-sentence serial | 22.0 s | 34.1s | 0.65 | 1.00× |
parallel_segments(2) |
24.8 s | 49.9s | 0.50 | 0.89× |
parallel_segments(4) |
18.6 s | 43.2s | 0.43 | 1.18× |
parallel_segments(8) |
12.8 s | 34.1s | 0.38 | 1.72× |
(The audio-length differences come from the per-sentence stop head firing at slightly different points; RTF is the apples-to-apples number.)
Which one to use? If you have multiple independent inputs, prefer
batch() — there is no first-segment serial step, so the speedup is
purely batched. If you have one long text and want the whole thing
ready faster, use parallel_segments.
Warning
Parallel segments may degrade voice consistency and quality, use with caution.
How far does batching scale?
Batching helps as long as the GPU is launch-bound; once each step saturates compute or memory, adding more elements just adds proportional work. To find the sweet spot we generated the same medium sentence N times in one batch (so no element dominates) and swept N. RX 9070 XT, Vulkan + bf16:
| Batch | Wall time | Audio | RTF | Throughput | Speedup vs b=1 |
|---|---|---|---|---|---|
| 1 | 4.3 s | 5.8s | 0.75 | 1.34× | 1.00× |
| 2 | 4.3 s | 8.8s | 0.49 | 2.05× | 2.00× (free) |
| 4 | 5.9 s | 18.2s | 0.32 | 3.12× | 2.94× |
| 8 | 9.6 s | 36.2s | 0.27 | 3.75× | 3.57× |
| 16 | 22.7 s | 76.5s | 0.30 | 3.37× | 3.03× |
| 32 | 54.8 s | 151.7s | 0.36 | 2.77× | 2.51× |
| 64 | 192.5 s | 316.5s | 0.61 | 1.64× | 1.43× |
Takeaways:
- b=1 → b=2 is literally free — at b=1 the GPU is 100 % launch-bound, so the second sequence rides along at zero extra cost.
- b=8 is the sweet spot on this card — peak throughput 3.75× realtime.
- b≥16 starts regressing. Past saturation, more batch members do not hide step cost; they only add proportional work, and at b=64 something (likely allocator pressure or an autotune miss for the rare giant shape) makes things noticeably worse.
- These numbers are hardware-specific. The shape of the curve
(free-doubling at small B, peak somewhere around 4–8, regression past
the GPU's saturation point) is universal — re-run
examples/batch_scale_sweep.rson your own hardware to find your own sweet spot.
For a server batching independent requests, target b=4–8 and queue beyond that; for latency-sensitive interactive use, treat b=8 as the upper bound.
Tuning knobs
All options flow through the fluent builder:
let opts = builder
.cfg // classifier-free guidance; 1.5–3.0 is typical
.timesteps // diffusion Euler steps; fewer = faster, <6 degrades
.min_len
.max_len // hard cap on generated latent patches (~80 ms each)
.chunk_patches // patches per chunk in `generate_stream`
.build;
Cancellation
Long generations can be cancelled cooperatively from another thread via
CancelToken. The autoregressive loop polls the token between every
diffusion step, so cancel latency is bounded by one step
(~200 ms on wgpu at default timesteps=10).
use ;
use ;
let cancel = new;
let opts = builder.cancel.build;
match model.generate
CancelToken is Clone + Send + Sync (an Arc<AtomicBool> underneath),
so you can hand copies to as many watchers as you like.
Architecture
VoxCPM2 is a cascade of four components — each lives in its own module:
text ──► tokenizer ──► minicpm4 (LM backbone) ──► locenc ──► locdit (diffusion) ──► audiovae ──► wav
| Module | Role |
|---|---|
tokenizer |
HF tokenizers wrapper for the LlamaTokenizerFast vocab. |
minicpm4 |
Decoder-only LM backbone (rotary attention + KV cache). |
locenc |
Local encoder — conditions the diffusion head on LM hidden states. |
locdit |
Local DiT + conditional flow-matching sampler. |
audiovae |
VAE decoder that turns FSQ patches into 16 kHz audio. |
voxcpm2 |
Glue + convenient VoxCPM façade. |
Weights are loaded directly from .safetensors or .pth via
burn-store with the PyTorchToBurnAdapter,
so HuggingFace checkpoints drop in with no manual conversion step.
Examples
Browse examples/ for standalone binaries:
tts.rs— end-to-end synthesis.tts_stream.rs— chunked streaming synthesis with per-chunk latency logging.clone.rs— voice cloning from a reference wav.bench_parallel.rs— RTF benchmark forparallel_segments(one long paragraph).bench_batch.rs— RTF benchmark forVoxCPM::batch()(many independent utterances).batch_varlen.rs— 8 wildly-different-length utterances in one batched call (writes to/tmp/voxbatching/).batch_scale_sweep.rs— sweep batch sizes 1→64 with uniform-length input to find your hardware's saturation point.lm_check.rs,vae_check.rs,feat_check.rs— per-component parity checks against the reference implementation.bench_rmsnorm.rs— microbench for hot kernels.
Contributing
Contributions are very welcome — especially:
- Bug reports with a minimal repro and the backend/feature flags you used.
- Performance PRs (kernels, memory layout, KV cache, sampler).
- New backends supported by Burn (CUDA, Metal direct, etc.).
Before opening a PR:
cargo fmt --allandcargo clippy --all-targets.cargo test --no-default-features --features cpu.- If you touched a numeric path, run the matching
*_checkexample against a real checkpoint and include the RTF / parity numbers in the PR description.
Keep PRs focused — one feature or fix per PR makes review much easier.
Related projects
- VoxCPM (official, Python) — the
reference implementation this crate tracks. A copy lives under
vendor/VoxCPMfor parity testing. - Burn — the ML framework powering all the tensor math here.
- cubecl — the GPU kernel compiler
behind Burn's
wgpubackend.
License
Licensed under the Apache License, Version 2.0. The vendored reference
implementation under vendor/VoxCPM/ (kept in the repository for parity testing,
not shipped on crates.io) retains its own license — see the
upstream LICENSE.