parasol_cpu 0.10.0

This crate contains the Parasol CPU, which runs programs over a mix of encrypted and plaintext data.
Documentation
use std::sync::{Arc, OnceLock};

use parasol_runtime::{
    DEFAULT_128, Encryption, Evaluation, L1GlweCiphertext, SecretKey,
    fluent::{Int, UInt},
    test_utils::get_compute_key_128,
};
use rayon::{ThreadPool, ThreadPoolBuilder};
use sunscreen_tfhe::entities::Polynomial;

use crate::{Byte, FheComputer, ToArg};

pub fn poly_one() -> Arc<Polynomial<u64>> {
    static ONE: OnceLock<Arc<Polynomial<u64>>> = OnceLock::new();

    ONE.get_or_init(|| {
        let mut coeffs = vec![0; 1024];
        coeffs[0] = 1;
        Arc::new(Polynomial::new(&coeffs))
    })
    .clone()
}

pub fn get_thread_pool() -> Arc<ThreadPool> {
    static THREAD_POOL: OnceLock<Arc<ThreadPool>> = OnceLock::new();

    THREAD_POOL
        .get_or_init(|| {
            Arc::new(
                ThreadPoolBuilder::new()
                    .thread_name(|x| format!("Fhe worker {x}"))
                    .build()
                    .unwrap(),
            )
        })
        .clone()
}

pub fn make_computer_128() -> (FheComputer, Encryption) {
    let compute_key = get_compute_key_128();
    let enc = Encryption::new(&DEFAULT_128);
    let eval = Evaluation::new(compute_key, &DEFAULT_128, &enc);

    (
        FheComputer::new_with_threadpool(&enc, &eval, get_thread_pool()),
        enc,
    )
}

pub trait TestFrom<T> {
    fn test_from(value: T) -> Self;
}

macro_rules! impl_test_from {
    (($fty:ty) $( $x:ty ),* ) => {
        $(
            impl TestFrom<$fty> for $x {
                fn test_from(value: $fty) -> Self {
                    value as $x
                }
            }
        )*
    };
}

impl_test_from!((u128) u128, u64, u32, u16, u8);
impl_test_from!((i128) u128, u64, u32, u16, u8);
impl_test_from!((u128) i128, i64, i32, i16, i8);
impl_test_from!((i128) i128, i64, i32, i16, i8);

pub trait Bits<const N: usize> {
    type PlaintextType: num::Num + TestFrom<u128> + TestFrom<i128> + std::fmt::Debug + Copy + ToArg;
}

pub struct BitsUnsigned();

impl Bits<32> for BitsUnsigned {
    type PlaintextType = u32;
}

impl Bits<16> for BitsUnsigned {
    type PlaintextType = u16;
}

impl Bits<8> for BitsUnsigned {
    type PlaintextType = u8;
}

pub enum MaybeEncryptedUInt<const N: usize>
where
    BitsUnsigned: Bits<N>,
    <BitsUnsigned as Bits<N>>::PlaintextType: Into<u64>,
{
    Plain(<BitsUnsigned as Bits<N>>::PlaintextType),
    Encrypted(UInt<N, L1GlweCiphertext>),
}

impl<const N: usize> std::fmt::Debug for MaybeEncryptedUInt<N>
where
    BitsUnsigned: Bits<N>,
    <BitsUnsigned as Bits<N>>::PlaintextType: Into<u64>,
{
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "MaybeEncryptedUInt {{ .. }}")
    }
}

impl<const N: usize> MaybeEncryptedUInt<N>
where
    BitsUnsigned: Bits<N>,
    <BitsUnsigned as Bits<N>>::PlaintextType: Into<u64>,
{
    pub fn new(val: u128, enc: &Encryption, sk: &SecretKey, encrypt: bool) -> Self {
        if !encrypt {
            Self::Plain(<BitsUnsigned as Bits<N>>::PlaintextType::test_from(val))
        } else {
            Self::Encrypted(UInt::encrypt_secret(val, enc, sk))
        }
    }

    pub fn get(
        &self,
        enc: &Encryption,
        sk: &SecretKey,
    ) -> <BitsUnsigned as Bits<N>>::PlaintextType {
        match self {
            Self::Plain(x) => *x,
            Self::Encrypted(x) => {
                <BitsUnsigned as Bits<N>>::PlaintextType::test_from(x.decrypt(enc, sk))
            }
        }
    }
}

