moshi_db/
quantization.rs

1// Copyright (c) Kyutai, all rights reserved.
2// This source code is licensed under the license found in the
3// LICENSE file in the root directory of this source tree.
4
5use candle::{IndexOp, Layout, Result, Shape, Tensor, D};
6use candle_nn::{linear, Linear, VarBuilder};
7
8struct CodebookEncode;
9
10impl candle::CustomOp2 for CodebookEncode {
11    fn name(&self) -> &'static str {
12        "cb"
13    }
14
15    fn cpu_fwd(
16        &self,
17        lhs_storage: &candle::CpuStorage,
18        lhs_layout: &Layout,
19        rhs_storage: &candle::CpuStorage,
20        rhs_layout: &Layout,
21    ) -> Result<(candle::CpuStorage, Shape)> {
22        use rayon::prelude::*;
23
24        let (lhs_dim1, lhs_dim2) = lhs_layout.shape().dims2()?;
25        let (rhs_dim1, rhs_dim2) = rhs_layout.shape().dims2()?;
26        if lhs_dim2 != rhs_dim2 {
27            candle::bail!("CodebookEncode, mismatch on last dim, {lhs_layout:?} {rhs_layout:?}");
28        }
29        if lhs_dim2 == 0 {
30            candle::bail!("CodebookEncode, empty last dim {lhs_layout:?}")
31        }
32        let lhs = match lhs_layout.contiguous_offsets() {
33            None => candle::bail!("CodebookEncode, lhs has to be contiguous, got {lhs_layout:?}"),
34            Some((o1, o2)) => {
35                let slice = lhs_storage.as_slice::<f32>()?;
36                &slice[o1..o2]
37            }
38        };
39        let rhs = match rhs_layout.contiguous_offsets() {
40            None => candle::bail!("CodebookEncode, rhs has to be contiguous, got {rhs_layout:?}"),
41            Some((o1, o2)) => {
42                let slice = rhs_storage.as_slice::<f32>()?;
43                &slice[o1..o2]
44            }
45        };
46        let dst = (0..lhs_dim1)
47            .into_par_iter()
48            .map(|idx1| {
49                let mut where_min = 0;
50                let mut min_dist = f32::INFINITY;
51                let lhs = &lhs[idx1 * lhs_dim2..(idx1 + 1) * lhs_dim2];
52                for idx2 in 0..rhs_dim1 {
53                    let rhs = &rhs[idx2 * rhs_dim2..(idx2 + 1) * rhs_dim2];
54                    let mut dist = 0f32;
55                    for (a, b) in lhs.iter().zip(rhs.iter()) {
56                        dist += (a - b) * (a - b)
57                    }
58                    if dist < min_dist {
59                        min_dist = dist;
60                        where_min = idx2;
61                    }
62                }
63                where_min as u32
64            })
65            .collect();
66        let storage = candle::WithDType::to_cpu_storage_owned(dst);
67        Ok((storage, (lhs_dim1,).into()))
68    }
69}
70
71#[allow(unused)]
72#[derive(Debug, Clone)]
73pub struct EuclideanCodebook {
74    initialized: Tensor,
75    cluster_usage: Tensor,
76    embedding_sum: Tensor,
77    embedding: Tensor,
78    c2: Tensor,
79    epsilon: f64,
80    dim: usize,
81    span_encode: tracing::Span,
82    span_decode: tracing::Span,
83}
84
85impl EuclideanCodebook {
86    pub fn new(dim: usize, codebook_size: usize, vb: VarBuilder) -> Result<Self> {
87        let epsilon = 1e-5;
88        let initialized = vb.get(1, "_initialized")?;
89        let cluster_usage = vb.get(codebook_size, "cluster_usage")?;
90        let embedding_sum = vb.get((codebook_size, dim), "embedding_sum")?;
91        let embedding = {
92            let cluster_usage = cluster_usage.maximum(epsilon)?.unsqueeze(1)?;
93            embedding_sum.broadcast_div(&cluster_usage)?
94        };
95        let c2 = ((&embedding * &embedding)?.sum(D::Minus1)? / 2.0)?;
96        Ok(Self {
97            initialized,
98            cluster_usage,
99            embedding_sum,
100            embedding,
101            c2,
102            epsilon,
103            dim,
104            span_encode: tracing::span!(tracing::Level::TRACE, "euclidean-encode"),
105            span_decode: tracing::span!(tracing::Level::TRACE, "euclidean-encode"),
106        })
107    }
108
109    pub fn encode_very_slow(&self, xs: &Tensor) -> Result<Tensor> {
110        let _enter = self.span_encode.enter();
111        let mut target_shape = xs.dims().to_vec();
112        target_shape.pop();
113        let xs = xs.flatten_to(D::Minus2)?;
114        let _ = xs.dims2()?;
115        // TODO: avoid repeating this.
116        let cluster_usage = self.cluster_usage.maximum(self.epsilon)?.unsqueeze(1)?;
117        let embedding = self.embedding_sum.broadcast_div(&cluster_usage)?;
118        // Manual cdist implementation.
119        let diff = xs.unsqueeze(1)?.broadcast_sub(&embedding.unsqueeze(0)?)?;
120        let dists = diff.sqr()?.sum(D::Minus1)?;
121        let codes = dists.argmin(D::Minus1)?;
122        codes.reshape(target_shape)
123    }
124
125    pub fn encode_slow(&self, xs: &Tensor) -> Result<Tensor> {
126        let _enter = self.span_encode.enter();
127        let mut target_shape = xs.dims().to_vec();
128        target_shape.pop();
129        let xs = xs.flatten_to(D::Minus2)?;
130        let _ = xs.dims2()?;
131        let dot_prod = xs.matmul(&self.embedding.t()?)?;
132        let codes = self.c2.broadcast_sub(&dot_prod)?.argmin(D::Minus1)?;
133        codes.reshape(target_shape)
134    }
135
136    pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
137        let _enter = self.span_encode.enter();
138        let mut target_shape = xs.dims().to_vec();
139        target_shape.pop();
140        let xs = xs.flatten_to(D::Minus2)?;
141        let _ = xs.dims2()?;
142        let codes = Tensor::apply_op2(&xs, &self.embedding, CodebookEncode)?;
143        codes.reshape(target_shape)
144    }
145
146    pub fn decode(&self, indexes: &Tensor) -> Result<Tensor> {
147        let _enter = self.span_decode.enter();
148        // let ys = candle_nn::Embedding::new(self.embedding.clone(), self.dim).forward(xs)?;
149        let mut final_dims = indexes.dims().to_vec();
150        final_dims.push(self.dim);
151        let indexes = indexes.flatten_all()?;
152        let values = self.embedding.index_select(&indexes, 0)?;
153        let values = values.reshape(final_dims)?;
154        Ok(values)
155    }
156}
157
158#[allow(unused)]
159#[derive(Debug, Clone)]
160pub struct VectorQuantization {
161    project_in: Option<Linear>,
162    project_out: Option<Linear>,
163    codebook: EuclideanCodebook,
164}
165
166impl VectorQuantization {
167    pub fn new(
168        dim: usize,
169        codebook_size: usize,
170        codebook_dim: Option<usize>,
171        vb: VarBuilder,
172    ) -> Result<Self> {
173        let codebook_dim = codebook_dim.unwrap_or(dim);
174        let (project_in, project_out) = if codebook_dim == dim {
175            (None, None)
176        } else {
177            let p_in = linear(dim, codebook_dim, vb.pp("project_in"))?;
178            let p_out = linear(codebook_dim, dim, vb.pp("project_out"))?;
179            (Some(p_in), Some(p_out))
180        };
181        let codebook = EuclideanCodebook::new(codebook_dim, codebook_size, vb.pp("_codebook"))?;
182        Ok(Self { project_in, project_out, codebook })
183    }
184
185    pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
186        let xs = xs.t()?.apply(&self.project_in.as_ref())?;
187        self.codebook.encode_slow(&xs)
188    }
189
190    pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
191        let quantized = self.codebook.decode(codes)?;
192        let quantized = match &self.project_out {
193            None => quantized,
194            Some(p) => quantized.apply(p)?,
195        };
196        quantized.t()
197    }
198}
199
200#[derive(Debug, Clone)]
201pub struct ResidualVectorQuantization {
202    layers: Vec<VectorQuantization>,
203}
204
205impl ResidualVectorQuantization {
206    pub fn new(
207        n_q: usize,
208        dim: usize,
209        codebook_size: usize,
210        codebook_dim: Option<usize>,
211        vb: VarBuilder,
212    ) -> Result<Self> {
213        let vb = vb.pp("layers");
214        let mut layers = Vec::with_capacity(n_q);
215        for i in 0..n_q {
216            let layer = VectorQuantization::new(dim, codebook_size, codebook_dim, vb.pp(i))?;
217            layers.push(layer)
218        }
219        Ok(Self { layers })
220    }
221
222    pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
223        let mut codes = Vec::with_capacity(self.layers.len());
224        let mut residual = xs.clone();
225        for layer in self.layers.iter() {
226            let indices = layer.encode(&residual)?;
227            let quantized = layer.decode(&indices)?;
228            residual = (residual - quantized)?;
229            codes.push(indices)
230        }
231        Tensor::stack(&codes, 0)
232    }
233
234    pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
235        if self.layers.is_empty() {
236            candle::bail!("empty layers in ResidualVectorQuantization")
237        }
238        if self.layers.len() != xs.dim(0)? {
239            candle::bail!(
240                "mismatch between the number of layers {} and the code shape {:?}",
241                self.layers.len(),
242                xs.shape()
243            )
244        }
245        let mut quantized = self.layers[0].decode(&xs.i(0)?)?;
246        for (i, layer) in self.layers.iter().enumerate().skip(1) {
247            let xs = xs.i(i)?;
248            quantized = (quantized + layer.decode(&xs))?
249        }
250        Ok(quantized)
251    }
252}
253
254#[allow(unused)]
255#[derive(Debug, Clone)]
256pub struct ResidualVectorQuantizer {
257    vq: ResidualVectorQuantization,
258    input_proj: Option<candle_nn::Conv1d>,
259    output_proj: Option<candle_nn::Conv1d>,
260}
261
262impl ResidualVectorQuantizer {
263    pub fn new(
264        dim: usize,
265        input_dim: Option<usize>,
266        output_dim: Option<usize>,
267        n_q: usize,
268        bins: usize,
269        force_projection: bool,
270        vb: VarBuilder,
271    ) -> Result<Self> {
272        let input_dim = input_dim.unwrap_or(dim);
273        let output_dim = output_dim.unwrap_or(dim);
274
275        let input_proj = if input_dim == dim && !force_projection {
276            None
277        } else {
278            let c = candle_nn::conv1d_no_bias(
279                input_dim,
280                dim,
281                1,
282                Default::default(),
283                vb.pp("input_proj"),
284            )?;
285            Some(c)
286        };
287        let output_proj = if output_dim == dim && !force_projection {
288            None
289        } else {
290            let c = candle_nn::conv1d_no_bias(
291                dim,
292                output_dim,
293                1,
294                Default::default(),
295                vb.pp("output_proj"),
296            )?;
297            Some(c)
298        };
299
300        let vq = ResidualVectorQuantization::new(
301            n_q,
302            dim,
303            /* codebook_size */ bins,
304            /* codebook_dim */ None,
305            vb.pp("vq"),
306        )?;
307        Ok(Self { vq, input_proj, output_proj })
308    }
309
310    pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
311        let codes = self.vq.encode(&xs.apply(&self.input_proj.as_ref())?)?;
312        codes.transpose(0, 1)
313    }
314
315    pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
316        // codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
317        let codes = codes.transpose(0, 1)?;
318        let quantized = self.vq.decode(&codes)?;
319        match &self.output_proj {
320            None => Ok(quantized),
321            Some(p) => quantized.apply(p),
322        }
323    }
324}
325
326// we do not use any codebook_offset at the moment. When reconstructing the codes, we could just
327// concatenate the indexes.
328#[derive(Debug, Clone)]
329pub struct SplitResidualVectorQuantizer {
330    rvq_first: ResidualVectorQuantizer,
331    rvq_rest: ResidualVectorQuantizer,
332    n_q: usize,
333    span_encode: tracing::Span,
334    span_decode: tracing::Span,
335}
336
337impl SplitResidualVectorQuantizer {
338    pub fn new(
339        dim: usize,
340        input_dim: Option<usize>,
341        output_dim: Option<usize>,
342        n_q: usize,
343        bins: usize,
344        vb: VarBuilder,
345    ) -> Result<Self> {
346        let rvq_first = ResidualVectorQuantizer::new(
347            dim,
348            input_dim,
349            output_dim,
350            1,
351            bins,
352            true,
353            vb.pp("rvq_first"),
354        )?;
355        let rvq_rest = ResidualVectorQuantizer::new(
356            dim,
357            input_dim,
358            output_dim,
359            n_q - 1,
360            bins,
361            true,
362            vb.pp("rvq_rest"),
363        )?;
364        let span_encode = tracing::span!(tracing::Level::TRACE, "split-rvq-encode");
365        let span_decode = tracing::span!(tracing::Level::TRACE, "split-rvq-decode");
366        Ok(Self { rvq_first, rvq_rest, n_q, span_encode, span_decode })
367    }
368
369    pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
370        let _enter = self.span_encode.enter();
371        let codes = self.rvq_first.encode(xs)?;
372        if self.n_q > 1 {
373            // We encode xs again here rather than the residual. The decomposition is not
374            // hierarchical but rather having semantic tokens for rvq_first and the acoustic tokens
375            // for rvq_rest.
376            let rest_codes = self.rvq_rest.encode(xs)?;
377            Tensor::cat(&[codes, rest_codes], 1)
378        } else {
379            Ok(codes)
380        }
381    }
382
383    pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
384        // codes is [B, K, T], with T frames, K nb of codebooks.
385        let _enter = self.span_decode.enter();
386        let quantized = self.rvq_first.decode(&codes.i((.., ..1))?)?;
387        let quantized = if self.n_q > 1 {
388            (quantized + self.rvq_rest.decode(&codes.i((.., 1..))?))?
389        } else {
390            quantized
391        };
392        Ok(quantized)
393    }
394}