use super::{gadget_parameters::GadgetParametersRing, gadget_ring::find_solution_gadget_ring};
use qfall_math::{
integer::{MatPolyOverZ, PolyOverZ, Z},
integer_mod_q::{MatPolynomialRingZq, PolynomialRingZq, Zq},
traits::*,
};
pub fn gen_short_basis_for_trapdoor_ring(
params: &GadgetParametersRing,
a: &MatPolynomialRingZq,
r: &MatPolyOverZ,
e: &MatPolyOverZ,
) -> MatPolyOverZ {
let sa_l = gen_sa_l(e, r);
let sa_r = gen_sa_r(params, a);
let mut basis = sa_l * sa_r;
let ctx_poly = params
.modulus
.get_representative_least_nonnegative_residue();
basis.reduce_by_poly(&ctx_poly);
basis
}
fn gen_sa_l(e: &MatPolyOverZ, r: &MatPolyOverZ) -> MatPolyOverZ {
let out = e.concat_vertical(r).unwrap();
let identity_lower_right = MatPolyOverZ::identity(out.get_num_columns(), out.get_num_columns());
let out = out.concat_vertical(&identity_lower_right).unwrap();
let identity_left = MatPolyOverZ::identity(out.get_num_rows(), 2);
identity_left.concat_horizontal(&out).unwrap()
}
fn gen_sa_r(params: &GadgetParametersRing, a: &MatPolynomialRingZq) -> MatPolyOverZ {
let n = params.modulus.get_degree();
let mut poly_degrees = MatPolyOverZ::new(1, n);
for i in 0..n {
let mut x_i = PolyOverZ::default();
x_i.set_coeff(i, 1).unwrap();
poly_degrees.set_entry(0, i, x_i).unwrap();
}
let mut s = compute_s(params);
if params.base.pow(¶ms.k).unwrap() == params.modulus.get_q() {
s.reverse_columns();
}
let s = poly_degrees.tensor_product(&s);
let zero = MatPolyOverZ::new(2, ¶ms.k * n);
let left = zero.concat_vertical(&s).unwrap();
let w = compute_w(params, a);
let ident = MatPolyOverZ::identity(2, 2);
let right = poly_degrees.tensor_product(&ident.concat_vertical(&w).unwrap());
left.concat_horizontal(&right).unwrap()
}
fn compute_w(params: &GadgetParametersRing, a: &MatPolynomialRingZq) -> MatPolyOverZ {
let minus_one = PolynomialRingZq::from((&PolyOverZ::from(-1), ¶ms.modulus));
let rhs_0: PolynomialRingZq = a.get_entry(0, 0).unwrap();
let rhs_1: PolynomialRingZq = a.get_entry(0, 1).unwrap();
let w_0 =
find_solution_gadget_ring(&(&minus_one * &rhs_0), ¶ms.k, ¶ms.base).transpose();
let w_1 =
find_solution_gadget_ring(&(&minus_one * &rhs_1), ¶ms.k, ¶ms.base).transpose();
w_0.concat_horizontal(&w_1).unwrap()
}
fn compute_s(params: &GadgetParametersRing) -> MatPolyOverZ {
let id_k = MatPolyOverZ::identity(¶ms.k, ¶ms.k);
let mut sk = ¶ms.base * id_k;
for i in 0..(sk.get_num_rows() - 1) {
sk.set_entry(i + 1, i, PolyOverZ::from(-1)).unwrap();
}
sk = if params.base.pow(¶ms.k).unwrap() == params.modulus.get_q() {
sk
} else {
let mut q = Z::from(¶ms.modulus.get_q());
for i in 0..(sk.get_num_rows()) {
let q_i = Zq::from((&q, ¶ms.base)).get_representative_least_nonnegative_residue();
sk.set_entry(i, sk.get_num_columns() - 1, PolyOverZ::from(&q_i))
.unwrap();
q -= q_i;
q = q.div_exact(¶ms.base).unwrap();
}
sk
};
sk
}
#[cfg(test)]
mod test_gen_short_basis_for_trapdoor_ring {
use super::gen_short_basis_for_trapdoor_ring;
use crate::sample::g_trapdoor::{
gadget_parameters::GadgetParametersRing, gadget_ring::gen_trapdoor_ring_lwe,
};
use qfall_math::{
integer::PolyOverZ,
integer_mod_q::MatPolynomialRingZq,
rational::{MatQ, Q},
traits::*,
};
#[test]
fn is_basis() {
for n in [5, 10, 12] {
let params = GadgetParametersRing::init_default(n, 16);
let a_bar = PolyOverZ::sample_uniform(¶ms.n, 0, params.modulus.get_q()).unwrap();
let (a, r, e) = gen_trapdoor_ring_lwe(¶ms, &a_bar, 5).unwrap();
let short_base = gen_short_basis_for_trapdoor_ring(¶ms, &a, &r, &e);
let short_base = MatPolynomialRingZq::from((&short_base, ¶ms.modulus));
assert_eq!(n * a.get_num_columns(), short_base.get_num_columns());
let res = a * short_base;
for i in 0..res.get_num_columns() {
let entry: PolyOverZ = res.get_entry(0, i).unwrap();
assert!(entry.is_zero())
}
}
}
#[test]
fn basis_is_reduced() {
for n in [5, 10, 12] {
let params = GadgetParametersRing::init_default(n, 16);
let a_bar = PolyOverZ::sample_uniform(¶ms.n, 0, params.modulus.get_q()).unwrap();
let (a, r, e) = gen_trapdoor_ring_lwe(¶ms, &a_bar, 5).unwrap();
let short_base = gen_short_basis_for_trapdoor_ring(¶ms, &a, &r, &e);
for i in 0..short_base.get_num_rows() {
for j in 0..short_base.get_num_columns() {
let entry = short_base.get_entry(i, j).unwrap();
assert!(entry.get_degree() < n)
}
}
}
}
#[test]
fn ensure_orthogonalized_length_perfect_power() {
for n in 4..8 {
let params = GadgetParametersRing::init_default(n, 32);
let a_bar = PolyOverZ::sample_uniform(¶ms.n, 0, params.modulus.get_q()).unwrap();
let (a, r, e) = gen_trapdoor_ring_lwe(¶ms, &a_bar, 5).unwrap();
let short_base = gen_short_basis_for_trapdoor_ring(¶ms, &a, &r, &e);
let short_base_embedded = short_base.into_coefficient_embedding(n);
let orthogonalized_short_basis = MatQ::from(&short_base_embedded).gso();
let s1_r = {
let mut r_max = Q::ZERO;
let r_embedded = r.into_coefficient_embedding(n);
for i in 0..r_embedded.get_num_columns() {
let r_new = r_embedded
.get_column(i)
.unwrap()
.norm_eucl_sqrd()
.unwrap()
.sqrt();
if r_new > r_max {
r_max = r_new
}
}
r_max
};
let s1_e = {
let mut e_max = Q::ZERO;
let e_embedded = e.into_coefficient_embedding(n);
for i in 0..e_embedded.get_num_columns() {
let e_new = e_embedded
.get_column(i)
.unwrap()
.norm_eucl_sqrd()
.unwrap()
.sqrt();
if e_new > e_max {
e_max = e_new
}
}
e_max
};
let orth_s_length = 2;
let upper_bound: Q = (s1_r + s1_e + 1) * orth_s_length;
for i in 0..orthogonalized_short_basis.get_num_columns() {
let b_tilde_i = orthogonalized_short_basis.get_column(i).unwrap();
assert!(b_tilde_i.norm_eucl_sqrd().unwrap() <= upper_bound.pow(2).unwrap())
}
}
}
#[test]
fn ensure_orthogonalized_length_not_perfect_power() {
for n in 4..8 {
let params = GadgetParametersRing::init_default(n, 42);
let a_bar = PolyOverZ::sample_uniform(¶ms.n, 0, params.modulus.get_q()).unwrap();
let (a, r, e) = gen_trapdoor_ring_lwe(¶ms, &a_bar, 5).unwrap();
let short_base = gen_short_basis_for_trapdoor_ring(¶ms, &a, &r, &e);
let short_base_embedded = short_base.into_coefficient_embedding(n);
let orthogonalized_short_basis = MatQ::from(&short_base_embedded).gso();
let s1_r = {
let mut r_max = Q::ZERO;
let r_embedded = r.into_coefficient_embedding(n);
for i in 0..r_embedded.get_num_columns() {
let r_new = r_embedded
.get_column(i)
.unwrap()
.norm_eucl_sqrd()
.unwrap()
.sqrt();
if r_new > r_max {
r_max = r_new
}
}
r_max
};
let s1_e = {
let mut e_max = Q::ZERO;
let e_embedded = e.into_coefficient_embedding(n);
for i in 0..e_embedded.get_num_columns() {
let e_new = e_embedded
.get_column(i)
.unwrap()
.norm_eucl_sqrd()
.unwrap()
.sqrt();
if e_new > e_max {
e_max = e_new
}
}
e_max
};
let orth_s_length = Q::from(5).sqrt();
let upper_bound: Q = (s1_r + s1_e + 1) * orth_s_length;
for i in 0..orthogonalized_short_basis.get_num_columns() {
let b_tilde_i = orthogonalized_short_basis.get_column(i).unwrap();
assert!(b_tilde_i.norm_eucl_sqrd().unwrap() <= upper_bound.pow(2).unwrap())
}
}
}
}
#[cfg(test)]
mod test_gen_sa {
use crate::sample::g_trapdoor::{
gadget_parameters::GadgetParametersRing,
short_basis_ring::{gen_sa_l, gen_sa_r},
};
use qfall_math::{
integer::{MatPolyOverZ, MatZ},
integer_mod_q::MatPolynomialRingZq,
traits::IntoCoefficientEmbedding,
};
use std::str::FromStr;
fn get_fixed_trapdoor() -> (
GadgetParametersRing,
MatPolynomialRingZq,
MatPolyOverZ,
MatPolyOverZ,
) {
let params = GadgetParametersRing::init_default(4, 16);
let a = MatPolyOverZ::from_str(
"[[1 1, 4 2 8 8 12, 4 11 10 7 13, 4 9 6 6 12, 4 6 11 1 6, 4 3 10 2 9]]",
)
.unwrap();
let a = MatPolynomialRingZq::from((&a, ¶ms.modulus));
let r = MatPolyOverZ::from_str("[[4 -1 7 6 -8, 3 0 -2 4, 4 0 3 -4 1, 4 6 4 -1 3]]")
.unwrap();
let e =
MatPolyOverZ::from_str("[[4 -4 8 -3 7, 4 1 -2 2 4, 3 -6 7 -5, 4 -7 10 -12 -15]]")
.unwrap();
(params, a, r, e)
}
#[test]
fn working_sa_l() {
let (_, _, r, e) = get_fixed_trapdoor();
let sa_l = gen_sa_l(&r, &e);
let sa_l_cmp = MatPolyOverZ::from_str(
"[\
[1 1, 0, 4 -1 7 6 -8, 3 0 -2 4, 4 0 3 -4 1, 4 6 4 -1 3],\
[0, 1 1, 4 -4 8 -3 7, 4 1 -2 2 4, 3 -6 7 -5, 4 -7 10 -12 -15],\
[0, 0, 1 1, 0, 0, 0],\
[0, 0, 0, 1 1, 0, 0],\
[0, 0, 0, 0, 1 1, 0],\
[0, 0, 0, 0, 0, 1 1]]",
)
.unwrap();
assert_eq!(sa_l_cmp, sa_l)
}
#[test]
fn working_sa_r() {
let (params, a, _, _) = get_fixed_trapdoor();
let mut sa_r = gen_sa_r(¶ms, &a);
sa_r.reduce_by_poly(
¶ms
.modulus
.get_representative_least_nonnegative_residue(),
);
let sa_r_cmp = MatZ::from_str(
"[\
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],\
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],\
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],\
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],\
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],\
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],\
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],\
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],\
[0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0],\
[0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],\
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],\
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 1, 0],\
[0, 0, 2, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0],\
[0, 0, 0, 0, 0, 0, 2, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0],\
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, -1, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0],\
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, -1, 0, 0, 0, 0, 0, 0, 1, 1],\
[0, 2, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, -1, 0, 0, 0, 0],\
[0, 0, 0, 0, 0, 2, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, -1, 0, 0],\
[0, 0, 0, 0, 0, 0, 0, 0, 0, 2, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, -1],\
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, -1, 0, 0, 1, 0, 0, 0, 0, 1, 1],\
[2, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, -1, 0, -1],\
[0, 0, 0, 0, 2, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0, 0, -1],\
[0, 0, 0, 0, 0, 0, 0, 0, 2, -1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 0, 0],\
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, -1, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1]]",
)
.unwrap();
assert_eq!(sa_r_cmp, sa_r.into_coefficient_embedding(4));
}
}
#[cfg(test)]
mod test_compute_s {
use crate::sample::g_trapdoor::{
gadget_parameters::GadgetParametersRing, short_basis_ring::compute_s,
};
use qfall_math::integer::{MatPolyOverZ, Z};
use std::str::FromStr;
#[test]
fn base_2_power_two() {
let params = GadgetParametersRing::init_default(8, 16);
let s = compute_s(¶ms);
let s_cmp = MatPolyOverZ::from_str(
"[[1 2, 0, 0, 0],\
[1 -1, 1 2, 0, 0],\
[0, 1 -1, 1 2, 0],\
[0, 0, 1 -1, 1 2]]",
)
.unwrap();
assert_eq!(s_cmp, s)
}
#[test]
fn base_2_arbitrary() {
let q = Z::from(0b1100110);
let params = GadgetParametersRing::init_default(1, q);
let s = compute_s(¶ms);
let s_cmp = MatPolyOverZ::from_str(
"[[1 2, 0, 0, 0, 0, 0, 0],\
[1 -1, 1 2, 0, 0, 0, 0, 1 1],\
[0, 1 -1, 1 2, 0, 0, 0, 1 1],\
[0, 0, 1 -1, 1 2, 0, 0, 0],\
[0, 0, 0, 1 -1, 1 2, 0, 0],\
[0, 0, 0, 0, 1 -1, 1 2, 1 1],\
[0, 0, 0, 0, 0, 1 -1, 1 1]]",
)
.unwrap();
assert_eq!(s_cmp, s)
}
#[test]
fn base_5_power_5() {
let mut params = GadgetParametersRing::init_default(1, 625);
params.k = Z::from(4);
params.base = Z::from(5);
let s = compute_s(¶ms);
let s_cmp = MatPolyOverZ::from_str(
"[[1 5, 0, 0, 0],\
[1 -1, 1 5, 0, 0],\
[0, 1 -1, 1 5, 0],\
[0, 0, 1 -1, 1 5]]",
)
.unwrap();
assert_eq!(s_cmp, s)
}
#[test]
fn base_5_arbitrary() {
let q = Z::from_str_b("4123", 5).unwrap();
let mut params = GadgetParametersRing::init_default(1, q);
params.k = Z::from(4);
params.base = Z::from(5);
let s = compute_s(¶ms);
let s_cmp = MatPolyOverZ::from_str(
"[[1 5, 0, 0, 1 3],\
[1 -1, 1 5, 0, 1 2],\
[0, 1 -1, 1 5, 1 1],\
[0, 0, 1 -1, 1 4]]",
)
.unwrap();
assert_eq!(s_cmp, s)
}
}
#[cfg(test)]
mod test_compute_w {
use crate::sample::g_trapdoor::{
gadget_parameters::GadgetParametersRing,
gadget_ring::{gen_gadget_ring, gen_trapdoor_ring_lwe},
short_basis_ring::compute_w,
};
use qfall_math::{
integer::{MatPolyOverZ, PolyOverZ},
integer_mod_q::MatPolynomialRingZq,
traits::MatrixDimensions,
};
#[test]
fn check_w_is_correct_solution() {
let params = GadgetParametersRing::init_default(8, 16);
let a_bar = PolyOverZ::sample_uniform(¶ms.n, 0, params.modulus.get_q()).unwrap();
let (a, _, _) = gen_trapdoor_ring_lwe(¶ms, &a_bar, 5).unwrap();
let w = compute_w(¶ms, &a);
let w = MatPolynomialRingZq::from((&w, ¶ms.modulus));
let gadget = gen_gadget_ring(¶ms.k, ¶ms.base);
let gadget = MatPolynomialRingZq::from((&gadget, ¶ms.modulus));
let gw = gadget.transpose() * w;
let i0 = -1 * MatPolyOverZ::identity(a.get_num_columns(), 2);
let i0 = MatPolynomialRingZq::from((&i0, ¶ms.modulus));
let rhs = a * i0;
assert_eq!(gw, rhs)
}
}