he_ring/circuit/
evaluator.rs

1use std::marker::PhantomData;
2
3use feanor_math::homomorphism::Homomorphism;
4use feanor_math::ring::*;
5
6use crate::cyclotomic::{CyclotomicGaloisGroupEl, CyclotomicRing, CyclotomicRingStore};
7
8use super::Coefficient;
9
10///
11/// Trait for objects that can evaluate arithmetic circuits.
12/// 
13/// This clearly has some similarity with rings, since we can always
14/// evaluate an arithmetic circuit over a ring. However, it is more general,
15/// such as to allow for the evaluation of circuits on more general inputs,
16/// in particular of course on encrypted data.
17/// 
18/// Hence, if we consider circuits to be "programs", this would be the
19/// equivalent of a "virtual machine" running those programs.
20/// 
21/// If you want to evaluate a circuit on ring elements, use [`HomEvaluator`]
22/// or [`HomEvaluatorGal`]. Otherwise, you can build a custom evaluator
23/// using [`DefaultCircuitEvaluator`], for example as follows:
24/// ```
25/// # use he_ring::circuit::*;
26/// # use he_ring::circuit::evaluator::*;
27/// # use feanor_math::ring::*;
28/// # use feanor_math::primitive_int::*;
29/// let ring = StaticRing::<i64>::RING;
30/// let square_xy = PlaintextCircuit::square(ring).compose(PlaintextCircuit::mul(ring), ring);
31/// // assume that, for some reason, we want to wrap the integers in Box; instead of
32/// // implementing our own ring which has boxed integers as elements, we use DefaultCircuitEvaluator
33/// assert_eq!(36, *square_xy.evaluate_generic(
34///     &[Box::new(2), Box::new(3)],
35///     DefaultCircuitEvaluator::new(
36///         /* multiplication = */ |lhs: Box<i64>, rhs| Box::new(*lhs * *rhs),
37///         /* create constant = */ |x| Box::new(x.to_ring_el(ring)),
38///         /* add product = */ |base, lhs, rhs| Box::new(*base + lhs.to_ring_el(ring) * **rhs)
39///     )
40///     // this is optional, but may improve performance if squaring is cheaper than general multiplication
41///         .with_square(|x| Box::new(ring.pow(*x, 2)))
42/// ).into_iter().next().unwrap());
43/// ```
44/// 
45pub trait CircuitEvaluator<'a, T, R: ?Sized + RingBase> {
46
47    fn mul(&mut self, lhs: T, rhs: T) -> T;
48    fn square(&mut self, val: T) -> T;
49    fn constant(&mut self, constant: &'a Coefficient<R>) -> T;
50    fn add_inner_prod(&mut self, dst: T, lhs: &'a [Coefficient<R>], rhs: &[T]) -> T;
51    fn gal(&mut self, val: T, gs: &'a [CyclotomicGaloisGroupEl]) -> Vec<T>;
52    fn supports_gal(&self) -> bool;
53}
54
55pub struct HomEvaluator<R, S, H>
56    where R: ?Sized + RingBase,
57        S: ?Sized + RingBase,
58        H: Homomorphism<R, S>
59{
60    from: PhantomData<Box<R>>,
61    to: PhantomData<Box<S>>,
62    hom: H
63}
64
65impl<R, S, H> HomEvaluator<R, S, H>
66    where R: ?Sized + RingBase,
67        S: ?Sized + RingBase,
68        H: Homomorphism<R, S>
69{
70    pub fn new(hom: H) -> Self {
71        Self {
72            from: PhantomData,
73            to: PhantomData,
74            hom: hom
75        }
76    }
77}
78
79impl<'a, R, S, H> CircuitEvaluator<'a, S::Element, R> for HomEvaluator<R, S, H>
80    where R: ?Sized + RingBase,
81        S: ?Sized + RingBase,
82        H: Homomorphism<R, S>
83{
84    fn add_inner_prod(&mut self, dst: S::Element, lhs: &[Coefficient<R>], rhs: &[S::Element]) -> S::Element {
85        self.hom.codomain().sum(
86            [dst].into_iter().chain(lhs.iter().zip(rhs.iter()).filter_map(|(l, r)| match l {
87                Coefficient::Zero => None,
88                Coefficient::One => Some(self.hom.codomain().clone_el(r)),
89                Coefficient::NegOne => Some(self.hom.codomain().negate(self.hom.codomain().clone_el(r))),
90                Coefficient::Integer(x) => Some(self.hom.codomain().int_hom().mul_ref_fst_map(r, *x)),
91                Coefficient::Other(x) => Some(self.hom.mul_ref_map(r, x))
92            }))
93        )
94    }
95
96    fn constant(&mut self, constant: &Coefficient<R>) -> S::Element {
97        self.hom.map(constant.clone(self.hom.domain()).to_ring_el(self.hom.domain()))
98    }
99
100    fn gal(&mut self, _val: S::Element, _gs: &[CyclotomicGaloisGroupEl]) -> Vec<S::Element> {
101        panic!()
102    }
103
104    fn supports_gal(&self) -> bool {
105        false
106    }
107
108    fn mul(&mut self, lhs: S::Element, rhs: S::Element) -> S::Element {
109        self.hom.codomain().mul(lhs, rhs)
110    }
111
112    fn square(&mut self, val: S::Element) -> S::Element {
113        self.hom.codomain().pow(val, 2)
114    }
115}
116
117pub struct HomEvaluatorGal<R, S, H>
118    where R: ?Sized + RingBase,
119        S: ?Sized + RingBase + CyclotomicRing,
120        H: Homomorphism<R, S>
121{
122    from: PhantomData<Box<R>>,
123    to: PhantomData<Box<S>>,
124    hom: H
125}
126
127impl<R, S, H> HomEvaluatorGal<R, S, H>
128    where R: ?Sized + RingBase,
129        S: ?Sized + RingBase + CyclotomicRing,
130        H: Homomorphism<R, S>
131{
132    pub fn new(hom: H) -> Self {
133        Self {
134            from: PhantomData,
135            to: PhantomData,
136            hom: hom
137        }
138    }
139}
140
141impl<'a, R, S, H> CircuitEvaluator<'a, S::Element, R> for HomEvaluatorGal<R, S, H>
142    where R: ?Sized + RingBase,
143        S: ?Sized + RingBase + CyclotomicRing,
144        H: Homomorphism<R, S>
145{
146    fn add_inner_prod(&mut self, dst: S::Element, lhs: &[Coefficient<R>], rhs: &[S::Element]) -> S::Element {
147        self.hom.codomain().sum(
148            [dst].into_iter().chain(lhs.iter().zip(rhs.iter()).filter_map(|(l, r)| match l {
149                Coefficient::Zero => None,
150                Coefficient::One => Some(self.hom.codomain().clone_el(r)),
151                Coefficient::NegOne => Some(self.hom.codomain().negate(self.hom.codomain().clone_el(r))),
152                Coefficient::Integer(x) => Some(self.hom.codomain().int_hom().mul_ref_fst_map(r, *x)),
153                Coefficient::Other(x) => Some(self.hom.mul_ref_map(r, x))
154            }))
155        )
156    }
157
158    fn constant(&mut self, constant: &Coefficient<R>) -> S::Element {
159        self.hom.map(constant.clone(self.hom.domain()).to_ring_el(self.hom.domain()))
160    }
161
162    fn gal(&mut self, val: S::Element, gs: &[CyclotomicGaloisGroupEl]) -> Vec<S::Element> {
163        self.hom.codomain().apply_galois_action_many(&val, gs)
164    }
165
166    fn supports_gal(&self) -> bool {
167        true
168    }
169
170    fn mul(&mut self, lhs: S::Element, rhs: S::Element) -> S::Element {
171        self.hom.codomain().mul(lhs, rhs)
172    }
173
174    fn square(&mut self, val: S::Element) -> S::Element {
175        self.hom.codomain().pow(val, 2)
176    }
177}
178
179pub trait Possibly {
180    type T;
181    fn get_mut(&mut self) -> Option<&mut Self::T>;
182    fn get(&self) -> Option<&Self::T>;
183}
184
185pub struct Present<T> {
186    t: T
187}
188
189impl<T> Possibly for Present<T> {
190    type T = T;
191    fn get_mut(&mut self) -> Option<&mut T> {
192        Some(&mut self.t)
193    }
194    fn get(&self) -> Option<&Self::T> {
195        Some(&self.t)
196    }
197}
198
199pub struct Absent<T> {
200    t: PhantomData<T>
201}
202
203impl<T> Possibly for Absent<T> {
204    type T = T;
205    fn get_mut(&mut self) -> Option<&mut T> {
206        None
207    }
208    fn get(&self) -> Option<&Self::T> {
209        None
210    }
211}
212
213pub struct DefaultCircuitEvaluator<'a, T, R: ?Sized + RingBase, FnMul, FnConst, FnAddProd, FnSquare, FnGal, FnInnerProd>
214    where FnMul: FnMut(T, T) -> T,
215        FnConst: FnMut(&'a Coefficient<R>) -> T,
216        FnAddProd: Possibly, FnAddProd::T: FnMut(T, &'a Coefficient<R>, &T) -> T,
217        FnSquare: Possibly, FnSquare::T: FnMut(T) -> T,
218        FnGal: Possibly, FnGal::T: FnMut(T, &'a [CyclotomicGaloisGroupEl]) -> Vec<T>,
219        FnInnerProd: Possibly, FnInnerProd::T: FnMut(T, &'a [Coefficient<R>], &[T]) -> T,
220        R: 'a
221{
222    element: PhantomData<T>,
223    ring: PhantomData<&'a R>,
224    mul: FnMul,
225    constant: FnConst,
226    add_prod: FnAddProd,
227    square: FnSquare,
228    gal: FnGal,
229    inner_product: FnInnerProd
230}
231
232impl<'a, T, R: ?Sized + RingBase, FnMul, FnConst, FnAddProd> DefaultCircuitEvaluator<'a, T, R, FnMul, FnConst, Present<FnAddProd>, Absent<fn(T) -> T>, Absent<fn(T, &[CyclotomicGaloisGroupEl]) -> Vec<T>>, Absent<fn(T, &[Coefficient<R>], &[T]) -> T>>
233    where FnMul: FnMut(T, T) -> T,
234        FnConst: FnMut(&'a Coefficient<R>) -> T,
235        FnAddProd: FnMut(T, &'a Coefficient<R>, &T) -> T,
236        R: 'a
237{
238    pub fn new(mul: FnMul, constant: FnConst, add_prod: FnAddProd) -> Self {
239        Self {
240            element: PhantomData,
241            add_prod: Present { t: add_prod },
242            constant: constant,
243            mul: mul,
244            gal: Absent { t: PhantomData },
245            inner_product: Absent { t: PhantomData },
246            square: Absent { t: PhantomData },
247            ring: PhantomData
248        }
249    }
250}
251
252impl<'a, T, R: ?Sized + RingBase, FnMul, FnConst, FnAddProd, FnSquare, FnGal, FnInnerProd> CircuitEvaluator<'a, T, R> for DefaultCircuitEvaluator<'a, T, R, FnMul, FnConst, FnAddProd, FnSquare, FnGal, FnInnerProd>
253    where FnMul: FnMut(T, T) -> T,
254        FnConst: FnMut(&'a Coefficient<R>) -> T,
255        FnAddProd: Possibly, FnAddProd::T: FnMut(T, &'a Coefficient<R>, &T) -> T,
256        FnSquare: Possibly, FnSquare::T: FnMut(T) -> T,
257        FnGal: Possibly, FnGal::T: FnMut(T, &'a [CyclotomicGaloisGroupEl]) -> Vec<T>,
258        FnInnerProd: Possibly, FnInnerProd::T: FnMut(T, &'a [Coefficient<R>], &[T]) -> T,
259        R: 'a,
260        T: 'a
261{
262    fn add_inner_prod(&mut self, dst: T, lhs: &'a [Coefficient<R>], rhs: &[T]) -> T {
263        assert_eq!(lhs.len(), rhs.len());
264        if let Some(inner_prod) = self.inner_product.get_mut() {
265            return inner_prod(dst, lhs, rhs);
266        } else {
267            let mut current = dst;
268            for i in 0..lhs.len() {
269                current = self.add_prod.get_mut().unwrap()(current, &lhs[i], &rhs[i]);
270            }
271            return current;
272        }
273    }
274
275    fn mul(&mut self, lhs: T, rhs: T) -> T {
276        (self.mul)(lhs, rhs)
277    }
278
279    fn constant(&mut self, constant: &'a Coefficient<R>) -> T {
280        (self.constant)(constant)
281    }
282
283    fn gal(&mut self, val: T, gs: &'a [CyclotomicGaloisGroupEl]) -> Vec<T> {
284        if let Some(gal) = self.gal.get_mut() {
285            gal(val, gs)
286        } else {
287            panic!("Circuit contains Galois gates, but no galois function has been specified during evaluator creation")
288        }
289    }
290
291    fn supports_gal(&self) -> bool {
292        self.gal.get().is_some()
293    }
294
295    fn square(&mut self, val: T) -> T {
296        if let Some(square) = self.square.get_mut() {
297            square(val)
298        } else {
299            let zero = (self.constant)(&Coefficient::Zero);
300            let val_copy = self.add_inner_prod(zero, &[Coefficient::One], std::slice::from_ref(&val));
301            (self.mul)(val, val_copy)
302        }
303    }
304}
305
306impl<'a, T, R: ?Sized + RingBase, FnMul, FnConst, FnAddProd, FnGal, FnInnerProd> DefaultCircuitEvaluator<'a, T, R, FnMul, FnConst, FnAddProd, Absent<fn(T) -> T>, FnGal, FnInnerProd>
307    where FnMul: FnMut(T, T) -> T,
308        FnConst: FnMut(&'a Coefficient<R>) -> T,
309        FnAddProd: Possibly, FnAddProd::T: FnMut(T, &'a Coefficient<R>, &T) -> T,
310        FnGal: Possibly, FnGal::T: FnMut(T, &'a [CyclotomicGaloisGroupEl]) -> Vec<T>,
311        FnInnerProd: Possibly, FnInnerProd::T: FnMut(T, &'a [Coefficient<R>], &[T]) -> T,
312        R: 'a,
313        T: 'a
314{
315    pub fn with_square<FnSquare>(self, square: FnSquare) -> DefaultCircuitEvaluator<'a, T, R, FnMul, FnConst, FnAddProd, Present<FnSquare>, FnGal, FnInnerProd>
316        where FnSquare: FnMut(T) -> T
317    {
318        DefaultCircuitEvaluator {
319            add_prod: self.add_prod,
320            constant: self.constant,
321            element: self.element,
322            gal: self.gal,
323            inner_product: self.inner_product,
324            mul: self.mul,
325            ring: self.ring,
326            square: Present { t: square }
327        }
328    }
329}
330
331
332impl<'a, T, R: ?Sized + RingBase, FnMul, FnConst, FnAddProd, FnSquare, FnInnerProd> DefaultCircuitEvaluator<'a, T, R, FnMul, FnConst, FnAddProd, FnSquare, Absent<fn(T, &[CyclotomicGaloisGroupEl]) -> Vec<T>>, FnInnerProd>
333    where FnMul: FnMut(T, T) -> T,
334        FnConst: FnMut(&'a Coefficient<R>) -> T,
335        FnAddProd: Possibly, FnAddProd::T: FnMut(T, &'a Coefficient<R>, &T) -> T,
336        FnSquare: Possibly, FnSquare::T: FnMut(T) -> T,
337        FnInnerProd: Possibly, FnInnerProd::T: FnMut(T, &'a [Coefficient<R>], &[T]) -> T,
338        R: 'a,
339        T: 'a
340{
341    pub fn with_gal<FnGal>(self, gal: FnGal) -> DefaultCircuitEvaluator<'a, T, R, FnMul, FnConst, FnAddProd, FnSquare, Present<FnGal>, FnInnerProd>
342        where FnGal: FnMut(T, &'a [CyclotomicGaloisGroupEl]) -> Vec<T>
343    {
344        DefaultCircuitEvaluator {
345            add_prod: self.add_prod,
346            constant: self.constant,
347            element: self.element,
348            gal: Present { t: gal },
349            inner_product: self.inner_product,
350            mul: self.mul,
351            ring: self.ring,
352            square: self.square
353        }
354    }
355}
356
357impl<'a, T, R: ?Sized + RingBase, FnMul, FnConst, FnAddProd, FnSquare, FnGal> DefaultCircuitEvaluator<'a, T, R, FnMul, FnConst, FnAddProd, FnSquare, FnGal, Absent<fn(T, &[Coefficient<R>], &[T]) -> T>>
358    where FnMul: FnMut(T, T) -> T,
359        FnConst: FnMut(&'a Coefficient<R>) -> T,
360        FnAddProd: Possibly, FnAddProd::T: FnMut(T, &'a Coefficient<R>, &T) -> T,
361        FnSquare: Possibly, FnSquare::T: FnMut(T) -> T,
362        FnGal: Possibly, FnGal::T: FnMut(T, &'a [CyclotomicGaloisGroupEl]) -> Vec<T>,
363        R: 'a,
364        T: 'a
365{
366    pub fn with_inner_product<FnInnerProd>(self, inner_product: FnInnerProd) -> DefaultCircuitEvaluator<'a, T, R, FnMul, FnConst, FnAddProd, FnSquare, FnGal, Present<FnInnerProd>>
367        where FnInnerProd: FnMut(T, &'a [Coefficient<R>], &[T]) -> T
368    {
369        DefaultCircuitEvaluator {
370            add_prod: self.add_prod,
371            constant: self.constant,
372            element: self.element,
373            gal: self.gal,
374            inner_product: Present { t: inner_product },
375            mul: self.mul,
376            ring: self.ring,
377            square: self.square
378        }
379    }
380}