find_simdoc/
feature.rs

1//! Feature extractor.
2use std::hash::{BuildHasher, Hash, Hasher};
3use std::ops::Range;
4
5use ahash::RandomState;
6use rand::{RngCore, SeedableRng};
7
8use crate::errors::{FindSimdocError, Result};
9use crate::shingling::ShingleIter;
10
11/// Configuration of feature extraction.
12#[derive(Clone, Debug)]
13pub struct FeatureConfig {
14    window_size: usize,
15    delimiter: Option<char>,
16    build_hasher: RandomState,
17}
18
19impl FeatureConfig {
20    /// Creates an instance.
21    ///
22    /// # Arguments
23    ///
24    /// * `window_size` - Window size for w-shingling in feature extraction (must be more than 0).
25    /// * `delimiter` - Delimiter for recognizing words as tokens in feature extraction.
26    ///                 If `None`, characters are used for tokens.
27    /// * `seed` - Seed value for random values.
28    pub fn new(window_size: usize, delimiter: Option<char>, seed: u64) -> Result<Self> {
29        if window_size == 0 {
30            return Err(FindSimdocError::input("Window size must not be 0."));
31        }
32        let mut seeder = rand_xoshiro::SplitMix64::seed_from_u64(seed);
33        let build_hasher = RandomState::with_seeds(
34            seeder.next_u64(),
35            seeder.next_u64(),
36            seeder.next_u64(),
37            seeder.next_u64(),
38        );
39        Ok(Self {
40            window_size,
41            delimiter,
42            build_hasher,
43        })
44    }
45
46    fn hash<I, T>(&self, iter: I) -> u64
47    where
48        I: IntoIterator<Item = T>,
49        T: Hash,
50    {
51        let mut s = self.build_hasher.build_hasher();
52        for t in iter {
53            t.hash(&mut s);
54        }
55        s.finish()
56    }
57}
58
59/// Extractor of feature vectors.
60pub struct FeatureExtractor<'a> {
61    config: &'a FeatureConfig,
62}
63
64impl<'a> FeatureExtractor<'a> {
65    /// Creates an instance.
66    pub const fn new(config: &'a FeatureConfig) -> Self {
67        Self { config }
68    }
69
70    /// Extracts a feature vector from an input text.
71    pub fn extract<S>(&self, text: S, feature: &mut Vec<u64>)
72    where
73        S: AsRef<str>,
74    {
75        let text = text.as_ref();
76
77        feature.clear();
78        if self.config.delimiter.is_none() && self.config.window_size == 1 {
79            // The simplest case.
80            text.chars().for_each(|c| feature.push(c as u64));
81        } else {
82            let token_ranges = self.tokenize(text);
83            for ranges in ShingleIter::new(&token_ranges, self.config.window_size) {
84                feature.push(self.config.hash(ranges.iter().cloned().map(|r| &text[r])));
85            }
86        }
87    }
88
89    /// Extracts a feature vector from an input text with weights of 1.0.
90    pub fn extract_with_weights<S>(&self, text: S, feature: &mut Vec<(u64, f64)>)
91    where
92        S: AsRef<str>,
93    {
94        let text = text.as_ref();
95
96        feature.clear();
97        if self.config.delimiter.is_none() && self.config.window_size == 1 {
98            // The simplest case.
99            text.chars().for_each(|c| {
100                let f = c as u64;
101                let w = 1.;
102                feature.push((f, w))
103            });
104        } else {
105            let token_ranges = self.tokenize(text);
106            for ranges in ShingleIter::new(&token_ranges, self.config.window_size) {
107                let f = self.config.hash(ranges.iter().cloned().map(|r| &text[r]));
108                let w = 1.;
109                feature.push((f, w))
110            }
111        }
112    }
113
114    fn tokenize(&self, text: &str) -> Vec<Range<usize>> {
115        let mut token_ranges = vec![];
116        for _ in 1..self.config.window_size {
117            token_ranges.push(0..0); // BOS
118        }
119        let mut offset = 0;
120        if let Some(delim) = self.config.delimiter {
121            while offset < text.len() {
122                let len = text[offset..].find(delim);
123                if let Some(len) = len {
124                    token_ranges.push(offset..offset + len);
125                    offset += len + 1;
126                } else {
127                    token_ranges.push(offset..text.len());
128                    break;
129                }
130            }
131        } else {
132            for c in text.chars() {
133                let len = c.len_utf8();
134                token_ranges.push(offset..offset + len);
135                offset += len;
136            }
137        }
138        for _ in 1..self.config.window_size {
139            token_ranges.push(text.len()..text.len()); // EOS
140        }
141        token_ranges
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    #[test]
150    fn test_char_unigram() {
151        let config = FeatureConfig::new(1, None, 42).unwrap();
152        let extractor = FeatureExtractor::new(&config);
153
154        let text = "abcd";
155        let mut feature = vec![];
156
157        extractor.extract(text, &mut feature);
158        assert_eq!(
159            feature,
160            vec!['a' as u64, 'b' as u64, 'c' as u64, 'd' as u64]
161        )
162    }
163
164    #[test]
165    fn test_char_bigram() {
166        let config = FeatureConfig::new(2, None, 42).unwrap();
167        let extractor = FeatureExtractor::new(&config);
168
169        let text = "abcd";
170        let mut feature = vec![];
171
172        extractor.extract(text, &mut feature);
173        assert_eq!(
174            feature,
175            vec![
176                config.hash(&["", "a"]),
177                config.hash(&["a", "b"]),
178                config.hash(&["b", "c"]),
179                config.hash(&["c", "d"]),
180                config.hash(&["d", ""]),
181            ]
182        )
183    }
184
185    #[test]
186    fn test_char_trigram() {
187        let config = FeatureConfig::new(3, None, 42).unwrap();
188        let extractor = FeatureExtractor::new(&config);
189
190        let text = "abcd";
191        let mut feature = vec![];
192
193        extractor.extract(text, &mut feature);
194        assert_eq!(
195            feature,
196            vec![
197                config.hash(&["", "", "a"]),
198                config.hash(&["", "a", "b"]),
199                config.hash(&["a", "b", "c"]),
200                config.hash(&["b", "c", "d"]),
201                config.hash(&["c", "d", ""]),
202                config.hash(&["d", "", ""]),
203            ]
204        )
205    }
206
207    #[test]
208    fn test_word_unigram() {
209        let config = FeatureConfig::new(1, Some(' '), 42).unwrap();
210        let extractor = FeatureExtractor::new(&config);
211
212        let text = "abc de fgh";
213        let mut feature = vec![];
214
215        extractor.extract(text, &mut feature);
216        assert_eq!(
217            feature,
218            vec![
219                config.hash(&["abc"]),
220                config.hash(&["de"]),
221                config.hash(&["fgh"]),
222            ]
223        )
224    }
225
226    #[test]
227    fn test_word_bigram() {
228        let config = FeatureConfig::new(2, Some(' '), 42).unwrap();
229        let extractor = FeatureExtractor::new(&config);
230
231        let text = "abc de fgh";
232        let mut feature = vec![];
233
234        extractor.extract(text, &mut feature);
235        assert_eq!(
236            feature,
237            vec![
238                config.hash(&["", "abc"]),
239                config.hash(&["abc", "de"]),
240                config.hash(&["de", "fgh"]),
241                config.hash(&["fgh", ""]),
242            ]
243        )
244    }
245
246    #[test]
247    fn test_word_trigram() {
248        let config = FeatureConfig::new(3, Some(' '), 42).unwrap();
249        let extractor = FeatureExtractor::new(&config);
250
251        let text = "abc de fgh";
252        let mut feature = vec![];
253
254        extractor.extract(text, &mut feature);
255        assert_eq!(
256            feature,
257            vec![
258                config.hash(&["", "", "abc"]),
259                config.hash(&["", "abc", "de"]),
260                config.hash(&["abc", "de", "fgh"]),
261                config.hash(&["de", "fgh", ""]),
262                config.hash(&["fgh", "", ""]),
263            ]
264        )
265    }
266}