1use std::hash::{BuildHasher, Hash, Hasher};
3use std::ops::Range;
4
5use ahash::RandomState;
6use rand::{RngCore, SeedableRng};
7
8use crate::errors::{FindSimdocError, Result};
9use crate::shingling::ShingleIter;
10
11#[derive(Clone, Debug)]
13pub struct FeatureConfig {
14 window_size: usize,
15 delimiter: Option<char>,
16 build_hasher: RandomState,
17}
18
19impl FeatureConfig {
20 pub fn new(window_size: usize, delimiter: Option<char>, seed: u64) -> Result<Self> {
29 if window_size == 0 {
30 return Err(FindSimdocError::input("Window size must not be 0."));
31 }
32 let mut seeder = rand_xoshiro::SplitMix64::seed_from_u64(seed);
33 let build_hasher = RandomState::with_seeds(
34 seeder.next_u64(),
35 seeder.next_u64(),
36 seeder.next_u64(),
37 seeder.next_u64(),
38 );
39 Ok(Self {
40 window_size,
41 delimiter,
42 build_hasher,
43 })
44 }
45
46 fn hash<I, T>(&self, iter: I) -> u64
47 where
48 I: IntoIterator<Item = T>,
49 T: Hash,
50 {
51 let mut s = self.build_hasher.build_hasher();
52 for t in iter {
53 t.hash(&mut s);
54 }
55 s.finish()
56 }
57}
58
59pub struct FeatureExtractor<'a> {
61 config: &'a FeatureConfig,
62}
63
64impl<'a> FeatureExtractor<'a> {
65 pub const fn new(config: &'a FeatureConfig) -> Self {
67 Self { config }
68 }
69
70 pub fn extract<S>(&self, text: S, feature: &mut Vec<u64>)
72 where
73 S: AsRef<str>,
74 {
75 let text = text.as_ref();
76
77 feature.clear();
78 if self.config.delimiter.is_none() && self.config.window_size == 1 {
79 text.chars().for_each(|c| feature.push(c as u64));
81 } else {
82 let token_ranges = self.tokenize(text);
83 for ranges in ShingleIter::new(&token_ranges, self.config.window_size) {
84 feature.push(self.config.hash(ranges.iter().cloned().map(|r| &text[r])));
85 }
86 }
87 }
88
89 pub fn extract_with_weights<S>(&self, text: S, feature: &mut Vec<(u64, f64)>)
91 where
92 S: AsRef<str>,
93 {
94 let text = text.as_ref();
95
96 feature.clear();
97 if self.config.delimiter.is_none() && self.config.window_size == 1 {
98 text.chars().for_each(|c| {
100 let f = c as u64;
101 let w = 1.;
102 feature.push((f, w))
103 });
104 } else {
105 let token_ranges = self.tokenize(text);
106 for ranges in ShingleIter::new(&token_ranges, self.config.window_size) {
107 let f = self.config.hash(ranges.iter().cloned().map(|r| &text[r]));
108 let w = 1.;
109 feature.push((f, w))
110 }
111 }
112 }
113
114 fn tokenize(&self, text: &str) -> Vec<Range<usize>> {
115 let mut token_ranges = vec![];
116 for _ in 1..self.config.window_size {
117 token_ranges.push(0..0); }
119 let mut offset = 0;
120 if let Some(delim) = self.config.delimiter {
121 while offset < text.len() {
122 let len = text[offset..].find(delim);
123 if let Some(len) = len {
124 token_ranges.push(offset..offset + len);
125 offset += len + 1;
126 } else {
127 token_ranges.push(offset..text.len());
128 break;
129 }
130 }
131 } else {
132 for c in text.chars() {
133 let len = c.len_utf8();
134 token_ranges.push(offset..offset + len);
135 offset += len;
136 }
137 }
138 for _ in 1..self.config.window_size {
139 token_ranges.push(text.len()..text.len()); }
141 token_ranges
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::*;
148
149 #[test]
150 fn test_char_unigram() {
151 let config = FeatureConfig::new(1, None, 42).unwrap();
152 let extractor = FeatureExtractor::new(&config);
153
154 let text = "abcd";
155 let mut feature = vec![];
156
157 extractor.extract(text, &mut feature);
158 assert_eq!(
159 feature,
160 vec!['a' as u64, 'b' as u64, 'c' as u64, 'd' as u64]
161 )
162 }
163
164 #[test]
165 fn test_char_bigram() {
166 let config = FeatureConfig::new(2, None, 42).unwrap();
167 let extractor = FeatureExtractor::new(&config);
168
169 let text = "abcd";
170 let mut feature = vec![];
171
172 extractor.extract(text, &mut feature);
173 assert_eq!(
174 feature,
175 vec![
176 config.hash(&["", "a"]),
177 config.hash(&["a", "b"]),
178 config.hash(&["b", "c"]),
179 config.hash(&["c", "d"]),
180 config.hash(&["d", ""]),
181 ]
182 )
183 }
184
185 #[test]
186 fn test_char_trigram() {
187 let config = FeatureConfig::new(3, None, 42).unwrap();
188 let extractor = FeatureExtractor::new(&config);
189
190 let text = "abcd";
191 let mut feature = vec![];
192
193 extractor.extract(text, &mut feature);
194 assert_eq!(
195 feature,
196 vec![
197 config.hash(&["", "", "a"]),
198 config.hash(&["", "a", "b"]),
199 config.hash(&["a", "b", "c"]),
200 config.hash(&["b", "c", "d"]),
201 config.hash(&["c", "d", ""]),
202 config.hash(&["d", "", ""]),
203 ]
204 )
205 }
206
207 #[test]
208 fn test_word_unigram() {
209 let config = FeatureConfig::new(1, Some(' '), 42).unwrap();
210 let extractor = FeatureExtractor::new(&config);
211
212 let text = "abc de fgh";
213 let mut feature = vec![];
214
215 extractor.extract(text, &mut feature);
216 assert_eq!(
217 feature,
218 vec![
219 config.hash(&["abc"]),
220 config.hash(&["de"]),
221 config.hash(&["fgh"]),
222 ]
223 )
224 }
225
226 #[test]
227 fn test_word_bigram() {
228 let config = FeatureConfig::new(2, Some(' '), 42).unwrap();
229 let extractor = FeatureExtractor::new(&config);
230
231 let text = "abc de fgh";
232 let mut feature = vec![];
233
234 extractor.extract(text, &mut feature);
235 assert_eq!(
236 feature,
237 vec![
238 config.hash(&["", "abc"]),
239 config.hash(&["abc", "de"]),
240 config.hash(&["de", "fgh"]),
241 config.hash(&["fgh", ""]),
242 ]
243 )
244 }
245
246 #[test]
247 fn test_word_trigram() {
248 let config = FeatureConfig::new(3, Some(' '), 42).unwrap();
249 let extractor = FeatureExtractor::new(&config);
250
251 let text = "abc de fgh";
252 let mut feature = vec![];
253
254 extractor.extract(text, &mut feature);
255 assert_eq!(
256 feature,
257 vec![
258 config.hash(&["", "", "abc"]),
259 config.hash(&["", "abc", "de"]),
260 config.hash(&["abc", "de", "fgh"]),
261 config.hash(&["de", "fgh", ""]),
262 config.hash(&["fgh", "", ""]),
263 ]
264 )
265 }
266}