diskann_utils/sampling/
random.rs1use rand::{rngs::StdRng, Rng};
7
8pub trait RoundFromf32 {
9 fn round_from_f32(x: f32) -> Self;
10}
11
12impl RoundFromf32 for f32 {
13 fn round_from_f32(x: f32) -> Self {
14 x
15 }
16}
17impl RoundFromf32 for i8 {
18 fn round_from_f32(x: f32) -> Self {
19 x.round() as i8
20 }
21}
22impl RoundFromf32 for u8 {
23 fn round_from_f32(x: f32) -> Self {
24 x.round() as u8
25 }
26}
27impl RoundFromf32 for half::f16 {
28 fn round_from_f32(x: f32) -> Self {
29 half::f16::from_f32(x)
30 }
31}
32
33pub trait WithApproximateNorm: Sized {
34 fn with_approximate_norm(dim: usize, norm: f32, rng: &mut StdRng) -> Vec<Self>;
35}
36
37impl WithApproximateNorm for f32 {
38 fn with_approximate_norm(dim: usize, norm: f32, rng: &mut StdRng) -> Vec<Self> {
39 generate_random_vector_with_norm_signed(dim, norm, true, rng, |x: f32| x)
40 }
41}
42
43impl WithApproximateNorm for half::f16 {
44 fn with_approximate_norm(dim: usize, norm: f32, rng: &mut StdRng) -> Vec<Self> {
45 generate_random_vector_with_norm_signed(dim, norm, true, rng, diskann_wide::cast_f32_to_f16)
48 }
49}
50
51impl WithApproximateNorm for u8 {
52 fn with_approximate_norm(dim: usize, norm: f32, rng: &mut StdRng) -> Vec<Self> {
53 generate_random_vector_with_norm_signed(dim, norm, false, rng, |x| x as u8)
54 }
55}
56
57impl WithApproximateNorm for i8 {
58 fn with_approximate_norm(dim: usize, norm: f32, rng: &mut StdRng) -> Vec<Self> {
59 generate_random_vector_with_norm_signed(dim, norm, true, rng, |x| x as i8)
60 }
61}
62
63fn generate_random_vector_with_norm_signed<T, F>(
65 dim: usize,
66 norm: f32,
67 signed: bool,
68 rng: &mut StdRng,
69 f: F,
70) -> Vec<T>
71where
72 F: Fn(f32) -> T,
73{
74 let mut vec: Vec<f32> = if signed {
75 (0..dim)
76 .map(|_| rng.random_range(-1.0f32..1.0f32))
77 .collect()
78 } else {
79 (0..dim).map(|_| rng.random_range(0.0f32..1.0f32)).collect()
80 };
81 let current_norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
82 let scale = norm / current_norm;
83 vec.iter_mut().for_each(|x| *x *= scale);
84 vec.into_iter().map(f).collect()
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90 use rand::SeedableRng;
91 use rstest::rstest;
92
93 #[rstest]
94 #[case(1, 0.01)]
95 #[case(100, 0.01)]
96 #[case(171, 5.0)]
97 #[case(1024, 100.7)]
98 fn test_generate_random_vector_with_norm_f32(#[case] dim: usize, #[case] norm: f32) {
99 let seed = 42;
100 let mut rng = StdRng::seed_from_u64(seed);
101 let vec: Vec<f32> = WithApproximateNorm::with_approximate_norm(dim, norm, &mut rng);
102 let computed_norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
103 let tolerance = 1e-5;
104 assert!((computed_norm - norm).abs() < tolerance);
105 }
106
107 #[rstest]
108 #[case(1, 0.01)]
109 #[case(100, 0.01)]
110 #[case(171, 5.0)]
111 #[case(1024, 100.7)]
112 fn test_generate_random_vector_with_norm_half_f16(#[case] dim: usize, #[case] norm: f32) {
113 let seed = 42;
114 let mut rng = StdRng::seed_from_u64(seed);
115 let vec: Vec<half::f16> = WithApproximateNorm::with_approximate_norm(dim, norm, &mut rng);
116 let computed_norm: f32 = vec
117 .iter()
118 .map(|x| {
119 let val: f32 = x.to_f32();
120 val * val
121 })
122 .sum::<f32>()
123 .sqrt();
124 let tolerance = 1e-2; assert!((computed_norm - norm).abs() / norm < tolerance);
126 }
127
128 #[rstest]
129 #[case(17, 50.0)]
130 #[case(1024, 1007.0)]
131 fn test_generate_random_vector_with_norm_u8(#[case] dim: usize, #[case] norm: f32) {
132 let seed = 42;
133 let mut rng = StdRng::seed_from_u64(seed);
134 let vec: Vec<u8> = WithApproximateNorm::with_approximate_norm(dim, norm, &mut rng);
135 let computed_norm: f32 = vec
136 .iter()
137 .map(|&x| {
138 let val: f32 = x as f32;
139 val * val
140 })
141 .sum::<f32>()
142 .sqrt();
143 let tolerance = 1e-1; assert!((computed_norm - norm).abs() / norm < tolerance);
145 }
146
147 #[rstest]
148 #[case(17, 50.0)]
149 #[case(1024, 1007.0)]
150 fn test_generate_random_vector_with_norm_i8(#[case] dim: usize, #[case] norm: f32) {
151 let seed = 42;
152 let mut rng = StdRng::seed_from_u64(seed);
153 let vec: Vec<i8> = WithApproximateNorm::with_approximate_norm(dim, norm, &mut rng);
154 let computed_norm: f32 = vec
155 .iter()
156 .map(|&x| {
157 let val: f32 = x as f32;
158 val * val
159 })
160 .sum::<f32>()
161 .sqrt();
162 let tolerance = 1e-1; assert!((computed_norm - norm).abs() / norm < tolerance);
164 }
165
166 #[rstest]
167 #[case(3.6f32, 4i8)]
168 #[case(2.3f32, 2i8)]
169 #[case(-1.5f32, -2i8)]
170 fn test_round_f32_to_i8(#[case] input: f32, #[case] expected: i8) {
171 let result: i8 = RoundFromf32::round_from_f32(input);
172 assert_eq!(result, expected);
173 }
174
175 #[rstest]
176 #[case(3.6f32, 4u8)]
177 #[case(2.3f32, 2u8)]
178 #[case(-1.5f32, 0u8)]
179 fn test_round_f32_to_u8(#[case] input: f32, #[case] expected: u8) {
180 let result: u8 = RoundFromf32::round_from_f32(input);
181 assert_eq!(result, expected);
182 }
183
184 #[rstest]
185 #[case(3.6f32, half::f16::from_f32(3.6f32))]
186 #[case(2.3f32, half::f16::from_f32(2.3f32))]
187 #[case(-1.5f32, half::f16::from_f32(-1.5f32))]
188 fn test_round_f32_to_f16(#[case] input: f32, #[case] expected: half::f16) {
189 let result: half::f16 = RoundFromf32::round_from_f32(input);
190 assert_eq!(result, expected);
191 }
192
193 #[rstest]
194 #[case(3.6f32, 3.6f32)]
195 #[case(2.3f32, 2.3f32)]
196 #[case(-1.5f32, -1.5f32)]
197 fn test_round_f32_to_f32(#[case] input: f32, #[case] expected: f32) {
198 let result: f32 = RoundFromf32::round_from_f32(input);
199 assert_eq!(result, expected);
200 }
201}