1use std::{fs, path::Path};
2
3use lindera_dictionary::{
4 builder::{
5 character_definition::CharacterDefinitionBuilderOptions,
6 connection_cost_matrix::ConnectionCostMatrixBuilderOptions, metadata::MetadataBuilder,
7 unknown_dictionary::UnknownDictionaryBuilderOptions,
8 user_dictionary::build_user_dictionary,
9 },
10 dictionary::{
11 character_definition::CharacterDefinition, metadata::Metadata, schema::Schema,
12 UserDictionary,
13 },
14 error::LinderaErrorKind,
15 LinderaResult,
16};
17
18use crate::dictionary::to_dict::prefix_dictionary::{
19 generate_prefix_dictionary,
20 parser::{
21 DefaultParser, DefaultParserOptions, UserDictionaryParser, UserDictionaryParserOptions,
22 },
23 write_prefix_dictionary, CSVReaderOptions,
24};
25
26use super::word_encoding::JPreprocessDictionaryWordEncoding;
27
28mod prefix_dictionary;
29
30pub struct JPreprocessDictionaryBuilder {
31 metadata: Metadata,
32}
33
34impl JPreprocessDictionaryBuilder {
35 pub fn new(metadata: Metadata) -> Self {
36 Self { metadata }
37 }
38}
39
40impl Default for JPreprocessDictionaryBuilder {
41 fn default() -> Self {
42 Self {
43 metadata: Self::default_metadata(),
44 }
45 }
46}
47
48impl JPreprocessDictionaryBuilder {
49 pub fn default_metadata() -> Metadata {
50 Metadata {
51 dictionary_schema: Schema::new(vec![
52 "surface".to_string(),
53 "left_context_id".to_string(),
54 "right_context_id".to_string(),
55 "cost".to_string(),
56 "major_pos".to_string(),
57 "middle_pos".to_string(),
58 "small_pos".to_string(),
59 "fine_pos".to_string(),
60 "conjugation_type".to_string(),
61 "conjugation_form".to_string(),
62 "base_form".to_string(),
63 "reading".to_string(),
64 "pronunciation".to_string(),
65 "accent_morasize".to_string(),
67 "chain_rule".to_string(),
68 "chain_flag".to_string(),
69 ]),
70 ..Default::default()
71 }
72 }
73
74 pub fn build_dictionary(&self, input_dir: &Path, output_dir: &Path) -> LinderaResult<()> {
75 fs::create_dir_all(output_dir)
76 .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?;
77
78 self.build_metadata(output_dir)?;
79 let chardef = self.build_character_definition(input_dir, output_dir)?;
80 self.build_unknown_dictionary(input_dir, &chardef, output_dir)?;
81 self.build_prefix_dictionary(input_dir, output_dir)?;
82 self.build_connection_cost_matrix(input_dir, output_dir)?;
83
84 Ok(())
85 }
86
87 pub fn build_metadata(&self, output_dir: &Path) -> LinderaResult<()> {
88 MetadataBuilder::new().build(&self.metadata, output_dir)
89 }
90
91 pub fn build_user_dictionary(
92 &self,
93 input_file: &Path,
94 output_file: &Path,
95 ) -> LinderaResult<()> {
96 let user_dict = self.build_user_dict(input_file)?;
97 build_user_dictionary(user_dict, output_file)
98 }
99
100 pub fn build_character_definition(
101 &self,
102 input_dir: &Path,
103 output_dir: &Path,
104 ) -> LinderaResult<CharacterDefinition> {
105 CharacterDefinitionBuilderOptions::default()
106 .encoding(self.metadata.encoding.clone())
107 .compress_algorithm(self.metadata.compress_algorithm)
108 .builder()
109 .unwrap()
110 .build(input_dir, output_dir)
111 }
112
113 pub fn build_unknown_dictionary(
114 &self,
115 input_dir: &Path,
116 chardef: &CharacterDefinition,
117 output_dir: &Path,
118 ) -> LinderaResult<()> {
119 UnknownDictionaryBuilderOptions::default()
120 .encoding(self.metadata.encoding.clone())
121 .compress_algorithm(self.metadata.compress_algorithm)
122 .builder()
123 .unwrap()
124 .build(input_dir, chardef, output_dir)
125 }
126
127 pub fn build_prefix_dictionary(
128 &self,
129 input_dir: &Path,
130 output_dir: &Path,
131 ) -> LinderaResult<()> {
132 let reader = CSVReaderOptions::default()
133 .flexible_csv(self.metadata.flexible_csv)
134 .encoding(self.metadata.encoding.clone())
135 .normalize_details(self.metadata.normalize_details)
136 .builder()
137 .unwrap();
138 let rows = reader.load_csv_data(input_dir)?;
139
140 let parser = DefaultParserOptions::default()
141 .skip_invalid_cost_or_id(self.metadata.skip_invalid_cost_or_id)
142 .schema(self.metadata.dictionary_schema.clone())
143 .normalize_details(self.metadata.normalize_details)
144 .builder()
145 .unwrap();
146
147 write_prefix_dictionary::<DefaultParser, JPreprocessDictionaryWordEncoding>(
148 &parser,
149 &rows,
150 output_dir,
151 self.metadata.compress_algorithm,
152 )
153 }
154
155 pub fn build_connection_cost_matrix(
156 &self,
157 input_dir: &Path,
158 output_dir: &Path,
159 ) -> LinderaResult<()> {
160 ConnectionCostMatrixBuilderOptions::default()
161 .encoding(self.metadata.encoding.clone())
162 .compress_algorithm(self.metadata.compress_algorithm)
163 .builder()
164 .unwrap()
165 .build(input_dir, output_dir)
166 }
167
168 pub fn build_user_dict(&self, input_file: &Path) -> LinderaResult<UserDictionary> {
169 let reader = CSVReaderOptions::default()
170 .flexible_csv(self.metadata.flexible_csv)
171 .builder()
172 .unwrap();
173 let rows = reader.read_csv_files(&[input_file.to_path_buf()])?;
174
175 self.build_user_dict_from_rows(rows)
176 }
177 pub fn build_user_dict_from_data(&self, data: Vec<Vec<&str>>) -> LinderaResult<UserDictionary> {
178 let rows = data
179 .into_iter()
180 .map(csv::StringRecord::from_iter)
181 .collect::<Vec<_>>();
182
183 self.build_user_dict_from_rows(rows)
184 }
185
186 fn build_user_dict_from_rows(
187 &self,
188 rows: Vec<csv::StringRecord>,
189 ) -> LinderaResult<UserDictionary> {
190 let parser = UserDictionaryParserOptions::default()
191 .user_dictionary_fields_num(self.metadata.user_dictionary_schema.field_count())
192 .default_word_cost(self.metadata.default_word_cost)
193 .default_left_context_id(self.metadata.default_left_context_id)
194 .default_right_context_id(self.metadata.default_right_context_id)
195 .dictionary_parser(
196 DefaultParserOptions::default()
197 .skip_invalid_cost_or_id(self.metadata.skip_invalid_cost_or_id)
198 .schema(self.metadata.dictionary_schema.clone())
199 .normalize_details(self.metadata.normalize_details)
200 .builder()
201 .unwrap(),
202 )
203 .user_dictionary_parser(
204 DefaultParserOptions::default()
205 .schema(self.metadata.user_dictionary_schema.clone())
206 .normalize_details(self.metadata.normalize_details)
207 .builder()
208 .unwrap(),
209 )
210 .builder()
211 .unwrap();
212
213 let dict = generate_prefix_dictionary::<
214 UserDictionaryParser,
215 JPreprocessDictionaryWordEncoding,
216 >(&parser, &rows, false)?;
217
218 Ok(UserDictionary { dict })
219 }
220}
221
222#[deprecated(
223 note = "Use JPreprocessDictionaryBuilder::build_user_dict_from_data instead",
224 since = "0.13.0"
225)]
226pub fn build_user_dict_from_data(data: Vec<Vec<&str>>) -> LinderaResult<UserDictionary> {
227 let rows = data
228 .into_iter()
229 .map(csv::StringRecord::from_iter)
230 .collect::<Vec<_>>();
231
232 let builder = JPreprocessDictionaryBuilder::new(Metadata {
233 default_word_cost: -10000,
234 default_left_context_id: 0,
235 default_right_context_id: 0,
236 ..JPreprocessDictionaryBuilder::default_metadata()
237 });
238
239 builder.build_user_dict_from_rows(rows)
240}
241
242#[cfg(test)]
243mod tests {
244 use lindera_dictionary::viterbi::{LexType, WordEntry, WordId};
245
246 use super::*;
247
248 #[test]
249 fn test_user_dictionary() {
250 let builder = JPreprocessDictionaryBuilder::default();
251
252 let data = vec![
253 vec![
254 "東京スカイツリー",
255 "1285",
256 "1285",
257 "-3000",
258 "名詞",
259 "固有名詞",
260 "一般",
261 "*",
262 "*",
263 "*",
264 "*",
265 "トウキョウスカイツリー",
266 "トウキョウスカイツリー",
267 "13",
268 "*",
269 "*",
270 ],
271 vec![
272 "すもももももももものうち",
273 "1285",
274 "1285",
275 "-3000",
276 "名詞",
277 "固有名詞",
278 "一般",
279 "*",
280 "*",
281 "*",
282 "*",
283 "スモモモモモモモモノウチ",
284 "スモモモモモモモモノウチ",
285 "13",
286 "*",
287 "*",
288 ],
289 ];
290
291 let user_dict = builder.build_user_dict_from_data(data).unwrap();
292 assert_eq!(
293 user_dict.dict.find_surface("東京スカイツリー"),
294 vec![WordEntry {
295 word_id: WordId {
296 id: 0,
297 is_system: false,
298 lex_type: LexType::User,
299 },
300 word_cost: -3000,
301 left_id: 1285,
302 right_id: 1285,
303 },]
304 );
305 assert_eq!(
306 user_dict.dict.find_surface("すもももももももものうち"),
307 vec![WordEntry {
308 word_id: WordId {
309 id: 1,
310 is_system: false,
311 lex_type: LexType::User,
312 },
313 word_cost: -3000,
314 left_id: 1285,
315 right_id: 1285,
316 },]
317 );
318 }
319
320 #[test]
321 fn test_simple_user_dictionary() {
322 let builder = JPreprocessDictionaryBuilder::default();
323
324 let data = vec![
325 vec![
326 "東京スカイツリー", "トウキョウスカイツリー", "トーキョースカイツリー", ],
330 vec![
331 "すもももももももものうち",
332 "スモモモモモモモモノウチ",
333 "スモモモモモモモモノウチ",
334 ],
335 ];
336
337 let user_dict = builder.build_user_dict_from_data(data).unwrap();
338 assert_eq!(
339 user_dict.dict.find_surface("東京スカイツリー"),
340 vec![WordEntry {
341 word_id: WordId {
342 id: 0,
343 is_system: false,
344 lex_type: LexType::User,
345 },
346 word_cost: -10000,
347 left_id: 1288,
348 right_id: 1288,
349 },]
350 );
351 assert_eq!(
352 user_dict.dict.find_surface("すもももももももものうち"),
353 vec![WordEntry {
354 word_id: WordId {
355 id: 1,
356 is_system: false,
357 lex_type: LexType::User,
358 },
359 word_cost: -10000,
360 left_id: 1288,
361 right_id: 1288,
362 },]
363 );
364 }
365}