Skip to main content

diskann_utils/sampling/
random.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use rand::{rngs::StdRng, Rng};
7use rand_distr::StandardNormal;
8
9pub trait RoundFromf32 {
10    fn round_from_f32(x: f32) -> Self;
11}
12
13impl RoundFromf32 for f32 {
14    fn round_from_f32(x: f32) -> Self {
15        x
16    }
17}
18impl RoundFromf32 for i8 {
19    fn round_from_f32(x: f32) -> Self {
20        x.round() as i8
21    }
22}
23impl RoundFromf32 for u8 {
24    fn round_from_f32(x: f32) -> Self {
25        x.round() as u8
26    }
27}
28impl RoundFromf32 for half::f16 {
29    fn round_from_f32(x: f32) -> Self {
30        half::f16::from_f32(x)
31    }
32}
33
34pub trait WithApproximateNorm: Sized {
35    fn with_approximate_norm(dim: usize, norm: f32, rng: &mut StdRng) -> Vec<Self>;
36}
37
38impl WithApproximateNorm for f32 {
39    fn with_approximate_norm(dim: usize, norm: f32, rng: &mut StdRng) -> Vec<Self> {
40        generate_random_vector_with_norm_signed(dim, norm, true, rng, |x: f32| x)
41    }
42}
43
44impl WithApproximateNorm for half::f16 {
45    fn with_approximate_norm(dim: usize, norm: f32, rng: &mut StdRng) -> Vec<Self> {
46        // Small QOL improvement, `diskann_wide::cast_f32_to_f16` works under `Miri` while `half::f16::from_f32`
47        // does not.
48        generate_random_vector_with_norm_signed(dim, norm, true, rng, diskann_wide::cast_f32_to_f16)
49    }
50}
51
52impl WithApproximateNorm for u8 {
53    fn with_approximate_norm(dim: usize, norm: f32, rng: &mut StdRng) -> Vec<Self> {
54        generate_random_vector_with_norm_signed(dim, norm, false, rng, |x| x as u8)
55    }
56}
57
58impl WithApproximateNorm for i8 {
59    fn with_approximate_norm(dim: usize, norm: f32, rng: &mut StdRng) -> Vec<Self> {
60        generate_random_vector_with_norm_signed(dim, norm, true, rng, |x| x as i8)
61    }
62}
63
64// This function uses StandardNormal distribution. StandardNormal creates uniformly
65// distributed points on sphere surface, making the graph easier to navigate.
66fn generate_random_vector_with_norm_signed<T, F>(
67    dim: usize,
68    norm: f32,
69    signed: bool,
70    rng: &mut StdRng,
71    f: F,
72) -> Vec<T>
73where
74    F: Fn(f32) -> T,
75{
76    let mut vec: Vec<f32> = (0..dim).map(|_| rng.sample(StandardNormal)).collect();
77    let current_norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
78    let scale = norm / current_norm;
79    if signed {
80        vec.iter_mut().for_each(|x| *x *= scale);
81    } else {
82        vec.iter_mut().for_each(|x| *x = (*x * scale).abs());
83    };
84    vec.into_iter().map(f).collect()
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90    use rand::SeedableRng;
91    use rstest::rstest;
92
93    #[rstest]
94    #[case(1, 0.01)]
95    #[case(100, 0.01)]
96    #[case(171, 5.0)]
97    #[case(1024, 100.7)]
98    fn test_generate_random_vector_with_norm_f32(#[case] dim: usize, #[case] norm: f32) {
99        let seed = 42;
100        let mut rng = StdRng::seed_from_u64(seed);
101        let vec: Vec<f32> = WithApproximateNorm::with_approximate_norm(dim, norm, &mut rng);
102        let computed_norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
103        let tolerance = 1e-5;
104        assert!((computed_norm - norm).abs() / norm < tolerance);
105    }
106
107    #[rstest]
108    #[case(1, 0.01)]
109    #[case(100, 0.01)]
110    #[case(171, 5.0)]
111    #[case(1024, 100.7)]
112    fn test_generate_random_vector_with_norm_half_f16(#[case] dim: usize, #[case] norm: f32) {
113        let seed = 42;
114        let mut rng = StdRng::seed_from_u64(seed);
115        let vec: Vec<half::f16> = WithApproximateNorm::with_approximate_norm(dim, norm, &mut rng);
116        let computed_norm: f32 = vec
117            .iter()
118            .map(|x| {
119                let val: f32 = x.to_f32();
120                val * val
121            })
122            .sum::<f32>()
123            .sqrt();
124        let tolerance = 1e-2; // half precision
125        assert!((computed_norm - norm).abs() / norm < tolerance);
126    }
127
128    #[rstest]
129    #[case(17, 50.0)]
130    #[case(1024, 1007.0)]
131    fn test_generate_random_vector_with_norm_u8(#[case] dim: usize, #[case] norm: f32) {
132        let seed = 42;
133        let mut rng = StdRng::seed_from_u64(seed);
134        let vec: Vec<u8> = WithApproximateNorm::with_approximate_norm(dim, norm, &mut rng);
135        let computed_norm: f32 = vec
136            .iter()
137            .map(|&x| {
138                let val: f32 = x as f32;
139                val * val
140            })
141            .sum::<f32>()
142            .sqrt();
143        let tolerance = 1e-1; // due to quantization
144        assert!((computed_norm - norm).abs() / norm < tolerance);
145    }
146
147    #[rstest]
148    #[case(17, 50.0)]
149    #[case(1024, 1007.0)]
150    fn test_generate_random_vector_with_norm_i8(#[case] dim: usize, #[case] norm: f32) {
151        let seed = 42;
152        let mut rng = StdRng::seed_from_u64(seed);
153        let vec: Vec<i8> = WithApproximateNorm::with_approximate_norm(dim, norm, &mut rng);
154        let computed_norm: f32 = vec
155            .iter()
156            .map(|&x| {
157                let val: f32 = x as f32;
158                val * val
159            })
160            .sum::<f32>()
161            .sqrt();
162        let tolerance = 1e-1; // due to quantization
163        assert!((computed_norm - norm).abs() / norm < tolerance);
164    }
165
166    #[rstest]
167    #[case(3.6f32, 4i8)]
168    #[case(2.3f32, 2i8)]
169    #[case(-1.5f32, -2i8)]
170    fn test_round_f32_to_i8(#[case] input: f32, #[case] expected: i8) {
171        let result: i8 = RoundFromf32::round_from_f32(input);
172        assert_eq!(result, expected);
173    }
174
175    #[rstest]
176    #[case(3.6f32, 4u8)]
177    #[case(2.3f32, 2u8)]
178    #[case(-1.5f32, 0u8)]
179    fn test_round_f32_to_u8(#[case] input: f32, #[case] expected: u8) {
180        let result: u8 = RoundFromf32::round_from_f32(input);
181        assert_eq!(result, expected);
182    }
183
184    #[rstest]
185    #[case(3.6f32, half::f16::from_f32(3.6f32))]
186    #[case(2.3f32, half::f16::from_f32(2.3f32))]
187    #[case(-1.5f32, half::f16::from_f32(-1.5f32))]
188    fn test_round_f32_to_f16(#[case] input: f32, #[case] expected: half::f16) {
189        let result: half::f16 = RoundFromf32::round_from_f32(input);
190        assert_eq!(result, expected);
191    }
192
193    #[rstest]
194    #[case(3.6f32, 3.6f32)]
195    #[case(2.3f32, 2.3f32)]
196    #[case(-1.5f32, -1.5f32)]
197    fn test_round_f32_to_f32(#[case] input: f32, #[case] expected: f32) {
198        let result: f32 = RoundFromf32::round_from_f32(input);
199        assert_eq!(result, expected);
200    }
201
202    /// Test that generated points are evenly distributed on a circle.
203    ///
204    /// **Testing methodology:**
205    /// 1. Split the circle into 36 buckets (signed) or 9 buckets (unsigned), each covering 10 degrees
206    /// 2. Generate points and count how many fall into each angular bucket
207    /// 3. Check that each bucket's count is within `tolerance_sigmas × σ` of the expected count,
208    ///    where σ = sqrt(expected) is the statistical noise for random sampling
209    /// 4. Fail if any bucket deviates too much (indicates clustering instead of uniform distribution)
210    ///
211    /// **Tolerance levels:**
212    ///   - tolerance_sigmas = 1.0 → Very strict, only allows ±1σ deviation (about 68% of buckets would naturally fall within this)
213    ///   - tolerance_sigmas = 3.0 → Moderate, allows ±3σ deviation (99.7% would naturally fall within this)
214    ///   - tolerance_sigmas = 6.0 → Very lenient, allows ±6σ deviation (99.9997% would naturally fall within this)
215    #[rstest]
216    #[case(true, 500, 3.0, 42)]
217    #[case(true, 500, 3.0, 43)]
218    #[case(true, 500, 3.0, 44)]
219    #[case(false, 500, 3.0, 42)]
220    #[case(false, 500, 3.0, 43)]
221    #[case(false, 500, 3.0, 44)]
222    fn test_generate_random_vector_with_norm_signed_produces_uniform_distribution_on_circle(
223        #[case] signed: bool,
224        #[case] expected_per_bucket: usize,
225        #[case] tolerance_sigmas: f32,
226        #[case] seed: u64,
227    ) {
228        let dim = 2;
229        let norm = 1.0;
230        let mut rng = StdRng::seed_from_u64(seed);
231
232        // Step 1: Pick number of buckets and calculate samples
233        let num_buckets = if signed { 36 } else { 9 };
234        let num_samples = num_buckets * expected_per_bucket;
235
236        // Generate samples
237        let samples: Vec<Vec<f32>> = (0..num_samples)
238            .map(|_| generate_random_vector_with_norm_signed(dim, norm, signed, &mut rng, |x| x))
239            .collect();
240
241        // Step 2: Count hits per bucket
242        let mut counts = vec![0usize; num_buckets];
243
244        for sample in &samples {
245            let theta = sample[1].atan2(sample[0]); // atan2(y, x) returns [-π, π]
246
247            // Map to bucket: floor(θ / 2π × buckets)
248            let bucket = if signed {
249                // Full circle [0, 2π) → [0, 36)
250                let normalized_theta = if theta < 0.0 {
251                    theta + 2.0 * std::f32::consts::PI
252                } else {
253                    theta
254                };
255                ((normalized_theta / (2.0 * std::f32::consts::PI)) * num_buckets as f32).floor()
256                    as usize
257                    % num_buckets
258            } else {
259                // First quadrant [0, π/2) → [0, 9)
260                ((theta / (std::f32::consts::PI / 2.0)) * num_buckets as f32).floor() as usize
261            };
262
263            counts[bucket] += 1;
264        }
265
266        // Step 3: Check each bucket is within tolerance_sigmas × σ
267        // Noise per bucket: σ ≈ sqrt(expected)
268        // Threshold: |observed - expected| / expected > tolerance_sigmas / sqrt(expected)
269        let sigma = (expected_per_bucket as f32).sqrt();
270        let threshold = tolerance_sigmas / sigma;
271
272        let failed_count = counts
273            .iter()
274            .filter(|&&observed| {
275                let deviation = (observed as f32 - expected_per_bucket as f32).abs()
276                    / expected_per_bucket as f32;
277                deviation > threshold
278            })
279            .count();
280
281        assert_eq!(
282            failed_count,
283            0,
284            "Distribution not uniform: {} out of {} bucket(s) had point counts that deviated more than {}σ from expected. \
285             This indicates the generator is producing clustered points instead of evenly distributed points on the circle surface.",
286            failed_count,
287            num_buckets,
288            tolerance_sigmas
289        );
290    }
291}