1use std::cmp::max;
2use std::collections::HashMap;
3use std::fmt::Debug;
4
5use crate::constants::CLASSIC_MEMORIZABLE_BINS_LOG;
6use rand_xoshiro::rand_core::{RngCore, SeedableRng};
7
8use crate::data_types::Latent;
9
10pub const MIN_SAMPLE: usize = 10;
11const SAMPLE_RATIO: usize = 40;
13const CLASSIC_MEMORIZABLE_BINS: f64 = (1 << CLASSIC_MEMORIZABLE_BINS_LOG) as f64;
16const SAMPLING_PERSISTENCE: usize = 4;
19
20fn calc_sample_n(n: usize) -> Option<usize> {
21 if n >= MIN_SAMPLE {
22 Some(MIN_SAMPLE + (n - MIN_SAMPLE) / SAMPLE_RATIO)
23 } else {
24 None
25 }
26}
27
28#[inline(never)]
29pub fn choose_sample<T, S: Copy + Debug, Filter: Fn(&T) -> Option<S>>(
30 nums: &[T],
31 filter: Filter,
32) -> Option<Vec<S>> {
33 let target_sample_size = calc_sample_n(nums.len())?;
39
40 let mut rng = rand_xoshiro::Xoroshiro128PlusPlus::seed_from_u64(0);
41 let mut visited = vec![0_u8; nums.len().div_ceil(8)];
42 let mut res = Vec::with_capacity(target_sample_size);
43 let mut n_iters = 0;
44 while res.len() < target_sample_size && n_iters < SAMPLING_PERSISTENCE * target_sample_size {
45 let rand_idx = rng.next_u64() as usize % nums.len();
46 let visited_idx = rand_idx / 8;
47 let visited_bit = rand_idx % 8;
48 let mask = 1 << visited_bit;
49 let is_visited = visited[visited_idx] & mask;
50 if is_visited == 0 {
51 if let Some(x) = filter(&nums[rand_idx]) {
52 res.push(x);
53 }
54 visited[visited_idx] |= mask;
55 }
56 n_iters += 1;
57 }
58
59 if res.len() >= MIN_SAMPLE {
60 Some(res)
61 } else {
62 None
63 }
64}
65
66pub struct PrimaryLatentAndSavings<L: Latent> {
67 pub primary: L,
68 pub bits_saved: f64,
69}
70
71#[inline(never)]
72pub fn est_bits_saved_per_num<L: Latent, S: Copy, F: Fn(S) -> PrimaryLatentAndSavings<L>>(
73 sample: &[S],
74 primary_fn: F,
75) -> f64 {
76 let mut primary_counts_and_savings = HashMap::<L, (usize, f64)>::with_capacity(sample.len());
77 for &x in sample {
78 let PrimaryLatentAndSavings {
79 primary: primary_latent,
80 bits_saved,
81 } = primary_fn(x);
82 let entry = primary_counts_and_savings
83 .entry(primary_latent)
84 .or_default();
85 entry.0 += 1;
86 entry.1 += bits_saved;
87 }
88
89 let infrequent_cutoff = max(
90 1,
91 (sample.len() as f64 / CLASSIC_MEMORIZABLE_BINS) as usize,
92 );
93
94 let sample_bits_saved = primary_counts_and_savings
97 .values()
98 .filter(|&&(count, _)| count <= infrequent_cutoff)
99 .map(|&(_, bits_saved)| bits_saved)
100 .sum::<f64>();
101 sample_bits_saved / sample.len() as f64
102}
103
104#[cfg(test)]
105mod tests {
106 use super::*;
107
108 #[test]
109 fn test_sample_n() {
110 assert_eq!(calc_sample_n(9), None);
111 assert_eq!(calc_sample_n(10), Some(10));
112 assert_eq!(calc_sample_n(100), Some(12));
113 assert_eq!(calc_sample_n(1000010), Some(25010));
114 }
115
116 #[test]
117 fn test_choose_sample() {
118 let mut nums = Vec::new();
119 for i in 0..150 {
120 nums.push(-i as f32);
121 }
122 let mut sample = choose_sample(&nums, |&num| {
123 if num == 0.0 {
124 None
125 } else {
126 Some(num)
127 }
128 })
129 .unwrap();
130 sample.sort_unstable_by(f32::total_cmp);
131 assert_eq!(sample.len(), 13);
132 assert_eq!(&sample[0..3], &[-147.0, -142.0, -119.0]);
133 }
134}