use std::path::{Path, PathBuf};
use anyhow::{bail, Context, Result};
use ndarray::{s, Array1, Array2, Array3, ArrayView1, ArrayView2, ArrayView3};
use rustfft::{num_complex::Complex, FftPlanner};
use safetensors::SafeTensors;
pub const SAMPLE_RATE: u32 = 24_000;
pub const ENCODER_SAMPLE_RATE: u32 = 16_000;
pub const SAMPLES_PER_TOKEN: usize = 480;
pub const ENCODER_SAMPLES_PER_TOKEN: usize = 320;
pub const ENCODER_DEFAULT_INPUT_SAMPLES: usize = 16_000 * 10;
pub fn wgpu_feature_enabled() -> bool {
cfg!(feature = "wgpu")
}
pub(crate) const FSQ_LEVELS: [i32; 8] = [4, 4, 4, 4, 4, 4, 4, 4];
pub(crate) const FSQ_BASIS: [i32; 8] = [1, 4, 16, 64, 256, 1_024, 4_096, 16_384];
fn load_f32(st: &SafeTensors<'_>, name: &str) -> Result<Vec<f32>> {
let view = st
.tensor(name)
.with_context(|| format!("Missing weight: {name}"))?;
let raw = view.data();
use safetensors::tensor::Dtype;
Ok(match view.dtype() {
Dtype::F32 => {
assert!(raw.len() % 4 == 0, "F32 tensor byte length not divisible by 4");
let n = raw.len() / 4;
let mut out = Vec::with_capacity(n);
#[cfg(target_endian = "little")]
{
unsafe {
std::ptr::copy_nonoverlapping(
raw.as_ptr(),
out.as_mut_ptr() as *mut u8,
raw.len(),
);
out.set_len(n);
}
}
#[cfg(not(target_endian = "little"))]
{
out.extend(raw.chunks_exact(4).map(|b| {
f32::from_le_bytes([b[0], b[1], b[2], b[3]])
}));
}
out
}
Dtype::BF16 => raw
.chunks_exact(2)
.map(|b| {
let bits = u16::from_le_bytes([b[0], b[1]]);
f32::from_bits((bits as u32) << 16)
})
.collect(),
dt => bail!("Tensor {name}: unsupported dtype {dt:?} (expected F32 or BF16)"),
})
}
fn shape_of(st: &SafeTensors<'_>, name: &str) -> Result<Vec<usize>> {
Ok(st
.tensor(name)
.with_context(|| format!("Missing weight: {name}"))?
.shape()
.to_vec())
}
fn as1d(data: Vec<f32>, n: usize) -> Array1<f32> {
Array1::from_shape_vec(n, data).expect("1-D shape mismatch")
}
fn as2d(data: Vec<f32>, rows: usize, cols: usize) -> Array2<f32> {
Array2::from_shape_vec((rows, cols), data).expect("2-D shape mismatch")
}
fn as3d(data: Vec<f32>, d0: usize, d1: usize, d2: usize) -> Array3<f32> {
Array3::from_shape_vec((d0, d1, d2), data).expect("3-D shape mismatch")
}
fn linear(x: ArrayView2<f32>, w: ArrayView2<f32>, b: Option<ArrayView1<f32>>) -> Array2<f32> {
let mut out = x.dot(&w.t()); if let Some(b) = b {
out += &b;
}
out
}
fn conv1d(
x: ArrayView2<f32>,
w: ArrayView3<f32>,
b: Option<ArrayView1<f32>>,
pad: usize,
) -> Array2<f32> {
let (c_in, t) = (x.shape()[0], x.shape()[1]);
let (c_out, _, k) = (w.shape()[0], w.shape()[1], w.shape()[2]);
let mut col = Array2::<f32>::zeros((t, c_in * k));
for ti in 0..t {
for ci in 0..c_in {
for ki in 0..k {
let src = ti + ki;
if src >= pad && src < t + pad {
col[[ti, ci * k + ki]] = x[[ci, src - pad]];
}
}
}
}
let w2 = w.into_shape_with_order((c_out, c_in * k)).expect("conv1d reshape");
let out_t = col.dot(&w2.t());
let mut out = out_t.t().to_owned();
if let Some(b) = b {
use ndarray::Axis;
out += &b.view().insert_axis(Axis(1));
}
out
}
fn group_norm(
x: ArrayView2<f32>,
n_groups: usize,
w: ArrayView1<f32>,
b: ArrayView1<f32>,
eps: f32,
) -> Array2<f32> {
let (c, t) = (x.shape()[0], x.shape()[1]);
let group_size = c / n_groups;
let n = (group_size * t) as f32;
let mut out = Array2::<f32>::zeros((c, t));
for g in 0..n_groups {
let c_start = g * group_size;
let c_end = c_start + group_size;
let block = x.slice(s![c_start..c_end, ..]);
let mean = block.iter().sum::<f32>() / n;
let var = block.iter().map(|&v| { let d = v - mean; d * d }).sum::<f32>() / n;
let inv_std = 1.0 / (var + eps).sqrt();
for ci in c_start..c_end {
let scale = inv_std * w[ci];
let shift = b[ci];
for ti in 0..t {
out[[ci, ti]] = (x[[ci, ti]] - mean) * scale + shift;
}
}
}
out
}
fn layer_norm(
x: ArrayView2<f32>,
w: ArrayView1<f32>,
b: ArrayView1<f32>,
eps: f32,
) -> Array2<f32> {
let (t, c) = (x.shape()[0], x.shape()[1]);
let c_f = c as f32;
let mut out = Array2::<f32>::zeros((t, c));
for ti in 0..t {
let row = x.slice(s![ti, ..]);
let mean = row.iter().sum::<f32>() / c_f;
let var = row.iter().map(|&v| { let d = v - mean; d * d }).sum::<f32>() / c_f;
let inv_std = 1.0 / (var + eps).sqrt();
for ci in 0..c {
out[[ti, ci]] = (x[[ti, ci]] - mean) * inv_std * w[ci] + b[ci];
}
}
out
}
fn rms_norm(x: ArrayView2<f32>, w: ArrayView1<f32>, eps: f32) -> Array2<f32> {
let (t, c) = (x.shape()[0], x.shape()[1]);
let c_f = c as f32;
let mut out = Array2::<f32>::zeros((t, c));
for ti in 0..t {
let row = x.slice(s![ti, ..]);
let ms = row.iter().map(|&v| v * v).sum::<f32>() / c_f;
let scale = 1.0 / (ms + eps).sqrt();
for ci in 0..c {
out[[ti, ci]] = x[[ti, ci]] * scale * w[ci];
}
}
out
}
#[inline(always)]
fn silu(x: f32) -> f32 {
x / (1.0 + (-x).exp())
}
fn softmax_inplace(x: &mut [f32]) {
let max = x.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
let mut sum = 0.0f32;
x.iter_mut().for_each(|v| {
*v = (*v - max).exp();
sum += *v;
});
x.iter_mut().for_each(|v| *v /= sum);
}
fn fsq_decode(
codes: &[i32],
proj_w: ArrayView2<f32>, proj_b: ArrayView1<f32>, ) -> Array2<f32> {
let t = codes.len();
let _out_dim = proj_w.shape()[0];
let mut digits = Array2::<f32>::zeros((t, FSQ_BASIS.len()));
for (i, &code) in codes.iter().enumerate() {
for (j, (&basis, &levels)) in FSQ_BASIS.iter().zip(FSQ_LEVELS.iter()).enumerate() {
let d = (code / basis) % levels;
digits[[i, j]] = d as f32 / 1.5 - 1.0;
}
}
linear(digits.view(), proj_w, Some(proj_b))
}
#[cfg(not(feature = "precise"))]
#[inline(always)]
pub(crate) fn rope_sin_cos(x: f32) -> (f32, f32) {
use std::f32::consts::TAU;
let x = x - TAU * (x * (1.0 / TAU)).round();
let x2 = x * x;
let s = x * (1.0 + x2 * (-1.0/6.0 + x2 * (1.0/120.0 - x2 * (1.0/5040.0))));
let c = 1.0 + x2 * (-0.5 + x2 * (1.0/24.0 - x2 * (1.0/720.0)));
(s, c)
}
#[cfg(feature = "precise")]
#[inline(always)]
pub(crate) fn rope_sin_cos(x: f32) -> (f32, f32) {
x.sin_cos()
}
fn apply_rope(x: &mut Array3<f32>) {
let (t, n_heads, head_dim) = (x.shape()[0], x.shape()[1], x.shape()[2]);
let half = head_dim / 2;
let inv_freqs: Vec<f32> = (0..half)
.map(|i| 1.0_f32 / 10_000_f32.powf(2.0 * i as f32 / head_dim as f32))
.collect();
for p in 0..t {
let p_f = p as f32;
for i in 0..half {
let (s, c) = rope_sin_cos(p_f * inv_freqs[i]);
for h in 0..n_heads {
let x1 = x[[p, h, i]];
let x2 = x[[p, h, i + half]];
x[[p, h, i]] = x1 * c - x2 * s;
x[[p, h, i + half]] = x1 * s + x2 * c;
}
}
}
}
pub(crate) struct TransformerWeights {
pub(crate) att_norm_w: Array1<f32>, pub(crate) c_attn_w: Array2<f32>, pub(crate) c_proj_w: Array2<f32>, pub(crate) ffn_norm_w: Array1<f32>, pub(crate) fc1_w: Array2<f32>, pub(crate) fc2_w: Array2<f32>, }
fn transformer_block(x: ArrayView2<f32>, w: &TransformerWeights, n_heads: usize) -> Array2<f32> {
let (t, d) = (x.shape()[0], x.shape()[1]);
let head_dim = d / n_heads;
let normed = rms_norm(x, w.att_norm_w.view(), 1e-6);
let qkv = linear(normed.view(), w.c_attn_w.view(), None);
let q_flat = qkv.slice(s![.., 0..d]).to_owned();
let k_flat = qkv.slice(s![.., d..2 * d]).to_owned();
let v_flat = qkv.slice(s![.., 2 * d..]).to_owned();
let mut q = q_flat
.into_shape_with_order((t, n_heads, head_dim))
.expect("q reshape");
let mut k = k_flat
.into_shape_with_order((t, n_heads, head_dim))
.expect("k reshape");
let v = v_flat
.into_shape_with_order((t, n_heads, head_dim))
.expect("v reshape");
apply_rope(&mut q);
apply_rope(&mut k);
let scale = (head_dim as f32).sqrt().recip();
let mut attn_out = Array3::<f32>::zeros((t, n_heads, head_dim));
for h in 0..n_heads {
let qh = q.slice(s![.., h, ..]).to_owned(); let kh = k.slice(s![.., h, ..]).to_owned();
let vh = v.slice(s![.., h, ..]).to_owned();
let mut scores = qh.dot(&kh.t());
scores.mapv_inplace(|v| v * scale);
for ti in 0..t {
softmax_inplace(scores.slice_mut(s![ti, ..]).as_slice_mut().unwrap());
}
let wv = scores.dot(&vh);
attn_out.slice_mut(s![.., h, ..]).assign(&wv);
}
let attn_flat = attn_out
.into_shape_with_order((t, d))
.expect("attn out reshape");
let attn_proj = linear(attn_flat.view(), w.c_proj_w.view(), None);
let x_attn = &x + &attn_proj;
let normed2 = rms_norm(x_attn.view(), w.ffn_norm_w.view(), 1e-6);
let h1 = linear(normed2.view(), w.fc1_w.view(), None);
let h1_act = h1.mapv(silu);
let h2 = linear(h1_act.view(), w.fc2_w.view(), None);
&x_attn + &h2
}
pub(crate) struct ResnetBlockWeights {
pub(crate) norm1_w: Array1<f32>, pub(crate) norm1_b: Array1<f32>,
pub(crate) conv1_w: Array3<f32>, pub(crate) conv1_b: Array1<f32>,
pub(crate) norm2_w: Array1<f32>,
pub(crate) norm2_b: Array1<f32>,
pub(crate) conv2_w: Array3<f32>, pub(crate) conv2_b: Array1<f32>,
}
fn resnet_block(x: ArrayView2<f32>, w: &ResnetBlockWeights) -> Array2<f32> {
let h = group_norm(x, 32, w.norm1_w.view(), w.norm1_b.view(), 1e-6);
let h = h.mapv(silu);
let h = conv1d(h.view(), w.conv1_w.view(), Some(w.conv1_b.view()), 1);
let h = group_norm(h.view(), 32, w.norm2_w.view(), w.norm2_b.view(), 1e-6);
let h = h.mapv(silu);
let h = conv1d(h.view(), w.conv2_w.view(), Some(w.conv2_b.view()), 1);
&x + &h
}
pub(crate) fn istft_burn(
mag: ArrayView2<f32>,
phase: ArrayView2<f32>,
hop: usize,
window: &[f32],
ifft: &dyn rustfft::Fft<f32>,
) -> Vec<f32> {
let n_bins = mag.shape()[0]; let n_frames = mag.shape()[1];
let n_fft = (n_bins - 1) * 2;
debug_assert_eq!(n_fft, window.len());
debug_assert_eq!(hop, n_fft / 4);
let out_size = (n_frames - 1) * hop + n_fft;
let mut y = vec![0.0f32; out_size];
let mut env = vec![0.0f32; out_size];
let mut buf = vec![Complex::<f32>::default(); n_fft];
for ti in 0..n_frames {
for fi in 0..n_bins {
let m = mag[[fi, ti]].exp().min(1e2); let p = phase[[fi, ti]];
buf[fi] = Complex::new(m * p.cos(), m * p.sin());
}
for fi in 1..n_bins - 1 {
buf[n_fft - fi] = buf[fi].conj();
}
ifft.process(&mut buf);
let norm = n_fft as f32;
let offset = ti * hop;
for i in 0..n_fft {
let sample = buf[i].re / norm * window[i];
y[offset + i] += sample;
env[offset + i] += window[i] * window[i];
}
}
for i in 0..out_size {
if env[i] > 1e-11 {
y[i] /= env[i];
}
}
let start = n_fft / 2;
let length = n_frames * hop;
y[start..start + length].to_vec()
}
fn hann_window(n: usize) -> Vec<f32> {
(0..n)
.map(|i| 0.5 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / n as f32).cos()))
.collect()
}
pub(crate) struct DecoderWeights {
pub(crate) fsq_proj_w: Array2<f32>, pub(crate) fsq_proj_b: Array1<f32>,
pub(crate) fc_post_a_w: Array2<f32>, pub(crate) fc_post_a_b: Array1<f32>,
pub(crate) embed_w: Array3<f32>, pub(crate) embed_b: Array1<f32>,
pub(crate) prior_net: Vec<ResnetBlockWeights>,
pub(crate) transformers: Vec<TransformerWeights>,
pub(crate) final_norm_w: Array1<f32>,
pub(crate) final_norm_b: Array1<f32>,
pub(crate) post_net: Vec<ResnetBlockWeights>,
pub(crate) head_w: Array2<f32>, pub(crate) head_b: Array1<f32>,
pub(crate) window: Vec<f32>,
pub(crate) hidden_dim: usize,
pub(crate) hop_length: usize,
pub(crate) depth: usize,
pub(crate) n_heads: usize,
pub(crate) ifft_plan: std::sync::Arc<dyn rustfft::Fft<f32>>,
}
fn load_resnet_block(st: &SafeTensors<'_>, prefix: &str, c: usize) -> Result<ResnetBlockWeights> {
Ok(ResnetBlockWeights {
norm1_w: as1d(load_f32(st, &format!("{prefix}.norm1.weight"))?, c),
norm1_b: as1d(load_f32(st, &format!("{prefix}.norm1.bias"))?, c),
conv1_w: as3d(load_f32(st, &format!("{prefix}.conv1.weight"))?, c, c, 3),
conv1_b: as1d(load_f32(st, &format!("{prefix}.conv1.bias"))?, c),
norm2_w: as1d(load_f32(st, &format!("{prefix}.norm2.weight"))?, c),
norm2_b: as1d(load_f32(st, &format!("{prefix}.norm2.bias"))?, c),
conv2_w: as3d(load_f32(st, &format!("{prefix}.conv2.weight"))?, c, c, 3),
conv2_b: as1d(load_f32(st, &format!("{prefix}.conv2.bias"))?, c),
})
}
fn load_transformer(st: &SafeTensors<'_>, prefix: &str, d: usize) -> Result<TransformerWeights> {
Ok(TransformerWeights {
att_norm_w: as1d(load_f32(st, &format!("{prefix}.att_norm.weight"))?, d),
c_attn_w: as2d(load_f32(st, &format!("{prefix}.att.c_attn.weight"))?, 3 * d, d),
c_proj_w: as2d(load_f32(st, &format!("{prefix}.att.c_proj.weight"))?, d, d),
ffn_norm_w: as1d(load_f32(st, &format!("{prefix}.ffn_norm.weight"))?, d),
fc1_w: as2d(load_f32(st, &format!("{prefix}.mlp.fc1.weight"))?, 4 * d, d),
fc2_w: as2d(load_f32(st, &format!("{prefix}.mlp.fc2.weight"))?, d, 4 * d),
})
}
fn load_decoder_weights(
st: &SafeTensors<'_>,
user_meta: &Option<std::collections::HashMap<String, String>>,
) -> Result<DecoderWeights> {
let embed_shape = shape_of(st, "generator.backbone.embed.weight")?;
let hidden_dim = embed_shape[0];
let head_shape = shape_of(st, "generator.head.out.weight")?;
let out_dim = head_shape[0]; let hop_length = (out_dim - 2) / 4;
let depth = (0..64)
.take_while(|&i| {
st.tensor(&format!(
"generator.backbone.transformers.{i}.att_norm.weight"
))
.is_ok()
})
.count();
if depth == 0 {
bail!("No transformer blocks found — is the safetensors file correct?");
}
let n_heads: usize = user_meta
.as_ref()
.and_then(|m| m.get("n_heads"))
.and_then(|s| s.parse().ok())
.unwrap_or(16);
let fsq_proj_key = if st.tensor("generator.quantizer.fsqs.0.project_out.weight").is_ok() {
"generator.quantizer.fsqs.0.project_out.weight"
} else {
"generator.quantizer.project_out.weight"
};
let fsq_bias_key = if st.tensor("generator.quantizer.fsqs.0.project_out.bias").is_ok() {
"generator.quantizer.fsqs.0.project_out.bias"
} else {
"generator.quantizer.project_out.bias"
};
let fsq_shape = shape_of(st, fsq_proj_key)?;
let fsq_out_dim = fsq_shape[0]; let fsq_in_dim = fsq_shape[1];
let fsq_proj_w = as2d(
load_f32(st, fsq_proj_key)?,
fsq_out_dim,
fsq_in_dim,
);
let fsq_proj_b = as1d(
load_f32(st, fsq_bias_key)?,
fsq_out_dim,
);
let fc_post_a_w = as2d(
load_f32(st, "fc_post_a.weight")?,
hidden_dim,
fsq_out_dim,
);
let fc_post_a_b = as1d(load_f32(st, "fc_post_a.bias")?, hidden_dim);
let embed_k = embed_shape[2];
let embed_w = as3d(
load_f32(st, "generator.backbone.embed.weight")?,
hidden_dim,
hidden_dim,
embed_k,
);
let embed_b = as1d(
load_f32(st, "generator.backbone.embed.bias")?,
hidden_dim,
);
let prior_net = (0..2)
.map(|i| {
load_resnet_block(
st,
&format!("generator.backbone.prior_net.{i}"),
hidden_dim,
)
})
.collect::<Result<Vec<_>>>()?;
let transformers = (0..depth)
.map(|i| {
load_transformer(
st,
&format!("generator.backbone.transformers.{i}"),
hidden_dim,
)
})
.collect::<Result<Vec<_>>>()?;
let final_norm_w = as1d(
load_f32(st, "generator.backbone.final_layer_norm.weight")?,
hidden_dim,
);
let final_norm_b = as1d(
load_f32(st, "generator.backbone.final_layer_norm.bias")?,
hidden_dim,
);
let post_net = (0..2)
.map(|i| {
load_resnet_block(
st,
&format!("generator.backbone.post_net.{i}"),
hidden_dim,
)
})
.collect::<Result<Vec<_>>>()?;
let n_fft = hop_length * 4;
let head_w = as2d(
load_f32(st, "generator.head.out.weight")?,
out_dim,
hidden_dim,
);
let head_b = as1d(load_f32(st, "generator.head.out.bias")?, out_dim);
let window = if st.tensor("generator.head.istft.window").is_ok() {
load_f32(st, "generator.head.istft.window")?
} else {
hann_window(n_fft)
};
let ifft_plan = {
let mut planner = FftPlanner::<f32>::new();
planner.plan_fft_inverse(n_fft)
};
Ok(DecoderWeights {
fsq_proj_w,
fsq_proj_b,
fc_post_a_w,
fc_post_a_b,
embed_w,
embed_b,
prior_net,
transformers,
final_norm_w,
final_norm_b,
post_net,
head_w,
head_b,
window,
hidden_dim,
hop_length,
depth,
n_heads,
ifft_plan,
})
}
fn decode_forward(codes: &[i32], w: &DecoderWeights) -> Vec<f32> {
let hop = w.hop_length;
let n_fft = hop * 4;
let embed_k = w.embed_w.shape()[2];
let embed_pad = embed_k / 2;
let emb = fsq_decode(codes, w.fsq_proj_w.view(), w.fsq_proj_b.view());
let x = linear(emb.view(), w.fc_post_a_w.view(), Some(w.fc_post_a_b.view()));
let x_ct = x.t().to_owned(); let x_ct = conv1d(
x_ct.view(),
w.embed_w.view(),
Some(w.embed_b.view()),
embed_pad,
);
let x_ct = w
.prior_net
.iter()
.fold(x_ct, |acc, rw| resnet_block(acc.view(), rw));
let x_tc = x_ct.t().to_owned(); let x_tc = w
.transformers
.iter()
.fold(x_tc, |acc, tw| transformer_block(acc.view(), tw, w.n_heads));
let x_ct = x_tc.t().to_owned(); let x_ct = w
.post_net
.iter()
.fold(x_ct, |acc, rw| resnet_block(acc.view(), rw));
let x_tc = x_ct.t().to_owned(); let x_tc = layer_norm(
x_tc.view(),
w.final_norm_w.view(),
w.final_norm_b.view(),
1e-6,
);
let x_pred = linear(x_tc.view(), w.head_w.view(), Some(w.head_b.view()));
let x_pred_ct = x_pred.t().to_owned(); let half = (n_fft / 2) + 1; let mag = x_pred_ct.slice(s![0..half, ..]).to_owned();
let phase = x_pred_ct.slice(s![half.., ..]).to_owned();
istft_burn(mag.view(), phase.view(), hop, &w.window, w.ifft_plan.as_ref())
}
fn default_decoder_path() -> PathBuf {
PathBuf::from("models/neucodec_decoder.safetensors")
}
pub struct NeuCodecDecoder {
weights: DecoderWeights,
path: PathBuf,
#[cfg(feature = "wgpu")]
burn_decoder: std::sync::Mutex<LazyBurnDecoder>,
}
#[cfg(feature = "wgpu")]
enum LazyBurnDecoder {
Ready(Option<Box<dyn crate::codec_burn::BurnDecoder + Send>>),
}
impl NeuCodecDecoder {
pub fn new() -> Result<Self> {
Self::from_file(&default_decoder_path())
}
pub fn from_file(path: &Path) -> Result<Self> {
if !path.exists() {
bail!(
"NeuCodec decoder weights not found: {}\n\
\n\
Run the one-time conversion to generate them:\n\
\n\
\tpython scripts/convert_weights.py\n\
\n\
Or set a custom path with NeuCodecDecoder::from_file().",
path.display()
);
}
let file = std::fs::File::open(path)
.with_context(|| format!("Failed to open {}", path.display()))?;
let mmap = unsafe {
memmap2::Mmap::map(&file)
.with_context(|| format!("Failed to mmap {}", path.display()))?
};
let bytes: &[u8] = &mmap;
let (_, file_meta) = SafeTensors::read_metadata(bytes)
.with_context(|| format!("Failed to parse safetensors header: {}", path.display()))?;
let user_meta = file_meta.metadata().clone();
let st = SafeTensors::deserialize(bytes)
.with_context(|| format!("Failed to parse safetensors: {}", path.display()))?;
let weights = load_decoder_weights(&st, &user_meta)
.with_context(|| format!("Failed to load decoder weights from {}", path.display()))?;
drop(st);
drop(mmap);
println!(
"NeuCodec decoder: hidden={}, depth={}, heads={}, hop={} ({} samples/token = {} tokens/s)",
weights.hidden_dim,
weights.depth,
weights.n_heads,
weights.hop_length,
weights.hop_length,
SAMPLE_RATE as usize / weights.hop_length,
);
#[cfg(feature = "wgpu")]
let burn_decoder = {
let t0 = std::time::Instant::now();
let dec = crate::codec_burn::make_burn_decoder(&weights);
println!(
"NeuCodec: {} backend ready in {:.2} s",
dec.as_ref().map_or("cpu (ndarray)", |b| b.backend_name()),
t0.elapsed().as_secs_f32(),
);
std::sync::Mutex::new(LazyBurnDecoder::Ready(dec))
};
Ok(Self {
weights,
path: path.to_path_buf(),
#[cfg(feature = "wgpu")]
burn_decoder,
})
}
pub fn decode(&self, codes: &[i32]) -> Result<Vec<f32>> {
if codes.is_empty() {
return Ok(Vec::new());
}
for (i, &code) in codes.iter().enumerate() {
if !(0..=65535).contains(&code) {
anyhow::bail!(
"Speech token at index {i} is out of range: {code} \
(NeuCodec FSQ codes must be in 0..=65535)"
);
}
}
#[cfg(feature = "wgpu")]
{
let state = self.burn_decoder.lock().unwrap();
if let LazyBurnDecoder::Ready(Some(ref bd)) = *state {
return bd.decode(codes);
}
}
Ok(decode_forward(codes, &self.weights))
}
pub fn backend_name(&self) -> &str {
#[cfg(feature = "wgpu")]
{
let state = self.burn_decoder.lock().unwrap();
return match &*state {
LazyBurnDecoder::Ready(Some(bd)) => bd.backend_name(),
LazyBurnDecoder::Ready(None) => "cpu (ndarray)",
};
}
#[cfg(not(feature = "wgpu"))]
"cpu (ndarray)"
}
pub fn load(path: &Path) -> Result<Self> {
Self::from_file(path)
}
pub fn weights_path(&self) -> &Path {
&self.path
}
pub fn hop_length(&self) -> usize {
self.weights.hop_length
}
}
pub struct NeuCodecEncoder;
impl NeuCodecEncoder {
pub fn new() -> Result<Self> {
bail!(
"The NeuCodec encoder is not yet implemented in the pure-Rust build.\n\
\n\
To encode reference audio, use the Python neucodec package:\n\
\n\
\tpip install neucodec huggingface_hub\n\
\tpython scripts/encode_reference.py --audio reference.wav --out ref.npy\n\
\n\
Then pass the .npy file via --ref-codes to the synthesis examples."
)
}
pub fn load(_path: &Path) -> Result<Self> {
Self::new()
}
pub fn encode_wav(&self, _path: &Path) -> Result<Vec<i32>> {
bail!("Encoder not implemented — see NeuCodecEncoder docs")
}
pub fn backend_name(&self) -> &str {
"not available"
}
}
pub fn resample(samples: &[f32], from_hz: u32, to_hz: u32) -> Vec<f32> {
if from_hz == to_hz {
return samples.to_vec();
}
let ratio = from_hz as f64 / to_hz as f64;
let out_len = (samples.len() as f64 / ratio).ceil() as usize;
(0..out_len)
.map(|i| {
let src = i as f64 * ratio;
let lo = src.floor() as usize;
let hi = (lo + 1).min(samples.len() - 1);
let frac = (src - lo as f64) as f32;
samples[lo] * (1.0 - frac) + samples[hi] * frac
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fsq_decode_shape() {
let w = Array2::ones((4, 8));
let b = Array1::zeros(4);
let codes = vec![0i32, 1, 2, 65535];
let out = fsq_decode(&codes, w.view(), b.view());
assert_eq!(out.shape(), &[4, 4]);
}
#[test]
fn test_fsq_code_0() {
let w = Array2::eye(8);
let b = Array1::zeros(8);
let out = fsq_decode(&[0], w.view(), b.view());
for v in out.iter() {
assert!((*v + 1.0).abs() < 1e-5, "expected -1.0, got {v}");
}
}
#[test]
fn test_fsq_code_max() {
let w = Array2::eye(8);
let b = Array1::zeros(8);
let out = fsq_decode(&[65535], w.view(), b.view());
for v in out.iter() {
assert!((*v - 1.0).abs() < 1e-5, "expected 1.0, got {v}");
}
}
#[test]
fn test_linear_shape() {
let x = Array2::ones((5, 3));
let w = Array2::ones((7, 3));
let b = Array1::zeros(7);
let out = linear(x.view(), w.view(), Some(b.view()));
assert_eq!(out.shape(), &[5, 7]);
}
#[test]
fn test_conv1d_same_length() {
let c_in = 4;
let c_out = 8;
let t = 16;
let k = 3;
let x = Array2::ones((c_in, t));
let w = Array3::ones((c_out, c_in, k));
let b = Array1::zeros(c_out);
let out = conv1d(x.view(), w.view(), Some(b.view()), 1);
assert_eq!(out.shape(), &[c_out, t]); }
#[test]
fn test_group_norm_shape() {
let c = 64;
let t = 10;
let x = Array2::ones((c, t));
let w = Array1::ones(c);
let b = Array1::zeros(c);
let out = group_norm(x.view(), 4, w.view(), b.view(), 1e-6);
assert_eq!(out.shape(), &[c, t]);
for &v in out.iter() {
assert!(v.abs() < 1e-4, "expected ~0 after group_norm of all-ones, got {v}");
}
}
#[test]
fn test_layer_norm_shape() {
let t = 5;
let c = 32;
let x = Array2::from_elem((t, c), 2.0f32);
let w = Array1::ones(c);
let b = Array1::zeros(c);
let out = layer_norm(x.view(), w.view(), b.view(), 1e-6);
assert_eq!(out.shape(), &[t, c]);
for &v in out.iter() {
assert!(v.abs() < 1e-4, "expected ~0, got {v}");
}
}
#[test]
fn test_rms_norm_shape() {
let t = 3;
let c = 8;
let x = Array2::ones((t, c));
let w = Array1::ones(c);
let out = rms_norm(x.view(), w.view(), 1e-6);
assert_eq!(out.shape(), &[t, c]);
for &v in out.iter() {
assert!((v - 1.0).abs() < 1e-4, "expected 1.0, got {v}");
}
}
#[test]
fn test_rope_shape_preserved() {
let t = 4;
let n_heads = 2;
let head_dim = 8;
let mut x = Array3::ones((t, n_heads, head_dim));
apply_rope(&mut x);
assert_eq!(x.shape(), &[t, n_heads, head_dim]);
}
#[test]
fn test_hann_window() {
let w = hann_window(4);
assert_eq!(w.len(), 4);
assert!(w[0].abs() < 1e-6);
assert!((w[2] - 1.0).abs() < 1e-6);
}
fn make_ifft(n_fft: usize) -> std::sync::Arc<dyn rustfft::Fft<f32>> {
FftPlanner::<f32>::new().plan_fft_inverse(n_fft)
}
#[test]
fn test_istft_length() {
let hop = 4;
let n_fft = 16; let t = 10;
let n_bins = n_fft / 2 + 1; let mag = Array2::zeros((n_bins, t));
let phase = Array2::zeros((n_bins, t));
let win = hann_window(n_fft);
let ifft = make_ifft(n_fft);
let audio = istft_burn(mag.view(), phase.view(), hop, &win, ifft.as_ref());
assert_eq!(audio.len(), t * hop, "expected {} samples, got {}", t * hop, audio.len());
}
#[test]
fn test_istft_clamp_does_not_blow_up() {
let hop = 4;
let n_fft = 16;
let t = 4;
let n_bins = n_fft / 2 + 1;
let mag = Array2::from_elem((n_bins, t), 50.0f32);
let phase = Array2::zeros((n_bins, t));
let win = hann_window(n_fft);
let ifft = make_ifft(n_fft);
let audio = istft_burn(mag.view(), phase.view(), hop, &win, ifft.as_ref());
for &s in &audio {
assert!(s.is_finite(), "sample is not finite: {s}");
assert!(s.abs() < 1e6, "sample magnitude suspiciously large: {s}");
}
}
#[test]
fn test_wgpu_feature_fn() {
let _ = wgpu_feature_enabled();
}
#[test]
fn test_resample_identity() {
let s: Vec<f32> = (0..100).map(|i| i as f32).collect();
let r = resample(&s, 16_000, 16_000);
assert_eq!(r, s);
}
}