impl<const N: usize> ToArg for MaybeEncryptedUInt<N>
where
    BitsUnsigned: Bits<N>,
    <BitsUnsigned as Bits<N>>::PlaintextType: Into<u64>,
{
    fn alignment() -> usize {
        <BitsUnsigned as Bits<N>>::PlaintextType::alignment()
    }

    fn size() -> usize {
        <BitsUnsigned as Bits<N>>::PlaintextType::size()
    }

    fn to_bytes(&self) -> Vec<Byte> {
        match self {
            Self::Plain(x) => x.to_bytes(),
            Self::Encrypted(x) => x.to_bytes(),
        }
    }

    fn try_from_bytes(data: Vec<crate::Byte>) -> crate::Result<Self> {
        match &data[0] {
            Byte::Plaintext(_) => Ok(Self::Plain(
                <<BitsUnsigned as Bits<N>>::PlaintextType>::try_from_bytes(data)?,
            )),
            Byte::Ciphertext(_) => Ok(Self::Encrypted(UInt::try_from_bytes(data)?)),
        }
    }
}

pub struct BitsSigned();

impl Bits<32> for BitsSigned {
    type PlaintextType = i32;
}

impl Bits<16> for BitsSigned {
    type PlaintextType = i16;
}

impl Bits<8> for BitsSigned {
    type PlaintextType = i8;
}

pub enum MaybeEncryptedInt<const N: usize>
where
    BitsSigned: Bits<N>,
    <BitsSigned as Bits<N>>::PlaintextType: Into<i64>,
{
    Plain(<BitsSigned as Bits<N>>::PlaintextType),
    Encrypted(Int<N, L1GlweCiphertext>),
}

impl<const N: usize> std::fmt::Debug for MaybeEncryptedInt<N>
where
    BitsSigned: Bits<N>,
    <BitsSigned as Bits<N>>::PlaintextType: Into<i64>,
{
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "MaybeEncryptedInt {{ .. }}")
    }
}

impl<const N: usize> MaybeEncryptedInt<N>
where
    BitsSigned: Bits<N>,
    <BitsSigned as Bits<N>>::PlaintextType: Into<i64>,
{
    pub fn new(val: i128, enc: &Encryption, sk: &SecretKey, encrypt: bool) -> Self {
        if !encrypt {
            Self::Plain(<BitsSigned as Bits<N>>::PlaintextType::test_from(val))
        } else {
            Self::Encrypted(Int::encrypt_secret(val, enc, sk))
        }
    }

    pub fn get(&self, enc: &Encryption, sk: &SecretKey) -> <BitsSigned as Bits<N>>::PlaintextType {
        match self {
            Self::Plain(x) => *x,
            Self::Encrypted(x) => {
                <BitsSigned as Bits<N>>::PlaintextType::test_from(x.decrypt(enc, sk))
            }
        }
    }
}

impl<const N: usize> ToArg for MaybeEncryptedInt<N>
where
    BitsSigned: Bits<N>,
    <BitsSigned as Bits<N>>::PlaintextType: Into<i64>,
{
    fn alignment() -> usize {
        <BitsSigned as Bits<N>>::PlaintextType::size()
    }

    fn size() -> usize {
        <BitsSigned as Bits<N>>::PlaintextType::size()
    }

    fn to_bytes(&self) -> Vec<Byte> {
        match self {
            Self::Plain(x) => x.to_bytes(),
            Self::Encrypted(x) => x.to_bytes(),
        }
    }

    fn try_from_bytes(data: Vec<crate::Byte>) -> crate::Result<Self> {
        match &data[0] {
            Byte::Plaintext(_) => Ok(Self::Plain(
                <<BitsSigned as Bits<N>>::PlaintextType>::try_from_bytes(data)?,
            )),
            Byte::Ciphertext(_) => Ok(Self::Encrypted(Int::try_from_bytes(data)?)),
        }
    }
}

#[cfg(test)]
mod tests {
    use parasol_runtime::test_utils::{get_encryption_128, get_secret_keys_128};

    use super::{MaybeEncryptedInt, MaybeEncryptedUInt};

    #[test]
    fn can_roundtrip_maybeuint() {
        let enc = get_encryption_128();
        let sk = get_secret_keys_128();

        for i in 0..10 {
            let val = MaybeEncryptedUInt::<8>::new(i, &enc, &sk, i % 2 == 0);
            assert_eq!(val.get(&enc, &sk), i as u8);
        }
    }

    #[test]
    fn can_roundtrip_maybeint() {
        let enc = get_encryption_128();
        let sk = get_secret_keys_128();

        for i in 118..138u8 {
            let val = MaybeEncryptedInt::<8>::new(i.cast_signed() as i128, &enc, &sk, i % 2 == 0);
            assert_eq!(val.get(&enc, &sk), i as i8);
        }
    }
}