use super::{
gadget_classical::{find_solution_gadget_mat, short_basis_gadget},
gadget_parameters::GadgetParameters,
};
use qfall_math::{
integer::{MatZ, Z},
integer_mod_q::MatZq,
traits::*,
};
pub fn gen_short_basis_for_trapdoor(
params: &GadgetParameters,
tag: &MatZq,
a: &MatZq,
r: &MatZ,
) -> MatZ {
let sa_l = gen_sa_l(r);
let sa_r = gen_sa_r(params, tag, a);
sa_l * sa_r
}
fn gen_sa_l(r: &MatZ) -> MatZ {
let r_rows = r.get_num_rows();
let r_cols = r.get_num_columns();
let mut sa_l = MatZ::identity(r_rows + r_cols, r_rows + r_cols);
sa_l.set_submatrix(0, r_rows, r, 0, 0, -1, -1).unwrap();
sa_l
}
fn gen_sa_r(params: &GadgetParameters, tag: &MatZq, a: &MatZq) -> MatZ {
let mut s = short_basis_gadget(params);
if params.base.pow(¶ms.k).unwrap() == params.q {
s.reverse_columns();
}
let w = compute_w(params, tag, a);
let mut sa_r = MatZ::new(
s.get_num_rows() + w.get_num_columns(),
s.get_num_columns() + w.get_num_columns(),
);
let offset_identity = s.get_num_columns();
for diagonal in 0..w.get_num_columns() {
unsafe { sa_r.set_entry_unchecked(diagonal, diagonal + offset_identity, 1) };
}
let offset_lower = w.get_num_columns();
sa_r.set_submatrix(offset_lower, 0, &s, 0, 0, -1, -1)
.unwrap();
sa_r.set_submatrix(offset_lower, s.get_num_columns(), &w, 0, 0, -1, -1)
.unwrap();
sa_r
}
fn compute_w(params: &GadgetParameters, tag: &MatZq, a: &MatZq) -> MatZ {
let tag_inv = tag.inverse().unwrap();
let rhs = Z::MINUS_ONE * tag_inv * (a * MatZ::identity(a.get_num_columns(), ¶ms.m_bar));
find_solution_gadget_mat(&rhs, ¶ms.k, ¶ms.base)
}
#[cfg(test)]
mod test_gen_short_basis_for_trapdoor {
use super::gen_short_basis_for_trapdoor;
use crate::sample::g_trapdoor::{
gadget_classical::gen_trapdoor, gadget_default::gen_trapdoor_default,
gadget_parameters::GadgetParameters,
};
use qfall_math::{
integer::Z,
integer_mod_q::{MatZq, Modulus},
rational::{MatQ, Q},
traits::*,
};
#[test]
fn is_basis_not_power_tag_identity() {
for n in [1, 5, 10, 12] {
let q = Modulus::from(127 + 3 * n);
let params = GadgetParameters::init_default(n, &q);
let (a, r) = gen_trapdoor_default(¶ms.n, &q);
let tag = MatZq::identity(¶ms.n, ¶ms.n, &q);
let short_basis = gen_short_basis_for_trapdoor(¶ms, &tag, &a, &r);
let zero_vec = MatZq::new(a.get_num_rows(), 1, &q);
for i in 0..short_basis.get_num_columns() {
assert_eq!(zero_vec, &a * short_basis.get_column(i).unwrap())
}
}
}
#[test]
fn is_basis_with_tag_factor_identity() {
for n in [2, 5, 10, 12] {
let q = Modulus::from(124 + 2 * n);
let params = GadgetParameters::init_default(n, &q);
let tag = 17 * MatZq::identity(n, n, ¶ms.q);
let a_bar = MatZq::sample_uniform(n, ¶ms.m_bar, ¶ms.q);
let (a, r) = gen_trapdoor(¶ms, &a_bar, &tag).unwrap();
let short_basis = gen_short_basis_for_trapdoor(¶ms, &tag, &a, &r);
let zero_vec = MatZq::new(a.get_num_rows(), 1, &q);
for i in 0..short_basis.get_num_columns() {
assert_eq!(zero_vec, &a * short_basis.get_column(i).unwrap())
}
}
}
#[test]
fn is_basis_with_tag_arbitrarily() {
for n in [2, 5, 10, 12] {
let q = Modulus::from(124 + 2 * n);
let params = GadgetParameters::init_default(n, &q);
let tag = calculate_invertible_tag(n, &q);
let a_bar = MatZq::sample_uniform(n, ¶ms.m_bar, ¶ms.q);
let (a, r) = gen_trapdoor(¶ms, &a_bar, &tag).unwrap();
let short_basis = gen_short_basis_for_trapdoor(¶ms, &tag, &a, &r);
let zero_vec = MatZq::new(a.get_num_rows(), 1, &q);
for i in 0..short_basis.get_num_columns() {
assert_eq!(zero_vec, &a * short_basis.get_column(i).unwrap())
}
}
}
#[test]
fn ensure_orthogonalized_length_perfect_power() {
for n in [1, 5, 7] {
let q = Modulus::from(128);
let params = GadgetParameters::init_default(n, &q);
let tag = calculate_invertible_tag(n, &q);
let a_bar = MatZq::sample_uniform(n, ¶ms.m_bar, ¶ms.q);
let (a, r) = gen_trapdoor(¶ms, &a_bar, &tag).unwrap();
let short_basis = gen_short_basis_for_trapdoor(¶ms, &tag, &a, &r);
let orthogonalized_short_basis = MatQ::from(&short_basis).gso();
let s1_r = params.m_bar.sqrt();
let orth_s_length = 2;
let upper_bound: Q = (s1_r + 1) * orth_s_length;
for i in 0..orthogonalized_short_basis.get_num_columns() {
let b_tilde_i = orthogonalized_short_basis.get_column(i).unwrap();
assert!(b_tilde_i.norm_eucl_sqrd().unwrap() <= upper_bound.pow(2).unwrap())
}
}
}
#[test]
fn ensure_orthogonalized_length_not_perfect_power() {
for n in [1, 5, 7] {
let q = Modulus::from(127);
let params = GadgetParameters::init_default(n, &q);
let tag = calculate_invertible_tag(n, &q);
let a_bar = MatZq::sample_uniform(n, ¶ms.m_bar, ¶ms.q);
let (a, r) = gen_trapdoor(¶ms, &a_bar, &tag).unwrap();
let short_basis = gen_short_basis_for_trapdoor(¶ms, &tag, &a, &r);
let orthogonalized_short_basis = MatQ::from(&short_basis).gso();
let s1_r = params.m_bar.sqrt();
let orth_s_length = Q::from(5).sqrt();
let upper_bound: Q = (s1_r + 1) * orth_s_length;
for i in 0..orthogonalized_short_basis.get_num_columns() {
let b_tilde_i = orthogonalized_short_basis.get_column(i).unwrap();
assert!(b_tilde_i.norm_eucl_sqrd().unwrap() <= upper_bound.pow(2).unwrap())
}
}
}
fn calculate_invertible_tag(size: i64, q: &Modulus) -> MatZq {
let max_value = Z::from(q);
let mut out = MatZq::identity(size, size, q);
for row in 0..size {
for column in 0..size {
if row < column {
out.set_entry(row, column, Z::sample_uniform(0, &max_value).unwrap())
.unwrap();
}
}
}
out
}
}
#[cfg(test)]
mod test_gen_sa {
use super::gen_sa_l;
use crate::sample::g_trapdoor::{
gadget_parameters::GadgetParameters, short_basis_classical::gen_sa_r,
};
use qfall_math::{integer::MatZ, integer_mod_q::MatZq};
use std::str::FromStr;
fn get_fixed_trapdoor_for_tag_identity() -> (GadgetParameters, MatZq, MatZ) {
let params = GadgetParameters::init_default(2, 8);
let a = MatZq::from_str(
"[\
[2, 6, 2, 5, 3, 0, 1, 1, 1, 6, 5, 0, 6],\
[6, 0, 3, 1, 5, 6, 2, 7, 0, 3, 7, 7, 0]] mod 8",
)
.unwrap();
let r = MatZ::from_str(
"[[0, 1, 0, 1, 1, 0],\
[-1, 1, 0, 0, 0, -1],\
[-1, 0, -1, -1, -1, 0],\
[-1, 1, 0, 0, 0, 1],\
[-1, -1, 0, 1, 0, 1],\
[-1, 0, 0, -1, 0, 1],\
[0, -1, 0, 0, 0, 0]]",
)
.unwrap();
(params, a, r)
}
#[test]
fn working_sa_l() {
let (_, _, r) = get_fixed_trapdoor_for_tag_identity();
let sa_1 = gen_sa_l(&r);
let sa_1_cmp = MatZ::from_str(
"[\
[1, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 0],\
[0, 1, 0, 0, 0, 0, 0, -1, 1, 0, 0, 0, -1],\
[0, 0, 1, 0, 0, 0, 0, -1, 0, -1, -1, -1, 0],\
[0, 0, 0, 1, 0, 0, 0, -1, 1, 0, 0, 0, 1],\
[0, 0, 0, 0, 1, 0, 0, -1, -1, 0, 1, 0, 1],\
[0, 0, 0, 0, 0, 1, 0, -1, 0, 0, -1, 0, 1],\
[0, 0, 0, 0, 0, 0, 1, 0, -1, 0, 0, 0, 0],\
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],\
[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],\
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],\
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],\
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],\
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]]",
)
.unwrap();
assert_eq!(sa_1_cmp, sa_1);
}
#[test]
fn working_sa_r_identity() {
let (params, a, _) = get_fixed_trapdoor_for_tag_identity();
let tag = MatZq::identity(¶ms.n, ¶ms.n, ¶ms.q);
let sa_r = gen_sa_r(¶ms, &tag, &a);
let sa_r_cmp = MatZ::from_str(
"[\
[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0],\
[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0],\
[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0],\
[0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0],\
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0],\
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0],\
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1],\
[0, 0, 0, 0, 0, 2, 0, 0, 0, 1, 1, 0, 1],\
[0, 0, 0, 0, 2, -1, 1, 1, 1, 1, 0, 0, 1],\
[0, 0, 0, 2, -1, 0, 1, 0, 1, 0, 1, 0, 1],\
[0, 0, 2, 0, 0, 0, 0, 0, 1, 1, 1, 0, 0],\
[0, 2, -1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1],\
[2, -1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 1]]",
)
.unwrap();
assert_eq!(sa_r_cmp, sa_r);
}
}
#[cfg(test)]
mod test_compute_w {
use super::compute_w;
use crate::sample::g_trapdoor::{
gadget_classical::gen_gadget_mat, gadget_parameters::GadgetParameters,
};
use qfall_math::{
integer::{MatZ, Z},
integer_mod_q::MatZq,
traits::MatrixDimensions,
};
use std::str::FromStr;
#[test]
fn working_example_tag_identity() {
let params = GadgetParameters::init_default(2, 8);
let tag = MatZq::identity(2, 2, ¶ms.q);
let a = MatZq::from_str(
"[\
[2, 6, 2, 5, 3, 0, 1, 1, 1, 6, 5, 0, 6],\
[6, 0, 3, 1, 5, 6, 2, 7, 0, 3, 7, 7, 0]] mod 8",
)
.unwrap();
let w = compute_w(¶ms, &tag, &a);
let g = gen_gadget_mat((¶ms.n).try_into().unwrap(), ¶ms.k, ¶ms.base);
let gw = MatZq::from((&(g * w), ¶ms.q));
let rhs = &a * MatZ::identity(a.get_num_columns(), ¶ms.m_bar);
assert_eq!(gw, Z::MINUS_ONE * rhs)
}
}