1use std::{collections::HashMap, sync::LazyLock};
2
3use ndarray::{ArrayView, IntoDimension, IxDyn};
4
5#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
6pub struct PolyChar {
7 pub index: usize,
8 pub phones: Vec<(String, usize)>,
9}
10
11#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
12pub struct MonoChar {
13 pub phone: String,
14}
15
16pub static MONO_CHARS_DIST_STR: &str = include_str!("dict_mono_chars.json");
17pub static POLY_CHARS_DIST_STR: &str = include_str!("dict_poly_chars.json");
18pub static DEFAULT_ZH_WORD_DICT: &str = include_str!("zh_word_dict.json");
19pub static BERT_TOKENIZER: &str = include_str!("g2pw_tokenizer.json");
20
21pub fn load_mono_chars() -> HashMap<char, MonoChar> {
22 if let Ok(dir) = std::env::var("G2PW_DIST_DIR") {
23 let s = std::fs::read_to_string(format!("{}/dict_mono_chars.json", dir))
24 .expect("dict_mono_chars.json not found");
25 serde_json::from_str(&s).expect("dict_mono_chars.json parse error")
26 } else {
27 serde_json::from_str(MONO_CHARS_DIST_STR).unwrap()
28 }
29}
30
31pub fn load_poly_chars() -> HashMap<char, PolyChar> {
32 if let Ok(dir) = std::env::var("G2PW_DIST_DIR") {
33 let s = std::fs::read_to_string(format!("{}/dict_poly_chars.json", dir))
34 .expect("dict_poly_chars.json not found");
35 serde_json::from_str(&s).expect("dict_poly_chars.json parse error")
36 } else {
37 serde_json::from_str(POLY_CHARS_DIST_STR).unwrap()
38 }
39}
40
41pub static DICT_MONO_CHARS: LazyLock<HashMap<char, MonoChar>> = LazyLock::new(load_mono_chars);
42pub static DICT_POLY_CHARS: LazyLock<HashMap<char, PolyChar>> = LazyLock::new(load_poly_chars);
43
44pub fn str_is_chinese(s: &str) -> bool {
45 let mut r = true;
46 for c in s.chars() {
47 if !DICT_MONO_CHARS.contains_key(&c) && !DICT_POLY_CHARS.contains_key(&c) {
48 r &= false;
49 }
50 }
51 r
52}
53
54pub fn argmax_2d(tensor: &ArrayView<f32, IxDyn>) -> (usize, usize) {
56 let mut max_index = (0, 0);
57 let mut max_value = tensor
58 .get(IxDyn::zeros(2))
59 .copied()
60 .unwrap_or(f32::NEG_INFINITY);
61
62 for i in 0..tensor.shape()[0] {
63 for j in 0..tensor.shape()[1] {
64 if let Some(value) = tensor.get((i, j).into_dimension()) {
65 if *value > max_value {
66 max_value = *value;
67 max_index = (i, j);
68 }
69 }
70 }
71 }
72 max_index
73}