1use crate::backbone::embed::EmbeddingTables;
19use crate::codec::layers::{
20 CodecConvBlock, CodecTransformer, compute_semantic_embedding, load_codec_layer, run_conv,
21 run_transformer, take_conv,
22};
23use crate::codec::layout::{EncoderExecBlock, encoder_execution_plan};
24use crate::config::CodecArgs;
25use crate::tokens::AUDIO_TOKEN_OFFSET;
26use crate::voice::VoiceEmbedding;
27use anyhow::{Context, Result, bail, ensure};
28use ndarray::Array2;
29use std::collections::HashMap;
30
31enum EncoderBlock {
32 Transformer(CodecTransformer),
33 Conv(CodecConvBlock),
34}
35
36pub struct CodecEncoder {
37 cfg: CodecArgs,
38 patch_size: usize,
39 input_weight: ndarray::Array3<f32>,
40 input_stride: usize,
41 input_pad_left: usize,
42 blocks: Vec<EncoderBlock>,
43 semantic_embedding: Array2<f32>,
44}
45
46impl CodecEncoder {
47 pub fn from_tensors(
48 prefix: &str,
49 tensors: &HashMap<String, (Vec<f32>, Vec<usize>)>,
50 cfg: &CodecArgs,
51 ) -> Result<Self> {
52 if !has_encoder_tensors(prefix, tensors) {
53 bail!(
54 "codec encoder weights missing under {prefix}input_proj / encoder_blocks.\n\
55 Reference-audio cloning needs a checkpoint that includes trained encoder weights."
56 );
57 }
58 let sem_sum = crate::codec::layers::take2d(
59 tensors,
60 &format!("{prefix}quantizer.semantic_codebook.embedding_sum"),
61 )?;
62 let sem_usage = crate::codec::layers::take1d(
63 tensors,
64 &format!("{prefix}quantizer.semantic_codebook.cluster_usage"),
65 )?;
66 let semantic_embedding = compute_semantic_embedding(&sem_sum, &sem_usage);
67
68 let input_weight = take_conv(tensors, &format!("{prefix}input_proj"))?;
69 let input_kernel = input_weight.shape()[2];
70 let input_stride = 1;
71 let input_pad_left = input_kernel.saturating_sub(input_stride);
72
73 let mut blocks = Vec::new();
74 for step in encoder_execution_plan(cfg)? {
75 match step {
76 EncoderExecBlock::Transformer {
77 block_idx,
78 n_layers,
79 window,
80 } => {
81 blocks.push(EncoderBlock::Transformer(CodecTransformer {
82 window,
83 layers: (0..n_layers)
84 .map(|li| {
85 load_codec_layer(
86 tensors,
87 &format!("{prefix}encoder_blocks.{block_idx}.layers.{li}"),
88 cfg,
89 )
90 })
91 .collect::<Result<_>>()?,
92 }));
93 }
94 EncoderExecBlock::ConvDown {
95 block_idx,
96 kernel,
97 stride,
98 in_dim,
99 out_dim,
100 } => {
101 let weight =
102 take_conv(tensors, &format!("{prefix}encoder_blocks.{block_idx}"))?;
103 let (oc, ic, _) = (weight.shape()[0], weight.shape()[1], weight.shape()[2]);
104 ensure!(
105 oc == out_dim && ic == in_dim,
106 "encoder_blocks.{block_idx} conv shape [{oc}, {ic}, _] \
107 != expected [{out_dim}, {in_dim}, {kernel}]"
108 );
109 blocks.push(EncoderBlock::Conv(CodecConvBlock::Forward {
110 weight,
111 stride,
112 pad_left: kernel.saturating_sub(stride),
113 }));
114 }
115 }
116 }
117
118 Ok(Self {
119 cfg: cfg.clone(),
120 patch_size: cfg.pretransform_patch_size,
121 input_weight,
122 input_stride,
123 input_pad_left,
124 blocks,
125 semantic_embedding,
126 })
127 }
128
129 pub fn encode_pcm_to_voice_embedding(
131 &self,
132 pcm: &[f32],
133 embed: &EmbeddingTables,
134 name: &str,
135 ) -> Result<VoiceEmbedding> {
136 ensure!(!pcm.is_empty(), "reference audio is empty");
137 let latent = self.forward_encoder(pcm)?;
138 let (semantic, acoustic) = self.quantizer_encode(&latent)?;
139 let hidden = embed.hidden_size();
140 let n_frames = semantic.len();
141 let mut data = Vec::with_capacity(n_frames * hidden);
142 for fi in 0..n_frames {
143 let mut frame = vec![0u32; 37];
144 frame[0] = semantic[fi] as u32 + AUDIO_TOKEN_OFFSET;
145 for ai in 0..36 {
146 frame[1 + ai] = acoustic[fi * 36 + ai] + AUDIO_TOKEN_OFFSET;
147 }
148 let row = embed.embed_audio_frame(&frame);
149 data.extend(row.iter().copied());
150 }
151 Ok(VoiceEmbedding {
152 name: name.to_string(),
153 data,
154 n_tokens: n_frames,
155 hidden,
156 })
157 }
158
159 fn forward_encoder(&self, pcm: &[f32]) -> Result<Array2<f32>> {
160 let patch = self.patch_size;
161 let mut samples = pcm.to_vec();
162 let rem = samples.len() % patch;
163 if rem != 0 {
164 samples.extend(std::iter::repeat_n(0f32, patch - rem));
165 }
166 let n_patches = samples.len() / patch;
167 let mut x = Array2::<f32>::zeros((patch, n_patches));
168 for pi in 0..n_patches {
169 for ic in 0..patch {
170 x[[ic, pi]] = samples[pi * patch + ic];
171 }
172 }
173 x = crate::math::conv1d(
174 x.view(),
175 self.input_weight.view(),
176 self.input_stride,
177 self.input_pad_left,
178 );
179
180 for block in &self.blocks {
181 match block {
182 EncoderBlock::Conv(conv) => {
183 x = run_conv(&x, conv);
184 }
185 EncoderBlock::Transformer(tr) => {
186 x = run_transformer(&x, tr)?;
187 }
188 }
189 }
190
191 let latent_dim = self.cfg.latent_dim();
192 let (d, _t) = x.dim();
193 ensure!(
194 d == latent_dim,
195 "encoder output channels {d} != latent_dim {latent_dim}"
196 );
197 Ok(x.slice(ndarray::s![..latent_dim, ..]).to_owned())
198 }
199
200 fn quantizer_encode(&self, latent: &Array2<f32>) -> Result<(Vec<usize>, Vec<u32>)> {
201 let d_sem = self.cfg.semantic_dim;
202 let (_, n_frames) = latent.dim();
203 let levels = self.cfg.acoustic_codebook_size;
204 let mut semantic = Vec::with_capacity(n_frames);
205 let mut acoustic = Vec::with_capacity(n_frames * 36);
206 for fi in 0..n_frames {
207 let mut best_id = 0usize;
208 let mut best_dist = f32::INFINITY;
209 for cid in 0..self.semantic_embedding.dim().0 {
210 let mut dist = 0f32;
211 for di in 0..d_sem {
212 let diff = latent[[di, fi]] - self.semantic_embedding[[cid, di]];
213 dist += diff * diff;
214 }
215 if dist < best_dist {
216 best_dist = dist;
217 best_id = cid;
218 }
219 }
220 semantic.push(best_id);
221 for ai in 0..36 {
222 let v = latent[[d_sem + ai, fi]].tanh();
223 let scaled = ((v + 1.0) * 0.5) * (levels as f32 - 1.0);
224 let code = scaled.round().clamp(0.0, levels as f32 - 1.0) as u32;
225 acoustic.push(code);
226 }
227 }
228 Ok((semantic, acoustic))
229 }
230}
231
232pub fn has_encoder_tensors(
233 prefix: &str,
234 tensors: &HashMap<String, (Vec<f32>, Vec<usize>)>,
235) -> bool {
236 tensors.contains_key(&format!(
237 "{prefix}input_proj.conv.parametrizations.weight.original1"
238 )) || tensors.contains_key(&format!("{prefix}input_proj.conv.weight"))
239}
240
241pub fn has_encoder_weights(keys: &std::collections::HashSet<String>, prefix: &str) -> bool {
242 keys.iter()
243 .any(|k| k.starts_with(&format!("{prefix}input_proj")))
244 && keys
245 .iter()
246 .any(|k| k.starts_with(&format!("{prefix}encoder_blocks")))
247}
248
249pub fn load_mono_wav(path: &std::path::Path, target_rate: u32) -> Result<Vec<f32>> {
250 let mut reader =
251 hound::WavReader::open(path).with_context(|| format!("open wav {}", path.display()))?;
252 let spec = reader.spec();
253 ensure!(
254 spec.channels == 1,
255 "reference wav must be mono (got {} channels)",
256 spec.channels
257 );
258 let samples: Result<Vec<f32>, _> = match spec.sample_format {
259 hound::SampleFormat::Float => reader
260 .samples::<f32>()
261 .map(|s| s.map_err(|e| anyhow::anyhow!("{e}")))
262 .collect(),
263 hound::SampleFormat::Int => reader
264 .samples::<i32>()
265 .map(|s| {
266 s.map(|v| v as f32 / i32::MAX as f32)
267 .map_err(|e| anyhow::anyhow!("{e}"))
268 })
269 .collect(),
270 };
271 let mut pcm = samples?;
272 if spec.sample_rate != target_rate {
273 pcm = resample_linear(&pcm, spec.sample_rate, target_rate);
274 }
275 Ok(pcm)
276}
277
278fn resample_linear(pcm: &[f32], from_rate: u32, to_rate: u32) -> Vec<f32> {
279 if from_rate == to_rate || pcm.is_empty() {
280 return pcm.to_vec();
281 }
282 let out_len = ((pcm.len() as u64 * to_rate as u64) / from_rate as u64) as usize;
283 let mut out = Vec::with_capacity(out_len.max(1));
284 for i in 0..out_len {
285 let src = (i as f64 * from_rate as f64) / to_rate as f64;
286 let i0 = src.floor() as usize;
287 let i1 = (i0 + 1).min(pcm.len() - 1);
288 let frac = (src - i0 as f64) as f32;
289 out.push(pcm[i0] * (1.0 - frac) + pcm[i1] * frac);
290 }
291 out
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[test]
299 fn resample_downsample_halves_length() {
300 let pcm = vec![0.0, 1.0, 0.0, -1.0];
301 let out = resample_linear(&pcm, 24000, 12000);
302 assert_eq!(out.len(), 2);
303 }
304}