use super::PSF;
use crate::{
sample::g_trapdoor::{
gadget_parameters::GadgetParametersRing, gadget_ring::gen_trapdoor_ring_lwe,
short_basis_ring::gen_short_basis_for_trapdoor_ring,
},
utils::rotation_matrix::rot_minus_matrix,
};
use qfall_math::{
integer::{MatPolyOverZ, MatZ, PolyOverZ},
integer_mod_q::{MatPolynomialRingZq, MatZq},
rational::{MatQ, PolyOverQ, Q},
traits::{
FromCoefficientEmbedding, IntoCoefficientEmbedding, MatrixDimensions, MatrixGetSubmatrix,
Pow,
},
};
use serde::{Deserialize, Serialize};
#[derive(Serialize, Deserialize)]
pub struct PSFGPVRing {
pub gp: GadgetParametersRing,
pub s: Q,
pub s_td: Q,
}
impl PSF for PSFGPVRing {
type A = MatPolynomialRingZq;
type Trapdoor = (MatPolyOverZ, MatPolyOverZ);
type Domain = MatPolyOverZ;
type Range = MatPolynomialRingZq;
fn trap_gen(&self) -> (MatPolynomialRingZq, (MatPolyOverZ, MatPolyOverZ)) {
let a_bar =
PolyOverZ::sample_uniform(self.gp.modulus.get_degree() - 1, 0, self.gp.modulus.get_q())
.unwrap();
let (a, r, e) = gen_trapdoor_ring_lwe(&self.gp, &a_bar, &self.s_td).unwrap();
(a, (r, e))
}
fn samp_d(&self) -> MatPolyOverZ {
let dimension = self.gp.modulus.get_degree() * (&self.gp.k + 2);
let sample = MatZ::sample_discrete_gauss(dimension, 1, 0, &self.s).unwrap();
MatPolyOverZ::from_coefficient_embedding((&sample, self.gp.modulus.get_degree() - 1))
}
fn samp_p(
&self,
a: &MatPolynomialRingZq,
(r, e): &(MatPolyOverZ, MatPolyOverZ),
u: &MatPolynomialRingZq,
) -> MatPolyOverZ {
let short_basis = gen_short_basis_for_trapdoor_ring(&self.gp, a, r, e);
let u_embedded = u
.get_representative_least_nonnegative_residue()
.into_coefficient_embedding(self.gp.modulus.get_degree());
let a_embedded = a
.get_representative_least_nonnegative_residue()
.into_coefficient_embedding(self.gp.modulus.get_degree());
let rot_a = rot_minus_matrix(&a_embedded);
let u_embedded = MatZq::from((&u_embedded, &self.gp.modulus.get_q()));
let rot_a = MatZq::from((&rot_a, &self.gp.modulus.get_q()));
let sol: MatZ = rot_a
.solve_gaussian_elimination(&u_embedded)
.unwrap()
.get_representative_least_nonnegative_residue();
let center = MatQ::from(&(-1 * &sol));
let mut center_embedded = Vec::new();
for block in 0..(center.get_num_rows() / (self.gp.modulus.get_degree())) {
let sub_mat = center
.get_submatrix(
block * self.gp.modulus.get_degree(),
(block + 1) * self.gp.modulus.get_degree() - 1,
0,
0,
)
.unwrap();
let embedded_sub_mat = PolyOverQ::from_coefficient_embedding(&sub_mat);
center_embedded.push(embedded_sub_mat);
}
MatPolyOverZ::from_coefficient_embedding((&sol, self.gp.modulus.get_degree() - 1))
+ MatPolyOverZ::sample_d(
&short_basis,
self.gp.modulus.get_degree(),
¢er_embedded,
&self.s,
)
.unwrap()
}
fn f_a(&self, a: &MatPolynomialRingZq, sigma: &MatPolyOverZ) -> MatPolynomialRingZq {
assert!(self.check_domain(sigma));
let sigma = MatPolynomialRingZq::from((sigma, &a.get_mod()));
a * sigma
}
fn check_domain(&self, sigma: &MatPolyOverZ) -> bool {
let m = &self.gp.k + 2;
let nr_coeffs = self.gp.modulus.get_degree();
let sigma_embedded = sigma.into_coefficient_embedding(nr_coeffs);
sigma.is_column_vector()
&& m == sigma.get_num_rows()
&& sigma_embedded.norm_eucl_sqrd().unwrap()
<= self.s.pow(2).unwrap() * sigma_embedded.get_num_rows()
}
}
#[cfg(test)]
mod test_gpv_psf {
use super::super::gpv_ring::PSFGPVRing;
use super::PSF;
use crate::sample::g_trapdoor::gadget_parameters::GadgetParametersRing;
use qfall_math::integer::{MatPolyOverZ, PolyOverZ};
use qfall_math::integer_mod_q::MatPolynomialRingZq;
use qfall_math::rational::Q;
use qfall_math::traits::*;
fn compute_s(n: i64) -> Q {
((2 * 2 * Q::from(1.005_f64) * Q::from(n).sqrt() + 1) * 2) * 4
}
#[test]
fn samp_d_samples_from_dn() {
let (n, q) = (5, 123456789);
let psf = PSFGPVRing {
gp: GadgetParametersRing::init_default(n, q),
s: Q::from(1000),
s_td: Q::from(1.005_f64),
};
for _ in 0..5 {
assert!(psf.check_domain(&psf.samp_d()));
}
}
#[test]
fn samp_p_preimage_and_domain() {
for (n, q) in [(5, i32::MAX - 57), (6, i32::MAX)] {
let psf = PSFGPVRing {
gp: GadgetParametersRing::init_default(n, q),
s: compute_s(n),
s_td: Q::from(1.005_f64),
};
let (a, r) = psf.trap_gen();
let domain_sample = psf.samp_d();
let range_fa = psf.f_a(&a, &domain_sample);
let preimage = psf.samp_p(&a, &r, &range_fa);
assert_eq!(range_fa, psf.f_a(&a, &preimage));
assert!(psf.check_domain(&preimage));
}
}
#[test]
fn f_a_works_as_expected() {
for (n, q) in [(5, 256), (6, 128)] {
let psf = PSFGPVRing {
gp: GadgetParametersRing::init_default(n, q),
s: compute_s(n),
s_td: Q::from(1.005_f64),
};
let (a, _) = psf.trap_gen();
let domain_sample = psf.samp_d();
let domain_sample_2 = MatPolynomialRingZq::from((&domain_sample, &a.get_mod()));
assert_eq!(&a * &domain_sample_2, psf.f_a(&a, &domain_sample));
}
}
#[test]
#[should_panic]
fn f_a_sigma_not_in_domain_matrix() {
let psf = PSFGPVRing {
gp: GadgetParametersRing::init_default(8, 1024),
s: compute_s(8),
s_td: Q::from(1.005_f64),
};
let (a, _) = psf.trap_gen();
let not_in_domain = MatPolyOverZ::new(a.get_num_columns(), 2);
let _ = psf.f_a(&a, ¬_in_domain);
}
#[test]
#[should_panic]
fn f_a_sigma_not_in_domain_incorrect_length() {
let psf = PSFGPVRing {
gp: GadgetParametersRing::init_default(8, 1024),
s: compute_s(8),
s_td: Q::from(1.005_f64),
};
let (a, _) = psf.trap_gen();
let not_in_domain = MatPolyOverZ::new(a.get_num_columns() - 1, 1);
let _ = psf.f_a(&a, ¬_in_domain);
}
#[test]
#[should_panic]
fn f_a_sigma_not_in_domain_too_long() {
let psf = PSFGPVRing {
gp: GadgetParametersRing::init_default(8, 1024),
s: compute_s(8),
s_td: Q::from(1.005_f64),
};
let (a, _) = psf.trap_gen();
let not_in_domain = psf.s.round()
* a.get_num_columns()
* 8
* MatPolyOverZ::identity(a.get_num_columns(), 1);
let _ = psf.f_a(&a, ¬_in_domain);
}
#[test]
fn check_domain_as_expected() {
let psf = PSFGPVRing {
gp: GadgetParametersRing::init_default(9, 1024),
s: compute_s(9),
s_td: Q::from(1.005_f64),
};
let (a, _) = psf.trap_gen();
let value = PolyOverZ::from(psf.s.round() * 3);
let mut in_domain = MatPolyOverZ::new(a.get_num_columns(), 1);
for i in 0..in_domain.get_num_rows() {
in_domain.set_entry(i, 0, &value).unwrap();
}
assert!(psf.check_domain(&MatPolyOverZ::new(a.get_num_columns(), 1)));
assert!(psf.check_domain(&in_domain));
}
#[test]
fn check_domain_not_in_dn() {
let psf = PSFGPVRing {
gp: GadgetParametersRing::init_default(8, 1024),
s: compute_s(8),
s_td: Q::from(1.005_f64),
};
let (a, _) = psf.trap_gen();
let matrix = MatPolyOverZ::new(a.get_num_columns(), 2);
let too_short = MatPolyOverZ::new(a.get_num_columns() - 1, 1);
let too_long = MatPolyOverZ::new(a.get_num_columns() + 1, 1);
let entry_too_large = psf.s.round()
* a.get_num_columns()
* 8
* MatPolyOverZ::identity(a.get_num_columns(), 1);
assert!(!psf.check_domain(&matrix));
assert!(!psf.check_domain(&too_long));
assert!(!psf.check_domain(&too_short));
assert!(!psf.check_domain(&entry_too_large));
}
}