goya_ipadic/
ipadic_loader.rs

1use super::ipadic::IPADic;
2use csv::ReaderBuilder;
3use encoding_rs::EUC_JP;
4use glob::glob;
5use goya::char_class::{CharClass, CharClassifier, CharDefinition, InvokeTiming};
6use goya::morpheme::Morpheme;
7use goya::word_features::WordFeaturesMap;
8use indexmap::IndexSet;
9use regex::Regex;
10use serde::Deserialize;
11use std::collections::{HashMap, HashSet};
12use std::error::Error;
13use std::fs;
14use std::path::Path;
15use std::vec::Vec;
16
17const COL_SURFACE_FORM: usize = 0; // 表層形
18const COL_LEFT_CONTEXT_ID: usize = 1; // 左文脈ID
19const COL_RIGHT_CONTEXT_ID: usize = 2; // 右文脈ID
20const COL_COST: usize = 3; // コスト
21
22pub struct LoadResult {
23    pub ipadic: IPADic,
24    pub word_set: WordFeaturesMap,
25    pub surfaces: HashMap<usize, String>,
26}
27
28pub struct IPADicLoader {}
29impl IPADicLoader {
30    pub fn load(&self, dir: &str) -> Result<LoadResult, Box<dyn Error>> {
31        let classes = load_chars(Path::new(dir).join("char.def"))?;
32        let matrix = load_matrix(Path::new(dir).join("matrix.def"))?;
33        let unknown = load_unknown(Path::new(dir).join("unk.def"))?;
34        let csv_pattern = Path::new(dir).join("*.csv");
35        let csv_pattern = csv_pattern.to_str().ok_or("Failed to build glob pattern")?;
36
37        let mut vocabulary_index: IndexSet<Morpheme> = IndexSet::new();
38        let mut surfaces = HashMap::new();
39        let mut known_features = HashMap::new();
40        let mut vocabulary = HashMap::new();
41        let mut tmp_homonyms = HashMap::new();
42        let mut id: usize = 1;
43        for path in glob(csv_pattern)? {
44            for row in load_words_csv(path?)? {
45                surfaces.insert(id, row.surface_form.to_string());
46                known_features.insert(id, row.features.clone());
47                tmp_homonyms
48                    .entry(row.surface_form.to_string())
49                    .or_insert_with(Vec::new)
50                    .push(id);
51
52                let (idx, _) = vocabulary_index.insert_full(row.into());
53                vocabulary.insert(id, idx);
54                id += 1;
55            }
56        }
57        let mut homonyms: HashMap<usize, Vec<usize>> = HashMap::new();
58        for wids in tmp_homonyms.values() {
59            for wid in wids.iter() {
60                homonyms.insert(*wid, wids.iter().copied().collect());
61            }
62        }
63
64        let mut unknown_vocabulary = HashMap::new();
65        let mut unknown_features = HashMap::new();
66        let mut unknown_classes = HashMap::new();
67        let mut id = 1;
68        for (class, words) in unknown.into_iter() {
69            for row in words {
70                unknown_features.insert(id, row.features.clone());
71                let (idx, _) = vocabulary_index.insert_full(row.into());
72                unknown_vocabulary.insert(id, idx);
73                unknown_classes
74                    .entry(class.to_string())
75                    .or_insert_with(Vec::new)
76                    .push(id);
77                id += 1;
78            }
79        }
80
81        let word_set = WordFeaturesMap::new(
82            map_to_vec(known_features, Vec::new),
83            map_to_vec(unknown_features, Vec::new),
84        );
85        let ipadic = IPADic::from(
86            map_to_vec(vocabulary, || 0),
87            map_to_vec(homonyms, Vec::new),
88            classes,
89            matrix,
90            unknown_classes,
91            map_to_vec(unknown_vocabulary, || 0),
92            vocabulary_index,
93        );
94        let ret = LoadResult {
95            word_set,
96            ipadic,
97            surfaces,
98        };
99        Ok(ret)
100    }
101}
102
103#[derive(Debug, Clone, Deserialize)]
104struct CSVRow {
105    /// 表層形
106    /// https://taku910.github.io/mecab/dic-detail.html
107    surface_form: String,
108    /// 左文脈ID (単語を左から見たときの文脈 ID)
109    /// https://taku910.github.io/mecab/dic-detail.html
110    left_context_id: usize,
111    /// 右文脈ID (単語を右から見たときの文脈 ID)
112    /// https://taku910.github.io/mecab/dic-detail.html
113    right_context_id: usize,
114    /// 単語コスト (小さいほど出現しやすい)
115    /// コスト値は short int (16bit 整数) の範囲におさめる必要があります.
116    cost: i16,
117    /// 5カラム目以降は, ユーザ定義の CSV フィールドです. 基本的に どんな内容でも CSV の許す限り追加することができます.
118    /// https://taku910.github.io/mecab/dic-detail.html
119    features: Vec<String>,
120}
121impl From<CSVRow> for Morpheme {
122    fn from(row: CSVRow) -> Self {
123        Morpheme {
124            left_context_id: row.left_context_id,
125            right_context_id: row.right_context_id,
126            cost: row.cost,
127        }
128    }
129}
130
131fn load_words_csv<P>(path: P) -> Result<Vec<CSVRow>, Box<dyn Error>>
132where
133    P: AsRef<Path>,
134{
135    let eucjp = fs::read(path)?;
136    let (utf8, _, _) = EUC_JP.decode(&eucjp);
137    let mut rdr = ReaderBuilder::new()
138        .has_headers(false)
139        .from_reader(utf8.as_bytes());
140    let mut words = vec![];
141    for row in rdr.records() {
142        let row = row?;
143        words.push(CSVRow {
144            surface_form: row[COL_SURFACE_FORM].to_string(),
145            left_context_id: row[COL_LEFT_CONTEXT_ID].parse::<usize>().unwrap(),
146            right_context_id: row[COL_RIGHT_CONTEXT_ID].parse::<usize>().unwrap(),
147            cost: row[COL_COST].parse::<i16>().unwrap(),
148            features: row
149                .iter()
150                .skip(COL_COST + 1)
151                .map(|v| v.to_string())
152                .collect::<Vec<_>>(),
153        })
154    }
155    Ok(words)
156}
157
158fn load_chars<P>(path: P) -> Result<CharClassifier, Box<dyn Error>>
159where
160    P: AsRef<Path>,
161{
162    let eucjp = fs::read(path)?;
163    let (utf8, _, _) = EUC_JP.decode(&eucjp);
164    let lines = utf8
165        .lines()
166        .filter(|line| !line.is_empty() && !line.starts_with('#'))
167        .map(|line| Regex::new(r"#.*$").unwrap().replace(line, ""))
168        .collect::<Vec<_>>();
169
170    let head = lines.iter().take_while(|line| {
171        let parts = line.trim().split_ascii_whitespace().collect::<Vec<_>>();
172        !parts[0].starts_with("0x")
173    });
174    let mut chars = HashMap::new();
175    for line in head {
176        let parts = line.trim().split_ascii_whitespace().collect::<Vec<_>>();
177        let kind = parts[0].to_owned();
178        let class = kind.to_string();
179        let timing = if parts[1] == "0" {
180            InvokeTiming::Fallback
181        } else {
182            InvokeTiming::Always
183        };
184        let group_by_same_kind = parts[2] == "1";
185        let len = parts[3].parse::<usize>()?;
186        chars.insert(
187            kind,
188            CharDefinition {
189                class,
190                timing,
191                group_by_same_kind,
192                len,
193                compatibilities: HashSet::new(),
194            },
195        );
196    }
197
198    let tail = lines.iter().skip_while(|line| {
199        let parts = line.trim().split_ascii_whitespace().collect::<Vec<_>>();
200        !parts[0].starts_with("0x")
201    });
202    let mut ranges = vec![];
203    for line in tail {
204        let parts = line.trim().split_ascii_whitespace().collect::<Vec<_>>();
205        let range = parts[0]
206            .split("..")
207            .map(|c| u32::from_str_radix(&c[2..], 16).unwrap())
208            .map(|c| char::from_u32(c).unwrap())
209            .collect::<Vec<_>>();
210        let range = if range.len() > 1 {
211            (range[0] as u32, range[1] as u32)
212        } else {
213            (range[0] as u32, range[0] as u32)
214        };
215        let class = parts[1];
216        let compatibilities = parts
217            .iter()
218            .skip(2)
219            .map(|s| s.to_string())
220            .collect::<HashSet<_>>();
221        chars.get_mut(class).unwrap().compatibilities = compatibilities;
222        ranges.push(CharClass::from(range, class.to_string()));
223    }
224
225    Ok(CharClassifier::from(chars, ranges))
226}
227
228fn load_matrix<P>(path: P) -> Result<Vec<Vec<i16>>, Box<dyn Error>>
229where
230    P: AsRef<Path>,
231{
232    let eucjp = fs::read(path)?;
233    let (utf8, _, _) = EUC_JP.decode(&eucjp);
234    let mut lines = utf8.lines();
235    let size = lines
236        .next()
237        .expect("failed to read the first line")
238        .split_ascii_whitespace()
239        .map(|p| p.parse::<usize>().unwrap())
240        .collect::<Vec<_>>();
241    let mut matrix = vec![vec![-1; size[1]]; size[0]];
242    for line in lines {
243        let parts = line.split_ascii_whitespace().collect::<Vec<_>>();
244        let left = parts[0].parse::<usize>()?;
245        let right = parts[1].parse::<usize>()?;
246        let cost = parts[2].parse::<i16>()?;
247        matrix[left][right] = cost;
248    }
249    Ok(matrix)
250}
251
252fn load_unknown<P>(path: P) -> Result<HashMap<String, Vec<CSVRow>>, Box<dyn Error>>
253where
254    P: AsRef<Path>,
255{
256    let words = load_words_csv(path)?;
257    let mut map = HashMap::<String, Vec<CSVRow>>::new();
258    for w in words.into_iter() {
259        map.entry(w.surface_form.to_string())
260            .or_insert_with(Vec::new)
261            .push(w);
262    }
263    Ok(map)
264}
265
266fn map_to_vec<T: Clone>(map: HashMap<usize, T>, default: impl Fn() -> T) -> Vec<T> {
267    let mut ret = vec![default(); map.len() + 1];
268    for (idx, value) in map.into_iter() {
269        ret[idx] = value;
270    }
271    ret
272}