qfall_math/utils/sample/discrete_gauss.rs
1// Copyright © 2025 Niklas Siemer
2//
3// This file is part of qFALL-math.
4//
5// qFALL-math is free software: you can redistribute it and/or modify it under
6// the terms of the Mozilla Public License Version 2.0 as published by the
7// Mozilla Foundation. See <https://mozilla.org/en-US/MPL/2.0/>.
8
9//! This module includes core functionality to sample according to the
10//! discrete gaussian distribution.
11//!
12//! The main references are listed in the following
13//! and will be further referenced in submodules by these numbers:
14//! - \[1\] Gentry, Craig and Peikert, Chris and Vaikuntanathan, Vinod (2008).
15//! Trapdoors for hard lattices and new cryptographic constructions.
16//! In: Proceedings of the fortieth annual ACM symposium on Theory of computing.
17//! <https://citeseerx.ist.psu.edu/document?doi=d9f54077d568784c786f7b1d030b00493eb3ae35>
18
19use super::uniform::UniformIntegerSampler;
20use crate::{
21 error::{MathError, StringConversionError},
22 integer::{MatZ, Z},
23 rational::{MatQ, Q},
24 traits::{MatrixDimensions, MatrixGetSubmatrix, Pow},
25};
26use rand::Rng;
27use serde::{Deserialize, Serialize};
28use std::collections::HashMap;
29
30/// Defines whether a lookup-table should be precomputed, filled on-the-fly,
31/// or not used at all for a discrete Gaussian sampler.
32#[derive(PartialEq, Clone, Copy, Serialize, Deserialize, Debug)]
33pub enum LookupTableSetting {
34 Precompute,
35 FillOnTheFly,
36 NoLookup,
37}
38
39/// This is the global variable used in all `sample_discrete_gauss` and `sample_d`
40/// functions. Its value should be in `ω(log(sqrt(n)))`. We set it (as most other libraries)
41/// statically to `6.0`.
42///
43/// You can use and change in an `unsafe` environment.
44/// ```compile_fail
45/// unsafe { TAILCUT = 4.0 };
46/// ```
47/// Make sure that the tailcut stays positive and large enough for your purposes.
48/// If you use multi-threading, read up on the behaviour of a `static mut` variable.
49/// Our tests only cover cases where `TAILCUT = 6.0`.
50pub static mut TAILCUT: f64 = 6.0;
51
52/// Enables for discrete Gaussian sampling out of
53/// `[⌈center - s * tailcut⌉ , ⌊center + s * tailcut⌋ ]`.
54///
55/// **WARNING:** If the attributes are not set using [`DiscreteGaussianIntegerSampler::init`],
56/// we can't guarantee sampling from the correct discrete Gaussian distribution.
57/// Altering any value will invalidate the [`HashMap`] in `table` and might invalidate
58/// other attributes, too.
59///
60/// Attributes:
61/// - `center`: specifies the position of the center with peak probability
62/// - `s`: specifies the Gaussian parameter, which is proportional
63/// to the standard deviation `sigma * sqrt(2 * pi) = s`
64/// - `lower_bound`: specifies the lower bound to sample uniformly from
65/// - `interval_size`: specifies the interval size to sample uniformly from
66/// - `lookup_table_setting`: Specifies whether a lookup-table should be used and
67/// how it should be filled, i.e. lazily on-the-fly (impacting sampling time slightly) or precomputed
68/// - `table`: the lookup-table if one is used
69///
70/// # Examples
71/// ```
72/// use qfall_math::{integer::Z, rational::Q};
73/// use qfall_math::utils::sample::discrete_gauss::{DiscreteGaussianIntegerSampler, LookupTableSetting};
74/// let n = Z::from(1024);
75/// let center = 0.0;
76/// let gaussian_parameter = 1.0;
77/// let tailcut = 6.0;
78///
79/// let mut dgis = DiscreteGaussianIntegerSampler::init(center, gaussian_parameter, tailcut, LookupTableSetting::NoLookup).unwrap();
80///
81/// let sample = dgis.sample_z();
82/// ```
83#[derive(Debug, Serialize, Deserialize, Clone)]
84pub struct DiscreteGaussianIntegerSampler {
85 pub center: Q,
86 pub s: Q,
87 pub lower_bound: Z,
88 pub interval_size: Z,
89 pub lookup_table_setting: LookupTableSetting,
90 pub table: HashMap<Z, f64>,
91}
92
93impl DiscreteGaussianIntegerSampler {
94 /// Initializes the [`DiscreteGaussianIntegerSampler`] with
95 /// - `center` as the center of the discrete Gaussian to sample from,
96 /// - `s` defining the Gaussian parameter, which is proportional
97 /// to the standard deviation `sigma * sqrt(2 * pi) = s`,
98 /// - `lower_bound` as `⌈center - 6 * s⌉`,
99 /// - `interval_size` as `⌊center + 6 * s⌋ - ⌈center - 6 * s⌉ + 1`, and
100 /// - `table` as an empty [`HashMap`] to store evaluations of the Gaussian function.
101 ///
102 /// Parameters:
103 /// - `n`: specifies the range from which is sampled
104 /// - `center`: as the center of the discrete Gaussian to sample from
105 /// - `s`: specifies the Gaussian parameter, which is proportional
106 /// to the standard deviation `sigma * sqrt(2 * pi) = s`
107 ///
108 /// Returns a sample chosen according to the specified discrete Gaussian distribution or
109 /// a [`MathError`] if the specified parameters were not chosen appropriately,
110 /// i.e. `n > 1` or `s > 0`.
111 ///
112 /// # Examples
113 /// ```
114 /// use qfall_math::{integer::Z, rational::Q};
115 /// use qfall_math::utils::sample::discrete_gauss::{DiscreteGaussianIntegerSampler, LookupTableSetting};
116 /// let center = 0.0;
117 /// let gaussian_parameter = 1.0;
118 /// let tailcut = 6.0;
119 ///
120 /// let mut dgis = DiscreteGaussianIntegerSampler::init(center, gaussian_parameter, tailcut, LookupTableSetting::Precompute).unwrap();
121 /// ```
122 ///
123 /// # Errors and Failures
124 /// - Returns a [`MathError`] of type [`InvalidIntegerInput`](MathError::InvalidIntegerInput)
125 /// if `tailcut < 0` or `s < 0`.
126 pub fn init(
127 center: impl Into<Q>,
128 s: impl Into<Q>,
129 tailcut: impl Into<Q>,
130 lookup_table_setting: LookupTableSetting,
131 ) -> Result<Self, MathError> {
132 let center = center.into();
133 let mut s = s.into();
134 let tailcut = tailcut.into();
135 if tailcut < Q::ZERO {
136 return Err(MathError::InvalidIntegerInput(format!(
137 "The value {tailcut} was provided for parameter tailcut of the function sample_z.
138 This function expects this input no smaller than 0."
139 )));
140 }
141 if s < Q::ZERO {
142 return Err(MathError::InvalidIntegerInput(format!(
143 "The value {s} was provided for parameter s of the function sample_z.
144 This function expects this input to be no smaller than 0."
145 )));
146 }
147 if s == Q::ZERO {
148 // Ensure that s != 0 s.t. we can divide by s^2 in the Gaussian function
149 s = Q::from(0.00001);
150 }
151
152 let lower_bound = (¢er - &s * &tailcut).ceil();
153 let upper_bound = (¢er + &s * tailcut).floor();
154 // interval [lower_bound, upper_bound] has upper_bound - lower_bound + 1 elements in it
155 let interval_size = &upper_bound - &lower_bound + Z::ONE;
156
157 let mut table = HashMap::new();
158
159 if lookup_table_setting != LookupTableSetting::NoLookup && interval_size > u16::MAX {
160 println!(
161 "WARNING: A completely filled lookup table will exceed 2^16 entries. You should reconsider your sampling method for discrete Gaussians."
162 )
163 }
164
165 if lookup_table_setting == LookupTableSetting::Precompute {
166 let mut i = lower_bound.clone();
167 while i <= upper_bound {
168 let evaluated_gauss_function = gaussian_function(&i, ¢er, &s);
169 table.insert(i.clone(), evaluated_gauss_function);
170 i += Z::ONE;
171 }
172 }
173
174 Ok(Self {
175 center,
176 s,
177 lower_bound,
178 interval_size,
179 lookup_table_setting,
180 table,
181 })
182 }
183
184 /// Chooses a sample according to the discrete Gaussian distribution out of
185 /// `[lower_bound , lower_bound + interval_size ]`.
186 ///
187 /// This function implements discrete Gaussian sampling according to the definition of
188 /// SampleZ as in [\[1\]](<index.html#:~:text=[1]>).
189 ///
190 /// # Examples
191 /// ```
192 /// use qfall_math::{integer::Z, rational::Q};
193 /// use qfall_math::utils::sample::discrete_gauss::{DiscreteGaussianIntegerSampler, LookupTableSetting};
194 /// let center = 0.0;
195 /// let gaussian_parameter = 1.0;
196 /// let tailcut = 6.0;
197 ///
198 /// let mut dgis = DiscreteGaussianIntegerSampler::init(center, gaussian_parameter, tailcut, LookupTableSetting::Precompute).unwrap();
199 ///
200 /// let sample = dgis.sample_z();
201 /// ```
202 pub fn sample_z(&mut self) -> Z {
203 let mut rng = rand::rng();
204 let mut uis = UniformIntegerSampler::init(&self.interval_size).unwrap();
205 loop {
206 // sample x in [c - s * tailcut, c + s * tailcut]
207 let sample = &self.lower_bound + uis.sample();
208
209 let evaluated_gauss_function: &f64 = match self.lookup_table_setting {
210 LookupTableSetting::NoLookup => &gaussian_function(&sample, &self.center, &self.s),
211 LookupTableSetting::FillOnTheFly => {
212 let pot_evaluated_gauss_function = self.table.get(&sample);
213 match pot_evaluated_gauss_function {
214 Some(x) => x,
215 None => &{
216 // if the entry doesn't exist yet, compute and insert it
217 let evaluated_function =
218 gaussian_function(&sample, &self.center, &self.s);
219 self.table.insert(sample.clone(), evaluated_function);
220 evaluated_function
221 },
222 }
223 }
224 LookupTableSetting::Precompute => self.table.get(&sample).unwrap(),
225 };
226
227 let random_f64: f64 = rng.random();
228 if evaluated_gauss_function >= &random_f64 {
229 return sample;
230 }
231 }
232 }
233}
234
235/// Computes the value of the Gaussian function for `x`.
236///
237/// **Warning**: This functions assumes `s != 0`.
238///
239/// Parameters:
240/// - `x`: specifies the value/ sample for which the Gaussian function's value is computed
241/// - `c`: specifies the position of the center with peak probability
242/// - `s`: specifies the Gaussian parameter, which is proportional
243/// to the standard deviation `sigma * sqrt(2 * pi) = s`
244///
245/// Returns the computed value of the Gaussian function for `x`.
246///
247/// # Examples
248/// ```
249/// use qfall_math::{integer::Z, rational::Q};
250/// use qfall_math::utils::sample::discrete_gauss::gaussian_function;
251/// let sample = Z::ONE;
252/// let center = Q::ZERO;
253/// let gaussian_parameter = Q::ONE;
254///
255/// let probability = gaussian_function(&sample, ¢er, &gaussian_parameter);
256/// ```
257///
258/// # Panics ...
259/// - if `s = 0`.
260/// - if `-π * (x - c)^2 / s^2` is larger than [`f64::MAX`]
261pub fn gaussian_function(x: &Z, c: &Q, s: &Q) -> f64 {
262 let num = Q::MINUS_ONE * Q::PI * (x - c).pow(2).unwrap();
263 let den = s.pow(2).unwrap();
264 let res = f64::from(&(num / den));
265 res.exp()
266}
267
268/// SampleD samples a discrete Gaussian from the lattice with `basis` using [`sample_z`] as a subroutine.
269///
270/// We do not check whether `basis` is actually a basis. Hence, the callee is
271/// responsible for making sure that `basis` provides a suitable basis.
272///
273/// Parameters:
274/// - `basis`: specifies a basis for the lattice from which is sampled
275/// - `n`: specifies the range from which [`sample_z`] samples
276/// - `center`: specifies the positions of the center with peak probability
277/// - `s`: specifies the Gaussian parameter, which is proportional
278/// to the standard deviation `sigma * sqrt(2 * pi) = s`
279///
280/// Returns a vector with discrete gaussian error based on a lattice point
281/// as in [\[1\]](<index.html#:~:text=[1]>): SampleD or a [`MathError`], if the
282/// `n <= 1` or `s <= 0`, the number of rows of the `basis` and `center` differ,
283/// or `center` is not a column vector.
284///
285/// # Examples
286/// ```compile_fail
287/// use qfall_math::{integer::{MatZ, Z}, rational::{MatQ, Q}};
288/// use qfall_math::utils::sample::discrete_gauss::sample_d;
289/// let basis = MatZ::identity(5, 5);
290/// let n = Z::from(1024);
291/// let center = MatQ::new(5, 1);
292/// let gaussian_parameter = Q::ONE;
293///
294/// let sample = sample_d(basis, &n, ¢er, &gaussian_parameter).unwrap();
295/// ```
296///
297/// # Errors and Failures
298/// - Returns a [`MathError`] of type [`InvalidIntegerInput`](MathError::InvalidIntegerInput)
299/// if `n <= 1` or `s <= 0`.
300/// - Returns a [`MathError`] of type [`MismatchingMatrixDimension`](MathError::MismatchingMatrixDimension)
301/// if the number of rows of the `basis` and `center` differ.
302/// - Returns a [`MathError`] of type [`StringConversionError`](MathError::StringConversionError)
303/// if `center` is not a column vector.
304pub(crate) fn sample_d(basis: &MatZ, center: &MatQ, s: &Q) -> Result<MatZ, MathError> {
305 let basis_gso = MatQ::from(basis).gso();
306 sample_d_precomputed_gso(basis, &basis_gso, center, s)
307}
308
309/// SampleD samples a discrete Gaussian from the lattice with `basis` using [`sample_z`] as a subroutine.
310///
311/// We do not check whether `basis` is actually a basis or whether `basis_gso` is
312/// actually the gso of `basis`. Hence, the callee is responsible for making sure that
313/// `basis` provides a suitable basis and `basis_gso` is a corresponding GSO.
314///
315/// Parameters:
316/// - `basis`: specifies a basis for the lattice from which is sampled
317/// - `basis_gso`: specifies the precomputed gso for `basis`
318/// - `n`: specifies the range from which [`sample_z`] samples
319/// - `center`: specifies the positions of the center with peak probability
320/// - `s`: specifies the Gaussian parameter, which is proportional
321/// to the standard deviation `sigma * sqrt(2 * pi) = s`
322///
323/// Returns a vector with discrete gaussian error based on a lattice point
324/// as in [\[1\]](<index.html#:~:text=[1]>): SampleD or a [`MathError`], if the
325/// `n <= 1` or `s <= 0`, the number of rows of the `basis` and `center` differ,
326/// or `center` is not a column vector.
327///
328/// # Examples
329/// ```compile_fail
330/// use qfall_math::{integer::{MatZ, Z}, rational::{MatQ, Q}};
331/// use qfall_math::utils::sample::discrete_gauss::sample_d;
332/// let basis = MatZ::identity(5, 5);
333/// let n = Z::from(1024);
334/// let center = MatQ::new(5, 1);
335/// let gaussian_parameter = Q::ONE;
336///
337/// let basis_gso = basis.gso();
338///
339/// let sample = sample_d(basis, &basis_gso, &n, ¢er, &gaussian_parameter).unwrap();
340/// ```
341///
342/// # Errors and Failures
343/// - Returns a [`MathError`] of type [`InvalidIntegerInput`](MathError::InvalidIntegerInput)
344/// if `n <= 1` or `s <= 0`.
345/// - Returns a [`MathError`] of type [`MismatchingMatrixDimension`](MathError::MismatchingMatrixDimension)
346/// if the number of rows of the `basis` and `center` differ.
347/// - Returns a [`MathError`] of type [`StringConversionError`](MathError::StringConversionError)
348/// if `center` is not a column vector.
349///
350/// # Panics ...
351/// - if the number of rows/columns of `basis_gso` and `basis` mismatch.
352pub(crate) fn sample_d_precomputed_gso(
353 basis: &MatZ,
354 basis_gso: &MatQ,
355 center: &MatQ,
356 s: &Q,
357) -> Result<MatZ, MathError> {
358 let mut center = center.clone();
359 assert_eq!(
360 basis.get_num_rows(),
361 basis_gso.get_num_rows(),
362 "The provided gso can not be based on the provided base, \
363 as they do not have the same number of rows."
364 );
365 assert_eq!(
366 basis.get_num_columns(),
367 basis_gso.get_num_columns(),
368 "The provided gso can not be based on the provided base, \
369 as they do not have the same number of columns."
370 );
371 if center.get_num_rows() != basis.get_num_rows() {
372 return Err(MathError::MismatchingMatrixDimension(format!(
373 "sample_d requires center and basis to have the same number of columns, but they were {} and {}.",
374 center.get_num_rows(),
375 basis.get_num_rows()
376 )));
377 }
378 if !center.is_column_vector() {
379 Err(StringConversionError::InvalidMatrix(format!(
380 "sample_d expects center to be a column vector, but it has dimensions {}x{}.",
381 center.get_num_rows(),
382 center.get_num_columns()
383 )))?;
384 }
385 if s < &Q::ZERO {
386 return Err(MathError::InvalidIntegerInput(format!(
387 "The value {s} was provided for parameter s of the function sample_z.
388 This function expects this input to be larger than 0."
389 )));
390 }
391
392 let mut out = MatZ::new(basis_gso.get_num_rows(), 1);
393
394 for i in (0..basis_gso.get_num_columns()).rev() {
395 // basisvector_i = b_tilde[i]
396 let basisvector_orth_i = unsafe { basis_gso.get_column_unchecked(i) };
397
398 // define the center for sample_z as c_2 = <c, b_tilde[i]> / <b_tilde[i], b_tilde[i]>;
399 let c_2 = center.dot_product(&basisvector_orth_i).unwrap()
400 / basisvector_orth_i.dot_product(&basisvector_orth_i).unwrap();
401
402 // Defines the gaussian parameter to be normalized along the basis vector: s2 = s / ||b_tilde[i]||
403 let s_2 = s / (basisvector_orth_i.norm_eucl_sqrd().unwrap().sqrt());
404
405 // sample z ~ D_{Z, s2, c2}
406 let mut dgis = DiscreteGaussianIntegerSampler::init(
407 &c_2,
408 &s_2,
409 unsafe { TAILCUT },
410 LookupTableSetting::FillOnTheFly,
411 )?;
412 let z = dgis.sample_z();
413
414 // update the center c = c - z * b[i]
415 let basisvector_i = unsafe { basis.get_column_unchecked(i) };
416 center -= MatQ::from(&(&z * &basisvector_i));
417
418 // out = out + z * b[i]
419 out = &out + &z * &basisvector_i;
420 }
421
422 Ok(out)
423}
424
425#[cfg(test)]
426mod test_discrete_gaussian_integer_sampler {
427 use super::DiscreteGaussianIntegerSampler;
428 use crate::{
429 rational::Q,
430 utils::sample::discrete_gauss::{LookupTableSetting, TAILCUT},
431 };
432
433 /// Checks whether samples are kept in correct interval for a small interval.
434 #[test]
435 fn small_interval() {
436 let center = Q::from(15);
437 let gaussian_parameter = Q::from((1, 2));
438
439 let mut dgis = DiscreteGaussianIntegerSampler::init(
440 ¢er,
441 &gaussian_parameter,
442 8.0,
443 LookupTableSetting::FillOnTheFly,
444 )
445 .unwrap();
446
447 for _ in 0..64 {
448 let sample = dgis.sample_z();
449
450 assert!(10 <= sample);
451 assert!(sample <= 20);
452 }
453 }
454
455 /// Checks whether samples are kept in correct interval for a large interval.
456 #[test]
457 fn large_interval() {
458 let center = Q::MINUS_ONE;
459 let gaussian_parameter = Q::ONE;
460
461 let mut dgis = DiscreteGaussianIntegerSampler::init(
462 ¢er,
463 &gaussian_parameter,
464 unsafe { TAILCUT },
465 LookupTableSetting::FillOnTheFly,
466 )
467 .unwrap();
468
469 for _ in 0..256 {
470 let sample = dgis.sample_z();
471
472 assert!(-64 <= sample);
473 assert!(sample <= 62);
474 }
475 }
476
477 /// Checks whether `sample_z` returns an error if the gaussian parameter `s < 0`.
478 #[test]
479 fn invalid_gaussian_parameter() {
480 let center = Q::ZERO;
481
482 assert!(
483 DiscreteGaussianIntegerSampler::init(
484 ¢er,
485 &Q::MINUS_ONE,
486 6.0,
487 LookupTableSetting::FillOnTheFly
488 )
489 .is_err()
490 );
491 assert!(
492 DiscreteGaussianIntegerSampler::init(
493 ¢er,
494 &Q::from(i64::MIN),
495 6.0,
496 LookupTableSetting::FillOnTheFly
497 )
498 .is_err()
499 );
500 }
501
502 /// Checks whether `sample_z` returns an error if `n < 0`.
503 #[test]
504 fn invalid_tailcut() {
505 let center = Q::MINUS_ONE;
506 let gaussian_parameter = Q::ONE;
507
508 assert!(
509 DiscreteGaussianIntegerSampler::init(
510 ¢er,
511 &gaussian_parameter,
512 -0.1,
513 LookupTableSetting::FillOnTheFly
514 )
515 .is_err()
516 );
517 assert!(
518 DiscreteGaussianIntegerSampler::init(
519 ¢er,
520 &gaussian_parameter,
521 i64::MIN,
522 LookupTableSetting::FillOnTheFly
523 )
524 .is_err()
525 );
526 }
527}
528
529#[cfg(test)]
530mod test_gaussian_function {
531 use super::{Q, Z, gaussian_function};
532 use crate::traits::Distance;
533
534 /// Ensures that the doc test would run properly.
535 #[test]
536 fn doc_test() {
537 let sample = Z::ONE;
538 let center = Q::ZERO;
539 let gaussian_parameter = Q::ONE;
540 // result roughly 0.0432139 computed via WolframAlpha
541 let cmp = Q::from((43214, 1_000_000));
542
543 let value = gaussian_function(&sample, ¢er, &gaussian_parameter);
544
545 assert!(cmp.distance(&Q::from(value)) < Q::from((1, 1_000_000)));
546 }
547
548 /// Checks whether the values for small values are computed appropriately
549 /// and with appropriate precision.
550 #[test]
551 fn small_values() {
552 let sample_0 = Z::ZERO;
553 let sample_1 = Z::MINUS_ONE;
554 let center = Q::MINUS_ONE;
555 let gaussian_parameter_0 = Q::from((1, 2));
556 let gaussian_parameter_1 = Q::from((3, 2));
557 // result roughly 0.00000348734 computed via WolframAlpha
558 let cmp_0 = Q::from((349, 100_000_000));
559 // result 0.247520 computed via WolframAlpha
560 let cmp_1 = Q::from((24752, 100_000));
561
562 let res_0 = gaussian_function(&sample_0, ¢er, &gaussian_parameter_0);
563 let res_1 = gaussian_function(&sample_0, ¢er, &gaussian_parameter_1);
564 let res_2 = gaussian_function(&sample_1, ¢er, &gaussian_parameter_0);
565 let res_3 = gaussian_function(&sample_1, ¢er, &gaussian_parameter_1);
566
567 assert!(cmp_0.distance(&Q::from(res_0)) < Q::from((3, 1_000_000_000)));
568 assert!(cmp_1.distance(&Q::from(res_1)) < Q::from((1, 1_000_000)));
569 assert_eq!(1.0, res_2);
570 assert_eq!(1.0, res_3);
571 }
572
573 /// Checks whether the values for large values are computed appropriately
574 /// and with appropriate precision.
575 #[test]
576 fn large_values() {
577 let sample = Z::from(i64::MAX);
578 let center = Q::from(i64::MAX as u64 + 1);
579 let gaussian_parameter = Q::from((1, 2));
580 // result roughly 0.00000348734 computed via WolframAlpha
581 let cmp = Q::from((349, 100_000_000));
582
583 let res = gaussian_function(&sample, ¢er, &gaussian_parameter);
584
585 assert!(cmp.distance(&Q::from(res)) < Q::from((3, 1_000_000_000)));
586 }
587
588 /// Checks whether `s = 0` results in a panic.
589 #[test]
590 #[should_panic]
591 fn invalid_s() {
592 let sample = Z::from(i64::MAX);
593 let center = Q::from(i64::MAX as u64 + 1);
594 let gaussian_parameter = Q::ZERO;
595
596 let _ = gaussian_function(&sample, ¢er, &gaussian_parameter);
597 }
598}
599
600#[cfg(test)]
601mod test_sample_d {
602 use super::sample_d_precomputed_gso;
603 use crate::traits::{Concatenate, MatrixDimensions, MatrixGetSubmatrix, Pow};
604 use crate::utils::sample::discrete_gauss::sample_d;
605 use crate::{
606 integer::{MatZ, Z},
607 rational::{MatQ, Q},
608 };
609 use flint_sys::fmpz_mat::fmpz_mat_hnf;
610 use std::str::FromStr;
611
612 /// Ensures that the doc-test compiles and runs properly.
613 #[test]
614 fn doc_test() {
615 let basis = MatZ::identity(5, 5);
616 let center = MatQ::new(5, 1);
617 let gaussian_parameter = Q::ONE;
618 let basis_gso = MatQ::from(&basis).gso();
619
620 let _ = sample_d(&basis, ¢er, &gaussian_parameter).unwrap();
621 let _ = sample_d_precomputed_gso(&basis, &basis_gso, ¢er, &gaussian_parameter).unwrap();
622 }
623
624 /// Ensures that `sample_d` works properly for a non-zero center.
625 #[test]
626 fn non_zero_center() {
627 let basis = MatZ::identity(5, 5);
628 let center = MatQ::identity(5, 1);
629 let gaussian_parameter = Q::ONE;
630 let basis_gso = MatQ::from(&basis).gso();
631
632 let _ = sample_d(&basis, ¢er, &gaussian_parameter).unwrap();
633 let _ = sample_d_precomputed_gso(&basis, &basis_gso, ¢er, &gaussian_parameter).unwrap();
634 }
635
636 /// Ensures that `sample_d` works properly for a different basis.
637 #[test]
638 fn non_identity_basis() {
639 let basis = MatZ::from_str("[[2, 1],[1, 2]]").unwrap();
640 let center = MatQ::new(2, 1);
641 let gaussian_parameter = Q::ONE;
642 let basis_gso = MatQ::from(&basis).gso();
643
644 let _ = sample_d(&basis, ¢er, &gaussian_parameter).unwrap();
645 let _ = sample_d_precomputed_gso(&basis, &basis_gso, ¢er, &gaussian_parameter).unwrap();
646 }
647
648 /// Ensures that `sample_d` outputs a vector that's part of the specified lattice.
649 ///
650 /// Checks whether the Hermite Normal Form HNF of the basis is equal to the HNF of
651 /// the basis concatenated with the sampled vector. If it is part of the lattice, it
652 /// should become a zero vector at the end of the matrix.
653 #[test]
654 fn point_of_lattice() {
655 let basis = MatZ::from_str("[[7, 0],[7, 3]]").unwrap();
656 let center = MatQ::new(2, 1);
657 let gaussian_parameter = Q::ONE;
658 let basis_gso = MatQ::from(&basis).gso();
659
660 let sample = sample_d(&basis, ¢er, &gaussian_parameter).unwrap();
661 let sample_prec =
662 sample_d_precomputed_gso(&basis, &basis_gso, ¢er, &gaussian_parameter).unwrap();
663
664 // check whether hermite normal form of HNF(b) = HNF([b|sample_vector])
665 let basis_concat_sample = basis.concat_horizontal(&sample).unwrap();
666 let basis_concat_sample_prec = basis.concat_horizontal(&sample_prec).unwrap();
667 let mut hnf_basis = MatZ::new(2, 2);
668 unsafe { fmpz_mat_hnf(&mut hnf_basis.matrix, &basis.matrix) };
669 let mut hnf_basis_concat_sample = MatZ::new(2, 3);
670 let mut hnf_basis_concat_sample_prec = MatZ::new(2, 3);
671 unsafe {
672 fmpz_mat_hnf(
673 &mut hnf_basis_concat_sample.matrix,
674 &basis_concat_sample.matrix,
675 )
676 };
677 unsafe {
678 fmpz_mat_hnf(
679 &mut hnf_basis_concat_sample_prec.matrix,
680 &basis_concat_sample_prec.matrix,
681 )
682 };
683 assert_eq!(
684 hnf_basis.get_column(0).unwrap(),
685 hnf_basis_concat_sample.get_column(0).unwrap()
686 );
687 assert_eq!(
688 hnf_basis.get_column(0).unwrap(),
689 hnf_basis_concat_sample_prec.get_column(0).unwrap()
690 );
691 assert_eq!(
692 hnf_basis.get_column(1).unwrap(),
693 hnf_basis_concat_sample.get_column(1).unwrap()
694 );
695 assert_eq!(
696 hnf_basis.get_column(1).unwrap(),
697 hnf_basis_concat_sample_prec.get_column(1).unwrap()
698 );
699 // check whether last vector is zero, i.e. was linearly dependent and part of lattice
700 assert!(hnf_basis_concat_sample.get_column(2).unwrap().is_zero());
701 assert!(
702 hnf_basis_concat_sample_prec
703 .get_column(2)
704 .unwrap()
705 .is_zero()
706 );
707 }
708
709 /// Checks whether `sample_d` returns an error if the gaussian parameter `s < 0`.
710 #[test]
711 fn invalid_gaussian_parameter() {
712 let basis = MatZ::identity(5, 5);
713 let center = MatQ::new(5, 1);
714 let basis_gso = MatQ::from(&basis).gso();
715
716 assert!(sample_d(&basis, ¢er, &Q::MINUS_ONE).is_err());
717 assert!(sample_d(&basis, ¢er, &Q::from(i64::MIN)).is_err());
718
719 assert!(sample_d_precomputed_gso(&basis, &basis_gso, ¢er, &Q::MINUS_ONE).is_err());
720 assert!(sample_d_precomputed_gso(&basis, &basis_gso, ¢er, &Q::from(i64::MIN)).is_err());
721 }
722
723 /// Checks whether `sample_d` returns an error if the basis and center number of rows differs.
724 #[test]
725 fn mismatching_matrix_dimensions() {
726 let basis = MatZ::identity(3, 5);
727 let center = MatQ::new(4, 1);
728 let gaussian_parameter = Q::ONE;
729 let basis_gso = MatQ::from(&basis).gso();
730
731 let res = sample_d(&basis, ¢er, &gaussian_parameter);
732 let res_prec = sample_d_precomputed_gso(&basis, &basis_gso, ¢er, &gaussian_parameter);
733
734 assert!(res.is_err());
735 assert!(res_prec.is_err());
736 }
737
738 /// Checks whether `sample_d` returns an error if center isn't a column vector.
739 #[test]
740 fn center_not_column_vector() {
741 let basis = MatZ::identity(2, 2);
742 let center = MatQ::new(2, 2);
743 let gaussian_parameter = Q::ONE;
744 let basis_gso = MatQ::from(&basis).gso();
745
746 let res = sample_d(&basis, ¢er, &gaussian_parameter);
747 let res_prec = sample_d_precomputed_gso(&basis, &basis_gso, ¢er, &gaussian_parameter);
748
749 assert!(res.is_err());
750 assert!(res_prec.is_err());
751 }
752
753 /// Ensures that the concentration bound holds.
754 #[test]
755 fn concentration_bound() {
756 let n = Z::from(20);
757 let basis = MatZ::sample_uniform(&n, &n, 0, 5000).unwrap();
758 let orth = MatQ::from(&basis).gso();
759 let mut len = Q::ZERO;
760 for i in 0..orth.get_num_columns() {
761 let column = orth.get_column(i).unwrap();
762 let column_len = column.norm_eucl_sqrd().unwrap().sqrt();
763 if column_len > len {
764 len = column_len
765 }
766 }
767
768 let expl_text = String::from("This test can fail with probability close to 0.
769 It fails if the length of the sampled is longer than expected.
770 If this happens, rerun the tests several times and check whether this issue comes up again.");
771
772 let center = MatQ::new(&n, 1);
773 let gaussian_parameter =
774 len * n.log(2).unwrap().sqrt() * (n.log(2).unwrap().log(2).unwrap());
775
776 for _ in 0..20 {
777 let res = sample_d(&basis, ¢er, &gaussian_parameter).unwrap();
778 let res_prec =
779 sample_d_precomputed_gso(&basis, &orth, ¢er, &gaussian_parameter).unwrap();
780
781 assert!(
782 res.norm_eucl_sqrd().unwrap() <= gaussian_parameter.pow(2).unwrap().round() * &n,
783 "{expl_text}"
784 );
785 assert!(
786 res_prec.norm_eucl_sqrd().unwrap()
787 <= gaussian_parameter.pow(2).unwrap().round() * &n,
788 "{expl_text}"
789 );
790 }
791 }
792
793 /// Ensure that an orthogonalized base with not matching rows panics.
794 #[test]
795 #[should_panic]
796 fn precomputed_gso_mismatching_rows() {
797 let n = Z::from(20);
798 let basis = MatZ::sample_uniform(&n, &n, 0, 5000).unwrap();
799 let center = MatQ::new(&n, 1);
800 let false_gso = MatQ::new(basis.get_num_rows() + 1, basis.get_num_columns());
801
802 let _ = sample_d_precomputed_gso(&basis, &false_gso, ¢er, &Q::from(5)).unwrap();
803 }
804 /// Ensure that an orthogonalized base with not matching columns panics.
805 #[test]
806 #[should_panic]
807 fn precomputed_gso_mismatching_columns() {
808 let n = Z::from(20);
809 let basis = MatZ::sample_uniform(&n, &n, 0, 5000).unwrap();
810 let center = MatQ::new(&n, 1);
811 let false_gso = MatQ::new(basis.get_num_rows(), basis.get_num_columns() + 1);
812
813 let _ = sample_d_precomputed_gso(&basis, &false_gso, ¢er, &Q::from(5)).unwrap();
814 }
815}