phantom_zone/
decomposer.rs

1use itertools::{izip, Itertools};
2use num_traits::{FromPrimitive, PrimInt, ToPrimitive, WrappingAdd, WrappingSub};
3use std::fmt::{Debug, Display};
4
5use crate::{
6    backend::ArithmeticOps,
7    parameters::{
8        DecompositionCount, DecompostionLogBase, DoubleDecomposerParams, SingleDecomposerParams,
9    },
10    utils::log2,
11};
12
13fn gadget_vector<T: PrimInt>(logq: usize, logb: usize, d: usize) -> Vec<T> {
14    assert!(logq >= (logb * d));
15    let ignored_bits = logq - (logb * d);
16
17    (0..d)
18        .into_iter()
19        .map(|i| T::one() << (logb * i + ignored_bits))
20        .collect_vec()
21}
22
23pub trait RlweDecomposer {
24    type Element;
25    type D: Decomposer<Element = Self::Element>;
26
27    /// Decomposer for RLWE Part A
28    fn a(&self) -> &Self::D;
29    /// Decomposer for RLWE Part B
30    fn b(&self) -> &Self::D;
31}
32
33impl<D> RlweDecomposer for (D, D)
34where
35    D: Decomposer,
36{
37    type D = D;
38    type Element = D::Element;
39    fn a(&self) -> &Self::D {
40        &self.0
41    }
42    fn b(&self) -> &Self::D {
43        &self.1
44    }
45}
46
47impl<D> DoubleDecomposerParams for D
48where
49    D: RlweDecomposer,
50{
51    type Base = DecompostionLogBase;
52    type Count = DecompositionCount;
53
54    fn decomposition_base(&self) -> Self::Base {
55        assert!(
56            Decomposer::decomposition_base(self.a()) == Decomposer::decomposition_base(self.b())
57        );
58        Decomposer::decomposition_base(self.a())
59    }
60    fn decomposition_count_a(&self) -> Self::Count {
61        Decomposer::decomposition_count(self.a())
62    }
63    fn decomposition_count_b(&self) -> Self::Count {
64        Decomposer::decomposition_count(self.b())
65    }
66}
67
68impl<D> SingleDecomposerParams for D
69where
70    D: Decomposer,
71{
72    type Base = DecompostionLogBase;
73    type Count = DecompositionCount;
74    fn decomposition_base(&self) -> Self::Base {
75        Decomposer::decomposition_base(self)
76    }
77    fn decomposition_count(&self) -> Self::Count {
78        Decomposer::decomposition_count(self)
79    }
80}
81
82pub trait Decomposer {
83    type Element;
84    type Iter: Iterator<Item = Self::Element>;
85    fn new(q: Self::Element, logb: usize, d: usize) -> Self;
86
87    fn decompose_to_vec(&self, v: &Self::Element) -> Vec<Self::Element>;
88    fn decompose_iter(&self, v: &Self::Element) -> Self::Iter;
89    fn decomposition_count(&self) -> DecompositionCount;
90    fn decomposition_base(&self) -> DecompostionLogBase;
91    fn gadget_vector(&self) -> Vec<Self::Element>;
92}
93
94pub struct DefaultDecomposer<T> {
95    /// Ciphertext modulus
96    q: T,
97    /// Log of ciphertext modulus
98    logq: usize,
99    /// Log of base B
100    logb: usize,
101    /// base B
102    b: T,
103    /// (B - 1). To simulate (% B) as &(B-1), that is extract least significant
104    /// logb bits
105    b_mask: T,
106    /// B/2
107    bby2: T,
108    /// Decomposition count
109    d: usize,
110    /// No. of bits to ignore in rounding
111    ignore_bits: usize,
112}
113
114pub trait NumInfo {
115    const BITS: u32;
116}
117
118impl NumInfo for u64 {
119    const BITS: u32 = u64::BITS;
120}
121impl NumInfo for u32 {
122    const BITS: u32 = u32::BITS;
123}
124impl NumInfo for u128 {
125    const BITS: u32 = u128::BITS;
126}
127
128impl<T: PrimInt + NumInfo + Debug> DefaultDecomposer<T> {
129    fn recompose<Op>(&self, limbs: &[T], modq_op: &Op) -> T
130    where
131        Op: ArithmeticOps<Element = T>,
132    {
133        let mut value = T::zero();
134        let gadget_vector = gadget_vector(self.logq, self.logb, self.d);
135        assert!(limbs.len() == gadget_vector.len());
136        izip!(limbs.iter(), gadget_vector.iter())
137            .for_each(|(d_el, beta)| value = modq_op.add(&value, &modq_op.mul(d_el, beta)));
138
139        value
140    }
141}
142
143impl<
144        T: PrimInt
145            + ToPrimitive
146            + FromPrimitive
147            + WrappingSub
148            + WrappingAdd
149            + NumInfo
150            + From<bool>
151            + Display
152            + Debug,
153    > Decomposer for DefaultDecomposer<T>
154{
155    type Element = T;
156    type Iter = DecomposerIter<T>;
157
158    fn new(q: T, logb: usize, d: usize) -> DefaultDecomposer<T> {
159        // if q is power of 2, then `BITS - leading_zeros` outputs logq + 1.
160        let logq = log2(&q);
161        assert!(
162            logq >= (logb * d),
163            "Decomposer wants logq >= logb*d but got logq={logq}, logb={logb}, d={d}"
164        );
165
166        let ignore_bits = logq - (logb * d);
167
168        DefaultDecomposer {
169            q,
170            logq,
171            logb,
172            b: T::one() << logb,
173            b_mask: (T::one() << logb) - T::one(),
174            bby2: T::one() << (logb - 1),
175            d,
176            ignore_bits,
177        }
178    }
179
180    fn decompose_to_vec(&self, value: &T) -> Vec<T> {
181        let q = self.q;
182        let logb = self.logb;
183        let b = T::one() << logb;
184        let full_mask = b - T::one();
185        let bby2 = b >> 1;
186
187        let mut value = *value;
188        if value >= (q >> 1) {
189            value = !(q - value) + T::one()
190        }
191        value = round_value(value, self.ignore_bits);
192        let mut out = Vec::with_capacity(self.d);
193        for _ in 0..(self.d) {
194            let k_i = value & full_mask;
195
196            value = (value - k_i) >> logb;
197
198            if k_i > bby2 || (k_i == bby2 && ((value & T::one()) == T::one())) {
199                out.push(q - (b - k_i));
200                value = value + T::one();
201            } else {
202                out.push(k_i);
203            }
204        }
205
206        return out;
207    }
208
209    fn decomposition_count(&self) -> DecompositionCount {
210        DecompositionCount(self.d)
211    }
212
213    fn decomposition_base(&self) -> DecompostionLogBase {
214        DecompostionLogBase(self.logb)
215    }
216
217    fn decompose_iter(&self, value: &T) -> DecomposerIter<T> {
218        let mut value = *value;
219        if value >= (self.q >> 1) {
220            value = !(self.q - value) + T::one()
221        }
222        value = round_value(value, self.ignore_bits);
223
224        DecomposerIter {
225            value,
226            q: self.q,
227            logq: self.logq,
228            logb: self.logb,
229            b: self.b,
230            bby2: self.bby2,
231            b_mask: self.b_mask,
232            steps_left: self.d,
233        }
234    }
235
236    fn gadget_vector(&self) -> Vec<T> {
237        return gadget_vector(self.logq, self.logb, self.d);
238    }
239}
240
241impl<T: PrimInt> DefaultDecomposer<T> {}
242
243pub struct DecomposerIter<T> {
244    /// Value to decompose
245    value: T,
246    steps_left: usize,
247    /// (1 << logb) - 1 (for % (1<<logb); i.e. to extract least signiciant logb
248    /// bits)
249    b_mask: T,
250    logb: usize,
251    // b/2 = 1 << (logb-1)
252    bby2: T,
253    /// Ciphertext modulus
254    q: T,
255    /// Log of ciphertext modulus
256    logq: usize,
257    /// b = 1 << logb
258    b: T,
259}
260
261impl<T: PrimInt + From<bool> + WrappingSub + Display> Iterator for DecomposerIter<T> {
262    type Item = T;
263
264    fn next(&mut self) -> Option<Self::Item> {
265        if self.steps_left != 0 {
266            self.steps_left -= 1;
267            let k_i = self.value & self.b_mask;
268
269            self.value = (self.value - k_i) >> self.logb;
270
271            // if k_i > self.bby2 || (k_i == self.bby2 && ((self.value &
272            // T::one()) == T::one())) {     self.value = self.value
273            // + T::one();     Some(self.q + k_i - self.b)
274            // } else {
275            //     Some(k_i)
276            // }
277
278            // Following is without branching impl of the commented version above. It
279            // happens to speed up bootstrapping for `SMALL_MP_BOOL_PARAMS` (& other
280            // parameters as well but I haven't tested) by roughly 15ms.
281            // Suprisingly the improvement does not show up when I benchmark
282            // `decomposer_iter` in isolation. Putting this remark here as a
283            // future task to investiage (TODO).
284            let carry_bool =
285                k_i > self.bby2 || (k_i == self.bby2 && ((self.value & T::one()) == T::one()));
286            let carry = <T as From<bool>>::from(carry_bool);
287            let neg_carry = (T::zero().wrapping_sub(&carry));
288            self.value = self.value + carry;
289            Some((neg_carry & self.q) + k_i - (carry << self.logb))
290
291            // Some(
292            //     (self.q & ((carry << self.logq) - (T::one() & carry))) + k_i
293            // - (carry << self.logb), )
294
295            // Some(k_i)
296        } else {
297            None
298        }
299    }
300}
301
302fn round_value<T: PrimInt + WrappingAdd>(value: T, ignore_bits: usize) -> T {
303    if ignore_bits == 0 {
304        return value;
305    }
306
307    let ignored_msb = (value & ((T::one() << ignore_bits) - T::one())) >> (ignore_bits - 1);
308    (value >> ignore_bits).wrapping_add(&ignored_msb)
309}
310
311#[cfg(test)]
312mod tests {
313
314    use itertools::Itertools;
315    use rand::{thread_rng, Rng};
316
317    use crate::{
318        backend::{ModInit, ModularOpsU64},
319        decomposer::round_value,
320        utils::generate_prime,
321    };
322
323    use super::{Decomposer, DefaultDecomposer};
324
325    #[test]
326    fn decomposition_works() {
327        let ring_size = 1 << 11;
328
329        let mut rng = thread_rng();
330
331        for logq in [37, 55] {
332            let logb = 11;
333            let d = 3;
334            // let mut stats = vec![Stats::new(); d];
335
336            for i in [true, false] {
337                let q = if i {
338                    generate_prime(logq, 2 * ring_size, 1u64 << logq).unwrap()
339                } else {
340                    1u64 << logq
341                };
342                let decomposer = DefaultDecomposer::new(q, logb, d);
343                let modq_op = ModularOpsU64::new(q);
344                for _ in 0..1000000 {
345                    let value = rng.gen_range(0..q);
346                    let limbs = decomposer.decompose_to_vec(&value);
347                    let limbs_from_iter = decomposer.decompose_iter(&value).collect_vec();
348                    assert_eq!(limbs, limbs_from_iter);
349                    let value_back = round_value(
350                        decomposer.recompose(&limbs, &modq_op),
351                        decomposer.ignore_bits,
352                    );
353                    let rounded_value = round_value(value, decomposer.ignore_bits);
354                    assert!((rounded_value as i64 - value_back as i64).abs() <= 1,);
355
356                    // izip!(stats.iter_mut(), limbs.iter()).for_each(|(s, l)| {
357                    //     s.add_more(&vec![q.map_element_to_i64(l)]);
358                    // });
359                }
360            }
361
362            // stats.iter().enumerate().for_each(|(index, s)| {
363            //     println!(
364            //         "Limb {index} - Mean: {}, Std: {}",
365            //         s.mean(),
366            //         s.std_dev().abs().log2()
367            //     );
368            // });
369        }
370    }
371}