jpreprocess_dictionary/dictionary/to_dict/
mod.rs

1use std::{fs, path::Path};
2
3use lindera_dictionary::{
4    builder::{
5        character_definition::CharacterDefinitionBuilderOptions,
6        connection_cost_matrix::ConnectionCostMatrixBuilderOptions, metadata::MetadataBuilder,
7        unknown_dictionary::UnknownDictionaryBuilderOptions,
8        user_dictionary::build_user_dictionary,
9    },
10    dictionary::{
11        character_definition::CharacterDefinition, metadata::Metadata, schema::Schema,
12        UserDictionary,
13    },
14    error::LinderaErrorKind,
15    LinderaResult,
16};
17
18use crate::dictionary::to_dict::prefix_dictionary::{
19    generate_prefix_dictionary,
20    parser::{
21        DefaultParser, DefaultParserOptions, UserDictionaryParser, UserDictionaryParserOptions,
22    },
23    write_prefix_dictionary, CSVReaderOptions,
24};
25
26use super::word_encoding::JPreprocessDictionaryWordEncoding;
27
28mod prefix_dictionary;
29
30pub struct JPreprocessDictionaryBuilder {
31    metadata: Metadata,
32}
33
34impl JPreprocessDictionaryBuilder {
35    pub fn new(metadata: Metadata) -> Self {
36        Self { metadata }
37    }
38}
39
40impl Default for JPreprocessDictionaryBuilder {
41    fn default() -> Self {
42        Self {
43            metadata: Self::default_metadata(),
44        }
45    }
46}
47
48impl JPreprocessDictionaryBuilder {
49    pub fn default_metadata() -> Metadata {
50        Metadata {
51            dictionary_schema: Schema::new(vec![
52                "surface".to_string(),
53                "left_context_id".to_string(),
54                "right_context_id".to_string(),
55                "cost".to_string(),
56                "major_pos".to_string(),
57                "middle_pos".to_string(),
58                "small_pos".to_string(),
59                "fine_pos".to_string(),
60                "conjugation_type".to_string(),
61                "conjugation_form".to_string(),
62                "base_form".to_string(),
63                "reading".to_string(),
64                "pronunciation".to_string(),
65                // Additional fields
66                "accent_morasize".to_string(),
67                "chain_rule".to_string(),
68                "chain_flag".to_string(),
69            ]),
70            ..Default::default()
71        }
72    }
73
74    pub fn build_dictionary(&self, input_dir: &Path, output_dir: &Path) -> LinderaResult<()> {
75        fs::create_dir_all(output_dir)
76            .map_err(|err| LinderaErrorKind::Io.with_error(anyhow::anyhow!(err)))?;
77
78        self.build_metadata(output_dir)?;
79        let chardef = self.build_character_definition(input_dir, output_dir)?;
80        self.build_unknown_dictionary(input_dir, &chardef, output_dir)?;
81        self.build_prefix_dictionary(input_dir, output_dir)?;
82        self.build_connection_cost_matrix(input_dir, output_dir)?;
83
84        Ok(())
85    }
86
87    pub fn build_metadata(&self, output_dir: &Path) -> LinderaResult<()> {
88        MetadataBuilder::new().build(&self.metadata, output_dir)
89    }
90
91    pub fn build_user_dictionary(
92        &self,
93        input_file: &Path,
94        output_file: &Path,
95    ) -> LinderaResult<()> {
96        let user_dict = self.build_user_dict(input_file)?;
97        build_user_dictionary(user_dict, output_file)
98    }
99
100    pub fn build_character_definition(
101        &self,
102        input_dir: &Path,
103        output_dir: &Path,
104    ) -> LinderaResult<CharacterDefinition> {
105        CharacterDefinitionBuilderOptions::default()
106            .encoding(self.metadata.encoding.clone())
107            .compress_algorithm(self.metadata.compress_algorithm)
108            .builder()
109            .unwrap()
110            .build(input_dir, output_dir)
111    }
112
113    pub fn build_unknown_dictionary(
114        &self,
115        input_dir: &Path,
116        chardef: &CharacterDefinition,
117        output_dir: &Path,
118    ) -> LinderaResult<()> {
119        UnknownDictionaryBuilderOptions::default()
120            .encoding(self.metadata.encoding.clone())
121            .compress_algorithm(self.metadata.compress_algorithm)
122            .builder()
123            .unwrap()
124            .build(input_dir, chardef, output_dir)
125    }
126
127    pub fn build_prefix_dictionary(
128        &self,
129        input_dir: &Path,
130        output_dir: &Path,
131    ) -> LinderaResult<()> {
132        let reader = CSVReaderOptions::default()
133            .flexible_csv(self.metadata.flexible_csv)
134            .encoding(self.metadata.encoding.clone())
135            .normalize_details(self.metadata.normalize_details)
136            .builder()
137            .unwrap();
138        let rows = reader.load_csv_data(input_dir)?;
139
140        let parser = DefaultParserOptions::default()
141            .skip_invalid_cost_or_id(self.metadata.skip_invalid_cost_or_id)
142            .schema(self.metadata.dictionary_schema.clone())
143            .normalize_details(self.metadata.normalize_details)
144            .builder()
145            .unwrap();
146
147        write_prefix_dictionary::<DefaultParser, JPreprocessDictionaryWordEncoding>(
148            &parser,
149            &rows,
150            output_dir,
151            self.metadata.compress_algorithm,
152        )
153    }
154
155    pub fn build_connection_cost_matrix(
156        &self,
157        input_dir: &Path,
158        output_dir: &Path,
159    ) -> LinderaResult<()> {
160        ConnectionCostMatrixBuilderOptions::default()
161            .encoding(self.metadata.encoding.clone())
162            .compress_algorithm(self.metadata.compress_algorithm)
163            .builder()
164            .unwrap()
165            .build(input_dir, output_dir)
166    }
167
168    pub fn build_user_dict(&self, input_file: &Path) -> LinderaResult<UserDictionary> {
169        let reader = CSVReaderOptions::default()
170            .flexible_csv(self.metadata.flexible_csv)
171            .builder()
172            .unwrap();
173        let rows = reader.read_csv_files(&[input_file.to_path_buf()])?;
174
175        self.build_user_dict_from_rows(rows)
176    }
177    pub fn build_user_dict_from_data(&self, data: Vec<Vec<&str>>) -> LinderaResult<UserDictionary> {
178        let rows = data
179            .into_iter()
180            .map(csv::StringRecord::from_iter)
181            .collect::<Vec<_>>();
182
183        self.build_user_dict_from_rows(rows)
184    }
185
186    fn build_user_dict_from_rows(
187        &self,
188        rows: Vec<csv::StringRecord>,
189    ) -> LinderaResult<UserDictionary> {
190        let parser = UserDictionaryParserOptions::default()
191            .user_dictionary_fields_num(self.metadata.user_dictionary_schema.field_count())
192            .default_word_cost(self.metadata.default_word_cost)
193            .default_left_context_id(self.metadata.default_left_context_id)
194            .default_right_context_id(self.metadata.default_right_context_id)
195            .dictionary_parser(
196                DefaultParserOptions::default()
197                    .skip_invalid_cost_or_id(self.metadata.skip_invalid_cost_or_id)
198                    .schema(self.metadata.dictionary_schema.clone())
199                    .normalize_details(self.metadata.normalize_details)
200                    .builder()
201                    .unwrap(),
202            )
203            .user_dictionary_parser(
204                DefaultParserOptions::default()
205                    .schema(self.metadata.user_dictionary_schema.clone())
206                    .normalize_details(self.metadata.normalize_details)
207                    .builder()
208                    .unwrap(),
209            )
210            .builder()
211            .unwrap();
212
213        let dict = generate_prefix_dictionary::<
214            UserDictionaryParser,
215            JPreprocessDictionaryWordEncoding,
216        >(&parser, &rows, false)?;
217
218        Ok(UserDictionary { dict })
219    }
220}
221
222#[deprecated(
223    note = "Use JPreprocessDictionaryBuilder::build_user_dict_from_data instead",
224    since = "0.13.0"
225)]
226pub fn build_user_dict_from_data(data: Vec<Vec<&str>>) -> LinderaResult<UserDictionary> {
227    let rows = data
228        .into_iter()
229        .map(csv::StringRecord::from_iter)
230        .collect::<Vec<_>>();
231
232    let builder = JPreprocessDictionaryBuilder::new(Metadata {
233        default_word_cost: -10000,
234        default_left_context_id: 0,
235        default_right_context_id: 0,
236        ..JPreprocessDictionaryBuilder::default_metadata()
237    });
238
239    builder.build_user_dict_from_rows(rows)
240}
241
242#[cfg(test)]
243mod tests {
244    use lindera_dictionary::viterbi::{LexType, WordEntry, WordId};
245
246    use super::*;
247
248    #[test]
249    fn test_user_dictionary() {
250        let builder = JPreprocessDictionaryBuilder::default();
251
252        let data = vec![
253            vec![
254                "東京スカイツリー",
255                "1285",
256                "1285",
257                "-3000",
258                "名詞",
259                "固有名詞",
260                "一般",
261                "*",
262                "*",
263                "*",
264                "*",
265                "トウキョウスカイツリー",
266                "トウキョウスカイツリー",
267                "13",
268                "*",
269                "*",
270            ],
271            vec![
272                "すもももももももものうち",
273                "1285",
274                "1285",
275                "-3000",
276                "名詞",
277                "固有名詞",
278                "一般",
279                "*",
280                "*",
281                "*",
282                "*",
283                "スモモモモモモモモノウチ",
284                "スモモモモモモモモノウチ",
285                "13",
286                "*",
287                "*",
288            ],
289        ];
290
291        let user_dict = builder.build_user_dict_from_data(data).unwrap();
292        assert_eq!(
293            user_dict.dict.find_surface("東京スカイツリー"),
294            vec![WordEntry {
295                word_id: WordId {
296                    id: 0,
297                    is_system: false,
298                    lex_type: LexType::User,
299                },
300                word_cost: -3000,
301                left_id: 1285,
302                right_id: 1285,
303            },]
304        );
305        assert_eq!(
306            user_dict.dict.find_surface("すもももももももものうち"),
307            vec![WordEntry {
308                word_id: WordId {
309                    id: 1,
310                    is_system: false,
311                    lex_type: LexType::User,
312                },
313                word_cost: -3000,
314                left_id: 1285,
315                right_id: 1285,
316            },]
317        );
318    }
319
320    #[test]
321    fn test_simple_user_dictionary() {
322        let builder = JPreprocessDictionaryBuilder::default();
323
324        let data = vec![
325            vec![
326                "東京スカイツリー",       // surface
327                "トウキョウスカイツリー", // reading
328                "トーキョースカイツリー", // pronunciation
329            ],
330            vec![
331                "すもももももももものうち",
332                "スモモモモモモモモノウチ",
333                "スモモモモモモモモノウチ",
334            ],
335        ];
336
337        let user_dict = builder.build_user_dict_from_data(data).unwrap();
338        assert_eq!(
339            user_dict.dict.find_surface("東京スカイツリー"),
340            vec![WordEntry {
341                word_id: WordId {
342                    id: 0,
343                    is_system: false,
344                    lex_type: LexType::User,
345                },
346                word_cost: -10000,
347                left_id: 1288,
348                right_id: 1288,
349            },]
350        );
351        assert_eq!(
352            user_dict.dict.find_surface("すもももももももものうち"),
353            vec![WordEntry {
354                word_id: WordId {
355                    id: 1,
356                    is_system: false,
357                    lex_type: LexType::User,
358                },
359                word_cost: -10000,
360                left_id: 1288,
361                right_id: 1288,
362            },]
363        );
364    }
365}