use rlp::encode_list;
use seal_fhe::SecurityLevel;
pub use semver::Version;
use serde::{Deserialize, Serialize};
use sunscreen_compiler_common::Type;
use sunscreen_fhe_program::{FheProgram, SchemeType};
use crate::{Error, Result};
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct CallSignature {
pub arguments: Vec<Type>,
pub returns: Vec<Type>,
pub num_ciphertexts: Vec<usize>,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum RequiredKeys {
Galois,
Relin,
PublicKey,
}
#[derive(Debug, Clone, Serialize, Hash, Deserialize, PartialEq, Eq)]
pub struct Params {
pub lattice_dimension: u64,
pub coeff_modulus: Vec<u64>,
pub plain_modulus: u64,
pub scheme_type: SchemeType,
pub security_level: SecurityLevel,
}
impl Params {
pub fn to_bytes(&self) -> Vec<u8> {
let mut bytes = vec![];
bytes.extend_from_slice(&self.lattice_dimension.to_be_bytes());
bytes.extend_from_slice(&self.plain_modulus.to_be_bytes());
let scheme_type: u8 = self.scheme_type.into();
bytes.push(scheme_type);
let security_level: i32 = self.security_level.into();
bytes.extend_from_slice(&security_level.to_be_bytes());
bytes.extend(encode_list(&self.coeff_modulus));
bytes
}
pub fn try_from_bytes(bytes: &[u8]) -> Result<Self> {
let (lattice_dimension, rest) = Self::read_u64(bytes)?;
let (plain_modulus, rest) = Self::read_u64(rest)?;
let (scheme_type, rest) = Self::read_u8(rest)?;
let scheme_type: SchemeType = scheme_type.try_into()?;
let (security_level, rest) = Self::read_i32(rest)?;
let security_level: SecurityLevel = security_level.try_into()?;
let coeff_modulus: Vec<u64> = rlp::decode_list(rest);
Ok(Self {
lattice_dimension,
plain_modulus,
scheme_type,
security_level,
coeff_modulus,
})
}
fn read_u64(bytes: &[u8]) -> Result<(u64, &[u8])> {
let (int_bytes, rest) = bytes.split_at(std::mem::size_of::<u64>());
let val = u64::from_be_bytes(
int_bytes
.try_into()
.map_err(|_| Error::ParamDeserializationError)?,
);
Ok((val, rest))
}
fn read_i32(bytes: &[u8]) -> Result<(i32, &[u8])> {
let (int_bytes, rest) = bytes.split_at(std::mem::size_of::<i32>());
let val = i32::from_be_bytes(
int_bytes
.try_into()
.map_err(|_| Error::ParamDeserializationError)?,
);
Ok((val, rest))
}
fn read_u8(bytes: &[u8]) -> Result<(u8, &[u8])> {
let (int_bytes, rest) = bytes.split_at(std::mem::size_of::<u8>());
let val = u8::from_be_bytes(
int_bytes
.try_into()
.map_err(|_| Error::ParamDeserializationError)?,
);
Ok((val, rest))
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct FheProgramMetadata {
pub params: Params,
pub signature: CallSignature,
pub required_keys: Vec<RequiredKeys>,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct CompiledFheProgram {
pub fhe_program_fn: FheProgram,
pub metadata: FheProgramMetadata,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn can_roundtrip_params() {
let params = Params {
lattice_dimension: 4096,
plain_modulus: 64,
coeff_modulus: vec![1, 2, 3, 4],
security_level: SecurityLevel::TC192,
scheme_type: SchemeType::Bfv,
};
let params_2 = Params::try_from_bytes(¶ms.to_bytes()).unwrap();
assert_eq!(params, params_2);
}
#[test]
fn can_serialize_deserialize_typename() {
let typename = Type {
name: "foo::Bar".to_owned(),
version: Version::new(42, 24, 6),
is_encrypted: false,
};
let serialized = serde_json::to_string(&typename).unwrap();
let deserialized: Type = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized.name, typename.name);
assert_eq!(deserialized.version, typename.version);
}
}