Skip to main content

rlx_voxtral_tts/codec/
encoder.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Voxtral codec **encoder** — reference audio → discrete codes → voice embeddings.
17
18use crate::backbone::embed::EmbeddingTables;
19use crate::codec::layers::{
20    CodecConvBlock, CodecTransformer, compute_semantic_embedding, load_codec_layer, run_conv,
21    run_transformer, take_conv,
22};
23use crate::codec::layout::{EncoderExecBlock, encoder_execution_plan};
24use crate::config::CodecArgs;
25use crate::tokens::AUDIO_TOKEN_OFFSET;
26use crate::voice::VoiceEmbedding;
27use anyhow::{Context, Result, bail, ensure};
28use ndarray::Array2;
29use std::collections::HashMap;
30
31enum EncoderBlock {
32    Transformer(CodecTransformer),
33    Conv(CodecConvBlock),
34}
35
36pub struct CodecEncoder {
37    cfg: CodecArgs,
38    patch_size: usize,
39    input_weight: ndarray::Array3<f32>,
40    input_stride: usize,
41    input_pad_left: usize,
42    blocks: Vec<EncoderBlock>,
43    semantic_embedding: Array2<f32>,
44}
45
46impl CodecEncoder {
47    pub fn from_tensors(
48        prefix: &str,
49        tensors: &HashMap<String, (Vec<f32>, Vec<usize>)>,
50        cfg: &CodecArgs,
51    ) -> Result<Self> {
52        if !has_encoder_tensors(prefix, tensors) {
53            bail!(
54                "codec encoder weights missing under {prefix}input_proj / encoder_blocks.\n\
55                 Reference-audio cloning needs a checkpoint that includes trained encoder weights."
56            );
57        }
58        let sem_sum = crate::codec::layers::take2d(
59            tensors,
60            &format!("{prefix}quantizer.semantic_codebook.embedding_sum"),
61        )?;
62        let sem_usage = crate::codec::layers::take1d(
63            tensors,
64            &format!("{prefix}quantizer.semantic_codebook.cluster_usage"),
65        )?;
66        let semantic_embedding = compute_semantic_embedding(&sem_sum, &sem_usage);
67
68        let input_weight = take_conv(tensors, &format!("{prefix}input_proj"))?;
69        let input_kernel = input_weight.shape()[2];
70        let input_stride = 1;
71        let input_pad_left = input_kernel.saturating_sub(input_stride);
72
73        let mut blocks = Vec::new();
74        for step in encoder_execution_plan(cfg)? {
75            match step {
76                EncoderExecBlock::Transformer {
77                    block_idx,
78                    n_layers,
79                    window,
80                } => {
81                    blocks.push(EncoderBlock::Transformer(CodecTransformer {
82                        window,
83                        layers: (0..n_layers)
84                            .map(|li| {
85                                load_codec_layer(
86                                    tensors,
87                                    &format!("{prefix}encoder_blocks.{block_idx}.layers.{li}"),
88                                    cfg,
89                                )
90                            })
91                            .collect::<Result<_>>()?,
92                    }));
93                }
94                EncoderExecBlock::ConvDown {
95                    block_idx,
96                    kernel,
97                    stride,
98                    in_dim,
99                    out_dim,
100                } => {
101                    let weight =
102                        take_conv(tensors, &format!("{prefix}encoder_blocks.{block_idx}"))?;
103                    let (oc, ic, _) = (weight.shape()[0], weight.shape()[1], weight.shape()[2]);
104                    ensure!(
105                        oc == out_dim && ic == in_dim,
106                        "encoder_blocks.{block_idx} conv shape [{oc}, {ic}, _] \
107                         != expected [{out_dim}, {in_dim}, {kernel}]"
108                    );
109                    blocks.push(EncoderBlock::Conv(CodecConvBlock::Forward {
110                        weight,
111                        stride,
112                        pad_left: kernel.saturating_sub(stride),
113                    }));
114                }
115            }
116        }
117
118        Ok(Self {
119            cfg: cfg.clone(),
120            patch_size: cfg.pretransform_patch_size,
121            input_weight,
122            input_stride,
123            input_pad_left,
124            blocks,
125            semantic_embedding,
126        })
127    }
128
129    /// Encode mono PCM at model sample rate into LLM voice rows `[n_frames, hidden]`.
130    pub fn encode_pcm_to_voice_embedding(
131        &self,
132        pcm: &[f32],
133        embed: &EmbeddingTables,
134        name: &str,
135    ) -> Result<VoiceEmbedding> {
136        ensure!(!pcm.is_empty(), "reference audio is empty");
137        let latent = self.forward_encoder(pcm)?;
138        let (semantic, acoustic) = self.quantizer_encode(&latent)?;
139        let hidden = embed.hidden_size();
140        let n_frames = semantic.len();
141        let mut data = Vec::with_capacity(n_frames * hidden);
142        for fi in 0..n_frames {
143            let mut frame = vec![0u32; 37];
144            frame[0] = semantic[fi] as u32 + AUDIO_TOKEN_OFFSET;
145            for ai in 0..36 {
146                frame[1 + ai] = acoustic[fi * 36 + ai] + AUDIO_TOKEN_OFFSET;
147            }
148            let row = embed.embed_audio_frame(&frame);
149            data.extend(row.iter().copied());
150        }
151        Ok(VoiceEmbedding {
152            name: name.to_string(),
153            data,
154            n_tokens: n_frames,
155            hidden,
156        })
157    }
158
159    fn forward_encoder(&self, pcm: &[f32]) -> Result<Array2<f32>> {
160        let patch = self.patch_size;
161        let mut samples = pcm.to_vec();
162        let rem = samples.len() % patch;
163        if rem != 0 {
164            samples.extend(std::iter::repeat_n(0f32, patch - rem));
165        }
166        let n_patches = samples.len() / patch;
167        let mut x = Array2::<f32>::zeros((patch, n_patches));
168        for pi in 0..n_patches {
169            for ic in 0..patch {
170                x[[ic, pi]] = samples[pi * patch + ic];
171            }
172        }
173        x = crate::math::conv1d(
174            x.view(),
175            self.input_weight.view(),
176            self.input_stride,
177            self.input_pad_left,
178        );
179
180        for block in &self.blocks {
181            match block {
182                EncoderBlock::Conv(conv) => {
183                    x = run_conv(&x, conv);
184                }
185                EncoderBlock::Transformer(tr) => {
186                    x = run_transformer(&x, tr)?;
187                }
188            }
189        }
190
191        let latent_dim = self.cfg.latent_dim();
192        let (d, _t) = x.dim();
193        ensure!(
194            d == latent_dim,
195            "encoder output channels {d} != latent_dim {latent_dim}"
196        );
197        Ok(x.slice(ndarray::s![..latent_dim, ..]).to_owned())
198    }
199
200    fn quantizer_encode(&self, latent: &Array2<f32>) -> Result<(Vec<usize>, Vec<u32>)> {
201        let d_sem = self.cfg.semantic_dim;
202        let (_, n_frames) = latent.dim();
203        let levels = self.cfg.acoustic_codebook_size;
204        let mut semantic = Vec::with_capacity(n_frames);
205        let mut acoustic = Vec::with_capacity(n_frames * 36);
206        for fi in 0..n_frames {
207            let mut best_id = 0usize;
208            let mut best_dist = f32::INFINITY;
209            for cid in 0..self.semantic_embedding.dim().0 {
210                let mut dist = 0f32;
211                for di in 0..d_sem {
212                    let diff = latent[[di, fi]] - self.semantic_embedding[[cid, di]];
213                    dist += diff * diff;
214                }
215                if dist < best_dist {
216                    best_dist = dist;
217                    best_id = cid;
218                }
219            }
220            semantic.push(best_id);
221            for ai in 0..36 {
222                let v = latent[[d_sem + ai, fi]].tanh();
223                let scaled = ((v + 1.0) * 0.5) * (levels as f32 - 1.0);
224                let code = scaled.round().clamp(0.0, levels as f32 - 1.0) as u32;
225                acoustic.push(code);
226            }
227        }
228        Ok((semantic, acoustic))
229    }
230}
231
232pub fn has_encoder_tensors(
233    prefix: &str,
234    tensors: &HashMap<String, (Vec<f32>, Vec<usize>)>,
235) -> bool {
236    tensors.contains_key(&format!(
237        "{prefix}input_proj.conv.parametrizations.weight.original1"
238    )) || tensors.contains_key(&format!("{prefix}input_proj.conv.weight"))
239}
240
241pub fn has_encoder_weights(keys: &std::collections::HashSet<String>, prefix: &str) -> bool {
242    keys.iter()
243        .any(|k| k.starts_with(&format!("{prefix}input_proj")))
244        && keys
245            .iter()
246            .any(|k| k.starts_with(&format!("{prefix}encoder_blocks")))
247}
248
249pub fn load_mono_wav(path: &std::path::Path, target_rate: u32) -> Result<Vec<f32>> {
250    let mut reader =
251        hound::WavReader::open(path).with_context(|| format!("open wav {}", path.display()))?;
252    let spec = reader.spec();
253    ensure!(
254        spec.channels == 1,
255        "reference wav must be mono (got {} channels)",
256        spec.channels
257    );
258    let samples: Result<Vec<f32>, _> = match spec.sample_format {
259        hound::SampleFormat::Float => reader
260            .samples::<f32>()
261            .map(|s| s.map_err(|e| anyhow::anyhow!("{e}")))
262            .collect(),
263        hound::SampleFormat::Int => reader
264            .samples::<i32>()
265            .map(|s| {
266                s.map(|v| v as f32 / i32::MAX as f32)
267                    .map_err(|e| anyhow::anyhow!("{e}"))
268            })
269            .collect(),
270    };
271    let mut pcm = samples?;
272    if spec.sample_rate != target_rate {
273        pcm = resample_linear(&pcm, spec.sample_rate, target_rate);
274    }
275    Ok(pcm)
276}
277
278fn resample_linear(pcm: &[f32], from_rate: u32, to_rate: u32) -> Vec<f32> {
279    if from_rate == to_rate || pcm.is_empty() {
280        return pcm.to_vec();
281    }
282    let out_len = ((pcm.len() as u64 * to_rate as u64) / from_rate as u64) as usize;
283    let mut out = Vec::with_capacity(out_len.max(1));
284    for i in 0..out_len {
285        let src = (i as f64 * from_rate as f64) / to_rate as f64;
286        let i0 = src.floor() as usize;
287        let i1 = (i0 + 1).min(pcm.len() - 1);
288        let frac = (src - i0 as f64) as f32;
289        out.push(pcm[i0] * (1.0 - frac) + pcm[i1] * frac);
290    }
291    out
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    #[test]
299    fn resample_downsample_halves_length() {
300        let pcm = vec![0.0, 1.0, 0.0, -1.0];
301        let out = resample_linear(&pcm, 24000, 12000);
302        assert_eq!(out.len(), 2);
303    }
304}