1use crate::models::unigram::{lattice::Lattice, model::Unigram};
2use crate::tokenizer::{AddedToken, Result, Trainer};
3use crate::utils::parallelism::*;
4use crate::utils::progress::{ProgressBar, ProgressStyle};
5use log::debug;
6use serde::{Deserialize, Serialize};
7use std::cmp::Reverse;
8use std::collections::{HashMap, HashSet};
9use std::convert::TryInto;
10
11type SentencePiece = (String, f64);
13
14type Sentence = (String, u32);
16
17fn digamma(mut x: f64) -> f64 {
18 let mut result = 0.0;
19 while x < 7.0 {
20 result -= 1.0 / x;
21 x += 1.0;
22 }
23 x -= 1.0 / 2.0;
24 let xx = 1.0 / x;
25 let xx2 = xx * xx;
26 let xx4 = xx2 * xx2;
27 result += x.ln() + (1.0 / 24.0) * xx2 - 7.0 / 960.0 * xx4 + (31.0 / 8064.0) * xx4 * xx2
28 - (127.0 / 30720.0) * xx4 * xx4;
29 result
30}
31
32#[derive(thiserror::Error, Debug)]
33pub enum UnigramTrainerError {
34 #[error("The vocabulary is not large enough to contain all chars")]
35 VocabularyTooSmall,
36}
37
38fn to_log_prob(pieces: &mut [SentencePiece]) {
39 let sum: f64 = pieces.iter().map(|(_, score)| score).sum();
40 let logsum = sum.ln();
41 for (_, score) in pieces.iter_mut() {
42 *score = score.ln() - logsum;
43 }
44}
45
46#[non_exhaustive]
48#[derive(Builder, Debug, Clone, Serialize, Deserialize)]
49pub struct UnigramTrainer {
50 #[builder(default = "true")]
51 pub show_progress: bool,
52 #[builder(default = "8000")]
53 pub vocab_size: u32,
54 #[builder(default = "2")]
55 pub n_sub_iterations: u32,
56 #[builder(default = "0.75")]
57 pub shrinking_factor: f64,
58 #[builder(default = "vec![]")]
59 pub special_tokens: Vec<AddedToken>,
60 #[builder(default = "HashSet::new()")]
61 pub initial_alphabet: HashSet<char>,
62
63 #[builder(default = "None")]
64 pub unk_token: Option<String>,
65
66 #[builder(default = "16")]
67 pub max_piece_length: usize,
68 #[builder(default = "1_000_000")]
69 seed_size: usize,
70 #[builder(default = "HashMap::new()")]
71 words: HashMap<String, u32>,
72}
73
74impl Default for UnigramTrainer {
75 fn default() -> Self {
76 Self::builder().build().unwrap()
77 }
78}
79
80impl UnigramTrainer {
81 pub fn builder() -> UnigramTrainerBuilder {
82 UnigramTrainerBuilder::default()
83 }
84
85 fn setup_progress(&self) -> Option<ProgressBar> {
87 if self.show_progress {
88 let p = ProgressBar::new(0);
89 p.set_style(
90 ProgressStyle::default_bar()
91 .template("[{elapsed_precise}] {msg:<30!} {wide_bar} {pos:<9!}/{len:>9!}")
92 .expect("Invalid progress template"),
93 );
94 Some(p)
95 } else {
96 None
97 }
98 }
99
100 fn is_valid_sentencepiece(&self, char_string: &[char]) -> bool {
101 let n = char_string.len();
106 if char_string.is_empty() || n > self.max_piece_length {
107 return false;
108 }
109
110 true
111 }
112
113 fn finalize(&self, model: Unigram, required_chars: HashSet<String>) -> Result<Unigram> {
114 let mut min_score_penalty = 0.0;
115 let min_score_penalty_delta = 0.0001;
116
117 let mut pieces: Vec<(String, f64)> = vec![];
118 let mut inserted: HashSet<String> = HashSet::new();
119
120 inserted.insert("<UNK>".into());
122
123 let existing_pieces: HashMap<String, f64> = model.iter().cloned().collect();
124 for c in required_chars {
125 if let Some(t) = existing_pieces.get(&c) {
126 inserted.insert(c.clone());
127 pieces.push((c, *t));
128 } else {
129 let score = model.min_score + min_score_penalty;
130
131 inserted.insert(c.clone());
132 pieces.push((c, score));
133 min_score_penalty += min_score_penalty_delta;
134 }
135 }
136
137 let (unk_id, need_add_unk) = if let Some(ref unk) = self.unk_token {
138 let unk_id = self.special_tokens.iter().enumerate().find_map(|(i, t)| {
139 if t.content == *unk {
140 Some(i)
141 } else {
142 None
143 }
144 });
145 match unk_id {
146 Some(id) => (Some(id), false),
147 None => (Some(0), true),
148 }
149 } else {
150 (None, false)
151 };
152
153 let vocab_size_without_special_tokens = if need_add_unk {
154 self.vocab_size as usize - self.special_tokens.len() - 1
155 } else {
156 self.vocab_size as usize - self.special_tokens.len()
157 };
158 for (token, score) in model.iter() {
159 if inserted.contains::<str>(token) {
160 continue;
161 }
162 inserted.insert(token.to_string());
163 pieces.push((token.to_string(), if score.is_nan() { 0.0 } else { *score }));
164
165 if pieces.len() == vocab_size_without_special_tokens {
166 break;
167 }
168 }
169 pieces.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
170
171 let mut special_tokens = self
173 .special_tokens
174 .iter()
175 .map(|t| (t.content.clone(), 0.0))
176 .collect::<Vec<_>>();
177 if need_add_unk {
178 special_tokens.insert(0, (self.unk_token.clone().unwrap(), 0.0));
179 }
180
181 Unigram::from(
182 special_tokens.into_iter().chain(pieces).collect(),
183 unk_id,
184 model.byte_fallback(),
185 )
186 }
187
188 fn required_chars(&self, word_counts: &[Sentence]) -> HashSet<String> {
189 word_counts
190 .iter()
191 .flat_map(|(s, _count)| s.chars())
192 .chain(self.initial_alphabet.iter().copied())
193 .map(|c| c.to_string())
194 .collect()
195 }
196 fn make_seed_sentence_pieces(
197 &self,
198 sentences: &[Sentence],
199 _progress: &Option<ProgressBar>,
200 ) -> Vec<SentencePiece> {
201 let total: usize = sentences
203 .iter()
204 .map(|(s, _)| s.chars().count())
205 .sum::<usize>()
206 + sentences.len();
207 let mut flat_string = String::with_capacity(total);
208 let mut all_chars: HashMap<char, u32> = HashMap::new();
209 let c_sentence_boundary = '\0';
210 let k_sentence_boundary = '\0'.to_string();
211 for (string, n) in sentences {
212 if string.is_empty() {
213 continue;
214 }
215 flat_string.push_str(string);
216 flat_string.push_str(&k_sentence_boundary);
220 for c in string.chars() {
221 if c != c_sentence_boundary {
222 *all_chars.entry(c).or_insert(0) += n;
223 }
224 }
225 }
226 flat_string.shrink_to_fit();
227 #[cfg(feature = "esaxx_fast")]
228 let suffix = esaxx_rs::suffix(&flat_string).unwrap();
229 #[cfg(not(feature = "esaxx_fast"))]
230 let suffix = esaxx_rs::suffix_rs(&flat_string).unwrap();
231
232 let mut seed_sentencepieces: Vec<SentencePiece> = vec![];
234
235 let mut sall_chars: Vec<_> = all_chars.into_iter().map(|(a, b)| (b, a)).collect();
236 sall_chars.sort_by_key(|&a| Reverse(a));
238 let mut substr_index: Vec<_> = suffix
239 .iter()
240 .filter_map(|(string, freq)| {
241 if string.len() <= 1 {
242 return None;
243 }
244 if string.contains(&c_sentence_boundary) {
245 return None;
246 }
247 if !self.is_valid_sentencepiece(string) {
248 return None;
249 }
250 let score = freq * string.len() as u32;
251 Some((score, string))
255 })
256 .collect();
257
258 for (count, character) in sall_chars {
260 seed_sentencepieces.push((character.to_string(), count.into()));
261 }
262
263 substr_index.sort_by_key(|&a| Reverse(a));
265 for (score, char_string) in substr_index {
266 assert!(self.is_valid_sentencepiece(char_string));
268 let string: String = char_string.iter().collect();
269 seed_sentencepieces.push((string, score.into()));
270 if seed_sentencepieces.len() >= self.seed_size {
271 break;
272 }
273 }
274 to_log_prob(&mut seed_sentencepieces);
275 seed_sentencepieces
276 }
277 fn prune_sentence_pieces(
278 &self,
279 model: &Unigram,
280 pieces: &[SentencePiece],
281 sentences: &[Sentence],
282 ) -> Vec<SentencePiece> {
283 let mut always_keep = vec![true; pieces.len()];
284 let mut alternatives: Vec<Vec<usize>> = vec![Vec::new(); pieces.len()];
285
286 let bos_id = pieces.len() + 1;
287 let eos_id = pieces.len() + 2;
288
289 for (id, (token, _score)) in pieces.iter().enumerate() {
295 if id == 0 {
297 always_keep[id] = false;
298 continue;
299 }
300 let mut lattice = Lattice::from(token, bos_id, eos_id);
301 model.populate_nodes(&mut lattice);
302
303 let nbests = lattice.nbest(2);
304 if nbests.len() == 1 {
305 always_keep[id] = true;
306 } else if nbests[0].len() >= 2 {
307 always_keep[id] = false;
308 } else if nbests[0].len() == 1 {
309 always_keep[id] = true;
310 for node in &nbests[1] {
311 let alt_id = node.borrow().id;
312 alternatives[id].push(alt_id);
313 }
314 }
315 }
316
317 let chunk_size = std::cmp::max(sentences.len() / current_num_threads(), 1);
321 let indexed_sentences: Vec<(usize, &Sentence)> = sentences.iter().enumerate().collect();
322 let collected: (f64, Vec<f64>, Vec<Vec<usize>>) = indexed_sentences
323 .maybe_par_chunks(chunk_size)
324 .map(|enumerated_sentence_count_chunk| {
325 let mut vsum = 0.0;
326 let mut freq: Vec<f64> = vec![0.0; pieces.len()];
327 let mut inverted: Vec<Vec<usize>> = vec![Vec::new(); pieces.len()];
328
329 for (i, (sentence, count)) in enumerated_sentence_count_chunk {
330 let mut lattice = Lattice::from(sentence, bos_id, eos_id);
331 model.populate_nodes(&mut lattice);
332 vsum += *count as f64;
333 for node_ref in lattice.viterbi() {
334 let id = node_ref.borrow().id;
335 freq[id] += *count as f64;
336 inverted[id].push(*i);
337 }
338 }
339 (vsum, freq, inverted)
340 })
341 .reduce(
342 || (0.0, vec![0.0; pieces.len()], vec![Vec::new(); pieces.len()]),
343 |(vsum, freq, inverted), (lvsum, lfreq, linverted)| {
344 (
345 vsum + lvsum,
346 freq.iter()
347 .zip(lfreq)
348 .map(|(global_el, local_el)| global_el + local_el)
349 .collect(),
350 inverted
351 .iter()
352 .zip(linverted)
353 .map(|(global_el, local_el)| [&global_el[..], &local_el[..]].concat())
354 .collect(),
355 )
356 },
357 );
358
359 let (vsum, freq, inverted) = collected;
360
361 let sum: f64 = freq.iter().sum();
362 let logsum = sum.ln();
363 let mut candidates: Vec<(usize, f64)> = vec![];
364 let mut new_pieces: Vec<SentencePiece> = Vec::with_capacity(self.vocab_size as usize);
365 new_pieces.push(pieces[0].clone());
366
367 for (id, (token, score)) in pieces.iter().enumerate() {
373 if id == 0 {
374 continue;
375 }
376 if freq[id] == 0.0 && !always_keep[id] {
377 continue;
379 } else if alternatives[id].is_empty() {
380 new_pieces.push((token.to_string(), *score));
382 } else {
383 let mut f = 0.0; for n in &inverted[id] {
386 let score = sentences[*n].1 as f64;
387 f += score;
388 }
389 if f == 0.0 || f.is_nan() {
391 continue;
393 }
394 f /= vsum; let logprob_sp = freq[id].ln() - logsum;
396
397 let logsum_alt = (sum + freq[id] * (alternatives.len() - 1) as f64).ln();
403
404 let mut logprob_alt = 0.0;
406 for n in &alternatives[id] {
407 logprob_alt += (freq[*n] + freq[id]).ln() - logsum_alt;
408 }
409
410 let loss = f * (logprob_sp - logprob_alt);
412 if loss.is_nan() {
413 panic!("");
414 }
415
416 candidates.push((id, loss));
417 }
418 }
419 let desired_vocab_size: usize = (self.vocab_size as usize * 11) / 10; let pruned_size: usize = ((pieces.len() as f64) * self.shrinking_factor) as usize;
421 let pruned_size = desired_vocab_size.max(pruned_size);
422
423 candidates.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
424 for (id, _score) in candidates {
425 if new_pieces.len() == pruned_size {
426 break;
427 }
428 new_pieces.push(pieces[id].clone());
429 }
430
431 new_pieces.to_vec()
432 }
433
434 fn update_progress(&self, p: &Option<ProgressBar>, len: usize, message: &'static str) {
436 if let Some(p) = p {
437 p.set_message(message);
438 p.set_length(len as u64);
439 p.reset();
440 }
441 }
442 fn finalize_progress(&self, p: &Option<ProgressBar>, final_len: usize) {
444 if let Some(p) = p {
445 p.set_length(final_len as u64);
446 p.finish();
447 println!();
448 }
449 }
450
451 fn run_e_step(&self, model: &Unigram, sentences: &[Sentence]) -> (f64, u32, Vec<f64>) {
452 let all_sentence_freq: u32 = sentences.iter().map(|(_a, b)| *b).sum();
453
454 let chunk_size = std::cmp::max(sentences.len() / current_num_threads(), 1);
455 let collected: (f64, u32, Vec<f64>) = sentences
456 .maybe_par_chunks(chunk_size)
457 .map(|sentences_chunk| {
458 let mut expected: Vec<f64> = vec![0.0; model.len()];
459 let mut objs: f64 = 0.0;
460 let mut ntokens: u32 = 0;
461
462 for (string, freq) in sentences_chunk {
463 let mut lattice = Lattice::from(string, model.bos_id, model.eos_id);
464 model.populate_nodes(&mut lattice);
465
466 let z: f64 = lattice.populate_marginal(*freq as f64, &mut expected);
467 if z.is_nan() {
468 panic!("likelihood is NAN. Input sentence may be too long.");
469 }
470 ntokens += lattice.viterbi().len() as u32;
471 objs -= z / (all_sentence_freq as f64);
472 }
473 (objs, ntokens, expected)
474 })
475 .reduce(
476 || (0.0, 0, vec![0.0; model.len()]),
477 |(objs, ntokens, expected), (lobjs, lntokens, lexpected)| {
478 (
479 objs + lobjs,
480 ntokens + lntokens,
481 expected
482 .iter()
483 .zip(lexpected)
484 .map(|(global_el, local_el)| global_el + local_el)
485 .collect(),
486 )
487 },
488 );
489
490 collected
491 }
492 fn run_m_step(&self, pieces: &[SentencePiece], expected: &[f64]) -> Vec<SentencePiece> {
493 if pieces.len() != expected.len() {
494 panic!(
495 "Those two iterators are supposed to be the same length ({} vs {})",
496 pieces.len(),
497 expected.len()
498 );
499 }
500 let mut new_pieces: Vec<SentencePiece> =
501 Vec::with_capacity(self.vocab_size.try_into().unwrap());
502
503 let mut sum = 0.0;
504 let expected_frequency_threshold = 0.5;
505
506 for (i, (freq, (piece, _score))) in expected.iter().zip(pieces).enumerate() {
507 if i == 0 {
509 new_pieces.push((piece.clone(), f64::NAN));
510 continue;
511 }
512 if *freq < expected_frequency_threshold {
513 continue;
514 }
515 new_pieces.push((piece.clone(), *freq));
516 sum += freq;
517 }
518 let logsum = digamma(sum);
523 let new_pieces: Vec<_> = new_pieces
524 .into_iter()
525 .map(|(s, c)| (s, digamma(c) - logsum))
526 .collect();
527 new_pieces
528 }
529 pub fn do_train(
530 &self,
531 sentences: Vec<Sentence>,
532 model: &mut Unigram,
533 ) -> Result<Vec<AddedToken>> {
534 let progress = self.setup_progress();
535 self.update_progress(&progress, sentences.len(), "Suffix array seeds");
539 let mut pieces: Vec<SentencePiece> =
540 Vec::with_capacity(self.vocab_size.try_into().unwrap());
541
542 pieces.push(("<UNK>".into(), f64::NAN));
544 pieces.extend(self.make_seed_sentence_pieces(&sentences, &progress));
545 self.finalize_progress(&progress, sentences.len());
546
547 debug!(
549 "Using {} pieces on {} sentences for EM training",
550 pieces.len(),
551 sentences.len()
552 );
553
554 let desired_vocab_size: usize = (self.vocab_size as usize * 11) / 10; let expected_loops = (((desired_vocab_size as f64).ln() - (pieces.len() as f64).ln())
562 / self.shrinking_factor.ln()) as usize
563 + 1;
564 let expected_updates = expected_loops * self.n_sub_iterations as usize;
565 self.update_progress(&progress, expected_updates, "EM training");
566 let required_chars = self.required_chars(&sentences);
567 if required_chars.len() as u32 > self.vocab_size {
568 return Err(Box::new(UnigramTrainerError::VocabularyTooSmall));
569 }
570 let mut new_model = Unigram::from(pieces.clone(), Some(0), false)?;
571 loop {
572 for _iter in 0..self.n_sub_iterations {
574 let (_objective, _num_tokens, expected) = self.run_e_step(&new_model, &sentences);
576
577 pieces = self.run_m_step(&pieces, &expected);
579 new_model = Unigram::from(pieces.clone(), Some(0), false)?;
580
581 debug!(
583 "Em iter={} size={} obj={} num_tokens={} num_tokens/piece={}",
584 _iter,
585 new_model.len(),
586 _objective,
587 _num_tokens,
588 _num_tokens as f64 / model.len() as f64
589 );
590 if let Some(p) = &progress {
591 p.inc(1);
592 }
593 } if pieces.len() <= desired_vocab_size {
598 break;
599 }
600
601 pieces = self.prune_sentence_pieces(&new_model, &pieces, &sentences);
603 new_model = Unigram::from(pieces.clone(), Some(0), false)?;
604 }
605 self.finalize_progress(&progress, expected_updates);
606
607 *model = self.finalize(new_model, required_chars)?;
609
610 Ok(self.special_tokens.clone())
611 }
612}
613
614impl Trainer for UnigramTrainer {
615 type Model = Unigram;
616
617 fn train(&self, model: &mut Unigram) -> Result<Vec<AddedToken>> {
619 let sentences: Vec<_> = self.words.iter().map(|(s, i)| (s.to_owned(), *i)).collect();
620 self.do_train(sentences, model)
621 }
622
623 fn should_show_progress(&self) -> bool {
625 self.show_progress
626 }
627
628 fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
629 where
630 I: Iterator<Item = S> + Send,
631 S: AsRef<str> + Send,
632 F: Fn(&str) -> Result<Vec<String>> + Sync,
633 {
634 let words: Result<HashMap<String, u32>> = iterator
635 .maybe_par_bridge()
636 .map(|sequence| {
637 let words = process(sequence.as_ref())?;
638 let mut map = HashMap::new();
639 for word in words {
640 map.entry(word).and_modify(|c| *c += 1).or_insert(1);
641 }
642 Ok(map)
643 })
644 .reduce(
645 || Ok(HashMap::new()),
646 |acc, ws| {
647 let mut acc = acc?;
648 for (k, v) in ws? {
649 acc.entry(k).and_modify(|c| *c += v).or_insert(v);
650 }
651 Ok(acc)
652 },
653 );
654
655 self.words = words?;
656 Ok(())
657 }
658}
659
660#[cfg(test)]
661mod tests {
662 use super::*;
663 use assert_approx_eq::assert_approx_eq;
664 use std::iter::FromIterator;
665
666 #[test]
667 fn test_unigram_chars() {
668 let trainer = UnigramTrainerBuilder::default()
669 .show_progress(false)
670 .build()
671 .unwrap();
672
673 let sentences = vec![
674 ("This is a".to_string(), 1),
675 ("こんにちは友達".to_string(), 1),
676 ];
677
678 let required_chars = trainer.required_chars(&sentences);
679 assert_eq!(required_chars.len(), 13);
680
681 let progress = None;
682 let table = trainer.make_seed_sentence_pieces(&sentences, &progress);
683
684 let target_strings = vec![
685 "s", "i", " ", "達", "友", "ん", "は", "に", "ち", "こ", "h", "a", "T", "is ", "s ",
686 ];
687
688 let strings: Vec<_> = table.iter().map(|(string, _)| string).collect();
689 assert_eq!(strings, target_strings);
690
691 let scores = table.iter().map(|(_, score)| score);
692 let target_scores = vec![
693 -2.5649493574615367, -2.5649493574615367, -2.5649493574615367, -3.258096538021482, -3.258096538021482, -3.258096538021482, -3.258096538021482, -3.258096538021482, -3.258096538021482, -3.258096538021482, -3.258096538021482, -3.258096538021482, -3.258096538021482, -1.4663370687934272, -1.8718021769015916, ];
709
710 for (score, target_score) in scores.zip(target_scores) {
711 assert_approx_eq!(*score, target_score, 0.01);
712 }
713 }
714
715 #[test]
716 fn test_initial_alphabet() {
717 let trainer = UnigramTrainerBuilder::default()
718 .show_progress(false)
719 .initial_alphabet(HashSet::from_iter(vec!['a', 'b', 'c', 'd', 'e', 'f']))
720 .build()
721 .unwrap();
722
723 let sentences = vec![("こんにちは友達".to_string(), 1)];
724 let required_chars = trainer.required_chars(&sentences);
725 assert_eq!(
726 required_chars,
727 vec!["こ", "ん", "に", "ち", "は", "友", "達", "a", "b", "c", "d", "e", "f"]
728 .into_iter()
729 .map(|s| s.to_owned())
730 .collect::<HashSet<_>>()
731 );
732 }
733
734 #[test]
735 fn test_unk_token() {
736 let trainer = UnigramTrainerBuilder::default()
738 .show_progress(false)
739 .special_tokens(vec![
740 AddedToken::from("[SEP]", true),
741 AddedToken::from("[CLS]", true),
742 ])
743 .unk_token(Some("[UNK]".into()))
744 .build()
745 .unwrap();
746
747 let mut unigram = Unigram::default();
748 trainer
749 .do_train(vec![("The".into(), 12), ("are".into(), 11)], &mut unigram)
750 .unwrap();
751
752 let mut pieces = unigram.iter();
753 assert_eq!(pieces.next(), Some(&("[UNK]".into(), 0.0)));
754 assert_eq!(pieces.next(), Some(&("[SEP]".into(), 0.0)));
755 assert_eq!(pieces.next(), Some(&("[CLS]".into(), 0.0)));
756
757 let trainer = UnigramTrainerBuilder::default()
759 .show_progress(false)
760 .special_tokens(vec![
761 AddedToken::from("[SEP]", true),
762 AddedToken::from("[CLS]", true),
763 AddedToken::from("[UNK]", true),
764 ])
765 .unk_token(Some("[UNK]".into()))
766 .build()
767 .unwrap();
768
769 let mut unigram = Unigram::default();
770 trainer
771 .do_train(vec![("The".into(), 12), ("are".into(), 11)], &mut unigram)
772 .unwrap();
773
774 let mut pieces = unigram.iter();
775 assert_eq!(pieces.next(), Some(&("[SEP]".into(), 0.0)));
776 assert_eq!(pieces.next(), Some(&("[CLS]".into(), 0.0)));
777 assert_eq!(pieces.next(), Some(&("[UNK]".into(), 0.0)));
778
779 let trainer = UnigramTrainerBuilder::default()
781 .show_progress(false)
782 .build()
783 .unwrap();
784
785 let mut unigram = Unigram::default();
786 trainer
787 .do_train(vec![("The".into(), 12), ("are".into(), 11)], &mut unigram)
788 .unwrap();
789
790 let mut pieces = unigram.iter();
791 assert_eq!(pieces.next().unwrap().0, "e".to_string());
792 }
793
794 #[test]
795 fn test_special_tokens() {
796 let trainer = UnigramTrainerBuilder::default()
797 .show_progress(false)
798 .special_tokens(vec![
799 AddedToken::from("[SEP]", true),
800 AddedToken::from("[CLS]", true),
801 ])
802 .build()
803 .unwrap();
804
805 let mut unigram = Unigram::default();
806 trainer
807 .do_train(vec![("The".into(), 12), ("are".into(), 11)], &mut unigram)
808 .unwrap();
809
810 let mut pieces = unigram.iter();
811 assert_eq!(pieces.next(), Some(&("[SEP]".into(), 0.0)));
812 assert_eq!(pieces.next(), Some(&("[CLS]".into(), 0.0)));
813 }
814
815 #[test]
816 fn test_to_log_prob() {
817 let mut a = vec![("".to_string(), 1.0), ("".to_string(), 2.0)];
818 to_log_prob(&mut a);
819 let scores = a.iter().map(|(_, score)| *score).collect::<Vec<_>>();
820 assert_approx_eq!(scores[0], -1.098, 0.01);
822 assert_approx_eq!(scores[1], -0.405, 0.01);
824 }
825}