sunscreen 0.8.1

A Fully Homomorphic Encryption (FHE) compiler supporting the Brakerski/Fan-Vercauteren (BFV) scheme.
Documentation
use std::ops::*;

use crypto_bigint::{nlimbs, Uint, Wrapping};
use paste::paste;
use seal_fhe::Plaintext as SealPlaintext;

use sunscreen_runtime::{
    InnerPlaintext, NumCiphertexts, Plaintext, TryFromPlaintext, TryIntoPlaintext,
};

use crate as sunscreen;
use crate::types::ops::GraphCipherInsert;
use crate::{
    fhe::{with_fhe_ctx, FheContextOps},
    types::{
        ops::{
            GraphCipherAdd, GraphCipherConstAdd, GraphCipherConstMul, GraphCipherConstSub,
            GraphCipherMul, GraphCipherPlainAdd, GraphCipherPlainMul, GraphCipherPlainSub,
            GraphCipherSub, GraphConstCipherSub, GraphPlainCipherSub,
        },
        Cipher,
    },
};
use crate::{
    types::{intern::FheProgramNode, BfvType, FheType, TypeNameInstance},
    FheProgramInputTrait, Params, TypeName as DeriveTypeName, WithContext,
};

#[derive(Debug, Clone, Copy, DeriveTypeName, PartialEq, Eq)]
/**
 * A single unsigned integer.
 */
pub struct Unsigned<const LIMBS: usize> {
    val: Uint<LIMBS>,
}

impl<const LIMBS: usize> NumCiphertexts for Unsigned<LIMBS> {
    const NUM_CIPHERTEXTS: usize = 1;
}

impl<const LIMBS: usize> FheProgramInputTrait for Unsigned<LIMBS> {}
impl<const LIMBS: usize> FheType for Unsigned<LIMBS> {}
impl<const LIMBS: usize> BfvType for Unsigned<LIMBS> {}

impl<const LIMBS: usize> std::fmt::Display for Unsigned<LIMBS> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "{}", self.val)
    }
}

impl<const LIMBS: usize> Default for Unsigned<LIMBS> {
    fn default() -> Self {
        Self::from(Uint::ZERO)
    }
}

impl<const LIMBS: usize> TryIntoPlaintext for Unsigned<LIMBS> {
    fn try_into_plaintext(
        &self,
        params: &Params,
    ) -> std::result::Result<Plaintext, sunscreen_runtime::Error> {
        let mut seal_plaintext = SealPlaintext::new()?;

        let sig_bits = self.val.bits_vartime();
        seal_plaintext.resize(sig_bits);

        for i in 0..sig_bits {
            let bit_value = self.val.bit_vartime(i);
            #[allow(clippy::unnecessary_cast)]
            seal_plaintext.set_coefficient(i, bit_value as u64);
        }

        Ok(Plaintext {
            data_type: self.type_name_instance(),
            inner: InnerPlaintext::Seal(vec![WithContext {
                params: params.clone(),
                data: seal_plaintext,
            }]),
        })
    }
}

impl<const LIMBS: usize> TryFromPlaintext for Unsigned<LIMBS> {
    fn try_from_plaintext(
        plaintext: &Plaintext,
        params: &Params,
    ) -> std::result::Result<Self, sunscreen_runtime::Error> {
        let val = match &plaintext.inner {
            InnerPlaintext::Seal(p) => {
                if p.len() != 1 {
                    return Err(sunscreen_runtime::Error::IncorrectCiphertextCount);
                }

                let bits = usize::min(std::mem::size_of::<Uint<LIMBS>>() * 8, p[0].len());

                let negative_cutoff = (params.plain_modulus + 1) / 2;

                let mut val = Uint::ZERO;
                for i in 0..bits {
                    let coeff = p[0].get_coefficient(i);
                    if coeff < negative_cutoff {
                        val = wrapping_add(
                            val,
                            wrapping_mul(Uint::from_u8(0x1) << i, Uint::from_u64(coeff)),
                        );
                    } else {
                        val = wrapping_sub(
                            val,
                            wrapping_mul(
                                Uint::from_u8(0x1) << i,
                                Uint::from_u64(params.plain_modulus - coeff),
                            ),
                        );
                    }
                }

                Self { val }
            }
        };

        Ok(val)
    }
}

