1mod config;
71mod corpus;
72mod feature_extractor;
73mod feature_rewriter;
74mod model;
75
76use std::num::NonZeroU32;
77
78use hashbrown::{HashMap, HashSet};
79use rucrf::{Edge, FeatureProvider, FeatureSet, Lattice};
80
81use crate::dictionary::word_idx::WordIdx;
82use crate::dictionary::LexType;
83use crate::errors::Result;
84pub use crate::trainer::config::TrainerConfig;
85pub use crate::trainer::corpus::{Corpus, Example, Word};
86use crate::trainer::feature_extractor::FeatureExtractor;
87use crate::trainer::feature_rewriter::FeatureRewriter;
88pub use crate::trainer::model::Model;
89use crate::trainer::model::ModelData;
90use crate::utils::{self, FromU32};
91
92pub struct Trainer {
94 config: TrainerConfig,
95 max_grouping_len: Option<usize>,
96 provider: FeatureProvider,
97
98 label_id_map: HashMap<String, HashMap<char, NonZeroU32>>,
101
102 label_id_map_unk: Vec<NonZeroU32>,
103 regularization_cost: f64,
104 max_iter: u64,
105 num_threads: usize,
106}
107
108impl Trainer {
109 fn extract_feature_set(
110 feature_extractor: &mut FeatureExtractor,
111 unigram_rewriter: &FeatureRewriter,
112 left_rewriter: &FeatureRewriter,
113 right_rewriter: &FeatureRewriter,
114 feature_str: &str,
115 cate_id: u32,
116 ) -> FeatureSet {
117 let features = utils::parse_csv_row(feature_str);
118 let unigram_features = if let Some(rewrite) = unigram_rewriter.rewrite(&features) {
119 feature_extractor.extract_unigram_feature_ids(&rewrite, cate_id)
120 } else {
121 feature_extractor.extract_unigram_feature_ids(&features, cate_id)
122 };
123 let left_features = if let Some(rewrite) = left_rewriter.rewrite(&features) {
124 feature_extractor.extract_left_feature_ids(&rewrite)
125 } else {
126 feature_extractor.extract_left_feature_ids(&features)
127 };
128 let right_features = if let Some(rewrite) = right_rewriter.rewrite(&features) {
129 feature_extractor.extract_right_feature_ids(&rewrite)
130 } else {
131 feature_extractor.extract_right_feature_ids(&features)
132 };
133 FeatureSet::new(&unigram_features, &right_features, &left_features)
134 }
135
136 pub fn new(mut config: TrainerConfig) -> Result<Self> {
146 let mut provider = FeatureProvider::default();
147 let mut label_id_map = HashMap::new();
148 let mut label_id_map_unk = vec![];
149
150 for word_id in 0..u32::try_from(config.surfaces.len()).unwrap() {
151 let word_idx = WordIdx::new(LexType::System, word_id);
152 let feature_str = config.dict.system_lexicon().word_feature(word_idx);
153 let first_char = config.surfaces[usize::from_u32(word_id)]
154 .chars()
155 .next()
156 .unwrap();
157 let cate_id = config.dict.char_prop().char_info(first_char).base_id();
158 let feature_set = Self::extract_feature_set(
159 &mut config.feature_extractor,
160 &config.unigram_rewriter,
161 &config.left_rewriter,
162 &config.right_rewriter,
163 feature_str,
164 cate_id,
165 );
166 let label_id = provider.add_feature_set(feature_set)?;
167 label_id_map
168 .raw_entry_mut()
169 .from_key(feature_str)
170 .or_insert_with(|| (feature_str.to_string(), HashMap::new()))
171 .1
172 .insert(first_char, label_id);
173 }
174 for word_id in 0..u32::try_from(config.dict.unk_handler().len()).unwrap() {
175 let word_idx = WordIdx::new(LexType::Unknown, word_id);
176 let feature_str = config.dict.unk_handler().word_feature(word_idx);
177 let cate_id = u32::from(config.dict.unk_handler().word_cate_id(word_idx));
178 let feature_set = Self::extract_feature_set(
179 &mut config.feature_extractor,
180 &config.unigram_rewriter,
181 &config.left_rewriter,
182 &config.right_rewriter,
183 feature_str,
184 cate_id,
185 );
186 label_id_map_unk.push(provider.add_feature_set(feature_set)?);
187 }
188
189 Ok(Self {
190 config,
191 max_grouping_len: None,
192 provider,
193 label_id_map,
194 label_id_map_unk,
195 regularization_cost: 0.01,
196 max_iter: 100,
197 num_threads: 1,
198 })
199 }
200
201 pub fn regularization_cost(mut self, cost: f64) -> Self {
210 assert!(cost >= 0.0);
211 self.regularization_cost = cost;
212 self
213 }
214
215 pub fn max_iter(mut self, n: u64) -> Self {
223 assert!(n >= 1);
224 self.max_iter = n;
225 self
226 }
227
228 pub fn num_threads(mut self, n: usize) -> Self {
236 assert!(n >= 1);
237 self.num_threads = n;
238 self
239 }
240
241 pub const fn max_grouping_len(mut self, max_grouping_len: usize) -> Self {
252 if max_grouping_len != 0 {
253 self.max_grouping_len = Some(max_grouping_len);
254 } else {
255 self.max_grouping_len = None;
256 }
257 self
258 }
259
260 fn build_lattice(&mut self, example: &Example) -> Result<Lattice> {
261 let Example { sentence, tokens } = example;
262
263 let input_chars = sentence.chars();
264 let input_len = sentence.len_char();
265
266 let mut edges = vec![];
272 let mut pos = 0;
273 for token in tokens {
274 let len = token.surface().chars().count();
275 let first_char = input_chars[pos];
276 let label_id = self
277 .label_id_map
278 .get(token.feature())
279 .and_then(|hm| hm.get(&first_char))
280 .cloned()
281 .map(Ok)
282 .unwrap_or_else(|| {
283 self.config
284 .dict
285 .unk_handler()
286 .compatible_unk_index(sentence, pos, pos + len, token.feature())
287 .map_or_else(
288 || {
289 eprintln!(
290 "adding virtual edge: {} {}",
291 token.surface(),
292 token.feature()
293 );
294 self.provider
295 .add_feature_set(FeatureSet::new(&[], &[], &[]))
296 },
297 |unk_index| {
298 Ok(self.label_id_map_unk[usize::from_u32(unk_index.word_id)])
299 },
300 )
301 })?;
302 edges.push((pos, Edge::new(pos + len, label_id)));
303 pos += len;
304 }
305 assert_eq!(pos, input_len);
306
307 let mut lattice = Lattice::new(input_len).unwrap();
308
309 for (pos, edge) in edges {
310 lattice.add_edge(pos, edge).unwrap();
311 }
312
313 for start_word in 0..input_len {
315 let mut has_matched = false;
316
317 let suffix = &input_chars[start_word..];
318
319 for m in self
320 .config
321 .dict
322 .system_lexicon()
323 .common_prefix_iterator(suffix)
324 {
325 has_matched = true;
326 let label_id = NonZeroU32::new(m.word_idx.word_id + 1).unwrap();
327 let pos = start_word;
328 let target = pos + m.end_char;
329 let edge = Edge::new(target, label_id);
330 if let Some(first_edge) = lattice.nodes()[pos].edges().first() {
332 if edge == *first_edge {
333 continue;
334 }
335 }
336 lattice.add_edge(pos, edge).unwrap();
337 }
338
339 self.config.dict.unk_handler().gen_unk_words(
340 sentence,
341 start_word,
342 has_matched,
343 self.max_grouping_len,
344 |w| {
345 let id_offset = u32::try_from(self.config.surfaces.len()).unwrap();
346 let label_id = NonZeroU32::new(id_offset + w.word_idx().word_id + 1).unwrap();
347 let pos = start_word;
348 let target = w.end_char();
349 let edge = Edge::new(target, label_id);
350 if let Some(first_edge) = lattice.nodes()[pos].edges().first() {
352 if edge == *first_edge {
353 return;
354 }
355 }
356 lattice.add_edge(pos, edge).unwrap();
357 },
358 );
359 }
360
361 Ok(lattice)
362 }
363
364 pub fn train(mut self, mut corpus: Corpus) -> Result<Model> {
375 let mut lattices = vec![];
376 for example in &mut corpus.examples {
377 example.sentence.compile(self.config.dict.char_prop());
378 lattices.push(self.build_lattice(example)?);
379 }
380
381 let trainer = rucrf::Trainer::new()
382 .regularization(rucrf::Regularization::L1, self.regularization_cost)
383 .unwrap()
384 .max_iter(self.max_iter)
385 .unwrap()
386 .n_threads(self.num_threads)
387 .unwrap();
388 let model = trainer.train(&lattices, self.provider);
389
390 let mut used_right_features = HashSet::new();
392 let unigram_feature_keys: Vec<_> = self
393 .config
394 .feature_extractor
395 .unigram_feature_ids
396 .keys()
397 .cloned()
398 .collect();
399 let left_feature_keys: Vec<_> = self
400 .config
401 .feature_extractor
402 .left_feature_ids
403 .keys()
404 .cloned()
405 .collect();
406 let right_feature_keys: Vec<_> = self
407 .config
408 .feature_extractor
409 .right_feature_ids
410 .keys()
411 .cloned()
412 .collect();
413 for k in &unigram_feature_keys {
414 let id = self
415 .config
416 .feature_extractor
417 .unigram_feature_ids
418 .get(k)
419 .unwrap();
420 if model
421 .unigram_weight_indices()
422 .get(usize::from_u32(id.get() - 1))
423 .cloned()
424 .flatten()
425 .is_none()
426 {
427 self.config.feature_extractor.unigram_feature_ids.remove(k);
428 }
429 }
430 for feature_ids in model.bigram_weight_indices() {
431 for (feature_id, _) in feature_ids {
432 used_right_features.insert(*feature_id);
433 }
434 }
435 for k in &left_feature_keys {
436 let id = self
437 .config
438 .feature_extractor
439 .left_feature_ids
440 .get(k)
441 .unwrap();
442 if let Some(x) = model.bigram_weight_indices().get(usize::from_u32(id.get())) {
443 if x.is_empty() {
444 self.config.feature_extractor.left_feature_ids.remove(k);
445 }
446 }
447 }
448 for k in &right_feature_keys {
449 let id = self
450 .config
451 .feature_extractor
452 .right_feature_ids
453 .get(k)
454 .unwrap();
455 if !used_right_features.contains(&id.get()) {
456 self.config.feature_extractor.right_feature_ids.remove(k);
457 }
458 }
459
460 Ok(Model {
461 data: ModelData {
462 config: self.config,
463 raw_model: model,
464 },
465 merged_model: None,
466 user_entries: vec![],
467 })
468 }
469}