use crate::{fhe::FheCompile, Error, FheProgramFn, Result, SecurityLevel};
use log::{debug, trace};
use seal_fhe::{
BfvEncryptionParametersBuilder, CoefficientModulus, Context, KeyGenerator, Modulus,
PlainModulus,
};
use sunscreen_backend::noise_model::{
noise_budget_to_noise, predict_noise, MeasuredModel, TargetNoiseLevel,
};
use sunscreen_fhe_program::{FheProgram, FheProgramTrait, Operation, SchemeType};
pub use sunscreen_runtime::Params;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PlainModulusConstraint {
Raw(u64),
BatchingMinimum(u32),
}
const LATTICE_DIMENSIONS: &[u64] = &[1024, 2048, 4096, 8192, 16384, 32768];
const BATCHING_MIN_BITS: &[u32] = &[14, 14, 16, 17, 17, 17];
fn plaintext_constraint_to_modulus(
constraint: PlainModulusConstraint,
lattice_dimension_index: usize,
) -> Result<seal_fhe::Modulus> {
let lattice_dimension = LATTICE_DIMENSIONS[lattice_dimension_index];
let plaintext_modulus = match constraint {
PlainModulusConstraint::Raw(v) => PlainModulus::raw(v).unwrap(),
PlainModulusConstraint::BatchingMinimum(min) => {
let min_batching_bits = BATCHING_MIN_BITS[lattice_dimension_index];
match PlainModulus::batching(lattice_dimension, u32::max(min_batching_bits, min)) {
Ok(v) => v,
Err(e) => {
trace!(
"Can't use batching with {} bits for dimension n={}: {:#?}",
min_batching_bits,
lattice_dimension,
e
);
return Err(Error::UnsatisfiableConstraint);
}
}
}
};
Ok(plaintext_modulus)
}
fn can_make_required_keys(fhe_program: &FheProgram, params: &Params) -> Result<bool> {
let plain_modulus = PlainModulus::raw(params.plain_modulus)?;
let modulus_chain = params
.coeff_modulus
.iter()
.map(|x| Modulus::new(*x).map_err(Error::from))
.collect::<Result<Vec<Modulus>>>()?;
let enc_params = BfvEncryptionParametersBuilder::new()
.set_plain_modulus(plain_modulus)
.set_coefficient_modulus(modulus_chain)
.set_poly_modulus_degree(params.lattice_dimension)
.build()?;
let context = Context::new(&enc_params, true, params.security_level).unwrap();
let keygen = KeyGenerator::new(&context).unwrap();
let create_galois = if fhe_program.requires_galois_keys() {
keygen.create_galois_keys().is_ok()
} else {
true
};
let create_relin = if fhe_program.requires_relin_keys() {
keygen.create_relinearization_keys().is_ok()
} else {
true
};
Ok(create_galois && create_relin)
}
pub fn determine_params(
fhe_program_fns: &[Box<dyn FheProgramFn>],
plaintext_constraint: PlainModulusConstraint,
security_level: SecurityLevel,
noise_margin_bits: u32,
scheme_type: SchemeType,
) -> Result<Params> {
'params_loop: for (i, n) in LATTICE_DIMENSIONS.iter().enumerate() {
let plaintext_modulus = match plaintext_constraint_to_modulus(plaintext_constraint, i) {
Ok(v) => v,
Err(_) => {
continue 'params_loop;
}
};
let coeff = CoefficientModulus::bfv_default(*n, security_level).unwrap();
let params = Params {
coeff_modulus: coeff.iter().map(|v| v.value()).collect(),
lattice_dimension: *n,
plain_modulus: plaintext_modulus.value(),
security_level,
scheme_type,
};
trace!(
"Trying to build scheme with \\lambda={:#?} p={} n={} c=default(\\lambda, n).",
security_level,
plaintext_modulus.value(),
n
);
for program in fhe_program_fns {
trace!("Successfully created parameters.");
trace!("Running backend compilation for {}", program.name());
let ir = program.build(¶ms)?.compile();
ir.validate().map_err(Error::FheProgramError)?;
trace!("Built and validated {}", program.name());
match can_make_required_keys(&ir, ¶ms) {
Ok(can_make_keys) => {
if !can_make_keys {
continue 'params_loop;
}
}
Err(_) => {
continue 'params_loop;
}
};
let mut chain_noise_level = 0f64;
for _ in 0..program.chain_count() {
let noise_targets = ir
.graph
.node_weights()
.filter(|n| {
matches!(
n.operation,
Operation::InputCiphertext(_) | Operation::InputPlaintext(_)
)
})
.map(|n| match n.operation {
Operation::InputCiphertext(_) => {
if chain_noise_level == 0f64 {
TargetNoiseLevel::Fresh
} else {
TargetNoiseLevel::InvariantNoise(chain_noise_level)
}
}
Operation::InputPlaintext(_) => TargetNoiseLevel::NotApplicable,
_ => unreachable!(),
})
.collect::<Vec<TargetNoiseLevel>>();
let model = match MeasuredModel::new(&ir, ¶ms, &noise_targets) {
Ok(v) => v,
Err(_) => {
trace!(
"Failed to construct noise model for {} with lattice_dimension={}",
program.name(),
n
);
continue 'params_loop;
}
};
let output_noises = predict_noise(&model, &ir);
let target_noise = noise_budget_to_noise(noise_margin_bits as f64);
for output_noise in output_noises {
if output_noise > target_noise {
trace!(
"Failed to meet noise constraints with lattice dimension {} for program {}",
n,
program.name()
);
continue 'params_loop;
} else if output_noise > chain_noise_level {
chain_noise_level = output_noise
}
}
}
debug!("Using params lattice_dimension={} and ={:#?}", n, coeff);
}
return Ok(params);
}
Err(Error::NoParams)
}