1use std::collections::BTreeSet;
2
3use crate::{
4 hash::{hash_bytes, FNV_PRIME},
5 object_pool::PoolReturnable,
6 utils::bits_to_max_feature_index,
7 FeatureHash, FeatureIndex, FeatureMask, NamespaceHash,
8};
9use approx::AbsDiffEq;
10use itertools::Itertools;
11use serde::{Deserialize, Serialize};
12pub struct NamespacesIterator<'a> {
13 namespaces: std::collections::hash_map::Iter<'a, Namespace, SparseFeaturesNamespace>,
14}
15
16#[derive(Clone)]
17pub struct NamespaceIterator<'a> {
18 indices: std::slice::Iter<'a, FeatureIndex>,
19 values: std::slice::Iter<'a, f32>,
20}
21
22impl<'a> Iterator for NamespacesIterator<'a> {
24 type Item = (Namespace, NamespaceIterator<'a>);
25 fn next(&mut self) -> Option<Self::Item> {
26 self.namespaces.next().map(|(namespace_feats, namespace)| {
27 (
28 *namespace_feats,
29 NamespaceIterator {
30 indices: namespace.feature_indices.iter(),
31 values: namespace.feature_values.iter(),
32 },
33 )
34 })
35 }
36}
37
38impl<'a> Iterator for NamespaceIterator<'a> {
39 type Item = (FeatureIndex, f32);
40 fn next(&mut self) -> Option<Self::Item> {
41 if let Some(index) = self.indices.next() {
42 Some((*index, *self.values.next().expect(
43 "NamespaceIterator::indices and NamespaceIterator::values are not the same length",
44 )))
45 } else {
46 None
47 }
48 }
49
50 fn size_hint(&self) -> (usize, Option<usize>) {
51 self.indices.size_hint()
52 }
53
54 fn nth(&mut self, n: usize) -> Option<Self::Item> {
55 match (self.indices.nth(n), self.values.nth(n)) {
56 (Some(i), Some(v)) => Some((*i, *v)),
57 _ => None,
58 }
59 }
60}
61
62#[derive(PartialEq, Clone, Debug)]
63pub struct SparseFeaturesNamespace {
64 namespace: Namespace,
65 feature_indices: Vec<FeatureIndex>,
66 feature_values: Vec<f32>,
67 active: bool,
71}
72
73impl AbsDiffEq for SparseFeaturesNamespace {
74 type Epsilon = f32;
75
76 fn default_epsilon() -> Self::Epsilon {
77 core::f32::EPSILON
78 }
79
80 fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
81 self.namespace == other.namespace
82 && self.feature_indices == other.feature_indices
83 && self
84 .feature_values
85 .iter()
86 .zip(other.feature_values.iter())
87 .all(|(a, b)| a.abs_diff_eq(b, epsilon))
88 }
89}
90
91impl SparseFeaturesNamespace {
92 pub fn iter(&self) -> NamespaceIterator {
93 NamespaceIterator {
94 indices: self.feature_indices.iter(),
95 values: self.feature_values.iter(),
96 }
97 }
98
99 pub fn new(namespace: Namespace) -> SparseFeaturesNamespace {
100 SparseFeaturesNamespace {
101 namespace,
102 feature_indices: Vec::new(),
103 feature_values: Vec::new(),
104 active: false,
105 }
106 }
107
108 pub fn new_with_capacity(namespace: Namespace, capacity: usize) -> SparseFeaturesNamespace {
109 SparseFeaturesNamespace {
110 namespace,
111 feature_indices: Vec::with_capacity(capacity),
112 feature_values: Vec::with_capacity(capacity),
113 active: false,
114 }
115 }
116
117 pub fn size(&self) -> usize {
118 self.feature_indices.len()
119 }
120
121 pub fn namespace(&self) -> Namespace {
122 self.namespace
123 }
124
125 pub fn reserve(&mut self, size: usize) {
126 self.feature_indices
127 .reserve_exact(size - self.feature_indices.capacity());
128 self.feature_values
129 .reserve(size - self.feature_values.capacity());
130 }
131
132 pub fn add_feature(&mut self, feature_index: FeatureIndex, feature_value: f32) {
133 self.feature_indices.push(feature_index);
134 self.feature_values.push(feature_value);
135 }
136
137 pub fn add_features(&mut self, feature_indices: &[FeatureIndex], feature_values: &[f32]) {
138 assert_eq!(feature_indices.len(), feature_values.len());
139 self.feature_indices.extend_from_slice(feature_indices);
140 self.feature_values.extend_from_slice(feature_values);
141 }
142
143 pub fn add_features_with_iter<I1, I2>(&mut self, feature_indices: I1, feature_values: I2)
144 where
145 I1: Iterator<Item = FeatureIndex>,
146 I2: Iterator<Item = f32>,
147 {
148 self.feature_indices.extend(feature_indices);
149 self.feature_values.extend(feature_values);
150 assert_eq!(self.feature_indices.len(), self.feature_values.len());
151 }
152
153 fn clear(&mut self) {
154 self.feature_indices.clear();
155 self.feature_values.clear();
156 }
157
158 fn is_active(&self) -> bool {
159 self.active
160 }
161
162 fn set_active(&mut self, active: bool) {
163 self.active = active;
164 }
165}
166
167#[derive(Serialize, Deserialize, PartialOrd, Ord, Clone, Copy, PartialEq, Eq, Hash, Debug)]
168pub enum Namespace {
169 Named(NamespaceHash),
170 Default,
171}
172
173impl Namespace {
174 pub fn from_name(namespace_name: &str, hash_seed: u32) -> Namespace {
175 match namespace_name {
176 " " => Namespace::Default,
178 ":default" => Namespace::Default,
179 _ => {
180 let namespace_hash = hash_bytes(namespace_name.as_bytes(), hash_seed).into();
181 Namespace::Named(namespace_hash)
182 }
183 }
184 }
185
186 pub fn hash(&self, _hash_seed: u32) -> NamespaceHash {
187 match self {
188 Namespace::Named(hash) => *hash,
189 Namespace::Default => 0.into(),
190 }
191 }
192}
193
194#[derive(PartialEq, Clone, Debug)]
195pub struct SparseFeatures {
196 namespaces: std::collections::HashMap<Namespace, SparseFeaturesNamespace>,
197}
198
199impl Default for SparseFeatures {
200 fn default() -> Self {
201 Self::new()
202 }
203}
204
205impl AbsDiffEq for SparseFeatures {
206 type Epsilon = f32;
207
208 fn default_epsilon() -> Self::Epsilon {
209 core::f32::EPSILON
210 }
211
212 fn abs_diff_eq(&self, other: &Self, epsilon: Self::Epsilon) -> bool {
213 let left_ns: BTreeSet<Namespace> = self
214 .namespaces
215 .iter()
216 .filter_map(|(ns, ns_vals)| if ns_vals.is_active() { Some(ns) } else { None })
217 .cloned()
218 .collect();
219
220 let right_ns = other
221 .namespaces
222 .iter()
223 .filter_map(|(ns, ns_vals)| if ns_vals.is_active() { Some(ns) } else { None })
224 .cloned()
225 .collect();
226 if left_ns != right_ns {
227 return false;
228 }
229
230 for ns in left_ns {
231 let ns_vals = self.namespaces.get(&ns).unwrap();
232 let other_ns_vals = other.namespaces.get(&ns).unwrap();
233 if !ns_vals.abs_diff_eq(other_ns_vals, epsilon) {
234 return false;
235 }
236 }
237
238 true
239 }
240}
241fn quadratic_feature_hash(i1: FeatureIndex, i2: FeatureIndex) -> FeatureHash {
243 let multiplied = (FNV_PRIME as u64).wrapping_mul(u32::from(i1) as u64) as u32;
244 (multiplied ^ u32::from(i2)).into()
245}
246
247fn cubic_feature_hash(i1: FeatureIndex, i2: FeatureIndex, i3: FeatureIndex) -> FeatureHash {
249 let multiplied = (FNV_PRIME as u64).wrapping_mul(u32::from(i1) as u64) as u32;
250 let multiplied = (FNV_PRIME as u64).wrapping_mul((multiplied ^ u32::from(i2)) as u64) as u32;
251 (multiplied ^ u32::from(i3)).into()
252}
253
254fn feature_space_median_index(num_bits: u8) -> FeatureIndex {
255 (u32::from(bits_to_max_feature_index(num_bits)) / 2).into()
256}
257
258pub fn constant_feature_index(num_bits: u8) -> FeatureIndex {
260 feature_space_median_index(num_bits)
261}
262
263impl SparseFeatures {
264 pub fn namespaces(&self) -> NamespacesIterator {
265 NamespacesIterator {
266 namespaces: self.namespaces.iter(),
267 }
268 }
269
270 pub fn quadratic_features(
271 &self,
272 ns1: Namespace,
273 ns2: Namespace,
274 num_bits: u8,
275 ) -> Option<impl Iterator<Item = (FeatureIndex, f32)> + '_> {
276 let ns1 = self.get_namespace(ns1)?;
277 let ns2 = self.get_namespace(ns2)?;
278
279 let masker = FeatureMask::from_num_bits(num_bits);
280
281 Some(
282 ns1.iter()
283 .cartesian_product(ns2.iter().clone())
284 .map(move |((i1, v1), (i2, v2))| {
285 (quadratic_feature_hash(i1, i2).mask(masker), v1 * v2)
286 }),
287 )
288 }
289
290 pub fn cubic_features(
291 &self,
292 ns1: Namespace,
293 ns2: Namespace,
294 ns3: Namespace,
295 num_bits: u8,
296 ) -> Option<impl Iterator<Item = (FeatureIndex, f32)> + '_> {
297 let ns1 = self.get_namespace(ns1)?;
298 let ns2 = self.get_namespace(ns2)?;
299 let ns3 = self.get_namespace(ns3)?;
300
301 let masker = FeatureMask::from_num_bits(num_bits);
302
303 Some(
304 ns1.iter()
305 .cartesian_product(ns2.iter().clone())
306 .cartesian_product(ns3.iter().clone())
307 .map(move |(((i1, v1), (i2, v2)), (i3, v3))| {
308 (cubic_feature_hash(i1, i2, i3).mask(masker), v1 * v2 * v3)
309 }),
310 )
311 }
312
313 pub fn all_features(&self) -> impl Iterator<Item = (FeatureIndex, f32)> + '_ {
314 self.namespaces
315 .iter()
316 .flat_map(|(_, namespace)| namespace.iter())
317 }
318
319 pub fn new() -> SparseFeatures {
320 SparseFeatures {
321 namespaces: std::collections::HashMap::new(),
322 }
323 }
324
325 pub fn get_namespace(&self, namespace: Namespace) -> Option<&SparseFeaturesNamespace> {
326 self.namespaces
327 .get(&namespace)
328 .filter(|namespace| namespace.is_active())
329 }
330
331 pub fn get_namespace_mut(
332 &mut self,
333 namespace: Namespace,
334 ) -> Option<&mut SparseFeaturesNamespace> {
335 self.namespaces
336 .get_mut(&namespace)
337 .filter(|namespace| namespace.is_active())
338 }
339
340 pub fn clear(&mut self) {
341 for namespace in self.namespaces.values_mut() {
342 namespace.clear();
343 namespace.set_active(false);
344 }
345 }
346
347 pub fn get_or_create_namespace(
355 &mut self,
356 namespace: Namespace,
357 ) -> &mut SparseFeaturesNamespace {
358 let item = self
359 .namespaces
360 .entry(namespace)
361 .or_insert(SparseFeaturesNamespace::new(namespace));
362 item.set_active(true);
363 item
364 }
365
366 pub fn get_or_create_namespace_with_capacity(
367 &mut self,
368 namespace: Namespace,
369 capacity: usize,
370 ) -> &mut SparseFeaturesNamespace {
371 let item =
372 self.namespaces
373 .entry(namespace)
374 .or_insert(SparseFeaturesNamespace::new_with_capacity(
375 namespace, capacity,
376 ));
377 item.set_active(true);
378 item
379 }
380
381 pub fn append(&mut self, other: &SparseFeatures) {
382 for (ns, feats) in &other.namespaces {
383 if feats.active {
384 let container = self.get_or_create_namespace_with_capacity(*ns, feats.size());
385 container.add_features(&feats.feature_indices, &feats.feature_values);
386 }
387 }
388 }
389
390 pub fn remove(&mut self, other: &SparseFeatures) {
393 for (ns, feats) in &other.namespaces {
394 if feats.active {
395 let container = self.get_or_create_namespace(*ns);
396 let size = container.size();
397 container.feature_indices.truncate(size - feats.size());
398 container.feature_values.truncate(size - feats.size());
399
400 if container.size() == 0 {
402 container.set_active(false);
403 }
404 }
405 }
406 }
407
408 pub fn empty(&self) -> bool {
409 self.namespaces.is_empty() || self.namespaces.values().all(|ns| !ns.is_active())
410 }
411}
412
413impl PoolReturnable<SparseFeatures> for SparseFeatures {
414 fn clear_and_return_object(mut self, pool: &crate::object_pool::Pool<SparseFeatures>) {
415 self.clear();
416 pool.return_object(self);
417 }
418}