qfall_math/utils/sample/
discrete_gauss.rs

1// Copyright © 2025 Niklas Siemer
2//
3// This file is part of qFALL-math.
4//
5// qFALL-math is free software: you can redistribute it and/or modify it under
6// the terms of the Mozilla Public License Version 2.0 as published by the
7// Mozilla Foundation. See <https://mozilla.org/en-US/MPL/2.0/>.
8
9//! This module includes core functionality to sample according to the
10//! discrete gaussian distribution.
11//!
12//! The main references are listed in the following
13//! and will be further referenced in submodules by these numbers:
14//! - \[1\] Gentry, Craig and Peikert, Chris and Vaikuntanathan, Vinod (2008).
15//!   Trapdoors for hard lattices and new cryptographic constructions.
16//!   In: Proceedings of the fortieth annual ACM symposium on Theory of computing.
17//!   <https://citeseerx.ist.psu.edu/document?doi=d9f54077d568784c786f7b1d030b00493eb3ae35>
18
19use super::uniform::UniformIntegerSampler;
20use crate::{
21    error::{MathError, StringConversionError},
22    integer::{MatZ, Z},
23    rational::{MatQ, Q},
24    traits::{MatrixDimensions, MatrixGetSubmatrix, Pow},
25};
26use rand::Rng;
27use serde::{Deserialize, Serialize};
28use std::collections::HashMap;
29
30/// Defines whether a lookup-table should be precomputed, filled on-the-fly,
31/// or not used at all for a discrete Gaussian sampler.
32#[derive(PartialEq, Clone, Copy, Serialize, Deserialize, Debug)]
33pub enum LookupTableSetting {
34    Precompute,
35    FillOnTheFly,
36    NoLookup,
37}
38
39/// This is the global variable used in all `sample_discrete_gauss` and `sample_d`
40/// functions. Its value should be in `ω(log(sqrt(n)))`. We set it (as most other libraries)
41/// statically to `6.0`.
42///
43/// You can use and change in an `unsafe` environment.
44/// ```compile_fail
45/// unsafe { TAILCUT = 4.0 };
46/// ```
47/// Make sure that the tailcut stays positive and large enough for your purposes.
48/// If you use multi-threading, read up on the behaviour of a `static mut` variable.
49/// Our tests only cover cases where `TAILCUT = 6.0`.
50pub static mut TAILCUT: f64 = 6.0;
51
52/// Enables for discrete Gaussian sampling out of
53/// `[⌈center - s * tailcut⌉ , ⌊center + s * tailcut⌋ ]`.
54///
55/// **WARNING:** If the attributes are not set using [`DiscreteGaussianIntegerSampler::init`],
56/// we can't guarantee sampling from the correct discrete Gaussian distribution.
57/// Altering any value will invalidate the [`HashMap`] in `table` and might invalidate
58/// other attributes, too.
59///
60/// Attributes:
61/// - `center`: specifies the position of the center with peak probability
62/// - `s`: specifies the Gaussian parameter, which is proportional
63///   to the standard deviation `sigma * sqrt(2 * pi) = s`
64/// - `lower_bound`: specifies the lower bound to sample uniformly from
65/// - `interval_size`: specifies the interval size to sample uniformly from
66/// - `lookup_table_setting`: Specifies whether a lookup-table should be used and
67///   how it should be filled, i.e. lazily on-the-fly (impacting sampling time slightly) or precomputed
68/// - `table`: the lookup-table if one is used
69///
70/// # Examples
71/// ```
72/// use qfall_math::{integer::Z, rational::Q};
73/// use qfall_math::utils::sample::discrete_gauss::{DiscreteGaussianIntegerSampler, LookupTableSetting};
74/// let n = Z::from(1024);
75/// let center = 0.0;
76/// let gaussian_parameter = 1.0;
77/// let tailcut = 6.0;
78///
79/// let mut dgis = DiscreteGaussianIntegerSampler::init(center, gaussian_parameter, tailcut, LookupTableSetting::NoLookup).unwrap();
80///
81/// let sample = dgis.sample_z();
82/// ```
83#[derive(Debug, Serialize, Deserialize, Clone)]
84pub struct DiscreteGaussianIntegerSampler {
85    pub center: Q,
86    pub s: Q,
87    pub lower_bound: Z,
88    pub interval_size: Z,
89    pub lookup_table_setting: LookupTableSetting,
90    pub table: HashMap<Z, f64>,
91}
92
93impl DiscreteGaussianIntegerSampler {
94    /// Initializes the [`DiscreteGaussianIntegerSampler`] with
95    /// - `center` as the center of the discrete Gaussian to sample from,
96    /// - `s` defining the Gaussian parameter, which is proportional
97    ///   to the standard deviation `sigma * sqrt(2 * pi) = s`,
98    /// - `lower_bound` as `⌈center - 6 * s⌉`,
99    /// - `interval_size` as `⌊center + 6 * s⌋ - ⌈center - 6 * s⌉ + 1`, and
100    /// - `table` as an empty [`HashMap`] to store evaluations of the Gaussian function.
101    ///
102    /// Parameters:
103    /// - `n`: specifies the range from which is sampled
104    /// - `center`: as the center of the discrete Gaussian to sample from
105    /// - `s`: specifies the Gaussian parameter, which is proportional
106    ///   to the standard deviation `sigma * sqrt(2 * pi) = s`
107    ///
108    /// Returns a sample chosen according to the specified discrete Gaussian distribution or
109    /// a [`MathError`] if the specified parameters were not chosen appropriately,
110    /// i.e. `n > 1` or `s > 0`.
111    ///
112    /// # Examples
113    /// ```
114    /// use qfall_math::{integer::Z, rational::Q};
115    /// use qfall_math::utils::sample::discrete_gauss::{DiscreteGaussianIntegerSampler, LookupTableSetting};
116    /// let center = 0.0;
117    /// let gaussian_parameter = 1.0;
118    /// let tailcut = 6.0;
119    ///
120    /// let mut dgis = DiscreteGaussianIntegerSampler::init(center, gaussian_parameter, tailcut, LookupTableSetting::Precompute).unwrap();
121    /// ```
122    ///
123    /// # Errors and Failures
124    /// - Returns a [`MathError`] of type [`InvalidIntegerInput`](MathError::InvalidIntegerInput)
125    ///   if `tailcut < 0` or `s < 0`.
126    pub fn init(
127        center: impl Into<Q>,
128        s: impl Into<Q>,
129        tailcut: impl Into<Q>,
130        lookup_table_setting: LookupTableSetting,
131    ) -> Result<Self, MathError> {
132        let center = center.into();
133        let mut s = s.into();
134        let tailcut = tailcut.into();
135        if tailcut < Q::ZERO {
136            return Err(MathError::InvalidIntegerInput(format!(
137                "The value {tailcut} was provided for parameter tailcut of the function sample_z.
138                This function expects this input no smaller than 0."
139            )));
140        }
141        if s < Q::ZERO {
142            return Err(MathError::InvalidIntegerInput(format!(
143                "The value {s} was provided for parameter s of the function sample_z.
144                This function expects this input to be no smaller than 0."
145            )));
146        }
147        if s == Q::ZERO {
148            // Ensure that s != 0 s.t. we can divide by s^2 in the Gaussian function
149            s = Q::from(0.00001);
150        }
151
152        let lower_bound = (&center - &s * &tailcut).ceil();
153        let upper_bound = (&center + &s * tailcut).floor();
154        // interval [lower_bound, upper_bound] has upper_bound - lower_bound + 1 elements in it
155        let interval_size = &upper_bound - &lower_bound + Z::ONE;
156
157        let mut table = HashMap::new();
158
159        if lookup_table_setting != LookupTableSetting::NoLookup && interval_size > u16::MAX {
160            println!(
161                "WARNING: A completely filled lookup table will exceed 2^16 entries. You should reconsider your sampling method for discrete Gaussians."
162            )
163        }
164
165        if lookup_table_setting == LookupTableSetting::Precompute {
166            let mut i = lower_bound.clone();
167            while i <= upper_bound {
168                let evaluated_gauss_function = gaussian_function(&i, &center, &s);
169                table.insert(i.clone(), evaluated_gauss_function);
170                i += Z::ONE;
171            }
172        }
173
174        Ok(Self {
175            center,
176            s,
177            lower_bound,
178            interval_size,
179            lookup_table_setting,
180            table,
181        })
182    }
183
184    /// Chooses a sample according to the discrete Gaussian distribution out of
185    /// `[lower_bound , lower_bound + interval_size ]`.
186    ///
187    /// This function implements discrete Gaussian sampling according to the definition of
188    /// SampleZ as in [\[1\]](<index.html#:~:text=[1]>).
189    ///
190    /// # Examples
191    /// ```
192    /// use qfall_math::{integer::Z, rational::Q};
193    /// use qfall_math::utils::sample::discrete_gauss::{DiscreteGaussianIntegerSampler, LookupTableSetting};
194    /// let center = 0.0;
195    /// let gaussian_parameter = 1.0;
196    /// let tailcut = 6.0;
197    ///
198    /// let mut dgis = DiscreteGaussianIntegerSampler::init(center, gaussian_parameter, tailcut, LookupTableSetting::Precompute).unwrap();
199    ///
200    /// let sample = dgis.sample_z();
201    /// ```
202    pub fn sample_z(&mut self) -> Z {
203        let mut rng = rand::rng();
204        let mut uis = UniformIntegerSampler::init(&self.interval_size).unwrap();
205        loop {
206            // sample x in [c - s * tailcut, c + s * tailcut]
207            let sample = &self.lower_bound + uis.sample();
208
209            let evaluated_gauss_function: &f64 = match self.lookup_table_setting {
210                LookupTableSetting::NoLookup => &gaussian_function(&sample, &self.center, &self.s),
211                LookupTableSetting::FillOnTheFly => {
212                    let pot_evaluated_gauss_function = self.table.get(&sample);
213                    match pot_evaluated_gauss_function {
214                        Some(x) => x,
215                        None => &{
216                            // if the entry doesn't exist yet, compute and insert it
217                            let evaluated_function =
218                                gaussian_function(&sample, &self.center, &self.s);
219                            self.table.insert(sample.clone(), evaluated_function);
220                            evaluated_function
221                        },
222                    }
223                }
224                LookupTableSetting::Precompute => self.table.get(&sample).unwrap(),
225            };
226
227            let random_f64: f64 = rng.random();
228            if evaluated_gauss_function >= &random_f64 {
229                return sample;
230            }
231        }
232    }
233}
234
235/// Computes the value of the Gaussian function for `x`.
236///
237/// **Warning**: This functions assumes `s != 0`.
238///
239/// Parameters:
240/// - `x`: specifies the value/ sample for which the Gaussian function's value is computed
241/// - `c`: specifies the position of the center with peak probability
242/// - `s`: specifies the Gaussian parameter, which is proportional
243///   to the standard deviation `sigma * sqrt(2 * pi) = s`
244///
245/// Returns the computed value of the Gaussian function for `x`.
246///
247/// # Examples
248/// ```
249/// use qfall_math::{integer::Z, rational::Q};
250/// use qfall_math::utils::sample::discrete_gauss::gaussian_function;
251/// let sample = Z::ONE;
252/// let center = Q::ZERO;
253/// let gaussian_parameter = Q::ONE;
254///
255/// let probability = gaussian_function(&sample, &center, &gaussian_parameter);
256/// ```
257///
258/// # Panics ...
259/// - if `s = 0`.
260/// - if `-π * (x - c)^2 / s^2` is larger than [`f64::MAX`]
261pub fn gaussian_function(x: &Z, c: &Q, s: &Q) -> f64 {
262    let num = Q::MINUS_ONE * Q::PI * (x - c).pow(2).unwrap();
263    let den = s.pow(2).unwrap();
264    let res = f64::from(&(num / den));
265    res.exp()
266}
267
268/// SampleD samples a discrete Gaussian from the lattice with `basis` using [`sample_z`] as a subroutine.
269///
270/// We do not check whether `basis` is actually a basis. Hence, the callee is
271/// responsible for making sure that `basis` provides a suitable basis.
272///
273/// Parameters:
274/// - `basis`: specifies a basis for the lattice from which is sampled
275/// - `n`: specifies the range from which [`sample_z`] samples
276/// - `center`: specifies the positions of the center with peak probability
277/// - `s`: specifies the Gaussian parameter, which is proportional
278///   to the standard deviation `sigma * sqrt(2 * pi) = s`
279///
280/// Returns a vector with discrete gaussian error based on a lattice point
281/// as in [\[1\]](<index.html#:~:text=[1]>): SampleD or a [`MathError`], if the
282/// `n <= 1` or `s <= 0`, the number of rows of the `basis` and `center` differ,
283/// or `center` is not a column vector.
284///
285/// # Examples
286/// ```compile_fail
287/// use qfall_math::{integer::{MatZ, Z}, rational::{MatQ, Q}};
288/// use qfall_math::utils::sample::discrete_gauss::sample_d;
289/// let basis = MatZ::identity(5, 5);
290/// let n = Z::from(1024);
291/// let center = MatQ::new(5, 1);
292/// let gaussian_parameter = Q::ONE;
293///
294/// let sample = sample_d(basis, &n, &center, &gaussian_parameter).unwrap();
295/// ```
296///
297/// # Errors and Failures
298/// - Returns a [`MathError`] of type [`InvalidIntegerInput`](MathError::InvalidIntegerInput)
299///   if `n <= 1` or `s <= 0`.
300/// - Returns a [`MathError`] of type [`MismatchingMatrixDimension`](MathError::MismatchingMatrixDimension)
301///   if the number of rows of the `basis` and `center` differ.
302/// - Returns a [`MathError`] of type [`StringConversionError`](MathError::StringConversionError)
303///   if `center` is not a column vector.
304pub(crate) fn sample_d(basis: &MatZ, center: &MatQ, s: &Q) -> Result<MatZ, MathError> {
305    let basis_gso = MatQ::from(basis).gso();
306    sample_d_precomputed_gso(basis, &basis_gso, center, s)
307}
308
309/// SampleD samples a discrete Gaussian from the lattice with `basis` using [`sample_z`] as a subroutine.
310///
311/// We do not check whether `basis` is actually a basis or whether `basis_gso` is
312/// actually the gso of `basis`. Hence, the callee is responsible for making sure that
313/// `basis` provides a suitable basis and `basis_gso` is a corresponding GSO.
314///
315/// Parameters:
316/// - `basis`: specifies a basis for the lattice from which is sampled
317/// - `basis_gso`: specifies the precomputed gso for `basis`
318/// - `n`: specifies the range from which [`sample_z`] samples
319/// - `center`: specifies the positions of the center with peak probability
320/// - `s`: specifies the Gaussian parameter, which is proportional
321///   to the standard deviation `sigma * sqrt(2 * pi) = s`
322///
323/// Returns a vector with discrete gaussian error based on a lattice point
324/// as in [\[1\]](<index.html#:~:text=[1]>): SampleD or a [`MathError`], if the
325/// `n <= 1` or `s <= 0`, the number of rows of the `basis` and `center` differ,
326/// or `center` is not a column vector.
327///
328/// # Examples
329/// ```compile_fail
330/// use qfall_math::{integer::{MatZ, Z}, rational::{MatQ, Q}};
331/// use qfall_math::utils::sample::discrete_gauss::sample_d;
332/// let basis = MatZ::identity(5, 5);
333/// let n = Z::from(1024);
334/// let center = MatQ::new(5, 1);
335/// let gaussian_parameter = Q::ONE;
336///
337/// let basis_gso = basis.gso();
338///
339/// let sample = sample_d(basis, &basis_gso, &n, &center, &gaussian_parameter).unwrap();
340/// ```
341///
342/// # Errors and Failures
343/// - Returns a [`MathError`] of type [`InvalidIntegerInput`](MathError::InvalidIntegerInput)
344///   if `n <= 1` or `s <= 0`.
345/// - Returns a [`MathError`] of type [`MismatchingMatrixDimension`](MathError::MismatchingMatrixDimension)
346///   if the number of rows of the `basis` and `center` differ.
347/// - Returns a [`MathError`] of type [`StringConversionError`](MathError::StringConversionError)
348///   if `center` is not a column vector.
349///
350/// # Panics ...
351/// - if the number of rows/columns of `basis_gso` and `basis` mismatch.
352pub(crate) fn sample_d_precomputed_gso(
353    basis: &MatZ,
354    basis_gso: &MatQ,
355    center: &MatQ,
356    s: &Q,
357) -> Result<MatZ, MathError> {
358    let mut center = center.clone();
359    assert_eq!(
360        basis.get_num_rows(),
361        basis_gso.get_num_rows(),
362        "The provided gso can not be based on the provided base, \
363        as they do not have the same number of rows."
364    );
365    assert_eq!(
366        basis.get_num_columns(),
367        basis_gso.get_num_columns(),
368        "The provided gso can not be based on the provided base, \
369        as they do not have the same number of columns."
370    );
371    if center.get_num_rows() != basis.get_num_rows() {
372        return Err(MathError::MismatchingMatrixDimension(format!(
373            "sample_d requires center and basis to have the same number of columns, but they were {} and {}.",
374            center.get_num_rows(),
375            basis.get_num_rows()
376        )));
377    }
378    if !center.is_column_vector() {
379        Err(StringConversionError::InvalidMatrix(format!(
380            "sample_d expects center to be a column vector, but it has dimensions {}x{}.",
381            center.get_num_rows(),
382            center.get_num_columns()
383        )))?;
384    }
385    if s < &Q::ZERO {
386        return Err(MathError::InvalidIntegerInput(format!(
387            "The value {s} was provided for parameter s of the function sample_z.
388            This function expects this input to be larger than 0."
389        )));
390    }
391
392    let mut out = MatZ::new(basis_gso.get_num_rows(), 1);
393
394    for i in (0..basis_gso.get_num_columns()).rev() {
395        // basisvector_i = b_tilde[i]
396        let basisvector_orth_i = unsafe { basis_gso.get_column_unchecked(i) };
397
398        // define the center for sample_z as c_2 = <c, b_tilde[i]> / <b_tilde[i], b_tilde[i]>;
399        let c_2 = center.dot_product(&basisvector_orth_i).unwrap()
400            / basisvector_orth_i.dot_product(&basisvector_orth_i).unwrap();
401
402        // Defines the gaussian parameter to be normalized along the basis vector: s2 = s / ||b_tilde[i]||
403        let s_2 = s / (basisvector_orth_i.norm_eucl_sqrd().unwrap().sqrt());
404
405        // sample z ~ D_{Z, s2, c2}
406        let mut dgis = DiscreteGaussianIntegerSampler::init(
407            &c_2,
408            &s_2,
409            unsafe { TAILCUT },
410            LookupTableSetting::FillOnTheFly,
411        )?;
412        let z = dgis.sample_z();
413
414        // update the center c = c - z * b[i]
415        let basisvector_i = unsafe { basis.get_column_unchecked(i) };
416        center -= MatQ::from(&(&z * &basisvector_i));
417
418        // out = out + z * b[i]
419        out = &out + &z * &basisvector_i;
420    }
421
422    Ok(out)
423}
424
425#[cfg(test)]
426mod test_discrete_gaussian_integer_sampler {
427    use super::DiscreteGaussianIntegerSampler;
428    use crate::{
429        rational::Q,
430        utils::sample::discrete_gauss::{LookupTableSetting, TAILCUT},
431    };
432
433    /// Checks whether samples are kept in correct interval for a small interval.
434    #[test]
435    fn small_interval() {
436        let center = Q::from(15);
437        let gaussian_parameter = Q::from((1, 2));
438
439        let mut dgis = DiscreteGaussianIntegerSampler::init(
440            &center,
441            &gaussian_parameter,
442            8.0,
443            LookupTableSetting::FillOnTheFly,
444        )
445        .unwrap();
446
447        for _ in 0..64 {
448            let sample = dgis.sample_z();
449
450            assert!(10 <= sample);
451            assert!(sample <= 20);
452        }
453    }
454
455    /// Checks whether samples are kept in correct interval for a large interval.
456    #[test]
457    fn large_interval() {
458        let center = Q::MINUS_ONE;
459        let gaussian_parameter = Q::ONE;
460
461        let mut dgis = DiscreteGaussianIntegerSampler::init(
462            &center,
463            &gaussian_parameter,
464            unsafe { TAILCUT },
465            LookupTableSetting::FillOnTheFly,
466        )
467        .unwrap();
468
469        for _ in 0..256 {
470            let sample = dgis.sample_z();
471
472            assert!(-64 <= sample);
473            assert!(sample <= 62);
474        }
475    }
476
477    /// Checks whether `sample_z` returns an error if the gaussian parameter `s < 0`.
478    #[test]
479    fn invalid_gaussian_parameter() {
480        let center = Q::ZERO;
481
482        assert!(
483            DiscreteGaussianIntegerSampler::init(
484                &center,
485                &Q::MINUS_ONE,
486                6.0,
487                LookupTableSetting::FillOnTheFly
488            )
489            .is_err()
490        );
491        assert!(
492            DiscreteGaussianIntegerSampler::init(
493                &center,
494                &Q::from(i64::MIN),
495                6.0,
496                LookupTableSetting::FillOnTheFly
497            )
498            .is_err()
499        );
500    }
501
502    /// Checks whether `sample_z` returns an error if `n < 0`.
503    #[test]
504    fn invalid_tailcut() {
505        let center = Q::MINUS_ONE;
506        let gaussian_parameter = Q::ONE;
507
508        assert!(
509            DiscreteGaussianIntegerSampler::init(
510                &center,
511                &gaussian_parameter,
512                -0.1,
513                LookupTableSetting::FillOnTheFly
514            )
515            .is_err()
516        );
517        assert!(
518            DiscreteGaussianIntegerSampler::init(
519                &center,
520                &gaussian_parameter,
521                i64::MIN,
522                LookupTableSetting::FillOnTheFly
523            )
524            .is_err()
525        );
526    }
527}
528
529#[cfg(test)]
530mod test_gaussian_function {
531    use super::{Q, Z, gaussian_function};
532    use crate::traits::Distance;
533
534    /// Ensures that the doc test would run properly.
535    #[test]
536    fn doc_test() {
537        let sample = Z::ONE;
538        let center = Q::ZERO;
539        let gaussian_parameter = Q::ONE;
540        // result roughly 0.0432139 computed via WolframAlpha
541        let cmp = Q::from((43214, 1_000_000));
542
543        let value = gaussian_function(&sample, &center, &gaussian_parameter);
544
545        assert!(cmp.distance(&Q::from(value)) < Q::from((1, 1_000_000)));
546    }
547
548    /// Checks whether the values for small values are computed appropriately
549    /// and with appropriate precision.
550    #[test]
551    fn small_values() {
552        let sample_0 = Z::ZERO;
553        let sample_1 = Z::MINUS_ONE;
554        let center = Q::MINUS_ONE;
555        let gaussian_parameter_0 = Q::from((1, 2));
556        let gaussian_parameter_1 = Q::from((3, 2));
557        // result roughly 0.00000348734 computed via WolframAlpha
558        let cmp_0 = Q::from((349, 100_000_000));
559        // result 0.247520 computed via WolframAlpha
560        let cmp_1 = Q::from((24752, 100_000));
561
562        let res_0 = gaussian_function(&sample_0, &center, &gaussian_parameter_0);
563        let res_1 = gaussian_function(&sample_0, &center, &gaussian_parameter_1);
564        let res_2 = gaussian_function(&sample_1, &center, &gaussian_parameter_0);
565        let res_3 = gaussian_function(&sample_1, &center, &gaussian_parameter_1);
566
567        assert!(cmp_0.distance(&Q::from(res_0)) < Q::from((3, 1_000_000_000)));
568        assert!(cmp_1.distance(&Q::from(res_1)) < Q::from((1, 1_000_000)));
569        assert_eq!(1.0, res_2);
570        assert_eq!(1.0, res_3);
571    }
572
573    /// Checks whether the values for large values are computed appropriately
574    /// and with appropriate precision.
575    #[test]
576    fn large_values() {
577        let sample = Z::from(i64::MAX);
578        let center = Q::from(i64::MAX as u64 + 1);
579        let gaussian_parameter = Q::from((1, 2));
580        // result roughly 0.00000348734 computed via WolframAlpha
581        let cmp = Q::from((349, 100_000_000));
582
583        let res = gaussian_function(&sample, &center, &gaussian_parameter);
584
585        assert!(cmp.distance(&Q::from(res)) < Q::from((3, 1_000_000_000)));
586    }
587
588    /// Checks whether `s = 0` results in a panic.
589    #[test]
590    #[should_panic]
591    fn invalid_s() {
592        let sample = Z::from(i64::MAX);
593        let center = Q::from(i64::MAX as u64 + 1);
594        let gaussian_parameter = Q::ZERO;
595
596        let _ = gaussian_function(&sample, &center, &gaussian_parameter);
597    }
598}
599
600#[cfg(test)]
601mod test_sample_d {
602    use super::sample_d_precomputed_gso;
603    use crate::traits::{Concatenate, MatrixDimensions, MatrixGetSubmatrix, Pow};
604    use crate::utils::sample::discrete_gauss::sample_d;
605    use crate::{
606        integer::{MatZ, Z},
607        rational::{MatQ, Q},
608    };
609    use flint_sys::fmpz_mat::fmpz_mat_hnf;
610    use std::str::FromStr;
611
612    /// Ensures that the doc-test compiles and runs properly.
613    #[test]
614    fn doc_test() {
615        let basis = MatZ::identity(5, 5);
616        let center = MatQ::new(5, 1);
617        let gaussian_parameter = Q::ONE;
618        let basis_gso = MatQ::from(&basis).gso();
619
620        let _ = sample_d(&basis, &center, &gaussian_parameter).unwrap();
621        let _ = sample_d_precomputed_gso(&basis, &basis_gso, &center, &gaussian_parameter).unwrap();
622    }
623
624    /// Ensures that `sample_d` works properly for a non-zero center.
625    #[test]
626    fn non_zero_center() {
627        let basis = MatZ::identity(5, 5);
628        let center = MatQ::identity(5, 1);
629        let gaussian_parameter = Q::ONE;
630        let basis_gso = MatQ::from(&basis).gso();
631
632        let _ = sample_d(&basis, &center, &gaussian_parameter).unwrap();
633        let _ = sample_d_precomputed_gso(&basis, &basis_gso, &center, &gaussian_parameter).unwrap();
634    }
635
636    /// Ensures that `sample_d` works properly for a different basis.
637    #[test]
638    fn non_identity_basis() {
639        let basis = MatZ::from_str("[[2, 1],[1, 2]]").unwrap();
640        let center = MatQ::new(2, 1);
641        let gaussian_parameter = Q::ONE;
642        let basis_gso = MatQ::from(&basis).gso();
643
644        let _ = sample_d(&basis, &center, &gaussian_parameter).unwrap();
645        let _ = sample_d_precomputed_gso(&basis, &basis_gso, &center, &gaussian_parameter).unwrap();
646    }
647
648    /// Ensures that `sample_d` outputs a vector that's part of the specified lattice.
649    ///
650    /// Checks whether the Hermite Normal Form HNF of the basis is equal to the HNF of
651    /// the basis concatenated with the sampled vector. If it is part of the lattice, it
652    /// should become a zero vector at the end of the matrix.
653    #[test]
654    fn point_of_lattice() {
655        let basis = MatZ::from_str("[[7, 0],[7, 3]]").unwrap();
656        let center = MatQ::new(2, 1);
657        let gaussian_parameter = Q::ONE;
658        let basis_gso = MatQ::from(&basis).gso();
659
660        let sample = sample_d(&basis, &center, &gaussian_parameter).unwrap();
661        let sample_prec =
662            sample_d_precomputed_gso(&basis, &basis_gso, &center, &gaussian_parameter).unwrap();
663
664        // check whether hermite normal form of HNF(b) = HNF([b|sample_vector])
665        let basis_concat_sample = basis.concat_horizontal(&sample).unwrap();
666        let basis_concat_sample_prec = basis.concat_horizontal(&sample_prec).unwrap();
667        let mut hnf_basis = MatZ::new(2, 2);
668        unsafe { fmpz_mat_hnf(&mut hnf_basis.matrix, &basis.matrix) };
669        let mut hnf_basis_concat_sample = MatZ::new(2, 3);
670        let mut hnf_basis_concat_sample_prec = MatZ::new(2, 3);
671        unsafe {
672            fmpz_mat_hnf(
673                &mut hnf_basis_concat_sample.matrix,
674                &basis_concat_sample.matrix,
675            )
676        };
677        unsafe {
678            fmpz_mat_hnf(
679                &mut hnf_basis_concat_sample_prec.matrix,
680                &basis_concat_sample_prec.matrix,
681            )
682        };
683        assert_eq!(
684            hnf_basis.get_column(0).unwrap(),
685            hnf_basis_concat_sample.get_column(0).unwrap()
686        );
687        assert_eq!(
688            hnf_basis.get_column(0).unwrap(),
689            hnf_basis_concat_sample_prec.get_column(0).unwrap()
690        );
691        assert_eq!(
692            hnf_basis.get_column(1).unwrap(),
693            hnf_basis_concat_sample.get_column(1).unwrap()
694        );
695        assert_eq!(
696            hnf_basis.get_column(1).unwrap(),
697            hnf_basis_concat_sample_prec.get_column(1).unwrap()
698        );
699        // check whether last vector is zero, i.e. was linearly dependent and part of lattice
700        assert!(hnf_basis_concat_sample.get_column(2).unwrap().is_zero());
701        assert!(
702            hnf_basis_concat_sample_prec
703                .get_column(2)
704                .unwrap()
705                .is_zero()
706        );
707    }
708
709    /// Checks whether `sample_d` returns an error if the gaussian parameter `s < 0`.
710    #[test]
711    fn invalid_gaussian_parameter() {
712        let basis = MatZ::identity(5, 5);
713        let center = MatQ::new(5, 1);
714        let basis_gso = MatQ::from(&basis).gso();
715
716        assert!(sample_d(&basis, &center, &Q::MINUS_ONE).is_err());
717        assert!(sample_d(&basis, &center, &Q::from(i64::MIN)).is_err());
718
719        assert!(sample_d_precomputed_gso(&basis, &basis_gso, &center, &Q::MINUS_ONE).is_err());
720        assert!(sample_d_precomputed_gso(&basis, &basis_gso, &center, &Q::from(i64::MIN)).is_err());
721    }
722
723    /// Checks whether `sample_d` returns an error if the basis and center number of rows differs.
724    #[test]
725    fn mismatching_matrix_dimensions() {
726        let basis = MatZ::identity(3, 5);
727        let center = MatQ::new(4, 1);
728        let gaussian_parameter = Q::ONE;
729        let basis_gso = MatQ::from(&basis).gso();
730
731        let res = sample_d(&basis, &center, &gaussian_parameter);
732        let res_prec = sample_d_precomputed_gso(&basis, &basis_gso, &center, &gaussian_parameter);
733
734        assert!(res.is_err());
735        assert!(res_prec.is_err());
736    }
737
738    /// Checks whether `sample_d` returns an error if center isn't a column vector.
739    #[test]
740    fn center_not_column_vector() {
741        let basis = MatZ::identity(2, 2);
742        let center = MatQ::new(2, 2);
743        let gaussian_parameter = Q::ONE;
744        let basis_gso = MatQ::from(&basis).gso();
745
746        let res = sample_d(&basis, &center, &gaussian_parameter);
747        let res_prec = sample_d_precomputed_gso(&basis, &basis_gso, &center, &gaussian_parameter);
748
749        assert!(res.is_err());
750        assert!(res_prec.is_err());
751    }
752
753    /// Ensures that the concentration bound holds.
754    #[test]
755    fn concentration_bound() {
756        let n = Z::from(20);
757        let basis = MatZ::sample_uniform(&n, &n, 0, 5000).unwrap();
758        let orth = MatQ::from(&basis).gso();
759        let mut len = Q::ZERO;
760        for i in 0..orth.get_num_columns() {
761            let column = orth.get_column(i).unwrap();
762            let column_len = column.norm_eucl_sqrd().unwrap().sqrt();
763            if column_len > len {
764                len = column_len
765            }
766        }
767
768        let expl_text = String::from("This test can fail with probability close to 0. 
769        It fails if the length of the sampled is longer than expected. 
770        If this happens, rerun the tests several times and check whether this issue comes up again.");
771
772        let center = MatQ::new(&n, 1);
773        let gaussian_parameter =
774            len * n.log(2).unwrap().sqrt() * (n.log(2).unwrap().log(2).unwrap());
775
776        for _ in 0..20 {
777            let res = sample_d(&basis, &center, &gaussian_parameter).unwrap();
778            let res_prec =
779                sample_d_precomputed_gso(&basis, &orth, &center, &gaussian_parameter).unwrap();
780
781            assert!(
782                res.norm_eucl_sqrd().unwrap() <= gaussian_parameter.pow(2).unwrap().round() * &n,
783                "{expl_text}"
784            );
785            assert!(
786                res_prec.norm_eucl_sqrd().unwrap()
787                    <= gaussian_parameter.pow(2).unwrap().round() * &n,
788                "{expl_text}"
789            );
790        }
791    }
792
793    /// Ensure that an orthogonalized base with not matching rows panics.
794    #[test]
795    #[should_panic]
796    fn precomputed_gso_mismatching_rows() {
797        let n = Z::from(20);
798        let basis = MatZ::sample_uniform(&n, &n, 0, 5000).unwrap();
799        let center = MatQ::new(&n, 1);
800        let false_gso = MatQ::new(basis.get_num_rows() + 1, basis.get_num_columns());
801
802        let _ = sample_d_precomputed_gso(&basis, &false_gso, &center, &Q::from(5)).unwrap();
803    }
804    /// Ensure that an orthogonalized base with not matching columns panics.
805    #[test]
806    #[should_panic]
807    fn precomputed_gso_mismatching_columns() {
808        let n = Z::from(20);
809        let basis = MatZ::sample_uniform(&n, &n, 0, 5000).unwrap();
810        let center = MatQ::new(&n, 1);
811        let false_gso = MatQ::new(basis.get_num_rows(), basis.get_num_columns() + 1);
812
813        let _ = sample_d_precomputed_gso(&basis, &false_gso, &center, &Q::from(5)).unwrap();
814    }
815}