use super::PSF;
use crate::sample::g_trapdoor::{
gadget_classical::short_basis_gadget,
gadget_classical::{find_solution_gadget_mat, gen_trapdoor},
gadget_parameters::GadgetParameters,
};
use qfall_math::{
integer::{MatZ, Z},
integer_mod_q::MatZq,
rational::{MatQ, Q},
traits::{Concatenate, MatrixDimensions, Pow},
};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
pub struct PSFPerturbation {
pub gp: GadgetParameters,
pub r: Q,
pub s: Q,
}
impl PSFPerturbation {
pub fn compute_sqrt_sigma_2(&self, mat_r: &MatZ, mat_sigma: &MatQ) -> MatQ {
let normalization_factor = 1.0 / (2.0 * Q::PI);
let full_td = mat_r
.concat_vertical(&MatZ::identity(
mat_sigma.get_num_rows() - mat_r.get_num_rows(),
mat_r.get_num_columns(),
))
.unwrap();
let mat_sigma_p: MatQ =
mat_sigma - (self.gp.base.pow(2).unwrap() + 1) * &full_td * full_td.transpose();
let sigma_2: MatQ = normalization_factor
* self.r.pow(2).unwrap()
* (&mat_sigma_p
- MatQ::identity(mat_sigma_p.get_num_rows(), mat_sigma_p.get_num_columns()));
sigma_2.cholesky_decomposition_flint()
}
}
pub(crate) fn randomized_nearest_plane_gadget(
psf: &PSFPerturbation,
vec_u: &MatZq,
short_basis_gadget: &MatZ,
short_basis_gadget_gso: &MatQ,
) -> MatZ {
let s = &psf.r * (psf.gp.base.pow(2).unwrap() + Z::ONE).sqrt();
let long_solution = find_solution_gadget_mat(vec_u, &psf.gp.k, &psf.gp.base);
let center = MatQ::from(&(-1 * &long_solution));
long_solution
+ MatZ::sample_d_precomputed_gso(short_basis_gadget, short_basis_gadget_gso, ¢er, &s)
.unwrap()
}
impl PSF for PSFPerturbation {
type A = MatZq;
type Trapdoor = (MatZ, MatQ, (MatZ, MatQ));
type Domain = MatZ;
type Range = MatZq;
fn trap_gen(&self) -> (MatZq, (MatZ, MatQ, (MatZ, MatQ))) {
let mat_a_bar = MatZq::sample_uniform(&self.gp.n, &self.gp.m_bar, &self.gp.q);
let tag = MatZq::identity(&self.gp.n, &self.gp.n, &self.gp.q);
let (mat_a, mat_r) = gen_trapdoor(&self.gp, &mat_a_bar, &tag).unwrap();
let mat_sqrt_sigma_2 = self.compute_sqrt_sigma_2(
&mat_r,
&(&self.s.pow(2).unwrap()
* MatQ::identity(mat_a.get_num_columns(), mat_a.get_num_columns())),
);
let short_basis_gadget = short_basis_gadget(&self.gp);
let short_basis_gadget_gso = MatQ::from(&short_basis_gadget).gso();
(
mat_a,
(
mat_r,
mat_sqrt_sigma_2,
(short_basis_gadget, short_basis_gadget_gso),
),
)
}
fn samp_d(&self) -> MatZ {
let m = &self.gp.n * &self.gp.k + &self.gp.m_bar;
MatZ::sample_discrete_gauss(m, 1, 0, &self.s * &self.r).unwrap()
}
fn samp_p(
&self,
mat_a: &MatZq,
(mat_r, mat_sqrt_sigma_2, (short_basis_gadget, short_basis_gadget_gso)): &(
MatZ,
MatQ,
(MatZ, MatQ),
),
vec_u: &MatZq,
) -> MatZ {
let vec_p = MatZ::sample_d_common_non_spherical(mat_sqrt_sigma_2, &self.r).unwrap();
let vec_v = vec_u - mat_a * &vec_p;
let vec_z = randomized_nearest_plane_gadget(
self,
&vec_v,
short_basis_gadget,
short_basis_gadget_gso,
);
let full_td = mat_r
.concat_vertical(&MatZ::identity(
mat_r.get_num_columns(),
mat_r.get_num_columns(),
))
.unwrap();
vec_p + full_td * vec_z
}
fn f_a(&self, mat_a: &MatZq, sigma: &MatZ) -> MatZq {
assert!(self.check_domain(sigma));
mat_a * sigma
}
fn check_domain(&self, sigma: &MatZ) -> bool {
let m = &self.gp.n * &self.gp.k + &self.gp.m_bar;
sigma.is_column_vector()
&& m == sigma.get_num_rows()
&& sigma.norm_eucl_sqrd().unwrap()
<= self.s.pow(2).unwrap() * &m * &self.r.pow(2).unwrap()
}
}
#[cfg(test)]
mod test_psf_perturbation {
use super::PSF;
use super::PSFPerturbation;
use crate::sample::g_trapdoor::gadget_parameters::GadgetParameters;
use qfall_math::integer::MatZ;
use qfall_math::rational::Q;
use qfall_math::traits::*;
#[test]
fn samp_d_samples_from_dn() {
for (n, q) in [(5, 256), (10, 128), (15, 157)] {
let psf = PSFPerturbation {
gp: GadgetParameters::init_default(n, q),
r: Q::from(n).log(2).unwrap(),
s: Q::from(25),
};
for _ in 0..5 {
assert!(psf.check_domain(&psf.samp_d()));
}
}
}
#[test]
fn samp_p_preimage_and_domain() {
for (n, q) in [(5, 256), (6, 128)] {
let psf = PSFPerturbation {
gp: GadgetParameters::init_default(n, q),
r: Q::from(n).log(2).unwrap(),
s: Q::from(25),
};
let (a, r) = psf.trap_gen();
let domain_sample = psf.samp_d();
let range_fa = psf.f_a(&a, &domain_sample);
let preimage = psf.samp_p(&a, &r, &range_fa);
assert_eq!(range_fa, psf.f_a(&a, &preimage));
assert!(psf.check_domain(&preimage));
}
}
#[test]
fn f_a_works_as_expected() {
for (n, q) in [(5, 256), (6, 128)] {
let psf = PSFPerturbation {
gp: GadgetParameters::init_default(n, q),
r: Q::from(n).log(2).unwrap(),
s: Q::from(25),
};
let (a, _) = psf.trap_gen();
let domain_sample = psf.samp_d();
assert_eq!(&a * &domain_sample, psf.f_a(&a, &domain_sample));
}
}
#[test]
#[should_panic]
fn f_a_sigma_not_in_domain_matrix() {
let psf = PSFPerturbation {
gp: GadgetParameters::init_default(8, 128),
r: Q::from(8).log(2).unwrap(),
s: Q::from(25),
};
let (a, _) = psf.trap_gen();
let not_in_domain = MatZ::new(a.get_num_columns(), 2);
let _ = psf.f_a(&a, ¬_in_domain);
}
#[test]
#[should_panic]
fn f_a_sigma_not_in_domain_incorrect_length() {
let psf = PSFPerturbation {
gp: GadgetParameters::init_default(8, 128),
r: Q::from(8).log(2).unwrap(),
s: Q::from(25),
};
let (a, _) = psf.trap_gen();
let not_in_domain = MatZ::new(a.get_num_columns() - 1, 1);
let _ = psf.f_a(&a, ¬_in_domain);
}
#[test]
#[should_panic]
fn f_a_sigma_not_in_domain_too_long() {
let psf = PSFPerturbation {
gp: GadgetParameters::init_default(8, 128),
r: Q::from(8).log(2).unwrap(),
s: Q::from(25),
};
let (a, _) = psf.trap_gen();
let not_in_domain =
psf.s.round() * a.get_num_columns() * MatZ::identity(a.get_num_columns(), 1);
let _ = psf.f_a(&a, ¬_in_domain);
}
#[test]
fn check_domain_as_expected() {
let psf = PSFPerturbation {
gp: GadgetParameters::init_default(8, 128),
r: Q::from(8).log(2).unwrap(),
s: Q::from(25),
};
let (a, _) = psf.trap_gen();
let value = psf.s.round();
let mut in_domain = MatZ::new(a.get_num_columns(), 1);
for i in 0..in_domain.get_num_rows() {
in_domain.set_entry(i, 0, &value).unwrap();
}
assert!(psf.check_domain(&MatZ::new(a.get_num_columns(), 1)));
assert!(psf.check_domain(&in_domain));
}
#[test]
fn check_domain_not_in_dn() {
let psf = PSFPerturbation {
gp: GadgetParameters::init_default(8, 128),
r: Q::from(8).log(2).unwrap(),
s: Q::from(25),
};
let (a, _) = psf.trap_gen();
let matrix = MatZ::new(a.get_num_columns(), 2);
let too_short = MatZ::new(a.get_num_columns() - 1, 1);
let too_long = MatZ::new(a.get_num_columns() + 1, 1);
let entry_too_large =
psf.s.round() * a.get_num_columns() * MatZ::identity(a.get_num_columns(), 1);
assert!(!psf.check_domain(&matrix));
assert!(!psf.check_domain(&too_long));
assert!(!psf.check_domain(&too_short));
assert!(!psf.check_domain(&entry_too_large));
}
}