1use crate::transformer::CaSrc;
6use candle::{Context, DType, Result, Tensor, D};
7use candle_nn::{linear_no_bias, Linear, VarBuilder};
8use candle_transformers::models::t5;
9
10pub struct Config {
11 pub t5: t5::Config,
12 pub lm: crate::lm::Config,
13 pub mimi: crate::mimi::Config,
14 pub max_duration_s: f64,
15 pub speaker_cond_duration_s: f64,
16 pub max_speakers: usize,
17}
18
19impl Config {
20 pub fn v0_1(t5: t5::Config) -> Self {
21 let lm = crate::lm::Config::tts_v0_1();
22 let mimi = crate::mimi::Config::v0_1(None);
23 Self { t5, lm, mimi, max_duration_s: 60., speaker_cond_duration_s: 4., max_speakers: 5 }
24 }
25
26 pub fn v0_2(t5: t5::Config) -> Self {
27 let lm = crate::lm::Config::tts_v0_1();
28 let mimi = crate::mimi::Config::v0_1(None);
29 Self { t5, lm, mimi, max_duration_s: 60., speaker_cond_duration_s: 10., max_speakers: 2 }
30 }
31}
32
33#[derive(Clone)]
34pub struct Model {
35 t5: t5::T5EncoderModel,
36 pub lm: crate::lm::LmModel,
37 speaker_cond: Option<(crate::mimi::Mimi, Linear)>,
38 t5_proj: Linear,
39 pub sample_rate: f64,
40 frame_rate: f64,
41 audio_vocab_size: u32,
42 audio_codebooks: usize,
43 pub max_duration_s: f64,
44 max_speakers: usize,
45 end_of_gen: Option<usize>,
46}
47
48impl Model {
49 pub fn new(
50 cfg: &Config,
51 vb_t5: VarBuilder,
52 vb_lm: VarBuilder,
53 vb_speaker_cond: Option<VarBuilder>,
54 ) -> Result<Self> {
55 let t5 = t5::T5EncoderModel::load(vb_t5, &cfg.t5)?;
56 let speaker_cond = match vb_speaker_cond {
57 None => None,
58 Some(vb) => {
59 let mimi = crate::mimi::Mimi::new(cfg.mimi.clone(), vb)?;
60 let proj = linear_no_bias(
61 cfg.mimi.seanet.dimension,
62 cfg.lm.transformer.d_model,
63 vb_lm.pp("condition_provider.conditioners.speaker_wavs.output_proj"),
64 )?;
65 Some((mimi, proj))
66 }
67 };
68 let t5_proj = {
69 let name = if speaker_cond.is_some() {
70 "condition_provider.conditioners.diarized_transcript_in_segment.output_proj"
71 } else {
72 "condition_provider.conditioners.transcript_in_segment.output_proj"
73 };
74 linear_no_bias(cfg.t5.d_model, cfg.lm.transformer.d_model, vb_lm.pp(name))?
75 };
76 let lm =
77 crate::lm::LmModel::new(&cfg.lm, crate::nn::MaybeQuantizedVarBuilder::Real(vb_lm))?;
78 Ok(Self {
79 t5,
80 lm,
81 speaker_cond,
82 t5_proj,
83 sample_rate: cfg.mimi.sample_rate,
84 frame_rate: cfg.mimi.frame_rate,
85 audio_vocab_size: cfg.lm.audio_vocab_size as u32,
86 audio_codebooks: cfg.lm.audio_codebooks,
87 max_duration_s: cfg.max_duration_s,
88 max_speakers: cfg.max_speakers,
89 end_of_gen: None,
90 })
91 }
92}
93
94pub fn add_sin_embeddings(xs: &Tensor) -> Result<Tensor> {
95 let target_dtype = xs.dtype();
96 let (_b_size, seq_len, dim) = xs.dims3()?;
97 let dev = xs.device();
98 let half_dim = dim / 2;
99 let positions =
100 Tensor::arange(0u32, seq_len as u32, dev)?.unsqueeze(1)?.to_dtype(DType::F32)?;
101 let inv_freq: Vec<_> =
102 (0..half_dim).map(|i| 1f32 / 10000f32.powf(i as f32 / (half_dim - 1) as f32)).collect();
103 let inv_freq_len = inv_freq.len();
104 let inv_freq = Tensor::from_vec(inv_freq, (1, inv_freq_len), dev)?;
105 let freqs = positions.broadcast_mul(&inv_freq)?;
106 let pos_emb = Tensor::cat(&[freqs.cos()?, freqs.sin()?], D::Minus1)?;
107 let xs = xs.to_dtype(DType::F32)?.broadcast_add(&pos_emb)?;
108 xs.to_dtype(target_dtype)
109}
110
111impl Model {
112 pub fn conditions(
113 &mut self,
114 token_ids: &Tensor,
115 speaker_pcm: Option<&Tensor>,
116 ) -> Result<Tensor> {
117 let t5_condition =
118 self.t5.forward(token_ids)?.to_dtype(candle::DType::BF16)?.apply(&self.t5_proj)?;
119 let conditions = match speaker_pcm {
120 None => t5_condition,
121 Some(speaker_pcm) => {
122 let sc = match self.speaker_cond.as_mut() {
123 None => candle::bail!("speaker_pcm specified without a speaker-cond model"),
124 Some((mimi, proj)) => mimi
125 .encode_pre_quantize(speaker_pcm)?
126 .t()?
127 .to_dtype(candle::DType::BF16)?
128 .apply(proj)?,
129 };
130 let z = sc.zeros_like()?;
131 let mut c1 = vec![&t5_condition, &sc];
132 let mut c2 = vec![&t5_condition, &z];
133 for _i in 0..self.max_speakers - 1 {
134 c1.push(&z);
135 c2.push(&z);
136 }
137 let c1 = Tensor::cat(&c1, 1)?;
138 let c2 = Tensor::cat(&c2, 1)?;
139 let xs = Tensor::cat(&[&c1, &c2], 0)?;
140 add_sin_embeddings(&xs)?
141 }
142 };
143 Ok(conditions)
144 }
145
146 pub fn sample(&mut self, conditions: &Tensor, cfg_alpha: f64) -> Result<Vec<Vec<u32>>> {
147 let lp = candle_transformers::generation::LogitsProcessor::from_sampling(
148 299792458,
149 candle_transformers::generation::Sampling::TopK { k: 100, temperature: 0.8 },
150 );
151 self.sample_lp(conditions, cfg_alpha, lp)
152 }
153
154 pub fn sample_lp(
155 &mut self,
156 conditions: &Tensor,
157 cfg_alpha: f64,
158 mut lp: candle_transformers::generation::LogitsProcessor,
159 ) -> Result<Vec<Vec<u32>>> {
160 let max_steps = (self.max_duration_s * self.frame_rate) as usize + 1;
161 let audio_codebooks = self.audio_codebooks;
162 let audio_vocab_size = self.audio_vocab_size;
163 let mut audio_tokens: Vec<Vec<u32>> = vec![vec![u32::MAX; audio_codebooks]; max_steps + 2];
164 let forced_audio_tokens = crate::lm::ForcedAudioTokens::new(
165 2,
166 self.lm.audio_pad_token(),
167 &[audio_codebooks],
168 );
169 let quantizer_bins = audio_vocab_size - 2; for step_idx in 0..(max_steps + 2) {
171 let mut codes = Vec::with_capacity(audio_codebooks);
172 for codebook in 0..audio_codebooks {
173 let t = if codebook == 0 {
174 if step_idx == 0 {
175 audio_vocab_size - 1
176 } else {
177 audio_tokens[step_idx - 1][0]
178 }
179 } else if step_idx <= 2 {
180 audio_vocab_size - 1
181 } else {
182 audio_tokens[step_idx - 3][codebook]
183 };
184 let t = Tensor::new(&[t], conditions.device())?.unsqueeze(0)?;
185 codes.push(Some(t))
186 }
187 let (_text_logits, ys) = self.lm.forward_ca(
188 None,
189 codes,
190 &CaSrc::Tokens(conditions.clone()),
191 None,
192 &().into(),
193 )?;
194 let last_audio_tokens = if self.speaker_cond.is_some() {
195 self.lm.depformer_sample_cfg(
196 &ys,
197 cfg_alpha,
198 None,
199 forced_audio_tokens.forced_tokens(step_idx),
200 &mut lp,
201 )?
202 } else {
203 self.lm.depformer_sample(
204 &ys,
205 None,
206 forced_audio_tokens.forced_tokens(step_idx),
207 &mut lp,
208 )?
209 };
210 let last_audio_tokens = last_audio_tokens.context("no depformer")?;
211 for (c_idx, token) in last_audio_tokens.into_iter().enumerate() {
212 if step_idx > 0 && token >= quantizer_bins && self.end_of_gen.is_none() {
213 self.end_of_gen = Some(step_idx + 2)
215 }
216 let delay = if c_idx == 0 { 0 } else { 2 };
217 audio_tokens[step_idx.saturating_sub(delay)][c_idx] = token
218 }
219 if Some(step_idx) == self.end_of_gen {
220 break;
221 }
222 }
223 Ok(audio_tokens)
224 }
225}