#![cfg_attr(not(feature = "std"), no_std)]
use super::params::Modulus;
use super::polynomial::Polynomial;
use crate::error::{Error, Result};
use rand::{CryptoRng, RngCore};
pub trait UniformSampler<M: Modulus> {
fn sample_uniform<R: RngCore + CryptoRng>(rng: &mut R) -> Result<Polynomial<M>>;
}
pub trait CbdSampler<M: Modulus> {
fn sample_cbd<R: RngCore + CryptoRng>(rng: &mut R, eta: u8) -> Result<Polynomial<M>>;
}
pub trait GaussianSampler<M: Modulus> {
fn sample_gaussian<R: RngCore + CryptoRng>(rng: &mut R, sigma: f64) -> Result<Polynomial<M>>;
}
pub struct DefaultSamplers;
impl<M: Modulus> UniformSampler<M> for DefaultSamplers {
fn sample_uniform<R: RngCore + CryptoRng>(rng: &mut R) -> Result<Polynomial<M>> {
let mut poly = Polynomial::<M>::zero();
let q = M::Q;
if q <= (1 << 16) {
sample_uniform_small::<M, R>(rng, &mut poly)?;
} else if q <= (1 << 24) {
sample_uniform_medium::<M, R>(rng, &mut poly)?;
} else {
sample_uniform_large::<M, R>(rng, &mut poly)?;
}
Ok(poly)
}
}
fn sample_uniform_small<M: Modulus, R: RngCore + CryptoRng>(
rng: &mut R,
poly: &mut Polynomial<M>,
) -> Result<()> {
let q = M::Q;
let n = M::N;
let threshold = ((1u32 << 16) / q) * q;
for i in 0..n {
loop {
let mut bytes = [0u8; 2];
rng.fill_bytes(&mut bytes);
let sample = u16::from_le_bytes(bytes) as u32;
if sample < threshold {
poly.coeffs[i] = sample % q;
break;
}
}
}
Ok(())
}
fn sample_uniform_medium<M: Modulus, R: RngCore + CryptoRng>(
rng: &mut R,
poly: &mut Polynomial<M>,
) -> Result<()> {
let q = M::Q;
let n = M::N;
let threshold = ((1u32 << 24) / q) * q;
for i in 0..n {
loop {
let mut bytes = [0u8; 3];
rng.fill_bytes(&mut bytes);
let sample = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], 0]);
if sample < threshold {
poly.coeffs[i] = sample % q;
break;
}
}
}
Ok(())
}
fn sample_uniform_large<M: Modulus, R: RngCore + CryptoRng>(
rng: &mut R,
poly: &mut Polynomial<M>,
) -> Result<()> {
let q = M::Q;
let n = M::N;
let threshold = ((1u32 << 31) / q) * q;
for i in 0..n {
loop {
let mut bytes = [0u8; 4];
rng.fill_bytes(&mut bytes);
bytes[3] &= 0x7F; let sample = u32::from_le_bytes(bytes);
if sample < threshold {
poly.coeffs[i] = sample % q;
break;
}
}
}
Ok(())
}
impl<M: Modulus> CbdSampler<M> for DefaultSamplers {
fn sample_cbd<R: RngCore + CryptoRng>(rng: &mut R, eta: u8) -> Result<Polynomial<M>> {
if eta == 0 || eta > 16 {
return Err(Error::Parameter {
name: "CBD sampling".into(),
reason: format!("eta must be in range [1, 16], got {}", eta).into(),
});
}
let mut poly = Polynomial::<M>::zero();
let n = M::N;
let q = M::Q;
let bytes_per_sample = (2 * eta as usize).div_ceil(8); let mut buffer = [0u8; 4];
for i in 0..n {
rng.fill_bytes(&mut buffer[..bytes_per_sample]);
let mut a = 0i32;
let mut b = 0i32;
for j in 0..eta {
let byte_idx = j as usize / 8;
let bit_idx = j as usize % 8;
a += ((buffer[byte_idx] >> bit_idx) & 1) as i32;
}
for j in 0..eta {
let bit_pos = (eta + j) as usize;
let byte_idx = bit_pos / 8;
let bit_idx = bit_pos % 8;
b += ((buffer[byte_idx] >> bit_idx) & 1) as i32;
}
let sample = a - b;
poly.coeffs[i] = ((sample + q as i32) % q as i32) as u32;
}
Ok(poly)
}
}
impl<M: Modulus> GaussianSampler<M> for DefaultSamplers {
fn sample_gaussian<R: RngCore + CryptoRng>(_rng: &mut R, _sigma: f64) -> Result<Polynomial<M>> {
Err(Error::NotImplemented {
feature: "Gaussian sampler (reserved for Falcon phase)",
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::rngs::StdRng;
use rand::SeedableRng;
#[derive(Clone)]
struct TestModulus;
impl Modulus for TestModulus {
const Q: u32 = 3329;
const N: usize = 256;
}
#[test]
fn test_uniform_sampling() {
let mut rng = StdRng::seed_from_u64(42);
let poly =
<DefaultSamplers as UniformSampler<TestModulus>>::sample_uniform(&mut rng).unwrap();
for &coeff in poly.as_coeffs_slice() {
assert!(coeff < TestModulus::Q);
}
}
#[test]
fn test_cbd_sampling() {
let mut rng = StdRng::seed_from_u64(42);
for eta in 1..=8 {
let poly =
<DefaultSamplers as CbdSampler<TestModulus>>::sample_cbd(&mut rng, eta).unwrap();
for &coeff in poly.as_coeffs_slice() {
assert!(coeff < TestModulus::Q);
}
}
}
#[test]
fn test_cbd_distribution() {
let mut rng = StdRng::seed_from_u64(42);
let eta = 2;
let num_samples = 10000;
let mut histogram = vec![0u32; (2 * eta + 1) as usize];
for _ in 0..num_samples {
let poly =
<DefaultSamplers as CbdSampler<TestModulus>>::sample_cbd(&mut rng, eta).unwrap();
let coeff = poly.coeffs[0];
let centered = (coeff as i32 + eta as i32) % TestModulus::Q as i32;
if centered <= 2 * eta as i32 {
histogram[centered as usize] += 1;
}
}
let expected = [625, 2500, 3750, 2500, 625];
let mut chi_squared = 0.0;
for i in 0..histogram.len() {
let observed = histogram[i] as f64;
let expected_val = expected[i] as f64;
chi_squared += (observed - expected_val).powi(2) / expected_val;
}
assert!(
chi_squared < 15.0,
"Chi-squared test failed: {}",
chi_squared
);
}
#[test]
fn test_gaussian_not_implemented() {
let mut rng = StdRng::seed_from_u64(42);
let result =
<DefaultSamplers as GaussianSampler<TestModulus>>::sample_gaussian(&mut rng, 1.0);
assert!(matches!(result, Err(Error::NotImplemented { .. })));
}
}