1use std::path::{Path, PathBuf};
38
39use anyhow::{Context, Result, bail};
40use ndarray::{Array1, Array2, Array3, ArrayView1, ArrayView2, ArrayView3, s};
41use rustfft::{FftPlanner, num_complex::Complex};
42use safetensors::SafeTensors;
43
44pub const SAMPLE_RATE: u32 = 24_000;
48
49pub const ENCODER_SAMPLE_RATE: u32 = 16_000;
51
52pub const SAMPLES_PER_TOKEN: usize = 480;
55
56pub const ENCODER_SAMPLES_PER_TOKEN: usize = 320;
58
59pub const ENCODER_DEFAULT_INPUT_SAMPLES: usize = 16_000 * 10;
61
62pub(crate) const FSQ_LEVELS: [i32; 8] = [4, 4, 4, 4, 4, 4, 4, 4];
68
69pub(crate) const FSQ_BASIS: [i32; 8] = [1, 4, 16, 64, 256, 1_024, 4_096, 16_384];
72
73fn load_f32(st: &SafeTensors<'_>, name: &str) -> Result<Vec<f32>> {
76 let view = st
77 .tensor(name)
78 .with_context(|| format!("Missing weight: {name}"))?;
79 let raw = view.data();
80 use safetensors::tensor::Dtype;
81 Ok(match view.dtype() {
82 Dtype::F32 => {
83 assert!(
87 raw.len() % 4 == 0,
88 "F32 tensor byte length not divisible by 4"
89 );
90 let n = raw.len() / 4;
91 let mut out = Vec::with_capacity(n);
92 #[cfg(target_endian = "little")]
95 {
96 unsafe {
99 std::ptr::copy_nonoverlapping(
100 raw.as_ptr(),
101 out.as_mut_ptr() as *mut u8,
102 raw.len(),
103 );
104 out.set_len(n);
105 }
106 }
107 #[cfg(not(target_endian = "little"))]
108 {
109 out.extend(
110 raw.chunks_exact(4)
111 .map(|b| f32::from_le_bytes([b[0], b[1], b[2], b[3]])),
112 );
113 }
114 out
115 }
116 Dtype::BF16 => raw
117 .chunks_exact(2)
118 .map(|b| {
119 let bits = u16::from_le_bytes([b[0], b[1]]);
120 f32::from_bits((bits as u32) << 16)
121 })
122 .collect(),
123 dt => bail!("Tensor {name}: unsupported dtype {dt:?} (expected F32 or BF16)"),
124 })
125}
126
127fn shape_of(st: &SafeTensors<'_>, name: &str) -> Result<Vec<usize>> {
128 Ok(st
129 .tensor(name)
130 .with_context(|| format!("Missing weight: {name}"))?
131 .shape()
132 .to_vec())
133}
134
135fn as1d(data: Vec<f32>, n: usize) -> Array1<f32> {
136 Array1::from_shape_vec(n, data).expect("1-D shape mismatch")
137}
138
139fn as2d(data: Vec<f32>, rows: usize, cols: usize) -> Array2<f32> {
140 Array2::from_shape_vec((rows, cols), data).expect("2-D shape mismatch")
141}
142
143fn as3d(data: Vec<f32>, d0: usize, d1: usize, d2: usize) -> Array3<f32> {
144 Array3::from_shape_vec((d0, d1, d2), data).expect("3-D shape mismatch")
145}
146
147fn linear(x: ArrayView2<f32>, w: ArrayView2<f32>, b: Option<ArrayView1<f32>>) -> Array2<f32> {
156 let mut out = x.dot(&w.t()); if let Some(b) = b {
158 out += &b;
159 }
160 out
161}
162
163fn conv1d(
170 x: ArrayView2<f32>,
171 w: ArrayView3<f32>,
172 b: Option<ArrayView1<f32>>,
173 pad: usize,
174) -> Array2<f32> {
175 let (c_in, t) = (x.shape()[0], x.shape()[1]);
176 let (c_out, _, k) = (w.shape()[0], w.shape()[1], w.shape()[2]);
177
178 let mut col = Array2::<f32>::zeros((t, c_in * k));
180 for ti in 0..t {
181 for ci in 0..c_in {
182 for ki in 0..k {
183 let src = ti + ki;
184 if src >= pad && src < t + pad {
185 col[[ti, ci * k + ki]] = x[[ci, src - pad]];
186 }
187 }
189 }
190 }
191
192 let w2 = w
194 .into_shape_with_order((c_out, c_in * k))
195 .expect("conv1d reshape");
196
197 let out_t = col.dot(&w2.t());
199 let mut out = out_t.t().to_owned(); if let Some(b) = b {
202 use ndarray::Axis;
204 out += &b.view().insert_axis(Axis(1));
205 }
206 out
207}
208
209fn group_norm(
215 x: ArrayView2<f32>,
216 n_groups: usize,
217 w: ArrayView1<f32>,
218 b: ArrayView1<f32>,
219 eps: f32,
220) -> Array2<f32> {
221 let (c, t) = (x.shape()[0], x.shape()[1]);
222 let group_size = c / n_groups;
223 let n = (group_size * t) as f32;
224 let mut out = Array2::<f32>::zeros((c, t));
225
226 for g in 0..n_groups {
227 let c_start = g * group_size;
228 let c_end = c_start + group_size;
229 let block = x.slice(s![c_start..c_end, ..]);
230
231 let mean = block.iter().sum::<f32>() / n;
233 let var = block
235 .iter()
236 .map(|&v| {
237 let d = v - mean;
238 d * d
239 })
240 .sum::<f32>()
241 / n;
242 let inv_std = 1.0 / (var + eps).sqrt();
243
244 for ci in c_start..c_end {
245 let scale = inv_std * w[ci];
246 let shift = b[ci];
247 for ti in 0..t {
248 out[[ci, ti]] = (x[[ci, ti]] - mean) * scale + shift;
249 }
250 }
251 }
252 out
253}
254
255fn layer_norm(x: ArrayView2<f32>, w: ArrayView1<f32>, b: ArrayView1<f32>, eps: f32) -> Array2<f32> {
260 let (t, c) = (x.shape()[0], x.shape()[1]);
261 let c_f = c as f32;
262 let mut out = Array2::<f32>::zeros((t, c));
263 for ti in 0..t {
264 let row = x.slice(s![ti, ..]);
265 let mean = row.iter().sum::<f32>() / c_f;
266 let var = row
267 .iter()
268 .map(|&v| {
269 let d = v - mean;
270 d * d
271 })
272 .sum::<f32>()
273 / c_f;
274 let inv_std = 1.0 / (var + eps).sqrt();
275 for ci in 0..c {
276 out[[ti, ci]] = (x[[ti, ci]] - mean) * inv_std * w[ci] + b[ci];
277 }
278 }
279 out
280}
281
282fn rms_norm(x: ArrayView2<f32>, w: ArrayView1<f32>, eps: f32) -> Array2<f32> {
287 let (t, c) = (x.shape()[0], x.shape()[1]);
288 let c_f = c as f32;
289 let mut out = Array2::<f32>::zeros((t, c));
290 for ti in 0..t {
291 let row = x.slice(s![ti, ..]);
292 let ms = row.iter().map(|&v| v * v).sum::<f32>() / c_f;
293 let scale = 1.0 / (ms + eps).sqrt();
294 for ci in 0..c {
295 out[[ti, ci]] = x[[ti, ci]] * scale * w[ci];
296 }
297 }
298 out
299}
300
301#[inline(always)]
303fn silu(x: f32) -> f32 {
304 x / (1.0 + (-x).exp())
305}
306
307fn softmax_inplace(x: &mut [f32]) {
309 let max = x.iter().cloned().fold(f32::NEG_INFINITY, f32::max);
310 let mut sum = 0.0f32;
311 x.iter_mut().for_each(|v| {
312 *v = (*v - max).exp();
313 sum += *v;
314 });
315 x.iter_mut().for_each(|v| *v /= sum);
316}
317
318fn fsq_decode(
329 codes: &[i32],
330 proj_w: ArrayView2<f32>, proj_b: ArrayView1<f32>, ) -> Array2<f32> {
333 let t = codes.len();
334 let _out_dim = proj_w.shape()[0];
335
336 let mut digits = Array2::<f32>::zeros((t, FSQ_BASIS.len()));
338 for (i, &code) in codes.iter().enumerate() {
339 for (j, (&basis, &levels)) in FSQ_BASIS.iter().zip(FSQ_LEVELS.iter()).enumerate() {
340 let d = (code / basis) % levels;
341 digits[[i, j]] = d as f32 / 1.5 - 1.0;
344 }
345 }
346
347 linear(digits.view(), proj_w, Some(proj_b))
349}
350
351#[cfg(not(feature = "precise"))]
381#[inline(always)]
382pub(crate) fn rope_sin_cos(x: f32) -> (f32, f32) {
383 use std::f32::consts::TAU;
384 let x = x - TAU * (x * (1.0 / TAU)).round();
386 let x2 = x * x;
387 let s = x * (1.0 + x2 * (-1.0 / 6.0 + x2 * (1.0 / 120.0 - x2 * (1.0 / 5040.0))));
389 let c = 1.0 + x2 * (-0.5 + x2 * (1.0 / 24.0 - x2 * (1.0 / 720.0)));
391 (s, c)
392}
393
394#[cfg(feature = "precise")]
395#[inline(always)]
396pub(crate) fn rope_sin_cos(x: f32) -> (f32, f32) {
397 x.sin_cos()
398}
399
400fn apply_rope(x: &mut Array3<f32>) {
411 let (t, n_heads, head_dim) = (x.shape()[0], x.shape()[1], x.shape()[2]);
412 let half = head_dim / 2;
413
414 let inv_freqs: Vec<f32> = (0..half)
416 .map(|i| 1.0_f32 / 10_000_f32.powf(2.0 * i as f32 / head_dim as f32))
417 .collect();
418
419 for p in 0..t {
420 let p_f = p as f32;
421 for i in 0..half {
422 let (s, c) = rope_sin_cos(p_f * inv_freqs[i]);
425 for h in 0..n_heads {
427 let x1 = x[[p, h, i]];
428 let x2 = x[[p, h, i + half]];
429 x[[p, h, i]] = x1 * c - x2 * s;
430 x[[p, h, i + half]] = x1 * s + x2 * c;
431 }
432 }
433 }
434}
435
436pub(crate) struct TransformerWeights {
439 pub(crate) att_norm_w: Array1<f32>, pub(crate) c_attn_w: Array2<f32>, pub(crate) c_proj_w: Array2<f32>, pub(crate) ffn_norm_w: Array1<f32>, pub(crate) fc1_w: Array2<f32>, pub(crate) fc2_w: Array2<f32>, }
446
447fn transformer_block(x: ArrayView2<f32>, w: &TransformerWeights, n_heads: usize) -> Array2<f32> {
451 let (t, d) = (x.shape()[0], x.shape()[1]);
452 let head_dim = d / n_heads;
453
454 let normed = rms_norm(x, w.att_norm_w.view(), 1e-6);
456 let qkv = linear(normed.view(), w.c_attn_w.view(), None);
458
459 let q_flat = qkv.slice(s![.., 0..d]).to_owned();
461 let k_flat = qkv.slice(s![.., d..2 * d]).to_owned();
462 let v_flat = qkv.slice(s![.., 2 * d..]).to_owned();
463
464 let mut q = q_flat
466 .into_shape_with_order((t, n_heads, head_dim))
467 .expect("q reshape");
468 let mut k = k_flat
469 .into_shape_with_order((t, n_heads, head_dim))
470 .expect("k reshape");
471 let v = v_flat
472 .into_shape_with_order((t, n_heads, head_dim))
473 .expect("v reshape");
474
475 apply_rope(&mut q);
476 apply_rope(&mut k);
477
478 let scale = (head_dim as f32).sqrt().recip();
480 let mut attn_out = Array3::<f32>::zeros((t, n_heads, head_dim));
482
483 for h in 0..n_heads {
484 let qh = q.slice(s![.., h, ..]).to_owned(); let kh = k.slice(s![.., h, ..]).to_owned();
486 let vh = v.slice(s![.., h, ..]).to_owned();
487
488 let mut scores = qh.dot(&kh.t());
490 scores.mapv_inplace(|v| v * scale);
491
492 for ti in 0..t {
494 softmax_inplace(scores.slice_mut(s![ti, ..]).as_slice_mut().unwrap());
495 }
496
497 let wv = scores.dot(&vh);
499 attn_out.slice_mut(s![.., h, ..]).assign(&wv);
500 }
501
502 let attn_flat = attn_out
504 .into_shape_with_order((t, d))
505 .expect("attn out reshape");
506
507 let attn_proj = linear(attn_flat.view(), w.c_proj_w.view(), None);
509
510 let x_attn = &x + &attn_proj;
512
513 let normed2 = rms_norm(x_attn.view(), w.ffn_norm_w.view(), 1e-6);
515 let h1 = linear(normed2.view(), w.fc1_w.view(), None);
516 let h1_act = h1.mapv(silu);
517 let h2 = linear(h1_act.view(), w.fc2_w.view(), None);
518
519 &x_attn + &h2
520}
521
522pub(crate) struct ResnetBlockWeights {
525 pub(crate) norm1_w: Array1<f32>, pub(crate) norm1_b: Array1<f32>,
527 pub(crate) conv1_w: Array3<f32>, pub(crate) conv1_b: Array1<f32>,
529 pub(crate) norm2_w: Array1<f32>,
530 pub(crate) norm2_b: Array1<f32>,
531 pub(crate) conv2_w: Array3<f32>, pub(crate) conv2_b: Array1<f32>,
533}
534
535fn resnet_block(x: ArrayView2<f32>, w: &ResnetBlockWeights) -> Array2<f32> {
539 let h = group_norm(x, 32, w.norm1_w.view(), w.norm1_b.view(), 1e-6);
541 let h = h.mapv(silu);
542 let h = conv1d(h.view(), w.conv1_w.view(), Some(w.conv1_b.view()), 1);
543
544 let h = group_norm(h.view(), 32, w.norm2_w.view(), w.norm2_b.view(), 1e-6);
546 let h = h.mapv(silu);
547 let h = conv1d(h.view(), w.conv2_w.view(), Some(w.conv2_b.view()), 1);
548
549 &x + &h
551}
552
553pub(crate) fn istft_burn(
581 mag: ArrayView2<f32>,
582 phase: ArrayView2<f32>,
583 hop: usize,
584 window: &[f32],
585 ifft: &dyn rustfft::Fft<f32>,
586) -> Vec<f32> {
587 let n_bins = mag.shape()[0]; let n_frames = mag.shape()[1];
589 let n_fft = (n_bins - 1) * 2;
590 debug_assert_eq!(n_fft, window.len());
591 debug_assert_eq!(hop, n_fft / 4);
592
593 let out_size = (n_frames - 1) * hop + n_fft;
595 let mut y = vec![0.0f32; out_size];
596 let mut env = vec![0.0f32; out_size];
597
598 let mut buf = vec![Complex::<f32>::default(); n_fft];
599
600 for ti in 0..n_frames {
601 for fi in 0..n_bins {
608 let m = mag[[fi, ti]].exp().min(1e2); let p = phase[[fi, ti]];
610 buf[fi] = Complex::new(m * p.cos(), m * p.sin());
611 }
612 for fi in 1..n_bins - 1 {
614 buf[n_fft - fi] = buf[fi].conj();
615 }
616
617 ifft.process(&mut buf);
619
620 let norm = n_fft as f32;
622 let offset = ti * hop;
623 for i in 0..n_fft {
624 let sample = buf[i].re / norm * window[i];
625 y[offset + i] += sample;
626 env[offset + i] += window[i] * window[i];
627 }
628 }
629
630 for i in 0..out_size {
632 if env[i] > 1e-11 {
633 y[i] /= env[i];
634 }
635 }
636
637 let start = n_fft / 2;
645 let length = n_frames * hop;
646 y[start..start + length].to_vec()
647}
648
649fn hann_window(n: usize) -> Vec<f32> {
651 (0..n)
652 .map(|i| 0.5 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / n as f32).cos()))
653 .collect()
654}
655
656pub(crate) struct DecoderWeights {
659 pub(crate) fsq_proj_w: Array2<f32>, pub(crate) fsq_proj_b: Array1<f32>, pub(crate) fc_post_a_w: Array2<f32>, pub(crate) fc_post_a_b: Array1<f32>, pub(crate) embed_w: Array3<f32>, pub(crate) embed_b: Array1<f32>, pub(crate) prior_net: Vec<ResnetBlockWeights>,
673
674 pub(crate) transformers: Vec<TransformerWeights>,
676
677 pub(crate) final_norm_w: Array1<f32>,
679 pub(crate) final_norm_b: Array1<f32>,
680
681 pub(crate) post_net: Vec<ResnetBlockWeights>,
683
684 pub(crate) head_w: Array2<f32>, pub(crate) head_b: Array1<f32>, pub(crate) window: Vec<f32>, pub(crate) hidden_dim: usize,
693 pub(crate) hop_length: usize,
694 pub(crate) depth: usize,
695 pub(crate) n_heads: usize,
696
697 pub(crate) ifft_plan: std::sync::Arc<dyn rustfft::Fft<f32>>,
700}
701
702fn load_resnet_block(st: &SafeTensors<'_>, prefix: &str, c: usize) -> Result<ResnetBlockWeights> {
703 Ok(ResnetBlockWeights {
704 norm1_w: as1d(load_f32(st, &format!("{prefix}.norm1.weight"))?, c),
705 norm1_b: as1d(load_f32(st, &format!("{prefix}.norm1.bias"))?, c),
706 conv1_w: as3d(load_f32(st, &format!("{prefix}.conv1.weight"))?, c, c, 3),
707 conv1_b: as1d(load_f32(st, &format!("{prefix}.conv1.bias"))?, c),
708 norm2_w: as1d(load_f32(st, &format!("{prefix}.norm2.weight"))?, c),
709 norm2_b: as1d(load_f32(st, &format!("{prefix}.norm2.bias"))?, c),
710 conv2_w: as3d(load_f32(st, &format!("{prefix}.conv2.weight"))?, c, c, 3),
711 conv2_b: as1d(load_f32(st, &format!("{prefix}.conv2.bias"))?, c),
712 })
713}
714
715fn load_transformer(st: &SafeTensors<'_>, prefix: &str, d: usize) -> Result<TransformerWeights> {
716 Ok(TransformerWeights {
717 att_norm_w: as1d(load_f32(st, &format!("{prefix}.att_norm.weight"))?, d),
718 c_attn_w: as2d(
719 load_f32(st, &format!("{prefix}.att.c_attn.weight"))?,
720 3 * d,
721 d,
722 ),
723 c_proj_w: as2d(load_f32(st, &format!("{prefix}.att.c_proj.weight"))?, d, d),
724 ffn_norm_w: as1d(load_f32(st, &format!("{prefix}.ffn_norm.weight"))?, d),
725 fc1_w: as2d(load_f32(st, &format!("{prefix}.mlp.fc1.weight"))?, 4 * d, d),
726 fc2_w: as2d(load_f32(st, &format!("{prefix}.mlp.fc2.weight"))?, d, 4 * d),
727 })
728}
729
730fn load_decoder_weights(
731 st: &SafeTensors<'_>,
732 user_meta: &Option<std::collections::HashMap<String, String>>,
733) -> Result<DecoderWeights> {
734 let embed_shape = shape_of(st, "generator.backbone.embed.weight")?;
736 let hidden_dim = embed_shape[0]; let head_shape = shape_of(st, "generator.head.out.weight")?;
739 let out_dim = head_shape[0]; let hop_length = (out_dim - 2) / 4;
741
742 let depth = (0..64)
744 .take_while(|&i| {
745 st.tensor(&format!(
746 "generator.backbone.transformers.{i}.att_norm.weight"
747 ))
748 .is_ok()
749 })
750 .count();
751
752 if depth == 0 {
753 bail!("No transformer blocks found — is the safetensors file correct?");
754 }
755
756 let n_heads: usize = user_meta
758 .as_ref()
759 .and_then(|m| m.get("n_heads"))
760 .and_then(|s| s.parse().ok())
761 .unwrap_or(16);
762
763 let fsq_proj_key = if st
766 .tensor("generator.quantizer.fsqs.0.project_out.weight")
767 .is_ok()
768 {
769 "generator.quantizer.fsqs.0.project_out.weight"
770 } else {
771 "generator.quantizer.project_out.weight"
772 };
773 let fsq_bias_key = if st
774 .tensor("generator.quantizer.fsqs.0.project_out.bias")
775 .is_ok()
776 {
777 "generator.quantizer.fsqs.0.project_out.bias"
778 } else {
779 "generator.quantizer.project_out.bias"
780 };
781
782 let fsq_shape = shape_of(st, fsq_proj_key)?;
783 let fsq_out_dim = fsq_shape[0]; let fsq_in_dim = fsq_shape[1]; let fsq_proj_w = as2d(load_f32(st, fsq_proj_key)?, fsq_out_dim, fsq_in_dim);
787 let fsq_proj_b = as1d(load_f32(st, fsq_bias_key)?, fsq_out_dim);
788
789 let fc_post_a_w = as2d(load_f32(st, "fc_post_a.weight")?, hidden_dim, fsq_out_dim);
791 let fc_post_a_b = as1d(load_f32(st, "fc_post_a.bias")?, hidden_dim);
792
793 let embed_k = embed_shape[2];
795 let embed_w = as3d(
796 load_f32(st, "generator.backbone.embed.weight")?,
797 hidden_dim,
798 hidden_dim,
799 embed_k,
800 );
801 let embed_b = as1d(load_f32(st, "generator.backbone.embed.bias")?, hidden_dim);
802
803 let prior_net = (0..2)
805 .map(|i| load_resnet_block(st, &format!("generator.backbone.prior_net.{i}"), hidden_dim))
806 .collect::<Result<Vec<_>>>()?;
807
808 let transformers = (0..depth)
810 .map(|i| {
811 load_transformer(
812 st,
813 &format!("generator.backbone.transformers.{i}"),
814 hidden_dim,
815 )
816 })
817 .collect::<Result<Vec<_>>>()?;
818
819 let final_norm_w = as1d(
821 load_f32(st, "generator.backbone.final_layer_norm.weight")?,
822 hidden_dim,
823 );
824 let final_norm_b = as1d(
825 load_f32(st, "generator.backbone.final_layer_norm.bias")?,
826 hidden_dim,
827 );
828
829 let post_net = (0..2)
831 .map(|i| load_resnet_block(st, &format!("generator.backbone.post_net.{i}"), hidden_dim))
832 .collect::<Result<Vec<_>>>()?;
833
834 let n_fft = hop_length * 4;
836 let head_w = as2d(
837 load_f32(st, "generator.head.out.weight")?,
838 out_dim,
839 hidden_dim,
840 );
841 let head_b = as1d(load_f32(st, "generator.head.out.bias")?, out_dim);
842
843 let window = if st.tensor("generator.head.istft.window").is_ok() {
845 load_f32(st, "generator.head.istft.window")?
846 } else {
847 hann_window(n_fft)
848 };
849
850 let ifft_plan = {
854 let mut planner = FftPlanner::<f32>::new();
855 planner.plan_fft_inverse(n_fft)
856 };
857
858 Ok(DecoderWeights {
859 fsq_proj_w,
860 fsq_proj_b,
861 fc_post_a_w,
862 fc_post_a_b,
863 embed_w,
864 embed_b,
865 prior_net,
866 transformers,
867 final_norm_w,
868 final_norm_b,
869 post_net,
870 head_w,
871 head_b,
872 window,
873 hidden_dim,
874 hop_length,
875 depth,
876 n_heads,
877 ifft_plan,
878 })
879}
880
881pub(crate) fn decode_forward(codes: &[i32], w: &DecoderWeights) -> Vec<f32> {
884 let hop = w.hop_length;
885 let n_fft = hop * 4;
886 let embed_k = w.embed_w.shape()[2];
887 let embed_pad = embed_k / 2;
888
889 let emb = fsq_decode(codes, w.fsq_proj_w.view(), w.fsq_proj_b.view());
891
892 let x = linear(emb.view(), w.fc_post_a_w.view(), Some(w.fc_post_a_b.view()));
894
895 let x_ct = x.t().to_owned(); let x_ct = conv1d(
898 x_ct.view(),
899 w.embed_w.view(),
900 Some(w.embed_b.view()),
901 embed_pad,
902 );
903
904 let x_ct = w
906 .prior_net
907 .iter()
908 .fold(x_ct, |acc, rw| resnet_block(acc.view(), rw));
909
910 let x_tc = x_ct.t().to_owned(); let x_tc = w
913 .transformers
914 .iter()
915 .fold(x_tc, |acc, tw| transformer_block(acc.view(), tw, w.n_heads));
916
917 let x_ct = x_tc.t().to_owned(); let x_ct = w
920 .post_net
921 .iter()
922 .fold(x_ct, |acc, rw| resnet_block(acc.view(), rw));
923
924 let x_tc = x_ct.t().to_owned(); let x_tc = layer_norm(
927 x_tc.view(),
928 w.final_norm_w.view(),
929 w.final_norm_b.view(),
930 1e-6,
931 );
932
933 let x_pred = linear(x_tc.view(), w.head_w.view(), Some(w.head_b.view()));
935
936 let x_pred_ct = x_pred.t().to_owned(); let half = (n_fft / 2) + 1; let mag = x_pred_ct.slice(s![0..half, ..]).to_owned();
940 let phase = x_pred_ct.slice(s![half.., ..]).to_owned();
941
942 istft_burn(
944 mag.view(),
945 phase.view(),
946 hop,
947 &w.window,
948 w.ifft_plan.as_ref(),
949 )
950}
951
952pub struct NeuCodecDecoder {
978 weights: DecoderWeights,
979 path: PathBuf,
980
981 #[cfg(feature = "burn")]
992 burn_decoder: std::sync::Mutex<LazyBurnDecoder>,
993}
994
995#[cfg(feature = "burn")]
996enum LazyBurnDecoder {
997 Ready(Option<Box<dyn super::burn::BurnDecoder + Send>>),
999}
1000
1001impl NeuCodecDecoder {
1002 pub fn new() -> Result<Self> {
1004 let path = super::decoder_weights_path()?;
1005 Self::from_file(&path)
1006 }
1007
1008 pub fn from_file(path: &Path) -> Result<Self> {
1010 if !path.exists() {
1011 bail!(
1012 "NeuCodec decoder weights not found: {}\n\
1013 Set NEUTTS_DECODER_PATH or pass an explicit path to NeuCodecDecoder::from_file().",
1014 path.display()
1015 );
1016 }
1017
1018 let file = std::fs::File::open(path)
1022 .with_context(|| format!("Failed to open {}", path.display()))?;
1023 let mmap = unsafe {
1027 memmap2::Mmap::map(&file)
1028 .with_context(|| format!("Failed to mmap {}", path.display()))?
1029 };
1030 let bytes: &[u8] = &mmap;
1031
1032 let (_, file_meta) = SafeTensors::read_metadata(bytes)
1034 .with_context(|| format!("Failed to parse safetensors header: {}", path.display()))?;
1035 let user_meta = file_meta.metadata().clone();
1036
1037 let st = SafeTensors::deserialize(bytes)
1038 .with_context(|| format!("Failed to parse safetensors: {}", path.display()))?;
1039
1040 let weights = load_decoder_weights(&st, &user_meta)
1041 .with_context(|| format!("Failed to load decoder weights from {}", path.display()))?;
1042
1043 drop(st);
1046 drop(mmap);
1047
1048 println!(
1049 "NeuCodec decoder: hidden={}, depth={}, heads={}, hop={} ({} samples/token = {} tokens/s)",
1050 weights.hidden_dim,
1051 weights.depth,
1052 weights.n_heads,
1053 weights.hop_length,
1054 weights.hop_length,
1055 SAMPLE_RATE as usize / weights.hop_length,
1056 );
1057
1058 #[cfg(feature = "burn")]
1066 let burn_decoder = {
1067 let t0 = std::time::Instant::now();
1068 let dec = super::burn::make_burn_decoder(&weights);
1069 println!(
1070 "NeuCodec: {} backend ready in {:.2} s",
1071 dec.as_ref().map_or("cpu (ndarray)", |b| b.backend_name()),
1072 t0.elapsed().as_secs_f32(),
1073 );
1074 std::sync::Mutex::new(LazyBurnDecoder::Ready(dec))
1075 };
1076
1077 Ok(Self {
1078 weights,
1079 path: path.to_path_buf(),
1080 #[cfg(feature = "burn")]
1081 burn_decoder,
1082 })
1083 }
1084
1085 pub fn decode(&self, codes: &[i32]) -> Result<Vec<f32>> {
1092 if codes.is_empty() {
1093 return Ok(Vec::new());
1094 }
1095
1096 for (i, &code) in codes.iter().enumerate() {
1099 if !(0..=65535).contains(&code) {
1100 anyhow::bail!(
1101 "Speech token at index {i} is out of range: {code} \
1102 (NeuCodec FSQ codes must be in 0..=65535)"
1103 );
1104 }
1105 }
1106
1107 #[cfg(feature = "burn")]
1112 {
1113 let state = self.burn_decoder.lock().unwrap();
1114 if let LazyBurnDecoder::Ready(Some(ref bd)) = *state {
1115 return bd.decode(codes);
1116 }
1117 }
1118
1119 #[cfg(feature = "rlx")]
1121 {
1122 super::rlx::decode(codes, &self.weights)
1123 }
1124 #[cfg(not(feature = "rlx"))]
1125 {
1126 Ok(decode_forward(codes, &self.weights))
1127 }
1128 }
1129
1130 pub fn backend_name(&self) -> &'static str {
1132 #[cfg(feature = "burn")]
1133 {
1134 let state = self.burn_decoder.lock().unwrap();
1135 if let LazyBurnDecoder::Ready(Some(bd)) = &*state {
1136 return bd.backend_name();
1137 }
1138 }
1139 if cfg!(feature = "rlx") {
1140 "rlx/eager-parity"
1141 } else {
1142 "codec/eager-ndarray"
1143 }
1144 }
1145
1146 pub fn load(path: &Path) -> Result<Self> {
1148 Self::from_file(path)
1149 }
1150
1151 pub fn weights_path(&self) -> &Path {
1153 &self.path
1154 }
1155
1156 pub fn hop_length(&self) -> usize {
1158 self.weights.hop_length
1159 }
1160}
1161
1162pub struct NeuCodecEncoder;
1179
1180impl NeuCodecEncoder {
1181 pub fn new() -> Result<Self> {
1183 bail!(
1184 "The NeuCodec encoder is not yet implemented in the pure-Rust build.\n\
1185 \n\
1186 To encode reference audio, use the Python neucodec package:\n\
1187 \n\
1188 \tpip install neucodec huggingface_hub\n\
1189 \tpython scripts/encode_reference.py --audio reference.wav --out ref.npy\n\
1190 \n\
1191 Then pass the .npy file via --ref-codes to the synthesis examples."
1192 )
1193 }
1194
1195 pub fn load(_path: &Path) -> Result<Self> {
1197 Self::new()
1198 }
1199
1200 pub fn encode_wav(&self, _path: &Path) -> Result<Vec<i32>> {
1202 bail!("Encoder not implemented — see NeuCodecEncoder docs")
1203 }
1204
1205 pub fn backend_name(&self) -> &str {
1207 "not available"
1208 }
1209}
1210
1211#[allow(dead_code)]
1215pub fn resample(samples: &[f32], from_hz: u32, to_hz: u32) -> Vec<f32> {
1216 if from_hz == to_hz {
1217 return samples.to_vec();
1218 }
1219 let ratio = from_hz as f64 / to_hz as f64;
1220 let out_len = (samples.len() as f64 / ratio).ceil() as usize;
1221 (0..out_len)
1222 .map(|i| {
1223 let src = i as f64 * ratio;
1224 let lo = src.floor() as usize;
1225 let hi = (lo + 1).min(samples.len() - 1);
1226 let frac = (src - lo as f64) as f32;
1227 samples[lo] * (1.0 - frac) + samples[hi] * frac
1228 })
1229 .collect()
1230}
1231
1232#[cfg(test)]
1235mod tests {
1236 use super::*;
1237
1238 #[test]
1239 fn test_fsq_decode_shape() {
1240 let w = Array2::ones((4, 8));
1242 let b = Array1::zeros(4);
1243 let codes = vec![0i32, 1, 2, 65535];
1244 let out = fsq_decode(&codes, w.view(), b.view());
1245 assert_eq!(out.shape(), &[4, 4]);
1246 }
1247
1248 #[test]
1249 fn test_fsq_code_0() {
1250 let w = Array2::eye(8);
1253 let b = Array1::zeros(8);
1254 let out = fsq_decode(&[0], w.view(), b.view());
1255 for v in out.iter() {
1256 assert!((*v + 1.0).abs() < 1e-5, "expected -1.0, got {v}");
1257 }
1258 }
1259
1260 #[test]
1261 fn test_fsq_code_max() {
1262 let w = Array2::eye(8);
1264 let b = Array1::zeros(8);
1265 let out = fsq_decode(&[65535], w.view(), b.view());
1266 for v in out.iter() {
1267 assert!((*v - 1.0).abs() < 1e-5, "expected 1.0, got {v}");
1268 }
1269 }
1270
1271 #[test]
1272 fn test_linear_shape() {
1273 let x = Array2::ones((5, 3));
1274 let w = Array2::ones((7, 3));
1275 let b = Array1::zeros(7);
1276 let out = linear(x.view(), w.view(), Some(b.view()));
1277 assert_eq!(out.shape(), &[5, 7]);
1278 }
1279
1280 #[test]
1281 fn test_conv1d_same_length() {
1282 let c_in = 4;
1283 let c_out = 8;
1284 let t = 16;
1285 let k = 3;
1286 let x = Array2::ones((c_in, t));
1287 let w = Array3::ones((c_out, c_in, k));
1288 let b = Array1::zeros(c_out);
1289 let out = conv1d(x.view(), w.view(), Some(b.view()), 1);
1290 assert_eq!(out.shape(), &[c_out, t]); }
1292
1293 #[test]
1294 fn test_group_norm_shape() {
1295 let c = 64;
1296 let t = 10;
1297 let x = Array2::ones((c, t));
1298 let w = Array1::ones(c);
1299 let b = Array1::zeros(c);
1300 let out = group_norm(x.view(), 4, w.view(), b.view(), 1e-6);
1301 assert_eq!(out.shape(), &[c, t]);
1302 for &v in out.iter() {
1304 assert!(
1305 v.abs() < 1e-4,
1306 "expected ~0 after group_norm of all-ones, got {v}"
1307 );
1308 }
1309 }
1310
1311 #[test]
1312 fn test_layer_norm_shape() {
1313 let t = 5;
1314 let c = 32;
1315 let x = Array2::from_elem((t, c), 2.0f32);
1316 let w = Array1::ones(c);
1317 let b = Array1::zeros(c);
1318 let out = layer_norm(x.view(), w.view(), b.view(), 1e-6);
1319 assert_eq!(out.shape(), &[t, c]);
1320 for &v in out.iter() {
1322 assert!(v.abs() < 1e-4, "expected ~0, got {v}");
1323 }
1324 }
1325
1326 #[test]
1327 fn test_rms_norm_shape() {
1328 let t = 3;
1329 let c = 8;
1330 let x = Array2::ones((t, c));
1331 let w = Array1::ones(c);
1332 let out = rms_norm(x.view(), w.view(), 1e-6);
1333 assert_eq!(out.shape(), &[t, c]);
1334 for &v in out.iter() {
1336 assert!((v - 1.0).abs() < 1e-4, "expected 1.0, got {v}");
1337 }
1338 }
1339
1340 #[test]
1341 fn test_rope_shape_preserved() {
1342 let t = 4;
1343 let n_heads = 2;
1344 let head_dim = 8;
1345 let mut x = Array3::ones((t, n_heads, head_dim));
1346 apply_rope(&mut x);
1347 assert_eq!(x.shape(), &[t, n_heads, head_dim]);
1348 }
1349
1350 #[test]
1351 fn test_hann_window() {
1352 let w = hann_window(4);
1353 assert_eq!(w.len(), 4);
1354 assert!(w[0].abs() < 1e-6);
1356 assert!((w[2] - 1.0).abs() < 1e-6);
1357 }
1358
1359 fn make_ifft(n_fft: usize) -> std::sync::Arc<dyn rustfft::Fft<f32>> {
1360 FftPlanner::<f32>::new().plan_fft_inverse(n_fft)
1361 }
1362
1363 #[test]
1364 fn test_istft_length() {
1365 let hop = 4;
1366 let n_fft = 16; let t = 10;
1368 let n_bins = n_fft / 2 + 1; let mag = Array2::zeros((n_bins, t));
1371 let phase = Array2::zeros((n_bins, t));
1372 let win = hann_window(n_fft);
1373 let ifft = make_ifft(n_fft);
1374 let audio = istft_burn(mag.view(), phase.view(), hop, &win, ifft.as_ref());
1375 assert_eq!(
1377 audio.len(),
1378 t * hop,
1379 "expected {} samples, got {}",
1380 t * hop,
1381 audio.len()
1382 );
1383 }
1384
1385 #[test]
1386 fn test_istft_clamp_does_not_blow_up() {
1387 let hop = 4;
1390 let n_fft = 16;
1391 let t = 4;
1392 let n_bins = n_fft / 2 + 1;
1393 let mag = Array2::from_elem((n_bins, t), 50.0f32);
1395 let phase = Array2::zeros((n_bins, t));
1396 let win = hann_window(n_fft);
1397 let ifft = make_ifft(n_fft);
1398 let audio = istft_burn(mag.view(), phase.view(), hop, &win, ifft.as_ref());
1399 for &s in &audio {
1402 assert!(s.is_finite(), "sample is not finite: {s}");
1403 assert!(s.abs() < 1e6, "sample magnitude suspiciously large: {s}");
1404 }
1405 }
1406
1407 #[test]
1408 fn test_burn_feature_fn() {
1409 let _ = crate::features::burn_feature_enabled();
1410 }
1411
1412 #[test]
1413 fn test_resample_identity() {
1414 let s: Vec<f32> = (0..100).map(|i| i as f32).collect();
1415 let r = resample(&s, 16_000, 16_000);
1416 assert_eq!(r, s);
1417 }
1418
1419 #[test]
1421 fn decode_output_matches_eager_forward() {
1422 let Some(path) = crate::decoder::decoder_weights_path_if_available() else {
1423 eprintln!("skip decode_output_matches_eager_forward: set NEUTTS_DECODER_PATH");
1424 return;
1425 };
1426
1427 let codes: Vec<i32> = vec![0, 42, 128, 512, 1023];
1428 let dec = NeuCodecDecoder::from_file(&path).expect("NeuCodecDecoder::from_file");
1429 let actual = dec.decode(&codes).expect("decode");
1430 eprintln!(
1431 "decode_output_matches_eager_forward: backend={}",
1432 dec.backend_name()
1433 );
1434
1435 let data = std::fs::read(&path).expect("read safetensors");
1436 let st = safetensors::SafeTensors::deserialize(&data).expect("safetensors");
1437 let w = load_decoder_weights(&st, &None).expect("load_decoder_weights");
1438 let expected = decode_forward(&codes, &w);
1439
1440 assert_eq!(actual.len(), expected.len(), "length mismatch");
1441 for (i, (a, e)) in actual.iter().zip(expected.iter()).enumerate() {
1442 assert!(a.is_finite() && e.is_finite(), "non-finite at {i}");
1443 let diff = (a - e).abs();
1444 assert!(
1445 diff < 1e-3,
1446 "sample {i}: actual={a} expected={e} diff={diff} backend={}",
1447 dec.backend_name()
1448 );
1449 }
1450 }
1451}