use super::{gadget_classical::find_solution_gadget_mat, gadget_parameters::GadgetParametersRing};
use qfall_math::{
error::MathError,
integer::{MatPolyOverZ, PolyOverZ, Z},
integer_mod_q::{MatPolynomialRingZq, MatZq, PolynomialRingZq},
rational::Q,
traits::{
Concatenate, IntoCoefficientEmbedding, MatrixDimensions, MatrixGetEntry, MatrixSetEntry,
Pow, SetCoefficient,
},
};
use std::fmt::Display;
pub fn gen_trapdoor_ring_lwe(
params: &GadgetParametersRing,
a_bar: &PolyOverZ,
s: impl Into<Q>,
) -> Result<(MatPolynomialRingZq, MatPolyOverZ, MatPolyOverZ), MathError> {
let s = s.into();
let r = params.distribution.sample(¶ms.n, ¶ms.k, &s);
let e = params.distribution.sample(¶ms.n, ¶ms.k, &s);
let mut big_a = MatPolyOverZ::new(1, 2);
big_a.set_entry(0, 0, &PolyOverZ::from(1))?;
big_a.set_entry(0, 1, a_bar)?;
let g = gen_gadget_ring(¶ms.k, ¶ms.base);
big_a = big_a.concat_horizontal(&(g.transpose() - (a_bar * &r + &e)))?;
Ok((MatPolynomialRingZq::from((&big_a, ¶ms.modulus)), r, e))
}
pub fn gen_gadget_ring(k: impl TryInto<i64> + Display, base: &Z) -> MatPolyOverZ {
let mut out = MatPolyOverZ::new(k, 1);
for j in 0..out.get_num_rows() {
unsafe { out.set_entry_unchecked(j, 0, &PolyOverZ::from(base.pow(j).unwrap())) };
}
out
}
pub fn find_solution_gadget_ring(u: &PolynomialRingZq, k: &Z, base: &Z) -> MatPolyOverZ {
let k_i64 = i64::try_from(k).unwrap();
let modulus = u.get_mod();
let size = modulus.get_degree();
let value = u
.get_representative_least_nonnegative_residue()
.into_coefficient_embedding(size);
let value = MatZq::from((&value, modulus.get_q()));
let classical_sol = find_solution_gadget_mat(&value, k, base);
let mut out = MatPolyOverZ::new(1, k);
for i in 0..k_i64 {
let mut poly = PolyOverZ::default();
for j in 0..size {
let entry = classical_sol.get_entry(i + j * k_i64, 0).unwrap();
poly.set_coeff(j, &entry).unwrap();
}
out.set_entry(0, i, &poly).unwrap();
}
out
}
#[cfg(test)]
mod test_gen_trapdoor_ring {
use crate::sample::g_trapdoor::{
gadget_parameters::GadgetParametersRing, gadget_ring::gen_trapdoor_ring_lwe,
};
use qfall_math::{
integer::{MatPolyOverZ, PolyOverZ, Z},
integer_mod_q::MatPolynomialRingZq,
traits::{Concatenate, GetCoefficient, MatrixDimensions, MatrixGetEntry, Pow},
};
fn compute_trapdoor(r: &MatPolyOverZ, e: &MatPolyOverZ, k: &Z) -> MatPolyOverZ {
let i_k = MatPolyOverZ::identity(k, k);
e.concat_vertical(r).unwrap().concat_vertical(&i_k).unwrap()
}
#[test]
fn is_trapdoor() {
let params = GadgetParametersRing::init_default(6, 32);
let a_bar = PolyOverZ::sample_uniform(¶ms.n, 0, params.modulus.get_q()).unwrap();
let (a, r, e) = gen_trapdoor_ring_lwe(¶ms, &a_bar, 10).unwrap();
let trapdoor =
MatPolynomialRingZq::from((&compute_trapdoor(&r, &e, ¶ms.k), ¶ms.modulus));
let res: MatPolynomialRingZq = &a * &trapdoor;
assert_eq!(params.k, Z::from(res.get_num_columns()));
assert_eq!(1, res.get_num_rows());
for i in 0..i64::try_from(¶ms.k).unwrap() {
let res_entry: PolyOverZ = res.get_entry(0, i).unwrap();
assert_eq!(res_entry.get_coeff(0).unwrap(), params.base.pow(i).unwrap())
}
}
}
#[cfg(test)]
mod test_find_solution_gadget_ring {
use super::{find_solution_gadget_ring, gen_gadget_ring};
use crate::sample::g_trapdoor::gadget_parameters::GadgetParametersRing;
use qfall_math::{
integer::PolyOverZ,
integer_mod_q::{MatPolynomialRingZq, PolynomialRingZq},
};
use std::str::FromStr;
#[test]
fn is_correct_solution() {
let gp = GadgetParametersRing::init_default(3, 32);
let gadget = gen_gadget_ring(&gp.k, &gp.base);
let gadget = MatPolynomialRingZq::from((&gadget, &gp.modulus));
let u = PolyOverZ::from_str("10 5 124 12 14 14 1 2 4 1 5").unwrap();
let u = PolynomialRingZq::from((&u, &gp.modulus));
let solution = find_solution_gadget_ring(&u, &gp.k, &gp.base);
let solution = MatPolynomialRingZq::from((&solution, &gp.modulus));
assert_eq!(u, gadget.dot_product(&solution).unwrap())
}
}