use super::gadget_parameters::GadgetParameters;
use qfall_math::{
error::MathError,
integer::{MatZ, Z},
integer_mod_q::{MatZq, Zq},
traits::*,
};
use std::fmt::Display;
pub fn gen_trapdoor(
params: &GadgetParameters,
a_bar: &MatZq,
tag: &MatZq,
) -> Result<(MatZq, MatZ), MathError> {
let g = gen_gadget_mat((¶ms.n).try_into().unwrap(), ¶ms.k, ¶ms.base);
let r = params
.distribution
.sample(¶ms.m_bar, &(¶ms.n * ¶ms.k));
let a = a_bar.concat_horizontal(&(tag * g - a_bar * &r))?;
Ok((a, r))
}
pub fn gen_gadget_mat(n: i64, k: impl TryInto<i64> + Display, base: &Z) -> MatZ {
let gadget_vec = gen_gadget_vec(k, base).transpose();
let mut out = MatZ::new(n, n * gadget_vec.get_num_columns());
for j in 0..out.get_num_rows() {
out.set_submatrix(
j,
j * gadget_vec.get_num_columns(),
&gadget_vec,
0,
0,
-1,
-1,
)
.unwrap();
}
out
}
pub fn gen_gadget_vec(k: impl TryInto<i64> + Display, base: &Z) -> MatZ {
let mut entry = Z::ONE;
let mut out = MatZ::new(k, 1);
for i in 0..out.get_num_rows() {
out.set_entry(i, 0, &entry).unwrap();
entry *= base
}
out
}
pub fn find_solution_gadget_vec(value: &Zq, k: &Z, base: &Z) -> MatZ {
if base.pow(k).unwrap() < value.get_mod() {
panic!("The modulus is too large, the value is potentially not representable.");
}
let mut value = value.get_representative_least_nonnegative_residue();
let mut out = MatZ::new(k, 1);
for i in 0..out.get_num_rows() {
let val_i = &value % base;
out.set_entry(i, 0, &val_i).unwrap();
value = (value - val_i).div_exact(base).unwrap();
}
out
}
pub fn find_solution_gadget_mat(value: &MatZq, k: &Z, base: &Z) -> MatZ {
let mut out = MatZ::new(k * value.get_num_rows(), value.get_num_columns());
for i in 0..value.get_num_columns() as usize {
for j in 0..value.get_num_rows() as usize {
let sol_j = find_solution_gadget_vec(&value.get_entry(j, i).unwrap(), k, base);
out.set_submatrix(k * j as u64, i, &sol_j, 0, 0, -1, 0)
.unwrap();
}
}
out
}
pub fn short_basis_gadget(params: &GadgetParameters) -> MatZ {
let mut sk = MatZ::new(¶ms.k, ¶ms.k);
let n: i64 = (¶ms.n).try_into().unwrap();
let k: i64 = (¶ms.k).try_into().unwrap();
for j in 0..k {
sk.set_entry(j, j, ¶ms.base).unwrap();
}
for i in 0..k - 1 {
sk.set_entry(i + 1, i, Z::MINUS_ONE).unwrap();
}
sk = if params.base.pow(k).unwrap() == params.q {
sk
} else {
let mut q = Z::from(¶ms.q);
for i in 0..k {
let q_i = &q % ¶ms.base;
sk.set_entry(i, k - 1, &q_i).unwrap();
q = (q - q_i).div_exact(¶ms.base).unwrap();
}
sk
};
let mut out = MatZ::new(n * k, n * k);
for j in 0..n {
out.set_submatrix(
j * sk.get_num_rows(),
j * sk.get_num_columns(),
&sk,
0,
0,
-1,
-1,
)
.unwrap();
}
out
}
#[cfg(test)]
mod test_gen_gadget_vec {
use crate::sample::g_trapdoor::gadget_classical::gen_gadget_vec;
use qfall_math::integer::{MatZ, Z};
use std::str::FromStr;
#[test]
fn correctness_base_2() {
let gadget_vec = gen_gadget_vec(5, &Z::from(2));
let vec = MatZ::from_str("[[1],[2],[4],[8],[16]]").unwrap();
assert_eq!(vec, gadget_vec);
}
#[test]
fn correctness_base_5() {
let gadget_vec = gen_gadget_vec(4, &Z::from(5));
let vec = MatZ::from_str("[[1],[5],[25],[125]]").unwrap();
assert_eq!(vec, gadget_vec);
}
}
#[cfg(test)]
mod test_gen_gadget_mat {
use super::gen_gadget_mat;
use qfall_math::integer::{MatZ, Z};
use std::str::FromStr;
#[test]
fn correctness_base_2_3x3() {
let gadget_mat = gen_gadget_mat(3, 3, &Z::from(2));
let mat_str = "[[1, 2, 4, 0, 0, 0, 0, 0, 0],\
[0, 0, 0, 1, 2, 4, 0, 0, 0],\
[0, 0, 0, 0, 0, 0, 1, 2, 4]]";
let mat = MatZ::from_str(mat_str).unwrap();
assert_eq!(mat, gadget_mat);
}
#[test]
fn correctness_base_3_2x5() {
let gadget_mat = gen_gadget_mat(2, 5, &Z::from(3));
let mat_str = "[[1, 3, 9, 27, 81, 0, 0, 0, 0, 0],\
[ 0, 0, 0, 0, 0, 1, 3, 9, 27, 81]]";
let mat = MatZ::from_str(mat_str).unwrap();
assert_eq!(mat, gadget_mat);
}
}
#[cfg(test)]
mod test_gen_trapdoor {
use super::gen_trapdoor;
use crate::sample::g_trapdoor::{
gadget_classical::gen_gadget_mat, gadget_parameters::GadgetParameters,
};
use qfall_math::{
integer::{MatZ, Z},
integer_mod_q::{MatZq, Modulus},
traits::*,
};
#[test]
fn is_trapdoor_without_tag() {
let params = GadgetParameters::init_default(42, 32);
let a_bar = MatZq::sample_uniform(42, ¶ms.m_bar, ¶ms.q);
let tag = MatZq::identity(42, 42, ¶ms.q);
let (a, r) = gen_trapdoor(¶ms, &a_bar, &tag).unwrap();
let trapdoor = r
.concat_vertical(&MatZ::identity(
a.get_num_columns() - r.get_num_rows(),
r.get_num_columns(),
))
.unwrap();
let gadget_mat = gen_gadget_mat(42, ¶ms.k, &Z::from(2));
assert_eq!(
MatZq::from((&gadget_mat, ¶ms.q)),
a * MatZq::from((&trapdoor, ¶ms.q))
);
}
#[test]
fn is_trapdoor_with_tag() {
let modulus = Modulus::from(32);
let params = GadgetParameters::init_default(42, &modulus);
let a_bar = MatZq::sample_uniform(42, ¶ms.m_bar, ¶ms.q);
let tag = calculate_invertible_tag(42, &modulus);
let (a, r) = gen_trapdoor(¶ms, &a_bar, &tag).unwrap();
let trapdoor = r
.concat_vertical(&MatZ::identity(
a.get_num_columns() - r.get_num_rows(),
r.get_num_columns(),
))
.unwrap();
let gadget_mat = gen_gadget_mat(42, ¶ms.k, &Z::from(2));
assert_eq!(
tag * MatZq::from((&gadget_mat, &modulus)),
a * MatZq::from((&trapdoor, &modulus))
);
}
fn calculate_invertible_tag(size: i64, modulus: &Modulus) -> MatZq {
let max_value = Z::from(modulus);
let mut out = MatZq::identity(size, size, modulus);
for row in 0..size {
for column in 0..size {
if row < column {
out.set_entry(row, column, Z::sample_uniform(0, &max_value).unwrap())
.unwrap();
}
}
}
out
}
}
#[cfg(test)]
mod test_find_solution_gadget {
use super::find_solution_gadget_vec;
use crate::sample::g_trapdoor::gadget_classical::{
find_solution_gadget_mat, gen_gadget_mat, gen_gadget_vec,
};
use qfall_math::{
integer::Z,
integer_mod_q::{MatZq, Zq},
traits::MatrixGetEntry,
};
use std::str::FromStr;
#[test]
fn returns_correct_solution_vec() {
let k = Z::from(5);
let base = Z::from(3);
for i in 0..124 {
let value = Zq::from((i, 125));
let sol = find_solution_gadget_vec(&value, &k, &base);
assert_eq!(
value.get_representative_least_nonnegative_residue(),
(gen_gadget_vec(&k, &base).transpose() * sol)
.get_entry(0, 0)
.unwrap()
)
}
}
#[test]
fn returns_correct_solution_mat() {
let k = Z::from(5);
let base = Z::from(3);
let value = MatZq::from_str("[[1, 42],[2, 40],[3, 90]] mod 125").unwrap();
let sol = find_solution_gadget_mat(&value, &k, &base);
assert_eq!(
value.get_representative_least_nonnegative_residue(),
gen_gadget_mat(3, &k, &base) * sol
)
}
}
#[cfg(test)]
mod test_short_basis_gadget {
use crate::sample::g_trapdoor::{
gadget_classical::short_basis_gadget, gadget_parameters::GadgetParameters,
};
use qfall_math::integer::{MatZ, Z};
use std::str::FromStr;
#[test]
fn base_2_power_two() {
let params = GadgetParameters::init_default(2, 16);
let s = short_basis_gadget(¶ms);
let s_cmp = MatZ::from_str(
"[[2, 0, 0, 0, 0, 0, 0, 0],\
[-1, 2, 0, 0, 0, 0, 0, 0],\
[0, -1, 2, 0, 0, 0, 0, 0],\
[0, 0, -1, 2, 0, 0, 0, 0],\
[0, 0, 0, 0, 2, 0, 0, 0],\
[0, 0, 0, 0, -1, 2, 0, 0],\
[0, 0, 0, 0, 0, -1, 2, 0],\
[0, 0, 0, 0, 0, 0, -1, 2]]",
)
.unwrap();
assert_eq!(s_cmp, s)
}
#[test]
fn base_2_arbitrary() {
let q = Z::from(0b1100110);
let params = GadgetParameters::init_default(1, q);
let s = short_basis_gadget(¶ms);
let s_cmp = MatZ::from_str(
"[[2, 0, 0, 0, 0, 0, 0],\
[-1, 2, 0, 0, 0, 0, 1],\
[0, -1, 2, 0, 0, 0, 1],\
[0, 0, -1, 2, 0, 0, 0],\
[0, 0, 0, -1, 2, 0, 0],\
[0, 0, 0, 0, -1, 2, 1],\
[0, 0, 0, 0, 0, -1, 1]]",
)
.unwrap();
assert_eq!(s_cmp, s)
}
#[test]
fn base_5_power_5() {
let mut params = GadgetParameters::init_default(1, 625);
params.k = Z::from(4);
params.base = Z::from(5);
let s = short_basis_gadget(¶ms);
let s_cmp = MatZ::from_str(
"[[5, 0, 0, 0],\
[-1, 5, 0, 0],\
[0, -1, 5, 0],\
[0, 0, -1, 5]]",
)
.unwrap();
assert_eq!(s_cmp, s)
}
#[test]
fn base_5_arbitrary() {
let q = Z::from_str_b("4123", 5).unwrap();
let mut params = GadgetParameters::init_default(1, q);
params.k = Z::from(4);
params.base = Z::from(5);
let s = short_basis_gadget(¶ms);
let s_cmp = MatZ::from_str(
"[[5, 0, 0, 3],\
[-1, 5, 0, 2],\
[0, -1, 5, 1],\
[0, 0, -1, 4]]",
)
.unwrap();
assert_eq!(s_cmp, s)
}
}