dcrypt_algorithms/poly/
sampling.rs1#![cfg_attr(not(feature = "std"), no_std)]
4
5use super::params::Modulus;
6use super::polynomial::Polynomial;
7use crate::error::{Error, Result};
8use rand::{CryptoRng, RngCore};
9
10pub trait UniformSampler<M: Modulus> {
12 fn sample_uniform<R: RngCore + CryptoRng>(rng: &mut R) -> Result<Polynomial<M>>;
14}
15
16pub trait CbdSampler<M: Modulus> {
18 fn sample_cbd<R: RngCore + CryptoRng>(rng: &mut R, eta: u8) -> Result<Polynomial<M>>;
20}
21
22pub trait GaussianSampler<M: Modulus> {
24 fn sample_gaussian<R: RngCore + CryptoRng>(rng: &mut R, sigma: f64) -> Result<Polynomial<M>>;
26}
27
28pub struct DefaultSamplers;
30
31impl<M: Modulus> UniformSampler<M> for DefaultSamplers {
32 fn sample_uniform<R: RngCore + CryptoRng>(rng: &mut R) -> Result<Polynomial<M>> {
33 let mut poly = Polynomial::<M>::zero();
34 let q = M::Q;
35
36 if q <= (1 << 16) {
38 sample_uniform_small::<M, R>(rng, &mut poly)?;
40 } else if q <= (1 << 24) {
41 sample_uniform_medium::<M, R>(rng, &mut poly)?;
43 } else {
44 sample_uniform_large::<M, R>(rng, &mut poly)?;
46 }
47
48 Ok(poly)
49 }
50}
51
52fn sample_uniform_small<M: Modulus, R: RngCore + CryptoRng>(
54 rng: &mut R,
55 poly: &mut Polynomial<M>,
56) -> Result<()> {
57 let q = M::Q;
58 let n = M::N;
59
60 let threshold = ((1u32 << 16) / q) * q;
62
63 for i in 0..n {
64 loop {
65 let mut bytes = [0u8; 2];
66 rng.fill_bytes(&mut bytes);
67 let sample = u16::from_le_bytes(bytes) as u32;
68
69 if sample < threshold {
71 poly.coeffs[i] = sample % q;
72 break;
73 }
74 }
75 }
76
77 Ok(())
78}
79
80fn sample_uniform_medium<M: Modulus, R: RngCore + CryptoRng>(
82 rng: &mut R,
83 poly: &mut Polynomial<M>,
84) -> Result<()> {
85 let q = M::Q;
86 let n = M::N;
87
88 let threshold = ((1u32 << 24) / q) * q;
90
91 for i in 0..n {
92 loop {
93 let mut bytes = [0u8; 3];
94 rng.fill_bytes(&mut bytes);
95 let sample = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], 0]);
96
97 if sample < threshold {
98 poly.coeffs[i] = sample % q;
99 break;
100 }
101 }
102 }
103
104 Ok(())
105}
106
107fn sample_uniform_large<M: Modulus, R: RngCore + CryptoRng>(
109 rng: &mut R,
110 poly: &mut Polynomial<M>,
111) -> Result<()> {
112 let q = M::Q;
113 let n = M::N;
114
115 let threshold = ((1u32 << 31) / q) * q;
117
118 for i in 0..n {
119 loop {
120 let mut bytes = [0u8; 4];
121 rng.fill_bytes(&mut bytes);
122 bytes[3] &= 0x7F; let sample = u32::from_le_bytes(bytes);
124
125 if sample < threshold {
126 poly.coeffs[i] = sample % q;
127 break;
128 }
129 }
130 }
131
132 Ok(())
133}
134
135impl<M: Modulus> CbdSampler<M> for DefaultSamplers {
136 fn sample_cbd<R: RngCore + CryptoRng>(rng: &mut R, eta: u8) -> Result<Polynomial<M>> {
137 if eta == 0 || eta > 16 {
138 return Err(Error::Parameter {
139 name: "CBD sampling".into(),
140 reason: format!("eta must be in range [1, 16], got {}", eta).into(),
141 });
142 }
143
144 let mut poly = Polynomial::<M>::zero();
145 let n = M::N;
146 let q = M::Q;
147
148 let bytes_per_sample = (2 * eta as usize).div_ceil(8); let mut buffer = [0u8; 4]; for i in 0..n {
153 rng.fill_bytes(&mut buffer[..bytes_per_sample]);
154
155 let mut a = 0i32;
156 let mut b = 0i32;
157
158 for j in 0..eta {
160 let byte_idx = j as usize / 8;
161 let bit_idx = j as usize % 8;
162 a += ((buffer[byte_idx] >> bit_idx) & 1) as i32;
163 }
164
165 for j in 0..eta {
167 let bit_pos = (eta + j) as usize;
168 let byte_idx = bit_pos / 8;
169 let bit_idx = bit_pos % 8;
170 b += ((buffer[byte_idx] >> bit_idx) & 1) as i32;
171 }
172
173 let sample = a - b;
175
176 poly.coeffs[i] = ((sample + q as i32) % q as i32) as u32;
178 }
179
180 Ok(poly)
181 }
182}
183
184impl<M: Modulus> GaussianSampler<M> for DefaultSamplers {
185 fn sample_gaussian<R: RngCore + CryptoRng>(_rng: &mut R, _sigma: f64) -> Result<Polynomial<M>> {
186 Err(Error::NotImplemented {
188 feature: "Gaussian sampler (reserved for Falcon phase)",
189 })
190 }
191}
192
193#[cfg(test)]
194mod tests {
195 use super::*;
196 use rand::rngs::StdRng;
197 use rand::SeedableRng;
198
199 #[derive(Clone)]
200 struct TestModulus;
201 impl Modulus for TestModulus {
202 const Q: u32 = 3329;
203 const N: usize = 256;
204 }
205
206 #[test]
207 fn test_uniform_sampling() {
208 let mut rng = StdRng::seed_from_u64(42);
209 let poly =
210 <DefaultSamplers as UniformSampler<TestModulus>>::sample_uniform(&mut rng).unwrap();
211
212 for &coeff in poly.as_coeffs_slice() {
214 assert!(coeff < TestModulus::Q);
215 }
216 }
217
218 #[test]
219 fn test_cbd_sampling() {
220 let mut rng = StdRng::seed_from_u64(42);
221
222 for eta in 1..=8 {
223 let poly =
224 <DefaultSamplers as CbdSampler<TestModulus>>::sample_cbd(&mut rng, eta).unwrap();
225
226 for &coeff in poly.as_coeffs_slice() {
228 assert!(coeff < TestModulus::Q);
229 }
230 }
231 }
232
233 #[test]
234 fn test_cbd_distribution() {
235 let mut rng = StdRng::seed_from_u64(42);
237 let eta = 2;
238 let num_samples = 10000;
239 let mut histogram = vec![0u32; (2 * eta + 1) as usize];
240
241 for _ in 0..num_samples {
242 let poly =
243 <DefaultSamplers as CbdSampler<TestModulus>>::sample_cbd(&mut rng, eta).unwrap();
244
245 let coeff = poly.coeffs[0];
247 let centered = (coeff as i32 + eta as i32) % TestModulus::Q as i32;
248 if centered <= 2 * eta as i32 {
249 histogram[centered as usize] += 1;
250 }
251 }
252
253 let expected = [625, 2500, 3750, 2500, 625]; let mut chi_squared = 0.0;
260 for i in 0..histogram.len() {
261 let observed = histogram[i] as f64;
262 let expected_val = expected[i] as f64;
263 chi_squared += (observed - expected_val).powi(2) / expected_val;
264 }
265
266 assert!(
268 chi_squared < 15.0,
269 "Chi-squared test failed: {}",
270 chi_squared
271 );
272 }
273
274 #[test]
275 fn test_gaussian_not_implemented() {
276 let mut rng = StdRng::seed_from_u64(42);
277 let result =
278 <DefaultSamplers as GaussianSampler<TestModulus>>::sample_gaussian(&mut rng, 1.0);
279 assert!(matches!(result, Err(Error::NotImplemented { .. })));
280 }
281}