1mod error;
2mod logits_sampler;
3mod onnx_builder;
4mod text;
5
6use {
7 async_stream::stream,
8 log::{debug, info},
9 ndarray::{
10 Array, Array2, ArrayBase, ArrayD, ArrayView2, Axis, IxDyn, OwnedRepr, concatenate, s,
11 },
12 ort::{
13 inputs,
14 session::{RunOptions, Session},
15 value::{Tensor, TensorRef},
16 },
17 rodio::{Source, buffer::SamplesBuffer, decoder::Decoder, source::UniformSourceIterator},
18 std::{io::Cursor, path::Path, time::SystemTime},
19 tokio::fs::read,
20};
21pub use {
22 error::*,
23 futures::{Stream, StreamExt},
24 logits_sampler::*,
25 onnx_builder::*,
26 text::*,
27};
28
29const T2S_DECODER_EOS: i64 = 1024;
30const VOCAB_SIZE: usize = 1025;
31const NUM_LAYERS: usize = 24;
32
33type KvDType = f32;
34
35#[derive(Clone)]
36pub struct ReferenceData {
37 ref_seq: Array2<i64>,
38 ref_bert: Array2<f32>,
39 ref_audio_32k: Array2<f32>,
40 ssl_content: ArrayBase<OwnedRepr<f32>, IxDyn>,
41}
42
43impl AsRef<Self> for ReferenceData {
44 fn as_ref(&self) -> &Self {
45 self
46 }
47}
48
49pub struct GptSoVitsModel {
50 text_processor: TextProcessor,
51 sovits: Session,
52 ssl: Session,
53 t2s_encoder: Session,
54 t2s_fs_decoder: Session,
55 t2s_s_decoder: Session,
56 num_layers: usize,
57 run_options: RunOptions,
58}
59
60const INITIAL_CACHE_SIZE: usize = 2048;
63const CACHE_REALLOC_INCREMENT: usize = 1024;
65
66impl GptSoVitsModel {
67 pub fn new<P>(
73 sovits_path: P,
74 ssl_path: P,
75 t2s_encoder_path: P,
76 t2s_fs_decoder_path: P,
77 t2s_s_decoder_path: P,
78 bert_path: Option<P>,
79 g2pw_path: Option<P>,
80 g2p_en_path: Option<P>,
81 ) -> Result<Self, GSVError>
82 where
83 P: AsRef<Path>,
84 {
85 info!("Initializing TTSModel with ONNX sessions");
86
87 let g2pw = G2PW::new(g2pw_path)?;
88
89 let text_processor =
90 TextProcessor::new(g2pw, G2pEn::new(g2p_en_path)?, BertModel::new(bert_path)?)?;
91
92 Ok(GptSoVitsModel {
93 text_processor,
94 sovits: create_onnx_cpu_session(sovits_path)?,
95 ssl: create_onnx_cpu_session(ssl_path)?,
96 t2s_encoder: create_onnx_cpu_session(t2s_encoder_path)?,
97 t2s_fs_decoder: create_onnx_cpu_session(t2s_fs_decoder_path)?,
98 t2s_s_decoder: create_onnx_cpu_session(t2s_s_decoder_path)?,
99 num_layers: NUM_LAYERS,
100 run_options: RunOptions::new()?,
101 })
102 }
103
104 pub async fn get_reference_data<P, S>(
105 &mut self,
106 reference_audio_path: P,
107 ref_text: S,
108 lang_id: LangId,
109 ) -> Result<ReferenceData, GSVError>
110 where
111 P: AsRef<Path>,
112 S: AsRef<str>,
113 {
114 info!("Processing reference audio and text: {}", ref_text.as_ref());
115 let ref_text = ensure_punctuation(ref_text);
116 let phones = self.text_processor.get_phone_and_bert(&ref_text, lang_id)?;
117 let ref_seq: Vec<i64> = phones.iter().fold(Vec::new(), |mut seq, p| {
118 seq.extend(p.1.clone());
119 seq
120 });
121
122 let ref_bert: Vec<Array2<f32>> = phones.iter().map(|f| f.2.clone()).collect();
123 let ref_bert = concatenate(
125 Axis(0),
126 &ref_bert.iter().map(|v| v.view()).collect::<Vec<_>>(),
127 )?;
128
129 let ref_seq = Array2::from_shape_vec((1, ref_seq.len()), ref_seq)?;
130 let (ref_audio_16k, ref_audio_32k) = read_and_resample_audio(&reference_audio_path).await?;
131 let ssl_content = self.process_ssl(&ref_audio_16k).await?;
132
133 Ok(ReferenceData {
134 ref_seq,
135 ref_bert,
136 ref_audio_32k,
137 ssl_content,
138 })
139 }
140
141 async fn process_ssl(
142 &mut self,
143 ref_audio_16k: &Array2<f32>,
144 ) -> Result<ArrayBase<OwnedRepr<f32>, IxDyn>, GSVError> {
145 let time = SystemTime::now();
146 let ssl_output = self
147 .ssl
148 .run_async(
149 inputs!["ref_audio_16k" => TensorRef::from_array_view(ref_audio_16k).unwrap()],
150 &self.run_options,
151 )?
152 .await?;
153 debug!("SSL processing time: {:?}", time.elapsed()?);
154 Ok(ssl_output["ssl_content"]
155 .try_extract_array::<f32>()?
156 .into_owned())
157 }
158
159 async fn run_t2s_s_decoder_loop(
161 &mut self,
162 sampler: &mut Sampler,
163 sampling_param: SamplingParams,
164 mut y_vec: Vec<i64>,
165 mut k_caches: Vec<ArrayBase<OwnedRepr<KvDType>, IxDyn>>,
166 mut v_caches: Vec<ArrayBase<OwnedRepr<KvDType>, IxDyn>>,
167 prefix_len: usize,
168 initial_valid_len: usize,
169 ) -> Result<ArrayBase<OwnedRepr<i64>, IxDyn>, GSVError> {
170 let mut idx = 0;
171 let mut valid_len = initial_valid_len;
172 y_vec.reserve(2048);
173
174 loop {
175 let mut inputs = inputs![
177 "iy" => TensorRef::from_array_view(unsafe {ArrayView2::from_shape_ptr((1, y_vec.len()), y_vec.as_ptr())})?,
178 "y_len" => Tensor::from_array(Array::from_vec(vec![prefix_len as i64]))?,
179 "idx" => Tensor::from_array(Array::from_vec(vec![idx as i64]))?,
180 ];
181
182 for i in 0..self.num_layers {
183 let k = k_caches[i].slice(s![.., 0..valid_len, ..]).to_owned();
185 let v = v_caches[i].slice(s![.., 0..valid_len, ..]).to_owned();
186
187 inputs.push((
188 format!("ik_cache_{}", i).into(),
189 Tensor::from_array(k)?.into(),
190 ));
191 inputs.push((
192 format!("iv_cache_{}", i).into(),
193 Tensor::from_array(v)?.into(),
194 ));
195 }
196 let mut output = self
198 .t2s_s_decoder
199 .run_async(inputs, &self.run_options)?
200 .await?;
201
202 let mut logits = output["logits"].try_extract_array_mut::<f32>()?;
203 let mut logits = logits.as_slice_mut().unwrap().to_owned();
204
205 if idx < 11 {
206 logits.pop();
207 }
208
209 y_vec.push(sampler.sample(&mut logits, &y_vec, &sampling_param));
210
211 let argmax_value = argmax(&logits);
212
213 let new_valid_len = valid_len + 1;
215
216 if new_valid_len > k_caches[0].shape()[1] {
218 info!(
219 "Reallocating KV cache from {} to {}",
220 k_caches[0].shape()[1],
221 k_caches[0].shape()[1] + CACHE_REALLOC_INCREMENT
222 );
223 for i in 0..self.num_layers {
224 let old_k = &k_caches[i];
225 let old_v = &v_caches[i];
226
227 let mut new_k_dims = old_k.raw_dim().clone();
229 new_k_dims[1] += CACHE_REALLOC_INCREMENT;
230 let mut new_v_dims = old_v.raw_dim().clone();
231 new_v_dims[1] += CACHE_REALLOC_INCREMENT;
232
233 let mut new_k = Array::zeros(new_k_dims);
234 let mut new_v = Array::zeros(new_v_dims);
235
236 new_k
238 .slice_mut(s![.., 0..valid_len, ..])
239 .assign(&old_k.slice(s![.., 0..valid_len, ..]));
240 new_v
241 .slice_mut(s![.., 0..valid_len, ..])
242 .assign(&old_v.slice(s![.., 0..valid_len, ..]));
243
244 k_caches[i] = new_k;
246 v_caches[i] = new_v;
247 }
248 }
249
250 for i in 0..self.num_layers {
252 let inc_k_cache =
253 output[format!("k_cache_{}", i)].try_extract_array::<KvDType>()?;
254 let inc_v_cache =
255 output[format!("v_cache_{}", i)].try_extract_array::<KvDType>()?;
256
257 let k_new_slice = inc_k_cache.slice(s![.., valid_len, ..]);
259 let v_new_slice = inc_v_cache.slice(s![.., valid_len, ..]);
260
261 k_caches[i]
263 .slice_mut(s![.., valid_len, ..])
264 .assign(&k_new_slice);
265 v_caches[i]
266 .slice_mut(s![.., valid_len, ..])
267 .assign(&v_new_slice);
268 }
269
270 valid_len = new_valid_len;
272
273 if idx >= 1500 || argmax_value == T2S_DECODER_EOS {
274 let mut sliced = y_vec[(y_vec.len() - idx + 1)..(y_vec.len() - 1)]
275 .iter()
276 .map(|&i| if i == T2S_DECODER_EOS { 0 } else { i })
277 .collect::<Vec<i64>>();
278 sliced.push(0);
279 debug!(
280 "t2s final len: {}, prefix_len: {}",
281 sliced.len(),
282 prefix_len
283 );
284 let y = ArrayD::from_shape_vec(IxDyn(&[1, 1, sliced.len()]), sliced)?;
285 return Ok(y);
286 }
287 idx += 1;
288 }
289 }
290
291 pub async fn synthesize<R, S>(
298 &mut self,
299 text: S,
300 reference_data: R,
301 sampling_param: SamplingParams,
302 lang_id: LangId,
303 ) -> Result<impl Stream<Item = Result<Vec<f32>, GSVError>> + Send + Unpin, GSVError>
304 where
305 R: AsRef<ReferenceData>,
306 S: AsRef<str>,
307 {
308 let time = SystemTime::now();
309 let texts_and_seqs = self
310 .text_processor
311 .get_phone_and_bert(text.as_ref(), lang_id)?;
312 debug!("g2pw and preprocess time: {:?}", time.elapsed()?);
313 let ref_data = reference_data.as_ref().clone();
314
315 let stream = stream! {
316 for (text, seq, bert) in texts_and_seqs {
317 debug!("process: {:?}", text);
318 yield self.in_stream_once_gen(&text, &bert, &seq, &ref_data, sampling_param).await;
319 }
320 };
321
322 Ok(Box::pin(stream))
323 }
324
325 async fn in_stream_once_gen(
326 &mut self,
327 _text: &str,
328 text_bert: &Array2<f32>,
329 text_seq_vec: &[i64],
330 ref_data: &ReferenceData,
331 sampling_param: SamplingParams,
332 ) -> Result<Vec<f32>, GSVError> {
333 let text_seq = Array2::from_shape_vec((1, text_seq_vec.len()), text_seq_vec.to_vec())?;
334 let mut sampler = Sampler::new(VOCAB_SIZE);
335
336 let prompts = {
337 let time = SystemTime::now();
338 let encoder_output = self
339 .t2s_encoder
340 .run_async(
341 inputs![
342 "ssl_content" => TensorRef::from_array_view(&ref_data.ssl_content)?
343 ],
344 &self.run_options,
345 )?
346 .await?;
347 debug!("T2S Encoder time: {:?}", time.elapsed()?);
348 encoder_output["prompts"]
349 .try_extract_array::<i64>()?
350 .into_owned()
351 };
352
353 let x = concatenate(Axis(1), &[ref_data.ref_seq.view(), text_seq.view()])?.to_owned();
354 let bert = concatenate(
355 Axis(1),
356 &[
357 ref_data.ref_bert.clone().permuted_axes([1, 0]).view(),
358 text_bert.clone().permuted_axes([1, 0]).view(),
359 ],
360 )?;
361
362 let bert = bert.insert_axis(Axis(0)).to_owned();
363
364 let (mut y_vec, _) = prompts.clone().into_raw_vec_and_offset();
365
366 let prefix_len = y_vec.len();
367
368 let (y_vec, k_caches, v_caches, initial_seq_len) = {
369 let time = SystemTime::now();
370 let fs_decoder_output = self
371 .t2s_fs_decoder
372 .run_async(
373 inputs![
374 "x" => Tensor::from_array(x)?,
375 "prompts" => TensorRef::from_array_view(&prompts)?,
376 "bert" => Tensor::from_array(bert)?,
377 ],
378 &self.run_options,
379 )?
380 .await?;
381 debug!("T2S FS Decoder time: {:?}", time.elapsed()?);
382
383 let logits = fs_decoder_output["logits"]
384 .try_extract_array::<f32>()?
385 .into_owned();
386
387 let k_init_first = fs_decoder_output["k_cache_0"].try_extract_array::<KvDType>()?;
390 let initial_dims_dyn = k_init_first.raw_dim();
391 let initial_seq_len = initial_dims_dyn[1];
392
393 let mut large_cache_dims = initial_dims_dyn.clone();
395 large_cache_dims[1] = INITIAL_CACHE_SIZE;
396
397 let mut k_caches = Vec::with_capacity(self.num_layers);
398 let mut v_caches = Vec::with_capacity(self.num_layers);
399
400 for i in 0..self.num_layers {
401 let k_init =
402 fs_decoder_output[format!("k_cache_{}", i)].try_extract_array::<KvDType>()?;
403 let v_init =
404 fs_decoder_output[format!("v_cache_{}", i)].try_extract_array::<KvDType>()?;
405
406 let mut k_large = Array::zeros(large_cache_dims.clone());
408 let mut v_large = Array::zeros(large_cache_dims.clone());
409
410 k_large
412 .slice_mut(s![.., 0..initial_seq_len, ..])
413 .assign(&k_init);
414 v_large
415 .slice_mut(s![.., 0..initial_seq_len, ..])
416 .assign(&v_init);
417
418 k_caches.push(k_large);
419 v_caches.push(v_large);
420 }
421 let (mut logits_vec, _) = logits.into_raw_vec_and_offset();
422 logits_vec.pop(); let sampling_rst = sampler.sample(&mut logits_vec, &y_vec, &sampling_param);
424 y_vec.push(sampling_rst);
425 (y_vec, k_caches, v_caches, initial_seq_len)
426 };
427
428 let time = SystemTime::now();
429 let pred_semantic = self
430 .run_t2s_s_decoder_loop(
431 &mut sampler,
432 sampling_param,
433 y_vec,
434 k_caches,
435 v_caches,
436 prefix_len,
437 initial_seq_len,
438 )
439 .await?;
440 debug!("T2S S Decoder all time: {:?}", time.elapsed()?);
441
442 let time = SystemTime::now();
443 let outputs = self
444 .sovits
445 .run_async(
446 inputs![
447 "text_seq" => TensorRef::from_array_view(&text_seq)?,
448 "pred_semantic" => TensorRef::from_array_view(&pred_semantic)?,
449 "ref_audio" => TensorRef::from_array_view(&ref_data.ref_audio_32k)?
450 ],
451 &self.run_options,
452 )?
453 .await?;
454 debug!("SoVITS time: {:?}", time.elapsed()?);
455 let output_audio = outputs["audio"].try_extract_array::<f32>()?;
456 let (mut audio, _) = output_audio.into_owned().into_raw_vec_and_offset();
457 for sample in &mut audio {
458 *sample = *sample * 4.0;
459 }
460 let max_audio = audio
462 .iter()
463 .filter(|&&x| x.is_finite()) .fold(0.0f32, |acc, &x| acc.max(x.abs()));
465 let audio = if max_audio > 1.0 {
466 audio
467 .into_iter()
468 .map(|x| x / max_audio)
469 .collect::<Vec<f32>>()
470 } else {
471 audio
472 };
473
474 Ok(audio)
475 }
476}
477
478fn ensure_punctuation<S>(text: S) -> String
479where
480 S: AsRef<str>,
481{
482 if !text
483 .as_ref()
484 .ends_with(['。', '!', '?', ';', '.', '!', '?', ';'])
485 {
486 text.as_ref().to_owned() + "。"
487 } else {
488 text.as_ref().to_owned()
489 }
490}
491
492fn resample_audio(input: &[f32], in_rate: u32, out_rate: u32) -> Vec<f32> {
493 if in_rate == out_rate {
494 return input.to_owned();
495 }
496
497 UniformSourceIterator::new(SamplesBuffer::new(1, in_rate, input), 1, out_rate).collect()
498}
499
500async fn read_and_resample_audio<P>(path: P) -> Result<(Array2<f32>, Array2<f32>), GSVError>
501where
502 P: AsRef<Path>,
503{
504 let data = Cursor::new(read(path).await?);
505 let decoder = Decoder::new(data)?;
506 let sample_rate = decoder.sample_rate();
507 let samples = if decoder.channels() == 1 {
508 decoder.collect::<Vec<_>>()
509 } else {
510 UniformSourceIterator::new(decoder, 1, sample_rate).collect()
511 };
512
513 let mut ref_audio_16k = resample_audio(&samples, sample_rate, 16000);
515 let ref_audio_32k = resample_audio(&samples, sample_rate, 32000);
516
517 let silence_16k = vec![0.0; (0.3 * 16000.0) as usize]; ref_audio_16k.splice(0..0, silence_16k);
521
522 Ok((
524 Array2::from_shape_vec((1, ref_audio_16k.len()), ref_audio_16k)?,
525 Array2::from_shape_vec((1, ref_audio_32k.len()), ref_audio_32k)?,
526 ))
527}