mod pir;
mod util;
use clap::Parser;
use fhe::bfv;
use fhe_math::rq::{traits::TryConvertFrom, Poly, Representation};
use fhe_traits::{
DeserializeParametrized, FheDecoder, FheDecrypter, FheEncoder, FheEncoderVariableTime,
FheEncrypter, Serialize,
};
use fhe_util::{inverse, transcode_bidirectional, transcode_to_bytes};
use indicatif::HumanBytes;
use itertools::Itertools;
use rand::{rng, RngCore};
use std::error::Error;
use util::{
encode_database, generate_database, number_elements_per_plaintext,
timeit::{timeit, timeit_n},
};
fn main() -> Result<(), Box<dyn Error>> {
env_logger::init();
let degree = 4096usize;
let plaintext_modulus = 2056193u64;
let moduli_sizes = [36, 36, 37];
let args = pir::Cli::parse();
let database_size = args.database_size;
let elements_size = args.element_size;
let mut rng = rng();
let max_element_size = ((plaintext_modulus.ilog2() as usize) * degree) / 8;
if elements_size > max_element_size || elements_size == 0 || database_size == 0 {
log::error!("Invalid parameters: database_size = {database_size}, elements_size = {elements_size}. The maximum element size if {max_element_size}.");
clap::Error::new(clap::error::ErrorKind::InvalidValue).exit();
}
println!("# SealPIR with fhe.rs");
println!(
"database of {}",
HumanBytes((database_size * elements_size) as u64)
);
println!("\tdatabase_size = {database_size}");
println!("\telements_size = {elements_size}");
let database = timeit!("Database generation", {
generate_database(database_size, elements_size)
});
let params = timeit!(
"Parameters generation",
bfv::BfvParametersBuilder::new()
.set_degree(degree)
.set_plaintext_modulus(plaintext_modulus)
.set_moduli_sizes(&moduli_sizes)
.build_arc()?
);
let (preprocessed_database, (dim1, dim2)) = timeit!("Database preprocessing", {
encode_database(&database, params.clone(), 1)
});
let (sk, ek_expansion_serialized) = timeit!("Client setup", {
let sk = bfv::SecretKey::random(¶ms, &mut rng);
let level = (dim1 + dim2).next_power_of_two().ilog2() as usize;
println!("expansion_level = {level}");
let ek_expansion = bfv::EvaluationKeyBuilder::new_leveled(&sk, 1, 0)?
.enable_expansion(level)?
.build(&mut rng)?;
let ek_expansion_serialized = ek_expansion.to_bytes();
(sk, ek_expansion_serialized)
});
println!(
"📄 Evaluation key: {}",
HumanBytes(ek_expansion_serialized.len() as u64)
);
let ek_expansion = timeit!(
"Server setup",
bfv::EvaluationKey::from_bytes(&ek_expansion_serialized, ¶ms)?
);
let index = (rng.next_u64() as usize) % database_size;
let query = timeit!("Client query", {
let level = (dim1 + dim2).next_power_of_two().ilog2();
let query_index = index
/ number_elements_per_plaintext(
params.degree(),
plaintext_modulus.ilog2() as usize,
elements_size,
);
let mut pt = vec![0u64; dim1 + dim2];
let inv = inverse(1 << level, plaintext_modulus).ok_or("No inverse")?;
pt[query_index / dim2] = inv;
pt[dim1 + (query_index % dim2)] = inv;
let query_pt = bfv::Plaintext::try_encode(&pt, bfv::Encoding::poly_at_level(1), ¶ms)?;
let query: bfv::Ciphertext = sk.try_encrypt(&query_pt, &mut rng)?;
query.to_bytes()
});
println!("📄 Query: {}", HumanBytes(query.len() as u64));
let responses: Vec<Vec<u8>> = timeit_n!("Server response", 5, {
let start = std::time::Instant::now();
let query = bfv::Ciphertext::from_bytes(&query, ¶ms)?;
let expanded_query = ek_expansion.expands(&query, dim1 + dim2)?;
println!("Expand: {}", DisplayDuration(start.elapsed()));
let query_vec = &expanded_query[..dim1];
let dot_product_mod_switch = move |i, database: &[bfv::Plaintext]| {
let column = database.iter().skip(i).step_by(dim2);
let mut c = bfv::dot_product_scalar(query_vec.iter(), column)?;
c.switch_to_level(c.max_switchable_level())?;
Ok(c)
};
let dot_products = (0..dim2)
.map(|i| dot_product_mod_switch(i, &preprocessed_database))
.collect::<fhe::Result<Vec<bfv::Ciphertext>>>()?;
let fold = dot_products
.iter()
.map(|c| {
let mut pt_values = Vec::with_capacity(
2 * (params.degree() * (64 - params.moduli()[0].leading_zeros() as usize))
.div_ceil(plaintext_modulus.ilog2() as usize),
);
pt_values.append(&mut transcode_bidirectional(
c.first().unwrap().coefficients().as_slice().unwrap(),
64 - params.moduli()[0].leading_zeros() as usize,
plaintext_modulus.ilog2() as usize,
));
pt_values.append(&mut transcode_bidirectional(
c.get(1).unwrap().coefficients().as_slice().unwrap(),
64 - params.moduli()[0].leading_zeros() as usize,
plaintext_modulus.ilog2() as usize,
));
unsafe {
bfv::PlaintextVec::try_encode_vt(
&pt_values,
bfv::Encoding::poly_at_level(1),
¶ms,
)
}
})
.collect::<fhe::Result<Vec<bfv::PlaintextVec>>>()?;
(0..fold[0].len())
.map(|i| {
let mut outi = bfv::dot_product_scalar(
expanded_query[dim1..].iter(),
fold.iter().map(|pts| &pts[i]),
)?;
outi.switch_to_level(outi.max_switchable_level())?;
Ok(outi.to_bytes())
})
.collect::<fhe::Result<Vec<Vec<u8>>>>()?
});
println!(
"📄 Response: {}",
HumanBytes(responses.iter().map(|r| r.len()).sum::<usize>() as u64)
);
let answer = timeit!("Client answer", {
let responses = responses
.iter()
.map(|r| bfv::Ciphertext::from_bytes(r, ¶ms).unwrap())
.collect_vec();
let decrypted_pt = responses
.iter()
.flat_map(|r| sk.try_decrypt(r))
.collect_vec();
let decrypted_vec = decrypted_pt
.iter()
.flat_map(|pt| Vec::<u64>::try_decode(pt, bfv::Encoding::poly_at_level(2)).unwrap())
.collect_vec();
let expect_ncoefficients = (params.degree()
* (64 - params.moduli()[0].leading_zeros() as usize))
.div_ceil(plaintext_modulus.ilog2() as usize);
assert!(decrypted_vec.len() >= 2 * expect_ncoefficients);
let mut poly0 = transcode_bidirectional(
&decrypted_vec[..expect_ncoefficients],
plaintext_modulus.ilog2() as usize,
64 - params.moduli()[0].leading_zeros() as usize,
);
let mut poly1 = transcode_bidirectional(
&decrypted_vec[expect_ncoefficients..2 * expect_ncoefficients],
plaintext_modulus.ilog2() as usize,
64 - params.moduli()[0].leading_zeros() as usize,
);
assert!(poly0.len() >= params.degree());
assert!(poly1.len() >= params.degree());
poly0.truncate(params.degree());
poly1.truncate(params.degree());
let ctx = params.context_at_level(2)?;
let ct = bfv::Ciphertext::new(
vec![
Poly::try_convert_from(poly0, ctx, true, Representation::Ntt)?,
Poly::try_convert_from(poly1, ctx, true, Representation::Ntt)?,
],
¶ms,
)?;
let pt = sk.try_decrypt(&ct).unwrap();
let pt = Vec::<u64>::try_decode(&pt, bfv::Encoding::poly_at_level(2))?;
let plaintext = transcode_to_bytes(&pt, plaintext_modulus.ilog2() as usize);
let offset = index
% number_elements_per_plaintext(
params.degree(),
plaintext_modulus.ilog2() as usize,
elements_size,
);
println!("Noise in response (ct): {:?}", unsafe {
sk.measure_noise(&ct)
});
plaintext[offset * elements_size..(offset + 1) * elements_size].to_vec()
});
assert_eq!(&database[index], &answer);
Ok(())
}