use super::PSF;
use crate::sample::g_trapdoor::{
gadget_classical::gen_trapdoor, gadget_parameters::GadgetParameters,
short_basis_classical::gen_short_basis_for_trapdoor,
};
use qfall_math::{
integer::MatZ,
integer_mod_q::MatZq,
rational::{MatQ, Q},
traits::{MatrixDimensions, Pow},
};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
pub struct PSFGPV {
pub gp: GadgetParameters,
pub s: Q,
}
impl PSF for PSFGPV {
type A = MatZq;
type Trapdoor = (MatZ, MatQ);
type Domain = MatZ;
type Range = MatZq;
fn trap_gen(&self) -> (MatZq, (MatZ, MatQ)) {
let a_bar = MatZq::sample_uniform(&self.gp.n, &self.gp.m_bar, &self.gp.q);
let tag = MatZq::identity(&self.gp.n, &self.gp.n, &self.gp.q);
let (a, r) = gen_trapdoor(&self.gp, &a_bar, &tag).unwrap();
let short_base = gen_short_basis_for_trapdoor(&self.gp, &tag, &a, &r);
let short_base_gso = MatQ::from(&short_base).gso();
(a, (short_base, short_base_gso))
}
fn samp_d(&self) -> MatZ {
let m = &self.gp.n * &self.gp.k + &self.gp.m_bar;
MatZ::sample_discrete_gauss(&m, 1, 0, &self.s).unwrap()
}
fn samp_p(&self, a: &MatZq, (short_base, short_base_gso): &(MatZ, MatQ), u: &MatZq) -> MatZ {
let sol: MatZ = a
.solve_gaussian_elimination(u)
.unwrap()
.get_representative_least_nonnegative_residue();
let center = MatQ::from(&(-1 * &sol));
sol + MatZ::sample_d_precomputed_gso(short_base, short_base_gso, ¢er, &self.s).unwrap()
}
fn f_a(&self, a: &MatZq, sigma: &MatZ) -> MatZq {
assert!(self.check_domain(sigma));
a * sigma
}
fn check_domain(&self, sigma: &MatZ) -> bool {
let m = &self.gp.n * &self.gp.k + &self.gp.m_bar;
sigma.is_column_vector()
&& m == sigma.get_num_rows()
&& sigma.norm_eucl_sqrd().unwrap() <= self.s.pow(2).unwrap() * &m
}
}
#[cfg(test)]
mod test_gpv_psf {
use super::super::gpv::PSFGPV;
use super::PSF;
use crate::sample::g_trapdoor::gadget_parameters::GadgetParameters;
use qfall_math::integer::MatZ;
use qfall_math::rational::Q;
use qfall_math::traits::*;
#[test]
fn samp_d_samples_from_dn() {
for (n, q) in [(5, 256), (10, 128), (15, 157)] {
let psf = PSFGPV {
gp: GadgetParameters::init_default(n, q),
s: Q::from(10),
};
for _ in 0..5 {
assert!(psf.check_domain(&psf.samp_d()));
}
}
}
#[test]
fn samp_p_preimage_and_domain() {
for (n, q) in [(5, 256), (6, 128)] {
let psf = PSFGPV {
gp: GadgetParameters::init_default(n, q),
s: Q::from(10),
};
let (a, r) = psf.trap_gen();
let domain_sample = psf.samp_d();
let range_fa = psf.f_a(&a, &domain_sample);
let preimage = psf.samp_p(&a, &r, &range_fa);
assert_eq!(range_fa, psf.f_a(&a, &preimage));
assert!(psf.check_domain(&preimage));
}
}
#[test]
fn f_a_works_as_expected() {
for (n, q) in [(5, 256), (6, 128)] {
let psf = PSFGPV {
gp: GadgetParameters::init_default(n, q),
s: Q::from(10),
};
let (a, _) = psf.trap_gen();
let domain_sample = psf.samp_d();
assert_eq!(&a * &domain_sample, psf.f_a(&a, &domain_sample));
}
}
#[test]
#[should_panic]
fn f_a_sigma_not_in_domain_matrix() {
let psf = PSFGPV {
gp: GadgetParameters::init_default(8, 128),
s: Q::from(10),
};
let (a, _) = psf.trap_gen();
let not_in_domain = MatZ::new(a.get_num_columns(), 2);
let _ = psf.f_a(&a, ¬_in_domain);
}
#[test]
#[should_panic]
fn f_a_sigma_not_in_domain_incorrect_length() {
let psf = PSFGPV {
gp: GadgetParameters::init_default(8, 128),
s: Q::from(10),
};
let (a, _) = psf.trap_gen();
let not_in_domain = MatZ::new(a.get_num_columns() - 1, 1);
let _ = psf.f_a(&a, ¬_in_domain);
}
#[test]
#[should_panic]
fn f_a_sigma_not_in_domain_too_long() {
let psf = PSFGPV {
gp: GadgetParameters::init_default(8, 128),
s: Q::from(10),
};
let (a, _) = psf.trap_gen();
let not_in_domain =
psf.s.round() * a.get_num_columns() * MatZ::identity(a.get_num_columns(), 1);
let _ = psf.f_a(&a, ¬_in_domain);
}
#[test]
fn check_domain_as_expected() {
let psf = PSFGPV {
gp: GadgetParameters::init_default(8, 128),
s: Q::from(10),
};
let (a, _) = psf.trap_gen();
let value = psf.s.round();
let mut in_domain = MatZ::new(a.get_num_columns(), 1);
for i in 0..in_domain.get_num_rows() {
in_domain.set_entry(i, 0, &value).unwrap();
}
assert!(psf.check_domain(&MatZ::new(a.get_num_columns(), 1)));
assert!(psf.check_domain(&in_domain));
}
#[test]
fn check_domain_not_in_dn() {
let psf = PSFGPV {
gp: GadgetParameters::init_default(8, 128),
s: Q::from(10),
};
let (a, _) = psf.trap_gen();
let matrix = MatZ::new(a.get_num_columns(), 2);
let too_short = MatZ::new(a.get_num_columns() - 1, 1);
let too_long = MatZ::new(a.get_num_columns() + 1, 1);
let entry_too_large =
psf.s.round() * a.get_num_columns() * MatZ::identity(a.get_num_columns(), 1);
assert!(!psf.check_domain(&matrix));
assert!(!psf.check_domain(&too_long));
assert!(!psf.check_domain(&too_short));
assert!(!psf.check_domain(&entry_too_large));
}
}