parasol_runtime/fluent/
mod.rs

1use std::sync::Arc;
2
3use bumpalo::Bump;
4use parasol_concurrency::AtomicRefCell;
5use petgraph::stable_graph::NodeIndex;
6use sunscreen_tfhe::{
7    PolynomialDegree,
8    entities::{GlevCiphertext, Polynomial, PolynomialRef},
9};
10
11use crate::{
12    CiphertextType, Encryption, Evaluation, FheCircuit, FheOp, L0LweCiphertext, L1GgswCiphertext,
13    L1GlweCiphertext, L1LweCiphertext, Params, SecretKey, TrivialZero,
14    crypto::{L1GlevCiphertext, PublicKey},
15    fhe_circuit::MuxMode,
16    safe_bincode::GetSize,
17};
18
19mod bit;
20mod dynamic_generic_int;
21mod dynamic_generic_int_graph_nodes;
22mod generic_int;
23mod generic_int_graph_nodes;
24mod int;
25mod packed_dynamic_generic_int_graph_node;
26mod packed_generic_int;
27mod packed_generic_int_graph_node;
28mod recrypted_int;
29mod uint;
30
31pub use bit::*;
32pub use dynamic_generic_int::*;
33pub use dynamic_generic_int_graph_nodes::*;
34pub use generic_int::*;
35pub use generic_int_graph_nodes::*;
36pub use int::*;
37pub use packed_dynamic_generic_int_graph_node::*;
38pub use packed_generic_int::*;
39pub use packed_generic_int_graph_node::*;
40pub use recrypted_int::*;
41pub use uint::*;
42
43/// A context for building FHE circuits out of high-level primitives (e.g.
44/// [UIntGraphNodes]).
45///
46/// # Panics
47/// The APIs in this module take the context as an immutable borrow to hide the
48/// allocator details from you, but rest assured you'll get a panic if you try
49/// to mutate using said primitives concurrently on multiple threads.
50pub struct FheCircuitCtx {
51    /// The underlying [`FheCircuit`].
52    pub circuit: AtomicRefCell<FheCircuit>,
53    one_cache: AtomicRefCell<[Option<NodeIndex>; 4]>,
54    zero_cache: AtomicRefCell<[Option<NodeIndex>; 4]>,
55    allocator: Bump,
56}
57
58impl Default for FheCircuitCtx {
59    fn default() -> Self {
60        Self::new()
61    }
62}
63
64impl FheCircuitCtx {
65    /// Create a new [`FheCircuitCtx`].
66    pub fn new() -> Self {
67        Self {
68            circuit: AtomicRefCell::new(FheCircuit::new()),
69            one_cache: AtomicRefCell::new([None; 4]),
70            zero_cache: AtomicRefCell::new([None; 4]),
71            allocator: Bump::new(),
72        }
73    }
74}
75
76/// Operations one can perform on ciphertexts that encrypt polynomials (e.g. [`L1GlweCiphertext`] and
77/// [`L1GlevCiphertext`]).
78pub trait PolynomialCiphertextOps {
79    /// Encrypt a polynomial under the given secret key. Returns the ciphertext.
80    fn encrypt_secret(msg: &PolynomialRef<u64>, enc: &Encryption, sk: &SecretKey) -> Self;
81
82    /// Encrypt a polynomial using the given public key. Returns the ciphertext.
83    fn encrypt(msg: &PolynomialRef<u64>, enc: &Encryption, pk: &PublicKey) -> Self;
84
85    /// Decrypt an encrypted polynomial using the given secret key. Returns the message.
86    fn decrypt(&self, enc: &Encryption, sk: &SecretKey) -> Polynomial<u64>;
87
88    /// Create a trivial encryption of the given polynomial.
89    fn trivial_encryption(polynomial: &PolynomialRef<u64>, encryption: &Encryption) -> Self;
90
91    /// Get the polynomial degree of messages for the given params.
92    fn poly_degree(params: &Params) -> PolynomialDegree;
93}
94
95impl PolynomialCiphertextOps for L1GlweCiphertext {
96    fn encrypt_secret(msg: &PolynomialRef<u64>, encryption: &Encryption, sk: &SecretKey) -> Self {
97        encryption.encrypt_glwe_l1_secret(msg, sk)
98    }
99
100    fn encrypt(msg: &PolynomialRef<u64>, encryption: &Encryption, pk: &PublicKey) -> Self {
101        encryption.encrypt_rlwe_l1(msg, pk)
102    }
103
104    fn trivial_encryption(polynomial: &PolynomialRef<u64>, encryption: &Encryption) -> Self {
105        encryption.trivial_glwe_l1(polynomial)
106    }
107
108    fn poly_degree(params: &Params) -> PolynomialDegree {
109        params.l1_poly_degree()
110    }
111
112    fn decrypt(&self, enc: &Encryption, sk: &SecretKey) -> Polynomial<u64> {
113        enc.decrypt_glwe_l1(self, sk)
114    }
115}
116
117/// Operations supported by all ciphertext types.
118pub trait CiphertextOps: GetSize + Clone
119where
120    Self: Sized,
121{
122    /// This is used internally to facilitate ciphertext conversion.
123    const CIPHERTEXT_TYPE: CiphertextType;
124
125    /// Allocate a new trivial zero ciphertext.
126    fn allocate(encryption: &Encryption) -> Self;
127
128    /// Encrypt a bit under the given secret key. Returns the ciphertext.
129    fn encrypt_secret(msg: bool, encryption: &Encryption, sk: &SecretKey) -> Self;
130
131    /// Decrypt and return the bit message contained in `self`.
132    fn decrypt(&self, encryption: &Encryption, sk: &SecretKey) -> bool;
133
134    /// Create an [`FheOp`] input corresponding to this ciphertext.
135    fn graph_input(bit: &Arc<AtomicRefCell<Self>>) -> FheOp;
136
137    /// Create an [`FheOp`] output corresponding to this ciphertext.
138    fn graph_output(bit: &Arc<AtomicRefCell<Self>>) -> FheOp;
139
140    /// Create a trivial encryption of the given bit message with ciphertext type `Self`.
141    ///
142    /// # Remarks
143    /// In the case of [`L1GgswCiphertext`]s, this will return a pre-encrypted one or zero, as
144    /// trivial encryptions of one would require knowing and would reveal the secret key.
145    fn trivial_encryption(bit: bool, encryption: &Encryption, eval: &Evaluation) -> Self;
146
147    /// Creates a zero encryption out of an existing ciphertext, so that it has
148    /// the same parameters without needing to pass in the `Encryption` object.
149    fn trivial_zero_from_existing(&self) -> Self;
150
151    /// Add an [`FheOp`] corresponding to this ciphertext's trivial one node.
152    fn graph_trivial_one() -> FheOp;
153
154    /// Add an [`FheOp`] corresponding to this ciphertext's trivial zero node.
155    fn graph_trivial_zero() -> FheOp;
156}
157
158impl CiphertextOps for L0LweCiphertext {
159    const CIPHERTEXT_TYPE: CiphertextType = CiphertextType::L0LweCiphertext;
160
161    fn allocate(encryption: &Encryption) -> Self {
162        encryption.allocate_lwe_l0()
163    }
164
165    fn encrypt_secret(msg: bool, encryption: &Encryption, sk: &SecretKey) -> Self {
166        encryption.encrypt_lwe_l0_secret(msg, sk)
167    }
168
169    fn decrypt(&self, encryption: &Encryption, sk: &SecretKey) -> bool {
170        encryption.decrypt_lwe_l0(self, sk)
171    }
172
173    fn graph_input(bit: &Arc<AtomicRefCell<Self>>) -> FheOp {
174        FheOp::InputLwe0(bit.clone())
175    }
176
177    fn graph_output(bit: &Arc<AtomicRefCell<Self>>) -> FheOp {
178        FheOp::OutputLwe0(bit.clone())
179    }
180
181    fn trivial_encryption(bit: bool, encryption: &Encryption, _eval: &Evaluation) -> Self {
182        if bit {
183            encryption.trivial_lwe_l0_one()
184        } else {
185            encryption.trivial_lwe_l0_zero()
186        }
187    }
188
189    fn trivial_zero_from_existing(&self) -> Self {
190        <L0LweCiphertext as TrivialZero>::trivial_zero_from_existing(self)
191    }
192
193    fn graph_trivial_one() -> FheOp {
194        FheOp::OneLwe0
195    }
196
197    fn graph_trivial_zero() -> FheOp {
198        FheOp::ZeroLwe0
199    }
200}
201impl CiphertextOps for L1LweCiphertext {
202    const CIPHERTEXT_TYPE: CiphertextType = CiphertextType::L1LweCiphertext;
203
204    fn allocate(encryption: &Encryption) -> Self {
205        encryption.allocate_lwe_l1()
206    }
207
208    fn encrypt_secret(msg: bool, encryption: &Encryption, sk: &SecretKey) -> Self {
209        encryption.encrypt_lwe_l1_secret(msg, sk)
210    }
211
212    fn decrypt(&self, encryption: &Encryption, sk: &SecretKey) -> bool {
213        encryption.decrypt_lwe_l1(self, sk)
214    }
215
216    fn graph_input(bit: &Arc<AtomicRefCell<Self>>) -> FheOp {
217        FheOp::InputLwe1(bit.clone())
218    }
219
220    fn graph_output(bit: &Arc<AtomicRefCell<Self>>) -> FheOp {
221        FheOp::OutputLwe1(bit.clone())
222    }
223
224    fn trivial_encryption(bit: bool, encryption: &Encryption, _eval: &Evaluation) -> Self {
225        if bit {
226            encryption.trivial_lwe_l1_one()
227        } else {
228            encryption.trivial_lwe_l1_zero()
229        }
230    }
231
232    fn trivial_zero_from_existing(&self) -> Self {
233        <L1LweCiphertext as TrivialZero>::trivial_zero_from_existing(self)
234    }
235
236    fn graph_trivial_one() -> FheOp {
237        unimplemented!()
238    }
239
240    fn graph_trivial_zero() -> FheOp {
241        unimplemented!()
242    }
243}
244impl CiphertextOps for L1GgswCiphertext {
245    const CIPHERTEXT_TYPE: CiphertextType = CiphertextType::L1GgswCiphertext;
246
247    fn allocate(encryption: &Encryption) -> Self {
248        encryption.allocate_ggsw_l1()
249    }
250
251    fn encrypt_secret(msg: bool, encryption: &Encryption, sk: &SecretKey) -> Self {
252        encryption.encrypt_ggsw_l1_secret(msg, sk)
253    }
254
255    fn decrypt(&self, encryption: &Encryption, sk: &SecretKey) -> bool {
256        encryption.decrypt_ggsw_l1(self, sk)
257    }
258
259    fn graph_input(bit: &Arc<AtomicRefCell<Self>>) -> FheOp {
260        FheOp::InputGgsw1(bit.clone())
261    }
262
263    fn graph_output(bit: &Arc<AtomicRefCell<Self>>) -> FheOp {
264        FheOp::OutputGgsw1(bit.clone())
265    }
266
267    fn trivial_encryption(bit: bool, _encryption: &Encryption, eval: &Evaluation) -> Self {
268        if bit {
269            eval.l1ggsw_one().to_owned()
270        } else {
271            eval.l1ggsw_zero().to_owned()
272        }
273    }
274
275    fn trivial_zero_from_existing(&self) -> Self {
276        <L1GgswCiphertext as TrivialZero>::trivial_zero_from_existing(self)
277    }
278
279    fn graph_trivial_one() -> FheOp {
280        FheOp::OneGgsw1
281    }
282
283    fn graph_trivial_zero() -> FheOp {
284        FheOp::ZeroGgsw1
285    }
286}
287impl CiphertextOps for L1GlweCiphertext {
288    const CIPHERTEXT_TYPE: CiphertextType = CiphertextType::L1GlweCiphertext;
289
290    fn allocate(encryption: &Encryption) -> Self {
291        encryption.allocate_glwe_l1()
292    }
293
294    fn encrypt_secret(msg: bool, encryption: &Encryption, sk: &SecretKey) -> Self {
295        let mut poly = Polynomial::new(&vec![
296            0u64;
297            encryption.params.l1_params.dim.polynomial_degree.0
298        ]);
299        poly.coeffs_mut()[0] = msg as u64;
300
301        encryption.encrypt_glwe_l1_secret(&poly, sk)
302    }
303
304    fn decrypt(&self, encryption: &Encryption, sk: &SecretKey) -> bool {
305        encryption.decrypt_glwe_l1(self, sk).coeffs()[0] == 1
306    }
307
308    fn graph_input(bit: &Arc<AtomicRefCell<Self>>) -> FheOp {
309        FheOp::InputGlwe1(bit.clone())
310    }
311
312    fn graph_output(bit: &Arc<AtomicRefCell<Self>>) -> FheOp {
313        FheOp::OutputGlwe1(bit.clone())
314    }
315
316    fn trivial_encryption(bit: bool, encryption: &Encryption, _eval: &Evaluation) -> Self {
317        if bit {
318            encryption.trivial_glwe_l1_one()
319        } else {
320            encryption.trivial_glwe_l1_zero()
321        }
322    }
323
324    fn trivial_zero_from_existing(&self) -> Self {
325        <L1GlweCiphertext as TrivialZero>::trivial_zero_from_existing(self)
326    }
327
328    fn graph_trivial_one() -> FheOp {
329        FheOp::OneGlwe1
330    }
331
332    fn graph_trivial_zero() -> FheOp {
333        FheOp::ZeroGlwe1
334    }
335}
336
337impl CiphertextOps for L1GlevCiphertext {
338    const CIPHERTEXT_TYPE: CiphertextType = CiphertextType::L1GlevCiphertext;
339
340    fn allocate(encryption: &Encryption) -> Self {
341        GlevCiphertext::new(&encryption.params.l1_params, &encryption.params.cbs_radix).into()
342    }
343
344    fn decrypt(&self, encryption: &Encryption, sk: &SecretKey) -> bool {
345        encryption.decrypt_glev_l1(self, sk).coeffs()[0] == 1
346    }
347
348    fn encrypt_secret(msg: bool, encryption: &Encryption, sk: &SecretKey) -> Self {
349        let mut poly = Polynomial::zero(encryption.params.l1_params.dim.polynomial_degree.0);
350        poly.coeffs_mut()[0] = msg as u64;
351
352        encryption.encrypt_glev_l1_secret(&poly, sk)
353    }
354
355    fn graph_input(bit: &Arc<AtomicRefCell<Self>>) -> FheOp {
356        FheOp::InputGlev1(bit.to_owned())
357    }
358
359    fn graph_output(bit: &Arc<AtomicRefCell<Self>>) -> FheOp {
360        FheOp::OutputGlev1(bit.to_owned())
361    }
362
363    fn graph_trivial_zero() -> FheOp {
364        FheOp::ZeroGlev1
365    }
366
367    fn graph_trivial_one() -> FheOp {
368        FheOp::OneGlev1
369    }
370
371    fn trivial_encryption(bit: bool, encryption: &Encryption, _eval: &Evaluation) -> Self {
372        if bit {
373            encryption.trivial_glev_l1_one()
374        } else {
375            encryption.trivial_glev_l1_zero()
376        }
377    }
378
379    fn trivial_zero_from_existing(&self) -> Self {
380        <L1GlevCiphertext as TrivialZero>::trivial_zero_from_existing(self)
381    }
382}
383
384/// A trait indicating one can perform Mux Operations over this ciphertext with a [`L1GgswCiphertext`]
385/// select bit. Used to abstract Mux circuits over different ciphertext types.
386pub trait Muxable: CiphertextOps {
387    /// The type of the `a` and `b` inputs and output of a mux operation. Allows the runtime to
388    /// dynamically choose [`FheOp::CMux`] or [`FheOp::GlevCMux`] as appropriate.
389    const MUX_MODE: MuxMode;
390}
391
392impl Muxable for L1GlweCiphertext {
393    const MUX_MODE: MuxMode = MuxMode::Glwe;
394}
395
396impl Muxable for L1GlevCiphertext {
397    const MUX_MODE: MuxMode = MuxMode::Glev;
398}
399
400#[cfg(test)]
401mod tests {
402    use bit::Bit;
403    use generic_int::GenericInt;
404    use rand::{RngCore, rng};
405    use uint::UInt;
406
407    use crate::test_utils::{
408        get_encryption_128, get_evaluation_128, get_secret_keys_128, make_uproc_128,
409    };
410
411    use super::*;
412
413    fn roundtrip<T: CiphertextOps, U: Sign, F: Fn() -> U::PlaintextType>(gen_pt: F) {
414        let sk = get_secret_keys_128();
415        let enc = get_encryption_128();
416
417        for _ in 0..32 {
418            // Make 16-bit integers.
419            let val = gen_pt();
420            let ct = GenericInt::<16, T, U>::encrypt_secret(val, &enc, &sk);
421            let actual = ct.decrypt(&enc, &sk);
422
423            assert_eq!(val, actual);
424        }
425    }
426
427    fn rand_u16() -> u128 {
428        (rng().next_u64() & 0xFFFF) as u128
429    }
430
431    fn rand_i16() -> i128 {
432        (rng().next_u64() & 0xFFFF) as i16 as i128
433    }
434
435    fn rand_u32() -> u128 {
436        (rng().next_u64() & 0xFFFFFFFF) as u128
437    }
438
439    fn rand_i32() -> i128 {
440        (rng().next_u64() & 0xFFFFFFFF) as i32 as i128
441    }
442
443    #[test]
444    fn can_roundtrip_l0_lwe() {
445        roundtrip::<L0LweCiphertext, Unsigned, _>(rand_u16);
446        roundtrip::<L0LweCiphertext, Signed, _>(rand_i16);
447    }
448
449    #[test]
450    fn can_roundtrip_l1_lwe() {
451        roundtrip::<L1LweCiphertext, Unsigned, _>(rand_u16);
452        roundtrip::<L1LweCiphertext, Signed, _>(rand_i16);
453    }
454
455    #[test]
456    fn can_roundtrip_l1_glwe() {
457        roundtrip::<L1GlweCiphertext, Unsigned, _>(rand_u16);
458        roundtrip::<L1GlweCiphertext, Signed, _>(rand_i16);
459    }
460
461    #[test]
462    fn can_roundtrip_l1_ggsw() {
463        roundtrip::<L1GgswCiphertext, Unsigned, _>(rand_u16);
464        roundtrip::<L1GgswCiphertext, Signed, _>(rand_i16);
465    }
466
467    fn input_output<T: CiphertextOps, U: Sign>(test_val: U::PlaintextType) {
468        let (uproc, fc) = make_uproc_128();
469        let enc = get_encryption_128();
470
471        let input = GenericInt::<16, T, U>::encrypt_secret(
472            test_val,
473            &get_encryption_128(),
474            &get_secret_keys_128(),
475        );
476
477        let graph = FheCircuitCtx::new();
478
479        let in_node = input.graph_inputs(&graph);
480        let output = in_node.collect_outputs(&graph, &enc);
481
482        uproc
483            .lock()
484            .unwrap()
485            .run_graph_blocking(&graph.circuit.borrow(), &fc)
486            .unwrap();
487
488        let actual = output.decrypt(&enc, &get_secret_keys_128());
489        assert_eq!(actual, test_val);
490    }
491
492    #[test]
493    fn can_input_output_generic_int_graph_l0_lwe() {
494        input_output::<L0LweCiphertext, Unsigned>(1234);
495        input_output::<L0LweCiphertext, Signed>(-104);
496    }
497
498    #[test]
499    fn can_input_output_generic_int_graph_l1_lwe() {
500        input_output::<L1LweCiphertext, Unsigned>(1234);
501        input_output::<L1LweCiphertext, Signed>(-104);
502    }
503
504    #[test]
505    fn can_input_output_generic_int_graph_l1_ggsw() {
506        input_output::<L1GgswCiphertext, Unsigned>(1234);
507        input_output::<L1GgswCiphertext, Signed>(-104);
508    }
509
510    #[test]
511    fn can_input_output_generic_int_graph_l1_glwe() {
512        input_output::<L1GlweCiphertext, Unsigned>(1234);
513        input_output::<L1GlweCiphertext, Signed>(-104);
514    }
515
516    #[test]
517    fn can_convert_ciphertexts() {
518        fn convert_test<T: CiphertextOps, U: CiphertextOps, V: Sign>(test_val: V::PlaintextType) {
519            let graph = FheCircuitCtx::new();
520            let enc = get_encryption_128();
521            let (uproc, fc) = make_uproc_128();
522            let sk = get_secret_keys_128();
523
524            let val = GenericInt::<16, T, V>::encrypt_secret(test_val, &enc, &sk);
525
526            let inputs = val.graph_inputs(&graph);
527            let converted = inputs.convert::<U>(&graph);
528            let outputs = converted.collect_outputs(&graph, &enc);
529
530            uproc
531                .lock()
532                .unwrap()
533                .run_graph_blocking(&graph.circuit.borrow(), &fc)
534                .unwrap();
535
536            let actual = outputs.decrypt(&enc, &sk);
537            assert_eq!(actual, test_val);
538        }
539
540        convert_test::<L0LweCiphertext, L1GgswCiphertext, Unsigned>(1234);
541        convert_test::<L0LweCiphertext, L1GgswCiphertext, Signed>(-106);
542        convert_test::<L0LweCiphertext, L1GlweCiphertext, Unsigned>(1234);
543        convert_test::<L0LweCiphertext, L1GlweCiphertext, Signed>(-106);
544        convert_test::<L0LweCiphertext, L1LweCiphertext, Unsigned>(1234);
545        convert_test::<L0LweCiphertext, L1LweCiphertext, Signed>(-106);
546        convert_test::<L0LweCiphertext, L0LweCiphertext, Unsigned>(1234);
547        convert_test::<L0LweCiphertext, L0LweCiphertext, Signed>(-106);
548
549        // GLEV ciphertexts are weird children, so give them a few cases.
550        convert_test::<L1GlevCiphertext, L1GgswCiphertext, Unsigned>(1234);
551        convert_test::<L1GlevCiphertext, L1GgswCiphertext, Signed>(-106);
552        convert_test::<L1GgswCiphertext, L1GlevCiphertext, Unsigned>(1234);
553        convert_test::<L1GgswCiphertext, L1GlevCiphertext, Signed>(-106);
554        convert_test::<L0LweCiphertext, L1GlevCiphertext, Unsigned>(1234);
555        convert_test::<L0LweCiphertext, L1GlevCiphertext, Signed>(-106);
556    }
557
558    #[test]
559    fn can_cmp() {
560        fn case<OutCt: Muxable, U: Sign>(
561            gt: bool,
562            eq: bool,
563            test_vals: (U::PlaintextType, U::PlaintextType),
564        ) {
565            let enc = &get_encryption_128();
566            let sk = get_secret_keys_128();
567            let ctx = FheCircuitCtx::new();
568            let (uproc, fc) = make_uproc_128();
569
570            let a = GenericInt::<16, L1GgswCiphertext, U>::encrypt_secret(test_vals.0, enc, &sk);
571            let b = GenericInt::<16, L1GgswCiphertext, U>::encrypt_secret(test_vals.1, enc, &sk);
572
573            let a_input = a.graph_inputs(&ctx);
574            let b_input = b.graph_inputs(&ctx);
575
576            let expect_gt = a_input
577                .cmp::<OutCt>(&b_input, &ctx, gt, eq)
578                .collect_output(&ctx, enc);
579            let expect_lt = b_input
580                .cmp::<OutCt>(&a_input, &ctx, gt, eq)
581                .collect_output(&ctx, enc);
582            let expect_eq = b_input
583                .cmp::<OutCt>(&b_input, &ctx, gt, eq)
584                .collect_output(&ctx, enc);
585
586            uproc
587                .lock()
588                .unwrap()
589                .run_graph_blocking(&ctx.circuit.borrow(), &fc)
590                .unwrap();
591
592            assert_eq!(expect_gt.decrypt(enc, &sk), gt);
593            assert_eq!(expect_lt.decrypt(enc, &sk), !gt);
594            assert_eq!(expect_eq.decrypt(enc, &sk), eq);
595        }
596
597        fn cases<OutCt: Muxable>() {
598            case::<OutCt, Unsigned>(false, false, (43, 42));
599            case::<OutCt, Signed>(false, false, (-35, -36));
600            case::<OutCt, Signed>(false, false, (1, -3));
601            case::<OutCt, Unsigned>(false, true, (43, 42));
602            case::<OutCt, Signed>(false, true, (-37, -38));
603            case::<OutCt, Signed>(false, true, (1, -3));
604            case::<OutCt, Unsigned>(true, false, (43, 42));
605            case::<OutCt, Signed>(true, false, (-37, -38));
606            case::<OutCt, Signed>(true, false, (1, -3));
607            case::<OutCt, Unsigned>(true, true, (43, 42));
608            case::<OutCt, Signed>(true, true, (-37, -38));
609            case::<OutCt, Signed>(true, true, (1, -3));
610        }
611
612        cases::<L1GlweCiphertext>();
613        cases::<L1GlevCiphertext>();
614    }
615
616    #[test]
617    fn can_eq() {
618        fn case<OutCt: Muxable, U: Sign>(
619            eq: bool,
620            test_vals: (U::PlaintextType, U::PlaintextType),
621        ) {
622            let enc = &get_encryption_128();
623            let sk = get_secret_keys_128();
624            let ctx = FheCircuitCtx::new();
625            let (uproc, fc) = make_uproc_128();
626
627            let a = GenericInt::<16, L1GgswCiphertext, U>::encrypt_secret(test_vals.0, enc, &sk);
628            let b = GenericInt::<16, L1GgswCiphertext, U>::encrypt_secret(test_vals.1, enc, &sk);
629
630            let a_input = a.graph_inputs(&ctx);
631            let b_input = b.graph_inputs(&ctx);
632
633            let calculated_eq = a_input
634                .eq::<OutCt>(&b_input, &ctx)
635                .collect_output(&ctx, enc);
636
637            uproc
638                .lock()
639                .unwrap()
640                .run_graph_blocking(&ctx.circuit.borrow(), &fc)
641                .unwrap();
642
643            assert_eq!(calculated_eq.decrypt(enc, &sk), eq);
644        }
645
646        fn cases<OutCt: Muxable>() {
647            case::<OutCt, Unsigned>(false, (43, 42));
648            case::<OutCt, Signed>(false, (-37, -38));
649            case::<OutCt, Unsigned>(true, (43, 43));
650            case::<OutCt, Signed>(true, (-37, -37));
651        }
652
653        cases::<L1GlweCiphertext>();
654        cases::<L1GlevCiphertext>();
655    }
656
657    // TODO this requires changing the `eq` method to use the correct `resize` method for creating the interleaved
658    // input, I am not bothered at this time
659    #[test]
660    fn can_eq_size_mismatch() {
661        fn case<const N: usize, const M: usize, OutCt: Muxable>(eq: bool) {
662            let enc = &get_encryption_128();
663            let sk = get_secret_keys_128();
664            let ctx = FheCircuitCtx::new();
665            let (uproc, fc) = make_uproc_128();
666
667            let (val_a, val_b) = if eq { (43, 43) } else { (43, 42) };
668
669            let a = UInt::<N, L1GgswCiphertext>::encrypt_secret(val_a, enc, &sk);
670            let b = UInt::<M, L1GgswCiphertext>::encrypt_secret(val_b, enc, &sk);
671
672            let a_input = a.graph_inputs(&ctx);
673            let b_input = b.graph_inputs(&ctx);
674
675            let calculated_eq = a_input
676                .eq::<OutCt>(&b_input, &ctx)
677                .collect_output(&ctx, enc);
678
679            uproc
680                .lock()
681                .unwrap()
682                .run_graph_blocking(&ctx.circuit.borrow(), &fc)
683                .unwrap();
684
685            assert_eq!(calculated_eq.decrypt(enc, &sk), eq);
686        }
687
688        // Test with 8-bit and 16-bit combinations
689        case::<8, 16, L1GlweCiphertext>(false);
690        case::<8, 16, L1GlweCiphertext>(true);
691        case::<16, 8, L1GlweCiphertext>(false);
692        case::<16, 8, L1GlweCiphertext>(true);
693
694        case::<8, 16, L1GlevCiphertext>(false);
695        case::<8, 16, L1GlevCiphertext>(true);
696        case::<16, 8, L1GlevCiphertext>(false);
697        case::<16, 8, L1GlevCiphertext>(true);
698    }
699
700    #[test]
701    fn can_neq() {
702        fn case<OutCt: Muxable, U: Sign>(
703            neq: bool,
704            test_vals: (U::PlaintextType, U::PlaintextType),
705        ) {
706            let enc = &get_encryption_128();
707            let sk = get_secret_keys_128();
708            let ctx = FheCircuitCtx::new();
709            let (uproc, fc) = make_uproc_128();
710
711            let a = GenericInt::<16, L1GgswCiphertext, U>::encrypt_secret(test_vals.0, enc, &sk);
712            let b = GenericInt::<16, L1GgswCiphertext, U>::encrypt_secret(test_vals.1, enc, &sk);
713
714            let a_input = a.graph_inputs(&ctx);
715            let b_input = b.graph_inputs(&ctx);
716
717            let calculated_neq = a_input
718                .neq::<OutCt>(&b_input, &ctx)
719                .collect_output(&ctx, enc);
720
721            uproc
722                .lock()
723                .unwrap()
724                .run_graph_blocking(&ctx.circuit.borrow(), &fc)
725                .unwrap();
726
727            assert_eq!(calculated_neq.decrypt(enc, &sk), neq);
728        }
729
730        fn cases<OutCt: Muxable>() {
731            case::<OutCt, Unsigned>(false, (43, 43));
732            case::<OutCt, Signed>(false, (-43, -43));
733            case::<OutCt, Unsigned>(true, (43, 42));
734            case::<OutCt, Signed>(true, (43, 42));
735            case::<OutCt, Signed>(true, (-43, -42));
736            case::<OutCt, Signed>(true, (42, -42));
737        }
738
739        cases::<L1GlweCiphertext>();
740        cases::<L1GlevCiphertext>();
741    }
742
743    // TODO this requires changing the `neq` method to use the correct `resize` method for creating the interleaved
744    // input, I am not bothered at this time
745    #[test]
746    fn can_neq_size_mismatch() {
747        fn case<const N: usize, const M: usize, OutCt: Muxable>(neq: bool) {
748            let enc = &get_encryption_128();
749            let sk = get_secret_keys_128();
750            let ctx = FheCircuitCtx::new();
751            let (uproc, fc) = make_uproc_128();
752
753            let (val_a, val_b) = if neq { (43, 42) } else { (43, 43) };
754
755            let a = UInt::<N, L1GgswCiphertext>::encrypt_secret(val_a, enc, &sk);
756            let b = UInt::<M, L1GgswCiphertext>::encrypt_secret(val_b, enc, &sk);
757
758            let a_input = a.graph_inputs(&ctx);
759            let b_input = b.graph_inputs(&ctx);
760
761            let calculated_neq = a_input
762                .neq::<OutCt>(&b_input, &ctx)
763                .collect_output(&ctx, enc);
764
765            uproc
766                .lock()
767                .unwrap()
768                .run_graph_blocking(&ctx.circuit.borrow(), &fc)
769                .unwrap();
770
771            assert_eq!(calculated_neq.decrypt(enc, &sk), neq);
772        }
773
774        // Test with 8-bit and 16-bit combinations
775        case::<8, 16, L1GlweCiphertext>(false);
776        case::<8, 16, L1GlweCiphertext>(true);
777        case::<16, 8, L1GlweCiphertext>(false);
778        case::<16, 8, L1GlweCiphertext>(true);
779
780        case::<8, 16, L1GlevCiphertext>(false);
781        case::<8, 16, L1GlevCiphertext>(true);
782        case::<16, 8, L1GlevCiphertext>(false);
783        case::<16, 8, L1GlevCiphertext>(true);
784    }
785
786    // TODO this requires changing the `cmp` method to use the correct `resize` method for creating the interleaved
787    // input, I am not bothered at this time
788    #[test]
789    fn can_cmp_size_mismatch() {
790        fn case<const N: usize, const M: usize, OutCt: Muxable>(gt: bool, eq: bool) {
791            let enc = &get_encryption_128();
792            let sk = get_secret_keys_128();
793            let ctx = FheCircuitCtx::new();
794            let (uproc, fc) = make_uproc_128();
795
796            let a = UInt::<N, L1GgswCiphertext>::encrypt_secret(43, enc, &sk);
797            let b = UInt::<M, L1GgswCiphertext>::encrypt_secret(42, enc, &sk);
798
799            let a_input = a.graph_inputs(&ctx);
800            let b_input = b.graph_inputs(&ctx);
801
802            let expect_gt = a_input
803                .cmp::<OutCt>(&b_input, &ctx, gt, eq)
804                .collect_output(&ctx, enc);
805            let expect_lt = b_input
806                .cmp::<OutCt>(&a_input, &ctx, gt, eq)
807                .collect_output(&ctx, enc);
808            let expect_eq = b_input
809                .cmp::<OutCt>(&b_input, &ctx, gt, eq)
810                .collect_output(&ctx, enc);
811
812            uproc
813                .lock()
814                .unwrap()
815                .run_graph_blocking(&ctx.circuit.borrow(), &fc)
816                .unwrap();
817
818            assert_eq!(expect_gt.decrypt(enc, &sk), gt);
819            assert_eq!(expect_lt.decrypt(enc, &sk), !gt);
820            assert_eq!(expect_eq.decrypt(enc, &sk), eq);
821        }
822
823        fn cases<OutCt: Muxable>() {
824            case::<8, 16, OutCt>(false, false);
825            case::<8, 16, OutCt>(false, true);
826            case::<8, 16, OutCt>(true, false);
827            case::<8, 16, OutCt>(true, true);
828
829            case::<16, 8, OutCt>(false, false);
830            case::<16, 8, OutCt>(false, true);
831            case::<16, 8, OutCt>(true, false);
832            case::<16, 8, OutCt>(true, true);
833        }
834
835        cases::<L1GlweCiphertext>();
836        cases::<L1GlevCiphertext>();
837    }
838
839    #[test]
840    fn can_cmp_trivial_nontrivial_ggsw() {
841        fn case<OutCt: Muxable, U: Sign>(
842            gt: bool,
843            eq: bool,
844            test_vals: (U::PlaintextType, U::PlaintextType),
845        ) {
846            let enc = &get_encryption_128();
847            let eval = &get_evaluation_128();
848            let sk = get_secret_keys_128();
849            let ctx = FheCircuitCtx::new();
850            let (uproc, fc) = make_uproc_128();
851
852            let a = GenericInt::<16, L1GgswCiphertext, U>::trivial(test_vals.0, enc, eval);
853            let b = GenericInt::<16, L1GgswCiphertext, U>::encrypt_secret(test_vals.1, enc, &sk);
854
855            let a_input = a.graph_inputs(&ctx);
856            let b_input = b.graph_inputs(&ctx);
857
858            let expect_gt = a_input
859                .cmp::<OutCt>(&b_input, &ctx, gt, eq)
860                .collect_output(&ctx, enc);
861            let expect_lt = b_input
862                .cmp::<OutCt>(&a_input, &ctx, gt, eq)
863                .collect_output(&ctx, enc);
864            let expect_eq = b_input
865                .cmp::<OutCt>(&b_input, &ctx, gt, eq)
866                .collect_output(&ctx, enc);
867
868            uproc
869                .lock()
870                .unwrap()
871                .run_graph_blocking(&ctx.circuit.borrow(), &fc)
872                .unwrap();
873
874            assert_eq!(expect_gt.decrypt(enc, &sk), gt);
875            assert_eq!(expect_lt.decrypt(enc, &sk), !gt);
876            assert_eq!(expect_eq.decrypt(enc, &sk), eq);
877        }
878
879        fn cases<OutCt: Muxable>() {
880            case::<OutCt, Unsigned>(false, false, (43, 42));
881            case::<OutCt, Signed>(false, false, (-35, -36));
882            case::<OutCt, Signed>(false, false, (1, -42));
883            case::<OutCt, Unsigned>(false, true, (43, 42));
884            case::<OutCt, Signed>(false, true, (-35, -36));
885            case::<OutCt, Signed>(false, true, (1, -42));
886            case::<OutCt, Unsigned>(true, false, (43, 42));
887            case::<OutCt, Signed>(true, false, (-35, -36));
888            case::<OutCt, Signed>(true, false, (1, -42));
889            case::<OutCt, Unsigned>(true, true, (43, 42));
890            case::<OutCt, Signed>(true, true, (-35, -36));
891            case::<OutCt, Signed>(true, true, (1, -42));
892        }
893
894        cases::<L1GlweCiphertext>();
895        cases::<L1GlevCiphertext>();
896    }
897
898    #[test]
899    fn can_select() {
900        fn case<U: Sign>(test_vals: (U::PlaintextType, U::PlaintextType)) {
901            let enc = &get_encryption_128();
902            let sk = get_secret_keys_128();
903            let ctx = FheCircuitCtx::new();
904            let (uproc, fc) = make_uproc_128();
905
906            let sel_false =
907                Bit::<L1GgswCiphertext>::encrypt_secret(false, enc, &sk).graph_input(&ctx);
908            let sel_true =
909                Bit::<L1GgswCiphertext>::encrypt_secret(true, enc, &sk).graph_input(&ctx);
910
911            let a: GenericIntGraphNodes<'_, 16, L1GlweCiphertext, U> =
912                GenericInt::<16, L1GlweCiphertext, U>::encrypt_secret(test_vals.0, enc, &sk)
913                    .graph_inputs(&ctx)
914                    .into();
915            let b: GenericIntGraphNodes<'_, 16, L1GlweCiphertext, U> =
916                GenericInt::<16, L1GlweCiphertext, U>::encrypt_secret(test_vals.1, enc, &sk)
917                    .graph_inputs(&ctx)
918                    .into();
919
920            let sel_false = sel_false
921                .select::<16, U, L1GlweCiphertext>(&a, &b, &ctx)
922                .collect_outputs(&ctx, enc);
923            let sel_true = sel_true
924                .select::<16, U, L1GlweCiphertext>(&a, &b, &ctx)
925                .collect_outputs(&ctx, enc);
926
927            uproc
928                .lock()
929                .unwrap()
930                .run_graph_blocking(&ctx.circuit.borrow(), &fc)
931                .unwrap();
932
933            assert_eq!(sel_false.decrypt(enc, &sk), test_vals.1);
934            assert_eq!(sel_true.decrypt(enc, &sk), test_vals.0);
935        }
936
937        case::<Unsigned>((42, 24));
938        case::<Signed>((-94, -112));
939    }
940
941    #[test]
942    fn can_select_plain() {
943        fn case<U: Sign>(test_vals: (U::PlaintextType, U::PlaintextType)) {
944            let enc = &get_encryption_128();
945            let eval = &get_evaluation_128();
946            let sk = get_secret_keys_128();
947            let ctx = FheCircuitCtx::new();
948            let (uproc, fc) = make_uproc_128();
949
950            let sel_false =
951                Bit::<L1GgswCiphertext>::trivial_encryption(false, enc, eval).graph_input(&ctx);
952            let sel_true =
953                Bit::<L1GgswCiphertext>::trivial_encryption(true, enc, eval).graph_input(&ctx);
954
955            let a: GenericIntGraphNodes<'_, 16, L1GlweCiphertext, U> =
956                GenericInt::<16, L1GlweCiphertext, U>::encrypt_secret(test_vals.0, enc, &sk)
957                    .graph_inputs(&ctx)
958                    .into();
959            let b: GenericIntGraphNodes<'_, 16, L1GlweCiphertext, U> =
960                GenericInt::<16, L1GlweCiphertext, U>::encrypt_secret(test_vals.1, enc, &sk)
961                    .graph_inputs(&ctx)
962                    .into();
963
964            let sel_false = sel_false
965                .select::<16, U, L1GlweCiphertext>(&a, &b, &ctx)
966                .collect_outputs(&ctx, enc);
967            let sel_true = sel_true
968                .select::<16, U, L1GlweCiphertext>(&a, &b, &ctx)
969                .collect_outputs(&ctx, enc);
970
971            uproc
972                .lock()
973                .unwrap()
974                .run_graph_blocking(&ctx.circuit.borrow(), &fc)
975                .unwrap();
976
977            assert_eq!(sel_false.decrypt(enc, &sk), test_vals.1);
978            assert_eq!(sel_true.decrypt(enc, &sk), test_vals.0);
979        }
980
981        case::<Unsigned>((42, 24));
982        case::<Signed>((-94, -112));
983    }
984
985    #[test]
986    fn can_sub() {
987        fn case<OutCt: Muxable, U: Sign>(
988            test_vals: (U::PlaintextType, U::PlaintextType, U::PlaintextType),
989        ) {
990            let enc = &get_encryption_128();
991            let sk = get_secret_keys_128();
992            let ctx = FheCircuitCtx::new();
993            let (uproc, fc) = make_uproc_128();
994
995            let a = GenericInt::<16, L1GgswCiphertext, U>::encrypt_secret(test_vals.0, enc, &sk)
996                .graph_inputs(&ctx);
997            let b = GenericInt::<16, L1GgswCiphertext, U>::encrypt_secret(test_vals.1, enc, &sk)
998                .graph_inputs(&ctx);
999
1000            let c = a.sub::<OutCt>(&b, &ctx).collect_outputs(&ctx, enc);
1001
1002            uproc
1003                .lock()
1004                .unwrap()
1005                .run_graph_blocking(&ctx.circuit.borrow(), &fc)
1006                .unwrap();
1007
1008            assert_eq!(c.decrypt(enc, &sk), test_vals.2);
1009        }
1010
1011        fn cases<OutCt: Muxable>() {
1012            case::<OutCt, Unsigned>((42, 16, 26));
1013            case::<OutCt, Signed>((-7, -9, 2));
1014            case::<OutCt, Signed>((-7, 2, -9));
1015            case::<OutCt, Signed>((-7, -5, -2));
1016            case::<OutCt, Signed>((2, -7, 9));
1017        }
1018
1019        cases::<L1GlweCiphertext>();
1020        cases::<L1GlevCiphertext>();
1021    }
1022
1023    #[test]
1024    fn trivial_generic_int_encryption() {
1025        fn case<T: CiphertextOps, U: Sign, F: Fn() -> U::PlaintextType>(gen_pt: F) {
1026            let enc = get_encryption_128();
1027            let eval = &get_evaluation_128();
1028            let sk = get_secret_keys_128();
1029
1030            let expected = gen_pt();
1031
1032            let val = GenericInt::<32, T, U>::trivial(expected, &enc, eval);
1033
1034            assert_eq!(val.decrypt(&enc, &sk), expected);
1035        }
1036
1037        case::<L0LweCiphertext, Unsigned, _>(rand_u32);
1038        case::<L0LweCiphertext, Signed, _>(rand_i32);
1039        case::<L1LweCiphertext, Unsigned, _>(rand_u32);
1040        case::<L1LweCiphertext, Signed, _>(rand_i32);
1041        case::<L1GlweCiphertext, Unsigned, _>(rand_u32);
1042        case::<L1GlweCiphertext, Signed, _>(rand_i32);
1043        case::<L1GgswCiphertext, Unsigned, _>(rand_u32);
1044        case::<L1GgswCiphertext, Signed, _>(rand_i32);
1045    }
1046
1047    #[test]
1048    fn can_resize() {
1049        fn case<T: CiphertextOps, U: Sign>(
1050            test_vals: (U::PlaintextType, U::PlaintextType, U::PlaintextType),
1051        ) {
1052            let enc = get_encryption_128();
1053            let sk = get_secret_keys_128();
1054            let ctx = FheCircuitCtx::new();
1055            let (proc, fc) = make_uproc_128();
1056
1057            let val = GenericInt::<16, T, U>::encrypt_secret(test_vals.0, &enc, &sk);
1058            let res = val
1059                .graph_inputs(&ctx)
1060                .resize(&ctx, 24)
1061                .collect_outputs(&ctx, &enc);
1062
1063            proc.lock()
1064                .unwrap()
1065                .run_graph_blocking(&ctx.circuit.borrow(), &fc)
1066                .unwrap();
1067
1068            assert_eq!(res.decrypt(&enc, &sk), test_vals.1);
1069
1070            let res = val
1071                .graph_inputs(&ctx)
1072                .resize(&ctx, 8)
1073                .collect_outputs(&ctx, &enc);
1074
1075            proc.lock()
1076                .unwrap()
1077                .run_graph_blocking(&ctx.circuit.borrow(), &fc)
1078                .unwrap();
1079
1080            assert_eq!(res.decrypt(&enc, &sk), test_vals.2);
1081        }
1082
1083        case::<L0LweCiphertext, Unsigned>((1234, 1234, 210));
1084        case::<L0LweCiphertext, Signed>((1234, 1234, 82));
1085        case::<L0LweCiphertext, Signed>((-1234, -1234, -82));
1086        // unimplemented
1087        //case::<L1LweCiphertext>();
1088        case::<L1GlweCiphertext, Unsigned>((1234, 1234, 210));
1089        case::<L1GlweCiphertext, Signed>((1234, 1234, 82));
1090        case::<L1GlweCiphertext, Signed>((-1234, -1234, -82));
1091        case::<L1GgswCiphertext, Unsigned>((1234, 1234, 210));
1092        case::<L1GgswCiphertext, Signed>((1234, 1234, 82));
1093        case::<L1GgswCiphertext, Signed>((-1234, -1234, -82));
1094    }
1095
1096    #[test]
1097    fn can_add() {
1098        fn case<OutCt: Muxable, U: Sign>(
1099            test_vals: (U::PlaintextType, U::PlaintextType, U::PlaintextType),
1100        ) {
1101            let enc = get_encryption_128();
1102            let sk = get_secret_keys_128();
1103            let ctx = FheCircuitCtx::new();
1104            let (proc, fc) = make_uproc_128();
1105
1106            let a = GenericInt::<16, L1GgswCiphertext, U>::encrypt_secret(test_vals.0, &enc, &sk)
1107                .graph_inputs(&ctx);
1108            let b = GenericInt::<16, L1GgswCiphertext, U>::encrypt_secret(test_vals.1, &enc, &sk)
1109                .graph_inputs(&ctx);
1110
1111            let c = a.add::<OutCt>(&b, &ctx).collect_outputs(&ctx, &enc);
1112
1113            println!("{:#?}", *ctx.circuit.borrow());
1114
1115            proc.lock()
1116                .unwrap()
1117                .run_graph_blocking(&ctx.circuit.borrow(), &fc)
1118                .unwrap();
1119
1120            assert_eq!(c.decrypt(&enc, &sk), test_vals.2);
1121        }
1122
1123        fn cases<OutCt: Muxable>() {
1124            case::<OutCt, Unsigned>((42, 16, 58));
1125            case::<OutCt, Signed>((-6, 16, 10));
1126            case::<OutCt, Signed>((-6, -7, -13));
1127            case::<OutCt, Signed>((-8, 2, -6));
1128        }
1129
1130        cases::<L1GlweCiphertext>();
1131        cases::<L1GlevCiphertext>();
1132    }
1133
1134    #[test]
1135    fn can_mul() {
1136        fn case<OutCt: Muxable, U: Sign>(
1137            test_vals: (U::PlaintextType, U::PlaintextType, U::PlaintextType),
1138        ) {
1139            let enc = get_encryption_128();
1140            let sk = get_secret_keys_128();
1141            let ctx = FheCircuitCtx::new();
1142            let (proc, fc) = make_uproc_128();
1143
1144            let a = GenericInt::<16, L1GgswCiphertext, U>::encrypt_secret(test_vals.0, &enc, &sk)
1145                .graph_inputs(&ctx);
1146            let b = GenericInt::<16, L1GgswCiphertext, U>::encrypt_secret(test_vals.1, &enc, &sk)
1147                .graph_inputs(&ctx);
1148
1149            let c = a.mul::<OutCt>(&b, &ctx).collect_outputs(&ctx, &enc);
1150
1151            proc.lock()
1152                .unwrap()
1153                .run_graph_blocking(&ctx.circuit.borrow(), &fc)
1154                .unwrap();
1155
1156            assert_eq!(c.decrypt(&enc, &sk), test_vals.2);
1157        }
1158
1159        case::<L1GlweCiphertext, Unsigned>((42, 16, 672));
1160        case::<L1GlweCiphertext, Signed>((42, 16, 672));
1161        case::<L1GlweCiphertext, Signed>((42, -16, -672));
1162        case::<L1GlweCiphertext, Signed>((-42, -16, 672));
1163    }
1164}