Skip to main content

rlx_voxtral_tts/codec/
decoder.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Voxtral codec **decoder** (vLLM-Omni `VoxtralTTSAudioTokenizer` decode path).
17
18use crate::codec::layers::{
19    CodecConvBlock, CodecTransformer, compute_semantic_embedding, load_codec_layer, rescale_fsq,
20    run_conv, run_transformer, take_conv, take_conv_transpose, take1d, take2d,
21};
22use crate::config::CodecArgs;
23use crate::tokens::{AUDIO_TOKEN_OFFSET, split_voxtral_frames};
24use anyhow::{Result, ensure};
25use ndarray::Array2;
26use std::collections::HashMap;
27
28pub struct CodecDecoder {
29    cfg: CodecArgs,
30    semantic_embedding: Array2<f32>,
31    blocks: Vec<DecoderBlock>,
32    output_weight: ndarray::Array3<f32>,
33}
34
35enum DecoderBlock {
36    Conv(CodecConvBlock),
37    Transformer(CodecTransformer),
38}
39
40impl CodecDecoder {
41    pub fn from_tensors(
42        prefix: &str,
43        tensors: &HashMap<String, (Vec<f32>, Vec<usize>)>,
44        cfg: &CodecArgs,
45    ) -> Result<Self> {
46        let sem_sum = take2d(
47            tensors,
48            &format!("{prefix}quantizer.semantic_codebook.embedding_sum"),
49        )?;
50        let sem_usage = take1d(
51            tensors,
52            &format!("{prefix}quantizer.semantic_codebook.cluster_usage"),
53        )?;
54        let semantic_embedding = compute_semantic_embedding(&sem_sum, &sem_usage);
55
56        let dec_kernels = cfg.decoder_convs_kernels();
57        let dec_strides = cfg.decoder_convs_strides();
58        let dec_lens = cfg.decoder_transformer_lengths();
59        ensure!(
60            dec_kernels.len() == dec_strides.len(),
61            "decoder conv config mismatch"
62        );
63
64        let mut blocks = Vec::new();
65        let mut window = cfg.attn_sliding_window_size;
66        let mut block_idx = 0usize;
67
68        blocks.push(DecoderBlock::Conv(CodecConvBlock::Forward {
69            weight: take_conv(tensors, &format!("{prefix}decoder_blocks.{block_idx}"))?,
70            stride: dec_strides[0],
71            pad_left: dec_kernels[0] - dec_strides[0],
72        }));
73        block_idx += 1;
74
75        for (stage, n_layers) in dec_lens.iter().enumerate() {
76            blocks.push(DecoderBlock::Transformer(CodecTransformer {
77                window,
78                layers: (0..*n_layers)
79                    .map(|li| {
80                        load_codec_layer(
81                            tensors,
82                            &format!("{prefix}decoder_blocks.{block_idx}.layers.{li}"),
83                            cfg,
84                        )
85                    })
86                    .collect::<Result<_>>()?,
87            }));
88            block_idx += 1;
89
90            if stage + 1 < dec_lens.len() {
91                let k = dec_kernels[stage + 1];
92                let st = dec_strides[stage + 1];
93                let total_pad = k - st;
94                blocks.push(DecoderBlock::Conv(CodecConvBlock::Transpose {
95                    weight: take_conv_transpose(
96                        tensors,
97                        &format!("{prefix}decoder_blocks.{block_idx}"),
98                    )?,
99                    stride: st,
100                    trim_left: total_pad - (total_pad / 2),
101                    trim_right: total_pad / 2,
102                }));
103                if st > 1 {
104                    window *= 2;
105                }
106                block_idx += 1;
107            }
108        }
109
110        let output_weight = take_conv(tensors, &format!("{prefix}output_proj"))?;
111        Ok(Self {
112            cfg: cfg.clone(),
113            semantic_embedding,
114            blocks,
115            output_weight,
116        })
117    }
118
119    /// Decode `[n_frames, 37]` vLLM-layout codes (semantic raw, acoustic +2 offset).
120    pub fn decode_codes(&self, codes: &[u32], n_frames: usize) -> Result<Vec<f32>> {
121        ensure!(
122            codes.len() == n_frames * 37,
123            "expected {}*37 codes",
124            n_frames
125        );
126        let (semantic, acoustic, actual) = split_voxtral_frames(codes, n_frames);
127        if actual == 0 {
128            return Ok(Vec::new());
129        }
130        let latent = self.quantizer_decode(&semantic, &acoustic, actual)?;
131        self.forward_decoder(&latent)
132    }
133
134    fn quantizer_decode(
135        &self,
136        semantic: &[usize],
137        acoustic: &[u32],
138        n_frames: usize,
139    ) -> Result<Array2<f32>> {
140        let d_sem = self.cfg.semantic_dim;
141        let d_ac = self.cfg.acoustic_dim;
142        let mut out = Array2::<f32>::zeros((d_sem + d_ac, n_frames));
143        for fi in 0..n_frames {
144            let sid = semantic[fi];
145            ensure!(
146                sid < self.semantic_embedding.dim().0,
147                "semantic id {sid} oob"
148            );
149            for di in 0..d_sem {
150                out[[di, fi]] = self.semantic_embedding[[sid, di]];
151            }
152            for ai in 0..36 {
153                let level = acoustic[fi * 36 + ai];
154                let v = rescale_fsq(level, self.cfg.acoustic_codebook_size);
155                out[[d_sem + ai, fi]] = v;
156            }
157        }
158        Ok(out)
159    }
160
161    fn forward_decoder(&self, emb: &Array2<f32>) -> Result<Vec<f32>> {
162        let mut x = emb.to_owned();
163        for block in &self.blocks {
164            match block {
165                DecoderBlock::Conv(conv) => {
166                    x = run_conv(&x, conv);
167                }
168                DecoderBlock::Transformer(tr) => {
169                    x = run_transformer(&x, tr)?;
170                }
171            }
172        }
173        let k = self.output_weight.shape()[2];
174        let pad_left = k - 1;
175        let wav = crate::math::conv1d(x.view(), self.output_weight.view(), 1, pad_left);
176        let (c, t) = wav.dim();
177        let mut pcm = Vec::with_capacity(c * t);
178        for ti in 0..t {
179            for ci in 0..c {
180                pcm.push(wav[[ci, ti]]);
181            }
182        }
183        Ok(pcm)
184    }
185}
186
187#[allow(dead_code)]
188pub fn apply_audio_offset(codes: &mut [u32]) {
189    for c in codes.iter_mut() {
190        *c += AUDIO_TOKEN_OFFSET;
191    }
192}