Skip to main content

rlx_neutts/decoder/
eager.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! NeuCodec decoder — eager CPU inference from safetensors weights.
17//!
18//! ## Architecture (XCodec2-based)
19//!
20//! ```text
21//!  codes [T]  ──►  FSQ lookup  ──►  fc_post_a  ──►  VocosBackbone  ──►  ISTFTHead  ──►  audio
22//! (int, 0..65535)  [T, 2048]      [T, 1024]        [T, 1024]                          [T*hop]
23//! ```
24//!
25//! **VocosBackbone**: Conv1d(k=7) → 2×ResnetBlock → 12×TransformerBlock (RoPE) → 2×ResnetBlock → LayerNorm
26//!
27//! **ISTFTHead**: Linear(1024 → n_fft+2) → split mag/phase → ISTFT
28//!
29//! ## Setup (one-time)
30//!
31//! ```sh
32//! python scripts/convert_weights.py   # download + extract decoder weights to safetensors
33//! ```
34//!
35//! Weights are loaded at runtime from `NEUTTS_DECODER_PATH` (not bundled in this crate).
36
37use std::path::{Path, PathBuf};
38
39use anyhow::{Context, Result, bail};
40use ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, ArrayView3, s};
41use rustfft::{FftPlanner, num_complex::Complex};
42use safetensors::SafeTensors;
43
44// ─── Public constants ─────────────────────────────────────────────────────────
45
46/// Sample rate of the decoder output (24 kHz).
47pub const SAMPLE_RATE: u32 = 24_000;
48
49/// Sample rate the encoder expects as input (16 kHz).
50pub const ENCODER_SAMPLE_RATE: u32 = 16_000;
51
52/// Decoder audio samples per speech token — assuming 50 tokens/s at 24 kHz.
53/// The actual value is detected from the weight shapes at load time.
54pub const SAMPLES_PER_TOKEN: usize = 480;
55
56/// Encoder audio samples consumed per speech token (16 000 / 50 = 320).
57pub const ENCODER_SAMPLES_PER_TOKEN: usize = 320;
58
59/// Default reference audio length for the encoder: 10 s × 16 000 Hz.
60pub const ENCODER_DEFAULT_INPUT_SAMPLES: usize = 16_000 * 10;
61
62// Feature probes live in `crate::features` (re-exported from `decoder`).
63
64// ─── FSQ constants ────────────────────────────────────────────────────────────
65
66/// FSQ levels for NeuCodec: 8 dimensions × 4 levels → 4^8 = 65 536 codes.
67pub(crate) const FSQ_LEVELS: [i32; 8] = [4, 4, 4, 4, 4, 4, 4, 4];
68
69/// Cumulative products of FSQ_LEVELS: used to decompose an integer code.
70/// basis[j] = product(FSQ_LEVELS[0..j])
71pub(crate) const FSQ_BASIS: [i32; 8] = [1, 4, 16, 64, 256, 1_024, 4_096, 16_384];
72
73// ─── Tensor helpers ───────────────────────────────────────────────────────────
74
75fn load_f32(st: &SafeTensors<'_>, name: &str) -> Result<Vec<f32>> {
76    let view = st
77        .tensor(name)
78        .with_context(|| format!("Missing weight: {name}"))?;
79    let raw = view.data();
80    use safetensors::tensor::Dtype;
81    Ok(match view.dtype() {
82        Dtype::F32 => {
83            // Fast path: the bytes are already little-endian f32.  On LE
84            // hosts (x86, ARM) we can reinterpret directly with no per-byte
85            // work — essentially a single memcpy via the Vec allocation.
86            assert!(
87                raw.len() % 4 == 0,
88                "F32 tensor byte length not divisible by 4"
89            );
90            let n = raw.len() / 4;
91            let mut out = Vec::with_capacity(n);
92            // SAFETY: raw is valid, aligned to u8 (no alignment requirement
93            // for the source), and we write exactly `n` f32 values.
94            #[cfg(target_endian = "little")]
95            {
96                // SAFETY: f32 and u8 have no padding/invalid-bit patterns for
97                // this cast; we own `out` and set its length immediately after.
98                unsafe {
99                    std::ptr::copy_nonoverlapping(
100                        raw.as_ptr(),
101                        out.as_mut_ptr() as *mut u8,
102                        raw.len(),
103                    );
104                    out.set_len(n);
105                }
106            }
107            #[cfg(not(target_endian = "little"))]
108            {
109                out.extend(
110                    raw.chunks_exact(4)
111                        .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]])),
112                );
113            }
114            out
115        }
116        Dtype::BF16 => raw
117            .chunks_exact(2)
118            .map(|b| {
119                let bits = u16::from_le_bytes([b[0], b[1]]);
120                f32::from_bits((bits as u32) << 16)
121            })
122            .collect(),
123        dt => bail!("Tensor {name}: unsupported dtype {dt:?} (expected F32 or BF16)"),
124    })
125}
126
127fn shape_of(st: &SafeTensors<'_>, name: &str) -> Result<Vec<usize>> {
128    Ok(st
129        .tensor(name)
130        .with_context(|| format!("Missing weight: {name}"))?
131        .shape()
132        .to_vec())
133}
134
135fn as1d(data: Vec<f32>, n: usize) -> Array1<f32> {
136    Array1::from_shape_vec(n, data).expect("1-D shape mismatch")
137}
138
139fn as2d(data: Vec<f32>, rows: usize, cols: usize) -> Array2<f32> {
140    Array2::from_shape_vec((rows, cols), data).expect("2-D shape mismatch")
141}
142
143fn as3d(data: Vec<f32>, d0: usize, d1: usize, d2: usize) -> Array3<f32> {
144    Array3::from_shape_vec((d0, d1, d2), data).expect("3-D shape mismatch")
145}
146
147// ─── Math primitives ──────────────────────────────────────────────────────────
148
149/// Linear layer: `out = x @ w.T + b`
150///
151/// * `x`: \[T, in_dim\]
152/// * `w`: \[out_dim, in_dim\]  (PyTorch row-major convention)
153/// * `b`: \[out_dim\]  (optional)
154/// * returns: \[T, out_dim\]
155fn linear(x: ArrayView2<f32>, w: ArrayView2<f32>, b: Option<ArrayView1<f32>>) -> Array2<f32> {
156    let mut out = x.dot(&w.t()); // [T, out_dim]
157    if let Some(b) = b {
158        out += &b;
159    }
160    out
161}
162
163/// Conv1d with same-length output (zero-padded).
164///
165/// * `x`: \[c_in, T\]
166/// * `w`: \[c_out, c_in, k\]
167/// * `b`: \[c_out\]  (optional)
168/// * returns: \[c_out, T\]
169fn conv1d(
170    x: ArrayView2<f32>,
171    w: ArrayView3<f32>,
172    b: Option<ArrayView1<f32>>,
173    pad: usize,
174) -> Array2<f32> {
175    let (c_in, t) = (x.shape()[0], x.shape()[1]);
176    let (c_out, _, k) = (w.shape()[0], w.shape()[1], w.shape()[2]);
177
178    // im2col: build [T, c_in × k] column matrix
179    let mut col = Array2::<f32>::zeros((t, c_in * k));
180    for ti in 0..t {
181        for ci in 0..c_in {
182            for ki in 0..k {
183                let src = ti + ki;
184                if src >= pad && src < t + pad {
185                    col[[ti, ci * k + ki]] = x[[ci, src - pad]];
186                }
187                // else zero-pad (already zeroed)
188            }
189        }
190    }
191
192    // weight: [c_out, c_in × k]
193    let w2 = w
194        .into_shape_with_order((c_out, c_in * k))
195        .expect("conv1d reshape");
196
197    // out_t = col @ w2.T  →  [T, c_out]  then transpose to [c_out, T]
198    let out_t = col.dot(&w2.t());
199    let mut out = out_t.t().to_owned(); // [c_out, T]
200
201    if let Some(b) = b {
202        // Broadcast b [c_out] over [c_out, T] — one ndarray op, no manual loop.
203        use ndarray::Axis;
204        out += &b.view().insert_axis(Axis(1));
205    }
206    out
207}
208
209/// GroupNorm: `affine=True`, over input \[C, T\].
210/// Normalises over (group_size × T) elements per group.
211///
212/// Uses an iterator-based variance computation to avoid the temporary
213/// array that `block.mapv(|v| (v - mean).powi(2))` would allocate.
214fn group_norm(
215    x: ArrayView2<f32>,
216    n_groups: usize,
217    w: ArrayView1<f32>,
218    b: ArrayView1<f32>,
219    eps: f32,
220) -> Array2<f32> {
221    let (c, t) = (x.shape()[0], x.shape()[1]);
222    let group_size = c / n_groups;
223    let n = (group_size * t) as f32;
224    let mut out = Array2::<f32>::zeros((c, t));
225
226    for g in 0..n_groups {
227        let c_start = g * group_size;
228        let c_end = c_start + group_size;
229        let block = x.slice(s![c_start..c_end, ..]);
230
231        // Mean — no temporary allocation
232        let mean = block.iter().sum::<f32>() / n;
233        // Variance — single pass, no temporary allocation
234        let var = block
235            .iter()
236            .map(|&v| {
237                let d = v - mean;
238                d * d
239            })
240            .sum::<f32>()
241            / n;
242        let inv_std = 1.0 / (var + eps).sqrt();
243
244        for ci in c_start..c_end {
245            let scale = inv_std * w[ci];
246            let shift = b[ci];
247            for ti in 0..t {
248                out[[ci, ti]] = (x[[ci, ti]] - mean) * scale + shift;
249            }
250        }
251    }
252    out
253}
254
255/// LayerNorm over the last axis of \[T, C\].
256///
257/// Uses iterator sums to avoid the temporary arrays that `row.mapv(…).sum()`
258/// would allocate for each of the T rows.
259fn layer_norm(x: ArrayView2<f32>, w: ArrayView1<f32>, b: ArrayView1<f32>, eps: f32) -> Array2<f32> {
260    let (t, c) = (x.shape()[0], x.shape()[1]);
261    let c_f = c as f32;
262    let mut out = Array2::<f32>::zeros((t, c));
263    for ti in 0..t {
264        let row = x.slice(s![ti, ..]);
265        let mean = row.iter().sum::<f32>() / c_f;
266        let var = row
267            .iter()
268            .map(|&v| {
269                let d = v - mean;
270                d * d
271            })
272            .sum::<f32>()
273            / c_f;
274        let inv_std = 1.0 / (var + eps).sqrt();
275        for ci in 0..c {
276            out[[ti, ci]] = (x[[ti, ci]] - mean) * inv_std * w[ci] + b[ci];
277        }
278    }
279    out
280}
281
282/// RMSNorm over the last axis of \[T, C\].
283///
284/// Uses an iterator sum to avoid the temporary array that `row.mapv(|v|
285/// v*v).sum()` would allocate for each of the T rows.
286fn rms_norm(x: ArrayView2<f32>, w: ArrayView1<f32>, eps: f32) -> Array2<f32> {
287    let (t, c) = (x.shape()[0], x.shape()[1]);
288    let c_f = c as f32;
289    let mut out = Array2::<f32>::zeros((t, c));
290    for ti in 0..t {
291        let row = x.slice(s![ti, ..]);
292        let ms = row.iter().map(|&v| v * v).sum::<f32>() / c_f;
293        let scale = 1.0 / (ms + eps).sqrt();
294        for ci in 0..c {
295            out[[ti, ci]] = x[[ti, ci]] * scale * w[ci];
296        }
297    }
298    out
299}
300
301/// SiLU (swish): `x * σ(x)`.
302#[inline(always)]
303fn silu(x: f32) -> f32 {
304    x / (1.0 + (-x).exp())
305}
306
307/// Row-wise softmax (in-place) over \[T\].
308fn softmax_inplace(x: &mut [f32]) {
309    let max = x.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
310    let mut sum = 0.0f32;
311    x.iter_mut().for_each(|v| {
312        *v = (*v - max).exp();
313        sum += *v;
314    });
315    x.iter_mut().for_each(|v| *v /= sum);
316}
317
318// ─── FSQ decode ───────────────────────────────────────────────────────────────
319
320/// Decode integer FSQ codes → continuous embeddings.
321///
322/// For each code (0..65535):
323/// 1. Decompose into 8 base-4 digits using `FSQ_BASIS`.
324/// 2. Scale each digit d ∈ {0,1,2,3} to {−1, −⅓, ⅓, 1} via `(d/1.5) - 1`.
325/// 3. Apply the `project_out` linear layer (8 → 2048).
326///
327/// Returns \[T, fsq_out_dim\].
328fn fsq_decode(
329    codes: &[i32],
330    proj_w: ArrayView2<f32>, // [fsq_out_dim, 8]
331    proj_b: ArrayView1<f32>, // [fsq_out_dim]
332) -> Array2<f32> {
333    let t = codes.len();
334    let _out_dim = proj_w.shape()[0];
335
336    // Build [T, 8] matrix of scaled FSQ digits
337    let mut digits = Array2::<f32>::zeros((t, FSQ_BASIS.len()));
338    for (i, &code) in codes.iter().enumerate() {
339        for (j, (&basis, &levels)) in FSQ_BASIS.iter().zip(FSQ_LEVELS.iter()).enumerate() {
340            let d = (code / basis) % levels;
341            // Scale from {0,1,…,L-1} to {-1, -1/3, 1/3, 1} for L=4
342            // Formula: (d / ((L-1)/2)) - 1  =  (d / 1.5) - 1
343            digits[[i, j]] = d as f32 / 1.5 - 1.0;
344        }
345    }
346
347    // project_out: [T, 8] @ [8, out_dim] + [out_dim]
348    linear(digits.view(), proj_w, Some(proj_b))
349}
350
351// ─── RoPE sin/cos dispatch ────────────────────────────────────────────────────
352
353/// Compute `(sin(x), cos(x))` for use in Rotary Positional Embedding.
354///
355/// The implementation is selected at compile time by the active feature flag:
356///
357/// | Feature     | Implementation                          | Max abs. error |
358/// |-------------|-----------------------------------------|----------------|
359/// | `fast`      | degree-7/6 Horner polynomial + f32 RR   | ~1 × 10⁻⁴     |
360/// | `precise`   | `f32::sin_cos()` — correctly rounded    | ~1 × 10⁻⁷     |
361/// | *(neither)* | same as `fast` (default)                | ~1 × 10⁻⁴     |
362///
363/// ### Fast-mode notes
364///
365/// The polynomial path avoids transcendental function calls entirely: sin and
366/// cos are each evaluated with 6 fused multiply-adds (Horner's method).  On
367/// platforms where hardware `sin`/`cos` instructions are slow or absent this
368/// can be 6–12× faster per value.
369///
370/// Range reduction to \[−π, π\] uses a single `f32` round-multiply.  For large
371/// angles — RoPE dimensions with position ≈ 2 047 and the highest frequency
372/// (`inv_freq = 1.0`) give θ ≈ 2 047 rad — floating-point cancellation in the
373/// reduction introduces O(2⁻²³ · |θ|) extra absolute error before the
374/// polynomial.  At the worst case this is ≈ 2 × 10⁻⁴ rad, which is well
375/// within perceptual threshold for speech synthesis.
376///
377/// Both this function and the Burn GPU path in `codec_burn::load_weights` use
378/// the same dispatch, so precomputed RoPE tables and runtime CPU evaluations
379/// are always produced by the same algorithm.
380#[cfg(not(feature = "precise"))]
381#[inline(always)]
382pub(crate) fn rope_sin_cos(x: f32) -> (f32, f32) {
383    use std::f32::consts::TAU;
384    // Range-reduce to [−π, π] with a single round() multiply.
385    let x = x - TAU * (x * (1.0 / TAU)).round();
386    let x2 = x * x;
387    // Horner-form degree-7 sin: x(1 + x²(−1/6 + x²(1/120 − x²/5040)))
388    let s = x * (1.0 + x2 * (-1.0 / 6.0 + x2 * (1.0 / 120.0 - x2 * (1.0 / 5040.0))));
389    // Horner-form degree-6 cos: 1 + x²(−1/2 + x²(1/24 − x²/720))
390    let c = 1.0 + x2 * (-0.5 + x2 * (1.0 / 24.0 - x2 * (1.0 / 720.0)));
391    (s, c)
392}
393
394#[cfg(feature = "precise")]
395#[inline(always)]
396pub(crate) fn rope_sin_cos(x: f32) -> (f32, f32) {
397    x.sin_cos()
398}
399
400// ─── Rotary positional embedding ──────────────────────────────────────────────
401
402/// Apply split-half RoPE (torchtune convention) to `x` in-place.
403///
404/// * `x`: \[T, n_heads, head_dim\]
405///
406/// The outer loop order is `(position, freq_index, head)` so each
407/// `sin_cos()` result is computed **once per (position, freq)** and reused
408/// across all heads — previously it was computed once per (position, freq,
409/// head), allocating a fresh `Vec<f32>` for each position.
410fn apply_rope(x: &mut Array3<f32>) {
411    let (t, n_heads, head_dim) = (x.shape()[0], x.shape()[1], x.shape()[2]);
412    let half = head_dim / 2;
413
414    // Inverse frequencies — only `half` f32 values, computed once.
415    let inv_freqs: Vec<f32> = (0..half)
416        .map(|i| 1.0_f32 / 10_000_f32.powf(2.0 * i as f32 / head_dim as f32))
417        .collect();
418
419    for p in 0..t {
420        let p_f = p as f32;
421        for i in 0..half {
422            // Dispatch to rope_sin_cos(): polynomial (fast, default) or
423            // stdlib sin_cos (precise feature).
424            let (s, c) = rope_sin_cos(p_f * inv_freqs[i]);
425            // Apply the same rotation to every head — no per-head recompute.
426            for h in 0..n_heads {
427                let x1 = x[[p, h, i]];
428                let x2 = x[[p, h, i + half]];
429                x[[p, h, i]] = x1 * c - x2 * s;
430                x[[p, h, i + half]] = x1 * s + x2 * c;
431            }
432        }
433    }
434}
435
436// ─── Transformer components ───────────────────────────────────────────────────
437
438pub(crate) struct TransformerWeights {
439    pub(crate) att_norm_w: Array1<f32>, // RMSNorm  [D]
440    pub(crate) c_attn_w: Array2<f32>,   // Linear   [3D, D]  (no bias)
441    pub(crate) c_proj_w: Array2<f32>,   // Linear   [D, D]   (no bias)
442    pub(crate) ffn_norm_w: Array1<f32>, // RMSNorm  [D]
443    pub(crate) fc1_w: Array2<f32>,      // Linear   [4D, D]  (no bias)
444    pub(crate) fc2_w: Array2<f32>,      // Linear   [D, 4D]  (no bias)
445}
446
447/// Single Transformer block (RMSNorm → Attention → RMSNorm → MLP), residual.
448///
449/// * `x`: \[T, D\]  (modified in-place conceptually; returns new array)
450fn transformer_block(x: ArrayView2<f32>, w: &TransformerWeights, n_heads: usize) -> Array2<f32> {
451    let (t, d) = (x.shape()[0], x.shape()[1]);
452    let head_dim = d / n_heads;
453
454    // ── Attention sub-layer ───────────────────────────────────────────────────
455    let normed = rms_norm(x, w.att_norm_w.view(), 1e-6);
456    // qkv: [T, 3D]  (no bias)
457    let qkv = linear(normed.view(), w.c_attn_w.view(), None);
458
459    // Split into Q, K, V each [T, D]
460    let q_flat = qkv.slice(s![.., 0..d]).to_owned();
461    let k_flat = qkv.slice(s![.., d..2 * d]).to_owned();
462    let v_flat = qkv.slice(s![.., 2 * d..]).to_owned();
463
464    // Reshape to [T, n_heads, head_dim]
465    let mut q = q_flat
466        .into_shape_with_order((t, n_heads, head_dim))
467        .expect("q reshape");
468    let mut k = k_flat
469        .into_shape_with_order((t, n_heads, head_dim))
470        .expect("k reshape");
471    let v = v_flat
472        .into_shape_with_order((t, n_heads, head_dim))
473        .expect("v reshape");
474
475    apply_rope(&mut q);
476    apply_rope(&mut k);
477
478    // Scaled dot-product attention per head
479    let scale = (head_dim as f32).sqrt().recip();
480    // attn_out: [T, n_heads, head_dim]
481    let mut attn_out = Array3::<f32>::zeros((t, n_heads, head_dim));
482
483    for h in 0..n_heads {
484        let qh = q.slice(s![.., h, ..]).to_owned(); // [T, head_dim]
485        let kh = k.slice(s![.., h, ..]).to_owned();
486        let vh = v.slice(s![.., h, ..]).to_owned();
487
488        // scores = qh @ kh.T * scale  →  [T, T]
489        let mut scores = qh.dot(&kh.t());
490        scores.mapv_inplace(|v| v * scale);
491
492        // softmax over last dim (per query row)
493        for ti in 0..t {
494            softmax_inplace(scores.slice_mut(s![ti, ..]).as_slice_mut().unwrap());
495        }
496
497        // weighted_v = scores @ vh  →  [T, head_dim]
498        let wv = scores.dot(&vh);
499        attn_out.slice_mut(s![.., h, ..]).assign(&wv);
500    }
501
502    // Reshape [T, n_heads, head_dim] → [T, D]
503    let attn_flat = attn_out
504        .into_shape_with_order((t, d))
505        .expect("attn out reshape");
506
507    // Project: c_proj (no bias)
508    let attn_proj = linear(attn_flat.view(), w.c_proj_w.view(), None);
509
510    // Residual
511    let x_attn = &x + &attn_proj;
512
513    // ── MLP sub-layer ─────────────────────────────────────────────────────────
514    let normed2 = rms_norm(x_attn.view(), w.ffn_norm_w.view(), 1e-6);
515    let h1 = linear(normed2.view(), w.fc1_w.view(), None);
516    let h1_act = h1.mapv(silu);
517    let h2 = linear(h1_act.view(), w.fc2_w.view(), None);
518
519    &x_attn + &h2
520}
521
522// ─── ResnetBlock ─────────────────────────────────────────────────────────────
523
524pub(crate) struct ResnetBlockWeights {
525    pub(crate) norm1_w: Array1<f32>, // GroupNorm [C]
526    pub(crate) norm1_b: Array1<f32>,
527    pub(crate) conv1_w: Array3<f32>, // Conv1d [C, C, 3]
528    pub(crate) conv1_b: Array1<f32>,
529    pub(crate) norm2_w: Array1<f32>,
530    pub(crate) norm2_b: Array1<f32>,
531    pub(crate) conv2_w: Array3<f32>, // Conv1d [C, C, 3]
532    pub(crate) conv2_b: Array1<f32>,
533}
534
535/// ResnetBlock: GroupNorm → swish → Conv1d(k=3) → GroupNorm → swish → Conv1d(k=3) + residual.
536///
537/// * `x`: \[C, T\]  (channels-first)
538fn resnet_block(x: ArrayView2<f32>, w: &ResnetBlockWeights) -> Array2<f32> {
539    // norm1 → swish → conv1
540    let h = group_norm(x, 32, w.norm1_w.view(), w.norm1_b.view(), 1e-6);
541    let h = h.mapv(silu);
542    let h = conv1d(h.view(), w.conv1_w.view(), Some(w.conv1_b.view()), 1);
543
544    // norm2 → swish → (dropout=no-op at inference) → conv2
545    let h = group_norm(h.view(), 32, w.norm2_w.view(), w.norm2_b.view(), 1e-6);
546    let h = h.mapv(silu);
547    let h = conv1d(h.view(), w.conv2_w.view(), Some(w.conv2_b.view()), 1);
548
549    // residual (in_channels == out_channels so no projection)
550    &x + &h
551}
552
553// ─── ISTFT ────────────────────────────────────────────────────────────────────
554
555/// Inverse STFT matching PyTorch `torch.istft(..., center=True)`.
556///
557/// * `mag`: \[n_fft/2+1, T\]  **log**-magnitudes (the model head outputs log-mag)
558/// * `phase`: \[n_fft/2+1, T\] phase angles in radians
559/// * `hop`: hop length (= n_fft / 4)
560/// * `window`: Hann window \[n_fft\]
561/// * returns: waveform of exactly `T × hop` samples
562///
563/// ### Two bugs this function previously had (now fixed)
564///
565/// 1. **Clamp-before-exp** — the original code did `mag.min(1e2).exp()`, which
566///    caps the *log*-magnitude at 100 (meaning `exp(100) ≈ 2.7e43` for large
567///    bins).  The correct Python behaviour is `exp(mag).clamp(max=1e2)` — clamp
568///    the *linear* magnitude to 100.  Large log-magnitude bins (common for
569///    loud/low-frequency speech) therefore blew up, drowning out high-frequency
570///    content and causing muffled output.
571///
572/// 2. **Wrong center trim** — PyTorch's `center=True` removes `n_fft/2` samples
573///    from the **start** of the OLA buffer and then takes exactly `T*hop`
574///    samples.  The old code instead removed `(n_fft-hop)/2` from **both ends**,
575///    which is a 240-sample temporal offset (at 24 kHz with hop=480) and
576///    includes partially-overlapped edge frames with poor reconstruction quality.
577///
578/// `pub(crate)` so the Burn decoder in `codec_burn.rs` can call it after
579/// pulling the head output back from the device.
580pub(crate) fn istft_burn(
581    mag: ArrayView2<f32>,
582    phase: ArrayView2<f32>,
583    hop: usize,
584    window: &[f32],
585    ifft: &dyn rustfft::Fft<f32>,
586) -> Vec<f32> {
587    let n_bins = mag.shape()[0]; // n_fft/2 + 1
588    let n_frames = mag.shape()[1];
589    let n_fft = (n_bins - 1) * 2;
590    debug_assert_eq!(n_fft, window.len());
591    debug_assert_eq!(hop, n_fft / 4);
592
593    // Output buffer length before trimming
594    let out_size = (n_frames - 1) * hop + n_fft;
595    let mut y = vec![0.0f32; out_size];
596    let mut env = vec![0.0f32; out_size];
597
598    let mut buf = vec![Complex::<f32>::default(); n_fft];
599
600    for ti in 0..n_frames {
601        // Build the complex spectrum from log-magnitude + phase angle.
602        //
603        // FIX 1: exp() first, then clamp — matching PyTorch's
604        //   `mag = torch.exp(mag).clamp(max=1e2)`
605        // The old `.min(1e2).exp()` capped the log-magnitude at 100, which
606        // effectively allowed linear magnitudes up to exp(100) ≈ 2.7e43.
607        for fi in 0..n_bins {
608            let m = mag[[fi, ti]].exp().min(1e2); // ← fixed: clamp linear mag
609            let p = phase[[fi, ti]];
610            buf[fi] = Complex::new(m * p.cos(), m * p.sin());
611        }
612        // Hermitian symmetry for real IFFT output
613        for fi in 1..n_bins - 1 {
614            buf[n_fft - fi] = buf[fi].conj();
615        }
616
617        // Inverse FFT (rustfft is unnormalized — we divide by n_fft below)
618        ifft.process(&mut buf);
619
620        // Normalize + apply synthesis window, then overlap-add
621        let norm = n_fft as f32;
622        let offset = ti * hop;
623        for i in 0..n_fft {
624            let sample = buf[i].re / norm * window[i];
625            y[offset + i] += sample;
626            env[offset + i] += window[i] * window[i];
627        }
628    }
629
630    // Weighted overlap-add normalization
631    for i in 0..out_size {
632        if env[i] > 1e-11 {
633            y[i] /= env[i];
634        }
635    }
636
637    // FIX 2: match PyTorch center=True — trim n_fft/2 from the START only,
638    // then take exactly T*hop samples.
639    //
640    // Old code: y[(n_fft-hop)/2 .. out_size-(n_fft-hop)/2]
641    //   → 240-sample temporal offset + includes edge frames with 1-2 overlaps.
642    // Correct:  y[n_fft/2 .. n_fft/2 + T*hop]
643    //   → first fully-overlapped sample (≥4 frames) through end of signal.
644    let start = n_fft / 2;
645    let length = n_frames * hop;
646    y[start..start + length].to_vec()
647}
648
649/// Hann window of length `n`.
650fn hann_window(n: usize) -> Vec<f32> {
651    (0..n)
652        .map(|i| 0.5 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / n as f32).cos()))
653        .collect()
654}
655
656// ─── Decoder weights ──────────────────────────────────────────────────────────
657
658pub(crate) struct DecoderWeights {
659    // FSQ
660    pub(crate) fsq_proj_w: Array2<f32>, // [2048, 8]
661    pub(crate) fsq_proj_b: Array1<f32>, // [2048]
662
663    // fc_post_a: Linear(2048, 1024)
664    pub(crate) fc_post_a_w: Array2<f32>, // [1024, 2048]
665    pub(crate) fc_post_a_b: Array1<f32>, // [1024]
666
667    // backbone.embed: Conv1d(1024, 1024, k=7, pad=3)
668    pub(crate) embed_w: Array3<f32>, // [1024, 1024, 7]
669    pub(crate) embed_b: Array1<f32>, // [1024]
670
671    // backbone.prior_net (2 ResnetBlocks)
672    pub(crate) prior_net: Vec<ResnetBlockWeights>,
673
674    // backbone.transformers (N TransformerBlocks)
675    pub(crate) transformers: Vec<TransformerWeights>,
676
677    // backbone.final_layer_norm: LayerNorm [D]
678    pub(crate) final_norm_w: Array1<f32>,
679    pub(crate) final_norm_b: Array1<f32>,
680
681    // backbone.post_net (2 ResnetBlocks)
682    pub(crate) post_net: Vec<ResnetBlockWeights>,
683
684    // head.out: Linear(D, n_fft+2)
685    pub(crate) head_w: Array2<f32>, // [n_fft+2, 1024]
686    pub(crate) head_b: Array1<f32>, // [n_fft+2]
687
688    // Hann window
689    pub(crate) window: Vec<f32>, // [n_fft]
690
691    // Detected hyper-parameters
692    pub(crate) hidden_dim: usize,
693    pub(crate) hop_length: usize,
694    pub(crate) depth: usize,
695    pub(crate) n_heads: usize,
696
697    // Cached IFFT plan — created once at load time so the plan cache is not
698    // discarded between decode() calls.
699    pub(crate) ifft_plan: std::sync::Arc<dyn rustfft::Fft<f32>>,
700}
701
702fn load_resnet_block(st: &SafeTensors<'_>, prefix: &str, c: usize) -> Result<ResnetBlockWeights> {
703    Ok(ResnetBlockWeights {
704        norm1_w: as1d(load_f32(st, &format!("{prefix}.norm1.weight"))?, c),
705        norm1_b: as1d(load_f32(st, &format!("{prefix}.norm1.bias"))?, c),
706        conv1_w: as3d(load_f32(st, &format!("{prefix}.conv1.weight"))?, c, c, 3),
707        conv1_b: as1d(load_f32(st, &format!("{prefix}.conv1.bias"))?, c),
708        norm2_w: as1d(load_f32(st, &format!("{prefix}.norm2.weight"))?, c),
709        norm2_b: as1d(load_f32(st, &format!("{prefix}.norm2.bias"))?, c),
710        conv2_w: as3d(load_f32(st, &format!("{prefix}.conv2.weight"))?, c, c, 3),
711        conv2_b: as1d(load_f32(st, &format!("{prefix}.conv2.bias"))?, c),
712    })
713}
714
715fn load_transformer(st: &SafeTensors<'_>, prefix: &str, d: usize) -> Result<TransformerWeights> {
716    Ok(TransformerWeights {
717        att_norm_w: as1d(load_f32(st, &format!("{prefix}.att_norm.weight"))?, d),
718        c_attn_w: as2d(
719            load_f32(st, &format!("{prefix}.att.c_attn.weight"))?,
720            3 * d,
721            d,
722        ),
723        c_proj_w: as2d(load_f32(st, &format!("{prefix}.att.c_proj.weight"))?, d, d),
724        ffn_norm_w: as1d(load_f32(st, &format!("{prefix}.ffn_norm.weight"))?, d),
725        fc1_w: as2d(load_f32(st, &format!("{prefix}.mlp.fc1.weight"))?, 4 * d, d),
726        fc2_w: as2d(load_f32(st, &format!("{prefix}.mlp.fc2.weight"))?, d, 4 * d),
727    })
728}
729
730fn load_decoder_weights(
731    st: &SafeTensors<'_>,
732    user_meta: &Option<std::collections::HashMap<String, String>>,
733) -> Result<DecoderWeights> {
734    // ── Auto-detect hyper-parameters from weight shapes ───────────────────────
735    let embed_shape = shape_of(st, "generator.backbone.embed.weight")?;
736    let hidden_dim = embed_shape[0]; // c_out
737
738    let head_shape = shape_of(st, "generator.head.out.weight")?;
739    let out_dim = head_shape[0]; // n_fft + 2
740    let hop_length = (out_dim - 2) / 4;
741
742    // Count transformer blocks by probing for weight keys
743    let depth = (0..64)
744        .take_while(|&i| {
745            st.tensor(&format!(
746                "generator.backbone.transformers.{i}.att_norm.weight"
747            ))
748            .is_ok()
749        })
750        .count();
751
752    if depth == 0 {
753        bail!("No transformer blocks found — is the safetensors file correct?");
754    }
755
756    // n_heads: read from safetensors __metadata__ if present, otherwise default to 16
757    let n_heads: usize = user_meta
758        .as_ref()
759        .and_then(|m| m.get("n_heads"))
760        .and_then(|s| s.parse().ok())
761        .unwrap_or(16);
762
763    // FSQ codebook projection
764    // Try the nested key first (older exports), fall back to the flat key
765    let fsq_proj_key = if st
766        .tensor("generator.quantizer.fsqs.0.project_out.weight")
767        .is_ok()
768    {
769        "generator.quantizer.fsqs.0.project_out.weight"
770    } else {
771        "generator.quantizer.project_out.weight"
772    };
773    let fsq_bias_key = if st
774        .tensor("generator.quantizer.fsqs.0.project_out.bias")
775        .is_ok()
776    {
777        "generator.quantizer.fsqs.0.project_out.bias"
778    } else {
779        "generator.quantizer.project_out.bias"
780    };
781
782    let fsq_shape = shape_of(st, fsq_proj_key)?;
783    let fsq_out_dim = fsq_shape[0]; // 2048
784    let fsq_in_dim = fsq_shape[1]; // 8
785
786    let fsq_proj_w = as2d(load_f32(st, fsq_proj_key)?, fsq_out_dim, fsq_in_dim);
787    let fsq_proj_b = as1d(load_f32(st, fsq_bias_key)?, fsq_out_dim);
788
789    // fc_post_a: [1024, 2048]
790    let fc_post_a_w = as2d(load_f32(st, "fc_post_a.weight")?, hidden_dim, fsq_out_dim);
791    let fc_post_a_b = as1d(load_f32(st, "fc_post_a.bias")?, hidden_dim);
792
793    // backbone.embed Conv1d
794    let embed_k = embed_shape[2];
795    let embed_w = as3d(
796        load_f32(st, "generator.backbone.embed.weight")?,
797        hidden_dim,
798        hidden_dim,
799        embed_k,
800    );
801    let embed_b = as1d(load_f32(st, "generator.backbone.embed.bias")?, hidden_dim);
802
803    // prior_net (2 ResnetBlocks)
804    let prior_net = (0..2)
805        .map(|i| load_resnet_block(st, &format!("generator.backbone.prior_net.{i}"), hidden_dim))
806        .collect::<Result<Vec<_>>>()?;
807
808    // transformers
809    let transformers = (0..depth)
810        .map(|i| {
811            load_transformer(
812                st,
813                &format!("generator.backbone.transformers.{i}"),
814                hidden_dim,
815            )
816        })
817        .collect::<Result<Vec<_>>>()?;
818
819    // final_layer_norm
820    let final_norm_w = as1d(
821        load_f32(st, "generator.backbone.final_layer_norm.weight")?,
822        hidden_dim,
823    );
824    let final_norm_b = as1d(
825        load_f32(st, "generator.backbone.final_layer_norm.bias")?,
826        hidden_dim,
827    );
828
829    // post_net (2 ResnetBlocks)
830    let post_net = (0..2)
831        .map(|i| load_resnet_block(st, &format!("generator.backbone.post_net.{i}"), hidden_dim))
832        .collect::<Result<Vec<_>>>()?;
833
834    // head.out
835    let n_fft = hop_length * 4;
836    let head_w = as2d(
837        load_f32(st, "generator.head.out.weight")?,
838        out_dim,
839        hidden_dim,
840    );
841    let head_b = as1d(load_f32(st, "generator.head.out.bias")?, out_dim);
842
843    // Hann window: try to load from safetensors; compute as fallback
844    let window = if st.tensor("generator.head.istft.window").is_ok() {
845        load_f32(st, "generator.head.istft.window")?
846    } else {
847        hann_window(n_fft)
848    };
849
850    // Build the IFFT plan once at load time — the FftPlanner caches plans
851    // internally, but recreating the planner on each decode() call would
852    // silently discard that cache and re-plan from scratch every time.
853    let ifft_plan = {
854        let mut planner = FftPlanner::<f32>::new();
855        planner.plan_fft_inverse(n_fft)
856    };
857
858    Ok(DecoderWeights {
859        fsq_proj_w,
860        fsq_proj_b,
861        fc_post_a_w,
862        fc_post_a_b,
863        embed_w,
864        embed_b,
865        prior_net,
866        transformers,
867        final_norm_w,
868        final_norm_b,
869        post_net,
870        head_w,
871        head_b,
872        window,
873        hidden_dim,
874        hop_length,
875        depth,
876        n_heads,
877        ifft_plan,
878    })
879}
880
881// ─── Decoder forward pass ─────────────────────────────────────────────────────
882
883pub(crate) fn decode_forward(codes: &[i32], w: &DecoderWeights) -> Vec<f32> {
884    let hop = w.hop_length;
885    let n_fft = hop * 4;
886    let embed_k = w.embed_w.shape()[2];
887    let embed_pad = embed_k / 2;
888
889    // 1. FSQ decode: [T] → [T, fsq_out_dim]
890    let emb = fsq_decode(codes, w.fsq_proj_w.view(), w.fsq_proj_b.view());
891
892    // 2. fc_post_a: [T, fsq_out_dim] → [T, hidden_dim]
893    let x = linear(emb.view(), w.fc_post_a_w.view(), Some(w.fc_post_a_b.view()));
894
895    // 3. backbone.embed Conv1d: [hidden_dim, T]
896    let x_ct = x.t().to_owned(); // [hidden_dim, T]
897    let x_ct = conv1d(
898        x_ct.view(),
899        w.embed_w.view(),
900        Some(w.embed_b.view()),
901        embed_pad,
902    );
903
904    // 4. prior_net (ResnetBlocks, channels-first)
905    let x_ct = w
906        .prior_net
907        .iter()
908        .fold(x_ct, |acc, rw| resnet_block(acc.view(), rw));
909
910    // 5. Transformers (sequence-first)
911    let x_tc = x_ct.t().to_owned(); // [T, hidden_dim]
912    let x_tc = w
913        .transformers
914        .iter()
915        .fold(x_tc, |acc, tw| transformer_block(acc.view(), tw, w.n_heads));
916
917    // 6. post_net (channels-first)
918    let x_ct = x_tc.t().to_owned(); // [hidden_dim, T]
919    let x_ct = w
920        .post_net
921        .iter()
922        .fold(x_ct, |acc, rw| resnet_block(acc.view(), rw));
923
924    // 7. final_layer_norm (sequence-first)
925    let x_tc = x_ct.t().to_owned(); // [T, hidden_dim]
926    let x_tc = layer_norm(
927        x_tc.view(),
928        w.final_norm_w.view(),
929        w.final_norm_b.view(),
930        1e-6,
931    );
932
933    // 8. head.out: [T, hidden_dim] → [T, n_fft+2]
934    let x_pred = linear(x_tc.view(), w.head_w.view(), Some(w.head_b.view()));
935
936    // 9. Transpose → [n_fft+2, T], split mag and phase
937    let x_pred_ct = x_pred.t().to_owned(); // [n_fft+2, T]
938    let half = (n_fft / 2) + 1; // n_bins = 641 for n_fft=1280, 961 for n_fft=1920
939    let mag = x_pred_ct.slice(s![0..half, ..]).to_owned();
940    let phase = x_pred_ct.slice(s![half.., ..]).to_owned();
941
942    // 10. ISTFT — use the pre-built plan from DecoderWeights
943    istft_burn(
944        mag.view(),
945        phase.view(),
946        hop,
947        &w.window,
948        w.ifft_plan.as_ref(),
949    )
950}
951
952// ─── Public API ───────────────────────────────────────────────────────────────
953
954/// NeuCodec decoder: converts speech token IDs to a 24 kHz audio waveform.
955///
956/// ## Setup
957///
958/// Set `NEUTTS_DECODER_PATH` to `neucodec_decoder.safetensors`, then:
959/// ```rust,ignore
960/// let dec = NeuCodecDecoder::new()?;
961/// let audio = dec.decode(&codes)?;
962/// ```
963///
964/// ## Backend selection
965///
966/// When built with `--features wgpu`, the decoder automatically selects the
967/// best available backend on the **first call to [`decode`]** (lazy init):
968///
969/// | Priority | Backend                   | When used                          |
970/// |----------|---------------------------|------------------------------------|
971/// | 1        | Burn wgpu (GPU)           | Metal / Vulkan / DX12 adapter found|
972/// | 2        | Burn NdArray (CPU)        | No GPU adapter available           |
973/// | 3        | Raw ndarray (CPU)         | Burn init failed entirely          |
974///
975/// The Burn backend is initialised **eagerly** at [`from_file`](Self::from_file)
976/// time so the GPU upload cost is part of model loading, not synthesis latency.
977pub struct NeuCodecDecoder {
978    weights: DecoderWeights,
979    path: PathBuf,
980
981    /// Eagerly-initialised Burn backend.
982    ///
983    /// * `Ready(Some(_))` — Burn backend available and ready.
984    /// * `Ready(None)`    — Burn init was attempted but failed; falls through
985    ///                      to raw ndarray.
986    ///
987    /// The `Mutex` provides interior mutability so that `decode(&self)` can
988    /// be called without `&mut self`.
989    /// `Mutex<T>: Send + Sync` when `T: Send`, so `NeuCodecDecoder` remains
990    /// `Send + Sync`.
991    #[cfg(feature = "burn")]
992    burn_decoder: std::sync::Mutex<LazyBurnDecoder>,
993}
994
995#[cfg(feature = "burn")]
996enum LazyBurnDecoder {
997    /// Burn backend is available and ready.
998    Ready(Option<Box<dyn super::burn::BurnDecoder + Send>>),
999}
1000
1001impl NeuCodecDecoder {
1002    /// Load from `NEUTTS_DECODER_PATH`.
1003    pub fn new() -> Result<Self> {
1004        let path = super::decoder_weights_path()?;
1005        Self::from_file(&path)
1006    }
1007
1008    /// Load from an explicit file path.
1009    pub fn from_file(path: &Path) -> Result<Self> {
1010        if !path.exists() {
1011            bail!(
1012                "NeuCodec decoder weights not found: {}\n\
1013                 Set NEUTTS_DECODER_PATH or pass an explicit path to NeuCodecDecoder::from_file().",
1014                path.display()
1015            );
1016        }
1017
1018        // Memory-map the file so the OS pages in tensor data on demand instead
1019        // of reading all 840 MB into a heap Vec<u8> upfront.  This halves peak
1020        // RAM usage during loading and avoids a large malloc + full-file copy.
1021        let file = std::fs::File::open(path)
1022            .with_context(|| format!("Failed to open {}", path.display()))?;
1023        // SAFETY: we do not mutate the mapping, and we hold `mmap` for the
1024        // full lifetime of `st` (both are dropped at the end of this block
1025        // after `load_decoder_weights` has copied all tensors into ndarray).
1026        let mmap = unsafe {
1027            memmap2::Mmap::map(&file)
1028                .with_context(|| format!("Failed to mmap {}", path.display()))?
1029        };
1030        let bytes: &[u8] = &mmap;
1031
1032        // Read user-defined metadata (n_heads, depth, etc.) from the file header
1033        let (_, file_meta) = SafeTensors::read_metadata(bytes)
1034            .with_context(|| format!("Failed to parse safetensors header: {}", path.display()))?;
1035        let user_meta = file_meta.metadata().clone();
1036
1037        let st = SafeTensors::deserialize(bytes)
1038            .with_context(|| format!("Failed to parse safetensors: {}", path.display()))?;
1039
1040        let weights = load_decoder_weights(&st, &user_meta)
1041            .with_context(|| format!("Failed to load decoder weights from {}", path.display()))?;
1042
1043        // `st` and `mmap` are dropped here — all tensor data is now owned by
1044        // the ndarray arrays inside `weights`.
1045        drop(st);
1046        drop(mmap);
1047
1048        println!(
1049            "NeuCodec decoder: hidden={}, depth={}, heads={}, hop={} ({} samples/token = {} tokens/s)",
1050            weights.hidden_dim,
1051            weights.depth,
1052            weights.n_heads,
1053            weights.hop_length,
1054            weights.hop_length,
1055            SAMPLE_RATE as usize / weights.hop_length,
1056        );
1057
1058        // ── Eagerly initialise the Burn GPU/CPU backend ───────────────────────
1059        //
1060        // Initialising here (at load time) rather than lazily on the first
1061        // decode() call moves the ~1-2 s GPU upload cost out of synthesis
1062        // latency and into model loading, which is a better user experience:
1063        // the "loaded in X s" number is accurate, and "synth took Y s" reflects
1064        // only the actual forward pass.
1065        #[cfg(feature = "burn")]
1066        let burn_decoder = {
1067            let t0 = std::time::Instant::now();
1068            let dec = super::burn::make_burn_decoder(&weights);
1069            println!(
1070                "NeuCodec: {} backend ready in {:.2} s",
1071                dec.as_ref().map_or("cpu (ndarray)", |b| b.backend_name()),
1072                t0.elapsed().as_secs_f32(),
1073            );
1074            std::sync::Mutex::new(LazyBurnDecoder::Ready(dec))
1075        };
1076
1077        Ok(Self {
1078            weights,
1079            path: path.to_path_buf(),
1080            #[cfg(feature = "burn")]
1081            burn_decoder,
1082        })
1083    }
1084
1085    /// Decode speech token IDs to a 24 kHz audio waveform.
1086    ///
1087    /// * `codes` — integer token IDs in `0..=65535` (NeuCodec FSQ range).
1088    ///   Out-of-range values are rejected with an error rather than silently
1089    ///   producing garbage digits from the FSQ decomposition.
1090    /// * returns — `Vec<f32>` of `codes.len() × hop_length` samples.
1091    pub fn decode(&self, codes: &[i32]) -> Result<Vec<f32>> {
1092        if codes.is_empty() {
1093            return Ok(Vec::new());
1094        }
1095
1096        // Validate before touching any weights — an out-of-range code would
1097        // silently produce wrong FSQ digits (e.g. a negative modulo result).
1098        for (i, &code) in codes.iter().enumerate() {
1099            if !(0..=65535).contains(&code) {
1100                anyhow::bail!(
1101                    "Speech token at index {i} is out of range: {code} \
1102                     (NeuCodec FSQ codes must be in 0..=65535)"
1103                );
1104            }
1105        }
1106
1107        // ── Prefer Burn-accelerated path (wgpu GPU or NdArray CPU via Burn) ──
1108        //
1109        // The backend is always Ready here: it is initialised eagerly in
1110        // from_file() so there is no lazy-init stall inside synthesis.
1111        #[cfg(feature = "burn")]
1112        {
1113            let state = self.burn_decoder.lock().unwrap();
1114            if let LazyBurnDecoder::Ready(Some(ref bd)) = *state {
1115                return bd.decode(codes);
1116            }
1117        }
1118
1119        // ── RLX path (byte-identical to eager until compiled graph lands) ────
1120        #[cfg(feature = "rlx")]
1121        {
1122            super::rlx::decode(codes, &self.weights)
1123        }
1124        #[cfg(not(feature = "rlx"))]
1125        {
1126            Ok(decode_forward(codes, &self.weights))
1127        }
1128    }
1129
1130    /// Name of the active inference backend.
1131    pub fn backend_name(&self) -> &'static str {
1132        #[cfg(feature = "burn")]
1133        {
1134            let state = self.burn_decoder.lock().unwrap();
1135            if let LazyBurnDecoder::Ready(Some(bd)) = &*state {
1136                return bd.backend_name();
1137            }
1138        }
1139        if cfg!(feature = "rlx") {
1140            "rlx/eager-parity"
1141        } else {
1142            "codec/eager-ndarray"
1143        }
1144    }
1145
1146    /// Alias for [`from_file`](Self::from_file) — load from an explicit path.
1147    pub fn load(path: &Path) -> Result<Self> {
1148        Self::from_file(path)
1149    }
1150
1151    /// Path from which the decoder was loaded.
1152    pub fn weights_path(&self) -> &Path {
1153        &self.path
1154    }
1155
1156    /// Detected `hop_length` (audio samples per speech token).
1157    pub fn hop_length(&self) -> usize {
1158        self.weights.hop_length
1159    }
1160}
1161
1162// ─── Encoder (stub) ───────────────────────────────────────────────────────────
1163
1164/// NeuCodec encoder: converts a 16 kHz audio waveform to speech token IDs.
1165///
1166/// **Note**: The full NeuCodec encoder requires Wav2Vec2BertModel (~600 MB)
1167/// as a semantic feature extractor.  Encoder support is not yet implemented
1168/// in this pure-Rust build.
1169///
1170/// For reference audio encoding, use the Python `neucodec` package:
1171/// ```python
1172/// from neucodec import NeuCodec
1173/// model = NeuCodec.from_pretrained("neuphonic/neucodec")
1174/// codes = model.encode_code(waveform)   # → i32 array
1175/// ```
1176/// Then save the codes as a `.npy` file and pass via `--ref-codes` to the
1177/// synthesis examples.
1178pub struct NeuCodecEncoder;
1179
1180impl NeuCodecEncoder {
1181    /// Always returns an error — encoder not yet implemented.
1182    pub fn new() -> Result<Self> {
1183        bail!(
1184            "The NeuCodec encoder is not yet implemented in the pure-Rust build.\n\
1185             \n\
1186             To encode reference audio, use the Python neucodec package:\n\
1187             \n\
1188             \tpip install neucodec huggingface_hub\n\
1189             \tpython scripts/encode_reference.py --audio reference.wav --out ref.npy\n\
1190             \n\
1191             Then pass the .npy file via --ref-codes to the synthesis examples."
1192        )
1193    }
1194
1195    /// Always returns an error — encoder not yet implemented.
1196    pub fn load(_path: &Path) -> Result<Self> {
1197        Self::new()
1198    }
1199
1200    /// Encode a WAV file to speech token IDs (not implemented).
1201    pub fn encode_wav(&self, _path: &Path) -> Result<Vec<i32>> {
1202        bail!("Encoder not implemented — see NeuCodecEncoder docs")
1203    }
1204
1205    /// Backend name.
1206    pub fn backend_name(&self) -> &str {
1207        "not available"
1208    }
1209}
1210
1211// ─── Resample helper ──────────────────────────────────────────────────────────
1212
1213/// Naive linear resampler: changes sample rate of `samples` from `from_hz` to `to_hz`.
1214#[allow(dead_code)]
1215pub fn resample(samples: &[f32], from_hz: u32, to_hz: u32) -> Vec<f32> {
1216    if from_hz == to_hz {
1217        return samples.to_vec();
1218    }
1219    let ratio = from_hz as f64 / to_hz as f64;
1220    let out_len = (samples.len() as f64 / ratio).ceil() as usize;
1221    (0..out_len)
1222        .map(|i| {
1223            let src = i as f64 * ratio;
1224            let lo = src.floor() as usize;
1225            let hi = (lo + 1).min(samples.len() - 1);
1226            let frac = (src - lo as f64) as f32;
1227            samples[lo] * (1.0 - frac) + samples[hi] * frac
1228        })
1229        .collect()
1230}
1231
1232// ─── Unit tests ───────────────────────────────────────────────────────────────
1233
1234#[cfg(test)]
1235mod tests {
1236    use super::*;
1237
1238    #[test]
1239    fn test_fsq_decode_shape() {
1240        // Minimal project_out: 4-dim output, 8-dim input
1241        let w = Array2::ones((4, 8));
1242        let b = Array1::zeros(4);
1243        let codes = vec![0i32, 1, 2, 65535];
1244        let out = fsq_decode(&codes, w.view(), b.view());
1245        assert_eq!(out.shape(), &[4, 4]);
1246    }
1247
1248    #[test]
1249    fn test_fsq_code_0() {
1250        // Code 0 → all digits 0 → all scaled to -1.0
1251        // project_out identity (8×8)
1252        let w = Array2::eye(8);
1253        let b = Array1::zeros(8);
1254        let out = fsq_decode(&[0], w.view(), b.view());
1255        for v in out.iter() {
1256            assert!((*v + 1.0).abs() < 1e-5, "expected -1.0, got {v}");
1257        }
1258    }
1259
1260    #[test]
1261    fn test_fsq_code_max() {
1262        // Code 65535 = 4^8 - 1 → all digits 3 → all scaled to 1.0
1263        let w = Array2::eye(8);
1264        let b = Array1::zeros(8);
1265        let out = fsq_decode(&[65535], w.view(), b.view());
1266        for v in out.iter() {
1267            assert!((*v - 1.0).abs() < 1e-5, "expected 1.0, got {v}");
1268        }
1269    }
1270
1271    #[test]
1272    fn test_linear_shape() {
1273        let x = Array2::ones((5, 3));
1274        let w = Array2::ones((7, 3));
1275        let b = Array1::zeros(7);
1276        let out = linear(x.view(), w.view(), Some(b.view()));
1277        assert_eq!(out.shape(), &[5, 7]);
1278    }
1279
1280    #[test]
1281    fn test_conv1d_same_length() {
1282        let c_in = 4;
1283        let c_out = 8;
1284        let t = 16;
1285        let k = 3;
1286        let x = Array2::ones((c_in, t));
1287        let w = Array3::ones((c_out, c_in, k));
1288        let b = Array1::zeros(c_out);
1289        let out = conv1d(x.view(), w.view(), Some(b.view()), 1);
1290        assert_eq!(out.shape(), &[c_out, t]); // same length
1291    }
1292
1293    #[test]
1294    fn test_group_norm_shape() {
1295        let c = 64;
1296        let t = 10;
1297        let x = Array2::ones((c, t));
1298        let w = Array1::ones(c);
1299        let b = Array1::zeros(c);
1300        let out = group_norm(x.view(), 4, w.view(), b.view(), 1e-6);
1301        assert_eq!(out.shape(), &[c, t]);
1302        // All-ones input → mean 1, var 0 → norm 0*w + b = 0
1303        for &v in out.iter() {
1304            assert!(
1305                v.abs() < 1e-4,
1306                "expected ~0 after group_norm of all-ones, got {v}"
1307            );
1308        }
1309    }
1310
1311    #[test]
1312    fn test_layer_norm_shape() {
1313        let t = 5;
1314        let c = 32;
1315        let x = Array2::from_elem((t, c), 2.0f32);
1316        let w = Array1::ones(c);
1317        let b = Array1::zeros(c);
1318        let out = layer_norm(x.view(), w.view(), b.view(), 1e-6);
1319        assert_eq!(out.shape(), &[t, c]);
1320        // Constant input → LayerNorm output is 0
1321        for &v in out.iter() {
1322            assert!(v.abs() < 1e-4, "expected ~0, got {v}");
1323        }
1324    }
1325
1326    #[test]
1327    fn test_rms_norm_shape() {
1328        let t = 3;
1329        let c = 8;
1330        let x = Array2::ones((t, c));
1331        let w = Array1::ones(c);
1332        let out = rms_norm(x.view(), w.view(), 1e-6);
1333        assert_eq!(out.shape(), &[t, c]);
1334        // RMSNorm of all-ones → 1/rms(1) * 1 = 1
1335        for &v in out.iter() {
1336            assert!((v - 1.0).abs() < 1e-4, "expected 1.0, got {v}");
1337        }
1338    }
1339
1340    #[test]
1341    fn test_rope_shape_preserved() {
1342        let t = 4;
1343        let n_heads = 2;
1344        let head_dim = 8;
1345        let mut x = Array3::ones((t, n_heads, head_dim));
1346        apply_rope(&mut x);
1347        assert_eq!(x.shape(), &[t, n_heads, head_dim]);
1348    }
1349
1350    #[test]
1351    fn test_hann_window() {
1352        let w = hann_window(4);
1353        assert_eq!(w.len(), 4);
1354        // Hann window: w[0] = 0, w[n/2] = 1, w[n] = 0
1355        assert!(w[0].abs() < 1e-6);
1356        assert!((w[2] - 1.0).abs() < 1e-6);
1357    }
1358
1359    fn make_ifft(n_fft: usize) -> std::sync::Arc<dyn rustfft::Fft<f32>> {
1360        FftPlanner::<f32>::new().plan_fft_inverse(n_fft)
1361    }
1362
1363    #[test]
1364    fn test_istft_length() {
1365        let hop = 4;
1366        let n_fft = 16; // hop * 4
1367        let t = 10;
1368        let n_bins = n_fft / 2 + 1; // 9
1369        // Zero mag → exp(0)=1 magnitude, zero phase → cos(0)=1, sin(0)=0
1370        let mag = Array2::zeros((n_bins, t));
1371        let phase = Array2::zeros((n_bins, t));
1372        let win = hann_window(n_fft);
1373        let ifft = make_ifft(n_fft);
1374        let audio = istft_burn(mag.view(), phase.view(), hop, &win, ifft.as_ref());
1375        // center=True: output is exactly T*hop samples
1376        assert_eq!(
1377            audio.len(),
1378            t * hop,
1379            "expected {} samples, got {}",
1380            t * hop,
1381            audio.len()
1382        );
1383    }
1384
1385    #[test]
1386    fn test_istft_clamp_does_not_blow_up() {
1387        // Log-magnitudes well above ln(100)≈4.6 must be clamped to 100 (linear),
1388        // not allowed to reach exp(large) ≈ infinity.
1389        let hop = 4;
1390        let n_fft = 16;
1391        let t = 4;
1392        let n_bins = n_fft / 2 + 1;
1393        // All log-magnitudes = 50 (would give exp(50) ≈ 5e21 without the fix)
1394        let mag = Array2::from_elem((n_bins, t), 50.0f32);
1395        let phase = Array2::zeros((n_bins, t));
1396        let win = hann_window(n_fft);
1397        let ifft = make_ifft(n_fft);
1398        let audio = istft_burn(mag.view(), phase.view(), hop, &win, ifft.as_ref());
1399        // All samples must be finite and ≤ some reasonable bound (the clamp
1400        // limits linear magnitude to 1e2, so waveform values should be bounded)
1401        for &s in &audio {
1402            assert!(s.is_finite(), "sample is not finite: {s}");
1403            assert!(s.abs() < 1e6, "sample magnitude suspiciously large: {s}");
1404        }
1405    }
1406
1407    #[test]
1408    fn test_burn_feature_fn() {
1409        let _ = crate::features::burn_feature_enabled();
1410    }
1411
1412    #[test]
1413    fn test_resample_identity() {
1414        let s: Vec<f32> = (0..100).map(|i| i as f32).collect();
1415        let r = resample(&s, 16_000, 16_000);
1416        assert_eq!(r, s);
1417    }
1418
1419    /// Active decoder path (eager / burn / rlx) must match the ndarray gold forward.
1420    #[test]
1421    fn decode_output_matches_eager_forward() {
1422        let Some(path) = crate::decoder::decoder_weights_path_if_available() else {
1423            eprintln!("skip decode_output_matches_eager_forward: set NEUTTS_DECODER_PATH");
1424            return;
1425        };
1426
1427        let codes: Vec<i32> = vec![0, 42, 128, 512, 1023];
1428        let dec = NeuCodecDecoder::from_file(&path).expect("NeuCodecDecoder::from_file");
1429        let actual = dec.decode(&codes).expect("decode");
1430        eprintln!(
1431            "decode_output_matches_eager_forward: backend={}",
1432            dec.backend_name()
1433        );
1434
1435        let data = std::fs::read(&path).expect("read safetensors");
1436        let st = safetensors::SafeTensors::deserialize(&data).expect("safetensors");
1437        let w = load_decoder_weights(&st, &None).expect("load_decoder_weights");
1438        let expected = decode_forward(&codes, &w);
1439
1440        assert_eq!(actual.len(), expected.len(), "length mismatch");
1441        for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {
1442            assert!(a.is_finite() && e.is_finite(), "non-finite at {i}");
1443            let diff = (a - e).abs();
1444            assert!(
1445                diff < 1e-3,
1446                "sample {i}: actual={a} expected={e} diff={diff} backend={}",
1447                dec.backend_name()
1448            );
1449        }
1450    }
1451}