use std::{marker::PhantomData, mem::size_of, ops::Deref, sync::Arc};
use crate::{
Encryption, Evaluation, KeylessEvaluation, L1GlweCiphertext, SecretKey,
crypto::{PublicKey, PublicOneTimePad},
fluent::{DynamicGenericInt, EncryptedRecryptedGenricInt, PackedDynamicGenericIntGraphNode},
recrypt_one_time_pad,
safe_bincode::GetSize,
};
use super::{CiphertextOps, FheCircuit, FheCircuitCtx, Muxable, PolynomialCiphertextOps};
use mux_circuits::MuxCircuit;
use parasol_concurrency::AtomicRefCell;
use petgraph::stable_graph::NodeIndex;
use serde::{Deserialize, Serialize};
use sunscreen_tfhe::entities::Polynomial;
pub trait PlaintextOps: Copy + PartialEq + std::fmt::Debug {
fn assert_in_bounds(&self, bits: usize);
fn from_bits<I: Iterator<Item = bool>>(iter: I) -> Self;
fn to_bits(&self, len: usize) -> impl Iterator<Item = bool>;
}
pub trait Sign {
type PlaintextType: PlaintextOps;
fn gen_compare_circuit(max_len: usize, gt: bool, eq: bool) -> MuxCircuit;
fn append_multiply<OutCt: Muxable>(
uop_graph: &mut FheCircuit,
a: &[NodeIndex],
b: &[NodeIndex],
) -> (Vec<NodeIndex>, Vec<NodeIndex>);
fn resize_config(old_size: usize, new_size: usize) -> (usize, usize, bool);
}
#[derive(Clone, Serialize, Deserialize)]
pub struct GenericInt<const N: usize, T: CiphertextOps, U: Sign> {
inner: DynamicGenericInt<T, U>,
}
impl<const N: usize, T: CiphertextOps, U: Sign> Deref for GenericInt<N, T, U> {
type Target = DynamicGenericInt<T, U>;
fn deref(&self) -> &Self::Target {
&self.inner
}
}
impl<const N: usize, T: CiphertextOps, U: Sign> From<GenericInt<N, T, U>>
for DynamicGenericInt<T, U>
{
fn from(value: GenericInt<N, T, U>) -> DynamicGenericInt<T, U> {
value.inner
}
}
impl<const N: usize, T: CiphertextOps, U: Sign> From<DynamicGenericInt<T, U>>
for GenericInt<N, T, U>
{
fn from(value: DynamicGenericInt<T, U>) -> Self {
assert_eq!(value.bits.len(), N);
Self { inner: value }
}
}
impl<const N: usize, T: CiphertextOps, U: Sign> GetSize for GenericInt<N, T, U> {
fn get_size(params: &crate::Params) -> usize {
N * T::get_size(params) + size_of::<u64>()
}
fn check_is_valid(&self, params: &crate::Params) -> crate::Result<()> {
for b in &self.inner.bits {
b.borrow().check_is_valid(params)?;
}
Ok(())
}
}
impl<const N: usize, T, U> GenericInt<N, T, U>
where
T: CiphertextOps,
U: Sign,
{
pub fn new(enc: &Encryption) -> Self {
Self {
inner: DynamicGenericInt::new(enc, N),
}
}
pub fn from_bits_shallow(bits: Vec<Arc<AtomicRefCell<T>>>) -> Self {
Self {
inner: DynamicGenericInt::from_bits_shallow(bits),
}
}
pub fn encrypt_secret(val: U::PlaintextType, enc: &Encryption, sk: &SecretKey) -> Self {
Self {
inner: DynamicGenericInt::<_, U>::encrypt_secret(val, enc, sk, N),
}
}
pub fn decrypt(&self, enc: &Encryption, sk: &SecretKey) -> U::PlaintextType {
self.inner.decrypt(enc, sk)
}
pub fn trivial(val: U::PlaintextType, enc: &Encryption, eval: &Evaluation) -> Self {
Self {
inner: DynamicGenericInt::<_, U>::trivial(val, enc, eval, N),
}
}
}
impl<const N: usize, U> GenericInt<N, L1GlweCiphertext, U>
where
U: Sign,
{
pub fn encrypt(val: U::PlaintextType, enc: &Encryption, pk: &PublicKey) -> Self {
Self {
inner: DynamicGenericInt::<_, U>::encrypt(val, enc, pk, N),
}
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct PackedDynamicGenericInt<T, U>
where
T: CiphertextOps + PolynomialCiphertextOps,
U: Sign,
{
pub(crate) bit_len: u32,
pub(crate) ct: Arc<AtomicRefCell<T>>,
pub(crate) _phantom: PhantomData<U>,
}
impl<T, U> From<(u32, T)> for PackedDynamicGenericInt<T, U>
where
T: CiphertextOps + PolynomialCiphertextOps,
U: Sign,
{
fn from(value: (u32, T)) -> Self {
Self {
bit_len: value.0,
ct: Arc::new(AtomicRefCell::new(value.1)),
_phantom: PhantomData,
}
}
}
impl<T: CiphertextOps + PolynomialCiphertextOps, U: Sign> GetSize
for PackedDynamicGenericInt<T, U>
{
fn get_size(params: &crate::Params) -> usize {
size_of::<u32>() + T::get_size(params)
}
fn check_is_valid(&self, params: &crate::Params) -> crate::Result<()> {
self.ct.borrow().check_is_valid(params)
}
}
impl<T, U> PackedDynamicGenericInt<T, U>
where
T: CiphertextOps + PolynomialCiphertextOps,
U: Sign,
{
pub fn encrypt(val: U::PlaintextType, enc: &Encryption, pk: &PublicKey, n: usize) -> Self {
val.assert_in_bounds(n);
let msg = Self::encode(val, enc, n);
Self {
bit_len: n as u32,
ct: Arc::new(AtomicRefCell::new(T::encrypt(&msg, enc, pk))),
_phantom: PhantomData,
}
}
fn encode(val: U::PlaintextType, enc: &Encryption, n: usize) -> Polynomial<u64> {
assert!(n < T::poly_degree(&enc.params).0);
let coeffs = val
.to_bits(n)
.map(|x| x as u64)
.chain(std::iter::repeat(0))
.take(enc.params.l1_poly_degree().0)
.collect::<Vec<_>>();
Polynomial::<u64>::new(&coeffs)
}
pub fn decrypt(&self, enc: &Encryption, sk: &SecretKey) -> U::PlaintextType {
let n = self.bit_len as usize;
assert!(n < T::poly_degree(&enc.params).0);
let poly = <T as PolynomialCiphertextOps>::decrypt(&self.ct.borrow(), enc, sk);
U::PlaintextType::from_bits(
poly.coeffs()
.iter()
.map(|x| *x == 0x1)
.take(self.bit_len as usize),
)
}
pub fn graph_input(&self, ctx: &FheCircuitCtx) -> PackedDynamicGenericIntGraphNode<T, U> {
PackedDynamicGenericIntGraphNode {
bit_len: self.bit_len,
id: ctx.circuit.borrow_mut().add_node(T::graph_input(&self.ct)),
_phantom: PhantomData,
}
}
pub fn trivial_encrypt(val: U::PlaintextType, enc: &Encryption, n: usize) -> Self {
let msg = Self::encode(val, enc, n);
Self {
bit_len: n as u32,
ct: Arc::new(AtomicRefCell::new(
<T as PolynomialCiphertextOps>::trivial_encryption(&msg, enc),
)),
_phantom: PhantomData,
}
}
pub fn inner(&self) -> T {
self.ct.borrow().clone()
}
}
impl<U: Sign> PackedDynamicGenericInt<L1GlweCiphertext, U> {
pub fn recrypt(
&self,
enc: &Encryption,
eval: &KeylessEvaluation,
otp: &PublicOneTimePad,
) -> EncryptedRecryptedGenricInt<U> {
let t = recrypt_one_time_pad(&self.ct.borrow(), otp, eval, enc);
EncryptedRecryptedGenricInt::new(self.bit_len, t)
}
}