use qfall_math::{
error::MathError,
integer::{PolyOverZ, Z},
integer_mod_q::{ModulusPolynomialRingZq, PolynomialRingZq},
traits::{GetCoefficient, SetCoefficient},
};
pub fn encode_value_in_polynomialringzq(
value: impl Into<Z>,
base: impl Into<Z>,
modulus: &ModulusPolynomialRingZq,
) -> Result<PolynomialRingZq, MathError> {
let mut value = value.into();
let base = base.into();
let modulus: ModulusPolynomialRingZq = modulus.into();
if value < Z::ZERO {
return Err(MathError::InvalidIntegerInput(format!(
"The given value {value} needs to be non-negative."
)));
}
let min_req_degree = u64::try_from((&value + Z::ONE).log_ceil(&base)?)?;
if min_req_degree > modulus.get_degree() as u64 {
return Err(MathError::InvalidIntegerInput(format!(
"The given value requires {min_req_degree} digits represented w.r.t. base {base}. Your modulus only provides space for {} digits.",
modulus.get_degree()
)));
}
let mut base_repr = Vec::with_capacity(min_req_degree as usize);
while value > Z::ZERO {
let digit = &value % &base;
base_repr.push(digit);
value = value.div_floor(&base);
}
let mut res = PolyOverZ::default();
for (i, digit) in base_repr.iter().enumerate() {
if digit != &Z::ZERO {
unsafe { res.set_coeff_unchecked(i as i64, digit) };
}
}
let q_div_base = modulus.get_q().div_floor(&base);
res *= q_div_base;
Ok(PolynomialRingZq::from((res, modulus)))
}
pub fn decode_value_from_polynomialringzq(
poly: &PolynomialRingZq,
base: impl Into<Z>,
) -> Result<Z, MathError> {
let base = base.into();
let q = poly.get_mod().get_q();
let q_div_2base = q.div_floor(2 * &base);
if base <= Z::ONE {
return Err(MathError::InvalidIntegerInput(format!(
"The given base {base} is smaller than 2, which does not allow the encoding of any information."
)));
}
let mut poly = poly.get_representative_least_nonnegative_residue();
poly *= &base;
let mut out = Z::default();
for i in (0..=poly.get_degree()).rev() {
let mut coeff = unsafe { poly.get_coeff_unchecked(i) };
coeff += &q_div_2base;
let res = coeff.div_floor(&q) % &base;
out *= &base;
out += res;
}
Ok(out)
}
#[cfg(test)]
mod test_encode_value_in_polynomialringzq {
use crate::utils::{
common_encodings::encode_value_in_polynomialringzq, common_moduli::new_anticyclic,
};
use qfall_math::{integer::Z, traits::GetCoefficient};
#[test]
fn binary() {
let q = 257;
let q_half = q / 2;
let modulus = new_anticyclic(16, q).unwrap();
let res0 = encode_value_in_polynomialringzq(1, 2, &modulus).unwrap();
let res1 = encode_value_in_polynomialringzq(2, 2, &modulus).unwrap();
let res2 = encode_value_in_polynomialringzq(3, 2, &modulus).unwrap();
assert_eq!(GetCoefficient::<Z>::get_coeff(&res0, 0).unwrap(), q_half);
assert_eq!(res0.get_degree(), 0);
assert_eq!(GetCoefficient::<Z>::get_coeff(&res1, 0).unwrap(), 0);
assert_eq!(GetCoefficient::<Z>::get_coeff(&res1, 1).unwrap(), q_half);
assert_eq!(res1.get_degree(), 1);
assert_eq!(GetCoefficient::<Z>::get_coeff(&res2, 0).unwrap(), q_half);
assert_eq!(GetCoefficient::<Z>::get_coeff(&res2, 1).unwrap(), q_half);
assert_eq!(res2.get_degree(), 1);
}
#[test]
fn ternary() {
let q = 257;
let q_third = q / 3;
let modulus = new_anticyclic(16, q).unwrap();
let res0 = encode_value_in_polynomialringzq(1, 3, &modulus).unwrap();
let res1 = encode_value_in_polynomialringzq(2, 3, &modulus).unwrap();
let res2 = encode_value_in_polynomialringzq(3, 3, &modulus).unwrap();
assert_eq!(GetCoefficient::<Z>::get_coeff(&res0, 0).unwrap(), q_third);
assert_eq!(res0.get_degree(), 0);
assert_eq!(
GetCoefficient::<Z>::get_coeff(&res1, 0).unwrap(),
2 * q_third
);
assert_eq!(res1.get_degree(), 0);
assert_eq!(GetCoefficient::<Z>::get_coeff(&res2, 0).unwrap(), 0);
assert_eq!(GetCoefficient::<Z>::get_coeff(&res2, 1).unwrap(), q_third);
assert_eq!(res2.get_degree(), 1);
}
#[test]
fn not_enough_space() {
let modulus = new_anticyclic(16, 257).unwrap();
let res = encode_value_in_polynomialringzq(u16::MAX as u32 + 1, 2, &modulus);
assert!(res.is_err());
}
#[test]
fn too_small_base() {
let modulus = new_anticyclic(16, 257).unwrap();
let res = encode_value_in_polynomialringzq(4, 1, &modulus);
assert!(res.is_err());
}
#[test]
fn neagive_value() {
let modulus = new_anticyclic(16, 257).unwrap();
let res = encode_value_in_polynomialringzq(-1, 1, &modulus);
assert!(res.is_err());
}
}
#[cfg(test)]
mod test_decode_value_from_polynomialringzq {
use crate::utils::{
common_encodings::{decode_value_from_polynomialringzq, encode_value_in_polynomialringzq},
common_moduli::new_anticyclic,
};
use qfall_math::{integer::Z, integer_mod_q::PolynomialRingZq};
#[test]
fn round_trip_binary() {
let q = 257;
let base = 2;
let modulus = new_anticyclic(17, q).unwrap();
let msg = Z::sample_uniform(0, u16::MAX).unwrap();
let encoding = encode_value_in_polynomialringzq(&msg, base, &modulus).unwrap();
let decoding = decode_value_from_polynomialringzq(&encoding, base).unwrap();
assert_eq!(msg, decoding);
}
#[test]
fn round_trip_ternary() {
let q = 257;
let base = 3;
let modulus = new_anticyclic(16, q).unwrap();
let msg = Z::sample_uniform(0, u16::MAX).unwrap();
let encoding = encode_value_in_polynomialringzq(&msg, base, &modulus).unwrap();
let decoding = decode_value_from_polynomialringzq(&encoding, base).unwrap();
assert_eq!(msg, decoding);
}
#[test]
fn too_small_base() {
let modulus = new_anticyclic(16, 257).unwrap();
let poly = PolynomialRingZq::sample_uniform(&modulus);
let res = decode_value_from_polynomialringzq(&poly, 1);
assert!(res.is_err());
}
}