use crate::ModelState;
use crate::models::seanet::{SEANetDecoder, SEANetEncoder};
use crate::models::transformer::ProjectedTransformer;
use crate::modules::conv::{ConvDownsample1d, ConvTrUpsample1d};
use candle_core::{Result, Tensor};
use candle_nn::{Conv1d, Conv1dConfig, Module, VarBuilder};
#[derive(Clone)]
pub struct Quantizer {
output_proj: Conv1d,
}
impl Quantizer {
pub fn new(dimension: usize, output_dimension: usize, vb: VarBuilder) -> Result<Self> {
let config = Conv1dConfig {
groups: 1,
padding: 0,
stride: 1,
dilation: 1,
..Default::default()
};
let output_proj = candle_nn::conv1d_no_bias(
dimension,
output_dimension,
1,
config,
vb.pp("output_proj"),
)?;
Ok(Self { output_proj })
}
pub fn forward(&self, x: &Tensor) -> Result<Tensor> {
self.output_proj.forward(x)
}
}
#[derive(Clone)]
pub struct MimiModel {
pub encoder: SEANetEncoder,
pub decoder: SEANetDecoder,
pub encoder_transformer: ProjectedTransformer,
pub decoder_transformer: ProjectedTransformer,
pub quantizer: Quantizer,
pub downsample: Option<ConvDownsample1d>,
pub upsample: Option<ConvTrUpsample1d>,
pub frame_rate: f64,
pub encoder_frame_rate: f64,
pub sample_rate: usize,
pub channels: usize,
pub dimension: usize,
}
impl MimiModel {
#[allow(clippy::too_many_arguments)]
pub fn new(
encoder: SEANetEncoder,
decoder: SEANetDecoder,
encoder_transformer: ProjectedTransformer,
decoder_transformer: ProjectedTransformer,
frame_rate: f64,
encoder_frame_rate: f64,
sample_rate: usize,
channels: usize,
dimension: usize, output_dimension: usize, name: &str,
vb: VarBuilder,
) -> Result<Self> {
let quantizer = Quantizer::new(dimension, output_dimension, vb.pp("quantizer"))?;
let (downsample, upsample) = if encoder_frame_rate != frame_rate {
let stride = (encoder_frame_rate / frame_rate) as usize;
(
Some(ConvDownsample1d::new(
stride,
output_dimension,
&format!("{}.downsample", name),
vb.pp("downsample"),
)?),
Some(ConvTrUpsample1d::new(
stride,
output_dimension,
&format!("{}.upsample", name),
vb.pp("upsample"),
)?),
)
} else {
(None, None)
};
Ok(Self {
encoder,
decoder,
encoder_transformer,
decoder_transformer,
quantizer,
downsample,
upsample,
frame_rate,
encoder_frame_rate,
sample_rate,
channels,
dimension,
})
}
pub fn frame_size(&self) -> usize {
(self.sample_rate as f64 / self.frame_rate) as usize
}
pub fn encode_to_latent(
&self,
x: &Tensor,
model_state: &mut ModelState,
step: usize,
) -> Result<Tensor> {
let _frame_size = self.frame_size();
let (b, c, _t_orig) = x.dims3()?;
let t = x.dims()[2];
let hop = self.frame_size();
let x = if !t.is_multiple_of(hop) {
let padding = hop - (t % hop);
let pad = Tensor::zeros((b, c, padding), x.dtype(), x.device())?;
Tensor::cat(&[x, &pad], 2)?
} else {
x.clone()
};
let mut emb = self.encoder.forward(&x, model_state, step)?;
let mut embs = self.encoder_transformer.forward(&emb, model_state, step)?;
emb = embs.remove(0);
if let Some(down) = &self.downsample {
emb = down.forward(&emb, model_state, step)?;
}
Ok(emb)
}
pub fn decode_from_latent(
&self,
latent: &Tensor,
model_state: &mut ModelState,
step: usize,
) -> Result<Tensor> {
let mut emb = latent.clone();
if let Some(up) = &self.upsample {
emb = up.forward(&emb, model_state, step)?;
}
let mut embs = self.decoder_transformer.forward(&emb, model_state, step)?;
emb = embs.remove(0);
let out = self.decoder.forward(&emb, model_state, step)?;
Ok(out)
}
pub fn quantize(&self, x: &Tensor) -> Result<Tensor> {
self.quantizer.forward(x)
}
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::{DType, Device, Tensor};
use candle_nn::VarBuilder;
use std::collections::HashMap;
#[test]
fn test_mimi_shapes() -> Result<()> {
let device = Device::Cpu;
let vb = VarBuilder::zeros(DType::F32, &device);
let encoder = SEANetEncoder::new(
1,
128,
32,
1,
&[2, 2],
7,
7,
3,
2,
"constant",
2,
"encoder",
vb.pp("encoder"),
)?;
let decoder = SEANetDecoder::new(
1,
128,
32,
1,
&[2, 2],
7,
7,
3,
2,
"constant",
2,
"decoder",
vb.pp("decoder"),
)?;
let encoder_transformer = ProjectedTransformer::new(
128,
vec![128],
128,
4,
1,
0.1,
10,
10000.0,
512,
"enc_tr",
vb.pp("enc_tr"),
)?;
let decoder_transformer = ProjectedTransformer::new(
128,
vec![128],
128,
4,
1,
0.1,
10,
10000.0,
512,
"dec_tr",
vb.pp("dec_tr"),
)?;
let mimi = MimiModel::new(
encoder,
decoder,
encoder_transformer,
decoder_transformer,
12.5,
50.0,
16000,
1,
128,
512,
"mimi",
vb.pp("mimi"),
)?;
let _audio = Tensor::zeros((1, 1, 1280), DType::F32, &device)?;
let mut _model_state: HashMap<String, HashMap<String, Tensor>> = HashMap::new();
assert_eq!(mimi.frame_size(), 1280);
Ok(())
}
}