use super::prelude::LweDimension;
use super::{PaddingBit, ShortintEncoding};
use crate::core_crypto::commons::computation_buffers::ComputationBuffers;
use crate::core_crypto::commons::generators::{
DeterministicSeeder, EncryptionRandomGenerator, SecretRandomGenerator,
};
#[cfg(feature = "zk-pok")]
use crate::core_crypto::commons::math::random::RandomGenerator;
use crate::core_crypto::commons::math::random::{DefaultRandomGenerator, Seeder};
use crate::core_crypto::commons::parameters::CiphertextModulus;
use crate::core_crypto::entities::*;
use crate::core_crypto::prelude::{ContainerMut, GlweSize, UnsignedInteger};
use crate::core_crypto::seeders::new_seeder;
use crate::shortint::ciphertext::{Degree, MaxDegree};
use crate::shortint::prelude::PolynomialSize;
use crate::shortint::{CarryModulus, MessageModulus};
use std::cell::RefCell;
use std::fmt::Debug;
mod client_side;
mod public_side;
mod server_side;
#[cfg(feature = "experimental")]
mod wopbs;
thread_local! {
static LOCAL_ENGINE: RefCell<ShortintEngine> = RefCell::new(ShortintEngine::new());
}
#[derive(Default)]
struct CiphertextBuffer {
buffer: Vec<u128>,
}
impl CiphertextBuffer {
fn as_lwe<Scalar>(
&mut self,
dim: LweDimension,
ciphertext_modulus: CiphertextModulus<Scalar>,
) -> LweCiphertextMutView<'_, Scalar>
where
Scalar: UnsignedInteger,
{
let elems_per_block = 128 / Scalar::BITS;
let required_elems = dim.to_lwe_size().0;
let required_blocks = required_elems.div_ceil(elems_per_block);
let buffer = if self.buffer.len() < required_blocks {
self.buffer.resize(required_blocks, 0u128);
self.buffer.as_mut_slice()
} else {
&mut self.buffer[..required_blocks]
};
let buffer = bytemuck::try_cast_slice_mut(buffer).unwrap_or_else(|_| {
panic!(
"Scalar of size {} are not supported by the shortint engine",
Scalar::BITS
)
});
LweCiphertextMutView::from_container(&mut buffer[..required_elems], ciphertext_modulus)
}
}
#[allow(clippy::too_many_arguments)]
pub(crate) fn fill_accumulator_with_encoding<F, C>(
accumulator: &mut GlweCiphertext<C>,
polynomial_size: PolynomialSize,
glwe_size: GlweSize,
input_message_modulus: MessageModulus,
input_carry_modulus: CarryModulus,
output_message_modulus: MessageModulus,
output_carry_modulus: CarryModulus,
f: F,
) -> u64
where
C: ContainerMut<Element = u64>,
F: Fn(u64) -> u64,
{
assert_eq!(accumulator.polynomial_size(), polynomial_size);
assert_eq!(accumulator.glwe_size(), glwe_size);
let output_encoding = ShortintEncoding {
ciphertext_modulus: CiphertextModulus::new_native(),
message_modulus: output_message_modulus,
carry_modulus: output_carry_modulus,
padding_bit: PaddingBit::Yes,
};
let mut accumulator_view = accumulator.as_mut_view();
accumulator_view.get_mut_mask().as_mut().fill(0);
let input_modulus_sup = (input_message_modulus.0 * input_carry_modulus.0) as usize;
let box_size = polynomial_size.0 / input_modulus_sup;
let mut body = accumulator_view.get_mut_body();
let accumulator_u64 = body.as_mut();
let mut max_value = 0;
for i in 0..input_modulus_sup {
let index = i * box_size;
let f_eval = f(i as u64);
max_value = max_value.max(f_eval);
accumulator_u64[index..index + box_size].fill(output_encoding.encode(Cleartext(f_eval)).0);
}
let half_box_size = box_size / 2;
for a_i in accumulator_u64[0..half_box_size].iter_mut() {
*a_i = (*a_i).wrapping_neg();
}
accumulator_u64.rotate_left(half_box_size);
max_value
}
pub(crate) fn fill_accumulator_no_encoding<F, C>(
accumulator: &mut GlweCiphertext<C>,
polynomial_size: PolynomialSize,
glwe_size: GlweSize,
f: F,
) where
C: ContainerMut<Element = u64>,
F: Fn(u64) -> u64,
{
assert_eq!(accumulator.polynomial_size(), polynomial_size);
assert_eq!(accumulator.glwe_size(), glwe_size);
let mut accumulator_view = accumulator.as_mut_view();
accumulator_view.get_mut_mask().as_mut().fill(0);
let mut body = accumulator_view.get_mut_body();
let accumulator_u64 = body.as_mut();
for (i, value) in accumulator_u64.iter_mut().enumerate() {
*value = f(i as u64);
}
}
pub(crate) fn fill_many_lut_accumulator<C>(
accumulator: &mut GlweCiphertext<C>,
polynomial_size: PolynomialSize,
glwe_size: GlweSize,
message_modulus: MessageModulus,
carry_modulus: CarryModulus,
functions: &[&dyn Fn(u64) -> u64],
) -> (MaxDegree, usize, Vec<Degree>)
where
C: ContainerMut<Element = u64>,
{
assert_eq!(accumulator.polynomial_size(), polynomial_size);
assert_eq!(accumulator.glwe_size(), glwe_size);
let encoding = ShortintEncoding {
ciphertext_modulus: accumulator.ciphertext_modulus(),
message_modulus,
carry_modulus,
padding_bit: PaddingBit::Yes,
};
let mut accumulator_view = accumulator.as_mut_view();
accumulator_view.get_mut_mask().as_mut().fill(0);
let modulus_sup = (message_modulus.0 * carry_modulus.0) as usize;
let box_size = polynomial_size.0 / modulus_sup;
let mut body = accumulator_view.get_mut_body();
let accumulator_u64 = body.as_mut();
accumulator_u64.as_mut().fill(0u64);
let fn_counts = functions.len();
assert!(
fn_counts <= modulus_sup / 2,
"Cannot generate many lut accumulator for {fn_counts} functions, maximum possible is {}",
modulus_sup / 2
);
let max_degree = MaxDegree::new((modulus_sup / fn_counts - 1) as u64);
let mut per_fn_output_degree = vec![Degree::new(0); fn_counts];
let single_function_sub_lut_size = (max_degree.get() as usize + 1) * box_size;
for ((function_sub_lut, output_degree), function) in accumulator_u64
.chunks_mut(single_function_sub_lut_size)
.zip(per_fn_output_degree.iter_mut())
.zip(functions)
{
for (msg_value, sub_lut_box) in function_sub_lut.chunks_exact_mut(box_size).enumerate() {
let msg_value = msg_value as u64;
let function_eval = function(msg_value);
*output_degree = Degree::new((function_eval).max(output_degree.get()));
sub_lut_box.fill(encoding.encode(Cleartext(function_eval)).0);
}
}
let half_box_size = box_size / 2;
for a_i in accumulator_u64[0..half_box_size].iter_mut() {
*a_i = (*a_i).wrapping_neg();
}
accumulator_u64.rotate_left(half_box_size);
(
max_degree,
single_function_sub_lut_size,
per_fn_output_degree,
)
}
#[allow(dead_code)]
#[derive(Debug)]
pub struct EngineError {
error: Box<dyn std::error::Error>,
}
impl<T> From<T> for EngineError
where
T: std::error::Error + 'static,
{
fn from(error: T) -> Self {
Self {
error: Box::new(error),
}
}
}
impl std::fmt::Display for EngineError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", &self.error)
}
}
pub struct ShortintEngine {
pub(crate) secret_generator: SecretRandomGenerator<DefaultRandomGenerator>,
pub(crate) encryption_generator: EncryptionRandomGenerator<DefaultRandomGenerator>,
pub(crate) seeder: DeterministicSeeder<DefaultRandomGenerator>,
#[cfg(feature = "zk-pok")]
pub(crate) random_generator: RandomGenerator<DefaultRandomGenerator>,
computation_buffers: ComputationBuffers,
ciphertext_buffers: CiphertextBuffer,
}
impl Default for ShortintEngine {
fn default() -> Self {
Self::new()
}
}
impl ShortintEngine {
#[inline]
pub fn with_thread_local_mut<F, R>(func: F) -> R
where
F: FnOnce(&mut Self) -> R,
{
LOCAL_ENGINE.with(|engine_cell| {
if let Ok(mut thread_engine) = engine_cell.try_borrow_mut() {
func(&mut thread_engine)
} else {
func(&mut Self::new())
}
})
}
pub fn new() -> Self {
let mut root_seeder = new_seeder();
Self::new_from_seeder(root_seeder.as_mut())
}
pub fn new_from_seeder(root_seeder: &mut dyn Seeder) -> Self {
let mut deterministic_seeder =
DeterministicSeeder::<DefaultRandomGenerator>::new(root_seeder.seed());
Self {
secret_generator: SecretRandomGenerator::new(deterministic_seeder.seed()),
encryption_generator: EncryptionRandomGenerator::new(
deterministic_seeder.seed(),
&mut deterministic_seeder,
),
#[cfg(feature = "zk-pok")]
random_generator: RandomGenerator::new(deterministic_seeder.seed()),
seeder: deterministic_seeder,
computation_buffers: ComputationBuffers::default(),
ciphertext_buffers: CiphertextBuffer::default(),
}
}
pub fn get_buffers<Scalar>(
&mut self,
lwe_dimension: LweDimension,
ciphertext_modulus: CiphertextModulus<Scalar>,
) -> (LweCiphertextMutView<'_, Scalar>, &mut ComputationBuffers)
where
Scalar: UnsignedInteger,
{
(
self.ciphertext_buffers
.as_lwe::<Scalar>(lwe_dimension, ciphertext_modulus),
&mut self.computation_buffers,
)
}
pub fn get_computation_buffers(&mut self) -> &mut ComputationBuffers {
&mut self.computation_buffers
}
}
#[cfg(test)]
mod test {
use rand::Rng;
use rayon::iter::{IntoParallelRefIterator, ParallelIterator};
use crate::shortint::parameters::test_params::TEST_PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
use crate::shortint::{CompactPrivateKey, CompactPublicKey};
#[test]
fn test_engine_thread_reuse_ci_run_filter() {
let mut rng = rand::thread_rng();
let param_pke = TEST_PARAM_PKE_MESSAGE_2_CARRY_2_KS_PBS_TUNIFORM_2M128;
let packed_modulus = param_pke.message_modulus.0 * param_pke.carry_modulus.0;
let compact_private_key = CompactPrivateKey::new(param_pke);
let pk = CompactPublicKey::new(&compact_private_key);
let elements = 500;
let fhe_uint_count = 16;
let messages = (0..elements)
.map(|_| {
let input_msg: u64 = rng.gen_range(0..packed_modulus);
vec![input_msg; fhe_uint_count]
})
.collect::<Vec<_>>();
messages.par_iter().for_each(|msg| {
pk.encrypt_iter_with_modulus(msg.iter().copied(), packed_modulus);
})
}
}