use std::marker::PhantomData;
use std::time::Instant;
use crate::error::*;
use crate::metadata::*;
use crate::ZkpProgramInput;
use crate::{
run_program_unchecked, serialization::WithContext, Ciphertext, FheProgramInput,
InnerCiphertext, InnerPlaintext, Plaintext, PrivateKey, PublicKey, SealCiphertext, SealData,
SealPlaintext, TryFromPlaintext, TryIntoPlaintext, TypeNameInstance,
};
use log::trace;
use sunscreen_fhe_program::FheProgramTrait;
use sunscreen_fhe_program::SchemeType;
use seal_fhe::{
BFVEvaluator, BfvEncryptionParametersBuilder, Context as SealContext, Decryptor, Encryptor,
KeyGenerator, Modulus, PolynomialArray,
};
pub use sunscreen_compiler_common::{Type, TypeName};
use sunscreen_zkp_backend::BigInt;
use sunscreen_zkp_backend::CompiledZkpProgram;
use sunscreen_zkp_backend::Proof;
use sunscreen_zkp_backend::ZkpBackend;
#[cfg(feature = "deterministic")]
fn encrypt_function(
encryptor: &Encryptor,
val: &seal_fhe::Plaintext,
seed: Option<&[u64; 8]>,
) -> Result<(
seal_fhe::Ciphertext,
PolynomialArray,
PolynomialArray,
seal_fhe::Plaintext,
)> {
let result = if let Some(seed) = seed {
encryptor.encrypt_return_components_deterministic(val, seed)
} else {
encryptor.encrypt_return_components(val)
};
result.map_err(Error::SealError)
}
#[cfg(not(feature = "deterministic"))]
fn encrypt_function(
encryptor: &Encryptor,
val: &seal_fhe::Plaintext,
_seed: Option<&[u64; 8]>,
) -> Result<(
seal_fhe::Ciphertext,
PolynomialArray,
PolynomialArray,
seal_fhe::Plaintext,
)> {
encryptor
.encrypt_return_components(val)
.map_err(Error::SealError)
}
enum Context {
Seal(SealContext),
}
pub mod marker {
pub trait Fhe {}
pub trait Zkp {}
}
#[allow(unused)]
pub struct BFVEncryptionComponents {
ciphertext: Ciphertext,
u: Vec<PolynomialArray>,
e: Vec<PolynomialArray>,
r: Vec<Plaintext>,
}
pub struct Fhe {}
impl marker::Fhe for Fhe {}
pub struct Zkp {}
impl marker::Zkp for Zkp {}
pub struct FheZkp {}
impl marker::Fhe for FheZkp {}
impl marker::Zkp for FheZkp {}
struct FheRuntimeData {
params: Params,
context: Context,
}
struct ZkpRuntimeData;
enum RuntimeData {
Fhe(FheRuntimeData),
Zkp(ZkpRuntimeData),
FheZkp(FheRuntimeData, ZkpRuntimeData),
}
impl RuntimeData {
fn unwrap_fhe(&self) -> &FheRuntimeData {
match self {
Self::Fhe(x) => x,
Self::FheZkp(x, _) => x,
_ => panic!("Expected RuntimeData::Fhe or RuntimeData::FheZkp."),
}
}
}
pub struct GenericRuntime<T, B> {
runtime_data: RuntimeData,
_phantom_t: PhantomData<T>,
zkp_backend: B,
}
impl<T, B> GenericRuntime<T, B>
where
T: self::marker::Fhe,
{
pub fn decrypt<P>(&self, ciphertext: &Ciphertext, private_key: &PrivateKey) -> Result<P>
where
P: TryFromPlaintext + TypeName,
{
let expected_type = Type {
is_encrypted: true,
..P::type_name()
};
if expected_type != ciphertext.data_type {
return Err(Error::type_mismatch(&expected_type, &ciphertext.data_type));
}
let fhe_data = self.runtime_data.unwrap_fhe();
let val = match (&fhe_data.context, &ciphertext.inner) {
(Context::Seal(context), InnerCiphertext::Seal(ciphertexts)) => {
let decryptor = Decryptor::new(context, &private_key.0)?;
let plaintexts = ciphertexts
.iter()
.map(|c| {
if decryptor
.invariant_noise_budget(c)
.map_err(Error::SealError)?
== 0
{
return Err(Error::TooMuchNoise);
}
decryptor.decrypt(c).map_err(Error::SealError)
})
.collect::<Result<Vec<SealPlaintext>>>()?
.drain(0..)
.map(|p| WithContext {
params: fhe_data.params.clone(),
data: p,
})
.collect();
P::try_from_plaintext(
&Plaintext {
data_type: P::type_name(),
inner: InnerPlaintext::Seal(plaintexts),
},
&fhe_data.params,
)?
}
};
Ok(val)
}
pub fn measure_noise_budget(&self, c: &Ciphertext, private_key: &PrivateKey) -> Result<u32> {
let fhe_data = self.runtime_data.unwrap_fhe();
match (&fhe_data.context, &c.inner) {
(Context::Seal(ctx), InnerCiphertext::Seal(ciphertexts)) => {
let decryptor = Decryptor::new(ctx, &private_key.0)?;
Ok(ciphertexts.iter().try_fold(u32::MAX, |min, c| {
let m = u32::min(min, decryptor.invariant_noise_budget(&c.data)?);
Ok::<_, seal_fhe::Error>(m)
})?)
}
}
}
pub fn generate_keys(&self) -> Result<(PublicKey, PrivateKey)> {
let fhe_data = self.runtime_data.unwrap_fhe();
let keys = match &fhe_data.context {
Context::Seal(context) => {
let keygen = KeyGenerator::new(context)?;
let galois_keys = keygen.create_galois_keys().ok().map(|v| WithContext {
params: fhe_data.params.clone(),
data: v,
});
let relin_keys = keygen
.create_relinearization_keys()
.ok()
.map(|v| WithContext {
params: fhe_data.params.clone(),
data: v,
});
let public_keys = PublicKey {
public_key: WithContext {
params: fhe_data.params.clone(),
data: keygen.create_public_key(),
},
galois_key: galois_keys,
relin_key: relin_keys,
};
let private_key = PrivateKey(WithContext {
params: fhe_data.params.clone(),
data: keygen.secret_key(),
});
(public_keys, private_key)
}
};
Ok(keys)
}
pub fn params(&self) -> &Params {
let fhe_data = self.runtime_data.unwrap_fhe();
&fhe_data.params
}
pub fn run<I>(
&self,
fhe_program: &CompiledFheProgram,
mut arguments: Vec<I>,
public_key: &PublicKey,
) -> Result<Vec<Ciphertext>>
where
I: Into<FheProgramInput>,
{
fhe_program.fhe_program_fn.validate()?;
if public_key.relin_key.is_none() && fhe_program.fhe_program_fn.requires_relin_keys() {
return Err(Error::MissingRelinearizationKeys);
}
if public_key.galois_key.is_none() && fhe_program.fhe_program_fn.requires_galois_keys() {
return Err(Error::MissingGaloisKeys);
}
let mut arguments: Vec<FheProgramInput> = arguments.drain(0..).map(|a| a.into()).collect();
let expected_args = &fhe_program.metadata.signature.arguments;
if expected_args.len() != arguments.len() {
return Err(Error::IncorrectCiphertextCount);
}
if arguments
.iter()
.enumerate()
.any(|(i, a)| a.type_name_instance() != expected_args[i])
{
return Err(Error::argument_mismatch(
expected_args,
&arguments
.iter()
.map(|a| a.type_name_instance())
.collect::<Vec<Type>>(),
));
}
if fhe_program.metadata.signature.num_ciphertexts.len()
!= fhe_program.metadata.signature.returns.len()
{
return Err(Error::ReturnTypeMetadataError);
}
let fhe_data = self.runtime_data.unwrap_fhe();
match &fhe_data.context {
Context::Seal(context) => {
let evaluator = BFVEvaluator::new(context)?;
let mut inputs: Vec<SealData> = vec![];
for i in arguments.drain(0..) {
match i {
FheProgramInput::Ciphertext(c) => match c.inner {
InnerCiphertext::Seal(mut c) => {
for j in c.drain(0..) {
inputs.push(SealData::Ciphertext(j.data));
}
}
},
FheProgramInput::Plaintext(p) => {
let p = p.try_into_plaintext(&fhe_data.params)?;
match p.inner {
InnerPlaintext::Seal(mut p) => {
for j in p.drain(0..) {
inputs.push(SealData::Plaintext(j.data));
}
}
}
}
}
}
let relin_key = public_key.relin_key.as_ref().map(|p| &p.data);
let galois_key = public_key.galois_key.as_ref().map(|p| &p.data);
let mut raw_ciphertexts = unsafe {
run_program_unchecked(
&fhe_program.fhe_program_fn,
&inputs,
&evaluator,
&relin_key,
&galois_key,
)
}?;
let mut packed_ciphertexts = vec![];
for (i, ciphertext_count) in fhe_program
.metadata
.signature
.num_ciphertexts
.iter()
.enumerate()
{
packed_ciphertexts.push(Ciphertext {
data_type: fhe_program.metadata.signature.returns[i].clone(),
inner: InnerCiphertext::Seal(
raw_ciphertexts
.drain(0..*ciphertext_count)
.map(|c| WithContext {
params: fhe_data.params.clone(),
data: c,
})
.collect(),
),
});
}
Ok(packed_ciphertexts)
}
}
}
pub fn encrypt<P>(&self, val: P, public_key: &PublicKey) -> Result<Ciphertext>
where
P: TryIntoPlaintext + TypeName,
{
self.encrypt_return_components_switched(val, public_key, false, None)
.map(|x| x.ciphertext)
}
#[cfg(feature = "deterministic")]
pub fn encrypt_deterministic<P>(
&self,
val: P,
public_key: &PublicKey,
seed: &[u64; 8],
) -> Result<Ciphertext>
where
P: TryIntoPlaintext + TypeName,
{
self.encrypt_return_components_switched(val, public_key, false, Some(seed))
.map(|x| x.ciphertext)
}
#[allow(dead_code)]
fn encrypt_return_components<P>(
&self,
val: P,
public_key: &PublicKey,
) -> Result<BFVEncryptionComponents>
where
P: TryIntoPlaintext + TypeName,
{
self.encrypt_return_components_switched(val, public_key, true, None)
}
#[cfg(feature = "deterministic")]
#[allow(dead_code)]
fn encrypt_return_components_deterministic<P>(
&self,
val: P,
public_key: &PublicKey,
seed: &[u64; 8],
) -> Result<BFVEncryptionComponents>
where
P: TryIntoPlaintext + TypeName,
{
self.encrypt_return_components_switched(val, public_key, true, Some(seed))
}
fn encrypt_return_components_switched<P>(
&self,
val: P,
public_key: &PublicKey,
export_components: bool,
seed: Option<&[u64; 8]>,
) -> Result<BFVEncryptionComponents>
where
P: TryIntoPlaintext + TypeName,
{
let fhe_data = self.runtime_data.unwrap_fhe();
let plaintext = val.try_into_plaintext(&fhe_data.params)?;
let (ciphertext, u, e, r) = match (&fhe_data.context, plaintext.inner) {
(Context::Seal(context), InnerPlaintext::Seal(inner_plain)) => {
let encryptor = Encryptor::with_public_key(context, &public_key.public_key.data)?;
let capacity = if export_components {
inner_plain.len()
} else {
0
};
let mut us = Vec::with_capacity(capacity);
let mut es = Vec::with_capacity(capacity);
let mut rs = Vec::with_capacity(capacity);
let ciphertexts = inner_plain
.iter()
.map(|p| {
let ciphertext = if export_components {
encryptor.encrypt(p).map_err(Error::SealError)
} else {
let (ciphertext, u, e, r) = encrypt_function(&encryptor, p, seed)?;
let r_context = WithContext {
params: fhe_data.params.clone(),
data: r,
};
let r = Plaintext {
data_type: P::type_name(),
inner: InnerPlaintext::Seal(vec![r_context]),
};
us.push(u);
es.push(e);
rs.push(r);
Ok(ciphertext)
}?;
Ok(ciphertext)
})
.collect::<Result<Vec<SealCiphertext>>>()?
.drain(0..)
.map(|c| WithContext {
params: fhe_data.params.clone(),
data: c,
})
.collect();
(
Ciphertext {
data_type: Type {
is_encrypted: true,
..P::type_name()
},
inner: InnerCiphertext::Seal(ciphertexts),
},
us,
es,
rs,
)
}
};
Ok(BFVEncryptionComponents {
ciphertext,
u,
e,
r,
})
}
}
impl<T, B> GenericRuntime<T, B>
where
T: marker::Zkp,
B: ZkpBackend,
{
pub fn prove<I>(
&self,
program: &CompiledZkpProgram,
private_inputs: Vec<I>,
public_inputs: Vec<I>,
constant_inputs: Vec<I>,
) -> Result<Proof>
where
I: Into<ZkpProgramInput>,
{
let private_inputs = private_inputs
.into_iter()
.flat_map(|x| I::into(x).0.to_native_fields())
.collect::<Vec<BigInt>>();
let public_inputs = public_inputs
.into_iter()
.flat_map(|x| I::into(x).0.to_native_fields())
.collect::<Vec<BigInt>>();
let constant_inputs = constant_inputs
.into_iter()
.flat_map(|x| I::into(x).0.to_native_fields())
.collect::<Vec<BigInt>>();
let backend = &self.zkp_backend;
trace!("Starting JIT (prover)...");
let now = Instant::now();
let prog =
backend.jit_prover(program, &constant_inputs, &public_inputs, &private_inputs)?;
trace!("Prover JIT time {}s", now.elapsed().as_secs_f64());
let inputs = [public_inputs, private_inputs].concat();
trace!("Starting backend prove...");
Ok(backend.prove(&prog, &inputs)?)
}
pub fn proof_builder<'r, 'p>(
&'r self,
program: &'p CompiledZkpProgram,
) -> ProofBuilder<'r, 'p, T, B> {
ProofBuilder::new(self, program)
}
pub fn verify<I>(
&self,
program: &CompiledZkpProgram,
proof: &Proof,
public_inputs: Vec<I>,
constant_inputs: Vec<I>,
) -> Result<()>
where
I: Into<ZkpProgramInput>,
{
let constant_inputs = constant_inputs
.into_iter()
.flat_map(|x| I::into(x).0.to_native_fields())
.collect::<Vec<BigInt>>();
let public_inputs = public_inputs
.into_iter()
.flat_map(|x| I::into(x).0.to_native_fields())
.collect::<Vec<BigInt>>();
let backend = &self.zkp_backend;
trace!("Starting JIT (verifier)");
let now = Instant::now();
let prog = backend.jit_verifier(program, &constant_inputs, &public_inputs)?;
trace!("Verifier JIT time {}s", now.elapsed().as_secs_f64());
trace!("Starting backend verify...");
Ok(backend.verify(&prog, proof)?)
}
pub fn verification_builder<'r, 'p>(
&'r self,
program: &'p CompiledZkpProgram,
) -> VerificationBuilder<'r, 'p, '_, T, B> {
VerificationBuilder::new(self, program)
}
}
impl GenericRuntime<(), ()> {
#[deprecated]
pub fn new(params: &Params) -> Result<FheRuntime> {
Self::new_fhe(params)
}
fn make_fhe_runtime_data(params: &Params) -> Result<FheRuntimeData> {
match params.scheme_type {
SchemeType::Bfv => {
let bfv_params = BfvEncryptionParametersBuilder::new()
.set_plain_modulus_u64(params.plain_modulus)
.set_poly_modulus_degree(params.lattice_dimension)
.set_coefficient_modulus(
params
.coeff_modulus
.iter()
.map(|v| Modulus::new(*v).unwrap())
.collect::<Vec<Modulus>>(),
)
.build()?;
let context = SealContext::new(&bfv_params, true, params.security_level)?;
Ok(FheRuntimeData {
params: params.clone(),
context: Context::Seal(context),
})
}
}
}
fn make_zkp_runtime_data() -> ZkpRuntimeData {
ZkpRuntimeData
}
pub fn new_fhe(params: &Params) -> Result<FheRuntime> {
Ok(GenericRuntime {
runtime_data: RuntimeData::Fhe(Self::make_fhe_runtime_data(params)?),
_phantom_t: PhantomData,
zkp_backend: (),
})
}
pub fn new_zkp<B>(backend: B) -> Result<ZkpRuntime<B>>
where
B: ZkpBackend + 'static,
{
Ok(GenericRuntime {
runtime_data: RuntimeData::Zkp(Self::make_zkp_runtime_data()),
_phantom_t: PhantomData,
zkp_backend: backend,
})
}
pub fn new_fhe_zkp<B>(params: &Params, zkp_backend: &B) -> Result<FheZkpRuntime<B>>
where
B: ZkpBackend + Clone + 'static,
{
let runtime_data = RuntimeData::FheZkp(
Self::make_fhe_runtime_data(params)?,
Self::make_zkp_runtime_data(),
);
Ok(GenericRuntime {
runtime_data,
_phantom_t: PhantomData,
zkp_backend: zkp_backend.clone(),
})
}
}
pub type FheZkpRuntime<B> = GenericRuntime<FheZkp, B>;
impl<B> FheZkpRuntime<B> {
pub fn new(params: &Params, zkp_backend: &B) -> Result<Self>
where
B: ZkpBackend + Clone + 'static,
{
Runtime::new_fhe_zkp(params, zkp_backend)
}
}
pub type FheRuntime = GenericRuntime<Fhe, ()>;
impl FheRuntime {
pub fn new(params: &Params) -> Result<Self> {
Runtime::new_fhe(params)
}
}
pub type ZkpRuntime<B> = GenericRuntime<Zkp, B>;
impl<B> ZkpRuntime<B> {
pub fn new(backend: B) -> Result<Self>
where
B: ZkpBackend + 'static,
{
Runtime::new_zkp(backend)
}
}
pub type Runtime = GenericRuntime<(), ()>;
pub struct ProofBuilder<'r, 'p, T: marker::Zkp, B: ZkpBackend> {
runtime: &'r GenericRuntime<T, B>,
program: &'p CompiledZkpProgram,
constant_inputs: Vec<ZkpProgramInput>,
public_inputs: Vec<ZkpProgramInput>,
private_inputs: Vec<ZkpProgramInput>,
}
impl<'r, 'p, T: marker::Zkp, B: ZkpBackend> ProofBuilder<'r, 'p, T, B> {
pub fn new(runtime: &'r GenericRuntime<T, B>, program: &'p CompiledZkpProgram) -> Self
where
T: marker::Zkp,
B: ZkpBackend,
{
Self {
runtime,
program,
constant_inputs: vec![],
public_inputs: vec![],
private_inputs: vec![],
}
}
pub fn constant_input(mut self, input: impl Into<ZkpProgramInput>) -> Self {
self.constant_inputs.push(input.into());
self
}
pub fn constant_inputs<I>(mut self, inputs: I) -> Self
where
I: IntoIterator<Item = T>,
ZkpProgramInput: From<T>,
{
self.constant_inputs
.extend(inputs.into_iter().map(ZkpProgramInput::from));
self
}
pub fn public_input(mut self, input: impl Into<ZkpProgramInput>) -> Self {
self.public_inputs.push(input.into());
self
}
pub fn public_inputs<I>(mut self, inputs: I) -> Self
where
I: IntoIterator<Item = T>,
ZkpProgramInput: From<T>,
{
self.public_inputs
.extend(inputs.into_iter().map(ZkpProgramInput::from));
self
}
pub fn private_input(mut self, input: impl Into<ZkpProgramInput>) -> Self {
self.private_inputs.push(input.into());
self
}
pub fn private_inputs<I>(mut self, inputs: I) -> Self
where
I: IntoIterator<Item = T>,
ZkpProgramInput: From<T>,
{
self.private_inputs
.extend(inputs.into_iter().map(ZkpProgramInput::from));
self
}
pub fn prove(self) -> Result<Proof> {
self.runtime.prove(
self.program,
self.private_inputs,
self.public_inputs,
self.constant_inputs,
)
}
}
pub struct VerificationBuilder<'r, 'p, 'a, T: marker::Zkp, B: ZkpBackend> {
runtime: &'r GenericRuntime<T, B>,
program: &'p CompiledZkpProgram,
proof: Option<&'a Proof>,
constant_inputs: Vec<ZkpProgramInput>,
public_inputs: Vec<ZkpProgramInput>,
}
impl<'r, 'p, 'a, T: marker::Zkp, B: ZkpBackend> VerificationBuilder<'r, 'p, 'a, T, B> {
pub fn new(runtime: &'r GenericRuntime<T, B>, program: &'p CompiledZkpProgram) -> Self
where
T: marker::Zkp,
B: ZkpBackend,
{
Self {
runtime,
program,
proof: None,
public_inputs: vec![],
constant_inputs: vec![],
}
}
pub fn proof(mut self, proof: &'a Proof) -> Self {
self.proof = Some(proof);
self
}
pub fn constant_input(mut self, input: impl Into<ZkpProgramInput>) -> Self {
self.constant_inputs.push(input.into());
self
}
pub fn constant_inputs<I>(mut self, inputs: I) -> Self
where
I: IntoIterator<Item = T>,
ZkpProgramInput: From<T>,
{
self.constant_inputs
.extend(inputs.into_iter().map(ZkpProgramInput::from));
self
}
pub fn public_input(mut self, input: impl Into<ZkpProgramInput>) -> Self {
self.public_inputs.push(input.into());
self
}
pub fn public_inputs<I>(mut self, inputs: I) -> Self
where
I: IntoIterator<Item = T>,
ZkpProgramInput: From<T>,
{
self.public_inputs
.extend(inputs.into_iter().map(ZkpProgramInput::from));
self
}
pub fn verify(self) -> Result<()> {
let proof = self.proof.ok_or_else(|| {
Error::zkp_builder_error(
"You must supply a proof to the verification builder before calling `verify`",
)
})?;
self.runtime.verify(
self.program,
proof,
self.public_inputs,
self.constant_inputs,
)
}
}