impl<const LIMBS: usize> From<Uint<LIMBS>> for Unsigned<LIMBS> {
    fn from(val: Uint<LIMBS>) -> Self {
        Self { val }
    }
}

impl<const LIMBS: usize> From<u64> for Unsigned<LIMBS> {
    fn from(n: u64) -> Self {
        Self {
            val: Uint::from_u64(n),
        }
    }
}

impl<const LIMBS: usize> From<Unsigned<LIMBS>> for Uint<LIMBS> {
    fn from(unsigned: Unsigned<LIMBS>) -> Self {
        unsigned.val
    }
}

fn wrapping_add<const LIMBS: usize>(lhs: Uint<LIMBS>, rhs: Uint<LIMBS>) -> Uint<LIMBS> {
    (Wrapping(lhs) + Wrapping(rhs)).0
}
fn wrapping_mul<const LIMBS: usize>(lhs: Uint<LIMBS>, rhs: Uint<LIMBS>) -> Uint<LIMBS> {
    (Wrapping(lhs) * Wrapping(rhs)).0
}
fn wrapping_sub<const LIMBS: usize>(lhs: Uint<LIMBS>, rhs: Uint<LIMBS>) -> Uint<LIMBS> {
    (Wrapping(lhs) - Wrapping(rhs)).0
}

macro_rules! impl_std_op {
    ($($op:ident),+) => {
        $(
            paste! {
                impl<const LIMBS: usize> $op for Unsigned<LIMBS> {
                    type Output = Self;

                    fn [<$op:lower>](self, rhs: Unsigned<LIMBS>) -> Self::Output {
                        Self::Output {
                            val: self.val.[<wrapping_ $op:lower>](&rhs.val),
                        }
                    }
                }

                impl<const LIMBS: usize> $op<Uint<LIMBS>> for Unsigned<LIMBS> {
                    type Output = Self;

                    fn [<$op:lower>](self, rhs: Uint<LIMBS>) -> Self::Output {
                        Self {
                            val: self.val.[<wrapping_ $op:lower>](&rhs),
                        }
                    }
                }

                impl<const LIMBS: usize> $op<Unsigned<LIMBS>> for Uint<LIMBS> {
                    type Output = Unsigned<LIMBS>;

                    fn [<$op:lower>](self, rhs: Self::Output) -> Self::Output {
                        Self::Output {
                            val: self.[<wrapping_ $op:lower>](&rhs.val),
                        }
                    }
                }

                impl<const LIMBS: usize> $op<u64> for Unsigned<LIMBS> {
                    type Output = Self;

                    fn [<$op:lower>](self, rhs: u64) -> Self::Output {
                        Self {
                            val: self.val.[<wrapping_ $op:lower>](&Uint::<LIMBS>::from_u64(rhs)),
                        }
                    }
                }

                impl<const LIMBS: usize> $op<Unsigned<LIMBS>> for u64 {
                    type Output = Unsigned<LIMBS>;

                    fn [<$op:lower>](self, rhs: Self::Output) -> Self::Output {
                        Self::Output {
                            val: Uint::from_u64(self).[<wrapping_ $op:lower>](&rhs.val),
                        }
                    }
                }
            }
        )+
    };
}

impl_std_op! {
    Add, Sub, Mul
}

