use crate::math::{GaussianSampler, NttContext, Poly};
use crate::rlwe::{RlweCiphertext, RlweSecretKey, SeededRlweCiphertext};
use rand::RngCore;
use serde::{Deserialize, Serialize};
fn sample_error_poly(dim: usize, moduli: &[u64], sampler: &mut GaussianSampler) -> Poly {
Poly::sample_gaussian_moduli(dim, moduli, sampler)
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct GadgetVector {
pub base: u64,
pub len: usize,
pub q: u64,
}
impl GadgetVector {
pub fn new(base: u64, len: usize, q: u64) -> Self {
debug_assert!(base > 1, "Gadget base must be > 1");
debug_assert!(len > 0, "Gadget length must be > 0");
Self { base, len, q }
}
pub fn from_base(base: u64, q: u64) -> Self {
let len = ((q as f64).log2() / (base as f64).log2()).ceil() as usize;
Self::new(base, len, q)
}
pub fn power(&self, i: usize) -> u64 {
let mut result = 1u128;
let base = self.base as u128;
let q = self.q as u128;
for _ in 0..i {
result = (result * base) % q;
}
result as u64
}
pub fn powers(&self) -> Vec<u64> {
let mut powers = Vec::with_capacity(self.len);
let mut current = 1u128;
let base = self.base as u128;
let q = self.q as u128;
for _ in 0..self.len {
powers.push(current as u64);
current = (current * base) % q;
}
powers
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct RgswCiphertext {
pub rows: Vec<RlweCiphertext>,
pub gadget: GadgetVector,
}
impl RgswCiphertext {
pub fn from_rows(rows: Vec<RlweCiphertext>, gadget: GadgetVector) -> Self {
debug_assert_eq!(rows.len(), 2 * gadget.len, "RGSW must have 2ℓ rows");
Self { rows, gadget }
}
pub fn encrypt(
sk: &RlweSecretKey,
message: &Poly,
gadget: &GadgetVector,
sampler: &mut GaussianSampler,
ctx: &NttContext,
) -> Self {
let d = sk.ring_dim();
let moduli = sk.poly.moduli();
let ell = gadget.len;
let mut rows = Vec::with_capacity(2 * ell);
let powers = gadget.powers();
assert!(
powers.len() >= ell,
"gadget powers must have at least {} entries, got {}",
ell,
powers.len()
);
for &power in &powers[..ell] {
let a_rand = Poly::random_moduli(d, moduli);
let error = sample_error_poly(d, moduli, sampler);
let a_s = a_rand.mul_ntt(&sk.poly, ctx);
let b = &(-a_s) + &error;
let scaled_msg = message.scalar_mul(power);
let a = &a_rand + &scaled_msg;
rows.push(RlweCiphertext::from_parts(a, b));
}
for &power in &powers[..ell] {
let a = Poly::random_moduli(d, moduli);
let error = sample_error_poly(d, moduli, sampler);
let a_s = a.mul_ntt(&sk.poly, ctx);
let b_base = &(-a_s) + &error;
let scaled_msg = message.scalar_mul(power);
let b = &b_base + &scaled_msg;
rows.push(RlweCiphertext::from_parts(a, b));
}
Self {
rows,
gadget: gadget.clone(),
}
}
pub fn encrypt_scalar(
sk: &RlweSecretKey,
message: u64,
gadget: &GadgetVector,
sampler: &mut GaussianSampler,
ctx: &NttContext,
) -> Self {
let msg_poly = Poly::constant_moduli(message, sk.ring_dim(), sk.poly.moduli());
Self::encrypt(sk, &msg_poly, gadget, sampler, ctx)
}
pub fn ring_dim(&self) -> usize {
self.rows[0].ring_dim()
}
pub fn modulus(&self) -> u64 {
self.rows[0].modulus()
}
pub fn gadget_len(&self) -> usize {
self.gadget.len
}
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct SeededRgswCiphertext {
pub rows: Vec<SeededRlweCiphertext>,
pub gadget: GadgetVector,
}
impl SeededRgswCiphertext {
pub fn encrypt(
sk: &RlweSecretKey,
message: &Poly,
gadget: &GadgetVector,
sampler: &mut GaussianSampler,
ctx: &NttContext,
) -> Self {
let d = sk.ring_dim();
let moduli = sk.poly.moduli();
let ell = gadget.len;
let mut rows = Vec::with_capacity(2 * ell);
let powers = gadget.powers();
let mut rng = rand::thread_rng();
assert!(
powers.len() >= ell,
"gadget powers must have at least {} entries, got {}",
ell,
powers.len()
);
for &power in &powers[..ell] {
let mut seed = [0u8; 32];
rng.fill_bytes(&mut seed);
let a_rand = Poly::from_seed_moduli(&seed, d, moduli);
let error = sample_error_poly(d, moduli, sampler);
let a_s = a_rand.mul_ntt(&sk.poly, ctx);
let b = &(-a_s) + &error;
let scaled_msg = message.scalar_mul(power);
let msg_s = scaled_msg.mul_ntt(&sk.poly, ctx);
let b_adjusted = &b + &msg_s;
rows.push(SeededRlweCiphertext::new(seed, b_adjusted));
}
for &power in &powers[..ell] {
let mut seed = [0u8; 32];
rng.fill_bytes(&mut seed);
let a = Poly::from_seed_moduli(&seed, d, moduli);
let error = sample_error_poly(d, moduli, sampler);
let a_s = a.mul_ntt(&sk.poly, ctx);
let b_base = &(-a_s) + &error;
let scaled_msg = message.scalar_mul(power);
let b = &b_base + &scaled_msg;
rows.push(SeededRlweCiphertext::new(seed, b));
}
Self {
rows,
gadget: gadget.clone(),
}
}
pub fn encrypt_scalar(
sk: &RlweSecretKey,
message: u64,
gadget: &GadgetVector,
sampler: &mut GaussianSampler,
ctx: &NttContext,
) -> Self {
let msg_poly = Poly::constant_moduli(message, sk.ring_dim(), sk.poly.moduli());
Self::encrypt(sk, &msg_poly, gadget, sampler, ctx)
}
pub fn expand(&self) -> RgswCiphertext {
let rows: Vec<RlweCiphertext> = self.rows.iter().map(|r| r.expand()).collect();
RgswCiphertext::from_rows(rows, self.gadget.clone())
}
pub fn ring_dim(&self) -> usize {
self.rows[0].ring_dim()
}
pub fn modulus(&self) -> u64 {
self.rows[0].modulus()
}
pub fn gadget_len(&self) -> usize {
self.gadget.len
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::params::InspireParams;
fn test_params() -> InspireParams {
InspireParams::secure_128_d2048()
}
fn make_ctx(params: &InspireParams) -> NttContext {
params.ntt_context()
}
#[test]
fn test_gadget_vector_creation() {
let q = 1152921504606830593u64;
let gadget = GadgetVector::new(1 << 20, 3, q);
assert_eq!(gadget.base, 1 << 20);
assert_eq!(gadget.len, 3);
assert_eq!(gadget.q, q);
}
#[test]
fn test_gadget_powers() {
let q = 1152921504606830593u64;
let base = 1 << 20;
let gadget = GadgetVector::new(base, 3, q);
let powers = gadget.powers();
assert_eq!(powers.len(), 3);
assert_eq!(powers[0], 1);
assert_eq!(powers[1], base);
assert_eq!(
powers[2],
((base as u128 * base as u128) % q as u128) as u64
);
}
#[test]
fn test_gadget_from_base() {
let q = 1152921504606830593u64;
let gadget = GadgetVector::from_base(1 << 20, q);
assert_eq!(gadget.len, 3);
}
#[test]
fn test_rgsw_encryption_structure() {
let params = test_params();
let ctx = make_ctx(¶ms);
let mut sampler = GaussianSampler::new(params.sigma);
let sk = RlweSecretKey::generate(¶ms, &mut sampler);
let gadget = GadgetVector::new(params.gadget_base, params.gadget_len, params.q);
let rgsw = RgswCiphertext::encrypt_scalar(&sk, 1, &gadget, &mut sampler, &ctx);
assert_eq!(rgsw.rows.len(), 2 * params.gadget_len);
assert_eq!(rgsw.ring_dim(), params.ring_dim);
assert_eq!(rgsw.modulus(), params.q);
}
#[test]
fn test_rgsw_encrypt_zero() {
let params = test_params();
let ctx = make_ctx(¶ms);
let mut sampler = GaussianSampler::new(params.sigma);
let sk = RlweSecretKey::generate(¶ms, &mut sampler);
let gadget = GadgetVector::new(params.gadget_base, params.gadget_len, params.q);
let rgsw = RgswCiphertext::encrypt_scalar(&sk, 0, &gadget, &mut sampler, &ctx);
assert_eq!(rgsw.rows.len(), 6); }
}