rlx_voxtral_tts/codec/
decoder.rs1use crate::codec::layers::{
19 CodecConvBlock, CodecTransformer, compute_semantic_embedding, load_codec_layer, rescale_fsq,
20 run_conv, run_transformer, take_conv, take_conv_transpose, take1d, take2d,
21};
22use crate::config::CodecArgs;
23use crate::tokens::{AUDIO_TOKEN_OFFSET, split_voxtral_frames};
24use anyhow::{Result, ensure};
25use ndarray::Array2;
26use std::collections::HashMap;
27
28pub struct CodecDecoder {
29 cfg: CodecArgs,
30 semantic_embedding: Array2<f32>,
31 blocks: Vec<DecoderBlock>,
32 output_weight: ndarray::Array3<f32>,
33}
34
35enum DecoderBlock {
36 Conv(CodecConvBlock),
37 Transformer(CodecTransformer),
38}
39
40impl CodecDecoder {
41 pub fn from_tensors(
42 prefix: &str,
43 tensors: &HashMap<String, (Vec<f32>, Vec<usize>)>,
44 cfg: &CodecArgs,
45 ) -> Result<Self> {
46 let sem_sum = take2d(
47 tensors,
48 &format!("{prefix}quantizer.semantic_codebook.embedding_sum"),
49 )?;
50 let sem_usage = take1d(
51 tensors,
52 &format!("{prefix}quantizer.semantic_codebook.cluster_usage"),
53 )?;
54 let semantic_embedding = compute_semantic_embedding(&sem_sum, &sem_usage);
55
56 let dec_kernels = cfg.decoder_convs_kernels();
57 let dec_strides = cfg.decoder_convs_strides();
58 let dec_lens = cfg.decoder_transformer_lengths();
59 ensure!(
60 dec_kernels.len() == dec_strides.len(),
61 "decoder conv config mismatch"
62 );
63
64 let mut blocks = Vec::new();
65 let mut window = cfg.attn_sliding_window_size;
66 let mut block_idx = 0usize;
67
68 blocks.push(DecoderBlock::Conv(CodecConvBlock::Forward {
69 weight: take_conv(tensors, &format!("{prefix}decoder_blocks.{block_idx}"))?,
70 stride: dec_strides[0],
71 pad_left: dec_kernels[0] - dec_strides[0],
72 }));
73 block_idx += 1;
74
75 for (stage, n_layers) in dec_lens.iter().enumerate() {
76 blocks.push(DecoderBlock::Transformer(CodecTransformer {
77 window,
78 layers: (0..*n_layers)
79 .map(|li| {
80 load_codec_layer(
81 tensors,
82 &format!("{prefix}decoder_blocks.{block_idx}.layers.{li}"),
83 cfg,
84 )
85 })
86 .collect::<Result<_>>()?,
87 }));
88 block_idx += 1;
89
90 if stage + 1 < dec_lens.len() {
91 let k = dec_kernels[stage + 1];
92 let st = dec_strides[stage + 1];
93 let total_pad = k - st;
94 blocks.push(DecoderBlock::Conv(CodecConvBlock::Transpose {
95 weight: take_conv_transpose(
96 tensors,
97 &format!("{prefix}decoder_blocks.{block_idx}"),
98 )?,
99 stride: st,
100 trim_left: total_pad - (total_pad / 2),
101 trim_right: total_pad / 2,
102 }));
103 if st > 1 {
104 window *= 2;
105 }
106 block_idx += 1;
107 }
108 }
109
110 let output_weight = take_conv(tensors, &format!("{prefix}output_proj"))?;
111 Ok(Self {
112 cfg: cfg.clone(),
113 semantic_embedding,
114 blocks,
115 output_weight,
116 })
117 }
118
119 pub fn decode_codes(&self, codes: &[u32], n_frames: usize) -> Result<Vec<f32>> {
121 ensure!(
122 codes.len() == n_frames * 37,
123 "expected {}*37 codes",
124 n_frames
125 );
126 let (semantic, acoustic, actual) = split_voxtral_frames(codes, n_frames);
127 if actual == 0 {
128 return Ok(Vec::new());
129 }
130 let latent = self.quantizer_decode(&semantic, &acoustic, actual)?;
131 self.forward_decoder(&latent)
132 }
133
134 fn quantizer_decode(
135 &self,
136 semantic: &[usize],
137 acoustic: &[u32],
138 n_frames: usize,
139 ) -> Result<Array2<f32>> {
140 let d_sem = self.cfg.semantic_dim;
141 let d_ac = self.cfg.acoustic_dim;
142 let mut out = Array2::<f32>::zeros((d_sem + d_ac, n_frames));
143 for fi in 0..n_frames {
144 let sid = semantic[fi];
145 ensure!(
146 sid < self.semantic_embedding.dim().0,
147 "semantic id {sid} oob"
148 );
149 for di in 0..d_sem {
150 out[[di, fi]] = self.semantic_embedding[[sid, di]];
151 }
152 for ai in 0..36 {
153 let level = acoustic[fi * 36 + ai];
154 let v = rescale_fsq(level, self.cfg.acoustic_codebook_size);
155 out[[d_sem + ai, fi]] = v;
156 }
157 }
158 Ok(out)
159 }
160
161 fn forward_decoder(&self, emb: &Array2<f32>) -> Result<Vec<f32>> {
162 let mut x = emb.to_owned();
163 for block in &self.blocks {
164 match block {
165 DecoderBlock::Conv(conv) => {
166 x = run_conv(&x, conv);
167 }
168 DecoderBlock::Transformer(tr) => {
169 x = run_transformer(&x, tr)?;
170 }
171 }
172 }
173 let k = self.output_weight.shape()[2];
174 let pad_left = k - 1;
175 let wav = crate::math::conv1d(x.view(), self.output_weight.view(), 1, pad_left);
176 let (c, t) = wav.dim();
177 let mut pcm = Vec::with_capacity(c * t);
178 for ti in 0..t {
179 for ci in 0..c {
180 pcm.push(wav[[ci, ti]]);
181 }
182 }
183 Ok(pcm)
184 }
185}
186
187#[allow(dead_code)]
188pub fn apply_audio_offset(codes: &mut [u32]) {
189 for c in codes.iter_mut() {
190 *c += AUDIO_TOKEN_OFFSET;
191 }
192}