finalfusion/
subword.rs

1//! Utilities for subword units.
2
3use std::cmp;
4use std::collections::{HashMap, VecDeque};
5use std::fmt;
6use std::hash::{Hash, Hasher};
7use std::marker::PhantomData;
8use std::ops::Deref;
9
10use fnv::FnvHasher;
11use smallvec::{smallvec, SmallVec};
12
13use crate::util::CollectWithCapacity;
14
15pub type NGramVec = SmallVec<[u64; 4]>;
16
17/// N-Gram indexer
18///
19/// An indexer maps an n-gram to an index in the subword embedding
20/// matrix.
21pub trait Indexer {
22    /// Map an n-gram to indices in the subword embedding matrix.
23    fn index_ngram(&self, ngram: &StrWithCharLen) -> NGramVec;
24
25    /// Return the (exclusive) upper bound of this indexer.
26    fn upper_bound(&self) -> u64;
27
28    /// Indicates whether this Indexer never fails to produce an index.
29    fn infallible() -> bool;
30
31    /// The scope of the indexer.
32    fn scope() -> IndicesScope;
33}
34
35/// N-Gram indexer with bucketing.
36pub trait BucketIndexer: Indexer {
37    /// Create a new indexer.
38    ///
39    /// The buckets argument is the number of buckets or the
40    /// bucket exponent (depending on the implementation).
41    fn new(buckets: usize) -> Self;
42
43    /// Get the number of buckets.
44    ///
45    /// Depending on the indexer, this may be the actual number of
46    /// buckets or the bucket exponent.
47    fn buckets(&self) -> usize;
48}
49
50/// Indexer using a hash function.
51///
52/// This indexer first hashes a given n-gram and then maps the
53/// resulting hash into *2^buckets_exp* buckets.
54///
55/// The largest possible bucket exponent is 64.
56pub struct HashIndexer<H> {
57    buckets_exp: usize,
58    mask: u64,
59    _phantom: PhantomData<H>,
60}
61
62impl<H> BucketIndexer for HashIndexer<H>
63where
64    H: Default + Hasher,
65{
66    /// Construct a `HashIndexer`.
67    ///
68    /// The largest possible bucket exponent is 64.
69    fn new(buckets_exp: usize) -> Self {
70        assert!(
71            buckets_exp <= 64,
72            "The largest possible buckets exponent is 64."
73        );
74
75        let mask = if buckets_exp == 64 {
76            !0
77        } else {
78            (1 << buckets_exp) - 1
79        };
80
81        HashIndexer {
82            buckets_exp,
83            mask,
84            _phantom: PhantomData,
85        }
86    }
87
88    fn buckets(&self) -> usize {
89        self.buckets_exp
90    }
91}
92
93impl<H> Clone for HashIndexer<H> {
94    fn clone(&self) -> Self {
95        *self
96    }
97}
98
99impl<H> Copy for HashIndexer<H> {}
100
101impl<H> fmt::Debug for HashIndexer<H> {
102    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
103        // std::intrinsics::type_name requires nightly
104        write!(f, "HashIndexer<impl Hasher> {{ mask: {} }}", self.mask)
105    }
106}
107
108impl<H> Eq for HashIndexer<H> {}
109
110impl<H> Indexer for HashIndexer<H>
111where
112    H: Default + Hasher,
113{
114    fn index_ngram(&self, ngram: &StrWithCharLen) -> NGramVec {
115        let mut hasher = H::default();
116        ngram.hash(&mut hasher);
117        smallvec![hasher.finish() & self.mask]
118    }
119
120    fn upper_bound(&self) -> u64 {
121        // max val is <= 64
122        2u64.pow(self.buckets_exp as u32)
123    }
124
125    fn infallible() -> bool {
126        true
127    }
128
129    fn scope() -> IndicesScope {
130        IndicesScope::Substrings
131    }
132}
133
134impl<H> PartialEq for HashIndexer<H> {
135    fn eq(&self, other: &Self) -> bool {
136        self.mask.eq(&other.mask)
137    }
138}
139
140/// Standard hash-based indexer in finalfusion.
141pub type FinalfusionHashIndexer = HashIndexer<FnvHasher>;
142
143/// Indexer for explicitly stored NGrams.
144#[derive(Clone, Debug, PartialEq, Eq)]
145pub struct ExplicitIndexer {
146    ngrams: Vec<String>,
147    index: HashMap<String, u64>,
148    bound: usize,
149}
150
151impl ExplicitIndexer {
152    pub fn ngrams(&self) -> &[String] {
153        &self.ngrams
154    }
155}
156
157impl ExplicitIndexer {
158    /// Construct a new explicit indexer.
159    ///
160    /// Panics when there are duplicate ngrams.
161    pub fn new(ngrams: impl Into<Vec<String>>) -> Self {
162        let ngrams = ngrams.into();
163        let index = ngrams
164            .iter()
165            .cloned()
166            .enumerate()
167            .map(|(idx, ngram)| (ngram, idx as u64))
168            .collect::<HashMap<String, u64>>();
169        assert_eq!(
170            index.len(),
171            ngrams.len(),
172            "ngrams contained duplicate entries."
173        );
174        let bound = index.len();
175        ExplicitIndexer {
176            ngrams,
177            index,
178            bound,
179        }
180    }
181
182    /// Construct a new explicit indexer with given indices.
183    ///
184    /// The `(String, u64)` tuples resemble the original `subword -> index` mapping. This mapping
185    /// does not need to be perfect, i.e. multiple subwords can map to the same index as it is
186    /// common with bucketed indexing.
187    ///
188    /// This constructor numbers the original indices as they appear and assigns a new index
189    /// accordingly. After construction, subwords that originally had the same index will still be
190    /// indexed by a common number. It is guaranteed that the new indices cover
191    /// `(0..n_original_indices)` where `n_original_indices` is the number of unique indices in the
192    /// `subword -> index` mapping.
193    ///
194    /// The second item in the returned tuple holds the `provided_index -> new_index` mapping.
195    /// I.e.: the `i`th unique `provided_index` in `ngram_tuples` is mapped to the `new_index` `i`.
196    ///
197    /// Panics when there are duplicate ngrams.
198    pub fn new_with_indices(
199        ngram_tuples: impl IntoIterator<Item = (String, u64)>,
200    ) -> (Self, HashMap<u64, usize>) {
201        let ngram_tuples = ngram_tuples.into_iter();
202        let mut old_to_new_indices = HashMap::with_capacity(ngram_tuples.size_hint().0);
203        let mut index = HashMap::with_capacity(ngram_tuples.size_hint().0);
204        let mut ngrams = Vec::with_capacity(ngram_tuples.size_hint().0);
205        for (ngram, bucket) in ngram_tuples {
206            let cur_idx = old_to_new_indices.len();
207            let new_idx = *old_to_new_indices.entry(bucket).or_insert(cur_idx);
208            assert!(
209                index.insert(ngram.clone(), new_idx as u64).is_none(),
210                "ngrams contains duplicate entries."
211            );
212            ngrams.push(ngram);
213        }
214        let bound = old_to_new_indices.len();
215        (
216            ExplicitIndexer {
217                ngrams,
218                index,
219                bound,
220            },
221            old_to_new_indices,
222        )
223    }
224}
225
226impl Indexer for ExplicitIndexer {
227    fn index_ngram(&self, ngram: &StrWithCharLen) -> NGramVec {
228        match self.index.get(ngram.inner) {
229            Some(&idx) => smallvec![idx],
230            None => smallvec![],
231        }
232    }
233
234    fn upper_bound(&self) -> u64 {
235        self.bound as u64
236    }
237
238    fn infallible() -> bool {
239        false
240    }
241
242    fn scope() -> IndicesScope {
243        IndicesScope::Substrings
244    }
245}
246
247/// A string reference with its length in characters.
248pub struct StrWithCharLen<'a> {
249    inner: &'a str,
250    char_len: usize,
251}
252
253impl<'a> From<&'a str> for StrWithCharLen<'a> {
254    fn from(s: &'a str) -> Self {
255        StrWithCharLen::new(s)
256    }
257}
258
259impl<'a> StrWithCharLen<'a> {
260    /// Construct `StrWithCharLen`.
261    ///
262    /// Counts the number of chars in a `&str` and constructs a `StrWithCharLen` from it.
263    pub fn new(s: &'a str) -> Self {
264        let char_len = s.chars().count();
265        StrWithCharLen { inner: s, char_len }
266    }
267
268    pub fn as_str(&self) -> &str {
269        self.inner
270    }
271
272    pub fn char_len(&self) -> usize {
273        self.char_len
274    }
275}
276
277impl<'a> Deref for StrWithCharLen<'a> {
278    type Target = str;
279
280    fn deref(&self) -> &Self::Target {
281        self.inner
282    }
283}
284
285impl<'a> Hash for StrWithCharLen<'a> {
286    fn hash<H>(&self, hasher: &mut H)
287    where
288        H: Hasher,
289    {
290        hasher.write(&(self.char_len as u64).to_le_bytes());
291        self.inner
292            .chars()
293            .for_each(|ch| hasher.write(&(ch as u32).to_le_bytes()));
294    }
295}
296
297/// Iterator over n-grams in a sequence.
298///
299/// N-grams provides an iterator over the n-grams in a sentence between a
300/// minimum and maximum length.
301///
302/// **Warning:** no guarantee is provided with regard to the iteration
303/// order. The iterator only guarantees that all n-grams are produced.
304pub struct NGrams<'a> {
305    max_n: usize,
306    min_n: usize,
307    string: &'a str,
308    char_offsets: VecDeque<usize>,
309    ngram_len: usize,
310}
311
312impl<'a> NGrams<'a> {
313    /// Create a new n-ngram iterator.
314    ///
315    /// The iterator will create n-ngrams of length *[min_n, max_n]*
316    pub fn new(string: &'a str, min_n: usize, max_n: usize) -> Self {
317        assert!(min_n != 0, "The minimum n-gram length cannot be zero.");
318        assert!(
319            min_n <= max_n,
320            "The maximum length should be equal to or greater than the minimum length."
321        );
322
323        // Get the byte offsets of the characters in `string`.
324        let char_offsets = string
325            .char_indices()
326            .map(|(idx, _)| idx)
327            .collect_with_capacity::<VecDeque<_>>(string.len());
328
329        let ngram_len = cmp::min(max_n, char_offsets.len());
330
331        NGrams {
332            max_n,
333            min_n,
334            string,
335            char_offsets,
336            ngram_len,
337        }
338    }
339}
340
341impl<'a> Iterator for NGrams<'a> {
342    type Item = StrWithCharLen<'a>;
343
344    #[inline]
345    fn next(&mut self) -> Option<Self::Item> {
346        // If the n-grams for the current suffix are exhausted,
347        // move to the next suffix.
348        if self.ngram_len < self.min_n {
349            // Remove first character, to get the next suffix.
350            self.char_offsets.pop_front();
351
352            // If the suffix is smaller than the minimal n-gram
353            // length, the iterator is exhausted.
354            if self.char_offsets.len() < self.min_n {
355                return None;
356            }
357
358            // Get the maximum n-gram length for this suffix.
359            self.ngram_len = cmp::min(self.max_n, self.char_offsets.len());
360        }
361
362        let ngram = if self.ngram_len == self.char_offsets.len() {
363            &self.string[self.char_offsets[0]..]
364        } else {
365            &self.string[self.char_offsets[0]..self.char_offsets[self.ngram_len]]
366        };
367
368        let ngram_with_len = StrWithCharLen {
369            inner: ngram,
370            char_len: self.ngram_len,
371        };
372
373        self.ngram_len -= 1;
374
375        Some(ngram_with_len)
376    }
377
378    #[inline]
379    fn size_hint(&self) -> (usize, Option<usize>) {
380        let cap_approx = (self.max_n - self.min_n + 1) * self.char_offsets.len();
381        (cap_approx, Some(cap_approx))
382    }
383}
384
385/// What to compute subwords for.
386#[derive(Clone, Copy, Debug, Eq, PartialEq)]
387pub enum IndicesScope {
388    /// Only substring indices.
389    Substrings,
390
391    /// String and substring indices.
392    StringAndSubstrings,
393}
394
395/// Trait returning iterators over subwords and indices.
396///
397/// Defines methods to iterate over the subwords and
398/// their corresponding indices as assigned through the
399/// given `Indexer`. The `Indexer` can allow collisions.
400pub trait SubwordIndices<'a, 'b, I>
401where
402    I: Indexer + 'b,
403{
404    type Iter: Iterator<Item = (&'a str, NGramVec)>;
405
406    /// Return an iterator over the subword indices of a string.
407    ///
408    /// The n-grams that are used are of length *[min_n, max_n]*, these are
409    /// mapped to indices using the given indexer.
410    fn subword_indices(
411        &'a self,
412        min_n: usize,
413        max_n: usize,
414        indexer: &'b I,
415    ) -> Box<dyn Iterator<Item = u64> + 'a>
416    where
417        'b: 'a,
418    {
419        Box::new(
420            self.subword_indices_with_ngrams(min_n, max_n, indexer)
421                .flat_map(|(_, indices)| indices),
422        )
423    }
424
425    /// Return an iterator over the subwords and subword indices of a string.
426    ///
427    /// The n-grams that are used are of length *[min_n, max_n]*, these are
428    /// mapped to indices using the given indexer.
429    fn subword_indices_with_ngrams(
430        &'a self,
431        min_n: usize,
432        max_n: usize,
433        indexer: &'b I,
434    ) -> Self::Iter;
435}
436
437impl<'a, 'b, I> SubwordIndices<'a, 'b, I> for str
438where
439    I: Indexer + 'b,
440{
441    type Iter = NGramsIndicesIter<'a, 'b, I>;
442    fn subword_indices_with_ngrams(
443        &'a self,
444        min_n: usize,
445        max_n: usize,
446        indexer: &'b I,
447    ) -> Self::Iter {
448        NGramsIndicesIter::new(self, min_n, max_n, indexer)
449    }
450}
451
452/// Iterator over the n-grams in a word and the corresponding subword indices.
453///
454/// `NGramsIndicesIter` is an iterator that produces the n-grams in a word and
455/// the corresponding subword indices as tuples `(ngram, index)`.
456///
457/// **Warning:** no guarantee is provided with regard to the iteration
458/// order. The iterator only guarantees that all n-grams and their indices are produced.
459pub struct NGramsIndicesIter<'a, 'b, I> {
460    string: Option<&'a str>,
461    ngrams: NGrams<'a>,
462    indexer: &'b I,
463}
464
465impl<'a, 'b, I> NGramsIndicesIter<'a, 'b, I> {
466    /// Create a new ngrams-indices iterator.
467    ///
468    /// The iterator will create all ngrams of length *[min_n, max_n]* and corresponding
469    /// subword indices.
470    pub fn new(string: &'a str, min_n: usize, max_n: usize, indexer: &'b I) -> Self
471    where
472        I: Indexer,
473    {
474        let ngrams = NGrams::new(string, min_n, max_n);
475        let string = match I::scope() {
476            IndicesScope::Substrings => None,
477            IndicesScope::StringAndSubstrings => Some(string),
478        };
479
480        NGramsIndicesIter {
481            ngrams,
482            indexer,
483            string,
484        }
485    }
486}
487
488impl<'a, 'b, I> Iterator for NGramsIndicesIter<'a, 'b, I>
489where
490    I: Indexer,
491{
492    type Item = (&'a str, NGramVec);
493
494    #[inline]
495    fn next(&mut self) -> Option<Self::Item> {
496        if let Some(string) = self.string.take() {
497            return Some((string, self.indexer.index_ngram(&string.into())));
498        }
499
500        self.ngrams
501            .next()
502            .map(|ngram| (ngram.inner, self.indexer.index_ngram(&ngram)))
503    }
504
505    #[inline]
506    fn size_hint(&self) -> (usize, Option<usize>) {
507        self.ngrams.size_hint()
508    }
509}
510
511#[cfg(test)]
512mod tests {
513    use std::collections::HashMap;
514
515    use lazy_static::lazy_static;
516    use maplit::hashmap;
517    use smallvec::smallvec;
518
519    use super::{BucketIndexer, FinalfusionHashIndexer, NGrams, SubwordIndices};
520    use crate::subword::NGramVec;
521
522    #[test]
523    fn ngrams_test() {
524        let mut hello_check: Vec<&str> = vec![
525            "h", "he", "hel", "e", "el", "ell", "l", "ll", "llö", "l", "lö", "lö ", "ö", "ö ",
526            "ö w", " ", " w", " wo", "w", "wo", "wor", "o", "or", "orl", "r", "rl", "rld", "l",
527            "ld", "d",
528        ];
529
530        hello_check.sort_unstable();
531
532        let mut hello_ngrams: Vec<_> = NGrams::new("hellö world", 1, 3).map(|s| s.inner).collect();
533        hello_ngrams.sort_unstable();
534
535        assert_eq!(hello_check, hello_ngrams);
536    }
537
538    #[test]
539    fn ngrams_23_test() {
540        let mut hello_check: Vec<&str> = vec![
541            "he", "hel", "el", "ell", "ll", "llo", "lo", "lo ", "o ", "o w", " w", " wo", "wo",
542            "wor", "or", "orl", "rl", "rld", "ld",
543        ];
544
545        hello_check.sort_unstable();
546
547        let mut hello_ngrams: Vec<_> = NGrams::new("hello world", 2, 3).map(|s| s.inner).collect();
548        hello_ngrams.sort_unstable();
549
550        assert_eq!(hello_check, hello_ngrams);
551    }
552
553    #[test]
554    fn short_ngram_test() {
555        let mut yep_check: Vec<&str> = vec!["ˈjə", "jəp", "ˈjəp"];
556        yep_check.sort_unstable();
557
558        let mut yep_ngrams: Vec<_> = NGrams::new("ˈjəp", 3, 6).map(|s| s.inner).collect();
559        yep_ngrams.sort_unstable();
560
561        assert_eq!(yep_check, yep_ngrams);
562    }
563
564    #[test]
565    fn empty_ngram_test() {
566        let check: &[&str] = &[];
567        assert_eq!(
568            NGrams::new("", 1, 3).map(|s| s.inner).collect::<Vec<_>>(),
569            check
570        );
571    }
572
573    #[test]
574    #[should_panic]
575    fn incorrect_min_n_test() {
576        NGrams::new("", 0, 3);
577    }
578
579    #[test]
580    #[should_panic]
581    fn incorrect_max_n_test() {
582        NGrams::new("", 2, 1);
583    }
584
585    lazy_static! {
586        static ref SUBWORD_TESTS_2: HashMap<&'static str, Vec<u64>> = hashmap! {
587            "<Daniël>" =>
588                vec![0, 0, 0, 0, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 3, 3, 3],
589            "<hallo>" =>
590                vec![0, 0, 0, 0, 1, 1, 1, 2, 3, 3, 3, 3, 3, 3],
591        };
592    }
593
594    lazy_static! {
595        static ref SUBWORD_TESTS_21: HashMap<&'static str, Vec<u64>> = hashmap! {
596            "<Daniël>" =>
597                vec![214157, 233912, 311961, 488897, 620206, 741276, 841219,
598                     1167494, 1192256, 1489905, 1532271, 1644730, 1666166,
599                     1679745, 1680294, 1693100, 2026735, 2065822],
600            "<hallo>" =>
601                vec![75867, 104120, 136555, 456131, 599360, 722393, 938007,
602                     985859, 1006102, 1163391, 1218704, 1321513, 1505861,
603                     1892376],
604        };
605    }
606
607    lazy_static! {
608        static ref NGRAMS_INDICES_TESTS_36: HashMap<&'static str, Vec<(&'static str, NGramVec)>> =
609            [
610                (
611                    "<Daniël>",
612                    vec![
613                        ("Dan", smallvec![214157]),
614                        ("iël", smallvec![233912]),
615                        ("Danië", smallvec![311961]),
616                        ("iël>", smallvec![488897]),
617                        ("niël>", smallvec![620206]),
618                        ("anië", smallvec![741276]),
619                        ("Dani", smallvec![841219]),
620                        ("Daniël", smallvec![1167494]),
621                        ("ani", smallvec![1192256]),
622                        ("niël", smallvec![1489905]),
623                        ("ël>", smallvec![1532271]),
624                        ("nië", smallvec![1644730]),
625                        ("<Dan", smallvec![1666166]),
626                        ("aniël", smallvec![1679745]),
627                        ("<Danië", smallvec![1680294]),
628                        ("aniël>", smallvec![1693100]),
629                        ("<Da", smallvec![2026735]),
630                        ("<Dani", smallvec![2065822])
631                    ]
632                ),
633                (
634                    "<hallo>",
635                    vec![
636                        ("lo>", smallvec![75867]),
637                        ("<hal", smallvec![104120]),
638                        ("hallo>", smallvec![136555]),
639                        ("hal", smallvec![456131]),
640                        ("allo>", smallvec![599360]),
641                        ("llo", smallvec![722393]),
642                        ("all", smallvec![938007]),
643                        ("<ha", smallvec![985859]),
644                        ("hallo", smallvec![1006102]),
645                        ("allo", smallvec![1163391]),
646                        ("llo>", smallvec![1218704]),
647                        ("<hallo", smallvec![1321513]),
648                        ("<hall", smallvec![1505861]),
649                        ("hall", smallvec![1892376])
650                    ]
651                )
652            ]
653            .iter()
654            .cloned()
655            .collect();
656    }
657
658    #[test]
659    fn subword_indices_4_test() {
660        // The goal of this test is to ensure that we are correctly bucketing
661        // subwords. With a bucket exponent of 2, there are 2^2 = 4 buckets,
662        // so we should see bucket numbers [0..3].
663
664        let indexer = FinalfusionHashIndexer::new(2);
665        for (word, indices_check) in SUBWORD_TESTS_2.iter() {
666            let mut indices = word.subword_indices(3, 6, &indexer).collect::<Vec<_>>();
667            indices.sort_unstable();
668            assert_eq!(indices_check, &indices);
669        }
670    }
671
672    #[test]
673    fn subword_indices_2m_test() {
674        // This test checks against precomputed bucket numbers. The goal of
675        // if this test is to ensure that the subword_indices() method hashes
676        // to the same buckets in the future.
677
678        let indexer = FinalfusionHashIndexer::new(21);
679        for (word, indices_check) in SUBWORD_TESTS_21.iter() {
680            let mut indices = word.subword_indices(3, 6, &indexer).collect::<Vec<_>>();
681            indices.sort_unstable();
682            assert_eq!(indices_check, &indices);
683        }
684    }
685
686    #[test]
687    fn ngrams_indices_2m_test() {
688        // This test checks against precomputed bucket numbers. The goal of
689        // if this test is to ensure that the ngrams_indices() method hashes
690        // to the same buckets in the future.
691
692        let indexer = FinalfusionHashIndexer::new(21);
693        for (word, ngrams_indices_check) in NGRAMS_INDICES_TESTS_36.iter() {
694            let mut ngrams_indices_test = word
695                .subword_indices_with_ngrams(3, 6, &indexer)
696                .collect::<Vec<_>>();
697            ngrams_indices_test.sort_by_key(|ngrams_indices_pairs| ngrams_indices_pairs.1.clone());
698            for (iter_check, iter_test) in ngrams_indices_check.iter().zip(ngrams_indices_test) {
699                assert_eq!(iter_check.0, iter_test.0);
700            }
701        }
702    }
703}