Skip to main content

diskann_utils/sampling/
random.rs

1/*
2 * Copyright (c) Microsoft Corporation.
3 * Licensed under the MIT license.
4 */
5
6use rand::{rngs::StdRng, Rng};
7
8pub trait RoundFromf32 {
9    fn round_from_f32(x: f32) -> Self;
10}
11
12impl RoundFromf32 for f32 {
13    fn round_from_f32(x: f32) -> Self {
14        x
15    }
16}
17impl RoundFromf32 for i8 {
18    fn round_from_f32(x: f32) -> Self {
19        x.round() as i8
20    }
21}
22impl RoundFromf32 for u8 {
23    fn round_from_f32(x: f32) -> Self {
24        x.round() as u8
25    }
26}
27impl RoundFromf32 for half::f16 {
28    fn round_from_f32(x: f32) -> Self {
29        half::f16::from_f32(x)
30    }
31}
32
33pub trait WithApproximateNorm: Sized {
34    fn with_approximate_norm(dim: usize, norm: f32, rng: &mut StdRng) -> Vec<Self>;
35}
36
37impl WithApproximateNorm for f32 {
38    fn with_approximate_norm(dim: usize, norm: f32, rng: &mut StdRng) -> Vec<Self> {
39        generate_random_vector_with_norm_signed(dim, norm, true, rng, |x: f32| x)
40    }
41}
42
43impl WithApproximateNorm for half::f16 {
44    fn with_approximate_norm(dim: usize, norm: f32, rng: &mut StdRng) -> Vec<Self> {
45        // Small QOL improvement, `diskann_wide::cast_f32_to_f16` works under `Miri` while `half::f16::from_f32`
46        // does not.
47        generate_random_vector_with_norm_signed(dim, norm, true, rng, diskann_wide::cast_f32_to_f16)
48    }
49}
50
51impl WithApproximateNorm for u8 {
52    fn with_approximate_norm(dim: usize, norm: f32, rng: &mut StdRng) -> Vec<Self> {
53        generate_random_vector_with_norm_signed(dim, norm, false, rng, |x| x as u8)
54    }
55}
56
57impl WithApproximateNorm for i8 {
58    fn with_approximate_norm(dim: usize, norm: f32, rng: &mut StdRng) -> Vec<Self> {
59        generate_random_vector_with_norm_signed(dim, norm, true, rng, |x| x as i8)
60    }
61}
62
63// Note: private function
64fn generate_random_vector_with_norm_signed<T, F>(
65    dim: usize,
66    norm: f32,
67    signed: bool,
68    rng: &mut StdRng,
69    f: F,
70) -> Vec<T>
71where
72    F: Fn(f32) -> T,
73{
74    let mut vec: Vec<f32> = if signed {
75        (0..dim)
76            .map(|_| rng.random_range(-1.0f32..1.0f32))
77            .collect()
78    } else {
79        (0..dim).map(|_| rng.random_range(0.0f32..1.0f32)).collect()
80    };
81    let current_norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
82    let scale = norm / current_norm;
83    vec.iter_mut().for_each(|x| *x *= scale);
84    vec.into_iter().map(f).collect()
85}
86
87#[cfg(test)]
88mod tests {
89    use super::*;
90    use rand::SeedableRng;
91    use rstest::rstest;
92
93    #[rstest]
94    #[case(1, 0.01)]
95    #[case(100, 0.01)]
96    #[case(171, 5.0)]
97    #[case(1024, 100.7)]
98    fn test_generate_random_vector_with_norm_f32(#[case] dim: usize, #[case] norm: f32) {
99        let seed = 42;
100        let mut rng = StdRng::seed_from_u64(seed);
101        let vec: Vec<f32> = WithApproximateNorm::with_approximate_norm(dim, norm, &mut rng);
102        let computed_norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
103        let tolerance = 1e-5;
104        assert!((computed_norm - norm).abs() < tolerance);
105    }
106
107    #[rstest]
108    #[case(1, 0.01)]
109    #[case(100, 0.01)]
110    #[case(171, 5.0)]
111    #[case(1024, 100.7)]
112    fn test_generate_random_vector_with_norm_half_f16(#[case] dim: usize, #[case] norm: f32) {
113        let seed = 42;
114        let mut rng = StdRng::seed_from_u64(seed);
115        let vec: Vec<half::f16> = WithApproximateNorm::with_approximate_norm(dim, norm, &mut rng);
116        let computed_norm: f32 = vec
117            .iter()
118            .map(|x| {
119                let val: f32 = x.to_f32();
120                val * val
121            })
122            .sum::<f32>()
123            .sqrt();
124        let tolerance = 1e-2; // half precision
125        assert!((computed_norm - norm).abs() / norm < tolerance);
126    }
127
128    #[rstest]
129    #[case(17, 50.0)]
130    #[case(1024, 1007.0)]
131    fn test_generate_random_vector_with_norm_u8(#[case] dim: usize, #[case] norm: f32) {
132        let seed = 42;
133        let mut rng = StdRng::seed_from_u64(seed);
134        let vec: Vec<u8> = WithApproximateNorm::with_approximate_norm(dim, norm, &mut rng);
135        let computed_norm: f32 = vec
136            .iter()
137            .map(|&x| {
138                let val: f32 = x as f32;
139                val * val
140            })
141            .sum::<f32>()
142            .sqrt();
143        let tolerance = 1e-1; // due to quantization
144        assert!((computed_norm - norm).abs() / norm < tolerance);
145    }
146
147    #[rstest]
148    #[case(17, 50.0)]
149    #[case(1024, 1007.0)]
150    fn test_generate_random_vector_with_norm_i8(#[case] dim: usize, #[case] norm: f32) {
151        let seed = 42;
152        let mut rng = StdRng::seed_from_u64(seed);
153        let vec: Vec<i8> = WithApproximateNorm::with_approximate_norm(dim, norm, &mut rng);
154        let computed_norm: f32 = vec
155            .iter()
156            .map(|&x| {
157                let val: f32 = x as f32;
158                val * val
159            })
160            .sum::<f32>()
161            .sqrt();
162        let tolerance = 1e-1; // due to quantization
163        assert!((computed_norm - norm).abs() / norm < tolerance);
164    }
165
166    #[rstest]
167    #[case(3.6f32, 4i8)]
168    #[case(2.3f32, 2i8)]
169    #[case(-1.5f32, -2i8)]
170    fn test_round_f32_to_i8(#[case] input: f32, #[case] expected: i8) {
171        let result: i8 = RoundFromf32::round_from_f32(input);
172        assert_eq!(result, expected);
173    }
174
175    #[rstest]
176    #[case(3.6f32, 4u8)]
177    #[case(2.3f32, 2u8)]
178    #[case(-1.5f32, 0u8)]
179    fn test_round_f32_to_u8(#[case] input: f32, #[case] expected: u8) {
180        let result: u8 = RoundFromf32::round_from_f32(input);
181        assert_eq!(result, expected);
182    }
183
184    #[rstest]
185    #[case(3.6f32, half::f16::from_f32(3.6f32))]
186    #[case(2.3f32, half::f16::from_f32(2.3f32))]
187    #[case(-1.5f32, half::f16::from_f32(-1.5f32))]
188    fn test_round_f32_to_f16(#[case] input: f32, #[case] expected: half::f16) {
189        let result: half::f16 = RoundFromf32::round_from_f32(input);
190        assert_eq!(result, expected);
191    }
192
193    #[rstest]
194    #[case(3.6f32, 3.6f32)]
195    #[case(2.3f32, 2.3f32)]
196    #[case(-1.5f32, -1.5f32)]
197    fn test_round_f32_to_f32(#[case] input: f32, #[case] expected: f32) {
198        let result: f32 = RoundFromf32::round_from_f32(input);
199        assert_eq!(result, expected);
200    }
201}