qfall_math/utils/sample/uniform.rs
1// Copyright © 2025 Niklas Siemer
2//
3// This file is part of qFALL-math.
4//
5// qFALL-math is free software: you can redistribute it and/or modify it under
6// the terms of the Mozilla Public License Version 2.0 as published by the
7// Mozilla Foundation. See <https://mozilla.org/en-US/MPL/2.0/>.
8
9//! This module includes core functionality to sample according to the
10//! uniform random distribution.
11
12use crate::{error::MathError, integer::Z};
13use flint_sys::fmpz::{fmpz_addmul_ui, fmpz_set_ui};
14use rand::{RngCore, rngs::ThreadRng};
15
16/// Enables uniformly random sampling a [`Z`] in `[0, interval_size)`.
17///
18/// Attributes:
19/// - `interval_size`: defines the interval [0, interval_size), which we sample from
20/// - `two_pow_32`: is a helper to shift bits by 32-bits left by multiplication
21/// - `nr_iterations`: defines how many full samples of u32 are required
22/// - `upper_modulo`: is a power of two to remove superfluously sampled bits to increase
23/// the probability of accepting a sample to at least 1/2
24/// - `rng`: defines the [`ThreadRng`] that's used to sample uniform [u32] integers
25///
26/// # Examples
27/// ```
28/// use qfall_math::{utils::sample::uniform::UniformIntegerSampler, integer::Z};
29/// let interval_size = Z::from(20);
30///
31/// let mut uis = UniformIntegerSampler::init(&interval_size).unwrap();
32///
33/// let sample = uis.sample();
34///
35/// assert!(Z::ZERO <= sample);
36/// assert!(sample < interval_size);
37/// ```
38pub struct UniformIntegerSampler {
39 interval_size: Z,
40 two_pow_32: u64,
41 nr_iterations: u32,
42 upper_modulo: u32,
43 rng: ThreadRng,
44}
45
46impl UniformIntegerSampler {
47 /// Initializes the [`UniformIntegerSampler`] with
48 /// - `interval_size` as `interval_size`,
49 /// - `two_pow_32` as a [u64] containing 2^32
50 /// - `nr_iterations` as `(interval_size - 1).bits() / 32` floored
51 /// - `upper_modulo` as 2^{(interval_size - 1).bits() mod 32}
52 /// - `rng` as a fresh [`ThreadRng`]
53 ///
54 /// Parameters:
55 /// - `interval_size`: specifies the interval `[0, interval_size)`
56 /// from which the samples are drawn
57 ///
58 /// Returns a [`UniformIntegerSampler`] or a [`MathError`],
59 /// if the interval size is chosen smaller than or equal to `1`.
60 ///
61 /// # Examples
62 /// ```
63 /// use qfall_math::{utils::sample::uniform::UniformIntegerSampler, integer::Z};
64 /// let interval_size = Z::from(20);
65 ///
66 /// let mut uis = UniformIntegerSampler::init(&interval_size).unwrap();
67 /// ```
68 ///
69 /// # Errors and Failures
70 /// - Returns a [`MathError`] of type [`InvalidInterval`](MathError::InvalidInterval)
71 /// if the interval is chosen smaller than `1`.
72 pub fn init(interval_size: &Z) -> Result<Self, MathError> {
73 if interval_size < &Z::ONE {
74 return Err(MathError::InvalidInterval(format!(
75 "An invalid interval size {interval_size} was provided."
76 )));
77 }
78
79 // Compute 2^32 to be able to shift bits to the left
80 // by 32 bits using multiplication
81 let two_pow_32 = u32::MAX as u64 + 1;
82
83 let bit_size = (interval_size - Z::ONE).bits() as u32;
84
85 // div rounds towards 0, i.e. div_floor in this case, i.e. this is
86 // perfect for sampling the top one first and then iterating
87 // nr_iterations-many times
88 let nr_iterations = bit_size / 32;
89
90 // Set upper_modulo to 2^{bit_size mod 32}
91 // defines how many bits will be discarded / have been sampled too much
92 let upper_modulo = 2_u32.pow(bit_size % 32);
93
94 let rng = rand::rng();
95
96 Ok(Self {
97 interval_size: interval_size.clone(),
98 two_pow_32,
99 nr_iterations,
100 upper_modulo,
101 rng,
102 })
103 }
104
105 /// Computes a uniformly chosen [`Z`] sample in `[0, interval_size)`
106 /// using rejection sampling that accepts samples with probability at least 1/2.
107 ///
108 /// # Examples
109 /// ```
110 /// use qfall_math::{utils::sample::uniform::UniformIntegerSampler, integer::Z};
111 /// let interval_size = Z::from(20);
112 ///
113 /// let mut uis = UniformIntegerSampler::init(&interval_size).unwrap();
114 ///
115 /// let sample = uis.sample();
116 ///
117 /// assert!(Z::ZERO <= sample);
118 /// assert!(sample < interval_size);
119 /// ```
120 pub fn sample(&mut self) -> Z {
121 if self.interval_size.is_one() {
122 return Z::ZERO;
123 }
124
125 let mut sample = self.sample_bits_uniform();
126 while sample >= self.interval_size {
127 sample = self.sample_bits_uniform();
128 }
129
130 sample
131 }
132
133 /// Computes `self.nr_iterations * 32 + upper_modulo` many uniformly chosen bits.
134 ///
135 /// Returns a [`Z`] containing `self.nr_iterations * 32 + upper_modulo`-many uniformly
136 /// chosen bits.
137 ///
138 /// # Examples
139 /// ```
140 /// use qfall_math::{utils::sample::uniform::UniformIntegerSampler, integer::Z};
141 /// let interval = Z::from(u16::MAX) + 1;
142 ///
143 /// let mut uis = UniformIntegerSampler::init(&interval).unwrap();
144 ///
145 /// let sample = uis.sample_bits_uniform();
146 ///
147 /// assert!(Z::ZERO <= sample);
148 /// assert!(sample < interval);
149 /// ```
150 pub fn sample_bits_uniform(&mut self) -> Z {
151 // remove superfluously sampled bits to increase chance of acception to at lest 1/2
152 let mut value = Z::from(self.rng.next_u32() % self.upper_modulo);
153
154 for _ in 0..self.nr_iterations {
155 let sample = self.rng.next_u32();
156
157 let mut res = Z::default();
158 unsafe {
159 fmpz_set_ui(&mut res.value, sample as u64);
160 // Sets res = res + value * 2^32 reusing the memory allocated of res
161 // could be optimized by shifting bits left by 32 bits once lshift is part of flint-sys
162 fmpz_addmul_ui(&mut res.value, &value.value, self.two_pow_32);
163 };
164 value = res;
165 }
166
167 value
168 }
169}
170
171#[cfg(test)]
172mod test_uis {
173 use super::{UniformIntegerSampler, Z};
174 use std::collections::HashSet;
175
176 /// Checks whether sampling works fine for small interval sizes.
177 #[test]
178 fn small_interval() {
179 let size_2 = Z::from(2);
180 let size_7 = Z::from(7);
181
182 let mut uis_2 = UniformIntegerSampler::init(&size_2).unwrap();
183 let mut uis_7 = UniformIntegerSampler::init(&size_7).unwrap();
184
185 for _ in 0..3 {
186 let sample_2 = uis_2.sample();
187 let sample_7 = uis_7.sample();
188
189 assert!(Z::ZERO <= sample_2);
190 assert!(sample_2 < size_2);
191 assert!(Z::ZERO <= sample_7);
192 assert!(sample_7 < size_7)
193 }
194 }
195
196 /// Checks whether sampling works fine for large interval sizes.
197 #[test]
198 fn large_interval() {
199 let size_0 = Z::from(u64::MAX);
200 let size_1 = Z::from(u64::MAX) * 2 + 1;
201
202 let mut uis_0 = UniformIntegerSampler::init(&size_0).unwrap();
203 let mut uis_1 = UniformIntegerSampler::init(&size_1).unwrap();
204
205 for _i in 0..u8::MAX {
206 let sample_0 = uis_0.sample();
207 let sample_1 = uis_1.sample();
208
209 assert!(Z::ZERO <= sample_0);
210 assert!(sample_0 < size_0);
211 assert!(Z::ZERO <= sample_1);
212 assert!(sample_1 < size_1);
213 }
214 }
215
216 /// Checks whether it samples from the entire interval.
217 #[test]
218 fn entire_interval() {
219 let interval_sizes = vec![6, 7, 16];
220
221 for interval_size in interval_sizes {
222 let interval = Z::from(interval_size);
223
224 let mut uis = UniformIntegerSampler::init(&interval).unwrap();
225
226 let mut samples = HashSet::new();
227 for _ in 0..2_u32.pow(interval_size) {
228 samples.insert(uis.sample());
229 }
230 // if len(samples) == interval_size, then every element in [0, interval_size)
231 // needs to be represented in samples
232 assert_eq!(
233 interval_size,
234 samples.len() as u32,
235 "This test may fail with low probability."
236 );
237 }
238 }
239
240 /// Checks whether interval sizes smaller than 2 result in an error.
241 #[test]
242 fn invalid_interval() {
243 assert!(UniformIntegerSampler::init(&Z::ZERO).is_err());
244 assert!(UniformIntegerSampler::init(&Z::MINUS_ONE).is_err());
245 }
246
247 /// Checks whether random bit sampling doesn't fill more bits than required.
248 #[test]
249 fn sample_bits_uniform_necessary_nr_bytes() {
250 let size_0 = Z::from(8);
251 let size_1 = Z::from(256);
252 let size_2 = Z::from(u32::MAX) + Z::ONE;
253
254 let mut uis_0 = UniformIntegerSampler::init(&size_0).unwrap();
255 let mut uis_1 = UniformIntegerSampler::init(&size_1).unwrap();
256 let mut uis_2 = UniformIntegerSampler::init(&size_2).unwrap();
257
258 for _ in 0..u8::MAX {
259 let sample_0 = uis_0.sample_bits_uniform();
260 let sample_1 = uis_1.sample_bits_uniform();
261 let sample_2 = uis_2.sample_bits_uniform();
262
263 assert!(Z::ZERO <= sample_0);
264 assert!(sample_0 < size_0);
265 assert!(Z::ZERO <= sample_1);
266 assert!(sample_1 < size_1);
267 assert!(Z::ZERO <= sample_2);
268 assert!(sample_2 < size_2);
269 }
270 }
271}