use super::uniform::UniformIntegerSampler;
use crate::{
error::{MathError, StringConversionError},
integer::{MatZ, Z},
rational::{MatQ, Q},
traits::{MatrixDimensions, MatrixGetSubmatrix, Pow},
};
use rand::RngExt;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(PartialEq, Clone, Copy, Serialize, Deserialize, Debug)]
pub enum LookupTableSetting {
Precompute,
FillOnTheFly,
NoLookup,
}
pub static mut TAILCUT: f64 = 6.0;
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct DiscreteGaussianIntegerSampler {
pub center: Q,
pub s: Q,
pub lower_bound: Z,
pub interval_size: Z,
pub lookup_table_setting: LookupTableSetting,
pub table: HashMap<Z, f64>,
}
impl DiscreteGaussianIntegerSampler {
pub fn init(
center: impl Into<Q>,
s: impl Into<Q>,
tailcut: impl Into<Q>,
lookup_table_setting: LookupTableSetting,
) -> Result<Self, MathError> {
let center = center.into();
let mut s = s.into();
let tailcut = tailcut.into();
if tailcut < Q::ZERO {
return Err(MathError::InvalidIntegerInput(format!(
"The value {tailcut} was provided for parameter tailcut of the function sample_z.
This function expects this input no smaller than 0."
)));
}
if s < Q::ZERO {
return Err(MathError::InvalidIntegerInput(format!(
"The value {s} was provided for parameter s of the function sample_z.
This function expects this input to be no smaller than 0."
)));
}
if s == Q::ZERO {
s = Q::from(0.00001);
}
let lower_bound = (¢er - &s * &tailcut).ceil();
let upper_bound = (¢er + &s * tailcut).floor();
let interval_size = &upper_bound - &lower_bound + Z::ONE;
let mut table = HashMap::new();
if lookup_table_setting != LookupTableSetting::NoLookup && interval_size > u16::MAX {
println!(
"WARNING: A completely filled lookup table will exceed 2^16 entries. You should reconsider your sampling method for discrete Gaussians."
)
}
if lookup_table_setting == LookupTableSetting::Precompute {
let mut i = lower_bound.clone();
while i <= upper_bound {
let evaluated_gauss_function = gaussian_function(&i, ¢er, &s);
table.insert(i.clone(), evaluated_gauss_function);
i += Z::ONE;
}
}
Ok(Self {
center,
s,
lower_bound,
interval_size,
lookup_table_setting,
table,
})
}
pub fn sample_z(&mut self) -> Z {
let mut rng = rand::rng();
let mut uis = UniformIntegerSampler::init(&self.interval_size).unwrap();
loop {
let sample = &self.lower_bound + uis.sample();
let evaluated_gauss_function: &f64 = match self.lookup_table_setting {
LookupTableSetting::NoLookup => &gaussian_function(&sample, &self.center, &self.s),
LookupTableSetting::FillOnTheFly => {
let pot_evaluated_gauss_function = self.table.get(&sample);
match pot_evaluated_gauss_function {
Some(x) => x,
None => &{
let evaluated_function =
gaussian_function(&sample, &self.center, &self.s);
self.table.insert(sample.clone(), evaluated_function);
evaluated_function
},
}
}
LookupTableSetting::Precompute => self.table.get(&sample).unwrap(),
};
let random_f64: f64 = rng.random();
if evaluated_gauss_function >= &random_f64 {
return sample;
}
}
}
}
pub fn gaussian_function(x: &Z, c: &Q, s: &Q) -> f64 {
let num = Q::MINUS_ONE * Q::PI * (x - c).pow(2).unwrap();
let den = s.pow(2).unwrap();
let res = f64::from(&(num / den));
res.exp()
}
pub(crate) fn sample_d(basis: &MatZ, center: &MatQ, s: &Q) -> Result<MatZ, MathError> {
let basis_gso = MatQ::from(basis).gso();
sample_d_precomputed_gso(basis, &basis_gso, center, s)
}
pub(crate) fn sample_d_precomputed_gso(
basis: &MatZ,
basis_gso: &MatQ,
center: &MatQ,
s: &Q,
) -> Result<MatZ, MathError> {
let mut center = center.clone();
assert_eq!(
basis.get_num_rows(),
basis_gso.get_num_rows(),
"The provided gso can not be based on the provided base, \
as they do not have the same number of rows."
);
assert_eq!(
basis.get_num_columns(),
basis_gso.get_num_columns(),
"The provided gso can not be based on the provided base, \
as they do not have the same number of columns."
);
if center.get_num_rows() != basis.get_num_rows() {
return Err(MathError::MismatchingMatrixDimension(format!(
"sample_d requires center and basis to have the same number of columns, but they were {} and {}.",
center.get_num_rows(),
basis.get_num_rows()
)));
}
if !center.is_column_vector() {
Err(StringConversionError::InvalidMatrix(format!(
"sample_d expects center to be a column vector, but it has dimensions {}x{}.",
center.get_num_rows(),
center.get_num_columns()
)))?;
}
if s < &Q::ZERO {
return Err(MathError::InvalidIntegerInput(format!(
"The value {s} was provided for parameter s of the function sample_z.
This function expects this input to be larger than 0."
)));
}
let mut out = MatZ::new(basis_gso.get_num_rows(), 1);
for i in (0..basis_gso.get_num_columns()).rev() {
let basisvector_orth_i = unsafe { basis_gso.get_column_unchecked(i) };
let c_2 = center.dot_product(&basisvector_orth_i).unwrap()
/ basisvector_orth_i.dot_product(&basisvector_orth_i).unwrap();
let s_2 = s / (basisvector_orth_i.norm_eucl_sqrd().unwrap().sqrt());
let mut dgis = DiscreteGaussianIntegerSampler::init(
&c_2,
&s_2,
unsafe { TAILCUT },
LookupTableSetting::FillOnTheFly,
)?;
let z = dgis.sample_z();
let basisvector_i = unsafe { basis.get_column_unchecked(i) };
center -= MatQ::from(&(&z * &basisvector_i));
out = &out + &z * &basisvector_i;
}
Ok(out)
}
#[cfg(test)]
mod test_discrete_gaussian_integer_sampler {
use super::DiscreteGaussianIntegerSampler;
use crate::{
rational::Q,
utils::sample::discrete_gauss::{LookupTableSetting, TAILCUT},
};
#[test]
fn small_interval() {
let center = Q::from(15);
let gaussian_parameter = Q::from((1, 2));
let mut dgis = DiscreteGaussianIntegerSampler::init(
¢er,
&gaussian_parameter,
8.0,
LookupTableSetting::FillOnTheFly,
)
.unwrap();
for _ in 0..64 {
let sample = dgis.sample_z();
assert!(10 <= sample);
assert!(sample <= 20);
}
}
#[test]
fn large_interval() {
let center = Q::MINUS_ONE;
let gaussian_parameter = Q::ONE;
let mut dgis = DiscreteGaussianIntegerSampler::init(
¢er,
&gaussian_parameter,
unsafe { TAILCUT },
LookupTableSetting::FillOnTheFly,
)
.unwrap();
for _ in 0..256 {
let sample = dgis.sample_z();
assert!(-64 <= sample);
assert!(sample <= 62);
}
}
#[test]
fn invalid_gaussian_parameter() {
let center = Q::ZERO;
assert!(
DiscreteGaussianIntegerSampler::init(
¢er,
&Q::MINUS_ONE,
6.0,
LookupTableSetting::FillOnTheFly
)
.is_err()
);
assert!(
DiscreteGaussianIntegerSampler::init(
¢er,
Q::from(i64::MIN),
6.0,
LookupTableSetting::FillOnTheFly
)
.is_err()
);
}
#[test]
fn invalid_tailcut() {
let center = Q::MINUS_ONE;
let gaussian_parameter = Q::ONE;
assert!(
DiscreteGaussianIntegerSampler::init(
¢er,
&gaussian_parameter,
-0.1,
LookupTableSetting::FillOnTheFly
)
.is_err()
);
assert!(
DiscreteGaussianIntegerSampler::init(
¢er,
&gaussian_parameter,
i64::MIN,
LookupTableSetting::FillOnTheFly
)
.is_err()
);
}
}
#[cfg(test)]
mod test_gaussian_function {
use super::{Q, Z, gaussian_function};
use crate::traits::Distance;
#[test]
fn doc_test() {
let sample = Z::ONE;
let center = Q::ZERO;
let gaussian_parameter = Q::ONE;
let cmp = Q::from((43214, 1_000_000));
let value = gaussian_function(&sample, ¢er, &gaussian_parameter);
assert!(cmp.distance(&Q::from(value)) < Q::from((1, 1_000_000)));
}
#[test]
fn small_values() {
let sample_0 = Z::ZERO;
let sample_1 = Z::MINUS_ONE;
let center = Q::MINUS_ONE;
let gaussian_parameter_0 = Q::from((1, 2));
let gaussian_parameter_1 = Q::from((3, 2));
let cmp_0 = Q::from((349, 100_000_000));
let cmp_1 = Q::from((24752, 100_000));
let res_0 = gaussian_function(&sample_0, ¢er, &gaussian_parameter_0);
let res_1 = gaussian_function(&sample_0, ¢er, &gaussian_parameter_1);
let res_2 = gaussian_function(&sample_1, ¢er, &gaussian_parameter_0);
let res_3 = gaussian_function(&sample_1, ¢er, &gaussian_parameter_1);
assert!(cmp_0.distance(&Q::from(res_0)) < Q::from((3, 1_000_000_000)));
assert!(cmp_1.distance(&Q::from(res_1)) < Q::from((1, 1_000_000)));
assert_eq!(1.0, res_2);
assert_eq!(1.0, res_3);
}
#[test]
fn large_values() {
let sample = Z::from(i64::MAX);
let center = Q::from(i64::MAX as u64 + 1);
let gaussian_parameter = Q::from((1, 2));
let cmp = Q::from((349, 100_000_000));
let res = gaussian_function(&sample, ¢er, &gaussian_parameter);
assert!(cmp.distance(&Q::from(res)) < Q::from((3, 1_000_000_000)));
}
#[test]
#[should_panic]
fn invalid_s() {
let sample = Z::from(i64::MAX);
let center = Q::from(i64::MAX as u64 + 1);
let gaussian_parameter = Q::ZERO;
let _ = gaussian_function(&sample, ¢er, &gaussian_parameter);
}
}
#[cfg(test)]
mod test_sample_d {
use super::sample_d_precomputed_gso;
use crate::traits::{Concatenate, MatrixDimensions, MatrixGetSubmatrix, Pow};
use crate::utils::sample::discrete_gauss::sample_d;
use crate::{
integer::{MatZ, Z},
rational::{MatQ, Q},
};
use flint_sys::fmpz_mat::fmpz_mat_hnf;
use std::str::FromStr;
#[test]
fn doc_test() {
let basis = MatZ::identity(5, 5);
let center = MatQ::new(5, 1);
let gaussian_parameter = Q::ONE;
let basis_gso = MatQ::from(&basis).gso();
let _ = sample_d(&basis, ¢er, &gaussian_parameter).unwrap();
let _ = sample_d_precomputed_gso(&basis, &basis_gso, ¢er, &gaussian_parameter).unwrap();
}
#[test]
fn non_zero_center() {
let basis = MatZ::identity(5, 5);
let center = MatQ::identity(5, 1);
let gaussian_parameter = Q::ONE;
let basis_gso = MatQ::from(&basis).gso();
let _ = sample_d(&basis, ¢er, &gaussian_parameter).unwrap();
let _ = sample_d_precomputed_gso(&basis, &basis_gso, ¢er, &gaussian_parameter).unwrap();
}
#[test]
fn non_identity_basis() {
let basis = MatZ::from_str("[[2, 1],[1, 2]]").unwrap();
let center = MatQ::new(2, 1);
let gaussian_parameter = Q::ONE;
let basis_gso = MatQ::from(&basis).gso();
let _ = sample_d(&basis, ¢er, &gaussian_parameter).unwrap();
let _ = sample_d_precomputed_gso(&basis, &basis_gso, ¢er, &gaussian_parameter).unwrap();
}
#[test]
fn point_of_lattice() {
let basis = MatZ::from_str("[[7, 0],[7, 3]]").unwrap();
let center = MatQ::new(2, 1);
let gaussian_parameter = Q::ONE;
let basis_gso = MatQ::from(&basis).gso();
let sample = sample_d(&basis, ¢er, &gaussian_parameter).unwrap();
let sample_prec =
sample_d_precomputed_gso(&basis, &basis_gso, ¢er, &gaussian_parameter).unwrap();
let basis_concat_sample = basis.concat_horizontal(&sample).unwrap();
let basis_concat_sample_prec = basis.concat_horizontal(&sample_prec).unwrap();
let mut hnf_basis = MatZ::new(2, 2);
unsafe { fmpz_mat_hnf(&mut hnf_basis.matrix, &basis.matrix) };
let mut hnf_basis_concat_sample = MatZ::new(2, 3);
let mut hnf_basis_concat_sample_prec = MatZ::new(2, 3);
unsafe {
fmpz_mat_hnf(
&mut hnf_basis_concat_sample.matrix,
&basis_concat_sample.matrix,
)
};
unsafe {
fmpz_mat_hnf(
&mut hnf_basis_concat_sample_prec.matrix,
&basis_concat_sample_prec.matrix,
)
};
assert_eq!(
hnf_basis.get_column(0).unwrap(),
hnf_basis_concat_sample.get_column(0).unwrap()
);
assert_eq!(
hnf_basis.get_column(0).unwrap(),
hnf_basis_concat_sample_prec.get_column(0).unwrap()
);
assert_eq!(
hnf_basis.get_column(1).unwrap(),
hnf_basis_concat_sample.get_column(1).unwrap()
);
assert_eq!(
hnf_basis.get_column(1).unwrap(),
hnf_basis_concat_sample_prec.get_column(1).unwrap()
);
assert!(hnf_basis_concat_sample.get_column(2).unwrap().is_zero());
assert!(
hnf_basis_concat_sample_prec
.get_column(2)
.unwrap()
.is_zero()
);
}
#[test]
fn invalid_gaussian_parameter() {
let basis = MatZ::identity(5, 5);
let center = MatQ::new(5, 1);
let basis_gso = MatQ::from(&basis).gso();
assert!(sample_d(&basis, ¢er, &Q::MINUS_ONE).is_err());
assert!(sample_d(&basis, ¢er, &Q::from(i64::MIN)).is_err());
assert!(sample_d_precomputed_gso(&basis, &basis_gso, ¢er, &Q::MINUS_ONE).is_err());
assert!(sample_d_precomputed_gso(&basis, &basis_gso, ¢er, &Q::from(i64::MIN)).is_err());
}
#[test]
fn mismatching_matrix_dimensions() {
let basis = MatZ::identity(3, 5);
let center = MatQ::new(4, 1);
let gaussian_parameter = Q::ONE;
let basis_gso = MatQ::from(&basis).gso();
let res = sample_d(&basis, ¢er, &gaussian_parameter);
let res_prec = sample_d_precomputed_gso(&basis, &basis_gso, ¢er, &gaussian_parameter);
assert!(res.is_err());
assert!(res_prec.is_err());
}
#[test]
fn center_not_column_vector() {
let basis = MatZ::identity(2, 2);
let center = MatQ::new(2, 2);
let gaussian_parameter = Q::ONE;
let basis_gso = MatQ::from(&basis).gso();
let res = sample_d(&basis, ¢er, &gaussian_parameter);
let res_prec = sample_d_precomputed_gso(&basis, &basis_gso, ¢er, &gaussian_parameter);
assert!(res.is_err());
assert!(res_prec.is_err());
}
#[test]
fn concentration_bound() {
let n = Z::from(20);
let basis = MatZ::sample_uniform(&n, &n, 0, 5000).unwrap();
let orth = MatQ::from(&basis).gso();
let mut len = Q::ZERO;
for i in 0..orth.get_num_columns() {
let column = orth.get_column(i).unwrap();
let column_len = column.norm_eucl_sqrd().unwrap().sqrt();
if column_len > len {
len = column_len
}
}
let expl_text = String::from("This test can fail with probability close to 0.
It fails if the length of the sampled is longer than expected.
If this happens, rerun the tests several times and check whether this issue comes up again.");
let center = MatQ::new(&n, 1);
let gaussian_parameter =
len * n.log(2).unwrap().sqrt() * (n.log(2).unwrap().log(2).unwrap());
for _ in 0..20 {
let res = sample_d(&basis, ¢er, &gaussian_parameter).unwrap();
let res_prec =
sample_d_precomputed_gso(&basis, &orth, ¢er, &gaussian_parameter).unwrap();
assert!(
res.norm_eucl_sqrd().unwrap() <= gaussian_parameter.pow(2).unwrap().round() * &n,
"{expl_text}"
);
assert!(
res_prec.norm_eucl_sqrd().unwrap()
<= gaussian_parameter.pow(2).unwrap().round() * &n,
"{expl_text}"
);
}
}
#[test]
#[should_panic]
fn precomputed_gso_mismatching_rows() {
let n = Z::from(20);
let basis = MatZ::sample_uniform(&n, &n, 0, 5000).unwrap();
let center = MatQ::new(&n, 1);
let false_gso = MatQ::new(basis.get_num_rows() + 1, basis.get_num_columns());
let _ = sample_d_precomputed_gso(&basis, &false_gso, ¢er, &Q::from(5)).unwrap();
}
#[test]
#[should_panic]
fn precomputed_gso_mismatching_columns() {
let n = Z::from(20);
let basis = MatZ::sample_uniform(&n, &n, 0, 5000).unwrap();
let center = MatQ::new(&n, 1);
let false_gso = MatQ::new(basis.get_num_rows(), basis.get_num_columns() + 1);
let _ = sample_d_precomputed_gso(&basis, &false_gso, ¢er, &Q::from(5)).unwrap();
}
}