gpt_sovits/text/
utils.rs

1use std::{collections::HashMap, sync::LazyLock};
2
3use ndarray::{ArrayView, IntoDimension, IxDyn};
4
5#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
6pub struct PolyChar {
7    pub index: usize,
8    pub phones: Vec<(String, usize)>,
9}
10
11#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
12pub struct MonoChar {
13    pub phone: String,
14}
15
16pub static MONO_CHARS_DIST_STR: &str = include_str!("dict_mono_chars.json");
17pub static POLY_CHARS_DIST_STR: &str = include_str!("dict_poly_chars.json");
18pub static DEFAULT_ZH_WORD_DICT: &str = include_str!("zh_word_dict.json");
19pub static BERT_TOKENIZER: &str = include_str!("g2pw_tokenizer.json");
20
21pub fn load_mono_chars() -> HashMap<char, MonoChar> {
22    if let Ok(dir) = std::env::var("G2PW_DIST_DIR") {
23        let s = std::fs::read_to_string(format!("{}/dict_mono_chars.json", dir))
24            .expect("dict_mono_chars.json not found");
25        serde_json::from_str(&s).expect("dict_mono_chars.json parse error")
26    } else {
27        serde_json::from_str(MONO_CHARS_DIST_STR).unwrap()
28    }
29}
30
31pub fn load_poly_chars() -> HashMap<char, PolyChar> {
32    if let Ok(dir) = std::env::var("G2PW_DIST_DIR") {
33        let s = std::fs::read_to_string(format!("{}/dict_poly_chars.json", dir))
34            .expect("dict_poly_chars.json not found");
35        serde_json::from_str(&s).expect("dict_poly_chars.json parse error")
36    } else {
37        serde_json::from_str(POLY_CHARS_DIST_STR).unwrap()
38    }
39}
40
41pub static DICT_MONO_CHARS: LazyLock<HashMap<char, MonoChar>> = LazyLock::new(load_mono_chars);
42pub static DICT_POLY_CHARS: LazyLock<HashMap<char, PolyChar>> = LazyLock::new(load_poly_chars);
43
44pub fn str_is_chinese(s: &str) -> bool {
45    let mut r = true;
46    for c in s.chars() {
47        if !DICT_MONO_CHARS.contains_key(&c) && !DICT_POLY_CHARS.contains_key(&c) {
48            r &= false;
49        }
50    }
51    r
52}
53
54// Finds the index of the maximum value in a 2D tensor
55pub fn argmax_2d(tensor: &ArrayView<f32, IxDyn>) -> (usize, usize) {
56    let mut max_index = (0, 0);
57    let mut max_value = tensor
58        .get(IxDyn::zeros(2))
59        .copied()
60        .unwrap_or(f32::NEG_INFINITY);
61
62    for i in 0..tensor.shape()[0] {
63        for j in 0..tensor.shape()[1] {
64            if let Some(value) = tensor.get((i, j).into_dimension()) {
65                if *value > max_value {
66                    max_value = *value;
67                    max_index = (i, j);
68                }
69            }
70        }
71    }
72    max_index
73}