macro_rules! impl_graph_cipher_op {
    ($(($op:ident, $op_noun:ident)),+) => {
        $(
            paste! {
                impl<const LIMBS: usize> [<GraphCipher $op>] for Unsigned<LIMBS> {
                    type Left = Self;
                    type Right = Self;

                    fn [<graph_cipher_ $op:lower>](
                        a: FheProgramNode<Cipher<Self::Left>>,
                        b: FheProgramNode<Cipher<Self::Right>>,
                    ) -> FheProgramNode<Cipher<Self::Left>> {
                        with_fhe_ctx(|ctx| {
                            let n = ctx.[<add_ $op_noun>](a.ids[0], b.ids[0]);

                            FheProgramNode::new(&[n])
                        })
                    }
                }

                impl<const LIMBS: usize> [<GraphCipherPlain $op>] for Unsigned<LIMBS> {
                    type Left = Self;
                    type Right = Self;

                    fn [<graph_cipher_plain_ $op:lower>](
                        a: FheProgramNode<Cipher<Self::Left>>,
                        b: FheProgramNode<Self::Right>,
                    ) -> FheProgramNode<Cipher<Self::Left>> {
                        with_fhe_ctx(|ctx| {
                            let n = ctx.[<add_ $op_noun _plaintext>](a.ids[0], b.ids[0]);

                            FheProgramNode::new(&[n])
                        })
                    }
                }

                impl<const LIMBS: usize> [<GraphCipherConst $op>] for Unsigned<LIMBS> {
                    type Left = Self;
                    type Right = Uint<LIMBS>;

                    fn [<graph_cipher_const_ $op:lower>](
                        a: FheProgramNode<Cipher<Self::Left>>,
                        b: Uint<LIMBS>,
                    ) -> FheProgramNode<Cipher<Self::Left>> {
                        let lit = Self::graph_cipher_insert(b);
                        with_fhe_ctx(|ctx| {
                            let [<$op:lower>] = ctx.[<add_ $op_noun _plaintext>](a.ids[0], lit.ids[0]);

                            FheProgramNode::new(&[[<$op:lower>]])
                        })
                    }
                }
            }
        )+
    };
}

impl_graph_cipher_op! {
    (Add, addition),
    (Sub, subtraction),
    (Mul, multiplication)
}

impl<const LIMBS: usize> GraphCipherInsert for Unsigned<LIMBS> {
    type Lit = Uint<LIMBS>;
    type Val = Self;

    fn graph_cipher_insert(lit: Self::Lit) -> FheProgramNode<Self::Val> {
        with_fhe_ctx(|ctx| {
            let lit = Self::from(lit).try_into_plaintext(&ctx.data).unwrap();
            let lit = ctx.add_plaintext_literal(lit.inner);

            FheProgramNode::new(&[lit])
        })
    }
}

impl<const LIMBS: usize> GraphConstCipherSub for Unsigned<LIMBS> {
    type Left = Uint<LIMBS>;
    type Right = Self;

    fn graph_const_cipher_sub(
        a: Uint<LIMBS>,
        b: FheProgramNode<Cipher<Self::Right>>,
    ) -> FheProgramNode<Cipher<Self::Right>> {
        let lit = Self::graph_cipher_insert(a);
        with_fhe_ctx(|ctx| {
            let n = ctx.add_subtraction_plaintext(b.ids[0], lit.ids[0]);
            let n = ctx.add_negate(n);

            FheProgramNode::new(&[n])
        })
    }
}

impl<const LIMBS: usize> GraphPlainCipherSub for Unsigned<LIMBS> {
    type Left = Self;
    type Right = Self;

    fn graph_plain_cipher_sub(
        a: FheProgramNode<Self::Left>,
        b: FheProgramNode<Cipher<Self::Right>>,
    ) -> FheProgramNode<Cipher<Self::Left>> {
        with_fhe_ctx(|ctx| {
            let n = ctx.add_subtraction_plaintext(b.ids[0], a.ids[0]);
            let n = ctx.add_negate(n);

            FheProgramNode::new(&[n])
        })
    }
}

macro_rules! type_synonyms {
    ($($bits:expr),+) => {
        $(
            paste! {
                #[doc= concat!("Unsigned ", stringify!($bits), "-bit integer")]
                pub type [<Unsigned $bits>] = Unsigned<{nlimbs!($bits)}>;
            }
        )+
    };
}

type_synonyms! {
    64, 128, 256, 512
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn can_add_non_fhe() {
        let a = Unsigned256::from(5);
        let b = Unsigned256::from(10);

        assert_eq!(a + b, 15.into());
        assert_eq!(a + 10, 15.into());
        assert_eq!(10 + a, 15.into());
    }

    #[test]
    fn can_mul_non_fhe() {
        let a = Unsigned64::from(5);
        let b = Unsigned64::from(10);

        assert_eq!(a * b, 50.into());
        assert_eq!(a * 10, 50.into());
        assert_eq!(10 * a, 50.into());
    }

    #[test]
    fn can_sub_non_fhe() {
        let a = Unsigned128::from(5);
        let b = Unsigned128::from(11);

        assert_eq!(b - a, 6.into());
        assert_eq!(b - 5, 6.into());
    }
}