use crate::math::{NttContext, Poly};
use crate::rlwe::RlweCiphertext;
use super::types::{GadgetVector, RgswCiphertext};
pub fn gadget_decompose(poly: &Poly, gadget: &GadgetVector) -> Vec<Poly> {
let d = poly.dimension();
let base = gadget.base;
let ell = gadget.len;
let mut result = Vec::with_capacity(ell);
for _ in 0..ell {
result.push(Poly::zero_moduli(d, poly.moduli()));
}
for j in 0..d {
let mut val = poly.coeff(j);
for result_poly in &mut result {
let digit = val % base;
result_poly.set_coeff(j, digit);
val /= base;
}
}
result
}
pub fn gadget_reconstruct(decomposed: &[Poly], gadget: &GadgetVector) -> Poly {
assert!(!decomposed.is_empty(), "Decomposition cannot be empty");
assert_eq!(
decomposed.len(),
gadget.len,
"Decomposition length must match gadget length"
);
let d = decomposed[0].dimension();
let moduli = decomposed[0].moduli();
for (idx, poly) in decomposed.iter().enumerate() {
assert_eq!(
poly.dimension(),
d,
"Decomposed poly[{idx}] has mismatched dimension"
);
assert_eq!(
poly.moduli(),
moduli,
"Decomposed poly[{idx}] has mismatched moduli"
);
}
let powers = gadget.powers();
let mut result = Poly::zero_moduli(d, moduli);
for (i, poly) in decomposed.iter().enumerate() {
let scaled = poly.scalar_mul(powers[i]);
result += scaled;
}
result
}
pub fn external_product(
rlwe: &RlweCiphertext,
rgsw: &RgswCiphertext,
ctx: &NttContext,
) -> RlweCiphertext {
let d = rlwe.ring_dim();
let moduli = rlwe.a.moduli();
let gadget = &rgsw.gadget;
let ell = gadget.len;
assert_eq!(rlwe.b.moduli(), moduli, "RLWE components must share moduli");
assert_eq!(
ctx.moduli(),
moduli,
"NTT context moduli must match ciphertext moduli"
);
assert_eq!(rgsw.rows.len(), 2 * ell, "RGSW must have 2ℓ rows");
for (idx, row) in rgsw.rows.iter().enumerate() {
assert_eq!(
row.ring_dim(),
d,
"RGSW row[{idx}] has mismatched ring dimension"
);
assert_eq!(
row.a.moduli(),
moduli,
"RGSW row[{idx}] moduli mismatch in a component"
);
assert_eq!(
row.b.moduli(),
moduli,
"RGSW row[{idx}] moduli mismatch in b component"
);
}
let a_decomp = gadget_decompose(&rlwe.a, gadget);
let b_decomp = gadget_decompose(&rlwe.b, gadget);
let mut result_a = Poly::zero_moduli(d, moduli);
let mut result_b = Poly::zero_moduli(d, moduli);
for i in 0..ell {
let row_a = &rgsw.rows[i];
let term_a_a = a_decomp[i].mul_ntt(&row_a.a, ctx);
let term_a_b = a_decomp[i].mul_ntt(&row_a.b, ctx);
result_a += term_a_a;
result_b += term_a_b;
let row_b = &rgsw.rows[ell + i];
let term_b_a = b_decomp[i].mul_ntt(&row_b.a, ctx);
let term_b_b = b_decomp[i].mul_ntt(&row_b.b, ctx);
result_a += term_b_a;
result_b += term_b_b;
}
RlweCiphertext::from_parts(result_a, result_b)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::math::GaussianSampler;
use crate::params::InspireParams;
use crate::rlwe::RlweSecretKey;
fn test_params() -> InspireParams {
InspireParams::secure_128_d2048()
}
fn make_ctx(params: &InspireParams) -> NttContext {
params.ntt_context()
}
fn sample_error_poly(dim: usize, moduli: &[u64], sampler: &mut GaussianSampler) -> Poly {
Poly::sample_gaussian_moduli(dim, moduli, sampler)
}
#[test]
fn test_gadget_decompose_reconstruct_roundtrip() {
let params = test_params();
let gadget = GadgetVector::new(params.gadget_base, params.gadget_len, params.q);
let poly = Poly::random_moduli(params.ring_dim, params.moduli());
let decomposed = gadget_decompose(&poly, &gadget);
let reconstructed = gadget_reconstruct(&decomposed, &gadget);
assert_eq!(poly, reconstructed);
}
#[test]
fn test_gadget_decompose_small_digits() {
let params = test_params();
let gadget = GadgetVector::new(params.gadget_base, params.gadget_len, params.q);
let poly = Poly::random_moduli(params.ring_dim, params.moduli());
let decomposed = gadget_decompose(&poly, &gadget);
for digit_poly in &decomposed {
for j in 0..params.ring_dim {
let coeff = digit_poly.coeff(j);
assert!(
coeff < params.gadget_base,
"Digit {} exceeds base {}",
coeff,
params.gadget_base
);
}
}
}
#[test]
fn test_gadget_decompose_zero() {
let params = test_params();
let gadget = GadgetVector::new(params.gadget_base, params.gadget_len, params.q);
let zero = Poly::zero_moduli(params.ring_dim, params.moduli());
let decomposed = gadget_decompose(&zero, &gadget);
for digit_poly in &decomposed {
assert!(digit_poly.is_zero());
}
}
#[test]
fn test_external_product_by_zero() {
let params = test_params();
let ctx = make_ctx(¶ms);
let mut sampler = GaussianSampler::new(params.sigma);
let delta = params.delta();
let sk = RlweSecretKey::generate(¶ms, &mut sampler);
let gadget = GadgetVector::new(params.gadget_base, params.gadget_len, params.q);
let msg_coeffs: Vec<u64> = (0..params.ring_dim)
.map(|i| (i as u64) % params.p)
.collect();
let msg = Poly::from_coeffs_moduli(msg_coeffs, params.moduli());
let a = Poly::random_moduli(params.ring_dim, params.moduli());
let e = sample_error_poly(params.ring_dim, params.moduli(), &mut sampler);
let rlwe = RlweCiphertext::encrypt(&sk, &msg, delta, a, &e, &ctx);
let rgsw_zero =
super::super::RgswCiphertext::encrypt_scalar(&sk, 0, &gadget, &mut sampler, &ctx);
let result = external_product(&rlwe, &rgsw_zero, &ctx);
let decrypted = result.decrypt(&sk, delta, params.p, &ctx);
for i in 0..params.ring_dim {
assert_eq!(decrypted.coeff(i), 0, "Expected 0 at coefficient {}", i);
}
}
#[test]
fn test_external_product_by_one() {
let params = test_params();
let ctx = make_ctx(¶ms);
let mut sampler = GaussianSampler::new(params.sigma);
let delta = params.delta();
let sk = RlweSecretKey::generate(¶ms, &mut sampler);
let gadget = GadgetVector::new(params.gadget_base, params.gadget_len, params.q);
let msg_coeffs: Vec<u64> = (0..params.ring_dim)
.map(|i| (i as u64) % params.p)
.collect();
let msg = Poly::from_coeffs_moduli(msg_coeffs.clone(), params.moduli());
let a = Poly::random_moduli(params.ring_dim, params.moduli());
let e = sample_error_poly(params.ring_dim, params.moduli(), &mut sampler);
let rlwe = RlweCiphertext::encrypt(&sk, &msg, delta, a, &e, &ctx);
let rgsw_one =
super::super::RgswCiphertext::encrypt_scalar(&sk, 1, &gadget, &mut sampler, &ctx);
let result = external_product(&rlwe, &rgsw_one, &ctx);
let decrypted = result.decrypt(&sk, delta, params.p, &ctx);
for (i, expected) in msg_coeffs.iter().enumerate().take(params.ring_dim) {
assert_eq!(
decrypted.coeff(i),
*expected,
"Mismatch at coefficient {}",
i
);
}
}
#[test]
fn test_external_product_scalar_multiplication() {
let params = test_params();
let ctx = make_ctx(¶ms);
let mut sampler = GaussianSampler::new(params.sigma);
let delta = params.delta();
let sk = RlweSecretKey::generate(¶ms, &mut sampler);
let gadget = GadgetVector::new(params.gadget_base, params.gadget_len, params.q);
let msg_coeffs: Vec<u64> = (0..params.ring_dim).map(|i| (i as u64) % 10).collect();
let msg = Poly::from_coeffs_moduli(msg_coeffs.clone(), params.moduli());
let a = Poly::random_moduli(params.ring_dim, params.moduli());
let e = sample_error_poly(params.ring_dim, params.moduli(), &mut sampler);
let rlwe = RlweCiphertext::encrypt(&sk, &msg, delta, a, &e, &ctx);
let scalar = 3u64;
let rgsw_scalar =
super::super::RgswCiphertext::encrypt_scalar(&sk, scalar, &gadget, &mut sampler, &ctx);
let result = external_product(&rlwe, &rgsw_scalar, &ctx);
let decrypted = result.decrypt(&sk, delta, params.p, &ctx);
for (i, msg_coeff) in msg_coeffs.iter().enumerate().take(params.ring_dim) {
let expected = (*msg_coeff * scalar) % params.p;
assert_eq!(
decrypted.coeff(i),
expected,
"Mismatch at coefficient {}: expected {}, got {}",
i,
expected,
decrypted.coeff(i)
);
}
}
#[test]
fn test_external_product_monomial() {
let params = test_params();
let ctx = make_ctx(¶ms);
let mut sampler = GaussianSampler::new(params.sigma);
let delta = params.delta();
let sk = RlweSecretKey::generate(¶ms, &mut sampler);
let gadget = GadgetVector::new(params.gadget_base, params.gadget_len, params.q);
let mut msg_coeffs = vec![0u64; params.ring_dim];
msg_coeffs[0] = 5;
let msg = Poly::from_coeffs_moduli(msg_coeffs, params.moduli());
let a = Poly::random_moduli(params.ring_dim, params.moduli());
let e = sample_error_poly(params.ring_dim, params.moduli(), &mut sampler);
let rlwe = RlweCiphertext::encrypt(&sk, &msg, delta, a, &e, &ctx);
let mut monomial_coeffs = vec![0u64; params.ring_dim];
monomial_coeffs[1] = 1;
let monomial = Poly::from_coeffs_moduli(monomial_coeffs, params.moduli());
let rgsw_mono =
super::super::RgswCiphertext::encrypt(&sk, &monomial, &gadget, &mut sampler, &ctx);
let result = external_product(&rlwe, &rgsw_mono, &ctx);
let decrypted = result.decrypt(&sk, delta, params.p, &ctx);
assert_eq!(decrypted.coeff(0), 0, "Constant term should be 0");
assert_eq!(decrypted.coeff(1), 5, "X coefficient should be 5");
for i in 2..params.ring_dim {
assert_eq!(decrypted.coeff(i), 0, "Higher terms should be 0");
}
}
}