gpt_sovits/
lib.rs

1mod error;
2mod logits_sampler;
3mod onnx_builder;
4mod text;
5
6use {
7    async_stream::stream,
8    log::{debug, info},
9    ndarray::{
10        Array, Array2, ArrayBase, ArrayD, ArrayView2, Axis, IxDyn, OwnedRepr, concatenate, s,
11    },
12    ort::{
13        inputs,
14        session::{RunOptions, Session},
15        value::{Tensor, TensorRef},
16    },
17    rodio::{Source, buffer::SamplesBuffer, decoder::Decoder, source::UniformSourceIterator},
18    std::{io::Cursor, path::Path, time::SystemTime},
19    tokio::fs::read,
20};
21pub use {
22    error::*,
23    futures::{Stream, StreamExt},
24    logits_sampler::*,
25    onnx_builder::*,
26    text::*,
27};
28
29const T2S_DECODER_EOS: i64 = 1024;
30const VOCAB_SIZE: usize = 1025;
31const NUM_LAYERS: usize = 24;
32
33type KvDType = f32;
34
35#[derive(Clone)]
36pub struct ReferenceData {
37    ref_seq: Array2<i64>,
38    ref_bert: Array2<f32>,
39    ref_audio_32k: Array2<f32>,
40    ssl_content: ArrayBase<OwnedRepr<f32>, IxDyn>,
41}
42
43impl AsRef<Self> for ReferenceData {
44    fn as_ref(&self) -> &Self {
45        self
46    }
47}
48
49pub struct GptSoVitsModel {
50    text_processor: TextProcessor,
51    sovits: Session,
52    ssl: Session,
53    t2s_encoder: Session,
54    t2s_fs_decoder: Session,
55    t2s_s_decoder: Session,
56    num_layers: usize,
57    run_options: RunOptions,
58}
59
60// --- KV Cache Configuration ---
61/// Initial size for the sequence length of the KV cache.
62const INITIAL_CACHE_SIZE: usize = 2048;
63/// How much to increment the KV cache size by when reallocating.
64const CACHE_REALLOC_INCREMENT: usize = 1024;
65
66impl GptSoVitsModel {
67    /// create new tts instance
68    /// bert_path, g2pw_path and g2p_en_path can be None
69    /// if bert path is none, the speech speed in chinese may become worse
70    /// if g2pw path is none, the chinese speech quality may be worse
71    /// g2p_en is still experimental, english speak quality may not be better because of bugs
72    pub fn new<P>(
73        sovits_path: P,
74        ssl_path: P,
75        t2s_encoder_path: P,
76        t2s_fs_decoder_path: P,
77        t2s_s_decoder_path: P,
78        bert_path: Option<P>,
79        g2pw_path: Option<P>,
80        g2p_en_path: Option<P>,
81    ) -> Result<Self, GSVError>
82    where
83        P: AsRef<Path>,
84    {
85        info!("Initializing TTSModel with ONNX sessions");
86
87        let g2pw = G2PW::new(g2pw_path)?;
88
89        let text_processor =
90            TextProcessor::new(g2pw, G2pEn::new(g2p_en_path)?, BertModel::new(bert_path)?)?;
91
92        Ok(GptSoVitsModel {
93            text_processor,
94            sovits: create_onnx_cpu_session(sovits_path)?,
95            ssl: create_onnx_cpu_session(ssl_path)?,
96            t2s_encoder: create_onnx_cpu_session(t2s_encoder_path)?,
97            t2s_fs_decoder: create_onnx_cpu_session(t2s_fs_decoder_path)?,
98            t2s_s_decoder: create_onnx_cpu_session(t2s_s_decoder_path)?,
99            num_layers: NUM_LAYERS,
100            run_options: RunOptions::new()?,
101        })
102    }
103
104    pub async fn get_reference_data<P, S>(
105        &mut self,
106        reference_audio_path: P,
107        ref_text: S,
108        lang_id: LangId,
109    ) -> Result<ReferenceData, GSVError>
110    where
111        P: AsRef<Path>,
112        S: AsRef<str>,
113    {
114        info!("Processing reference audio and text: {}", ref_text.as_ref());
115        let ref_text = ensure_punctuation(ref_text);
116        let phones = self.text_processor.get_phone_and_bert(&ref_text, lang_id)?;
117        let ref_seq: Vec<i64> = phones.iter().fold(Vec::new(), |mut seq, p| {
118            seq.extend(p.1.clone());
119            seq
120        });
121
122        let ref_bert: Vec<Array2<f32>> = phones.iter().map(|f| f.2.clone()).collect();
123        // Concatenate along dimension 0
124        let ref_bert = concatenate(
125            Axis(0),
126            &ref_bert.iter().map(|v| v.view()).collect::<Vec<_>>(),
127        )?;
128
129        let ref_seq = Array2::from_shape_vec((1, ref_seq.len()), ref_seq)?;
130        let (ref_audio_16k, ref_audio_32k) = read_and_resample_audio(&reference_audio_path).await?;
131        let ssl_content = self.process_ssl(&ref_audio_16k).await?;
132
133        Ok(ReferenceData {
134            ref_seq,
135            ref_bert,
136            ref_audio_32k,
137            ssl_content,
138        })
139    }
140
141    async fn process_ssl(
142        &mut self,
143        ref_audio_16k: &Array2<f32>,
144    ) -> Result<ArrayBase<OwnedRepr<f32>, IxDyn>, GSVError> {
145        let time = SystemTime::now();
146        let ssl_output = self
147            .ssl
148            .run_async(
149                inputs!["ref_audio_16k" => TensorRef::from_array_view(ref_audio_16k).unwrap()],
150                &self.run_options,
151            )?
152            .await?;
153        debug!("SSL processing time: {:?}", time.elapsed()?);
154        Ok(ssl_output["ssl_content"]
155            .try_extract_array::<f32>()?
156            .into_owned())
157    }
158
159    /// Efficiently runs the streaming decoder loop with a pre-allocated, resizable KV cache.
160    async fn run_t2s_s_decoder_loop(
161        &mut self,
162        sampler: &mut Sampler,
163        sampling_param: SamplingParams,
164        mut y_vec: Vec<i64>,
165        mut k_caches: Vec<ArrayBase<OwnedRepr<KvDType>, IxDyn>>,
166        mut v_caches: Vec<ArrayBase<OwnedRepr<KvDType>, IxDyn>>,
167        prefix_len: usize,
168        initial_valid_len: usize,
169    ) -> Result<ArrayBase<OwnedRepr<i64>, IxDyn>, GSVError> {
170        let mut idx = 0;
171        let mut valid_len = initial_valid_len;
172        y_vec.reserve(2048);
173
174        loop {
175            // --- 1. Prepare inputs using views of the valid cache portion ---
176            let mut inputs = inputs![
177                "iy" => TensorRef::from_array_view(unsafe {ArrayView2::from_shape_ptr((1, y_vec.len()), y_vec.as_ptr())})?,
178                "y_len" => Tensor::from_array(Array::from_vec(vec![prefix_len as i64]))?,
179                "idx" => Tensor::from_array(Array::from_vec(vec![idx as i64]))?,
180            ];
181
182            for i in 0..self.num_layers {
183                // Create a view of the valid part of the cache
184                let k = k_caches[i].slice(s![.., 0..valid_len, ..]).to_owned();
185                let v = v_caches[i].slice(s![.., 0..valid_len, ..]).to_owned();
186
187                inputs.push((
188                    format!("ik_cache_{}", i).into(),
189                    Tensor::from_array(k)?.into(),
190                ));
191                inputs.push((
192                    format!("iv_cache_{}", i).into(),
193                    Tensor::from_array(v)?.into(),
194                ));
195            }
196            // --- 2. Run the decoder model for one step ---
197            let mut output = self
198                .t2s_s_decoder
199                .run_async(inputs, &self.run_options)?
200                .await?;
201
202            let mut logits = output["logits"].try_extract_array_mut::<f32>()?;
203            let mut logits = logits.as_slice_mut().unwrap().to_owned();
204
205            if idx < 11 {
206                logits.pop();
207            }
208
209            y_vec.push(sampler.sample(&mut logits, &y_vec, &sampling_param));
210
211            let argmax_value = argmax(&logits);
212
213            // --- 3. Check for reallocation and update caches ---
214            let new_valid_len = valid_len + 1;
215
216            // Check if we need to reallocate BEFORE writing to the new index.
217            if new_valid_len > k_caches[0].shape()[1] {
218                info!(
219                    "Reallocating KV cache from {} to {}",
220                    k_caches[0].shape()[1],
221                    k_caches[0].shape()[1] + CACHE_REALLOC_INCREMENT
222                );
223                for i in 0..self.num_layers {
224                    let old_k = &k_caches[i];
225                    let old_v = &v_caches[i];
226
227                    // Create new, larger arrays
228                    let mut new_k_dims = old_k.raw_dim().clone();
229                    new_k_dims[1] += CACHE_REALLOC_INCREMENT;
230                    let mut new_v_dims = old_v.raw_dim().clone();
231                    new_v_dims[1] += CACHE_REALLOC_INCREMENT;
232
233                    let mut new_k = Array::zeros(new_k_dims);
234                    let mut new_v = Array::zeros(new_v_dims);
235
236                    // Copy existing valid data to the new arrays
237                    new_k
238                        .slice_mut(s![.., 0..valid_len, ..])
239                        .assign(&old_k.slice(s![.., 0..valid_len, ..]));
240                    new_v
241                        .slice_mut(s![.., 0..valid_len, ..])
242                        .assign(&old_v.slice(s![.., 0..valid_len, ..]));
243
244                    // Replace the old caches with the new, larger ones
245                    k_caches[i] = new_k;
246                    v_caches[i] = new_v;
247                }
248            }
249
250            // Update KV caches by pasting the newly generated slice of data
251            for i in 0..self.num_layers {
252                let inc_k_cache =
253                    output[format!("k_cache_{}", i)].try_extract_array::<KvDType>()?;
254                let inc_v_cache =
255                    output[format!("v_cache_{}", i)].try_extract_array::<KvDType>()?;
256
257                // The new data is the last row of the incremental output from the model
258                let k_new_slice = inc_k_cache.slice(s![.., valid_len, ..]);
259                let v_new_slice = inc_v_cache.slice(s![.., valid_len, ..]);
260
261                // Paste the new row into our long-running cache at the correct position
262                k_caches[i]
263                    .slice_mut(s![.., valid_len, ..])
264                    .assign(&k_new_slice);
265                v_caches[i]
266                    .slice_mut(s![.., valid_len, ..])
267                    .assign(&v_new_slice);
268            }
269
270            // --- 4. Update valid length and check stop condition ---
271            valid_len = new_valid_len;
272
273            if idx >= 1500 || argmax_value == T2S_DECODER_EOS {
274                let mut sliced = y_vec[(y_vec.len() - idx + 1)..(y_vec.len() - 1)]
275                    .iter()
276                    .map(|&i| if i == T2S_DECODER_EOS { 0 } else { i })
277                    .collect::<Vec<i64>>();
278                sliced.push(0);
279                debug!(
280                    "t2s final len: {}, prefix_len: {}",
281                    sliced.len(),
282                    prefix_len
283                );
284                let y = ArrayD::from_shape_vec(IxDyn(&[1, 1, sliced.len()]), sliced)?;
285                return Ok(y);
286            }
287            idx += 1;
288        }
289    }
290
291    /// synthesize async
292    ///
293    /// `text` is input text for run
294    ///
295    /// `lang_id` can be LangId::Auto(Mandarin) or LangId::AutoYue(cantonese)
296    ///
297    pub async fn synthesize<R, S>(
298        &mut self,
299        text: S,
300        reference_data: R,
301        sampling_param: SamplingParams,
302        lang_id: LangId,
303    ) -> Result<impl Stream<Item = Result<Vec<f32>, GSVError>> + Send + Unpin, GSVError>
304    where
305        R: AsRef<ReferenceData>,
306        S: AsRef<str>,
307    {
308        let time = SystemTime::now();
309        let texts_and_seqs = self
310            .text_processor
311            .get_phone_and_bert(text.as_ref(), lang_id)?;
312        debug!("g2pw and preprocess time: {:?}", time.elapsed()?);
313        let ref_data = reference_data.as_ref().clone();
314
315        let stream = stream! {
316            for (text, seq, bert) in texts_and_seqs {
317                debug!("process: {:?}", text);
318                yield self.in_stream_once_gen(&text, &bert, &seq, &ref_data, sampling_param).await;
319            }
320        };
321
322        Ok(Box::pin(stream))
323    }
324
325    async fn in_stream_once_gen(
326        &mut self,
327        _text: &str,
328        text_bert: &Array2<f32>,
329        text_seq_vec: &[i64],
330        ref_data: &ReferenceData,
331        sampling_param: SamplingParams,
332    ) -> Result<Vec<f32>, GSVError> {
333        let text_seq = Array2::from_shape_vec((1, text_seq_vec.len()), text_seq_vec.to_vec())?;
334        let mut sampler = Sampler::new(VOCAB_SIZE);
335
336        let prompts = {
337            let time = SystemTime::now();
338            let encoder_output = self
339                .t2s_encoder
340                .run_async(
341                    inputs![
342                        "ssl_content" => TensorRef::from_array_view(&ref_data.ssl_content)?
343                    ],
344                    &self.run_options,
345                )?
346                .await?;
347            debug!("T2S Encoder time: {:?}", time.elapsed()?);
348            encoder_output["prompts"]
349                .try_extract_array::<i64>()?
350                .into_owned()
351        };
352
353        let x = concatenate(Axis(1), &[ref_data.ref_seq.view(), text_seq.view()])?.to_owned();
354        let bert = concatenate(
355            Axis(1),
356            &[
357                ref_data.ref_bert.clone().permuted_axes([1, 0]).view(),
358                text_bert.clone().permuted_axes([1, 0]).view(),
359            ],
360        )?;
361
362        let bert = bert.insert_axis(Axis(0)).to_owned();
363
364        let (mut y_vec, _) = prompts.clone().into_raw_vec_and_offset();
365
366        let prefix_len = y_vec.len();
367
368        let (y_vec, k_caches, v_caches, initial_seq_len) = {
369            let time = SystemTime::now();
370            let fs_decoder_output = self
371                .t2s_fs_decoder
372                .run_async(
373                    inputs![
374                        "x" => Tensor::from_array(x)?,
375                        "prompts" => TensorRef::from_array_view(&prompts)?,
376                        "bert" => Tensor::from_array(bert)?,
377                    ],
378                    &self.run_options,
379                )?
380                .await?;
381            debug!("T2S FS Decoder time: {:?}", time.elapsed()?);
382
383            let logits = fs_decoder_output["logits"]
384                .try_extract_array::<f32>()?
385                .into_owned();
386
387            // --- Initialize large KV Caches ---
388            // Get shape and initial data from the first-pass decoder.
389            let k_init_first = fs_decoder_output["k_cache_0"].try_extract_array::<KvDType>()?;
390            let initial_dims_dyn = k_init_first.raw_dim();
391            let initial_seq_len = initial_dims_dyn[1];
392
393            // Define the shape for our large, pre-allocated cache.
394            let mut large_cache_dims = initial_dims_dyn.clone();
395            large_cache_dims[1] = INITIAL_CACHE_SIZE;
396
397            let mut k_caches = Vec::with_capacity(self.num_layers);
398            let mut v_caches = Vec::with_capacity(self.num_layers);
399
400            for i in 0..self.num_layers {
401                let k_init =
402                    fs_decoder_output[format!("k_cache_{}", i)].try_extract_array::<KvDType>()?;
403                let v_init =
404                    fs_decoder_output[format!("v_cache_{}", i)].try_extract_array::<KvDType>()?;
405
406                // Create large, zero-initialized caches.
407                let mut k_large = Array::zeros(large_cache_dims.clone());
408                let mut v_large = Array::zeros(large_cache_dims.clone());
409
410                // Copy the initial data from the first-pass decoder into the start of our large caches.
411                k_large
412                    .slice_mut(s![.., 0..initial_seq_len, ..])
413                    .assign(&k_init);
414                v_large
415                    .slice_mut(s![.., 0..initial_seq_len, ..])
416                    .assign(&v_init);
417
418                k_caches.push(k_large);
419                v_caches.push(v_large);
420            }
421            let (mut logits_vec, _) = logits.into_raw_vec_and_offset();
422            logits_vec.pop(); // remove T2S_DECODER_EOS
423            let sampling_rst = sampler.sample(&mut logits_vec, &y_vec, &sampling_param);
424            y_vec.push(sampling_rst);
425            (y_vec, k_caches, v_caches, initial_seq_len)
426        };
427
428        let time = SystemTime::now();
429        let pred_semantic = self
430            .run_t2s_s_decoder_loop(
431                &mut sampler,
432                sampling_param,
433                y_vec,
434                k_caches,
435                v_caches,
436                prefix_len,
437                initial_seq_len,
438            )
439            .await?;
440        debug!("T2S S Decoder all time: {:?}", time.elapsed()?);
441
442        let time = SystemTime::now();
443        let outputs = self
444            .sovits
445            .run_async(
446                inputs![
447                    "text_seq" => TensorRef::from_array_view(&text_seq)?,
448                    "pred_semantic" => TensorRef::from_array_view(&pred_semantic)?,
449                    "ref_audio" => TensorRef::from_array_view(&ref_data.ref_audio_32k)?
450                ],
451                &self.run_options,
452            )?
453            .await?;
454        debug!("SoVITS time: {:?}", time.elapsed()?);
455        let output_audio = outputs["audio"].try_extract_array::<f32>()?;
456        let (mut audio, _) = output_audio.into_owned().into_raw_vec_and_offset();
457        for sample in &mut audio {
458            *sample = *sample * 4.0;
459        }
460        // Find the maximum absolute value in the audio
461        let max_audio = audio
462            .iter()
463            .filter(|&&x| x.is_finite()) // Ignore NaN or inf
464            .fold(0.0f32, |acc, &x| acc.max(x.abs()));
465        let audio = if max_audio > 1.0 {
466            audio
467                .into_iter()
468                .map(|x| x / max_audio)
469                .collect::<Vec<f32>>()
470        } else {
471            audio
472        };
473
474        Ok(audio)
475    }
476}
477
478fn ensure_punctuation<S>(text: S) -> String
479where
480    S: AsRef<str>,
481{
482    if !text
483        .as_ref()
484        .ends_with(['。', '!', '?', ';', '.', '!', '?', ';'])
485    {
486        text.as_ref().to_owned() + "。"
487    } else {
488        text.as_ref().to_owned()
489    }
490}
491
492fn resample_audio(input: &[f32], in_rate: u32, out_rate: u32) -> Vec<f32> {
493    if in_rate == out_rate {
494        return input.to_owned();
495    }
496
497    UniformSourceIterator::new(SamplesBuffer::new(1, in_rate, input), 1, out_rate).collect()
498}
499
500async fn read_and_resample_audio<P>(path: P) -> Result<(Array2<f32>, Array2<f32>), GSVError>
501where
502    P: AsRef<Path>,
503{
504    let data = Cursor::new(read(path).await?);
505    let decoder = Decoder::new(data)?;
506    let sample_rate = decoder.sample_rate();
507    let samples = if decoder.channels() == 1 {
508        decoder.collect::<Vec<_>>()
509    } else {
510        UniformSourceIterator::new(decoder, 1, sample_rate).collect()
511    };
512
513    // Resample to 16kHz and 32kHz
514    let mut ref_audio_16k = resample_audio(&samples, sample_rate, 16000);
515    let ref_audio_32k = resample_audio(&samples, sample_rate, 32000);
516
517    // Prepend 0.3 seconds of silence
518    let silence_16k = vec![0.0; (0.3 * 16000.0) as usize]; // 8000 samples for 16kHz
519
520    ref_audio_16k.splice(0..0, silence_16k);
521
522    // Convert to Array2
523    Ok((
524        Array2::from_shape_vec((1, ref_audio_16k.len()), ref_audio_16k)?,
525        Array2::from_shape_vec((1, ref_audio_32k.len()), ref_audio_32k)?,
526    ))
527}