1use std::{marker::PhantomData, mem::size_of, ops::Deref, sync::Arc};
2
3use crate::{
4 Encryption, Evaluation, KeylessEvaluation, L1GlweCiphertext, SecretKey,
5 crypto::{PublicKey, PublicOneTimePad},
6 fluent::{DynamicGenericInt, EncryptedRecryptedGenricInt, PackedDynamicGenericIntGraphNode},
7 recrypt_one_time_pad,
8 safe_bincode::GetSize,
9};
10
11use super::{CiphertextOps, FheCircuit, FheCircuitCtx, Muxable, PolynomialCiphertextOps};
12
13use mux_circuits::MuxCircuit;
14use parasol_concurrency::AtomicRefCell;
15use petgraph::stable_graph::NodeIndex;
16use serde::{Deserialize, Serialize};
17use sunscreen_tfhe::entities::Polynomial;
18
19pub trait PlaintextOps: Copy + PartialEq + std::fmt::Debug {
21 fn assert_in_bounds(&self, bits: usize);
23
24 fn from_bits<I: Iterator<Item = bool>>(iter: I) -> Self;
26
27 fn to_bits(&self, len: usize) -> impl Iterator<Item = bool>;
29}
30
31pub trait Sign {
33 type PlaintextType: PlaintextOps;
35
36 fn gen_compare_circuit(max_len: usize, gt: bool, eq: bool) -> MuxCircuit;
38
39 fn append_multiply<OutCt: Muxable>(
41 uop_graph: &mut FheCircuit,
42 a: &[NodeIndex],
43 b: &[NodeIndex],
44 ) -> (Vec<NodeIndex>, Vec<NodeIndex>);
45
46 fn resize_config(old_size: usize, new_size: usize) -> (usize, usize, bool);
49}
50
51#[derive(Clone, Serialize, Deserialize)]
52pub struct GenericInt<const N: usize, T: CiphertextOps, U: Sign> {
55 inner: DynamicGenericInt<T, U>,
56}
57
58impl<const N: usize, T: CiphertextOps, U: Sign> Deref for GenericInt<N, T, U> {
59 type Target = DynamicGenericInt<T, U>;
60
61 fn deref(&self) -> &Self::Target {
62 &self.inner
63 }
64}
65
66impl<const N: usize, T: CiphertextOps, U: Sign> From<GenericInt<N, T, U>>
67 for DynamicGenericInt<T, U>
68{
69 fn from(value: GenericInt<N, T, U>) -> DynamicGenericInt<T, U> {
70 value.inner
71 }
72}
73
74impl<const N: usize, T: CiphertextOps, U: Sign> From<DynamicGenericInt<T, U>>
75 for GenericInt<N, T, U>
76{
77 fn from(value: DynamicGenericInt<T, U>) -> Self {
78 assert_eq!(value.bits.len(), N);
79
80 Self { inner: value }
81 }
82}
83
84impl<const N: usize, T: CiphertextOps, U: Sign> GetSize for GenericInt<N, T, U> {
85 fn get_size(params: &crate::Params) -> usize {
86 N * T::get_size(params) + size_of::<u64>()
87 }
88
89 fn check_is_valid(&self, params: &crate::Params) -> crate::Result<()> {
90 for b in &self.inner.bits {
91 b.borrow().check_is_valid(params)?;
92 }
93
94 Ok(())
95 }
96}
97
98impl<const N: usize, T, U> GenericInt<N, T, U>
99where
100 T: CiphertextOps,
101 U: Sign,
102{
103 pub fn new(enc: &Encryption) -> Self {
106 Self {
107 inner: DynamicGenericInt::new(enc, N),
108 }
109 }
110
111 pub fn from_bits_shallow(bits: Vec<Arc<AtomicRefCell<T>>>) -> Self {
113 Self {
114 inner: DynamicGenericInt::from_bits_shallow(bits),
115 }
116 }
117
118 pub fn encrypt_secret(val: U::PlaintextType, enc: &Encryption, sk: &SecretKey) -> Self {
123 Self {
124 inner: DynamicGenericInt::<_, U>::encrypt_secret(val, enc, sk, N),
125 }
126 }
127
128 pub fn decrypt(&self, enc: &Encryption, sk: &SecretKey) -> U::PlaintextType {
130 self.inner.decrypt(enc, sk)
131 }
132
133 pub fn trivial(val: U::PlaintextType, enc: &Encryption, eval: &Evaluation) -> Self {
135 Self {
136 inner: DynamicGenericInt::<_, U>::trivial(val, enc, eval, N),
137 }
138 }
139}
140
141impl<const N: usize, U> GenericInt<N, L1GlweCiphertext, U>
142where
143 U: Sign,
144{
145 pub fn encrypt(val: U::PlaintextType, enc: &Encryption, pk: &PublicKey) -> Self {
156 Self {
157 inner: DynamicGenericInt::<_, U>::encrypt(val, enc, pk, N),
158 }
159 }
160}
161
162#[derive(Clone, Serialize, Deserialize)]
163pub struct PackedDynamicGenericInt<T, U>
177where
178 T: CiphertextOps + PolynomialCiphertextOps,
179 U: Sign,
180{
181 pub(crate) bit_len: u32,
182 pub(crate) ct: Arc<AtomicRefCell<T>>,
183 pub(crate) _phantom: PhantomData<U>,
184}
185
186impl<T, U> From<(u32, T)> for PackedDynamicGenericInt<T, U>
187where
188 T: CiphertextOps + PolynomialCiphertextOps,
189 U: Sign,
190{
191 fn from(value: (u32, T)) -> Self {
192 Self {
193 bit_len: value.0,
194 ct: Arc::new(AtomicRefCell::new(value.1)),
195 _phantom: PhantomData,
196 }
197 }
198}
199
200impl<T: CiphertextOps + PolynomialCiphertextOps, U: Sign> GetSize
201 for PackedDynamicGenericInt<T, U>
202{
203 fn get_size(params: &crate::Params) -> usize {
204 size_of::<u32>() + T::get_size(params)
205 }
206
207 fn check_is_valid(&self, params: &crate::Params) -> crate::Result<()> {
208 self.ct.borrow().check_is_valid(params)
209 }
210}
211
212impl<T, U> PackedDynamicGenericInt<T, U>
213where
214 T: CiphertextOps + PolynomialCiphertextOps,
215 U: Sign,
216{
217 pub fn encrypt(val: U::PlaintextType, enc: &Encryption, pk: &PublicKey, n: usize) -> Self {
220 val.assert_in_bounds(n);
221
222 let msg = Self::encode(val, enc, n);
223
224 Self {
225 bit_len: n as u32,
226 ct: Arc::new(AtomicRefCell::new(T::encrypt(&msg, enc, pk))),
227 _phantom: PhantomData,
228 }
229 }
230
231 fn encode(val: U::PlaintextType, enc: &Encryption, n: usize) -> Polynomial<u64> {
232 assert!(n < T::poly_degree(&enc.params).0);
233
234 let coeffs = val
235 .to_bits(n)
236 .map(|x| x as u64)
237 .chain(std::iter::repeat(0))
238 .take(enc.params.l1_poly_degree().0)
239 .collect::<Vec<_>>();
240
241 Polynomial::<u64>::new(&coeffs)
242 }
243
244 pub fn decrypt(&self, enc: &Encryption, sk: &SecretKey) -> U::PlaintextType {
246 let n = self.bit_len as usize;
247
248 assert!(n < T::poly_degree(&enc.params).0);
249
250 let poly = <T as PolynomialCiphertextOps>::decrypt(&self.ct.borrow(), enc, sk);
251
252 U::PlaintextType::from_bits(
253 poly.coeffs()
254 .iter()
255 .map(|x| *x == 0x1)
256 .take(self.bit_len as usize),
257 )
258 }
259
260 pub fn graph_input(&self, ctx: &FheCircuitCtx) -> PackedDynamicGenericIntGraphNode<T, U> {
262 PackedDynamicGenericIntGraphNode {
263 bit_len: self.bit_len,
264 id: ctx.circuit.borrow_mut().add_node(T::graph_input(&self.ct)),
265 _phantom: PhantomData,
266 }
267 }
268
269 pub fn trivial_encrypt(val: U::PlaintextType, enc: &Encryption, n: usize) -> Self {
271 let msg = Self::encode(val, enc, n);
272
273 Self {
274 bit_len: n as u32,
275 ct: Arc::new(AtomicRefCell::new(
276 <T as PolynomialCiphertextOps>::trivial_encryption(&msg, enc),
277 )),
278 _phantom: PhantomData,
279 }
280 }
281
282 pub fn inner(&self) -> T {
284 self.ct.borrow().clone()
285 }
286}
287
288impl<U: Sign> PackedDynamicGenericInt<L1GlweCiphertext, U> {
289 pub fn recrypt(
291 &self,
292 enc: &Encryption,
293 eval: &KeylessEvaluation,
294 otp: &PublicOneTimePad,
295 ) -> EncryptedRecryptedGenricInt<U> {
296 let t = recrypt_one_time_pad(&self.ct.borrow(), otp, eval, enc);
297
298 EncryptedRecryptedGenricInt::new(self.bit_len, t)
299 }
300}