use flint_sys::fmpz_poly::fmpz_poly_set_coeff_fmpz;
use qfall_math::{
integer::{MatPolyOverZ, PolyOverZ, Z},
integer_mod_q::{MatPolynomialRingZq, ModulusPolynomialRingZq, PolynomialRingZq},
traits::{GetCoefficient, MatrixDimensions, MatrixGetEntry, MatrixSetEntry, Pow},
};
pub trait LossyCompressionFIPS203 {
type CompressedType;
type ModulusType;
fn lossy_compress(&self, d: impl Into<Z>) -> Self::CompressedType;
fn lossy_decompress(
compressed: &Self::CompressedType,
d: impl Into<Z>,
modulus: &Self::ModulusType,
) -> Self;
}
impl LossyCompressionFIPS203 for PolynomialRingZq {
type CompressedType = PolyOverZ;
type ModulusType = ModulusPolynomialRingZq;
fn lossy_compress(&self, d: impl Into<Z>) -> Self::CompressedType {
let d = d.into();
assert!(
d >= Z::ONE,
"Performing this function with d < 1 implies reducing mod 1, leaving no information to recover. Choose a larger parameter d."
);
let two_pow_d = Z::from(2).pow(d).unwrap();
let q = self.get_mod().get_q();
let q_div_2 = q.div_floor(2);
let mut out = PolyOverZ::default();
for coeff_i in 0..=self.get_degree() {
let mut coeff: Z = unsafe { self.get_coeff_unchecked(coeff_i) };
coeff *= &two_pow_d;
coeff += &q_div_2;
let mut res = coeff.div_floor(&q) % &two_pow_d;
unsafe {
fmpz_poly_set_coeff_fmpz(out.get_fmpz_poly_struct(), coeff_i, res.get_fmpz());
};
}
out
}
fn lossy_decompress(
compressed: &Self::CompressedType,
d: impl Into<Z>,
modulus: &Self::ModulusType,
) -> Self {
let d = d.into();
assert!(
d >= Z::ONE,
"Performing this function with d < 1 implies reducing mod 1, leaving no information to recover. Choose a larger parameter d."
);
let two_pow_d_minus_1 = Z::from(2).pow(d - 1).unwrap();
let two_pow_d = &two_pow_d_minus_1 * 2;
let q = modulus.get_q();
let mut out = Self::from(modulus);
for coeff_i in 0..=compressed.get_degree() {
let mut coeff: Z = unsafe { compressed.get_coeff_unchecked(coeff_i) };
coeff *= &q;
coeff += &two_pow_d_minus_1;
let mut res = coeff.div_floor(&two_pow_d);
unsafe {
fmpz_poly_set_coeff_fmpz(out.get_fmpz_poly_struct(), coeff_i, res.get_fmpz());
};
}
out
}
}
impl LossyCompressionFIPS203 for MatPolynomialRingZq {
type CompressedType = MatPolyOverZ;
type ModulusType = ModulusPolynomialRingZq;
fn lossy_compress(&self, d: impl Into<Z>) -> Self::CompressedType {
let d = d.into();
let mut out = MatPolyOverZ::new(self.get_num_rows(), self.get_num_columns());
for row in 0..self.get_num_rows() {
for col in 0..self.get_num_columns() {
let entry: PolynomialRingZq = unsafe { self.get_entry_unchecked(row, col) };
let res = entry.lossy_compress(&d);
unsafe { out.set_entry_unchecked(row, col, res) };
}
}
out
}
fn lossy_decompress(
compressed: &Self::CompressedType,
d: impl Into<Z>,
modulus: &Self::ModulusType,
) -> Self {
let d = d.into();
let mut out = MatPolynomialRingZq::new(
compressed.get_num_rows(),
compressed.get_num_columns(),
modulus,
);
for row in 0..compressed.get_num_rows() {
for col in 0..compressed.get_num_columns() {
let entry: PolyOverZ = unsafe { compressed.get_entry_unchecked(row, col) };
let res = PolynomialRingZq::lossy_decompress(&entry, &d, modulus);
unsafe { out.set_entry_unchecked(row, col, res) };
}
}
out
}
}
#[cfg(test)]
mod test_compression_poly {
use crate::{compression::LossyCompressionFIPS203, utils::common_moduli::new_anticyclic};
use qfall_math::{
integer::Z,
integer_mod_q::PolynomialRingZq,
traits::{Distance, GetCoefficient, Pow},
};
#[test]
fn round_trip_small_d() {
let d = 4;
let q = Z::from(257);
let modulus = new_anticyclic(16, &q).unwrap();
let poly = PolynomialRingZq::sample_uniform(&modulus);
let compressed = poly.lossy_compress(d);
let decompressed = PolynomialRingZq::lossy_decompress(&compressed, d, &modulus);
for i in 0..modulus.get_degree() {
let orig_coeff: Z = poly.get_coeff(i).unwrap();
let coeff: Z = decompressed.get_coeff(i).unwrap();
let mut distance = orig_coeff.distance(coeff);
if distance > &q / 2 {
distance = &q - distance;
}
assert!(distance <= Z::from(2).pow(q.log_ceil(2).unwrap() - d - 1).unwrap());
}
}
#[test]
fn round_trip() {
let d = 11;
let q = Z::from(3329);
let modulus = new_anticyclic(16, &q).unwrap();
let poly = PolynomialRingZq::sample_uniform(&modulus);
let compressed = poly.lossy_compress(d);
let decompressed = PolynomialRingZq::lossy_decompress(&compressed, d, &modulus);
for i in 0..modulus.get_degree() {
let orig_coeff: Z = poly.get_coeff(i).unwrap();
let coeff: Z = decompressed.get_coeff(i).unwrap();
let mut distance = orig_coeff.distance(coeff);
if distance > &q / 2 {
distance = &q - distance;
}
assert!(distance <= Z::from(2).pow(q.log_ceil(2).unwrap() - d - 1).unwrap());
}
}
#[test]
#[should_panic]
fn too_small_d() {
let d = 0;
let q = Z::from(3329);
let modulus = new_anticyclic(16, &q).unwrap();
let poly = PolynomialRingZq::sample_uniform(&modulus);
poly.lossy_compress(d);
}
}
#[cfg(test)]
mod test_compression_matrix {
use crate::{compression::LossyCompressionFIPS203, utils::common_moduli::new_anticyclic};
use qfall_math::{
integer::Z,
integer_mod_q::{MatPolynomialRingZq, PolynomialRingZq},
traits::{Distance, GetCoefficient, MatrixDimensions, MatrixGetEntry, Pow},
};
#[test]
fn round_trip() {
let d = 11;
let q = Z::from(3329);
let modulus = new_anticyclic(16, &q).unwrap();
let matrix = MatPolynomialRingZq::sample_uniform(2, 2, &modulus);
let compressed = matrix.lossy_compress(d);
let decompressed = MatPolynomialRingZq::lossy_decompress(&compressed, d, &modulus);
for row in 0..matrix.get_num_rows() {
for col in 0..matrix.get_num_columns() {
let orig_entry: PolynomialRingZq = matrix.get_entry(row, col).unwrap();
let entry: PolynomialRingZq = decompressed.get_entry(row, col).unwrap();
for i in 0..modulus.get_degree() {
let orig_coeff: Z = orig_entry.get_coeff(i).unwrap();
let coeff: Z = entry.get_coeff(i).unwrap();
let mut distance = orig_coeff.distance(coeff);
if distance > &q / 2 {
distance = &q - distance;
}
assert!(distance <= Z::from(2).pow(q.log_ceil(2).unwrap() - d - 1).unwrap());
}
}
}
}
#[test]
#[should_panic]
fn too_small_d() {
let d = 0;
let q = Z::from(3329);
let modulus = new_anticyclic(16, &q).unwrap();
let poly = MatPolynomialRingZq::sample_uniform(2, 3, &modulus);
poly.lossy_compress(d);
}
}