finalfusion/compat/floret/
indexer.rs

1use std::io::Cursor;
2
3use murmur3::murmur3_x64_128;
4use smallvec::smallvec;
5
6use crate::subword::{Indexer, IndicesScope, NGramVec, StrWithCharLen};
7
8/// floret subword indexer.
9///
10/// By default, floret does not use a separate word embedding matrix. Every
11/// n-gram and the full word is mapped to 1 to 4 hash functions.
12#[derive(Clone, Copy, Debug, Eq, PartialEq)]
13pub struct FloretIndexer {
14    n_buckets: u64,
15    seed: u32,
16    n_hashes: u32,
17}
18
19impl FloretIndexer {
20    pub fn new(n_buckets: u64, n_hashes: u32, seed: u32) -> Self {
21        assert!(
22            n_hashes > 0 && n_hashes <= 4,
23            "Floret indexer needs 1 to 4 hashes, got {}",
24            n_hashes
25        );
26
27        assert_ne!(n_buckets, 0, "Floret needs at least 1 bucket.");
28
29        Self {
30            n_buckets,
31            n_hashes,
32            seed,
33        }
34    }
35
36    pub fn n_buckets(&self) -> u64 {
37        self.n_buckets
38    }
39
40    pub fn n_hashes(&self) -> u32 {
41        self.n_hashes
42    }
43
44    pub fn seed(&self) -> u32 {
45        self.seed
46    }
47}
48
49impl Indexer for FloretIndexer {
50    fn index_ngram(&self, ngram: &StrWithCharLen) -> NGramVec {
51        let hash = murmur3_x64_128(&mut Cursor::new(ngram.as_bytes()), self.seed)
52            .expect("Murmur hash failed");
53
54        let mut hash_array = [0; 4];
55        hash_array[0] = hash as u32;
56        hash_array[1] = (hash >> 32) as u32;
57        hash_array[2] = (hash >> 64) as u32;
58        hash_array[3] = (hash >> 96) as u32;
59
60        let mut indices = smallvec![0; self.n_hashes as usize];
61        for i in 0..self.n_hashes as usize {
62            indices[i] = hash_array[i] as u64 % self.n_buckets;
63        }
64
65        indices
66    }
67
68    fn upper_bound(&self) -> u64 {
69        self.n_buckets
70    }
71
72    fn infallible() -> bool {
73        true
74    }
75
76    fn scope() -> IndicesScope {
77        IndicesScope::StringAndSubstrings
78    }
79}