tokenizers/models/unigram/
trainer.rs

1use crate::models::unigram::{lattice::Lattice, model::Unigram};
2use crate::tokenizer::{AddedToken, Result, Trainer};
3use crate::utils::parallelism::*;
4use crate::utils::progress::{ProgressBar, ProgressStyle};
5use log::debug;
6use serde::{Deserialize, Serialize};
7use std::cmp::Reverse;
8use std::collections::{HashMap, HashSet};
9use std::convert::TryInto;
10
11// A token and a score
12type SentencePiece = (String, f64);
13
14// A full sentence or word + it's count within the dataset
15type Sentence = (String, u32);
16
17fn digamma(mut x: f64) -> f64 {
18    let mut result = 0.0;
19    while x < 7.0 {
20        result -= 1.0 / x;
21        x += 1.0;
22    }
23    x -= 1.0 / 2.0;
24    let xx = 1.0 / x;
25    let xx2 = xx * xx;
26    let xx4 = xx2 * xx2;
27    result += x.ln() + (1.0 / 24.0) * xx2 - 7.0 / 960.0 * xx4 + (31.0 / 8064.0) * xx4 * xx2
28        - (127.0 / 30720.0) * xx4 * xx4;
29    result
30}
31
32#[derive(thiserror::Error, Debug)]
33pub enum UnigramTrainerError {
34    #[error("The vocabulary is not large enough to contain all chars")]
35    VocabularyTooSmall,
36}
37
38fn to_log_prob(pieces: &mut [SentencePiece]) {
39    let sum: f64 = pieces.iter().map(|(_, score)| score).sum();
40    let logsum = sum.ln();
41    for (_, score) in pieces.iter_mut() {
42        *score = score.ln() - logsum;
43    }
44}
45
46/// A `UnigramTrainer` can train a `Unigram` model from `word_counts`.
47#[non_exhaustive]
48#[derive(Builder, Debug, Clone, Serialize, Deserialize)]
49pub struct UnigramTrainer {
50    #[builder(default = "true")]
51    pub show_progress: bool,
52    #[builder(default = "8000")]
53    pub vocab_size: u32,
54    #[builder(default = "2")]
55    pub n_sub_iterations: u32,
56    #[builder(default = "0.75")]
57    pub shrinking_factor: f64,
58    #[builder(default = "vec![]")]
59    pub special_tokens: Vec<AddedToken>,
60    #[builder(default = "HashSet::new()")]
61    pub initial_alphabet: HashSet<char>,
62
63    #[builder(default = "None")]
64    pub unk_token: Option<String>,
65
66    #[builder(default = "16")]
67    pub max_piece_length: usize,
68    #[builder(default = "1_000_000")]
69    seed_size: usize,
70    #[builder(default = "HashMap::new()")]
71    words: HashMap<String, u32>,
72}
73
74impl Default for UnigramTrainer {
75    fn default() -> Self {
76        Self::builder().build().unwrap()
77    }
78}
79
80impl UnigramTrainer {
81    pub fn builder() -> UnigramTrainerBuilder {
82        UnigramTrainerBuilder::default()
83    }
84
85    /// Setup a progress bar if asked to show progress
86    fn setup_progress(&self) -> Option<ProgressBar> {
87        if self.show_progress {
88            let p = ProgressBar::new(0);
89            p.set_style(
90                ProgressStyle::default_bar()
91                    .template("[{elapsed_precise}] {msg:<30!} {wide_bar} {pos:<9!}/{len:>9!}")
92                    .expect("Invalid progress template"),
93            );
94            Some(p)
95        } else {
96            None
97        }
98    }
99
100    fn is_valid_sentencepiece(&self, char_string: &[char]) -> bool {
101        // Checks string length
102        // Space not in the substring, numbers, hiragana and more should be taken
103        // care of within pre_tokenizers.
104        // https://github.com/google/sentencepiece/blob/26be9516cd81d5315ee31c48d2438018e0eab879/src/trainer_interface.cc#L203
105        let n = char_string.len();
106        if char_string.is_empty() || n > self.max_piece_length {
107            return false;
108        }
109
110        true
111    }
112
113    fn finalize(&self, model: Unigram, required_chars: HashSet<String>) -> Result<Unigram> {
114        let mut min_score_penalty = 0.0;
115        let min_score_penalty_delta = 0.0001;
116
117        let mut pieces: Vec<(String, f64)> = vec![];
118        let mut inserted: HashSet<String> = HashSet::new();
119
120        // We don't want to include the <UNK> that was used to train
121        inserted.insert("<UNK>".into());
122
123        let existing_pieces: HashMap<String, f64> = model.iter().cloned().collect();
124        for c in required_chars {
125            if let Some(t) = existing_pieces.get(&c) {
126                inserted.insert(c.clone());
127                pieces.push((c, *t));
128            } else {
129                let score = model.min_score + min_score_penalty;
130
131                inserted.insert(c.clone());
132                pieces.push((c, score));
133                min_score_penalty += min_score_penalty_delta;
134            }
135        }
136
137        let (unk_id, need_add_unk) = if let Some(ref unk) = self.unk_token {
138            let unk_id = self.special_tokens.iter().enumerate().find_map(|(i, t)| {
139                if t.content == *unk {
140                    Some(i)
141                } else {
142                    None
143                }
144            });
145            match unk_id {
146                Some(id) => (Some(id), false),
147                None => (Some(0), true),
148            }
149        } else {
150            (None, false)
151        };
152
153        let vocab_size_without_special_tokens = if need_add_unk {
154            self.vocab_size as usize - self.special_tokens.len() - 1
155        } else {
156            self.vocab_size as usize - self.special_tokens.len()
157        };
158        for (token, score) in model.iter() {
159            if inserted.contains::<str>(token) {
160                continue;
161            }
162            inserted.insert(token.to_string());
163            pieces.push((token.to_string(), if score.is_nan() { 0.0 } else { *score }));
164
165            if pieces.len() == vocab_size_without_special_tokens {
166                break;
167            }
168        }
169        pieces.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
170
171        // Insert the necessary tokens
172        let mut special_tokens = self
173            .special_tokens
174            .iter()
175            .map(|t| (t.content.clone(), 0.0))
176            .collect::<Vec<_>>();
177        if need_add_unk {
178            special_tokens.insert(0, (self.unk_token.clone().unwrap(), 0.0));
179        }
180
181        Unigram::from(
182            special_tokens.into_iter().chain(pieces).collect(),
183            unk_id,
184            model.byte_fallback(),
185        )
186    }
187
188    fn required_chars(&self, word_counts: &[Sentence]) -> HashSet<String> {
189        word_counts
190            .iter()
191            .flat_map(|(s, _count)| s.chars())
192            .chain(self.initial_alphabet.iter().copied())
193            .map(|c| c.to_string())
194            .collect()
195    }
196    fn make_seed_sentence_pieces(
197        &self,
198        sentences: &[Sentence],
199        _progress: &Option<ProgressBar>,
200    ) -> Vec<SentencePiece> {
201        // Put all sentences in a string, separated by \0
202        let total: usize = sentences
203            .iter()
204            .map(|(s, _)| s.chars().count())
205            .sum::<usize>()
206            + sentences.len();
207        let mut flat_string = String::with_capacity(total);
208        let mut all_chars: HashMap<char, u32> = HashMap::new();
209        let c_sentence_boundary = '\0';
210        let k_sentence_boundary = '\0'.to_string();
211        for (string, n) in sentences {
212            if string.is_empty() {
213                continue;
214            }
215            flat_string.push_str(string);
216            // XXX
217            // Comment suggests we add sentence boundary, but it seems to be missing from actual
218            // code in spm.
219            flat_string.push_str(&k_sentence_boundary);
220            for c in string.chars() {
221                if c != c_sentence_boundary {
222                    *all_chars.entry(c).or_insert(0) += n;
223                }
224            }
225        }
226        flat_string.shrink_to_fit();
227        #[cfg(feature = "esaxx_fast")]
228        let suffix = esaxx_rs::suffix(&flat_string).unwrap();
229        #[cfg(not(feature = "esaxx_fast"))]
230        let suffix = esaxx_rs::suffix_rs(&flat_string).unwrap();
231
232        //  Basic chars need to be in sentence pieces.
233        let mut seed_sentencepieces: Vec<SentencePiece> = vec![];
234
235        let mut sall_chars: Vec<_> = all_chars.into_iter().map(|(a, b)| (b, a)).collect();
236        // Reversed order
237        sall_chars.sort_by_key(|&a| Reverse(a));
238        let mut substr_index: Vec<_> = suffix
239            .iter()
240            .filter_map(|(string, freq)| {
241                if string.len() <= 1 {
242                    return None;
243                }
244                if string.contains(&c_sentence_boundary) {
245                    return None;
246                }
247                if !self.is_valid_sentencepiece(string) {
248                    return None;
249                }
250                let score = freq * string.len() as u32;
251                // if let Some(p) = &progress {
252                //     p.inc(1);
253                // }
254                Some((score, string))
255            })
256            .collect();
257
258        // Fill seed_sentencepieces
259        for (count, character) in sall_chars {
260            seed_sentencepieces.push((character.to_string(), count.into()));
261        }
262
263        // sort by decreasing score
264        substr_index.sort_by_key(|&a| Reverse(a));
265        for (score, char_string) in substr_index {
266            // Just in case
267            assert!(self.is_valid_sentencepiece(char_string));
268            let string: String = char_string.iter().collect();
269            seed_sentencepieces.push((string, score.into()));
270            if seed_sentencepieces.len() >= self.seed_size {
271                break;
272            }
273        }
274        to_log_prob(&mut seed_sentencepieces);
275        seed_sentencepieces
276    }
277    fn prune_sentence_pieces(
278        &self,
279        model: &Unigram,
280        pieces: &[SentencePiece],
281        sentences: &[Sentence],
282    ) -> Vec<SentencePiece> {
283        let mut always_keep = vec![true; pieces.len()];
284        let mut alternatives: Vec<Vec<usize>> = vec![Vec::new(); pieces.len()];
285
286        let bos_id = pieces.len() + 1;
287        let eos_id = pieces.len() + 2;
288
289        // First, segments the current sentencepieces to know
290        // how each sentencepiece is resegmented if this sentencepiece is removed
291        // from the vocabulary.
292        // To do so, we take the second best segmentation of sentencepiece[i].
293        // alternatives[i] stores the sequence of second best sentencepieces.
294        for (id, (token, _score)) in pieces.iter().enumerate() {
295            // Always keep unk.
296            if id == 0 {
297                always_keep[id] = false;
298                continue;
299            }
300            let mut lattice = Lattice::from(token, bos_id, eos_id);
301            model.populate_nodes(&mut lattice);
302
303            let nbests = lattice.nbest(2);
304            if nbests.len() == 1 {
305                always_keep[id] = true;
306            } else if nbests[0].len() >= 2 {
307                always_keep[id] = false;
308            } else if nbests[0].len() == 1 {
309                always_keep[id] = true;
310                for node in &nbests[1] {
311                    let alt_id = node.borrow().id;
312                    alternatives[id].push(alt_id);
313                }
314            }
315        }
316
317        // Second, segments all sentences to compute likelihood
318        // with a unigram language model. inverted[i] stores
319        // the set of sentence index where the sentencepieces[i] appears.
320        let chunk_size = std::cmp::max(sentences.len() / current_num_threads(), 1);
321        let indexed_sentences: Vec<(usize, &Sentence)> = sentences.iter().enumerate().collect();
322        let collected: (f64, Vec<f64>, Vec<Vec<usize>>) = indexed_sentences
323            .maybe_par_chunks(chunk_size)
324            .map(|enumerated_sentence_count_chunk| {
325                let mut vsum = 0.0;
326                let mut freq: Vec<f64> = vec![0.0; pieces.len()];
327                let mut inverted: Vec<Vec<usize>> = vec![Vec::new(); pieces.len()];
328
329                for (i, (sentence, count)) in enumerated_sentence_count_chunk {
330                    let mut lattice = Lattice::from(sentence, bos_id, eos_id);
331                    model.populate_nodes(&mut lattice);
332                    vsum += *count as f64;
333                    for node_ref in lattice.viterbi() {
334                        let id = node_ref.borrow().id;
335                        freq[id] += *count as f64;
336                        inverted[id].push(*i);
337                    }
338                }
339                (vsum, freq, inverted)
340            })
341            .reduce(
342                || (0.0, vec![0.0; pieces.len()], vec![Vec::new(); pieces.len()]),
343                |(vsum, freq, inverted), (lvsum, lfreq, linverted)| {
344                    (
345                        vsum + lvsum,
346                        freq.iter()
347                            .zip(lfreq)
348                            .map(|(global_el, local_el)| global_el + local_el)
349                            .collect(),
350                        inverted
351                            .iter()
352                            .zip(linverted)
353                            .map(|(global_el, local_el)| [&global_el[..], &local_el[..]].concat())
354                            .collect(),
355                    )
356                },
357            );
358
359        let (vsum, freq, inverted) = collected;
360
361        let sum: f64 = freq.iter().sum();
362        let logsum = sum.ln();
363        let mut candidates: Vec<(usize, f64)> = vec![];
364        let mut new_pieces: Vec<SentencePiece> = Vec::with_capacity(self.vocab_size as usize);
365        new_pieces.push(pieces[0].clone());
366
367        // Finally, computes how likely the LM likelihood is reduced if
368        // the sentencepiece[i] is removed from the vocabulary.
369        // Since the exact computation of loss is difficult, we compute the
370        // loss approximately by assuming that all sentencepiece[i] in the sentences
371        // are replaced with alternatives[i] when sentencepiece[i] is removed.
372        for (id, (token, score)) in pieces.iter().enumerate() {
373            if id == 0 {
374                continue;
375            }
376            if freq[id] == 0.0 && !always_keep[id] {
377                // not found in Viterbi path. Can remove this entry safely.
378                continue;
379            } else if alternatives[id].is_empty() {
380                // no alternatives. Keeps this entry.
381                new_pieces.push((token.to_string(), *score));
382            } else {
383                let mut f = 0.0; // the frequency of pieces[i];
384
385                for n in &inverted[id] {
386                    let score = sentences[*n].1 as f64;
387                    f += score;
388                }
389                // TODO: Temporary hack to avoid Nans.
390                if f == 0.0 || f.is_nan() {
391                    // new_pieces.push((token.to_string(), *score));
392                    continue;
393                }
394                f /= vsum; // normalizes by all sentence frequency.
395                let logprob_sp = freq[id].ln() - logsum;
396
397                // After removing the sentencepiece[i], its frequency freq[i] is
398                // re-assigned to alternatives.
399                // new_sum = current_sum - freq[i] + freq[i] * alternatives.size()
400                //         = current_sum + freq[i] (alternatives - 1)
401
402                let logsum_alt = (sum + freq[id] * (alternatives.len() - 1) as f64).ln();
403
404                // The frequencies of altenatives are increased by freq[i].
405                let mut logprob_alt = 0.0;
406                for n in &alternatives[id] {
407                    logprob_alt += (freq[*n] + freq[id]).ln() - logsum_alt;
408                }
409
410                // loss: the diff of likelihood after removing the sentencepieces[i].
411                let loss = f * (logprob_sp - logprob_alt);
412                if loss.is_nan() {
413                    panic!("");
414                }
415
416                candidates.push((id, loss));
417            }
418        }
419        let desired_vocab_size: usize = (self.vocab_size as usize * 11) / 10; // * 1.1
420        let pruned_size: usize = ((pieces.len() as f64) * self.shrinking_factor) as usize;
421        let pruned_size = desired_vocab_size.max(pruned_size);
422
423        candidates.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
424        for (id, _score) in candidates {
425            if new_pieces.len() == pruned_size {
426                break;
427            }
428            new_pieces.push(pieces[id].clone());
429        }
430
431        new_pieces.to_vec()
432    }
433
434    /// Update the progress bar with the new provided length and message
435    fn update_progress(&self, p: &Option<ProgressBar>, len: usize, message: &'static str) {
436        if let Some(p) = p {
437            p.set_message(message);
438            p.set_length(len as u64);
439            p.reset();
440        }
441    }
442    /// Set the progress bar in the finish state
443    fn finalize_progress(&self, p: &Option<ProgressBar>, final_len: usize) {
444        if let Some(p) = p {
445            p.set_length(final_len as u64);
446            p.finish();
447            println!();
448        }
449    }
450
451    fn run_e_step(&self, model: &Unigram, sentences: &[Sentence]) -> (f64, u32, Vec<f64>) {
452        let all_sentence_freq: u32 = sentences.iter().map(|(_a, b)| *b).sum();
453
454        let chunk_size = std::cmp::max(sentences.len() / current_num_threads(), 1);
455        let collected: (f64, u32, Vec<f64>) = sentences
456            .maybe_par_chunks(chunk_size)
457            .map(|sentences_chunk| {
458                let mut expected: Vec<f64> = vec![0.0; model.len()];
459                let mut objs: f64 = 0.0;
460                let mut ntokens: u32 = 0;
461
462                for (string, freq) in sentences_chunk {
463                    let mut lattice = Lattice::from(string, model.bos_id, model.eos_id);
464                    model.populate_nodes(&mut lattice);
465
466                    let z: f64 = lattice.populate_marginal(*freq as f64, &mut expected);
467                    if z.is_nan() {
468                        panic!("likelihood is NAN. Input sentence may be too long.");
469                    }
470                    ntokens += lattice.viterbi().len() as u32;
471                    objs -= z / (all_sentence_freq as f64);
472                }
473                (objs, ntokens, expected)
474            })
475            .reduce(
476                || (0.0, 0, vec![0.0; model.len()]),
477                |(objs, ntokens, expected), (lobjs, lntokens, lexpected)| {
478                    (
479                        objs + lobjs,
480                        ntokens + lntokens,
481                        expected
482                            .iter()
483                            .zip(lexpected)
484                            .map(|(global_el, local_el)| global_el + local_el)
485                            .collect(),
486                    )
487                },
488            );
489
490        collected
491    }
492    fn run_m_step(&self, pieces: &[SentencePiece], expected: &[f64]) -> Vec<SentencePiece> {
493        if pieces.len() != expected.len() {
494            panic!(
495                "Those two iterators are supposed to be the same length ({} vs {})",
496                pieces.len(),
497                expected.len()
498            );
499        }
500        let mut new_pieces: Vec<SentencePiece> =
501            Vec::with_capacity(self.vocab_size.try_into().unwrap());
502
503        let mut sum = 0.0;
504        let expected_frequency_threshold = 0.5;
505
506        for (i, (freq, (piece, _score))) in expected.iter().zip(pieces).enumerate() {
507            // Always keep unk.
508            if i == 0 {
509                new_pieces.push((piece.clone(), f64::NAN));
510                continue;
511            }
512            if *freq < expected_frequency_threshold {
513                continue;
514            }
515            new_pieces.push((piece.clone(), *freq));
516            sum += freq;
517        }
518        // // Here we do not use the original EM, but use the
519        // // Bayesianified/DPified EM algorithm.
520        // // https://cs.stanford.edu/~pliang/papers/tutorial-acl2007-talk.pdf
521        // // This modification will act as a sparse prior.
522        let logsum = digamma(sum);
523        let new_pieces: Vec<_> = new_pieces
524            .into_iter()
525            .map(|(s, c)| (s, digamma(c) - logsum))
526            .collect();
527        new_pieces
528    }
529    pub fn do_train(
530        &self,
531        sentences: Vec<Sentence>,
532        model: &mut Unigram,
533    ) -> Result<Vec<AddedToken>> {
534        let progress = self.setup_progress();
535        //
536        // 1. Compute frequent substrings
537        // TODO Should be able to upgrade to u64 when needed
538        self.update_progress(&progress, sentences.len(), "Suffix array seeds");
539        let mut pieces: Vec<SentencePiece> =
540            Vec::with_capacity(self.vocab_size.try_into().unwrap());
541
542        // We use a UNK token when training, whatever the `self.unk_token`
543        pieces.push(("<UNK>".into(), f64::NAN));
544        pieces.extend(self.make_seed_sentence_pieces(&sentences, &progress));
545        self.finalize_progress(&progress, sentences.len());
546
547        // Useful to check compatibility with spm.
548        debug!(
549            "Using {} pieces on {} sentences for EM training",
550            pieces.len(),
551            sentences.len()
552        );
553
554        let desired_vocab_size: usize = (self.vocab_size as usize * 11) / 10; // * 1.1
555
556        // 2. Run E-M Loops to fine grain the pieces.
557        // We will shrink the vocab by shrinking_factor every loop on average
558        // Some other pieces are dropped if logprob is too small
559        // V = N * (f)**k
560        // k = log(V / N) / log(f)
561        let expected_loops = (((desired_vocab_size as f64).ln() - (pieces.len() as f64).ln())
562            / self.shrinking_factor.ln()) as usize
563            + 1;
564        let expected_updates = expected_loops * self.n_sub_iterations as usize;
565        self.update_progress(&progress, expected_updates, "EM training");
566        let required_chars = self.required_chars(&sentences);
567        if required_chars.len() as u32 > self.vocab_size {
568            return Err(Box::new(UnigramTrainerError::VocabularyTooSmall));
569        }
570        let mut new_model = Unigram::from(pieces.clone(), Some(0), false)?;
571        loop {
572            // Sub-EM iteration.
573            for _iter in 0..self.n_sub_iterations {
574                // Executes E step
575                let (_objective, _num_tokens, expected) = self.run_e_step(&new_model, &sentences);
576
577                // Executes M step.
578                pieces = self.run_m_step(&pieces, &expected);
579                new_model = Unigram::from(pieces.clone(), Some(0), false)?;
580
581                // Useful comment for checking compatibility with spm
582                debug!(
583                    "Em iter={} size={} obj={} num_tokens={} num_tokens/piece={}",
584                    _iter,
585                    new_model.len(),
586                    _objective,
587                    _num_tokens,
588                    _num_tokens as f64 / model.len() as f64
589                );
590                if let Some(p) = &progress {
591                    p.inc(1);
592                }
593            } // end of Sub EM iteration
594
595            // Stops the iteration when the size of sentences reaches to the
596            // desired symbol size.
597            if pieces.len() <= desired_vocab_size {
598                break;
599            }
600
601            // Prunes pieces.
602            pieces = self.prune_sentence_pieces(&new_model, &pieces, &sentences);
603            new_model = Unigram::from(pieces.clone(), Some(0), false)?;
604        }
605        self.finalize_progress(&progress, expected_updates);
606
607        // Finally, adjusts the size of sentencepices to be |vocab_size|.
608        *model = self.finalize(new_model, required_chars)?;
609
610        Ok(self.special_tokens.clone())
611    }
612}
613
614impl Trainer for UnigramTrainer {
615    type Model = Unigram;
616
617    /// Train a Unigram model
618    fn train(&self, model: &mut Unigram) -> Result<Vec<AddedToken>> {
619        let sentences: Vec<_> = self.words.iter().map(|(s, i)| (s.to_owned(), *i)).collect();
620        self.do_train(sentences, model)
621    }
622
623    /// Whether we should show progress
624    fn should_show_progress(&self) -> bool {
625        self.show_progress
626    }
627
628    fn feed<I, S, F>(&mut self, iterator: I, process: F) -> Result<()>
629    where
630        I: Iterator<Item = S> + Send,
631        S: AsRef<str> + Send,
632        F: Fn(&str) -> Result<Vec<String>> + Sync,
633    {
634        let words: Result<HashMap<String, u32>> = iterator
635            .maybe_par_bridge()
636            .map(|sequence| {
637                let words = process(sequence.as_ref())?;
638                let mut map = HashMap::new();
639                for word in words {
640                    map.entry(word).and_modify(|c| *c += 1).or_insert(1);
641                }
642                Ok(map)
643            })
644            .reduce(
645                || Ok(HashMap::new()),
646                |acc, ws| {
647                    let mut acc = acc?;
648                    for (k, v) in ws? {
649                        acc.entry(k).and_modify(|c| *c += v).or_insert(v);
650                    }
651                    Ok(acc)
652                },
653            );
654
655        self.words = words?;
656        Ok(())
657    }
658}
659
660#[cfg(test)]
661mod tests {
662    use super::*;
663    use assert_approx_eq::assert_approx_eq;
664    use std::iter::FromIterator;
665
666    #[test]
667    fn test_unigram_chars() {
668        let trainer = UnigramTrainerBuilder::default()
669            .show_progress(false)
670            .build()
671            .unwrap();
672
673        let sentences = vec![
674            ("This is a".to_string(), 1),
675            ("こんにちは友達".to_string(), 1),
676        ];
677
678        let required_chars = trainer.required_chars(&sentences);
679        assert_eq!(required_chars.len(), 13);
680
681        let progress = None;
682        let table = trainer.make_seed_sentence_pieces(&sentences, &progress);
683
684        let target_strings = vec![
685            "s", "i", " ", "達", "友", "ん", "は", "に", "ち", "こ", "h", "a", "T", "is ", "s ",
686        ];
687
688        let strings: Vec<_> = table.iter().map(|(string, _)| string).collect();
689        assert_eq!(strings, target_strings);
690
691        let scores = table.iter().map(|(_, score)| score);
692        let target_scores = vec![
693            -2.5649493574615367, // 2.0
694            -2.5649493574615367, // 2.0
695            -2.5649493574615367, // 2.0
696            -3.258096538021482,  // 1.0
697            -3.258096538021482,  // 1.0
698            -3.258096538021482,  // 1.0
699            -3.258096538021482,  // 1.0
700            -3.258096538021482,  // 1.0
701            -3.258096538021482,  // 1.0
702            -3.258096538021482,  // 1.0
703            -3.258096538021482,  // 1.0
704            -3.258096538021482,  // 1.0
705            -3.258096538021482,  // 1.0
706            -1.4663370687934272, // 6.0
707            -1.8718021769015916, // 4.0
708        ];
709
710        for (score, target_score) in scores.zip(target_scores) {
711            assert_approx_eq!(*score, target_score, 0.01);
712        }
713    }
714
715    #[test]
716    fn test_initial_alphabet() {
717        let trainer = UnigramTrainerBuilder::default()
718            .show_progress(false)
719            .initial_alphabet(HashSet::from_iter(vec!['a', 'b', 'c', 'd', 'e', 'f']))
720            .build()
721            .unwrap();
722
723        let sentences = vec![("こんにちは友達".to_string(), 1)];
724        let required_chars = trainer.required_chars(&sentences);
725        assert_eq!(
726            required_chars,
727            vec!["こ", "ん", "に", "ち", "は", "友", "達", "a", "b", "c", "d", "e", "f"]
728                .into_iter()
729                .map(|s| s.to_owned())
730                .collect::<HashSet<_>>()
731        );
732    }
733
734    #[test]
735    fn test_unk_token() {
736        // 1. Should add `unk_token` as first special token
737        let trainer = UnigramTrainerBuilder::default()
738            .show_progress(false)
739            .special_tokens(vec![
740                AddedToken::from("[SEP]", true),
741                AddedToken::from("[CLS]", true),
742            ])
743            .unk_token(Some("[UNK]".into()))
744            .build()
745            .unwrap();
746
747        let mut unigram = Unigram::default();
748        trainer
749            .do_train(vec![("The".into(), 12), ("are".into(), 11)], &mut unigram)
750            .unwrap();
751
752        let mut pieces = unigram.iter();
753        assert_eq!(pieces.next(), Some(&("[UNK]".into(), 0.0)));
754        assert_eq!(pieces.next(), Some(&("[SEP]".into(), 0.0)));
755        assert_eq!(pieces.next(), Some(&("[CLS]".into(), 0.0)));
756
757        // 2. Let it where it is
758        let trainer = UnigramTrainerBuilder::default()
759            .show_progress(false)
760            .special_tokens(vec![
761                AddedToken::from("[SEP]", true),
762                AddedToken::from("[CLS]", true),
763                AddedToken::from("[UNK]", true),
764            ])
765            .unk_token(Some("[UNK]".into()))
766            .build()
767            .unwrap();
768
769        let mut unigram = Unigram::default();
770        trainer
771            .do_train(vec![("The".into(), 12), ("are".into(), 11)], &mut unigram)
772            .unwrap();
773
774        let mut pieces = unigram.iter();
775        assert_eq!(pieces.next(), Some(&("[SEP]".into(), 0.0)));
776        assert_eq!(pieces.next(), Some(&("[CLS]".into(), 0.0)));
777        assert_eq!(pieces.next(), Some(&("[UNK]".into(), 0.0)));
778
779        // 3. Don't put it there if not needed
780        let trainer = UnigramTrainerBuilder::default()
781            .show_progress(false)
782            .build()
783            .unwrap();
784
785        let mut unigram = Unigram::default();
786        trainer
787            .do_train(vec![("The".into(), 12), ("are".into(), 11)], &mut unigram)
788            .unwrap();
789
790        let mut pieces = unigram.iter();
791        assert_eq!(pieces.next().unwrap().0, "e".to_string());
792    }
793
794    #[test]
795    fn test_special_tokens() {
796        let trainer = UnigramTrainerBuilder::default()
797            .show_progress(false)
798            .special_tokens(vec![
799                AddedToken::from("[SEP]", true),
800                AddedToken::from("[CLS]", true),
801            ])
802            .build()
803            .unwrap();
804
805        let mut unigram = Unigram::default();
806        trainer
807            .do_train(vec![("The".into(), 12), ("are".into(), 11)], &mut unigram)
808            .unwrap();
809
810        let mut pieces = unigram.iter();
811        assert_eq!(pieces.next(), Some(&("[SEP]".into(), 0.0)));
812        assert_eq!(pieces.next(), Some(&("[CLS]".into(), 0.0)));
813    }
814
815    #[test]
816    fn test_to_log_prob() {
817        let mut a = vec![("".to_string(), 1.0), ("".to_string(), 2.0)];
818        to_log_prob(&mut a);
819        let scores = a.iter().map(|(_, score)| *score).collect::<Vec<_>>();
820        // ln(1) - ln(3)
821        assert_approx_eq!(scores[0], -1.098, 0.01);
822        // ln(2) - ln(3)
823        assert_approx_eq!(scores[1], -0.405, 0.01);
824    }
825}