use seal_fhe::Plaintext as SealPlaintext;
use crate::{
fhe::{with_fhe_ctx, FheContextOps},
types::{
ops::{
GraphCipherAdd, GraphCipherConstAdd, GraphCipherConstDiv, GraphCipherConstMul,
GraphCipherConstSub, GraphCipherInsert, GraphCipherMul, GraphCipherNeg,
GraphCipherPlainAdd, GraphCipherPlainMul, GraphCipherPlainSub, GraphCipherSub,
GraphConstCipherSub, GraphPlainCipherSub,
},
Cipher,
},
};
use crate::{
types::{intern::FheProgramNode, BfvType, FheType, Type, Version},
FheProgramInputTrait, Params, WithContext,
};
use sunscreen_runtime::{
InnerPlaintext, NumCiphertexts, Plaintext, TryFromPlaintext, TryIntoPlaintext, TypeName,
TypeNameInstance,
};
use std::ops::*;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct Fractional<const INT_BITS: usize> {
val: f64,
}
impl<const INT_BITS: usize> std::ops::Deref for Fractional<INT_BITS> {
type Target = f64;
fn deref(&self) -> &Self::Target {
&self.val
}
}
impl<const INT_BITS: usize> NumCiphertexts for Fractional<INT_BITS> {
const NUM_CIPHERTEXTS: usize = 1;
}
impl<const INT_BITS: usize> FheProgramInputTrait for Fractional<INT_BITS> {}
impl<const INT_BITS: usize> Default for Fractional<INT_BITS> {
fn default() -> Self {
Self::from(0.0)
}
}
impl<const INT_BITS: usize> TypeName for Fractional<INT_BITS> {
fn type_name() -> Type {
let version = env!("CARGO_PKG_VERSION");
Type {
name: format!("sunscreen::types::Fractional<{}>", INT_BITS),
version: Version::parse(version).expect("Crate version is not a valid semver"),
is_encrypted: false,
}
}
}
impl<const INT_BITS: usize> TypeNameInstance for Fractional<INT_BITS> {
fn type_name_instance(&self) -> Type {
Self::type_name()
}
}
impl<const INT_BITS: usize> FheType for Fractional<INT_BITS> {}
impl<const INT_BITS: usize> BfvType for Fractional<INT_BITS> {}
impl<const INT_BITS: usize> Fractional<INT_BITS> {}
impl<const INT_BITS: usize> GraphCipherAdd for Fractional<INT_BITS> {
type Left = Fractional<INT_BITS>;
type Right = Fractional<INT_BITS>;
fn graph_cipher_add(
a: FheProgramNode<Cipher<Self::Left>>,
b: FheProgramNode<Cipher<Self::Right>>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_fhe_ctx(|ctx| {
let n = ctx.add_addition(a.ids[0], b.ids[0]);
FheProgramNode::new(&[n])
})
}
}
impl<const INT_BITS: usize> GraphCipherPlainAdd for Fractional<INT_BITS> {
type Left = Fractional<INT_BITS>;
type Right = Fractional<INT_BITS>;
fn graph_cipher_plain_add(
a: FheProgramNode<Cipher<Self::Left>>,
b: FheProgramNode<Self::Right>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_fhe_ctx(|ctx| {
let n = ctx.add_addition_plaintext(a.ids[0], b.ids[0]);
FheProgramNode::new(&[n])
})
}
}
impl<const INT_BITS: usize> GraphCipherInsert for Fractional<INT_BITS> {
type Lit = f64;
type Val = Self;
fn graph_cipher_insert(lit: Self::Lit) -> FheProgramNode<Self::Val> {
with_fhe_ctx(|ctx| {
let lit = Self::from(lit).try_into_plaintext(&ctx.data).unwrap();
let lit = ctx.add_plaintext_literal(lit.inner);
FheProgramNode::new(&[lit])
})
}
}
impl<const INT_BITS: usize> GraphCipherConstAdd for Fractional<INT_BITS> {
type Left = Fractional<INT_BITS>;
type Right = f64;
fn graph_cipher_const_add(
a: FheProgramNode<Cipher<Self::Left>>,
b: Self::Right,
) -> FheProgramNode<Cipher<Self::Left>> {
let lit = Self::graph_cipher_insert(b);
with_fhe_ctx(|ctx| {
let n = ctx.add_addition_plaintext(a.ids[0], lit.ids[0]);
FheProgramNode::new(&[n])
})
}
}
impl<const INT_BITS: usize> GraphCipherSub for Fractional<INT_BITS> {
type Left = Fractional<INT_BITS>;
type Right = Fractional<INT_BITS>;
fn graph_cipher_sub(
a: FheProgramNode<Cipher<Self::Left>>,
b: FheProgramNode<Cipher<Self::Right>>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_fhe_ctx(|ctx| {
let n = ctx.add_subtraction(a.ids[0], b.ids[0]);
FheProgramNode::new(&[n])
})
}
}
impl<const INT_BITS: usize> GraphCipherPlainSub for Fractional<INT_BITS> {
type Left = Fractional<INT_BITS>;
type Right = Fractional<INT_BITS>;
fn graph_cipher_plain_sub(
a: FheProgramNode<Cipher<Self::Left>>,
b: FheProgramNode<Self::Right>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_fhe_ctx(|ctx| {
let n = ctx.add_subtraction_plaintext(a.ids[0], b.ids[0]);
FheProgramNode::new(&[n])
})
}
}
impl<const INT_BITS: usize> GraphPlainCipherSub for Fractional<INT_BITS> {
type Left = Fractional<INT_BITS>;
type Right = Fractional<INT_BITS>;
fn graph_plain_cipher_sub(
a: FheProgramNode<Self::Left>,
b: FheProgramNode<Cipher<Self::Right>>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_fhe_ctx(|ctx| {
let n = ctx.add_subtraction_plaintext(b.ids[0], a.ids[0]);
let n = ctx.add_negate(n);
FheProgramNode::new(&[n])
})
}
}
impl<const INT_BITS: usize> GraphCipherConstSub for Fractional<INT_BITS> {
type Left = Fractional<INT_BITS>;
type Right = f64;
fn graph_cipher_const_sub(
a: FheProgramNode<Cipher<Self::Left>>,
b: Self::Right,
) -> FheProgramNode<Cipher<Self::Left>> {
let lit = Self::graph_cipher_insert(b);
with_fhe_ctx(|ctx| {
let n = ctx.add_subtraction_plaintext(a.ids[0], lit.ids[0]);
FheProgramNode::new(&[n])
})
}
}
impl<const INT_BITS: usize> GraphConstCipherSub for Fractional<INT_BITS> {
type Left = f64;
type Right = Fractional<INT_BITS>;
fn graph_const_cipher_sub(
a: Self::Left,
b: FheProgramNode<Cipher<Self::Right>>,
) -> FheProgramNode<Cipher<Self::Right>> {
let lit = Self::graph_cipher_insert(a);
with_fhe_ctx(|ctx| {
let n = ctx.add_subtraction_plaintext(b.ids[0], lit.ids[0]);
let n = ctx.add_negate(n);
FheProgramNode::new(&[n])
})
}
}
impl<const INT_BITS: usize> GraphCipherMul for Fractional<INT_BITS> {
type Left = Fractional<INT_BITS>;
type Right = Fractional<INT_BITS>;
fn graph_cipher_mul(
a: FheProgramNode<Cipher<Self::Left>>,
b: FheProgramNode<Cipher<Self::Right>>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_fhe_ctx(|ctx| {
let n = ctx.add_multiplication(a.ids[0], b.ids[0]);
FheProgramNode::new(&[n])
})
}
}
impl<const INT_BITS: usize> GraphCipherPlainMul for Fractional<INT_BITS> {
type Left = Fractional<INT_BITS>;
type Right = Fractional<INT_BITS>;
fn graph_cipher_plain_mul(
a: FheProgramNode<Cipher<Self::Left>>,
b: FheProgramNode<Self::Right>,
) -> FheProgramNode<Cipher<Self::Left>> {
with_fhe_ctx(|ctx| {
let n = ctx.add_multiplication_plaintext(a.ids[0], b.ids[0]);
FheProgramNode::new(&[n])
})
}
}
impl<const INT_BITS: usize> GraphCipherConstMul for Fractional<INT_BITS> {
type Left = Fractional<INT_BITS>;
type Right = f64;
fn graph_cipher_const_mul(
a: FheProgramNode<Cipher<Self::Left>>,
b: Self::Right,
) -> FheProgramNode<Cipher<Self::Left>> {
let lit = Self::graph_cipher_insert(b);
with_fhe_ctx(|ctx| {
let n = ctx.add_multiplication_plaintext(a.ids[0], lit.ids[0]);
FheProgramNode::new(&[n])
})
}
}
impl<const INT_BITS: usize> GraphCipherConstDiv for Fractional<INT_BITS> {
type Left = Fractional<INT_BITS>;
type Right = f64;
fn graph_cipher_const_div(
a: FheProgramNode<Cipher<Self::Left>>,
b: f64,
) -> FheProgramNode<Cipher<Self::Left>> {
let lit = Self::graph_cipher_insert(1. / b);
with_fhe_ctx(|ctx| {
let n = ctx.add_multiplication_plaintext(a.ids[0], lit.ids[0]);
FheProgramNode::new(&[n])
})
}
}
impl<const INT_BITS: usize> GraphCipherNeg for Fractional<INT_BITS> {
type Val = Fractional<INT_BITS>;
fn graph_cipher_neg(a: FheProgramNode<Cipher<Self>>) -> FheProgramNode<Cipher<Self::Val>> {
with_fhe_ctx(|ctx| {
let n = ctx.add_negate(a.ids[0]);
FheProgramNode::new(&[n])
})
}
}
impl<const INT_BITS: usize> TryIntoPlaintext for Fractional<INT_BITS> {
fn try_into_plaintext(
&self,
params: &Params,
) -> std::result::Result<Plaintext, sunscreen_runtime::Error> {
if self.val.is_nan() {
return Err(sunscreen_runtime::Error::fhe_type_error("Value is NaN."));
}
if self.val.is_infinite() {
return Err(sunscreen_runtime::Error::fhe_type_error(
"Value is infinite.",
));
}
let mut seal_plaintext = SealPlaintext::new()?;
let n = params.lattice_dimension as usize;
seal_plaintext.resize(n);
if self.val.is_subnormal() || self.val == 0.0 {
return Ok(Plaintext {
data_type: self.type_name_instance(),
inner: InnerPlaintext::Seal(vec![WithContext {
params: params.clone(),
data: seal_plaintext,
}]),
});
}
let as_u64: u64 = self.val.to_bits();
let sign_mask = 0x1 << 63;
let mantissa_mask = 0xFFFFFFFFFFFFF;
let exp_mask = !mantissa_mask & !sign_mask;
let mantissa = as_u64 & mantissa_mask | (mantissa_mask + 1);
let exp = as_u64 & exp_mask;
let power = (exp >> (f64::MANTISSA_DIGITS - 1)) as i64 - 1023;
let sign = (as_u64 & sign_mask) >> 63;
if power + 1 > INT_BITS as i64 {
return Err(sunscreen_runtime::Error::fhe_type_error("Out of range"));
}
for i in 0..f64::MANTISSA_DIGITS {
let bit_value = (mantissa & 0x1 << i) >> i;
let bit_power = power - (f64::MANTISSA_DIGITS - i - 1) as i64;
let coeff_index = if bit_power >= 0 {
bit_power as usize
} else {
(n as i64 + bit_power) as usize
};
let sign = if bit_power >= 0 { sign } else { !sign & 0x1 };
let coeff = if sign == 0 {
bit_value
} else if bit_value > 0 {
params.plain_modulus - bit_value
} else {
0
};
seal_plaintext.set_coefficient(coeff_index, coeff);
}
Ok(Plaintext {
data_type: self.type_name_instance(),
inner: InnerPlaintext::Seal(vec![WithContext {
params: params.clone(),
data: seal_plaintext,
}]),
})
}
}
impl<const INT_BITS: usize> TryFromPlaintext for Fractional<INT_BITS> {
fn try_from_plaintext(
plaintext: &Plaintext,
params: &Params,
) -> std::result::Result<Self, sunscreen_runtime::Error> {
let val = match &plaintext.inner {
InnerPlaintext::Seal(p) => {
if p.len() != 1 {
return Err(sunscreen_runtime::Error::IncorrectCiphertextCount);
}
let mut val = 0.0f64;
let n = params.lattice_dimension as usize;
let len = p[0].len();
let negative_cutoff = (params.plain_modulus + 1) / 2;
for i in 0..usize::min(n, len) {
let power = if i < INT_BITS {
i as i64
} else {
i as i64 - n as i64
};
let coeff = p[0].get_coefficient(i);
let sign = if power >= 0 { 1f64 } else { -1f64 };
if coeff < negative_cutoff {
val += sign * coeff as f64 * (power as f64).exp2();
} else {
val -= sign * (params.plain_modulus - coeff) as f64 * (power as f64).exp2();
};
}
Self { val }
}
};
Ok(val)
}
}
impl<const INT_BITS: usize> From<f64> for Fractional<INT_BITS> {
fn from(val: f64) -> Self {
Self { val }
}
}
impl<const INT_BITS: usize> From<Fractional<INT_BITS>> for f64 {
fn from(frac: Fractional<INT_BITS>) -> Self {
frac.val
}
}
impl<const INT_BITS: usize> Add for Fractional<INT_BITS> {
type Output = Self;
fn add(self, rhs: Self) -> Self {
Self {
val: self.val + rhs.val,
}
}
}
impl<const INT_BITS: usize> Add<f64> for Fractional<INT_BITS> {
type Output = Self;
fn add(self, rhs: f64) -> Self {
Self {
val: self.val + rhs,
}
}
}
impl<const INT_BITS: usize> Add<Fractional<INT_BITS>> for f64 {
type Output = Fractional<INT_BITS>;
fn add(self, rhs: Fractional<INT_BITS>) -> Self::Output {
Fractional {
val: self + rhs.val,
}
}
}
impl<const INT_BITS: usize> Mul for Fractional<INT_BITS> {
type Output = Self;
fn mul(self, rhs: Self) -> Self {
Self {
val: self.val * rhs.val,
}
}
}
impl<const INT_BITS: usize> Mul<f64> for Fractional<INT_BITS> {
type Output = Self;
fn mul(self, rhs: f64) -> Self {
Self {
val: self.val * rhs,
}
}
}
impl<const INT_BITS: usize> Mul<Fractional<INT_BITS>> for f64 {
type Output = Fractional<INT_BITS>;
fn mul(self, rhs: Fractional<INT_BITS>) -> Self::Output {
Fractional {
val: self * rhs.val,
}
}
}
impl<const INT_BITS: usize> Sub for Fractional<INT_BITS> {
type Output = Self;
fn sub(self, rhs: Self) -> Self {
Self {
val: self.val - rhs.val,
}
}
}
impl<const INT_BITS: usize> Sub<f64> for Fractional<INT_BITS> {
type Output = Self;
fn sub(self, rhs: f64) -> Self {
Self {
val: self.val - rhs,
}
}
}
impl<const INT_BITS: usize> Sub<Fractional<INT_BITS>> for f64 {
type Output = Fractional<INT_BITS>;
fn sub(self, rhs: Fractional<INT_BITS>) -> Self::Output {
Fractional {
val: self - rhs.val,
}
}
}
impl<const INT_BITS: usize> Div<f64> for Fractional<INT_BITS> {
type Output = Self;
fn div(self, rhs: f64) -> Self {
Self {
val: self.val / rhs,
}
}
}
impl<const INT_BITS: usize> Neg for Fractional<INT_BITS> {
type Output = Self;
fn neg(self) -> Self {
Self { val: -self.val }
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::approx_constant)]
use super::*;
use crate::{SchemeType, SecurityLevel};
use float_cmp::ApproxEq;
#[test]
fn can_encode_decode_fractional() {
let round_trip = |x: f64| {
let params = Params {
lattice_dimension: 4096,
plain_modulus: 1_000_000,
coeff_modulus: vec![],
scheme_type: SchemeType::Bfv,
security_level: SecurityLevel::TC128,
};
let f_1 = Fractional::<64>::from(x);
let pt = f_1.try_into_plaintext(¶ms).unwrap();
let f_2 = Fractional::<64>::try_from_plaintext(&pt, ¶ms).unwrap();
assert_eq!(f_1, f_2);
};
round_trip(3.14);
round_trip(0.0);
round_trip(1.0);
round_trip(5.8125);
round_trip(6.0);
round_trip(6.6);
round_trip(1.2);
round_trip(1e13);
round_trip(0.0000000005);
round_trip(-1.0);
round_trip(-5.875);
round_trip(-6.0);
round_trip(-6.6);
round_trip(-1.2);
round_trip(-1e13);
round_trip(-0.0000000005);
}
#[test]
fn can_add_non_fhe() {
let a = Fractional::<64>::from(3.14);
let b = Fractional::<64>::from(1.5);
assert!((a + b).approx_eq(4.64, (0.0, 1)));
assert!((3.14 + b).approx_eq(4.64, (0.0, 1)));
assert!((a + 1.5).approx_eq(4.64, (0.0, 1)));
}
#[test]
fn can_mul_non_fhe() {
let a = Fractional::<64>::from(3.14);
let b = Fractional::<64>::from(1.5);
assert!((a * b).approx_eq(4.71, (0.0, 1)));
assert!((3.14 * b).approx_eq(4.71, (0.0, 1)));
assert!((a * 1.5).approx_eq(4.71, (0.0, 1)));
}
#[test]
fn can_sub_non_fhe() {
let a = Fractional::<64>::from(3.14);
let b = Fractional::<64>::from(1.5);
assert!((a - b).approx_eq(1.64, (0.0, 1)));
assert!((3.14 - b).approx_eq(1.64, (0.0, 1)));
assert!((a - 1.5).approx_eq(1.64, (0.0, 1)));
}
#[test]
fn can_div_non_fhe() {
let a = Fractional::<64>::from(3.14);
assert!((a / 1.5).approx_eq(3.14 / 1.5, (0.0, 1)));
}
#[test]
fn can_neg_non_fhe() {
let a = Fractional::<64>::from(3.14);
assert_eq!(-a, (-3.14).into());
}
}