pco/
sampling.rs

1use std::cmp::max;
2use std::collections::HashMap;
3use std::fmt::Debug;
4
5use crate::constants::CLASSIC_MEMORIZABLE_BINS_LOG;
6use rand_xoshiro::rand_core::{RngCore, SeedableRng};
7
8use crate::data_types::Latent;
9
10pub const MIN_SAMPLE: usize = 10;
11// 1 in this many nums get put into sample
12const SAMPLE_RATIO: usize = 40;
13// Int mults will be considered infrequent if they occur less than 1/this of
14// the time.
15const CLASSIC_MEMORIZABLE_BINS: f64 = (1 << CLASSIC_MEMORIZABLE_BINS_LOG) as f64;
16// how many times over to try collecting samples without replacement before
17// giving up
18const SAMPLING_PERSISTENCE: usize = 4;
19
20fn calc_sample_n(n: usize) -> Option<usize> {
21  if n >= MIN_SAMPLE {
22    Some(MIN_SAMPLE + (n - MIN_SAMPLE) / SAMPLE_RATIO)
23  } else {
24    None
25  }
26}
27
28#[inline(never)]
29pub fn choose_sample<T, S: Copy + Debug, Filter: Fn(&T) -> Option<S>>(
30  nums: &[T],
31  filter: Filter,
32) -> Option<Vec<S>> {
33  // We can't modify the list, and copying it may be expensive, but we want to
34  // sample a small fraction from it without replacement, so we keep a
35  // bitpacked vector representing whether each one is used yet and just keep
36  // resampling.
37  // Maybe this is a bad idea, but it works for now.
38  let target_sample_size = calc_sample_n(nums.len())?;
39
40  let mut rng = rand_xoshiro::Xoroshiro128PlusPlus::seed_from_u64(0);
41  let mut visited = vec![0_u8; nums.len().div_ceil(8)];
42  let mut res = Vec::with_capacity(target_sample_size);
43  let mut n_iters = 0;
44  while res.len() < target_sample_size && n_iters < SAMPLING_PERSISTENCE * target_sample_size {
45    let rand_idx = rng.next_u64() as usize % nums.len();
46    let visited_idx = rand_idx / 8;
47    let visited_bit = rand_idx % 8;
48    let mask = 1 << visited_bit;
49    let is_visited = visited[visited_idx] & mask;
50    if is_visited == 0 {
51      if let Some(x) = filter(&nums[rand_idx]) {
52        res.push(x);
53      }
54      visited[visited_idx] |= mask;
55    }
56    n_iters += 1;
57  }
58
59  if res.len() >= MIN_SAMPLE {
60    Some(res)
61  } else {
62    None
63  }
64}
65
66pub struct PrimaryLatentAndSavings<L: Latent> {
67  pub primary: L,
68  pub bits_saved: f64,
69}
70
71#[inline(never)]
72pub fn est_bits_saved_per_num<L: Latent, S: Copy, F: Fn(S) -> PrimaryLatentAndSavings<L>>(
73  sample: &[S],
74  primary_fn: F,
75) -> f64 {
76  let mut primary_counts_and_savings = HashMap::<L, (usize, f64)>::with_capacity(sample.len());
77  for &x in sample {
78    let PrimaryLatentAndSavings {
79      primary: primary_latent,
80      bits_saved,
81    } = primary_fn(x);
82    let entry = primary_counts_and_savings
83      .entry(primary_latent)
84      .or_default();
85    entry.0 += 1;
86    entry.1 += bits_saved;
87  }
88
89  let infrequent_cutoff = max(
90    1,
91    (sample.len() as f64 / CLASSIC_MEMORIZABLE_BINS) as usize,
92  );
93
94  // Maybe this should be made fuzzy instead of a hard cutoff because it's just
95  // a sample.
96  let sample_bits_saved = primary_counts_and_savings
97    .values()
98    .filter(|&&(count, _)| count <= infrequent_cutoff)
99    .map(|&(_, bits_saved)| bits_saved)
100    .sum::<f64>();
101  sample_bits_saved / sample.len() as f64
102}
103
104#[cfg(test)]
105mod tests {
106  use super::*;
107
108  #[test]
109  fn test_sample_n() {
110    assert_eq!(calc_sample_n(9), None);
111    assert_eq!(calc_sample_n(10), Some(10));
112    assert_eq!(calc_sample_n(100), Some(12));
113    assert_eq!(calc_sample_n(1000010), Some(25010));
114  }
115
116  #[test]
117  fn test_choose_sample() {
118    let mut nums = Vec::new();
119    for i in 0..150 {
120      nums.push(-i as f32);
121    }
122    let mut sample = choose_sample(&nums, |&num| {
123      if num == 0.0 {
124        None
125      } else {
126        Some(num)
127      }
128    })
129    .unwrap();
130    sample.sort_unstable_by(f32::total_cmp);
131    assert_eq!(sample.len(), 13);
132    assert_eq!(&sample[0..3], &[-147.0, -142.0, -119.0]);
133  }
134}