vibrato/
trainer.rs

1//! Module for training models.
2//!
3//! # Examples
4//!
5//! ```
6//! # fn main() -> Result<(), Box<dyn std::error::Error>> {
7//! use std::fs::File;
8//! use vibrato::trainer::{Corpus, Trainer, TrainerConfig};
9//! use vibrato::{SystemDictionaryBuilder, Tokenizer};
10//!
11//! // Loads configurations
12//! let lexicon_rdr = File::open("src/tests/resources/train_lex.csv")?;
13//! let char_prop_rdr = File::open("src/tests/resources/char.def")?;
14//! let unk_handler_rdr = File::open("src/tests/resources/train_unk.def")?;
15//! let feature_templates_rdr = File::open("src/tests/resources/feature.def")?;
16//! let rewrite_rules_rdr = File::open("src/tests/resources/rewrite.def")?;
17//! let config = TrainerConfig::from_readers(
18//!     lexicon_rdr,
19//!     char_prop_rdr,
20//!     unk_handler_rdr,
21//!     feature_templates_rdr,
22//!     rewrite_rules_rdr,
23//! )?;
24//!
25//! // Initializes trainer
26//! let trainer = Trainer::new(config)?
27//!     .regularization_cost(0.01)
28//!     .max_iter(300)
29//!     .num_threads(20);
30//!
31//! // Loads corpus
32//! let corpus_rdr = File::open("src/tests/resources/corpus.txt")?;
33//! let corpus = Corpus::from_reader(corpus_rdr)?;
34//!
35//! // Model data
36//! let mut lexicon_trained = vec![];
37//! let mut connector_trained = vec![];
38//! let mut unk_handler_trained = vec![];
39//! let mut user_lexicon_trained = vec![];
40//!
41//! // Starts training
42//! let mut model = trainer.train(corpus)?;
43//!
44//! model.write_dictionary(
45//!     &mut lexicon_trained,
46//!     &mut connector_trained,
47//!     &mut unk_handler_trained,
48//!     &mut user_lexicon_trained,
49//! )?;
50//!
51//! // Loads trained model
52//! let char_prop_rdr = File::open("src/tests/resources/char.def")?;
53//! let dict = SystemDictionaryBuilder::from_readers(
54//!     &*lexicon_trained,
55//!     &*connector_trained,
56//!     char_prop_rdr,
57//!     &*unk_handler_trained,
58//! )?;
59//!
60//! let tokenizer = Tokenizer::new(dict);
61//! let mut worker = tokenizer.new_worker();
62//!
63//! worker.reset_sentence("外国人参政権");
64//! worker.tokenize();
65//! assert_eq!(worker.num_tokens(), 4); // 外国/人/参政/権
66//! # Ok(())
67//! # }
68//! ```
69
70mod config;
71mod corpus;
72mod feature_extractor;
73mod feature_rewriter;
74mod model;
75
76use std::num::NonZeroU32;
77
78use hashbrown::{HashMap, HashSet};
79use rucrf::{Edge, FeatureProvider, FeatureSet, Lattice};
80
81use crate::dictionary::word_idx::WordIdx;
82use crate::dictionary::LexType;
83use crate::errors::Result;
84pub use crate::trainer::config::TrainerConfig;
85pub use crate::trainer::corpus::{Corpus, Example, Word};
86use crate::trainer::feature_extractor::FeatureExtractor;
87use crate::trainer::feature_rewriter::FeatureRewriter;
88pub use crate::trainer::model::Model;
89use crate::trainer::model::ModelData;
90use crate::utils::{self, FromU32};
91
92/// Trainer of morphological analyzer.
93pub struct Trainer {
94    config: TrainerConfig,
95    max_grouping_len: Option<usize>,
96    provider: FeatureProvider,
97
98    // Assume a dictionary word W is associated with id X and feature string F.
99    // It maps F to a hash table that maps the first character of W to X.
100    label_id_map: HashMap<String, HashMap<char, NonZeroU32>>,
101
102    label_id_map_unk: Vec<NonZeroU32>,
103    regularization_cost: f64,
104    max_iter: u64,
105    num_threads: usize,
106}
107
108impl Trainer {
109    fn extract_feature_set(
110        feature_extractor: &mut FeatureExtractor,
111        unigram_rewriter: &FeatureRewriter,
112        left_rewriter: &FeatureRewriter,
113        right_rewriter: &FeatureRewriter,
114        feature_str: &str,
115        cate_id: u32,
116    ) -> FeatureSet {
117        let features = utils::parse_csv_row(feature_str);
118        let unigram_features = if let Some(rewrite) = unigram_rewriter.rewrite(&features) {
119            feature_extractor.extract_unigram_feature_ids(&rewrite, cate_id)
120        } else {
121            feature_extractor.extract_unigram_feature_ids(&features, cate_id)
122        };
123        let left_features = if let Some(rewrite) = left_rewriter.rewrite(&features) {
124            feature_extractor.extract_left_feature_ids(&rewrite)
125        } else {
126            feature_extractor.extract_left_feature_ids(&features)
127        };
128        let right_features = if let Some(rewrite) = right_rewriter.rewrite(&features) {
129            feature_extractor.extract_right_feature_ids(&rewrite)
130        } else {
131            feature_extractor.extract_right_feature_ids(&features)
132        };
133        FeatureSet::new(&unigram_features, &right_features, &left_features)
134    }
135
136    /// Creates a new [`Trainer`] using the specified configuration.
137    ///
138    /// # Arguments
139    ///
140    ///  * `config` - Training configuration.
141    ///
142    /// # Errors
143    ///
144    /// [`VibratoError`](crate::errors::VibratoError) is returned when the model will become too large.
145    pub fn new(mut config: TrainerConfig) -> Result<Self> {
146        let mut provider = FeatureProvider::default();
147        let mut label_id_map = HashMap::new();
148        let mut label_id_map_unk = vec![];
149
150        for word_id in 0..u32::try_from(config.surfaces.len()).unwrap() {
151            let word_idx = WordIdx::new(LexType::System, word_id);
152            let feature_str = config.dict.system_lexicon().word_feature(word_idx);
153            let first_char = config.surfaces[usize::from_u32(word_id)]
154                .chars()
155                .next()
156                .unwrap();
157            let cate_id = config.dict.char_prop().char_info(first_char).base_id();
158            let feature_set = Self::extract_feature_set(
159                &mut config.feature_extractor,
160                &config.unigram_rewriter,
161                &config.left_rewriter,
162                &config.right_rewriter,
163                feature_str,
164                cate_id,
165            );
166            let label_id = provider.add_feature_set(feature_set)?;
167            label_id_map
168                .raw_entry_mut()
169                .from_key(feature_str)
170                .or_insert_with(|| (feature_str.to_string(), HashMap::new()))
171                .1
172                .insert(first_char, label_id);
173        }
174        for word_id in 0..u32::try_from(config.dict.unk_handler().len()).unwrap() {
175            let word_idx = WordIdx::new(LexType::Unknown, word_id);
176            let feature_str = config.dict.unk_handler().word_feature(word_idx);
177            let cate_id = u32::from(config.dict.unk_handler().word_cate_id(word_idx));
178            let feature_set = Self::extract_feature_set(
179                &mut config.feature_extractor,
180                &config.unigram_rewriter,
181                &config.left_rewriter,
182                &config.right_rewriter,
183                feature_str,
184                cate_id,
185            );
186            label_id_map_unk.push(provider.add_feature_set(feature_set)?);
187        }
188
189        Ok(Self {
190            config,
191            max_grouping_len: None,
192            provider,
193            label_id_map,
194            label_id_map_unk,
195            regularization_cost: 0.01,
196            max_iter: 100,
197            num_threads: 1,
198        })
199    }
200
201    /// Changes the cost of L1-regularization.
202    ///
203    /// The greater this value, the stronger the regularization.
204    /// Default to 0.01.
205    ///
206    /// # Panics
207    ///
208    /// The value must be greater than or equal to 0.
209    pub fn regularization_cost(mut self, cost: f64) -> Self {
210        assert!(cost >= 0.0);
211        self.regularization_cost = cost;
212        self
213    }
214
215    /// Changes the maximum number of iterations.
216    ///
217    /// Default to 100.
218    ///
219    /// # Panics
220    ///
221    /// The value must be positive.
222    pub fn max_iter(mut self, n: u64) -> Self {
223        assert!(n >= 1);
224        self.max_iter = n;
225        self
226    }
227
228    /// Enables multi-threading.
229    ///
230    /// Default to 1.
231    ///
232    /// # Panics
233    ///
234    /// The value must be positive.
235    pub fn num_threads(mut self, n: usize) -> Self {
236        assert!(n >= 1);
237        self.num_threads = n;
238        self
239    }
240
241    /// Specifies the maximum grouping length for unknown words.
242    /// By default, the length is infinity.
243    ///
244    /// This option is for compatibility with MeCab.
245    /// Specifies the argument with `24` if you want to obtain the same results as MeCab.
246    ///
247    /// # Arguments
248    ///
249    ///  * `max_grouping_len` - The maximum grouping length for unknown words.
250    ///                         The default value is 0, indicating the infinity length.
251    pub const fn max_grouping_len(mut self, max_grouping_len: usize) -> Self {
252        if max_grouping_len != 0 {
253            self.max_grouping_len = Some(max_grouping_len);
254        } else {
255            self.max_grouping_len = None;
256        }
257        self
258    }
259
260    fn build_lattice(&mut self, example: &Example) -> Result<Lattice> {
261        let Example { sentence, tokens } = example;
262
263        let input_chars = sentence.chars();
264        let input_len = sentence.len_char();
265
266        // Add positive edges
267        // 1. If the word is found in the dictionary, add the edge as it is.
268        // 2. If the word is not found in the dictionary:
269        //   a) If a compatible unknown word is found, add the unknown word edge instead.
270        //   b) If there is no available word, add a virtual edge, which does not have any features.
271        let mut edges = vec![];
272        let mut pos = 0;
273        for token in tokens {
274            let len = token.surface().chars().count();
275            let first_char = input_chars[pos];
276            let label_id = self
277                .label_id_map
278                .get(token.feature())
279                .and_then(|hm| hm.get(&first_char))
280                .cloned()
281                .map(Ok)
282                .unwrap_or_else(|| {
283                    self.config
284                        .dict
285                        .unk_handler()
286                        .compatible_unk_index(sentence, pos, pos + len, token.feature())
287                        .map_or_else(
288                            || {
289                                eprintln!(
290                                    "adding virtual edge: {} {}",
291                                    token.surface(),
292                                    token.feature()
293                                );
294                                self.provider
295                                    .add_feature_set(FeatureSet::new(&[], &[], &[]))
296                            },
297                            |unk_index| {
298                                Ok(self.label_id_map_unk[usize::from_u32(unk_index.word_id)])
299                            },
300                        )
301                })?;
302            edges.push((pos, Edge::new(pos + len, label_id)));
303            pos += len;
304        }
305        assert_eq!(pos, input_len);
306
307        let mut lattice = Lattice::new(input_len).unwrap();
308
309        for (pos, edge) in edges {
310            lattice.add_edge(pos, edge).unwrap();
311        }
312
313        // Add negative edges
314        for start_word in 0..input_len {
315            let mut has_matched = false;
316
317            let suffix = &input_chars[start_word..];
318
319            for m in self
320                .config
321                .dict
322                .system_lexicon()
323                .common_prefix_iterator(suffix)
324            {
325                has_matched = true;
326                let label_id = NonZeroU32::new(m.word_idx.word_id + 1).unwrap();
327                let pos = start_word;
328                let target = pos + m.end_char;
329                let edge = Edge::new(target, label_id);
330                // Skips adding if the edge is already added as a positive edge.
331                if let Some(first_edge) = lattice.nodes()[pos].edges().first() {
332                    if edge == *first_edge {
333                        continue;
334                    }
335                }
336                lattice.add_edge(pos, edge).unwrap();
337            }
338
339            self.config.dict.unk_handler().gen_unk_words(
340                sentence,
341                start_word,
342                has_matched,
343                self.max_grouping_len,
344                |w| {
345                    let id_offset = u32::try_from(self.config.surfaces.len()).unwrap();
346                    let label_id = NonZeroU32::new(id_offset + w.word_idx().word_id + 1).unwrap();
347                    let pos = start_word;
348                    let target = w.end_char();
349                    let edge = Edge::new(target, label_id);
350                    // Skips adding if the edge is already added as a positive edge.
351                    if let Some(first_edge) = lattice.nodes()[pos].edges().first() {
352                        if edge == *first_edge {
353                            return;
354                        }
355                    }
356                    lattice.add_edge(pos, edge).unwrap();
357                },
358            );
359        }
360
361        Ok(lattice)
362    }
363
364    /// Starts training and returns a model.
365    ///
366    /// # Arguments
367    ///
368    /// * `corpus` - Corpus used for training.
369    ///
370    /// # Errors
371    ///
372    /// [`VibratoError`](crate::errors::VibratoError) is returned when the sentence compilation
373    /// fails.
374    pub fn train(mut self, mut corpus: Corpus) -> Result<Model> {
375        let mut lattices = vec![];
376        for example in &mut corpus.examples {
377            example.sentence.compile(self.config.dict.char_prop());
378            lattices.push(self.build_lattice(example)?);
379        }
380
381        let trainer = rucrf::Trainer::new()
382            .regularization(rucrf::Regularization::L1, self.regularization_cost)
383            .unwrap()
384            .max_iter(self.max_iter)
385            .unwrap()
386            .n_threads(self.num_threads)
387            .unwrap();
388        let model = trainer.train(&lattices, self.provider);
389
390        // Remove unused feature strings
391        let mut used_right_features = HashSet::new();
392        let unigram_feature_keys: Vec<_> = self
393            .config
394            .feature_extractor
395            .unigram_feature_ids
396            .keys()
397            .cloned()
398            .collect();
399        let left_feature_keys: Vec<_> = self
400            .config
401            .feature_extractor
402            .left_feature_ids
403            .keys()
404            .cloned()
405            .collect();
406        let right_feature_keys: Vec<_> = self
407            .config
408            .feature_extractor
409            .right_feature_ids
410            .keys()
411            .cloned()
412            .collect();
413        for k in &unigram_feature_keys {
414            let id = self
415                .config
416                .feature_extractor
417                .unigram_feature_ids
418                .get(k)
419                .unwrap();
420            if model
421                .unigram_weight_indices()
422                .get(usize::from_u32(id.get() - 1))
423                .cloned()
424                .flatten()
425                .is_none()
426            {
427                self.config.feature_extractor.unigram_feature_ids.remove(k);
428            }
429        }
430        for feature_ids in model.bigram_weight_indices() {
431            for (feature_id, _) in feature_ids {
432                used_right_features.insert(*feature_id);
433            }
434        }
435        for k in &left_feature_keys {
436            let id = self
437                .config
438                .feature_extractor
439                .left_feature_ids
440                .get(k)
441                .unwrap();
442            if let Some(x) = model.bigram_weight_indices().get(usize::from_u32(id.get())) {
443                if x.is_empty() {
444                    self.config.feature_extractor.left_feature_ids.remove(k);
445                }
446            }
447        }
448        for k in &right_feature_keys {
449            let id = self
450                .config
451                .feature_extractor
452                .right_feature_ids
453                .get(k)
454                .unwrap();
455            if !used_right_features.contains(&id.get()) {
456                self.config.feature_extractor.right_feature_ids.remove(k);
457            }
458        }
459
460        Ok(Model {
461            data: ModelData {
462                config: self.config,
463                raw_model: model,
464            },
465            merged_model: None,
466            user_entries: vec![],
467        })
468    }
469}