1use std::sync::Arc;
2
3use bumpalo::Bump;
4use parasol_concurrency::AtomicRefCell;
5use petgraph::stable_graph::NodeIndex;
6use sunscreen_tfhe::{
7 PolynomialDegree,
8 entities::{GlevCiphertext, Polynomial, PolynomialRef},
9};
10
11use crate::{
12 CiphertextType, Encryption, Evaluation, FheCircuit, FheOp, L0LweCiphertext, L1GgswCiphertext,
13 L1GlweCiphertext, L1LweCiphertext, Params, SecretKey, TrivialZero,
14 crypto::{L1GlevCiphertext, PublicKey},
15 fhe_circuit::MuxMode,
16 safe_bincode::GetSize,
17};
18
19mod bit;
20mod dynamic_generic_int;
21mod dynamic_generic_int_graph_nodes;
22mod generic_int;
23mod generic_int_graph_nodes;
24mod int;
25mod packed_dynamic_generic_int_graph_node;
26mod packed_generic_int;
27mod packed_generic_int_graph_node;
28mod recrypted_int;
29mod uint;
30
31pub use bit::*;
32pub use dynamic_generic_int::*;
33pub use dynamic_generic_int_graph_nodes::*;
34pub use generic_int::*;
35pub use generic_int_graph_nodes::*;
36pub use int::*;
37pub use packed_dynamic_generic_int_graph_node::*;
38pub use packed_generic_int::*;
39pub use packed_generic_int_graph_node::*;
40pub use recrypted_int::*;
41pub use uint::*;
42
43pub struct FheCircuitCtx {
51 pub circuit: AtomicRefCell<FheCircuit>,
53 one_cache: AtomicRefCell<[Option<NodeIndex>; 4]>,
54 zero_cache: AtomicRefCell<[Option<NodeIndex>; 4]>,
55 allocator: Bump,
56}
57
58impl Default for FheCircuitCtx {
59 fn default() -> Self {
60 Self::new()
61 }
62}
63
64impl FheCircuitCtx {
65 pub fn new() -> Self {
67 Self {
68 circuit: AtomicRefCell::new(FheCircuit::new()),
69 one_cache: AtomicRefCell::new([None; 4]),
70 zero_cache: AtomicRefCell::new([None; 4]),
71 allocator: Bump::new(),
72 }
73 }
74}
75
76pub trait PolynomialCiphertextOps {
79 fn encrypt_secret(msg: &PolynomialRef<u64>, enc: &Encryption, sk: &SecretKey) -> Self;
81
82 fn encrypt(msg: &PolynomialRef<u64>, enc: &Encryption, pk: &PublicKey) -> Self;
84
85 fn decrypt(&self, enc: &Encryption, sk: &SecretKey) -> Polynomial<u64>;
87
88 fn trivial_encryption(polynomial: &PolynomialRef<u64>, encryption: &Encryption) -> Self;
90
91 fn poly_degree(params: &Params) -> PolynomialDegree;
93}
94
95impl PolynomialCiphertextOps for L1GlweCiphertext {
96 fn encrypt_secret(msg: &PolynomialRef<u64>, encryption: &Encryption, sk: &SecretKey) -> Self {
97 encryption.encrypt_glwe_l1_secret(msg, sk)
98 }
99
100 fn encrypt(msg: &PolynomialRef<u64>, encryption: &Encryption, pk: &PublicKey) -> Self {
101 encryption.encrypt_rlwe_l1(msg, pk)
102 }
103
104 fn trivial_encryption(polynomial: &PolynomialRef<u64>, encryption: &Encryption) -> Self {
105 encryption.trivial_glwe_l1(polynomial)
106 }
107
108 fn poly_degree(params: &Params) -> PolynomialDegree {
109 params.l1_poly_degree()
110 }
111
112 fn decrypt(&self, enc: &Encryption, sk: &SecretKey) -> Polynomial<u64> {
113 enc.decrypt_glwe_l1(self, sk)
114 }
115}
116
117pub trait CiphertextOps: GetSize + Clone
119where
120 Self: Sized,
121{
122 const CIPHERTEXT_TYPE: CiphertextType;
124
125 fn allocate(encryption: &Encryption) -> Self;
127
128 fn encrypt_secret(msg: bool, encryption: &Encryption, sk: &SecretKey) -> Self;
130
131 fn decrypt(&self, encryption: &Encryption, sk: &SecretKey) -> bool;
133
134 fn graph_input(bit: &Arc<AtomicRefCell<Self>>) -> FheOp;
136
137 fn graph_output(bit: &Arc<AtomicRefCell<Self>>) -> FheOp;
139
140 fn trivial_encryption(bit: bool, encryption: &Encryption, eval: &Evaluation) -> Self;
146
147 fn trivial_zero_from_existing(&self) -> Self;
150
151 fn graph_trivial_one() -> FheOp;
153
154 fn graph_trivial_zero() -> FheOp;
156}
157
158impl CiphertextOps for L0LweCiphertext {
159 const CIPHERTEXT_TYPE: CiphertextType = CiphertextType::L0LweCiphertext;
160
161 fn allocate(encryption: &Encryption) -> Self {
162 encryption.allocate_lwe_l0()
163 }
164
165 fn encrypt_secret(msg: bool, encryption: &Encryption, sk: &SecretKey) -> Self {
166 encryption.encrypt_lwe_l0_secret(msg, sk)
167 }
168
169 fn decrypt(&self, encryption: &Encryption, sk: &SecretKey) -> bool {
170 encryption.decrypt_lwe_l0(self, sk)
171 }
172
173 fn graph_input(bit: &Arc<AtomicRefCell<Self>>) -> FheOp {
174 FheOp::InputLwe0(bit.clone())
175 }
176
177 fn graph_output(bit: &Arc<AtomicRefCell<Self>>) -> FheOp {
178 FheOp::OutputLwe0(bit.clone())
179 }
180
181 fn trivial_encryption(bit: bool, encryption: &Encryption, _eval: &Evaluation) -> Self {
182 if bit {
183 encryption.trivial_lwe_l0_one()
184 } else {
185 encryption.trivial_lwe_l0_zero()
186 }
187 }
188
189 fn trivial_zero_from_existing(&self) -> Self {
190 <L0LweCiphertext as TrivialZero>::trivial_zero_from_existing(self)
191 }
192
193 fn graph_trivial_one() -> FheOp {
194 FheOp::OneLwe0
195 }
196
197 fn graph_trivial_zero() -> FheOp {
198 FheOp::ZeroLwe0
199 }
200}
201impl CiphertextOps for L1LweCiphertext {
202 const CIPHERTEXT_TYPE: CiphertextType = CiphertextType::L1LweCiphertext;
203
204 fn allocate(encryption: &Encryption) -> Self {
205 encryption.allocate_lwe_l1()
206 }
207
208 fn encrypt_secret(msg: bool, encryption: &Encryption, sk: &SecretKey) -> Self {
209 encryption.encrypt_lwe_l1_secret(msg, sk)
210 }
211
212 fn decrypt(&self, encryption: &Encryption, sk: &SecretKey) -> bool {
213 encryption.decrypt_lwe_l1(self, sk)
214 }
215
216 fn graph_input(bit: &Arc<AtomicRefCell<Self>>) -> FheOp {
217 FheOp::InputLwe1(bit.clone())
218 }
219
220 fn graph_output(bit: &Arc<AtomicRefCell<Self>>) -> FheOp {
221 FheOp::OutputLwe1(bit.clone())
222 }
223
224 fn trivial_encryption(bit: bool, encryption: &Encryption, _eval: &Evaluation) -> Self {
225 if bit {
226 encryption.trivial_lwe_l1_one()
227 } else {
228 encryption.trivial_lwe_l1_zero()
229 }
230 }
231
232 fn trivial_zero_from_existing(&self) -> Self {
233 <L1LweCiphertext as TrivialZero>::trivial_zero_from_existing(self)
234 }
235
236 fn graph_trivial_one() -> FheOp {
237 unimplemented!()
238 }
239
240 fn graph_trivial_zero() -> FheOp {
241 unimplemented!()
242 }
243}
244impl CiphertextOps for L1GgswCiphertext {
245 const CIPHERTEXT_TYPE: CiphertextType = CiphertextType::L1GgswCiphertext;
246
247 fn allocate(encryption: &Encryption) -> Self {
248 encryption.allocate_ggsw_l1()
249 }
250
251 fn encrypt_secret(msg: bool, encryption: &Encryption, sk: &SecretKey) -> Self {
252 encryption.encrypt_ggsw_l1_secret(msg, sk)
253 }
254
255 fn decrypt(&self, encryption: &Encryption, sk: &SecretKey) -> bool {
256 encryption.decrypt_ggsw_l1(self, sk)
257 }
258
259 fn graph_input(bit: &Arc<AtomicRefCell<Self>>) -> FheOp {
260 FheOp::InputGgsw1(bit.clone())
261 }
262
263 fn graph_output(bit: &Arc<AtomicRefCell<Self>>) -> FheOp {
264 FheOp::OutputGgsw1(bit.clone())
265 }
266
267 fn trivial_encryption(bit: bool, _encryption: &Encryption, eval: &Evaluation) -> Self {
268 if bit {
269 eval.l1ggsw_one().to_owned()
270 } else {
271 eval.l1ggsw_zero().to_owned()
272 }
273 }
274
275 fn trivial_zero_from_existing(&self) -> Self {
276 <L1GgswCiphertext as TrivialZero>::trivial_zero_from_existing(self)
277 }
278
279 fn graph_trivial_one() -> FheOp {
280 FheOp::OneGgsw1
281 }
282
283 fn graph_trivial_zero() -> FheOp {
284 FheOp::ZeroGgsw1
285 }
286}
287impl CiphertextOps for L1GlweCiphertext {
288 const CIPHERTEXT_TYPE: CiphertextType = CiphertextType::L1GlweCiphertext;
289
290 fn allocate(encryption: &Encryption) -> Self {
291 encryption.allocate_glwe_l1()
292 }
293
294 fn encrypt_secret(msg: bool, encryption: &Encryption, sk: &SecretKey) -> Self {
295 let mut poly = Polynomial::new(&vec![
296 0u64;
297 encryption.params.l1_params.dim.polynomial_degree.0
298 ]);
299 poly.coeffs_mut()[0] = msg as u64;
300
301 encryption.encrypt_glwe_l1_secret(&poly, sk)
302 }
303
304 fn decrypt(&self, encryption: &Encryption, sk: &SecretKey) -> bool {
305 encryption.decrypt_glwe_l1(self, sk).coeffs()[0] == 1
306 }
307
308 fn graph_input(bit: &Arc<AtomicRefCell<Self>>) -> FheOp {
309 FheOp::InputGlwe1(bit.clone())
310 }
311
312 fn graph_output(bit: &Arc<AtomicRefCell<Self>>) -> FheOp {
313 FheOp::OutputGlwe1(bit.clone())
314 }
315
316 fn trivial_encryption(bit: bool, encryption: &Encryption, _eval: &Evaluation) -> Self {
317 if bit {
318 encryption.trivial_glwe_l1_one()
319 } else {
320 encryption.trivial_glwe_l1_zero()
321 }
322 }
323
324 fn trivial_zero_from_existing(&self) -> Self {
325 <L1GlweCiphertext as TrivialZero>::trivial_zero_from_existing(self)
326 }
327
328 fn graph_trivial_one() -> FheOp {
329 FheOp::OneGlwe1
330 }
331
332 fn graph_trivial_zero() -> FheOp {
333 FheOp::ZeroGlwe1
334 }
335}
336
337impl CiphertextOps for L1GlevCiphertext {
338 const CIPHERTEXT_TYPE: CiphertextType = CiphertextType::L1GlevCiphertext;
339
340 fn allocate(encryption: &Encryption) -> Self {
341 GlevCiphertext::new(&encryption.params.l1_params, &encryption.params.cbs_radix).into()
342 }
343
344 fn decrypt(&self, encryption: &Encryption, sk: &SecretKey) -> bool {
345 encryption.decrypt_glev_l1(self, sk).coeffs()[0] == 1
346 }
347
348 fn encrypt_secret(msg: bool, encryption: &Encryption, sk: &SecretKey) -> Self {
349 let mut poly = Polynomial::zero(encryption.params.l1_params.dim.polynomial_degree.0);
350 poly.coeffs_mut()[0] = msg as u64;
351
352 encryption.encrypt_glev_l1_secret(&poly, sk)
353 }
354
355 fn graph_input(bit: &Arc<AtomicRefCell<Self>>) -> FheOp {
356 FheOp::InputGlev1(bit.to_owned())
357 }
358
359 fn graph_output(bit: &Arc<AtomicRefCell<Self>>) -> FheOp {
360 FheOp::OutputGlev1(bit.to_owned())
361 }
362
363 fn graph_trivial_zero() -> FheOp {
364 FheOp::ZeroGlev1
365 }
366
367 fn graph_trivial_one() -> FheOp {
368 FheOp::OneGlev1
369 }
370
371 fn trivial_encryption(bit: bool, encryption: &Encryption, _eval: &Evaluation) -> Self {
372 if bit {
373 encryption.trivial_glev_l1_one()
374 } else {
375 encryption.trivial_glev_l1_zero()
376 }
377 }
378
379 fn trivial_zero_from_existing(&self) -> Self {
380 <L1GlevCiphertext as TrivialZero>::trivial_zero_from_existing(self)
381 }
382}
383
384pub trait Muxable: CiphertextOps {
387 const MUX_MODE: MuxMode;
390}
391
392impl Muxable for L1GlweCiphertext {
393 const MUX_MODE: MuxMode = MuxMode::Glwe;
394}
395
396impl Muxable for L1GlevCiphertext {
397 const MUX_MODE: MuxMode = MuxMode::Glev;
398}
399
400#[cfg(test)]
401mod tests {
402 use bit::Bit;
403 use generic_int::GenericInt;
404 use rand::{RngCore, rng};
405 use uint::UInt;
406
407 use crate::test_utils::{
408 get_encryption_128, get_evaluation_128, get_secret_keys_128, make_uproc_128,
409 };
410
411 use super::*;
412
413 fn roundtrip<T: CiphertextOps, U: Sign, F: Fn() -> U::PlaintextType>(gen_pt: F) {
414 let sk = get_secret_keys_128();
415 let enc = get_encryption_128();
416
417 for _ in 0..32 {
418 let val = gen_pt();
420 let ct = GenericInt::<16, T, U>::encrypt_secret(val, &enc, &sk);
421 let actual = ct.decrypt(&enc, &sk);
422
423 assert_eq!(val, actual);
424 }
425 }
426
427 fn rand_u16() -> u128 {
428 (rng().next_u64() & 0xFFFF) as u128
429 }
430
431 fn rand_i16() -> i128 {
432 (rng().next_u64() & 0xFFFF) as i16 as i128
433 }
434
435 fn rand_u32() -> u128 {
436 (rng().next_u64() & 0xFFFFFFFF) as u128
437 }
438
439 fn rand_i32() -> i128 {
440 (rng().next_u64() & 0xFFFFFFFF) as i32 as i128
441 }
442
443 #[test]
444 fn can_roundtrip_l0_lwe() {
445 roundtrip::<L0LweCiphertext, Unsigned, _>(rand_u16);
446 roundtrip::<L0LweCiphertext, Signed, _>(rand_i16);
447 }
448
449 #[test]
450 fn can_roundtrip_l1_lwe() {
451 roundtrip::<L1LweCiphertext, Unsigned, _>(rand_u16);
452 roundtrip::<L1LweCiphertext, Signed, _>(rand_i16);
453 }
454
455 #[test]
456 fn can_roundtrip_l1_glwe() {
457 roundtrip::<L1GlweCiphertext, Unsigned, _>(rand_u16);
458 roundtrip::<L1GlweCiphertext, Signed, _>(rand_i16);
459 }
460
461 #[test]
462 fn can_roundtrip_l1_ggsw() {
463 roundtrip::<L1GgswCiphertext, Unsigned, _>(rand_u16);
464 roundtrip::<L1GgswCiphertext, Signed, _>(rand_i16);
465 }
466
467 fn input_output<T: CiphertextOps, U: Sign>(test_val: U::PlaintextType) {
468 let (uproc, fc) = make_uproc_128();
469 let enc = get_encryption_128();
470
471 let input = GenericInt::<16, T, U>::encrypt_secret(
472 test_val,
473 &get_encryption_128(),
474 &get_secret_keys_128(),
475 );
476
477 let graph = FheCircuitCtx::new();
478
479 let in_node = input.graph_inputs(&graph);
480 let output = in_node.collect_outputs(&graph, &enc);
481
482 uproc
483 .lock()
484 .unwrap()
485 .run_graph_blocking(&graph.circuit.borrow(), &fc)
486 .unwrap();
487
488 let actual = output.decrypt(&enc, &get_secret_keys_128());
489 assert_eq!(actual, test_val);
490 }
491
492 #[test]
493 fn can_input_output_generic_int_graph_l0_lwe() {
494 input_output::<L0LweCiphertext, Unsigned>(1234);
495 input_output::<L0LweCiphertext, Signed>(-104);
496 }
497
498 #[test]
499 fn can_input_output_generic_int_graph_l1_lwe() {
500 input_output::<L1LweCiphertext, Unsigned>(1234);
501 input_output::<L1LweCiphertext, Signed>(-104);
502 }
503
504 #[test]
505 fn can_input_output_generic_int_graph_l1_ggsw() {
506 input_output::<L1GgswCiphertext, Unsigned>(1234);
507 input_output::<L1GgswCiphertext, Signed>(-104);
508 }
509
510 #[test]
511 fn can_input_output_generic_int_graph_l1_glwe() {
512 input_output::<L1GlweCiphertext, Unsigned>(1234);
513 input_output::<L1GlweCiphertext, Signed>(-104);
514 }
515
516 #[test]
517 fn can_convert_ciphertexts() {
518 fn convert_test<T: CiphertextOps, U: CiphertextOps, V: Sign>(test_val: V::PlaintextType) {
519 let graph = FheCircuitCtx::new();
520 let enc = get_encryption_128();
521 let (uproc, fc) = make_uproc_128();
522 let sk = get_secret_keys_128();
523
524 let val = GenericInt::<16, T, V>::encrypt_secret(test_val, &enc, &sk);
525
526 let inputs = val.graph_inputs(&graph);
527 let converted = inputs.convert::<U>(&graph);
528 let outputs = converted.collect_outputs(&graph, &enc);
529
530 uproc
531 .lock()
532 .unwrap()
533 .run_graph_blocking(&graph.circuit.borrow(), &fc)
534 .unwrap();
535
536 let actual = outputs.decrypt(&enc, &sk);
537 assert_eq!(actual, test_val);
538 }
539
540 convert_test::<L0LweCiphertext, L1GgswCiphertext, Unsigned>(1234);
541 convert_test::<L0LweCiphertext, L1GgswCiphertext, Signed>(-106);
542 convert_test::<L0LweCiphertext, L1GlweCiphertext, Unsigned>(1234);
543 convert_test::<L0LweCiphertext, L1GlweCiphertext, Signed>(-106);
544 convert_test::<L0LweCiphertext, L1LweCiphertext, Unsigned>(1234);
545 convert_test::<L0LweCiphertext, L1LweCiphertext, Signed>(-106);
546 convert_test::<L0LweCiphertext, L0LweCiphertext, Unsigned>(1234);
547 convert_test::<L0LweCiphertext, L0LweCiphertext, Signed>(-106);
548
549 convert_test::<L1GlevCiphertext, L1GgswCiphertext, Unsigned>(1234);
551 convert_test::<L1GlevCiphertext, L1GgswCiphertext, Signed>(-106);
552 convert_test::<L1GgswCiphertext, L1GlevCiphertext, Unsigned>(1234);
553 convert_test::<L1GgswCiphertext, L1GlevCiphertext, Signed>(-106);
554 convert_test::<L0LweCiphertext, L1GlevCiphertext, Unsigned>(1234);
555 convert_test::<L0LweCiphertext, L1GlevCiphertext, Signed>(-106);
556 }
557
558 #[test]
559 fn can_cmp() {
560 fn case<OutCt: Muxable, U: Sign>(
561 gt: bool,
562 eq: bool,
563 test_vals: (U::PlaintextType, U::PlaintextType),
564 ) {
565 let enc = &get_encryption_128();
566 let sk = get_secret_keys_128();
567 let ctx = FheCircuitCtx::new();
568 let (uproc, fc) = make_uproc_128();
569
570 let a = GenericInt::<16, L1GgswCiphertext, U>::encrypt_secret(test_vals.0, enc, &sk);
571 let b = GenericInt::<16, L1GgswCiphertext, U>::encrypt_secret(test_vals.1, enc, &sk);
572
573 let a_input = a.graph_inputs(&ctx);
574 let b_input = b.graph_inputs(&ctx);
575
576 let expect_gt = a_input
577 .cmp::<OutCt>(&b_input, &ctx, gt, eq)
578 .collect_output(&ctx, enc);
579 let expect_lt = b_input
580 .cmp::<OutCt>(&a_input, &ctx, gt, eq)
581 .collect_output(&ctx, enc);
582 let expect_eq = b_input
583 .cmp::<OutCt>(&b_input, &ctx, gt, eq)
584 .collect_output(&ctx, enc);
585
586 uproc
587 .lock()
588 .unwrap()
589 .run_graph_blocking(&ctx.circuit.borrow(), &fc)
590 .unwrap();
591
592 assert_eq!(expect_gt.decrypt(enc, &sk), gt);
593 assert_eq!(expect_lt.decrypt(enc, &sk), !gt);
594 assert_eq!(expect_eq.decrypt(enc, &sk), eq);
595 }
596
597 fn cases<OutCt: Muxable>() {
598 case::<OutCt, Unsigned>(false, false, (43, 42));
599 case::<OutCt, Signed>(false, false, (-35, -36));
600 case::<OutCt, Signed>(false, false, (1, -3));
601 case::<OutCt, Unsigned>(false, true, (43, 42));
602 case::<OutCt, Signed>(false, true, (-37, -38));
603 case::<OutCt, Signed>(false, true, (1, -3));
604 case::<OutCt, Unsigned>(true, false, (43, 42));
605 case::<OutCt, Signed>(true, false, (-37, -38));
606 case::<OutCt, Signed>(true, false, (1, -3));
607 case::<OutCt, Unsigned>(true, true, (43, 42));
608 case::<OutCt, Signed>(true, true, (-37, -38));
609 case::<OutCt, Signed>(true, true, (1, -3));
610 }
611
612 cases::<L1GlweCiphertext>();
613 cases::<L1GlevCiphertext>();
614 }
615
616 #[test]
617 fn can_eq() {
618 fn case<OutCt: Muxable, U: Sign>(
619 eq: bool,
620 test_vals: (U::PlaintextType, U::PlaintextType),
621 ) {
622 let enc = &get_encryption_128();
623 let sk = get_secret_keys_128();
624 let ctx = FheCircuitCtx::new();
625 let (uproc, fc) = make_uproc_128();
626
627 let a = GenericInt::<16, L1GgswCiphertext, U>::encrypt_secret(test_vals.0, enc, &sk);
628 let b = GenericInt::<16, L1GgswCiphertext, U>::encrypt_secret(test_vals.1, enc, &sk);
629
630 let a_input = a.graph_inputs(&ctx);
631 let b_input = b.graph_inputs(&ctx);
632
633 let calculated_eq = a_input
634 .eq::<OutCt>(&b_input, &ctx)
635 .collect_output(&ctx, enc);
636
637 uproc
638 .lock()
639 .unwrap()
640 .run_graph_blocking(&ctx.circuit.borrow(), &fc)
641 .unwrap();
642
643 assert_eq!(calculated_eq.decrypt(enc, &sk), eq);
644 }
645
646 fn cases<OutCt: Muxable>() {
647 case::<OutCt, Unsigned>(false, (43, 42));
648 case::<OutCt, Signed>(false, (-37, -38));
649 case::<OutCt, Unsigned>(true, (43, 43));
650 case::<OutCt, Signed>(true, (-37, -37));
651 }
652
653 cases::<L1GlweCiphertext>();
654 cases::<L1GlevCiphertext>();
655 }
656
657 #[test]
660 fn can_eq_size_mismatch() {
661 fn case<const N: usize, const M: usize, OutCt: Muxable>(eq: bool) {
662 let enc = &get_encryption_128();
663 let sk = get_secret_keys_128();
664 let ctx = FheCircuitCtx::new();
665 let (uproc, fc) = make_uproc_128();
666
667 let (val_a, val_b) = if eq { (43, 43) } else { (43, 42) };
668
669 let a = UInt::<N, L1GgswCiphertext>::encrypt_secret(val_a, enc, &sk);
670 let b = UInt::<M, L1GgswCiphertext>::encrypt_secret(val_b, enc, &sk);
671
672 let a_input = a.graph_inputs(&ctx);
673 let b_input = b.graph_inputs(&ctx);
674
675 let calculated_eq = a_input
676 .eq::<OutCt>(&b_input, &ctx)
677 .collect_output(&ctx, enc);
678
679 uproc
680 .lock()
681 .unwrap()
682 .run_graph_blocking(&ctx.circuit.borrow(), &fc)
683 .unwrap();
684
685 assert_eq!(calculated_eq.decrypt(enc, &sk), eq);
686 }
687
688 case::<8, 16, L1GlweCiphertext>(false);
690 case::<8, 16, L1GlweCiphertext>(true);
691 case::<16, 8, L1GlweCiphertext>(false);
692 case::<16, 8, L1GlweCiphertext>(true);
693
694 case::<8, 16, L1GlevCiphertext>(false);
695 case::<8, 16, L1GlevCiphertext>(true);
696 case::<16, 8, L1GlevCiphertext>(false);
697 case::<16, 8, L1GlevCiphertext>(true);
698 }
699
700 #[test]
701 fn can_neq() {
702 fn case<OutCt: Muxable, U: Sign>(
703 neq: bool,
704 test_vals: (U::PlaintextType, U::PlaintextType),
705 ) {
706 let enc = &get_encryption_128();
707 let sk = get_secret_keys_128();
708 let ctx = FheCircuitCtx::new();
709 let (uproc, fc) = make_uproc_128();
710
711 let a = GenericInt::<16, L1GgswCiphertext, U>::encrypt_secret(test_vals.0, enc, &sk);
712 let b = GenericInt::<16, L1GgswCiphertext, U>::encrypt_secret(test_vals.1, enc, &sk);
713
714 let a_input = a.graph_inputs(&ctx);
715 let b_input = b.graph_inputs(&ctx);
716
717 let calculated_neq = a_input
718 .neq::<OutCt>(&b_input, &ctx)
719 .collect_output(&ctx, enc);
720
721 uproc
722 .lock()
723 .unwrap()
724 .run_graph_blocking(&ctx.circuit.borrow(), &fc)
725 .unwrap();
726
727 assert_eq!(calculated_neq.decrypt(enc, &sk), neq);
728 }
729
730 fn cases<OutCt: Muxable>() {
731 case::<OutCt, Unsigned>(false, (43, 43));
732 case::<OutCt, Signed>(false, (-43, -43));
733 case::<OutCt, Unsigned>(true, (43, 42));
734 case::<OutCt, Signed>(true, (43, 42));
735 case::<OutCt, Signed>(true, (-43, -42));
736 case::<OutCt, Signed>(true, (42, -42));
737 }
738
739 cases::<L1GlweCiphertext>();
740 cases::<L1GlevCiphertext>();
741 }
742
743 #[test]
746 fn can_neq_size_mismatch() {
747 fn case<const N: usize, const M: usize, OutCt: Muxable>(neq: bool) {
748 let enc = &get_encryption_128();
749 let sk = get_secret_keys_128();
750 let ctx = FheCircuitCtx::new();
751 let (uproc, fc) = make_uproc_128();
752
753 let (val_a, val_b) = if neq { (43, 42) } else { (43, 43) };
754
755 let a = UInt::<N, L1GgswCiphertext>::encrypt_secret(val_a, enc, &sk);
756 let b = UInt::<M, L1GgswCiphertext>::encrypt_secret(val_b, enc, &sk);
757
758 let a_input = a.graph_inputs(&ctx);
759 let b_input = b.graph_inputs(&ctx);
760
761 let calculated_neq = a_input
762 .neq::<OutCt>(&b_input, &ctx)
763 .collect_output(&ctx, enc);
764
765 uproc
766 .lock()
767 .unwrap()
768 .run_graph_blocking(&ctx.circuit.borrow(), &fc)
769 .unwrap();
770
771 assert_eq!(calculated_neq.decrypt(enc, &sk), neq);
772 }
773
774 case::<8, 16, L1GlweCiphertext>(false);
776 case::<8, 16, L1GlweCiphertext>(true);
777 case::<16, 8, L1GlweCiphertext>(false);
778 case::<16, 8, L1GlweCiphertext>(true);
779
780 case::<8, 16, L1GlevCiphertext>(false);
781 case::<8, 16, L1GlevCiphertext>(true);
782 case::<16, 8, L1GlevCiphertext>(false);
783 case::<16, 8, L1GlevCiphertext>(true);
784 }
785
786 #[test]
789 fn can_cmp_size_mismatch() {
790 fn case<const N: usize, const M: usize, OutCt: Muxable>(gt: bool, eq: bool) {
791 let enc = &get_encryption_128();
792 let sk = get_secret_keys_128();
793 let ctx = FheCircuitCtx::new();
794 let (uproc, fc) = make_uproc_128();
795
796 let a = UInt::<N, L1GgswCiphertext>::encrypt_secret(43, enc, &sk);
797 let b = UInt::<M, L1GgswCiphertext>::encrypt_secret(42, enc, &sk);
798
799 let a_input = a.graph_inputs(&ctx);
800 let b_input = b.graph_inputs(&ctx);
801
802 let expect_gt = a_input
803 .cmp::<OutCt>(&b_input, &ctx, gt, eq)
804 .collect_output(&ctx, enc);
805 let expect_lt = b_input
806 .cmp::<OutCt>(&a_input, &ctx, gt, eq)
807 .collect_output(&ctx, enc);
808 let expect_eq = b_input
809 .cmp::<OutCt>(&b_input, &ctx, gt, eq)
810 .collect_output(&ctx, enc);
811
812 uproc
813 .lock()
814 .unwrap()
815 .run_graph_blocking(&ctx.circuit.borrow(), &fc)
816 .unwrap();
817
818 assert_eq!(expect_gt.decrypt(enc, &sk), gt);
819 assert_eq!(expect_lt.decrypt(enc, &sk), !gt);
820 assert_eq!(expect_eq.decrypt(enc, &sk), eq);
821 }
822
823 fn cases<OutCt: Muxable>() {
824 case::<8, 16, OutCt>(false, false);
825 case::<8, 16, OutCt>(false, true);
826 case::<8, 16, OutCt>(true, false);
827 case::<8, 16, OutCt>(true, true);
828
829 case::<16, 8, OutCt>(false, false);
830 case::<16, 8, OutCt>(false, true);
831 case::<16, 8, OutCt>(true, false);
832 case::<16, 8, OutCt>(true, true);
833 }
834
835 cases::<L1GlweCiphertext>();
836 cases::<L1GlevCiphertext>();
837 }
838
839 #[test]
840 fn can_cmp_trivial_nontrivial_ggsw() {
841 fn case<OutCt: Muxable, U: Sign>(
842 gt: bool,
843 eq: bool,
844 test_vals: (U::PlaintextType, U::PlaintextType),
845 ) {
846 let enc = &get_encryption_128();
847 let eval = &get_evaluation_128();
848 let sk = get_secret_keys_128();
849 let ctx = FheCircuitCtx::new();
850 let (uproc, fc) = make_uproc_128();
851
852 let a = GenericInt::<16, L1GgswCiphertext, U>::trivial(test_vals.0, enc, eval);
853 let b = GenericInt::<16, L1GgswCiphertext, U>::encrypt_secret(test_vals.1, enc, &sk);
854
855 let a_input = a.graph_inputs(&ctx);
856 let b_input = b.graph_inputs(&ctx);
857
858 let expect_gt = a_input
859 .cmp::<OutCt>(&b_input, &ctx, gt, eq)
860 .collect_output(&ctx, enc);
861 let expect_lt = b_input
862 .cmp::<OutCt>(&a_input, &ctx, gt, eq)
863 .collect_output(&ctx, enc);
864 let expect_eq = b_input
865 .cmp::<OutCt>(&b_input, &ctx, gt, eq)
866 .collect_output(&ctx, enc);
867
868 uproc
869 .lock()
870 .unwrap()
871 .run_graph_blocking(&ctx.circuit.borrow(), &fc)
872 .unwrap();
873
874 assert_eq!(expect_gt.decrypt(enc, &sk), gt);
875 assert_eq!(expect_lt.decrypt(enc, &sk), !gt);
876 assert_eq!(expect_eq.decrypt(enc, &sk), eq);
877 }
878
879 fn cases<OutCt: Muxable>() {
880 case::<OutCt, Unsigned>(false, false, (43, 42));
881 case::<OutCt, Signed>(false, false, (-35, -36));
882 case::<OutCt, Signed>(false, false, (1, -42));
883 case::<OutCt, Unsigned>(false, true, (43, 42));
884 case::<OutCt, Signed>(false, true, (-35, -36));
885 case::<OutCt, Signed>(false, true, (1, -42));
886 case::<OutCt, Unsigned>(true, false, (43, 42));
887 case::<OutCt, Signed>(true, false, (-35, -36));
888 case::<OutCt, Signed>(true, false, (1, -42));
889 case::<OutCt, Unsigned>(true, true, (43, 42));
890 case::<OutCt, Signed>(true, true, (-35, -36));
891 case::<OutCt, Signed>(true, true, (1, -42));
892 }
893
894 cases::<L1GlweCiphertext>();
895 cases::<L1GlevCiphertext>();
896 }
897
898 #[test]
899 fn can_select() {
900 fn case<U: Sign>(test_vals: (U::PlaintextType, U::PlaintextType)) {
901 let enc = &get_encryption_128();
902 let sk = get_secret_keys_128();
903 let ctx = FheCircuitCtx::new();
904 let (uproc, fc) = make_uproc_128();
905
906 let sel_false =
907 Bit::<L1GgswCiphertext>::encrypt_secret(false, enc, &sk).graph_input(&ctx);
908 let sel_true =
909 Bit::<L1GgswCiphertext>::encrypt_secret(true, enc, &sk).graph_input(&ctx);
910
911 let a: GenericIntGraphNodes<'_, 16, L1GlweCiphertext, U> =
912 GenericInt::<16, L1GlweCiphertext, U>::encrypt_secret(test_vals.0, enc, &sk)
913 .graph_inputs(&ctx)
914 .into();
915 let b: GenericIntGraphNodes<'_, 16, L1GlweCiphertext, U> =
916 GenericInt::<16, L1GlweCiphertext, U>::encrypt_secret(test_vals.1, enc, &sk)
917 .graph_inputs(&ctx)
918 .into();
919
920 let sel_false = sel_false
921 .select::<16, U, L1GlweCiphertext>(&a, &b, &ctx)
922 .collect_outputs(&ctx, enc);
923 let sel_true = sel_true
924 .select::<16, U, L1GlweCiphertext>(&a, &b, &ctx)
925 .collect_outputs(&ctx, enc);
926
927 uproc
928 .lock()
929 .unwrap()
930 .run_graph_blocking(&ctx.circuit.borrow(), &fc)
931 .unwrap();
932
933 assert_eq!(sel_false.decrypt(enc, &sk), test_vals.1);
934 assert_eq!(sel_true.decrypt(enc, &sk), test_vals.0);
935 }
936
937 case::<Unsigned>((42, 24));
938 case::<Signed>((-94, -112));
939 }
940
941 #[test]
942 fn can_select_plain() {
943 fn case<U: Sign>(test_vals: (U::PlaintextType, U::PlaintextType)) {
944 let enc = &get_encryption_128();
945 let eval = &get_evaluation_128();
946 let sk = get_secret_keys_128();
947 let ctx = FheCircuitCtx::new();
948 let (uproc, fc) = make_uproc_128();
949
950 let sel_false =
951 Bit::<L1GgswCiphertext>::trivial_encryption(false, enc, eval).graph_input(&ctx);
952 let sel_true =
953 Bit::<L1GgswCiphertext>::trivial_encryption(true, enc, eval).graph_input(&ctx);
954
955 let a: GenericIntGraphNodes<'_, 16, L1GlweCiphertext, U> =
956 GenericInt::<16, L1GlweCiphertext, U>::encrypt_secret(test_vals.0, enc, &sk)
957 .graph_inputs(&ctx)
958 .into();
959 let b: GenericIntGraphNodes<'_, 16, L1GlweCiphertext, U> =
960 GenericInt::<16, L1GlweCiphertext, U>::encrypt_secret(test_vals.1, enc, &sk)
961 .graph_inputs(&ctx)
962 .into();
963
964 let sel_false = sel_false
965 .select::<16, U, L1GlweCiphertext>(&a, &b, &ctx)
966 .collect_outputs(&ctx, enc);
967 let sel_true = sel_true
968 .select::<16, U, L1GlweCiphertext>(&a, &b, &ctx)
969 .collect_outputs(&ctx, enc);
970
971 uproc
972 .lock()
973 .unwrap()
974 .run_graph_blocking(&ctx.circuit.borrow(), &fc)
975 .unwrap();
976
977 assert_eq!(sel_false.decrypt(enc, &sk), test_vals.1);
978 assert_eq!(sel_true.decrypt(enc, &sk), test_vals.0);
979 }
980
981 case::<Unsigned>((42, 24));
982 case::<Signed>((-94, -112));
983 }
984
985 #[test]
986 fn can_sub() {
987 fn case<OutCt: Muxable, U: Sign>(
988 test_vals: (U::PlaintextType, U::PlaintextType, U::PlaintextType),
989 ) {
990 let enc = &get_encryption_128();
991 let sk = get_secret_keys_128();
992 let ctx = FheCircuitCtx::new();
993 let (uproc, fc) = make_uproc_128();
994
995 let a = GenericInt::<16, L1GgswCiphertext, U>::encrypt_secret(test_vals.0, enc, &sk)
996 .graph_inputs(&ctx);
997 let b = GenericInt::<16, L1GgswCiphertext, U>::encrypt_secret(test_vals.1, enc, &sk)
998 .graph_inputs(&ctx);
999
1000 let c = a.sub::<OutCt>(&b, &ctx).collect_outputs(&ctx, enc);
1001
1002 uproc
1003 .lock()
1004 .unwrap()
1005 .run_graph_blocking(&ctx.circuit.borrow(), &fc)
1006 .unwrap();
1007
1008 assert_eq!(c.decrypt(enc, &sk), test_vals.2);
1009 }
1010
1011 fn cases<OutCt: Muxable>() {
1012 case::<OutCt, Unsigned>((42, 16, 26));
1013 case::<OutCt, Signed>((-7, -9, 2));
1014 case::<OutCt, Signed>((-7, 2, -9));
1015 case::<OutCt, Signed>((-7, -5, -2));
1016 case::<OutCt, Signed>((2, -7, 9));
1017 }
1018
1019 cases::<L1GlweCiphertext>();
1020 cases::<L1GlevCiphertext>();
1021 }
1022
1023 #[test]
1024 fn trivial_generic_int_encryption() {
1025 fn case<T: CiphertextOps, U: Sign, F: Fn() -> U::PlaintextType>(gen_pt: F) {
1026 let enc = get_encryption_128();
1027 let eval = &get_evaluation_128();
1028 let sk = get_secret_keys_128();
1029
1030 let expected = gen_pt();
1031
1032 let val = GenericInt::<32, T, U>::trivial(expected, &enc, eval);
1033
1034 assert_eq!(val.decrypt(&enc, &sk), expected);
1035 }
1036
1037 case::<L0LweCiphertext, Unsigned, _>(rand_u32);
1038 case::<L0LweCiphertext, Signed, _>(rand_i32);
1039 case::<L1LweCiphertext, Unsigned, _>(rand_u32);
1040 case::<L1LweCiphertext, Signed, _>(rand_i32);
1041 case::<L1GlweCiphertext, Unsigned, _>(rand_u32);
1042 case::<L1GlweCiphertext, Signed, _>(rand_i32);
1043 case::<L1GgswCiphertext, Unsigned, _>(rand_u32);
1044 case::<L1GgswCiphertext, Signed, _>(rand_i32);
1045 }
1046
1047 #[test]
1048 fn can_resize() {
1049 fn case<T: CiphertextOps, U: Sign>(
1050 test_vals: (U::PlaintextType, U::PlaintextType, U::PlaintextType),
1051 ) {
1052 let enc = get_encryption_128();
1053 let sk = get_secret_keys_128();
1054 let ctx = FheCircuitCtx::new();
1055 let (proc, fc) = make_uproc_128();
1056
1057 let val = GenericInt::<16, T, U>::encrypt_secret(test_vals.0, &enc, &sk);
1058 let res = val
1059 .graph_inputs(&ctx)
1060 .resize(&ctx, 24)
1061 .collect_outputs(&ctx, &enc);
1062
1063 proc.lock()
1064 .unwrap()
1065 .run_graph_blocking(&ctx.circuit.borrow(), &fc)
1066 .unwrap();
1067
1068 assert_eq!(res.decrypt(&enc, &sk), test_vals.1);
1069
1070 let res = val
1071 .graph_inputs(&ctx)
1072 .resize(&ctx, 8)
1073 .collect_outputs(&ctx, &enc);
1074
1075 proc.lock()
1076 .unwrap()
1077 .run_graph_blocking(&ctx.circuit.borrow(), &fc)
1078 .unwrap();
1079
1080 assert_eq!(res.decrypt(&enc, &sk), test_vals.2);
1081 }
1082
1083 case::<L0LweCiphertext, Unsigned>((1234, 1234, 210));
1084 case::<L0LweCiphertext, Signed>((1234, 1234, 82));
1085 case::<L0LweCiphertext, Signed>((-1234, -1234, -82));
1086 case::<L1GlweCiphertext, Unsigned>((1234, 1234, 210));
1089 case::<L1GlweCiphertext, Signed>((1234, 1234, 82));
1090 case::<L1GlweCiphertext, Signed>((-1234, -1234, -82));
1091 case::<L1GgswCiphertext, Unsigned>((1234, 1234, 210));
1092 case::<L1GgswCiphertext, Signed>((1234, 1234, 82));
1093 case::<L1GgswCiphertext, Signed>((-1234, -1234, -82));
1094 }
1095
1096 #[test]
1097 fn can_add() {
1098 fn case<OutCt: Muxable, U: Sign>(
1099 test_vals: (U::PlaintextType, U::PlaintextType, U::PlaintextType),
1100 ) {
1101 let enc = get_encryption_128();
1102 let sk = get_secret_keys_128();
1103 let ctx = FheCircuitCtx::new();
1104 let (proc, fc) = make_uproc_128();
1105
1106 let a = GenericInt::<16, L1GgswCiphertext, U>::encrypt_secret(test_vals.0, &enc, &sk)
1107 .graph_inputs(&ctx);
1108 let b = GenericInt::<16, L1GgswCiphertext, U>::encrypt_secret(test_vals.1, &enc, &sk)
1109 .graph_inputs(&ctx);
1110
1111 let c = a.add::<OutCt>(&b, &ctx).collect_outputs(&ctx, &enc);
1112
1113 println!("{:#?}", *ctx.circuit.borrow());
1114
1115 proc.lock()
1116 .unwrap()
1117 .run_graph_blocking(&ctx.circuit.borrow(), &fc)
1118 .unwrap();
1119
1120 assert_eq!(c.decrypt(&enc, &sk), test_vals.2);
1121 }
1122
1123 fn cases<OutCt: Muxable>() {
1124 case::<OutCt, Unsigned>((42, 16, 58));
1125 case::<OutCt, Signed>((-6, 16, 10));
1126 case::<OutCt, Signed>((-6, -7, -13));
1127 case::<OutCt, Signed>((-8, 2, -6));
1128 }
1129
1130 cases::<L1GlweCiphertext>();
1131 cases::<L1GlevCiphertext>();
1132 }
1133
1134 #[test]
1135 fn can_mul() {
1136 fn case<OutCt: Muxable, U: Sign>(
1137 test_vals: (U::PlaintextType, U::PlaintextType, U::PlaintextType),
1138 ) {
1139 let enc = get_encryption_128();
1140 let sk = get_secret_keys_128();
1141 let ctx = FheCircuitCtx::new();
1142 let (proc, fc) = make_uproc_128();
1143
1144 let a = GenericInt::<16, L1GgswCiphertext, U>::encrypt_secret(test_vals.0, &enc, &sk)
1145 .graph_inputs(&ctx);
1146 let b = GenericInt::<16, L1GgswCiphertext, U>::encrypt_secret(test_vals.1, &enc, &sk)
1147 .graph_inputs(&ctx);
1148
1149 let c = a.mul::<OutCt>(&b, &ctx).collect_outputs(&ctx, &enc);
1150
1151 proc.lock()
1152 .unwrap()
1153 .run_graph_blocking(&ctx.circuit.borrow(), &fc)
1154 .unwrap();
1155
1156 assert_eq!(c.decrypt(&enc, &sk), test_vals.2);
1157 }
1158
1159 case::<L1GlweCiphertext, Unsigned>((42, 16, 672));
1160 case::<L1GlweCiphertext, Signed>((42, 16, 672));
1161 case::<L1GlweCiphertext, Signed>((42, -16, -672));
1162 case::<L1GlweCiphertext, Signed>((-42, -16, 672));
1163 }
1164}