1use rand::{rngs::StdRng, Rng};
7use rand_distr::StandardNormal;
8
9pub trait RoundFromf32 {
10 fn round_from_f32(x: f32) -> Self;
11}
12
13impl RoundFromf32 for f32 {
14 fn round_from_f32(x: f32) -> Self {
15 x
16 }
17}
18impl RoundFromf32 for i8 {
19 fn round_from_f32(x: f32) -> Self {
20 x.round() as i8
21 }
22}
23impl RoundFromf32 for u8 {
24 fn round_from_f32(x: f32) -> Self {
25 x.round() as u8
26 }
27}
28impl RoundFromf32 for half::f16 {
29 fn round_from_f32(x: f32) -> Self {
30 half::f16::from_f32(x)
31 }
32}
33
34pub trait WithApproximateNorm: Sized {
35 fn with_approximate_norm(dim: usize, norm: f32, rng: &mut StdRng) -> Vec<Self>;
36}
37
38impl WithApproximateNorm for f32 {
39 fn with_approximate_norm(dim: usize, norm: f32, rng: &mut StdRng) -> Vec<Self> {
40 generate_random_vector_with_norm_signed(dim, norm, true, rng, |x: f32| x)
41 }
42}
43
44impl WithApproximateNorm for half::f16 {
45 fn with_approximate_norm(dim: usize, norm: f32, rng: &mut StdRng) -> Vec<Self> {
46 generate_random_vector_with_norm_signed(dim, norm, true, rng, diskann_wide::cast_f32_to_f16)
49 }
50}
51
52impl WithApproximateNorm for u8 {
53 fn with_approximate_norm(dim: usize, norm: f32, rng: &mut StdRng) -> Vec<Self> {
54 generate_random_vector_with_norm_signed(dim, norm, false, rng, |x| x as u8)
55 }
56}
57
58impl WithApproximateNorm for i8 {
59 fn with_approximate_norm(dim: usize, norm: f32, rng: &mut StdRng) -> Vec<Self> {
60 generate_random_vector_with_norm_signed(dim, norm, true, rng, |x| x as i8)
61 }
62}
63
64fn generate_random_vector_with_norm_signed<T, F>(
67 dim: usize,
68 norm: f32,
69 signed: bool,
70 rng: &mut StdRng,
71 f: F,
72) -> Vec<T>
73where
74 F: Fn(f32) -> T,
75{
76 let mut vec: Vec<f32> = (0..dim).map(|_| rng.sample(StandardNormal)).collect();
77 let current_norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
78 let scale = norm / current_norm;
79 if signed {
80 vec.iter_mut().for_each(|x| *x *= scale);
81 } else {
82 vec.iter_mut().for_each(|x| *x = (*x * scale).abs());
83 };
84 vec.into_iter().map(f).collect()
85}
86
87#[cfg(test)]
88mod tests {
89 use super::*;
90 use rand::SeedableRng;
91 use rstest::rstest;
92
93 #[rstest]
94 #[case(1, 0.01)]
95 #[case(100, 0.01)]
96 #[case(171, 5.0)]
97 #[case(1024, 100.7)]
98 fn test_generate_random_vector_with_norm_f32(#[case] dim: usize, #[case] norm: f32) {
99 let seed = 42;
100 let mut rng = StdRng::seed_from_u64(seed);
101 let vec: Vec<f32> = WithApproximateNorm::with_approximate_norm(dim, norm, &mut rng);
102 let computed_norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
103 let tolerance = 1e-5;
104 assert!((computed_norm - norm).abs() / norm < tolerance);
105 }
106
107 #[rstest]
108 #[case(1, 0.01)]
109 #[case(100, 0.01)]
110 #[case(171, 5.0)]
111 #[case(1024, 100.7)]
112 fn test_generate_random_vector_with_norm_half_f16(#[case] dim: usize, #[case] norm: f32) {
113 let seed = 42;
114 let mut rng = StdRng::seed_from_u64(seed);
115 let vec: Vec<half::f16> = WithApproximateNorm::with_approximate_norm(dim, norm, &mut rng);
116 let computed_norm: f32 = vec
117 .iter()
118 .map(|x| {
119 let val: f32 = x.to_f32();
120 val * val
121 })
122 .sum::<f32>()
123 .sqrt();
124 let tolerance = 1e-2; assert!((computed_norm - norm).abs() / norm < tolerance);
126 }
127
128 #[rstest]
129 #[case(17, 50.0)]
130 #[case(1024, 1007.0)]
131 fn test_generate_random_vector_with_norm_u8(#[case] dim: usize, #[case] norm: f32) {
132 let seed = 42;
133 let mut rng = StdRng::seed_from_u64(seed);
134 let vec: Vec<u8> = WithApproximateNorm::with_approximate_norm(dim, norm, &mut rng);
135 let computed_norm: f32 = vec
136 .iter()
137 .map(|&x| {
138 let val: f32 = x as f32;
139 val * val
140 })
141 .sum::<f32>()
142 .sqrt();
143 let tolerance = 1e-1; assert!((computed_norm - norm).abs() / norm < tolerance);
145 }
146
147 #[rstest]
148 #[case(17, 50.0)]
149 #[case(1024, 1007.0)]
150 fn test_generate_random_vector_with_norm_i8(#[case] dim: usize, #[case] norm: f32) {
151 let seed = 42;
152 let mut rng = StdRng::seed_from_u64(seed);
153 let vec: Vec<i8> = WithApproximateNorm::with_approximate_norm(dim, norm, &mut rng);
154 let computed_norm: f32 = vec
155 .iter()
156 .map(|&x| {
157 let val: f32 = x as f32;
158 val * val
159 })
160 .sum::<f32>()
161 .sqrt();
162 let tolerance = 1e-1; assert!((computed_norm - norm).abs() / norm < tolerance);
164 }
165
166 #[rstest]
167 #[case(3.6f32, 4i8)]
168 #[case(2.3f32, 2i8)]
169 #[case(-1.5f32, -2i8)]
170 fn test_round_f32_to_i8(#[case] input: f32, #[case] expected: i8) {
171 let result: i8 = RoundFromf32::round_from_f32(input);
172 assert_eq!(result, expected);
173 }
174
175 #[rstest]
176 #[case(3.6f32, 4u8)]
177 #[case(2.3f32, 2u8)]
178 #[case(-1.5f32, 0u8)]
179 fn test_round_f32_to_u8(#[case] input: f32, #[case] expected: u8) {
180 let result: u8 = RoundFromf32::round_from_f32(input);
181 assert_eq!(result, expected);
182 }
183
184 #[rstest]
185 #[case(3.6f32, half::f16::from_f32(3.6f32))]
186 #[case(2.3f32, half::f16::from_f32(2.3f32))]
187 #[case(-1.5f32, half::f16::from_f32(-1.5f32))]
188 fn test_round_f32_to_f16(#[case] input: f32, #[case] expected: half::f16) {
189 let result: half::f16 = RoundFromf32::round_from_f32(input);
190 assert_eq!(result, expected);
191 }
192
193 #[rstest]
194 #[case(3.6f32, 3.6f32)]
195 #[case(2.3f32, 2.3f32)]
196 #[case(-1.5f32, -1.5f32)]
197 fn test_round_f32_to_f32(#[case] input: f32, #[case] expected: f32) {
198 let result: f32 = RoundFromf32::round_from_f32(input);
199 assert_eq!(result, expected);
200 }
201
202 #[rstest]
216 #[case(true, 500, 3.0, 42)]
217 #[case(true, 500, 3.0, 43)]
218 #[case(true, 500, 3.0, 44)]
219 #[case(false, 500, 3.0, 42)]
220 #[case(false, 500, 3.0, 43)]
221 #[case(false, 500, 3.0, 44)]
222 fn test_generate_random_vector_with_norm_signed_produces_uniform_distribution_on_circle(
223 #[case] signed: bool,
224 #[case] expected_per_bucket: usize,
225 #[case] tolerance_sigmas: f32,
226 #[case] seed: u64,
227 ) {
228 let dim = 2;
229 let norm = 1.0;
230 let mut rng = StdRng::seed_from_u64(seed);
231
232 let num_buckets = if signed { 36 } else { 9 };
234 let num_samples = num_buckets * expected_per_bucket;
235
236 let samples: Vec<Vec<f32>> = (0..num_samples)
238 .map(|_| generate_random_vector_with_norm_signed(dim, norm, signed, &mut rng, |x| x))
239 .collect();
240
241 let mut counts = vec![0usize; num_buckets];
243
244 for sample in &samples {
245 let theta = sample[1].atan2(sample[0]); let bucket = if signed {
249 let normalized_theta = if theta < 0.0 {
251 theta + 2.0 * std::f32::consts::PI
252 } else {
253 theta
254 };
255 ((normalized_theta / (2.0 * std::f32::consts::PI)) * num_buckets as f32).floor()
256 as usize
257 % num_buckets
258 } else {
259 ((theta / (std::f32::consts::PI / 2.0)) * num_buckets as f32).floor() as usize
261 };
262
263 counts[bucket] += 1;
264 }
265
266 let sigma = (expected_per_bucket as f32).sqrt();
270 let threshold = tolerance_sigmas / sigma;
271
272 let failed_count = counts
273 .iter()
274 .filter(|&&observed| {
275 let deviation = (observed as f32 - expected_per_bucket as f32).abs()
276 / expected_per_bucket as f32;
277 deviation > threshold
278 })
279 .count();
280
281 assert_eq!(
282 failed_count,
283 0,
284 "Distribution not uniform: {} out of {} bucket(s) had point counts that deviated more than {}σ from expected. \
285 This indicates the generator is producing clustered points instead of evenly distributed points on the circle surface.",
286 failed_count,
287 num_buckets,
288 tolerance_sigmas
289 );
290 }
291}