he_ring/circuit/
serialization.rs

1use std::marker::PhantomData;
2
3use feanor_math::ring::*;
4use feanor_math::seq::{VectorFn, VectorView};
5use feanor_math::serialization::{DeserializeSeedSeq, DeserializeWithRing, SerializableElementRing, SerializableSeq, SerializeWithRing};
6use serde::de::DeserializeSeed;
7use serde::Serialize;
8
9use crate::cyclotomic::{CyclotomicGaloisGroup, CyclotomicGaloisGroupEl, DeserializeSeedCyclotomicGaloisGroupEl, SerializableCyclotomicGaloisGroupEl};
10use crate::{impl_deserialize_seed_for_dependent_enum, impl_deserialize_seed_for_dependent_struct};
11
12use super::{Coefficient, LinearCombination, PlaintextCircuit, PlaintextCircuitGate};
13
14#[derive(Serialize)]
15#[serde(rename = "CoefficientData", bound = "")]
16enum SerializableCoefficient<'a, R: RingStore>
17    where R::Type: SerializableElementRing
18{
19    Integer(i32),
20    Other(SerializeWithRing<'a, R>)
21}
22
23#[derive(Serialize)]
24#[serde(rename = "LinearCombinationData", bound = "")]
25struct SerializableLinearCombination<C: Serialize, S: Serialize> {
26    constant: C,
27    factors: S
28}
29
30#[derive(Serialize)]
31#[serde(rename = "MulGateData", bound = "")]
32struct SerializablePlaintextCircuitMulGate<L: Serialize> {
33    lhs: L,
34    rhs: L
35}
36
37#[derive(Serialize)]
38#[serde(rename = "SquareGateData", bound = "")]
39struct SerializablePlaintextCircuitSquareGate<L: Serialize> {
40    val: L
41}
42
43#[derive(Serialize)]
44#[serde(rename = "GalGateData", bound = "")]
45struct SerializablePlaintextCircuitGalGate<L: Serialize, G: Serialize> {
46    automorphisms: G,
47    input: L
48}
49
50#[derive(Serialize)]
51#[serde(rename = "GateData", bound = "")]
52enum SerializablePlaintextCircuitGate<L: Serialize, G: Serialize> {
53    Mul(SerializablePlaintextCircuitMulGate<L>),
54    Gal(SerializablePlaintextCircuitGalGate<L, G>),
55    Square(SerializablePlaintextCircuitSquareGate<L>)
56}
57
58#[derive(Serialize)]
59#[serde(rename = "PlaintextCircuitData", bound = "")]
60struct SerializablePlaintextCircuitData<G: Serialize, O: Serialize> {
61    input_count: usize,
62    gates: G,
63    output_transforms: O
64}
65
66pub struct SerializablePlaintextCircuit<'a, R: RingStore> {
67    circuit: &'a PlaintextCircuit<R::Type>,
68    ring: R,
69    galois_group: Option<&'a CyclotomicGaloisGroup>
70}
71
72impl<'a, R: RingStore + Copy> SerializablePlaintextCircuit<'a, R>
73    where R::Type: SerializableElementRing
74{
75    pub fn new(ring: R, galois_group: &'a CyclotomicGaloisGroup, circuit: &'a PlaintextCircuit<R::Type>) -> Self {
76        Self { circuit: circuit, ring: ring, galois_group: Some(galois_group) }
77    }
78
79    pub fn new_no_galois(ring: R, circuit: &'a PlaintextCircuit<R::Type>) -> Self {
80        assert!(!circuit.has_galois_gates());
81        Self { circuit: circuit, ring: ring, galois_group: None }
82    }
83}
84
85impl<'a, R: RingStore + Copy> Serialize for SerializablePlaintextCircuit<'a, R> 
86    where R::Type: SerializableElementRing
87{
88    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
89        where S: serde::Serializer
90    {
91        fn serialize_constant<'a, R: RingStore>(c: &'a Coefficient<R::Type>, ring: R) -> SerializableCoefficient<'a, R>
92            where R::Type: SerializableElementRing
93        {
94            match c {
95                Coefficient::Integer(x) => SerializableCoefficient::Integer(*x),
96                Coefficient::One => SerializableCoefficient::Integer(1),
97                Coefficient::NegOne => SerializableCoefficient::Integer(-1),
98                Coefficient::Zero => SerializableCoefficient::Integer(0),
99                Coefficient::Other(x) => SerializableCoefficient::Other(SerializeWithRing::new(x, ring))
100            }
101        }
102        fn serialize_lin_transform<'a, R: Copy + RingStore>(t: &'a LinearCombination<R::Type>, ring: R) -> SerializableLinearCombination<SerializableCoefficient<'a, R>, impl use<'a, R> + Serialize>
103            where R::Type: SerializableElementRing,
104                R: 'a
105        {
106            SerializableLinearCombination {
107                constant: serialize_constant(&t.constant, ring),
108                factors: SerializableSeq::new(t.factors.as_fn().map_fn(move |c| serialize_constant(c, ring)))
109            }
110        }
111        SerializablePlaintextCircuitData {
112            input_count: self.circuit.input_count,
113            gates: SerializableSeq::new(self.circuit.gates.as_fn().map_fn(|gate| match gate {
114                PlaintextCircuitGate::Mul(lhs, rhs) => SerializablePlaintextCircuitGate::Mul(SerializablePlaintextCircuitMulGate {
115                    lhs: serialize_lin_transform(lhs, self.ring), 
116                    rhs: serialize_lin_transform(rhs, self.ring)
117                }),
118                PlaintextCircuitGate::Gal(gs, val) => SerializablePlaintextCircuitGate::Gal(SerializablePlaintextCircuitGalGate {
119                    automorphisms: SerializableSeq::new(gs.as_fn().map_fn(|g| SerializableCyclotomicGaloisGroupEl::new(self.galois_group.unwrap(), *g))), 
120                    input: serialize_lin_transform(val, self.ring)
121                }),
122                PlaintextCircuitGate::Square(val) => SerializablePlaintextCircuitGate::Square(SerializablePlaintextCircuitSquareGate { 
123                    val: serialize_lin_transform(val, self.ring) 
124                })
125            })),
126            output_transforms: SerializableSeq::new(self.circuit.output_transforms.as_fn().map_fn(|t| serialize_lin_transform(t, self.ring)))
127        }.serialize(serializer)
128    }
129}
130
131#[derive(Clone)]
132struct DeserializeSeedCoefficient<R: RingStore>
133    where R::Type: SerializableElementRing
134{
135    deserializer: DeserializeWithRing<R>
136}
137
138impl_deserialize_seed_for_dependent_enum!{
139    <{'de, R}> pub enum CoefficientData<{'de, R}> using DeserializeSeedCoefficient<R> {
140        Integer(i32): |_: DeserializeSeedCoefficient<R>| PhantomData,
141        Other(El<R>): |d: DeserializeSeedCoefficient<R>| d.deserializer
142    } where R: RingStore, R::Type: SerializableElementRing
143}
144
145#[derive(Clone)]
146struct DeserializeSeedLinearCombination<R: RingStore + Copy>
147    where R::Type: SerializableElementRing
148{
149    deserializer: DeserializeWithRing<R>
150}
151
152impl_deserialize_seed_for_dependent_struct!{
153    <{'de, R}> pub struct LinearCombinationData<{'de, R}> using DeserializeSeedLinearCombination<R> {
154        constant: CoefficientData<'de, R>: |d: &DeserializeSeedLinearCombination<R>| DeserializeSeedCoefficient { deserializer: d.deserializer.clone() },
155        factors: Vec<CoefficientData<'de, R>>: |d: &DeserializeSeedLinearCombination<R>| DeserializeSeedSeq::new(
156            std::iter::repeat(DeserializeSeedCoefficient { deserializer: d.deserializer.clone() }),
157            Vec::new(),
158            |mut current, next| { current.push(next); current }
159        )
160    } where R: RingStore + Copy, R::Type: SerializableElementRing
161}
162
163#[derive(Clone)]
164struct DeserializeSeedPlaintextCircuitMulGate<R: RingStore + Copy>
165    where R::Type: SerializableElementRing
166{
167    deserializer: DeserializeWithRing<R>
168}
169
170impl_deserialize_seed_for_dependent_struct!{
171    <{'de, R}> pub struct MulGateData<{'de, R}> using DeserializeSeedPlaintextCircuitMulGate<R> {
172        lhs: LinearCombinationData<'de, R>: |d: &DeserializeSeedPlaintextCircuitMulGate<R>| DeserializeSeedLinearCombination { deserializer: d.deserializer.clone() },
173        rhs: LinearCombinationData<'de, R>: |d: &DeserializeSeedPlaintextCircuitMulGate<R>| DeserializeSeedLinearCombination { deserializer: d.deserializer.clone() }
174    } where R: RingStore + Copy, R::Type: SerializableElementRing
175}
176
177#[derive(Clone)]
178struct DeserializeSeedPlaintextCircuitSquareGate<R: RingStore + Copy>
179    where R::Type: SerializableElementRing
180{
181    deserializer: DeserializeWithRing<R>
182}
183
184impl_deserialize_seed_for_dependent_struct!{
185    <{'de, R}> pub struct SquareGateData<{'de, R}> using DeserializeSeedPlaintextCircuitSquareGate<R> {
186        val: LinearCombinationData<'de, R>: |d: &DeserializeSeedPlaintextCircuitSquareGate<R>| DeserializeSeedLinearCombination { deserializer: d.deserializer.clone() }
187    } where R: RingStore + Copy, R::Type: SerializableElementRing
188}
189
190#[derive(Clone)]
191struct DeserializeSeedPlaintextCircuitGalGate<'a, R: RingStore + Copy>
192    where R::Type: SerializableElementRing
193{
194    galois_group: Option<&'a CyclotomicGaloisGroup>,
195    deserializer: DeserializeWithRing<R>
196}
197
198fn derive_gal_gate_deserializer<'de, 'a, R>(d: &DeserializeSeedPlaintextCircuitGalGate<'a, R>) -> impl use<'a, 'de, R> + DeserializeSeed<'de, Value = Vec<CyclotomicGaloisGroupEl>>
199    where R: RingStore + Copy, R::Type: SerializableElementRing
200{
201    let galois_group: &'a CyclotomicGaloisGroup = d.galois_group.expect("cannot deserialize a circuit with galois gates if no galois group was specified");
202    DeserializeSeedSeq::new(
203        std::iter::repeat(DeserializeSeedCyclotomicGaloisGroupEl::new(galois_group)),
204        Vec::new(),
205        |mut current, next| { current.push(next); current }
206    )
207}
208
209impl_deserialize_seed_for_dependent_struct!{
210    <{'de, 'a, R}> pub struct GalGateData<{'de, R}> using DeserializeSeedPlaintextCircuitGalGate<'a, R> {
211        automorphisms: Vec<CyclotomicGaloisGroupEl>: derive_gal_gate_deserializer,
212        input: LinearCombinationData<'de, R>: |d: &DeserializeSeedPlaintextCircuitGalGate<R>| DeserializeSeedLinearCombination { deserializer: d.deserializer.clone() }
213    } where R: RingStore + Copy, R::Type: SerializableElementRing
214}
215
216#[derive(Clone)]
217struct DeserializeSeedPlaintextCircuitGate<'a, R: RingStore + Copy>
218    where R::Type: SerializableElementRing
219{
220    galois_group: Option<&'a CyclotomicGaloisGroup>,
221    deserializer: DeserializeWithRing<R>
222}
223
224impl_deserialize_seed_for_dependent_enum!{
225    <{'de, 'a, R}> pub enum GateData<{'de, R}> using DeserializeSeedPlaintextCircuitGate<'a, R> {
226        Mul(MulGateData<'de, R>): |d: DeserializeSeedPlaintextCircuitGate<'a, R>| DeserializeSeedPlaintextCircuitMulGate { deserializer: d.deserializer },
227        Gal(GalGateData<'de, R>): |d: DeserializeSeedPlaintextCircuitGate<'a, R>| DeserializeSeedPlaintextCircuitGalGate { deserializer: d.deserializer, galois_group: d.galois_group },
228        Square(SquareGateData<'de, R>): |d: DeserializeSeedPlaintextCircuitGate<'a, R>| DeserializeSeedPlaintextCircuitSquareGate { deserializer: d.deserializer }
229    } where R: RingStore + Copy, R::Type: SerializableElementRing
230}
231
232struct DeserializeSeedPlaintextCircuitData<'a, R: RingStore + Copy>
233    where R::Type: SerializableElementRing
234{
235    galois_group: Option<&'a CyclotomicGaloisGroup>,
236    deserializer: DeserializeWithRing<R>
237}
238
239impl_deserialize_seed_for_dependent_struct!{
240    <{'de, 'a, R}> pub struct PlaintextCircuitData<{'de, R}> using DeserializeSeedPlaintextCircuitData<'a, R> {
241        input_count: usize: |_| PhantomData,
242        gates: Vec<GateData<'de, R>>: |d: &DeserializeSeedPlaintextCircuitData<'a, R>| DeserializeSeedSeq::new(
243            std::iter::repeat(DeserializeSeedPlaintextCircuitGate { deserializer: d.deserializer.clone(), galois_group: d.galois_group }),
244            Vec::new(),
245            |mut current, next| { current.push(next); current }
246        ),
247        output_transforms: Vec<LinearCombinationData<'de, R>>: |d: &DeserializeSeedPlaintextCircuitData<'a, R>| DeserializeSeedSeq::new(
248            std::iter::repeat(DeserializeSeedLinearCombination { deserializer: d.deserializer.clone() }),
249            Vec::new(),
250            |mut current, next| { current.push(next); current }
251        )
252    } where R: RingStore + Copy, R::Type: SerializableElementRing
253}
254
255pub struct DeserializeSeedPlaintextCircuit<'a, R: RingStore + Copy>
256    where R::Type: SerializableElementRing
257{
258    ring: R,
259    galois_group: Option<&'a CyclotomicGaloisGroup>
260}
261
262impl<'a, R: RingStore + Copy> DeserializeSeedPlaintextCircuit<'a, R>
263    where R::Type: SerializableElementRing
264{
265    pub fn new(ring: R, galois_group: &'a CyclotomicGaloisGroup) -> Self {
266        Self { ring: ring, galois_group: Some(galois_group) }
267    }
268
269    pub fn new_no_galois(ring: R) -> Self {
270        Self { ring: ring, galois_group: None }
271    }
272}
273
274impl<'de, 'a, R: RingStore + Copy> DeserializeSeed<'de> for DeserializeSeedPlaintextCircuit<'a, R>
275    where R::Type: SerializableElementRing
276{
277    type Value = PlaintextCircuit<R::Type>;
278
279    fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
280        where D: serde::Deserializer<'de>
281    {
282        let convert_coefficient = |c: CoefficientData<_>| match c {
283            CoefficientData::Integer((x, _)) if x == 0 => Coefficient::Zero,
284            CoefficientData::Integer((x, _)) if x == 1 => Coefficient::One,
285            CoefficientData::Integer((x, _)) if x == -1 => Coefficient::NegOne,
286            CoefficientData::Integer((x, _)) => Coefficient::Integer(x),
287            CoefficientData::Other((x, _)) => Coefficient::Other(x)
288        };
289        let convert_transform = |t: LinearCombinationData<_>| LinearCombination {
290            constant: convert_coefficient(t.constant),
291            factors: t.factors.into_iter().map(convert_coefficient).collect()
292        };
293        let res = DeserializeSeedPlaintextCircuitData {
294            deserializer: DeserializeWithRing::new(self.ring),
295            galois_group: self.galois_group
296        }.deserialize(deserializer)?;
297        let result = PlaintextCircuit {
298            gates: res.gates.into_iter().map(|gate| match gate {
299                GateData::Gal((gate, _)) => PlaintextCircuitGate::Gal(gate.automorphisms, convert_transform(gate.input)),
300                GateData::Mul((gate, _)) => PlaintextCircuitGate::Mul(convert_transform(gate.lhs), convert_transform(gate.rhs)),
301                GateData::Square((gate, _)) => PlaintextCircuitGate::Square(convert_transform(gate.val))
302            }).collect(),
303            input_count: res.input_count,
304            output_transforms: res.output_transforms.into_iter().map(convert_transform).collect()
305        };
306        return Ok(result);
307    }
308}