use crate::ks::KeySwitchingMatrix;
use crate::lwe::LweCiphertext;
use crate::math::Poly;
use crate::params::InspireParams;
use crate::rlwe::RlweCiphertext;
use super::collapse::{collapse, collapse_partial};
use super::transform::{aggregate, transform_at_slot};
use super::types::AggregatedCiphertext;
use serde::{Deserialize, Serialize};
pub fn pack(
lwe_ciphertexts: &[LweCiphertext],
k_g: &KeySwitchingMatrix,
k_h: &KeySwitchingMatrix,
params: &InspireParams,
) -> RlweCiphertext {
let d = params.ring_dim;
assert_eq!(
lwe_ciphertexts.len(),
d,
"Must provide exactly d ciphertexts for full packing"
);
let intermediates: Vec<_> = lwe_ciphertexts
.iter()
.enumerate()
.map(|(i, lwe)| transform_at_slot(lwe, i, params))
.collect();
let aggregated = aggregate(&intermediates, params);
collapse(&aggregated, k_g, k_h, params)
}
pub fn partial_pack(
lwe_ciphertexts: &[LweCiphertext],
k_g: &KeySwitchingMatrix,
params: &InspireParams,
) -> RlweCiphertext {
let gamma = lwe_ciphertexts.len();
let d = params.ring_dim;
assert!(gamma <= d / 2, "partial_pack requires γ ≤ d/2 ciphertexts");
if gamma == 0 {
return RlweCiphertext::from_parts(
Poly::zero_moduli(d, params.moduli()),
Poly::zero_moduli(d, params.moduli()),
);
}
let intermediates: Vec<_> = lwe_ciphertexts
.iter()
.enumerate()
.map(|(i, lwe)| transform_at_slot(lwe, i, params))
.collect();
let aggregated = aggregate(&intermediates, params);
collapse_partial(gamma, &aggregated.to_intermediate(), k_g, params)
}
#[derive(Clone, Debug, Serialize, Deserialize)]
pub struct PackingPrecomputation {
precomputed_a_aggregate: AggregatedCiphertext,
num_ciphertexts: usize,
ring_dim: usize,
q: u64,
moduli: Vec<u64>,
}
impl PackingPrecomputation {
pub fn num_ciphertexts(&self) -> usize {
self.num_ciphertexts
}
}
pub fn precompute_packing(
crs_a_vectors: &[Vec<u64>],
_k_g: &KeySwitchingMatrix,
_k_h: &KeySwitchingMatrix,
params: &InspireParams,
) -> PackingPrecomputation {
let d = params.ring_dim;
let q = params.q;
let moduli = params.moduli().to_vec();
let n = crs_a_vectors.len();
assert!(!crs_a_vectors.is_empty(), "Must have at least one a vector");
assert_eq!(crs_a_vectors[0].len(), d, "a vectors must have dimension d");
let dummy_lwes: Vec<LweCiphertext> = crs_a_vectors
.iter()
.map(|a| LweCiphertext {
a: a.clone(),
b: 0,
q,
})
.collect();
let intermediates: Vec<_> = dummy_lwes
.iter()
.enumerate()
.map(|(i, lwe)| transform_at_slot(lwe, i, params))
.collect();
let aggregated = aggregate(&intermediates, params);
PackingPrecomputation {
precomputed_a_aggregate: aggregated,
num_ciphertexts: n,
ring_dim: d,
q,
moduli,
}
}
pub fn pack_online(
lwe_b_values: &[u64],
precomp: &PackingPrecomputation,
k_g: &KeySwitchingMatrix,
k_h: &KeySwitchingMatrix,
params: &InspireParams,
) -> RlweCiphertext {
let d = precomp.ring_dim;
let moduli = &precomp.moduli;
let n = lwe_b_values.len();
assert_eq!(
n, precomp.num_ciphertexts,
"Number of b values must match precomputation"
);
let mut b_coeffs = vec![0u64; d];
for (i, &b_val) in lwe_b_values.iter().enumerate() {
if i < d {
b_coeffs[i] = b_val;
}
}
let b_poly = Poly::from_coeffs_moduli(b_coeffs, moduli);
let full_aggregate = AggregatedCiphertext::new(
precomp.precomputed_a_aggregate.a_polys.clone(),
&precomp.precomputed_a_aggregate.b_poly + &b_poly,
);
collapse(&full_aggregate, k_g, k_h, params)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::lwe::LweSecretKey;
use crate::math::GaussianSampler;
use rand::Rng;
use rand::SeedableRng;
fn test_params() -> InspireParams {
InspireParams {
ring_dim: 256,
q: 1152921504606830593,
crt_moduli: vec![1152921504606830593],
p: 65536,
sigma: 6.4,
gadget_base: 1 << 20,
gadget_len: 3,
security_level: crate::params::SecurityLevel::Bits128,
}
}
fn random_lwe<R: Rng>(rng: &mut R, params: &InspireParams) -> LweCiphertext {
let a: Vec<u64> = (0..params.ring_dim)
.map(|_| rng.gen_range(0..params.q))
.collect();
let b = rng.gen_range(0..params.q);
LweCiphertext { a, b, q: params.q }
}
fn encrypt_lwe<R: Rng>(
sk: &LweSecretKey,
message: u64,
rng: &mut R,
params: &InspireParams,
) -> LweCiphertext {
let a: Vec<u64> = (0..params.ring_dim)
.map(|_| rng.gen_range(0..params.q))
.collect();
let error = (rng.gen::<u8>() % 7) as i64 - 3;
LweCiphertext::encrypt(sk, message, params.delta(), a, error)
}
#[test]
fn test_pack_dimensions() {
let params = test_params();
let mut rng = rand_chacha::ChaCha20Rng::seed_from_u64(12345);
let lwe_cts: Vec<LweCiphertext> = (0..params.ring_dim)
.map(|_| random_lwe(&mut rng, ¶ms))
.collect();
let k_g = KeySwitchingMatrix::dummy(params.ring_dim, params.moduli(), params.gadget_len);
let k_h = KeySwitchingMatrix::dummy(params.ring_dim, params.moduli(), params.gadget_len);
let result = pack(&lwe_cts, &k_g, &k_h, ¶ms);
assert_eq!(result.ring_dim(), params.ring_dim);
assert_eq!(result.modulus(), params.q);
}
#[test]
fn test_partial_pack_dimensions() {
let params = test_params();
let mut rng = rand_chacha::ChaCha20Rng::seed_from_u64(54321);
let gamma = params.ring_dim / 4;
let lwe_cts: Vec<LweCiphertext> =
(0..gamma).map(|_| random_lwe(&mut rng, ¶ms)).collect();
let k_g = KeySwitchingMatrix::dummy(params.ring_dim, params.moduli(), params.gadget_len);
let result = partial_pack(&lwe_cts, &k_g, ¶ms);
assert_eq!(result.ring_dim(), params.ring_dim);
}
#[test]
fn test_precompute_pack_online() {
let params = test_params();
let mut rng = rand_chacha::ChaCha20Rng::seed_from_u64(98765);
let n = 16; let crs_a_vectors: Vec<Vec<u64>> = (0..n)
.map(|_| {
(0..params.ring_dim)
.map(|_| rng.gen_range(0..params.q))
.collect()
})
.collect();
let k_g = KeySwitchingMatrix::dummy(params.ring_dim, params.moduli(), params.gadget_len);
let k_h = KeySwitchingMatrix::dummy(params.ring_dim, params.moduli(), params.gadget_len);
let precomp = precompute_packing(&crs_a_vectors, &k_g, &k_h, ¶ms);
assert_eq!(precomp.num_ciphertexts(), n);
let b_values: Vec<u64> = (0..n).map(|_| rng.gen_range(0..params.q)).collect();
let result = pack_online(&b_values, &precomp, &k_g, &k_h, ¶ms);
assert_eq!(result.ring_dim(), params.ring_dim);
}
#[test]
fn test_pack_with_real_encryption() {
let params = test_params();
let mut rng = rand_chacha::ChaCha20Rng::seed_from_u64(11111);
let mut sampler = GaussianSampler::new(params.sigma);
let lwe_sk = LweSecretKey::generate(params.ring_dim, params.q, &mut sampler);
let messages: Vec<u64> = (0..params.ring_dim)
.map(|i| (i as u64 * 7) % params.p)
.collect();
let lwe_cts: Vec<LweCiphertext> = messages
.iter()
.map(|&m| encrypt_lwe(&lwe_sk, m, &mut rng, ¶ms))
.collect();
for (ct, &expected) in lwe_cts.iter().zip(messages.iter()) {
let decrypted = ct.decrypt(&lwe_sk, params.delta(), params.p);
assert_eq!(decrypted, expected, "LWE decryption failed");
}
let k_g = KeySwitchingMatrix::dummy(params.ring_dim, params.moduli(), params.gadget_len);
let k_h = KeySwitchingMatrix::dummy(params.ring_dim, params.moduli(), params.gadget_len);
let packed = pack(&lwe_cts, &k_g, &k_h, ¶ms);
assert_eq!(packed.ring_dim(), params.ring_dim);
}
#[test]
fn test_empty_partial_pack() {
let params = test_params();
let k_g = KeySwitchingMatrix::dummy(params.ring_dim, params.moduli(), params.gadget_len);
let result = partial_pack(&[], &k_g, ¶ms);
assert_eq!(result.ring_dim(), params.ring_dim);
for i in 0..params.ring_dim {
assert_eq!(result.a.coeff(i), 0);
assert_eq!(result.b.coeff(i), 0);
}
}
#[test]
fn test_aggregate_properties() {
let params = test_params();
let mut rng = rand_chacha::ChaCha20Rng::seed_from_u64(33333);
let n = 4;
let lwe_cts: Vec<LweCiphertext> = (0..n).map(|_| random_lwe(&mut rng, ¶ms)).collect();
let intermediates: Vec<_> = lwe_cts
.iter()
.enumerate()
.map(|(i, lwe)| transform_at_slot(lwe, i, ¶ms))
.collect();
let aggregated = aggregate(&intermediates, ¶ms);
for (i, ct) in lwe_cts.iter().enumerate() {
assert_eq!(
aggregated.b_poly.coeff(i),
ct.b,
"b coefficient mismatch at position {}",
i
);
}
}
}