dcrypt_algorithms/poly/
sampling.rs

1//! sampling.rs - Cryptographic sampling algorithms
2
3#![cfg_attr(not(feature = "std"), no_std)]
4
5use super::params::Modulus;
6use super::polynomial::Polynomial;
7use crate::error::{Error, Result};
8use rand::{CryptoRng, RngCore};
9
10/// Trait for sampling polynomials uniformly at random
11pub trait UniformSampler<M: Modulus> {
12    /// Samples a polynomial with coefficients uniformly random in [0, Q-1]
13    fn sample_uniform<R: RngCore + CryptoRng>(rng: &mut R) -> Result<Polynomial<M>>;
14}
15
16/// Trait for sampling polynomials from a Centered Binomial Distribution (CBD)
17pub trait CbdSampler<M: Modulus> {
18    /// Samples a polynomial with coefficients from CBD(eta)
19    fn sample_cbd<R: RngCore + CryptoRng>(rng: &mut R, eta: u8) -> Result<Polynomial<M>>;
20}
21
22/// Trait for sampling polynomials from a discrete Gaussian distribution
23pub trait GaussianSampler<M: Modulus> {
24    /// Samples a polynomial with coefficients from a discrete Gaussian distribution
25    fn sample_gaussian<R: RngCore + CryptoRng>(rng: &mut R, sigma: f64) -> Result<Polynomial<M>>;
26}
27
28/// Default implementation of cryptographic samplers
29pub struct DefaultSamplers;
30
31impl<M: Modulus> UniformSampler<M> for DefaultSamplers {
32    fn sample_uniform<R: RngCore + CryptoRng>(rng: &mut R) -> Result<Polynomial<M>> {
33        let mut poly = Polynomial::<M>::zero();
34        let q = M::Q;
35
36        // Handle different modulus sizes
37        if q <= (1 << 16) {
38            // For small moduli, use rejection sampling with u16
39            sample_uniform_small::<M, R>(rng, &mut poly)?;
40        } else if q <= (1 << 24) {
41            // For medium moduli, use rejection sampling with u32
42            sample_uniform_medium::<M, R>(rng, &mut poly)?;
43        } else {
44            // For large moduli up to 2^31
45            sample_uniform_large::<M, R>(rng, &mut poly)?;
46        }
47
48        Ok(poly)
49    }
50}
51
52/// Rejection sampling for small moduli (Q <= 2^16)
53fn sample_uniform_small<M: Modulus, R: RngCore + CryptoRng>(
54    rng: &mut R,
55    poly: &mut Polynomial<M>,
56) -> Result<()> {
57    let q = M::Q;
58    let n = M::N;
59
60    // Find the largest multiple of q that fits in u16
61    let threshold = ((1u32 << 16) / q) * q;
62
63    for i in 0..n {
64        loop {
65            let mut bytes = [0u8; 2];
66            rng.fill_bytes(&mut bytes);
67            let sample = u16::from_le_bytes(bytes) as u32;
68
69            // Rejection sampling for uniform distribution
70            if sample < threshold {
71                poly.coeffs[i] = sample % q;
72                break;
73            }
74        }
75    }
76
77    Ok(())
78}
79
80/// Rejection sampling for medium moduli (2^16 < Q <= 2^24)
81fn sample_uniform_medium<M: Modulus, R: RngCore + CryptoRng>(
82    rng: &mut R,
83    poly: &mut Polynomial<M>,
84) -> Result<()> {
85    let q = M::Q;
86    let n = M::N;
87
88    // Use 3 bytes for sampling
89    let threshold = ((1u32 << 24) / q) * q;
90
91    for i in 0..n {
92        loop {
93            let mut bytes = [0u8; 3];
94            rng.fill_bytes(&mut bytes);
95            let sample = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], 0]);
96
97            if sample < threshold {
98                poly.coeffs[i] = sample % q;
99                break;
100            }
101        }
102    }
103
104    Ok(())
105}
106
107/// Rejection sampling for large moduli (2^24 < Q <= 2^31)
108fn sample_uniform_large<M: Modulus, R: RngCore + CryptoRng>(
109    rng: &mut R,
110    poly: &mut Polynomial<M>,
111) -> Result<()> {
112    let q = M::Q;
113    let n = M::N;
114
115    // Use full u32 with MSB clear to ensure < 2^31
116    let threshold = ((1u32 << 31) / q) * q;
117
118    for i in 0..n {
119        loop {
120            let mut bytes = [0u8; 4];
121            rng.fill_bytes(&mut bytes);
122            bytes[3] &= 0x7F; // Clear MSB
123            let sample = u32::from_le_bytes(bytes);
124
125            if sample < threshold {
126                poly.coeffs[i] = sample % q;
127                break;
128            }
129        }
130    }
131
132    Ok(())
133}
134
135impl<M: Modulus> CbdSampler<M> for DefaultSamplers {
136    fn sample_cbd<R: RngCore + CryptoRng>(rng: &mut R, eta: u8) -> Result<Polynomial<M>> {
137        if eta == 0 || eta > 16 {
138            return Err(Error::Parameter {
139                name: "CBD sampling".into(),
140                reason: format!("eta must be in range [1, 16], got {}", eta).into(),
141            });
142        }
143
144        let mut poly = Polynomial::<M>::zero();
145        let n = M::N;
146        let q = M::Q;
147
148        // CBD(eta): sample 2*eta bits, compute sum of first eta bits minus sum of second eta bits
149        let bytes_per_sample = (2 * eta as usize).div_ceil(8); // FIXED: Use div_ceil
150        let mut buffer = [0u8; 4]; // Max 32 bits for eta=16
151
152        for i in 0..n {
153            rng.fill_bytes(&mut buffer[..bytes_per_sample]);
154
155            let mut a = 0i32;
156            let mut b = 0i32;
157
158            // Extract eta bits for positive contribution
159            for j in 0..eta {
160                let byte_idx = j as usize / 8;
161                let bit_idx = j as usize % 8;
162                a += ((buffer[byte_idx] >> bit_idx) & 1) as i32;
163            }
164
165            // Extract eta bits for negative contribution
166            for j in 0..eta {
167                let bit_pos = (eta + j) as usize;
168                let byte_idx = bit_pos / 8;
169                let bit_idx = bit_pos % 8;
170                b += ((buffer[byte_idx] >> bit_idx) & 1) as i32;
171            }
172
173            // CBD sample is in range [-eta, eta]
174            let sample = a - b;
175
176            // Convert to [0, q) range
177            poly.coeffs[i] = ((sample + q as i32) % q as i32) as u32;
178        }
179
180        Ok(poly)
181    }
182}
183
184impl<M: Modulus> GaussianSampler<M> for DefaultSamplers {
185    fn sample_gaussian<R: RngCore + CryptoRng>(_rng: &mut R, _sigma: f64) -> Result<Polynomial<M>> {
186        // Gaussian sampling is complex and will be implemented in Falcon phase
187        Err(Error::NotImplemented {
188            feature: "Gaussian sampler (reserved for Falcon phase)",
189        })
190    }
191}
192
193#[cfg(test)]
194mod tests {
195    use super::*;
196    use rand::rngs::StdRng;
197    use rand::SeedableRng;
198
199    #[derive(Clone)]
200    struct TestModulus;
201    impl Modulus for TestModulus {
202        const Q: u32 = 3329;
203        const N: usize = 256;
204    }
205
206    #[test]
207    fn test_uniform_sampling() {
208        let mut rng = StdRng::seed_from_u64(42);
209        let poly =
210            <DefaultSamplers as UniformSampler<TestModulus>>::sample_uniform(&mut rng).unwrap();
211
212        // Check all coefficients are in valid range
213        for &coeff in poly.as_coeffs_slice() {
214            assert!(coeff < TestModulus::Q);
215        }
216    }
217
218    #[test]
219    fn test_cbd_sampling() {
220        let mut rng = StdRng::seed_from_u64(42);
221
222        for eta in 1..=8 {
223            let poly =
224                <DefaultSamplers as CbdSampler<TestModulus>>::sample_cbd(&mut rng, eta).unwrap();
225
226            // Check all coefficients are in valid range
227            for &coeff in poly.as_coeffs_slice() {
228                assert!(coeff < TestModulus::Q);
229            }
230        }
231    }
232
233    #[test]
234    fn test_cbd_distribution() {
235        // Simple statistical test for CBD
236        let mut rng = StdRng::seed_from_u64(42);
237        let eta = 2;
238        let num_samples = 10000;
239        let mut histogram = vec![0u32; (2 * eta + 1) as usize];
240
241        for _ in 0..num_samples {
242            let poly =
243                <DefaultSamplers as CbdSampler<TestModulus>>::sample_cbd(&mut rng, eta).unwrap();
244
245            // Check first coefficient distribution
246            let coeff = poly.coeffs[0];
247            let centered = (coeff as i32 + eta as i32) % TestModulus::Q as i32;
248            if centered <= 2 * eta as i32 {
249                histogram[centered as usize] += 1;
250            }
251        }
252
253        // CBD(2) should have distribution:
254        // P(X = -2) = 1/16, P(X = -1) = 4/16, P(X = 0) = 6/16,
255        // P(X = 1) = 4/16, P(X = 2) = 1/16
256        let expected = [625, 2500, 3750, 2500, 625]; // Out of 10000
257
258        // Chi-squared test with reasonable tolerance
259        let mut chi_squared = 0.0;
260        for i in 0..histogram.len() {
261            let observed = histogram[i] as f64;
262            let expected_val = expected[i] as f64;
263            chi_squared += (observed - expected_val).powi(2) / expected_val;
264        }
265
266        // Degrees of freedom = 4, critical value at 0.05 significance ≈ 9.488
267        assert!(
268            chi_squared < 15.0,
269            "Chi-squared test failed: {}",
270            chi_squared
271        );
272    }
273
274    #[test]
275    fn test_gaussian_not_implemented() {
276        let mut rng = StdRng::seed_from_u64(42);
277        let result =
278            <DefaultSamplers as GaussianSampler<TestModulus>>::sample_gaussian(&mut rng, 1.0);
279        assert!(matches!(result, Err(Error::NotImplemented { .. })));
280    }
281}