1use std::cell::RefCell;
2use std::cmp::max;
3
4use evaluator::CircuitEvaluator;
5use evaluator::HomEvaluator;
6use evaluator::HomEvaluatorGal;
7use feanor_math::homomorphism::Homomorphism;
8use feanor_math::ring::*;
9
10use crate::cyclotomic::*;
11
12pub mod serialization;
13pub mod evaluator;
14
15pub enum Coefficient<R: ?Sized + RingBase> {
23 Zero, One, NegOne, Integer(i32), Other(R::Element)
24}
25
26impl<R> Clone for Coefficient<R>
27 where R: ?Sized + RingBase,
28 R::Element: Clone
29{
30 fn clone(&self) -> Self {
31 match self {
32 Coefficient::Zero => Coefficient::Zero,
33 Coefficient::One => Coefficient::One,
34 Coefficient::NegOne => Coefficient::NegOne,
35 Coefficient::Integer(x) => Coefficient::Integer(*x),
36 Coefficient::Other(x) => Coefficient::Other(x.clone())
37 }
38 }
39}
40
41impl<R> Copy for Coefficient<R>
42 where R: ?Sized + RingBase,
43 R::Element: Copy
44{}
45
46impl<R: ?Sized + RingBase> Coefficient<R> {
47
48 pub fn clone<S: RingStore<Type = R>>(&self, ring: S) -> Self {
49 match self {
50 Coefficient::Zero => Coefficient::Zero,
51 Coefficient::One => Coefficient::One,
52 Coefficient::NegOne => Coefficient::NegOne,
53 Coefficient::Integer(x) => Coefficient::Integer(*x),
54 Coefficient::Other(x) => Coefficient::Other(ring.clone_el(x))
55 }
56 }
57
58 pub fn eq<S: RingStore<Type = R> + Copy>(&self, other: &Self, ring: S) -> bool {
59 ring.eq_el(&self.clone(ring).to_ring_el(ring), &other.clone(ring).to_ring_el(ring))
60 }
61
62 pub fn add_to<S: RingStore<Type = R> + Copy>(&self, x: El<S>, ring: S) -> El<S> {
66 match self {
67 Coefficient::Zero => x,
68 Coefficient::One => ring.add(x, ring.one()),
69 Coefficient::NegOne => ring.add(x, ring.neg_one()),
70 Coefficient::Integer(y) => ring.add(x, ring.int_hom().map(*y)),
71 Coefficient::Other(y) => ring.add_ref_snd(x, y)
72 }
73 }
74
75 pub fn mul_to<S: RingStore<Type = R> + Copy>(&self, x: El<S>, ring: S) -> El<S> {
80 match self {
81 Coefficient::Zero => ring.zero(),
82 Coefficient::One => x,
83 Coefficient::NegOne => ring.negate(x),
84 Coefficient::Integer(y) => ring.int_hom().mul_map(x, *y),
85 Coefficient::Other(y) => ring.mul_ref_snd(x, y)
86 }
87 }
88
89 pub fn is_zero(&self) -> bool {
90 match self {
91 Coefficient::Zero => true,
92 _ => false
93 }
94 }
95
96 fn from<S: RingStore<Type = R> + Copy>(el: El<S>, ring: S) -> Self {
97 if ring.is_zero(&el) {
98 Coefficient::Zero
99 } else if ring.is_one(&el) {
100 Coefficient::One
101 } else {
102 Coefficient::Other(el)
103 }
104 }
105
106 pub fn to_ring_el<S: RingStore<Type = R>>(self, ring: S) -> El<S> {
107 match self {
108 Coefficient::Zero => ring.zero(),
109 Coefficient::One => ring.one(),
110 Coefficient::NegOne => ring.neg_one(),
111 Coefficient::Integer(x) => ring.int_hom().map(x),
112 Coefficient::Other(x) => x
113 }
114 }
115
116 pub fn negate<S: RingStore<Type = R>>(self, ring: S) -> Self {
117 match self {
118 Coefficient::Zero => Coefficient::Zero,
119 Coefficient::One => Coefficient::NegOne,
120 Coefficient::NegOne => Coefficient::One,
121 Coefficient::Integer(x) => Coefficient::Integer(-x),
122 Coefficient::Other(x) => Coefficient::Other(ring.negate(x))
123 }
124 }
125
126 pub fn add<S: RingStore<Type = R> + Copy>(self, other: Self, ring: S) -> Self {
127 match (self, other) {
128 (Coefficient::Zero, rhs) => rhs,
129 (lhs, Coefficient::Zero) => lhs,
130 (Coefficient::One, Coefficient::Integer(x)) => Coefficient::Integer(x + 1),
131 (Coefficient::NegOne, Coefficient::Integer(x)) => Coefficient::Integer(x - 1),
132 (Coefficient::Integer(x), Coefficient::One) => Coefficient::Integer(x + 1),
133 (Coefficient::Integer(x), Coefficient::NegOne) => Coefficient::Integer(x - 1),
134 (lhs, rhs) => Coefficient::Other(ring.add(lhs.to_ring_el(ring), rhs.to_ring_el(ring)))
135 }
136 }
137
138 pub fn mul<S: RingStore<Type = R> + Copy>(self, other: Self, ring: S) -> Self {
139 match (self, other) {
140 (Coefficient::Zero, _) => Coefficient::Zero,
141 (_, Coefficient::Zero) => Coefficient::Zero,
142 (Coefficient::One, rhs) => rhs,
143 (lhs, Coefficient::One) => lhs,
144 (lhs, Coefficient::NegOne) => lhs.negate(ring),
145 (Coefficient::NegOne, rhs) => rhs.negate(ring),
146 (lhs, rhs) => Coefficient::Other(ring.mul(lhs.to_ring_el(ring), rhs.to_ring_el(ring)))
147 }
148 }
149
150 pub fn change_ring<S, F>(self, mut f: F) -> Coefficient<S>
155 where F: FnMut(R::Element) -> S::Element,
156 S: ?Sized + RingBase
157 {
158 match self {
159 Coefficient::Integer(x) => Coefficient::Integer(x),
160 Coefficient::NegOne => Coefficient::NegOne,
161 Coefficient::Zero => Coefficient::Zero,
162 Coefficient::One => Coefficient::One,
163 Coefficient::Other(x) => Coefficient::Other(f(x))
164 }
165 }
166
167}
168
169struct LinearCombination<R: ?Sized + RingBase> {
175 factors: Vec<Coefficient<R>>,
176 constant: Coefficient<R>
177}
178
179impl<R: ?Sized + RingBase> LinearCombination<R> {
180
181 fn clone<S: RingStore<Type = R> + Copy>(&self, ring: S) -> Self {
182 Self {
183 factors: self.factors.iter().map(|c| c.clone(ring)).collect(),
184 constant: self.constant.clone(ring)
185 }
186 }
187
188 fn evaluate_generic<'a, T, E>(&'a self, first_inputs: &[T], second_inputs: &[T], evaluator: &mut E) -> T
189 where E: CircuitEvaluator<'a, T, R>
190 {
191 assert_eq!(self.factors.len(), first_inputs.len() + second_inputs.len());
192 let current = evaluator.constant(&self.constant);
193 let current = evaluator.add_inner_prod(
194 current,
195 &self.factors[..first_inputs.len()],
196 first_inputs
197 );
198 evaluator.add_inner_prod(
199 current,
200 &self.factors[first_inputs.len()..],
201 second_inputs
202 )
203 }
204
205 fn compose<S>(self, input_transforms: &[LinearCombination<R>], ring: S) -> LinearCombination<R>
206 where S: RingStore<Type = R> + Copy
207 {
208 assert_eq!(self.factors.len(), input_transforms.len());
209 if input_transforms.len() == 0 {
210 return self.clone(ring);
211 }
212 let new_input_count = input_transforms[0].factors.len();
213 assert!(input_transforms.iter().all(|t| t.factors.len() == new_input_count));
214 let mut result_factors = (0..new_input_count).map(|_| Coefficient::Zero).collect::<Vec<_>>();
215 let mut result_constant = self.constant.clone(ring);
216 for (factor, t) in self.factors.into_iter().zip(input_transforms.iter()) {
217 for i in 0..new_input_count {
218 take_mut::take(&mut result_factors[i], |x| x.add(factor.clone(ring).mul(t.factors[i].clone(ring), ring), ring));
219 }
220 result_constant = result_constant.add(factor.mul(t.constant.clone(ring), ring), ring);
221 }
222 return LinearCombination {
223 constant: result_constant,
224 factors: result_factors
225 };
226 }
227
228 fn change_ring<S, F1, F2>(self, change_summand: &mut F1, change_factor: &mut F2) -> LinearCombination<S>
229 where F1: FnMut(Coefficient<R>) -> Coefficient<S>,
230 F2: FnMut(Coefficient<R>) -> Coefficient<S>,
231 S: ?Sized + RingBase
232 {
233 LinearCombination {
234 constant: change_summand(self.constant),
235 factors: self.factors.into_iter().map(|c| change_factor(c)).collect()
236 }
237 }
238}
239
240impl<R: RingBase + Default> PartialEq for LinearCombination<R> {
241
242 fn eq(&self, other: &Self) -> bool {
243 assert_eq!(self.factors.len(), other.factors.len());
244 let ring = RingValue::<R>::default();
245 return self.constant.eq(&other.constant, &ring) &&
246 self.factors.iter().zip(other.factors.iter()).all(|(lhs, rhs)| lhs.eq(rhs, &ring));
247 }
248}
249
250enum PlaintextCircuitGate<R: ?Sized + RingBase> {
254 Mul(LinearCombination<R>, LinearCombination<R>),
255 Square(LinearCombination<R>),
256 Gal(Vec<CyclotomicGaloisGroupEl>, LinearCombination<R>)
257}
258
259impl<R: ?Sized + RingBase> PlaintextCircuitGate<R> {
260
261 fn clone<S: RingStore<Type = R> + Copy>(&self, ring: S) -> Self {
262 match self {
263 PlaintextCircuitGate::Mul(lhs, rhs) => PlaintextCircuitGate::Mul(lhs.clone(ring), rhs.clone(ring)),
264 PlaintextCircuitGate::Gal(gs, t) => PlaintextCircuitGate::Gal(gs.clone(), t.clone(ring)),
265 PlaintextCircuitGate::Square(t) => PlaintextCircuitGate::Square(t.clone(ring))
266 }
267 }
268}
269
270impl<R: RingBase + Default> PartialEq for PlaintextCircuitGate<R> {
271
272 fn eq(&self, other: &Self) -> bool {
273 match (self, other) {
274 (PlaintextCircuitGate::Mul(self_lhs, self_rhs), PlaintextCircuitGate::Mul(other_lhs, other_rhs)) => self_lhs == other_lhs && self_rhs == other_rhs,
275 (PlaintextCircuitGate::Square(self_t), PlaintextCircuitGate::Square(other_t)) => self_t == other_t,
276 _ => false
277 }
278 }
279}
280
281pub struct PlaintextCircuit<R: ?Sized + RingBase> {
304 input_count: usize,
305 gates: Vec<PlaintextCircuitGate<R>>,
306 output_transforms: Vec<LinearCombination<R>>
307}
308
309impl<R: RingBase + Default> PartialEq for PlaintextCircuit<R> {
310
311 fn eq(&self, other: &Self) -> bool {
312 self.input_count == other.input_count && self.gates == other.gates && self.output_transforms == other.output_transforms
313 }
314}
315
316impl<R: ?Sized + RingBase> PlaintextCircuit<R> {
317
318 fn check_invariants(&self) {
319 let mut current_count = self.input_count;
320 for gate in &self.gates {
321 match gate {
322 PlaintextCircuitGate::Mul(lhs, rhs) => {
323 assert_eq!(current_count, lhs.factors.len());
324 assert_eq!(current_count, rhs.factors.len());
325 current_count += 1;
326 },
327 PlaintextCircuitGate::Gal(gs, t) => {
328 assert_eq!(current_count, t.factors.len());
329 current_count += gs.len();
330 },
331 PlaintextCircuitGate::Square(t) => {
332 assert_eq!(current_count, t.factors.len());
333 current_count += 1;
334 }
335 }
336 }
337 for out in &self.output_transforms {
338 assert_eq!(current_count, out.factors.len());
339 }
340 }
341
342 pub fn clone<S: RingStore<Type = R> + Copy>(&self, ring: S) -> Self {
343 Self {
344 gates: self.gates.iter().map(|gate| gate.clone(ring)).collect(),
345 input_count: self.input_count,
346 output_transforms: self.output_transforms.iter().map(|t| t.clone(ring)).collect()
347 }
348 }
349
350 fn computed_wire_count(&self) -> usize {
351 self.gates.iter().map(|gate| match gate {
352 PlaintextCircuitGate::Mul(_, _) => 1,
353 PlaintextCircuitGate::Square(_) => 1,
354 PlaintextCircuitGate::Gal(gs, _) => gs.len()
355 }).sum()
356 }
357
358 pub fn empty() -> Self {
362 Self {
363 input_count: 0,
364 gates: Vec::new(),
365 output_transforms: Vec::new()
366 }
367 }
368
369 pub fn constant_i32<S: RingStore<Type = R>>(c: i32, _ring: S) -> Self {
378 let result = Self {
379 input_count: 0,
380 gates: Vec::new(),
381 output_transforms: vec![LinearCombination {
382 constant: if c == 0{
383 Coefficient::Zero
384 } else if c == 1 {
385 Coefficient::One
386 } else {
387 Coefficient::Integer(c)
388 },
389 factors: Vec::new()
390 }]
391 };
392 result.check_invariants();
393 return result;
394 }
395
396 pub fn constant<S: RingStore<Type = R>>(el: El<S>, ring: S) -> Self {
405 let result = Self {
406 input_count: 0,
407 gates: Vec::new(),
408 output_transforms: vec![LinearCombination {
409 constant: Coefficient::from(el, &ring),
410 factors: Vec::new()
411 }]
412 };
413 result.check_invariants();
414 return result;
415 }
416
417 pub fn change_ring<S, F1, F2>(self, mut change_summand: F1, mut change_factor: F2) -> PlaintextCircuit<S>
422 where F1: FnMut(Coefficient<R>) -> Coefficient<S>,
423 F2: FnMut(Coefficient<R>) -> Coefficient<S>,
424 S: ?Sized + RingBase
425 {
426 PlaintextCircuit {
427 input_count: self.input_count,
428 gates: self.gates.into_iter().map(|gate| match gate {
429 PlaintextCircuitGate::Gal(gs, t) => PlaintextCircuitGate::Gal(gs, t.change_ring(&mut change_summand, &mut change_factor)),
430 PlaintextCircuitGate::Mul(l, r) => PlaintextCircuitGate::Mul(l.change_ring(&mut change_summand, &mut change_factor), r.change_ring(&mut change_summand, &mut change_factor)),
431 PlaintextCircuitGate::Square(t) => PlaintextCircuitGate::Square(t.change_ring(&mut change_summand, &mut change_factor))
432 }).collect(),
433 output_transforms: self.output_transforms.into_iter().map(|t| t.change_ring(&mut change_summand, &mut change_factor)).collect()
434 }
435 }
436
437 pub fn change_ring_uniform<S, F>(self, f: F) -> PlaintextCircuit<S>
442 where F: FnMut(Coefficient<R>) -> Coefficient<S>,
443 S: ?Sized + RingBase
444 {
445 let f_refcell = RefCell::new(f);
446 return self.change_ring(|x| (f_refcell.borrow_mut())(x), |x| (f_refcell.borrow_mut())(x));
447 }
448
449 pub fn linear_transform<S: RingStore<Type = R>>(coeffs: &[Coefficient<R>], ring: S) -> Self {
463 let result = Self {
464 input_count: coeffs.len(),
465 gates: Vec::new(),
466 output_transforms: vec![LinearCombination {
467 constant: Coefficient::Zero,
468 factors: coeffs.iter().map(|c| c.clone(&ring)).collect()
469 }]
470 };
471 result.check_invariants();
472 return result;
473 }
474
475 pub fn linear_transform_ring<S: RingStore<Type = R>>(coeffs: &[El<S>], ring: S) -> Self {
487 let result = Self {
488 input_count: coeffs.len(),
489 gates: Vec::new(),
490 output_transforms: vec![LinearCombination {
491 constant: Coefficient::Zero,
492 factors: coeffs.iter().map(|c| Coefficient::from(ring.clone_el(c), &ring)).collect()
493 }]
494 };
495 result.check_invariants();
496 return result;
497 }
498
499 pub fn add<S: RingStore<Type = R>>(_ring: S) -> Self {
512 let result = Self {
513 input_count: 2,
514 gates: Vec::new(),
515 output_transforms: vec![LinearCombination {
516 constant: Coefficient::Zero,
517 factors: vec![Coefficient::One, Coefficient::One]
518 }]
519 };
520 return result;
521 }
522
523 pub fn square<S: RingStore<Type = R>>(_ring: S) -> Self {
546 let result = Self {
547 input_count: 1,
548 gates: vec![PlaintextCircuitGate::Square(LinearCombination {
549 constant: Coefficient::Zero,
550 factors: vec![Coefficient::One]
551 })],
552 output_transforms: vec![LinearCombination {
553 constant: Coefficient::Zero,
554 factors: vec![Coefficient::Zero, Coefficient::One]
555 }]
556 };
557 return result;
558 }
559
560 pub fn sub<S: RingStore<Type = R>>(_ring: S) -> Self {
573 let result = Self {
574 input_count: 2,
575 gates: Vec::new(),
576 output_transforms: vec![LinearCombination {
577 constant: Coefficient::Zero,
578 factors: vec![Coefficient::One, Coefficient::NegOne]
579 }]
580 };
581 return result;
582 }
583
584 pub fn mul<S: RingStore<Type = R>>(_ring: S) -> Self {
595 let result = Self {
596 input_count: 2,
597 gates: vec![PlaintextCircuitGate::Mul(
598 LinearCombination {
599 constant: Coefficient::Zero,
600 factors: vec![Coefficient::One, Coefficient::Zero]
601 },
602 LinearCombination {
603 constant: Coefficient::Zero,
604 factors: vec![Coefficient::Zero, Coefficient::One]
605 }
606 )],
607 output_transforms: vec![LinearCombination {
608 constant: Coefficient::Zero,
609 factors: vec![Coefficient::Zero, Coefficient::Zero, Coefficient::One]
610 }]
611 };
612 result.check_invariants();
613 return result;
614 }
615
616 pub fn gal<S: RingStore<Type = R>>(g: CyclotomicGaloisGroupEl, _ring: S) -> Self {
627 let result = Self {
628 input_count: 1,
629 gates: vec![PlaintextCircuitGate::Gal(vec![g], LinearCombination {
630 constant: Coefficient::Zero,
631 factors: vec![Coefficient::One]
632 })],
633 output_transforms: vec![LinearCombination {
634 constant: Coefficient::Zero,
635 factors: vec![Coefficient::Zero, Coefficient::One]
636 }]
637 };
638 result.check_invariants();
639 return result;
640 }
641
642 pub fn gal_many<S: RingStore<Type = R>>(gs: &[CyclotomicGaloisGroupEl], _ring: S) -> Self {
653 let result = Self {
654 input_count: 1,
655 gates: vec![PlaintextCircuitGate::Gal(
656 gs.to_owned(),
657 LinearCombination {
658 constant: Coefficient::Zero,
659 factors: vec![Coefficient::One]
660 }
661 )],
662 output_transforms: (0..gs.len()).map(|i| LinearCombination {
663 constant: Coefficient::Zero,
664 factors: (0..=gs.len()).map(|j| if j == i + 1 { Coefficient::One } else { Coefficient::Zero }).collect()
665 }).collect()
666 };
667 result.check_invariants();
668 return result;
669 }
670
671 pub fn output_twice<S: RingStore<Type = R> + Copy>(self, ring: S) -> Self {
694 self.output_times(2, ring)
695 }
696
697 pub fn drop(wire_count: usize) -> Self {
706 let result = Self {
707 input_count: wire_count,
708 gates: Vec::new(),
709 output_transforms: Vec::new()
710 };
711 result.check_invariants();
712 return result;
713 }
714
715 pub fn identity<S: RingStore<Type = R>>(wire_count: usize, _ring: S) -> Self {
723 let result = Self {
724 input_count: wire_count,
725 gates: Vec::new(),
726 output_transforms: (0..wire_count).map(|i| LinearCombination {
727 constant: Coefficient::Zero,
728 factors: (0..wire_count).map(|j| if j == i { Coefficient::One } else { Coefficient::Zero }).collect()
729 }).collect()
730 };
731 result.check_invariants();
732 return result;
733 }
734
735 pub fn select<S: RingStore<Type = R>>(input_wire_count: usize, output_wires: &[usize], _ring: S) -> Self {
740 let result = Self {
741 input_count: input_wire_count,
742 gates: Vec::new(),
743 output_transforms: output_wires.iter().map(|i| {
744 assert!(*i < input_wire_count);
745 LinearCombination {
746 constant: Coefficient::Zero,
747 factors: (0..input_wire_count).map(|j| if *i == j { Coefficient::One } else { Coefficient::Zero }).collect()
748 }
749 }).collect()
750 };
751 result.check_invariants();
752 return result;
753 }
754
755 pub fn output_times<S: RingStore<Type = R> + Copy>(self, times: usize, ring: S) -> Self {
756 let result = Self {
757 input_count: self.input_count,
758 gates: self.gates.iter().map(|gate| gate.clone(ring)).collect(),
759 output_transforms: (0..times).flat_map(|_| self.output_transforms.iter()).map(|lin| lin.clone(ring)).collect()
760 };
761 result.check_invariants();
762 return result;
763 }
764
765 pub fn tensor<S: RingStore<Type = R>>(self, rhs: Self, ring: S) -> Self {
784 let add_zeros = |vec: &[Coefficient<R>], index: usize, count: usize|
785 vec[0..index].iter().map(|c| c.clone(&ring))
786 .chain(std::iter::repeat_with(|| Coefficient::Zero).take(count))
787 .chain(vec[index..].iter().map(|c| c.clone(&ring)))
788 .collect::<Vec<_>>();
789
790 let map_self_transform = |t: &LinearCombination<R>| LinearCombination {
791 constant: t.constant.clone(&ring),
792 factors: add_zeros(&t.factors, self.input_count, rhs.input_count)
793 };
794 let map_rhs_transform = |t: &LinearCombination<R>| LinearCombination {
795 constant: t.constant.clone(&ring),
796 factors: add_zeros(&add_zeros(&t.factors, rhs.input_count, self.computed_wire_count()), 0, self.input_count)
797 };
798 let result = Self {
799 input_count: self.input_count + rhs.input_count,
800 gates: self.gates.iter().map(|gate| match gate {
801 PlaintextCircuitGate::Mul(lhs, rhs) => PlaintextCircuitGate::Mul(
802 map_self_transform(&lhs),
803 map_self_transform(&rhs)
804 ),
805 PlaintextCircuitGate::Gal(gs, t) => PlaintextCircuitGate::Gal(
806 gs.clone(),
807 map_self_transform(t)
808 ),
809 PlaintextCircuitGate::Square(t) => PlaintextCircuitGate::Square(
810 map_self_transform(t)
811 )
812 }).chain(
813 rhs.gates.iter().map(|gate| match gate {
814 PlaintextCircuitGate::Mul(lhs, rhs) => PlaintextCircuitGate::Mul(
815 map_rhs_transform(&lhs),
816 map_rhs_transform(&rhs)
817 ),
818 PlaintextCircuitGate::Gal(gs, t) => PlaintextCircuitGate::Gal(
819 gs.clone(),
820 map_rhs_transform(t)
821 ),
822 PlaintextCircuitGate::Square(t) => PlaintextCircuitGate::Square(
823 map_rhs_transform(t)
824 )
825 })
826 ).collect(),
827 output_transforms: self.output_transforms.iter().map(|t| {
828 assert_eq!(self.computed_wire_count() + self.input_count, t.factors.len());
829 let added_inputs_t = map_self_transform(t);
830 LinearCombination {
831 factors: add_zeros(&added_inputs_t.factors, self.input_count + rhs.input_count + self.computed_wire_count(), rhs.computed_wire_count()),
832 constant: added_inputs_t.constant
833 }
834 }).chain(rhs.output_transforms.iter().map(|t| {
835 assert_eq!(rhs.computed_wire_count() + rhs.input_count, t.factors.len());
836 map_rhs_transform(t)
837 })).collect()
838 };
839 result.check_invariants();
840 return result;
841 }
842
843 pub fn compose<S: RingStore<Type = R> + Copy>(self, first: Self, ring: S) -> Self {
869 assert_eq!(first.output_count(), self.input_count());
870
871 let map_transform = |t: &LinearCombination<R>| {
872 let input_transform = LinearCombination {
873 constant: t.constant.clone(&ring),
874 factors: t.factors[0..self.input_count].iter().map(|c| c.clone(&ring)).collect()
875 };
876 let mut result = input_transform.compose(&first.output_transforms, ring);
877 result.factors.extend(t.factors[self.input_count..].iter().map(|c| c.clone(&ring)));
878 return result;
879 };
880 let result = Self {
881 input_count: first.input_count,
882 gates: first.gates.iter().map(|gate| gate.clone(ring)).chain(
883 self.gates.iter().map(|gate| match gate {
884 PlaintextCircuitGate::Mul(lhs, rhs) => PlaintextCircuitGate::Mul(
885 map_transform(lhs),
886 map_transform(rhs),
887 ),
888 PlaintextCircuitGate::Gal(gs, t) => PlaintextCircuitGate::Gal(
889 gs.clone(),
890 map_transform(t)
891 ),
892 PlaintextCircuitGate::Square(t) => PlaintextCircuitGate::Square(
893 map_transform(t)
894 )
895 })
896 ).collect(),
897 output_transforms: self.output_transforms.iter().map(map_transform).collect()
898 };
899 result.check_invariants();
900 return result;
901 }
902
903 pub fn input_count(&self) -> usize {
904 self.input_count
905 }
906
907 pub fn output_count(&self) -> usize {
908 self.output_transforms.len()
909 }
910
911 pub fn evaluate_generic<'a, T, E>(&'a self, inputs: &[T], mut evaluator: E) -> Vec<T>
951 where E: CircuitEvaluator<'a, T, R>
952 {
953 assert_eq!(self.input_count, inputs.len());
954 assert!(evaluator.supports_gal() || !self.has_galois_gates());
955 let mut current = Vec::new();
956 for gate in &self.gates {
957 match gate {
958 PlaintextCircuitGate::Mul(lhs, rhs) => {
959 let lhs = lhs.evaluate_generic(inputs, ¤t, &mut evaluator);
960 let rhs = rhs.evaluate_generic(inputs, ¤t, &mut evaluator);
961 current.push(evaluator.mul(lhs, rhs));
962 },
963 PlaintextCircuitGate::Gal(gs, t) => {
964 let val = t.evaluate_generic(inputs, ¤t, &mut evaluator);
965 current.extend(evaluator.gal(val, gs));
966 },
967 PlaintextCircuitGate::Square(t) => {
968 let val = t.evaluate_generic(inputs, ¤t, &mut evaluator);
969 current.push(evaluator.square(val));
970 }
971 }
972 }
973 return self.output_transforms.iter().map(|t| t.evaluate_generic(inputs, ¤t, &mut evaluator)).collect()
974 }
975
976 pub fn evaluate_no_galois<S, H>(&self, inputs: &[S::Element], hom: H) -> Vec<S::Element>
983 where S: ?Sized + RingBase,
984 H: Homomorphism<R, S>
985 {
986 assert!(!self.has_galois_gates());
987 return self.evaluate_generic(inputs, HomEvaluator::new(hom));
988 }
989
990 pub fn evaluate<S, H>(&self, inputs: &[S::Element], hom: H) -> Vec<S::Element>
999 where S: ?Sized + RingBase + CyclotomicRing,
1000 H: Homomorphism<R, S>
1001 {
1002 return self.evaluate_generic(inputs, HomEvaluatorGal::new(hom));
1003 }
1004
1005 pub fn has_galois_gates(&self) -> bool {
1006 self.gates.iter().any(|gate| match gate {
1007 PlaintextCircuitGate::Gal(_, _) => true,
1008 PlaintextCircuitGate::Mul(_, _) => false,
1009 PlaintextCircuitGate::Square(_) => false
1010 })
1011 }
1012
1013 pub fn has_multiplication_gates(&self) -> bool {
1014 self.gates.iter().any(|gate| match gate {
1015 PlaintextCircuitGate::Gal(_, _) => false,
1016 PlaintextCircuitGate::Mul(_, _) => true,
1017 PlaintextCircuitGate::Square(_) => true
1018 })
1019 }
1020
1021 pub fn multiplication_gate_count(&self) -> usize {
1022 self.gates.iter().filter(|gate| match gate {
1023 PlaintextCircuitGate::Gal(_, _) => false,
1024 PlaintextCircuitGate::Mul(_, _) => true,
1025 PlaintextCircuitGate::Square(_) => true
1026 }).count()
1027 }
1028
1029 pub fn required_galois_keys(&self, galois_group: &CyclotomicGaloisGroup) -> Vec<CyclotomicGaloisGroupEl> {
1038 let mut result = self.gates.iter().flat_map(|gate| match gate {
1039 PlaintextCircuitGate::Gal(gs, _) => gs.iter().copied(),
1040 PlaintextCircuitGate::Mul(_, _) => [].iter().copied(),
1041 PlaintextCircuitGate::Square(_) => [].iter().copied()
1042 }).collect::<Vec<_>>();
1043 result.sort_unstable_by_key(|g| galois_group.representative(*g));
1044 result.dedup_by_key(|g| galois_group.representative(*g));
1045 return result;
1046 }
1047
1048 pub fn is_linear(&self) -> bool {
1054 !self.has_multiplication_gates()
1055 }
1056
1057 pub fn mul_depth(&self, i: usize) -> usize {
1063 let mut multiplicative_depths = Vec::new();
1064 multiplicative_depths.resize(self.input_count(), 0);
1065 let mult_depth_of_linear_combination = |lin_combination: &LinearCombination<_>, multiplicative_depths: &[usize]| {
1066 assert_eq!(lin_combination.factors.len(), multiplicative_depths.len());
1067 lin_combination.factors.iter().zip(multiplicative_depths.iter()).filter(|(factor, _)| !factor.is_zero()).map(|(_, d)| *d).max().unwrap_or(0)
1068 };
1069 for gate in &self.gates {
1070 let (new_depth, count) = match gate {
1071 PlaintextCircuitGate::Mul(lhs, rhs) => (max(mult_depth_of_linear_combination(lhs, &multiplicative_depths), mult_depth_of_linear_combination(rhs, &multiplicative_depths)) + 1, 1),
1072 PlaintextCircuitGate::Gal(gs, t) => (mult_depth_of_linear_combination(t, &multiplicative_depths), gs.len()),
1073 PlaintextCircuitGate::Square(t) => (mult_depth_of_linear_combination(t, &multiplicative_depths) + 1, 1)
1074 };
1075 multiplicative_depths.extend((0..count).map(|_| new_depth));
1076 }
1077 return mult_depth_of_linear_combination(&self.output_transforms[i], &multiplicative_depths);
1078 }
1079
1080 pub fn max_mul_depth(&self) -> usize {
1086 (0..self.output_count()).map(|i| self.mul_depth(i)).max().unwrap_or(0)
1087 }
1088}
1089
1090#[cfg(test)]
1091use feanor_math::assert_el_eq;
1092#[cfg(test)]
1093use feanor_math::primitive_int::*;
1094#[cfg(test)]
1095use feanor_math::rings::zn::zn_64::Zn;
1096#[cfg(test)]
1097use feanor_math::rings::extension::FreeAlgebraStore;
1098#[cfg(test)]
1099use crate::number_ring::quotient::NumberRingQuotientBase;
1100#[cfg(test)]
1101use crate::number_ring::pow2_cyclotomic::Pow2CyclotomicNumberRing;
1102#[cfg(test)]
1103use serde::de::DeserializeSeed;
1104#[cfg(test)]
1105use serde::Serialize;
1106#[cfg(test)]
1107use serialization::DeserializeSeedPlaintextCircuit;
1108#[cfg(test)]
1109use serialization::SerializablePlaintextCircuit;
1110
1111#[test]
1112fn test_circuit_tensor_compose() {
1113 let ring = StaticRing::<i64>::RING;
1114 let x = PlaintextCircuit::linear_transform_ring(&[1], ring);
1115 let x_sqr = PlaintextCircuit::mul(ring).compose(x.output_twice(ring), ring);
1116 assert!(PlaintextCircuit {
1117 input_count: 1,
1118 gates: vec![PlaintextCircuitGate::Mul(
1119 LinearCombination {
1120 constant: Coefficient::Zero,
1121 factors: vec![Coefficient::One]
1122 },
1123 LinearCombination {
1124 constant: Coefficient::Zero,
1125 factors: vec![Coefficient::One]
1126 }
1127 )],
1128 output_transforms: vec![LinearCombination {
1129 constant: Coefficient::Zero,
1130 factors: vec![Coefficient::Zero, Coefficient::One]
1131 }]
1132 } == x_sqr);
1133
1134 let x = PlaintextCircuit::identity(1, ring);
1135 let y = PlaintextCircuit::identity(1, ring);
1136 let x_y_x_y = x.clone(&ring).tensor(y, ring).output_twice(ring);
1137 let x_y_z = x.clone(ring).tensor(x.clone(ring), ring).tensor(PlaintextCircuit::linear_transform_ring(&[2, 3], ring), ring).compose(x_y_x_y, ring);
1139 let xy_z = PlaintextCircuit::mul(ring).tensor(x, ring).compose(x_y_z, ring);
1140 let w = PlaintextCircuit::mul(ring).compose(xy_z, ring);
1142 for x in -5..5 {
1143 for y in -5..5 {
1144 assert_eq!(x * y * (2 * x + 3 * y), w.evaluate_no_galois(&[x, y], ring.identity()).into_iter().next().unwrap());
1145 }
1146 }
1147
1148 let w_1_sqr = PlaintextCircuit::mul(ring).compose(PlaintextCircuit::add(ring).compose(w.tensor(PlaintextCircuit::constant(1, ring), ring), ring).output_twice(ring), ring);
1149 for x in -5..5 {
1150 for y in -5..5 {
1151 assert_eq!(StaticRing::<i64>::RING.pow(x * y * (2 * x + 3 * y) + 1, 2), w_1_sqr.evaluate_no_galois(&[x, y], ring.identity()).into_iter().next().unwrap());
1152 }
1153 }
1154}
1155
1156#[test]
1157fn test_circuit_tensor_compose_with_galois() {
1158 let ring = NumberRingQuotientBase::new(Pow2CyclotomicNumberRing::new(16), Zn::new(17));
1159
1160 let x = PlaintextCircuit::identity(1, &ring);
1161 let y = PlaintextCircuit::identity(1, &ring);
1162 let xy = PlaintextCircuit::mul(&ring).compose(x.tensor(y, &ring), &ring);
1163 let conj_xy = PlaintextCircuit::gal(ring.galois_group().from_representative(-1), &ring).compose(xy.clone(&ring), &ring);
1164 let partial_trace_xy = PlaintextCircuit::add(&ring).compose(xy.tensor(conj_xy, &ring), &ring).compose(PlaintextCircuit::identity(2, &ring).output_twice(&ring), &ring);
1165
1166 for x_e in 0..8 {
1167 for y_e in 0..8 {
1168 let x = ring.pow(ring.canonical_gen(), x_e);
1169 let y = ring.pow(ring.canonical_gen(), y_e);
1170 let xy = ring.mul_ref(&x, &y);
1171 let conj_xy = ring.mul(ring.pow(ring.canonical_gen(), 16 - x_e), ring.pow(ring.canonical_gen(), 16 - y_e));
1172 assert_el_eq!(
1173 &ring,
1174 ring.add(xy, conj_xy),
1175 partial_trace_xy.evaluate(&[x, y], ring.identity()).into_iter().next().unwrap()
1176 );
1177 }
1178 }
1179}
1180
1181#[test]
1182fn test_giant_step_circuit() {
1183 let ring = StaticRing::<i64>::RING;
1184 let powers = PlaintextCircuit::identity(1, ring).tensor(PlaintextCircuit::mul(ring), ring).tensor(PlaintextCircuit::mul(ring), ring).compose(
1185 PlaintextCircuit::mul(ring).output_times(4, ring).tensor(PlaintextCircuit::identity(1, ring), ring),
1186 ring
1187 ).compose(
1188 PlaintextCircuit::identity(1, ring).output_times(3, ring),
1189 ring
1190 );
1191 assert_eq!(vec![4, 16, 8], powers.evaluate_no_galois(&[2], ring.identity()));
1192
1193 let permuted_baby_step_dupl_input = PlaintextCircuit::constant(1, ring).tensor(PlaintextCircuit::identity(1, ring), ring).tensor(powers, ring);
1194 assert_eq!(vec![1, 2, 4, 16, 8], permuted_baby_step_dupl_input.evaluate_no_galois(&[2, 2], ring.identity()));
1195
1196 let copy_input = PlaintextCircuit::identity(1, ring).output_twice(ring);
1197 assert_eq!(vec![2, 2], copy_input.evaluate_no_galois(&[2], ring.identity()));
1198
1199 let permuted_baby_steps = permuted_baby_step_dupl_input.compose(copy_input, ring);
1200 assert_eq!(vec![1, 2, 4, 16, 8], permuted_baby_steps.evaluate_no_galois(&[2], ring.identity()));
1201
1202 let baby_steps = PlaintextCircuit::select(5, &[0, 1, 2, 4, 3], ring).compose(permuted_baby_steps, ring);
1203 assert_eq!(1, baby_steps.input_count());
1204 assert_eq!(5, baby_steps.output_count());
1205 assert_eq!(vec![1, 2, 4, 8, 16], baby_steps.evaluate_no_galois(&[2], ring.identity()));
1206
1207 let giant_steps_before_baby_steps = PlaintextCircuit::constant(1, ring).tensor(PlaintextCircuit::identity(1, ring), ring);
1208 let baby_and_giant_steps = PlaintextCircuit::identity(4, ring).tensor(giant_steps_before_baby_steps, ring).compose(baby_steps, ring);
1209 assert_eq!(vec![1, 2, 4, 8, 1, 16], baby_and_giant_steps.evaluate_no_galois(&[2], ring.identity()));
1210}
1211
1212#[test]
1213fn test_serialization() {
1214 let ring = StaticRing::<i64>::RING;
1215 let x = PlaintextCircuit::linear_transform_ring(&[1], ring);
1216 let neg_x = PlaintextCircuit::linear_transform_ring(&[-1], ring);
1217 let x_neg_x = PlaintextCircuit::mul(ring).compose(x.clone(ring).tensor(neg_x, ring), ring).compose(x.output_twice(ring), ring);
1218 let two_minus_x_neg_x = PlaintextCircuit::add(ring).compose(x_neg_x.tensor(PlaintextCircuit::constant(2, ring), ring), ring);
1219 let circuit = PlaintextCircuit::square(ring).compose(two_minus_x_neg_x, ring);
1220
1221 for x in -100..100 {
1222 assert_eq!((2 - x * x) * (2 - x * x), circuit.evaluate_no_galois(&[x], ring.identity()).into_iter().next().unwrap());
1223 }
1224
1225 let serializer = serde_assert::Serializer::builder().is_human_readable(true).build();
1226 let tokens = SerializablePlaintextCircuit::new_no_galois(&ring, &circuit).serialize(&serializer).unwrap();
1227 let mut deserializer = serde_assert::Deserializer::builder(tokens).is_human_readable(true).build();
1228 let deserialized_circuit = DeserializeSeedPlaintextCircuit::new_no_galois(&ring).deserialize(&mut deserializer).unwrap();
1229 assert!(deserialized_circuit == circuit);
1230
1231 let serializer = serde_assert::Serializer::builder().is_human_readable(false).build();
1232 let tokens = SerializablePlaintextCircuit::new_no_galois(&ring, &circuit).serialize(&serializer).unwrap();
1233 let mut deserializer = serde_assert::Deserializer::builder(tokens).is_human_readable(false).build();
1234 let deserialized_circuit = DeserializeSeedPlaintextCircuit::new_no_galois(&ring).deserialize(&mut deserializer).unwrap();
1235 assert!(deserialized_circuit == circuit);
1236}