qfall_math/utils/sample/
uniform.rs

1// Copyright © 2025 Niklas Siemer
2//
3// This file is part of qFALL-math.
4//
5// qFALL-math is free software: you can redistribute it and/or modify it under
6// the terms of the Mozilla Public License Version 2.0 as published by the
7// Mozilla Foundation. See <https://mozilla.org/en-US/MPL/2.0/>.
8
9//! This module includes core functionality to sample according to the
10//! uniform random distribution.
11
12use crate::{error::MathError, integer::Z};
13use flint_sys::fmpz::{fmpz_addmul_ui, fmpz_set_ui};
14use rand::{RngCore, rngs::ThreadRng};
15
16/// Enables uniformly random sampling a [`Z`] in `[0, interval_size)`.
17///
18/// Attributes:
19/// - `interval_size`: defines the interval [0, interval_size), which we sample from
20/// - `two_pow_32`: is a helper to shift bits by 32-bits left by multiplication
21/// - `nr_iterations`: defines how many full samples of u32 are required
22/// - `upper_modulo`: is a power of two to remove superfluously sampled bits to increase
23///   the probability of accepting a sample to at least 1/2
24/// - `rng`: defines the [`ThreadRng`] that's used to sample uniform [u32] integers
25///
26/// # Examples
27/// ```
28/// use qfall_math::{utils::sample::uniform::UniformIntegerSampler, integer::Z};
29/// let interval_size = Z::from(20);
30///
31/// let mut uis = UniformIntegerSampler::init(&interval_size).unwrap();
32///
33/// let sample = uis.sample();
34///
35/// assert!(Z::ZERO <= sample);
36/// assert!(sample < interval_size);
37/// ```
38pub struct UniformIntegerSampler {
39    interval_size: Z,
40    two_pow_32: u64,
41    nr_iterations: u32,
42    upper_modulo: u32,
43    rng: ThreadRng,
44}
45
46impl UniformIntegerSampler {
47    /// Initializes the [`UniformIntegerSampler`] with
48    /// - `interval_size` as `interval_size`,
49    /// - `two_pow_32` as a [u64] containing 2^32
50    /// - `nr_iterations` as `(interval_size - 1).bits() / 32` floored
51    /// - `upper_modulo` as 2^{(interval_size - 1).bits() mod 32}
52    /// - `rng` as a fresh [`ThreadRng`]
53    ///
54    /// Parameters:
55    /// - `interval_size`: specifies the interval `[0, interval_size)`
56    ///   from which the samples are drawn
57    ///
58    /// Returns a [`UniformIntegerSampler`] or a [`MathError`],
59    /// if the interval size is chosen smaller than or equal to `1`.
60    ///
61    /// # Examples
62    /// ```
63    /// use qfall_math::{utils::sample::uniform::UniformIntegerSampler, integer::Z};
64    /// let interval_size = Z::from(20);
65    ///
66    /// let mut uis = UniformIntegerSampler::init(&interval_size).unwrap();
67    /// ```
68    ///
69    /// # Errors and Failures
70    /// - Returns a [`MathError`] of type [`InvalidInterval`](MathError::InvalidInterval)
71    ///   if the interval is chosen smaller than `1`.
72    pub fn init(interval_size: &Z) -> Result<Self, MathError> {
73        if interval_size < &Z::ONE {
74            return Err(MathError::InvalidInterval(format!(
75                "An invalid interval size {interval_size} was provided."
76            )));
77        }
78
79        // Compute 2^32 to be able to shift bits to the left
80        // by 32 bits using multiplication
81        let two_pow_32 = u32::MAX as u64 + 1;
82
83        let bit_size = (interval_size - Z::ONE).bits() as u32;
84
85        // div rounds towards 0, i.e. div_floor in this case, i.e. this is
86        // perfect for sampling the top one first and then iterating
87        // nr_iterations-many times
88        let nr_iterations = bit_size / 32;
89
90        // Set upper_modulo to 2^{bit_size mod 32}
91        // defines how many bits will be discarded / have been sampled too much
92        let upper_modulo = 2_u32.pow(bit_size % 32);
93
94        let rng = rand::rng();
95
96        Ok(Self {
97            interval_size: interval_size.clone(),
98            two_pow_32,
99            nr_iterations,
100            upper_modulo,
101            rng,
102        })
103    }
104
105    /// Computes a uniformly chosen [`Z`] sample in `[0, interval_size)`
106    /// using rejection sampling that accepts samples with probability at least 1/2.
107    ///
108    /// # Examples
109    /// ```
110    /// use qfall_math::{utils::sample::uniform::UniformIntegerSampler, integer::Z};
111    /// let interval_size = Z::from(20);
112    ///
113    /// let mut uis = UniformIntegerSampler::init(&interval_size).unwrap();
114    ///
115    /// let sample = uis.sample();
116    ///
117    /// assert!(Z::ZERO <= sample);
118    /// assert!(sample < interval_size);
119    /// ```
120    pub fn sample(&mut self) -> Z {
121        if self.interval_size.is_one() {
122            return Z::ZERO;
123        }
124
125        let mut sample = self.sample_bits_uniform();
126        while sample >= self.interval_size {
127            sample = self.sample_bits_uniform();
128        }
129
130        sample
131    }
132
133    /// Computes `self.nr_iterations * 32 + upper_modulo` many uniformly chosen bits.
134    ///
135    /// Returns a [`Z`] containing `self.nr_iterations * 32 + upper_modulo`-many uniformly
136    /// chosen bits.
137    ///
138    /// # Examples
139    /// ```
140    /// use qfall_math::{utils::sample::uniform::UniformIntegerSampler, integer::Z};
141    /// let interval = Z::from(u16::MAX) + 1;
142    ///
143    /// let mut uis = UniformIntegerSampler::init(&interval).unwrap();
144    ///
145    /// let sample = uis.sample_bits_uniform();
146    ///
147    /// assert!(Z::ZERO <= sample);
148    /// assert!(sample < interval);
149    /// ```
150    pub fn sample_bits_uniform(&mut self) -> Z {
151        // remove superfluously sampled bits to increase chance of acception to at lest 1/2
152        let mut value = Z::from(self.rng.next_u32() % self.upper_modulo);
153
154        for _ in 0..self.nr_iterations {
155            let sample = self.rng.next_u32();
156
157            let mut res = Z::default();
158            unsafe {
159                fmpz_set_ui(&mut res.value, sample as u64);
160                // Sets res = res + value * 2^32 reusing the memory allocated of res
161                // could be optimized by shifting bits left by 32 bits once lshift is part of flint-sys
162                fmpz_addmul_ui(&mut res.value, &value.value, self.two_pow_32);
163            };
164            value = res;
165        }
166
167        value
168    }
169}
170
171#[cfg(test)]
172mod test_uis {
173    use super::{UniformIntegerSampler, Z};
174    use std::collections::HashSet;
175
176    /// Checks whether sampling works fine for small interval sizes.
177    #[test]
178    fn small_interval() {
179        let size_2 = Z::from(2);
180        let size_7 = Z::from(7);
181
182        let mut uis_2 = UniformIntegerSampler::init(&size_2).unwrap();
183        let mut uis_7 = UniformIntegerSampler::init(&size_7).unwrap();
184
185        for _ in 0..3 {
186            let sample_2 = uis_2.sample();
187            let sample_7 = uis_7.sample();
188
189            assert!(Z::ZERO <= sample_2);
190            assert!(sample_2 < size_2);
191            assert!(Z::ZERO <= sample_7);
192            assert!(sample_7 < size_7)
193        }
194    }
195
196    /// Checks whether sampling works fine for large interval sizes.
197    #[test]
198    fn large_interval() {
199        let size_0 = Z::from(u64::MAX);
200        let size_1 = Z::from(u64::MAX) * 2 + 1;
201
202        let mut uis_0 = UniformIntegerSampler::init(&size_0).unwrap();
203        let mut uis_1 = UniformIntegerSampler::init(&size_1).unwrap();
204
205        for _i in 0..u8::MAX {
206            let sample_0 = uis_0.sample();
207            let sample_1 = uis_1.sample();
208
209            assert!(Z::ZERO <= sample_0);
210            assert!(sample_0 < size_0);
211            assert!(Z::ZERO <= sample_1);
212            assert!(sample_1 < size_1);
213        }
214    }
215
216    /// Checks whether it samples from the entire interval.
217    #[test]
218    fn entire_interval() {
219        let interval_sizes = vec![6, 7, 16];
220
221        for interval_size in interval_sizes {
222            let interval = Z::from(interval_size);
223
224            let mut uis = UniformIntegerSampler::init(&interval).unwrap();
225
226            let mut samples = HashSet::new();
227            for _ in 0..2_u32.pow(interval_size) {
228                samples.insert(uis.sample());
229            }
230            // if len(samples) == interval_size, then every element in [0, interval_size)
231            // needs to be represented in samples
232            assert_eq!(
233                interval_size,
234                samples.len() as u32,
235                "This test may fail with low probability."
236            );
237        }
238    }
239
240    /// Checks whether interval sizes smaller than 2 result in an error.
241    #[test]
242    fn invalid_interval() {
243        assert!(UniformIntegerSampler::init(&Z::ZERO).is_err());
244        assert!(UniformIntegerSampler::init(&Z::MINUS_ONE).is_err());
245    }
246
247    /// Checks whether random bit sampling doesn't fill more bits than required.
248    #[test]
249    fn sample_bits_uniform_necessary_nr_bytes() {
250        let size_0 = Z::from(8);
251        let size_1 = Z::from(256);
252        let size_2 = Z::from(u32::MAX) + Z::ONE;
253
254        let mut uis_0 = UniformIntegerSampler::init(&size_0).unwrap();
255        let mut uis_1 = UniformIntegerSampler::init(&size_1).unwrap();
256        let mut uis_2 = UniformIntegerSampler::init(&size_2).unwrap();
257
258        for _ in 0..u8::MAX {
259            let sample_0 = uis_0.sample_bits_uniform();
260            let sample_1 = uis_1.sample_bits_uniform();
261            let sample_2 = uis_2.sample_bits_uniform();
262
263            assert!(Z::ZERO <= sample_0);
264            assert!(sample_0 < size_0);
265            assert!(Z::ZERO <= sample_1);
266            assert!(sample_1 < size_1);
267            assert!(Z::ZERO <= sample_2);
268            assert!(sample_2 < size_2);
269        }
270    }
271}