1use candle::{IndexOp, Layout, Result, Shape, Tensor, D};
6use candle_nn::{linear, Linear, VarBuilder};
7
8struct CodebookEncode;
9
10impl candle::CustomOp2 for CodebookEncode {
11 fn name(&self) -> &'static str {
12 "cb"
13 }
14
15 fn cpu_fwd(
16 &self,
17 lhs_storage: &candle::CpuStorage,
18 lhs_layout: &Layout,
19 rhs_storage: &candle::CpuStorage,
20 rhs_layout: &Layout,
21 ) -> Result<(candle::CpuStorage, Shape)> {
22 use rayon::prelude::*;
23
24 let (lhs_dim1, lhs_dim2) = lhs_layout.shape().dims2()?;
25 let (rhs_dim1, rhs_dim2) = rhs_layout.shape().dims2()?;
26 if lhs_dim2 != rhs_dim2 {
27 candle::bail!("CodebookEncode, mismatch on last dim, {lhs_layout:?} {rhs_layout:?}");
28 }
29 if lhs_dim2 == 0 {
30 candle::bail!("CodebookEncode, empty last dim {lhs_layout:?}")
31 }
32 let lhs = match lhs_layout.contiguous_offsets() {
33 None => candle::bail!("CodebookEncode, lhs has to be contiguous, got {lhs_layout:?}"),
34 Some((o1, o2)) => {
35 let slice = lhs_storage.as_slice::<f32>()?;
36 &slice[o1..o2]
37 }
38 };
39 let rhs = match rhs_layout.contiguous_offsets() {
40 None => candle::bail!("CodebookEncode, rhs has to be contiguous, got {rhs_layout:?}"),
41 Some((o1, o2)) => {
42 let slice = rhs_storage.as_slice::<f32>()?;
43 &slice[o1..o2]
44 }
45 };
46 let dst = (0..lhs_dim1)
47 .into_par_iter()
48 .map(|idx1| {
49 let mut where_min = 0;
50 let mut min_dist = f32::INFINITY;
51 let lhs = &lhs[idx1 * lhs_dim2..(idx1 + 1) * lhs_dim2];
52 for idx2 in 0..rhs_dim1 {
53 let rhs = &rhs[idx2 * rhs_dim2..(idx2 + 1) * rhs_dim2];
54 let mut dist = 0f32;
55 for (a, b) in lhs.iter().zip(rhs.iter()) {
56 dist += (a - b) * (a - b)
57 }
58 if dist < min_dist {
59 min_dist = dist;
60 where_min = idx2;
61 }
62 }
63 where_min as u32
64 })
65 .collect();
66 let storage = candle::WithDType::to_cpu_storage_owned(dst);
67 Ok((storage, (lhs_dim1,).into()))
68 }
69}
70
71#[allow(unused)]
72#[derive(Debug, Clone)]
73pub struct EuclideanCodebook {
74 initialized: Tensor,
75 cluster_usage: Tensor,
76 embedding_sum: Tensor,
77 embedding: Tensor,
78 c2: Tensor,
79 epsilon: f64,
80 dim: usize,
81 span_encode: tracing::Span,
82 span_decode: tracing::Span,
83}
84
85impl EuclideanCodebook {
86 pub fn new(dim: usize, codebook_size: usize, vb: VarBuilder) -> Result<Self> {
87 let epsilon = 1e-5;
88 let initialized = vb.get(1, "_initialized")?;
89 let cluster_usage = vb.get(codebook_size, "cluster_usage")?;
90 let embedding_sum = vb.get((codebook_size, dim), "embedding_sum")?;
91 let embedding = {
92 let cluster_usage = cluster_usage.maximum(epsilon)?.unsqueeze(1)?;
93 embedding_sum.broadcast_div(&cluster_usage)?
94 };
95 let c2 = ((&embedding * &embedding)?.sum(D::Minus1)? / 2.0)?;
96 Ok(Self {
97 initialized,
98 cluster_usage,
99 embedding_sum,
100 embedding,
101 c2,
102 epsilon,
103 dim,
104 span_encode: tracing::span!(tracing::Level::TRACE, "euclidean-encode"),
105 span_decode: tracing::span!(tracing::Level::TRACE, "euclidean-encode"),
106 })
107 }
108
109 pub fn encode_very_slow(&self, xs: &Tensor) -> Result<Tensor> {
110 let _enter = self.span_encode.enter();
111 let mut target_shape = xs.dims().to_vec();
112 target_shape.pop();
113 let xs = xs.flatten_to(D::Minus2)?;
114 let _ = xs.dims2()?;
115 let cluster_usage = self.cluster_usage.maximum(self.epsilon)?.unsqueeze(1)?;
117 let embedding = self.embedding_sum.broadcast_div(&cluster_usage)?;
118 let diff = xs.unsqueeze(1)?.broadcast_sub(&embedding.unsqueeze(0)?)?;
120 let dists = diff.sqr()?.sum(D::Minus1)?;
121 let codes = dists.argmin(D::Minus1)?;
122 codes.reshape(target_shape)
123 }
124
125 pub fn encode_slow(&self, xs: &Tensor) -> Result<Tensor> {
126 let _enter = self.span_encode.enter();
127 let mut target_shape = xs.dims().to_vec();
128 target_shape.pop();
129 let xs = xs.flatten_to(D::Minus2)?;
130 let _ = xs.dims2()?;
131 let dot_prod = xs.matmul(&self.embedding.t()?)?;
132 let codes = self.c2.broadcast_sub(&dot_prod)?.argmin(D::Minus1)?;
133 codes.reshape(target_shape)
134 }
135
136 pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
137 let _enter = self.span_encode.enter();
138 let mut target_shape = xs.dims().to_vec();
139 target_shape.pop();
140 let xs = xs.flatten_to(D::Minus2)?;
141 let _ = xs.dims2()?;
142 let codes = Tensor::apply_op2(&xs, &self.embedding, CodebookEncode)?;
143 codes.reshape(target_shape)
144 }
145
146 pub fn decode(&self, indexes: &Tensor) -> Result<Tensor> {
147 let _enter = self.span_decode.enter();
148 let mut final_dims = indexes.dims().to_vec();
150 final_dims.push(self.dim);
151 let indexes = indexes.flatten_all()?;
152 let values = self.embedding.index_select(&indexes, 0)?;
153 let values = values.reshape(final_dims)?;
154 Ok(values)
155 }
156}
157
158#[allow(unused)]
159#[derive(Debug, Clone)]
160pub struct VectorQuantization {
161 project_in: Option<Linear>,
162 project_out: Option<Linear>,
163 codebook: EuclideanCodebook,
164}
165
166impl VectorQuantization {
167 pub fn new(
168 dim: usize,
169 codebook_size: usize,
170 codebook_dim: Option<usize>,
171 vb: VarBuilder,
172 ) -> Result<Self> {
173 let codebook_dim = codebook_dim.unwrap_or(dim);
174 let (project_in, project_out) = if codebook_dim == dim {
175 (None, None)
176 } else {
177 let p_in = linear(dim, codebook_dim, vb.pp("project_in"))?;
178 let p_out = linear(codebook_dim, dim, vb.pp("project_out"))?;
179 (Some(p_in), Some(p_out))
180 };
181 let codebook = EuclideanCodebook::new(codebook_dim, codebook_size, vb.pp("_codebook"))?;
182 Ok(Self { project_in, project_out, codebook })
183 }
184
185 pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
186 let xs = xs.t()?.apply(&self.project_in.as_ref())?;
187 self.codebook.encode_slow(&xs)
188 }
189
190 pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
191 let quantized = self.codebook.decode(codes)?;
192 let quantized = match &self.project_out {
193 None => quantized,
194 Some(p) => quantized.apply(p)?,
195 };
196 quantized.t()
197 }
198}
199
200#[derive(Debug, Clone)]
201pub struct ResidualVectorQuantization {
202 layers: Vec<VectorQuantization>,
203}
204
205impl ResidualVectorQuantization {
206 pub fn new(
207 n_q: usize,
208 dim: usize,
209 codebook_size: usize,
210 codebook_dim: Option<usize>,
211 vb: VarBuilder,
212 ) -> Result<Self> {
213 let vb = vb.pp("layers");
214 let mut layers = Vec::with_capacity(n_q);
215 for i in 0..n_q {
216 let layer = VectorQuantization::new(dim, codebook_size, codebook_dim, vb.pp(i))?;
217 layers.push(layer)
218 }
219 Ok(Self { layers })
220 }
221
222 pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
223 let mut codes = Vec::with_capacity(self.layers.len());
224 let mut residual = xs.clone();
225 for layer in self.layers.iter() {
226 let indices = layer.encode(&residual)?;
227 let quantized = layer.decode(&indices)?;
228 residual = (residual - quantized)?;
229 codes.push(indices)
230 }
231 Tensor::stack(&codes, 0)
232 }
233
234 pub fn decode(&self, xs: &Tensor) -> Result<Tensor> {
235 if self.layers.is_empty() {
236 candle::bail!("empty layers in ResidualVectorQuantization")
237 }
238 if self.layers.len() != xs.dim(0)? {
239 candle::bail!(
240 "mismatch between the number of layers {} and the code shape {:?}",
241 self.layers.len(),
242 xs.shape()
243 )
244 }
245 let mut quantized = self.layers[0].decode(&xs.i(0)?)?;
246 for (i, layer) in self.layers.iter().enumerate().skip(1) {
247 let xs = xs.i(i)?;
248 quantized = (quantized + layer.decode(&xs))?
249 }
250 Ok(quantized)
251 }
252}
253
254#[allow(unused)]
255#[derive(Debug, Clone)]
256pub struct ResidualVectorQuantizer {
257 vq: ResidualVectorQuantization,
258 input_proj: Option<candle_nn::Conv1d>,
259 output_proj: Option<candle_nn::Conv1d>,
260}
261
262impl ResidualVectorQuantizer {
263 pub fn new(
264 dim: usize,
265 input_dim: Option<usize>,
266 output_dim: Option<usize>,
267 n_q: usize,
268 bins: usize,
269 force_projection: bool,
270 vb: VarBuilder,
271 ) -> Result<Self> {
272 let input_dim = input_dim.unwrap_or(dim);
273 let output_dim = output_dim.unwrap_or(dim);
274
275 let input_proj = if input_dim == dim && !force_projection {
276 None
277 } else {
278 let c = candle_nn::conv1d_no_bias(
279 input_dim,
280 dim,
281 1,
282 Default::default(),
283 vb.pp("input_proj"),
284 )?;
285 Some(c)
286 };
287 let output_proj = if output_dim == dim && !force_projection {
288 None
289 } else {
290 let c = candle_nn::conv1d_no_bias(
291 dim,
292 output_dim,
293 1,
294 Default::default(),
295 vb.pp("output_proj"),
296 )?;
297 Some(c)
298 };
299
300 let vq = ResidualVectorQuantization::new(
301 n_q,
302 dim,
303 bins,
304 None,
305 vb.pp("vq"),
306 )?;
307 Ok(Self { vq, input_proj, output_proj })
308 }
309
310 pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
311 let codes = self.vq.encode(&xs.apply(&self.input_proj.as_ref())?)?;
312 codes.transpose(0, 1)
313 }
314
315 pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
316 let codes = codes.transpose(0, 1)?;
318 let quantized = self.vq.decode(&codes)?;
319 match &self.output_proj {
320 None => Ok(quantized),
321 Some(p) => quantized.apply(p),
322 }
323 }
324}
325
326#[derive(Debug, Clone)]
329pub struct SplitResidualVectorQuantizer {
330 rvq_first: ResidualVectorQuantizer,
331 rvq_rest: ResidualVectorQuantizer,
332 n_q: usize,
333 span_encode: tracing::Span,
334 span_decode: tracing::Span,
335}
336
337impl SplitResidualVectorQuantizer {
338 pub fn new(
339 dim: usize,
340 input_dim: Option<usize>,
341 output_dim: Option<usize>,
342 n_q: usize,
343 bins: usize,
344 vb: VarBuilder,
345 ) -> Result<Self> {
346 let rvq_first = ResidualVectorQuantizer::new(
347 dim,
348 input_dim,
349 output_dim,
350 1,
351 bins,
352 true,
353 vb.pp("rvq_first"),
354 )?;
355 let rvq_rest = ResidualVectorQuantizer::new(
356 dim,
357 input_dim,
358 output_dim,
359 n_q - 1,
360 bins,
361 true,
362 vb.pp("rvq_rest"),
363 )?;
364 let span_encode = tracing::span!(tracing::Level::TRACE, "split-rvq-encode");
365 let span_decode = tracing::span!(tracing::Level::TRACE, "split-rvq-decode");
366 Ok(Self { rvq_first, rvq_rest, n_q, span_encode, span_decode })
367 }
368
369 pub fn encode(&self, xs: &Tensor) -> Result<Tensor> {
370 let _enter = self.span_encode.enter();
371 let codes = self.rvq_first.encode(xs)?;
372 if self.n_q > 1 {
373 let rest_codes = self.rvq_rest.encode(xs)?;
377 Tensor::cat(&[codes, rest_codes], 1)
378 } else {
379 Ok(codes)
380 }
381 }
382
383 pub fn decode(&self, codes: &Tensor) -> Result<Tensor> {
384 let _enter = self.span_decode.enter();
386 let quantized = self.rvq_first.decode(&codes.i((.., ..1))?)?;
387 let quantized = if self.n_q > 1 {
388 (quantized + self.rvq_rest.decode(&codes.i((.., 1..))?))?
389 } else {
390 quantized
391 };
392 Ok(quantized)
393 }
394}