poulpy_hal/reference/vec_znx/
normalize.rs

1use std::hint::black_box;
2
3use criterion::{BenchmarkId, Criterion};
4
5use crate::{
6    api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes},
7    layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
8    reference::znx::{
9        ZnxAddInplace, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulPowerOfTwoInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep,
10        ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace,
11        ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxZero,
12    },
13    source::Source,
14};
15
16pub fn vec_znx_normalize_tmp_bytes(n: usize) -> usize {
17    2 * n * size_of::<i64>()
18}
19
20pub fn vec_znx_normalize<R, A, ZNXARI>(
21    res_base2k: usize,
22    res: &mut R,
23    res_col: usize,
24    a_base2k: usize,
25    a: &A,
26    a_col: usize,
27    carry: &mut [i64],
28) where
29    R: VecZnxToMut,
30    A: VecZnxToRef,
31    ZNXARI: ZnxZero
32        + ZnxCopy
33        + ZnxAddInplace
34        + ZnxMulPowerOfTwoInplace
35        + ZnxNormalizeFirstStepCarryOnly
36        + ZnxNormalizeMiddleStepCarryOnly
37        + ZnxNormalizeMiddleStep
38        + ZnxNormalizeFinalStep
39        + ZnxNormalizeFirstStep
40        + ZnxExtractDigitAddMul
41        + ZnxNormalizeDigit,
42{
43    let mut res: VecZnx<&mut [u8]> = res.to_mut();
44    let a: VecZnx<&[u8]> = a.to_ref();
45
46    #[cfg(debug_assertions)]
47    {
48        assert!(carry.len() >= 2 * res.n());
49        assert_eq!(res.n(), a.n());
50    }
51
52    let n: usize = res.n();
53    let res_size: usize = res.size();
54    let a_size: usize = a.size();
55
56    let carry: &mut [i64] = &mut carry[..2 * n];
57
58    if res_base2k == a_base2k {
59        if a_size > res_size {
60            for j in (res_size..a_size).rev() {
61                if j == a_size - 1 {
62                    ZNXARI::znx_normalize_first_step_carry_only(res_base2k, 0, a.at(a_col, j), carry);
63                } else {
64                    ZNXARI::znx_normalize_middle_step_carry_only(res_base2k, 0, a.at(a_col, j), carry);
65                }
66            }
67
68            for j in (1..res_size).rev() {
69                ZNXARI::znx_normalize_middle_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
70            }
71
72            ZNXARI::znx_normalize_final_step(res_base2k, 0, res.at_mut(res_col, 0), a.at(a_col, 0), carry);
73        } else {
74            for j in (0..a_size).rev() {
75                if j == a_size - 1 {
76                    ZNXARI::znx_normalize_first_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
77                } else if j == 0 {
78                    ZNXARI::znx_normalize_final_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
79                } else {
80                    ZNXARI::znx_normalize_middle_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
81                }
82            }
83
84            for j in a_size..res_size {
85                ZNXARI::znx_zero(res.at_mut(res_col, j));
86            }
87        }
88    } else {
89        let (a_norm, carry) = carry.split_at_mut(n);
90
91        // Relevant limbs of res
92        let res_min_size: usize = (a_size * a_base2k).div_ceil(res_base2k).min(res_size);
93
94        // Relevant limbs of a
95        let a_min_size: usize = (res_size * res_base2k).div_ceil(a_base2k).min(a_size);
96
97        // Get carry for limbs of a that have higher precision than res
98        for j in (a_min_size..a_size).rev() {
99            if j == a_size - 1 {
100                ZNXARI::znx_normalize_first_step_carry_only(a_base2k, 0, a.at(a_col, j), carry);
101            } else {
102                ZNXARI::znx_normalize_middle_step_carry_only(a_base2k, 0, a.at(a_col, j), carry);
103            }
104        }
105
106        if a_min_size == a_size {
107            ZNXARI::znx_zero(carry);
108        }
109
110        // Maximum relevant precision of a
111        let a_prec: usize = a_min_size * a_base2k;
112
113        // Maximum relevant precision of res
114        let res_prec: usize = res_min_size * res_base2k;
115
116        // Res limb index
117        let mut res_idx: usize = res_min_size - 1;
118
119        // Trackers: wow much of res is left to be populated
120        // for the current limb.
121        let mut res_left: usize = res_base2k;
122
123        for j in 0..res_size {
124            ZNXARI::znx_zero(res.at_mut(res_col, j));
125        }
126
127        for j in (0..a_min_size).rev() {
128            // Trackers: wow much of a_norm is left to
129            // be flushed on res.
130            let mut a_left: usize = a_base2k;
131
132            // Normalizes the j-th limb of a and store the results into a_norm.
133            // This step is required to avoid overflow in the next step,
134            // which assumes that |a| is bounded by 2^{a_base2k -1}.
135            if j != 0 {
136                ZNXARI::znx_normalize_middle_step(a_base2k, 0, a_norm, a.at(a_col, j), carry);
137            } else {
138                ZNXARI::znx_normalize_final_step(a_base2k, 0, a_norm, a.at(a_col, j), carry);
139            }
140
141            // In the first iteration we need to match the precision of the input/output.
142            // If a_min_size * a_base2k > res_min_size * res_base2k
143            // then divround a_norm by the difference of precision and
144            // acts like if a_norm has already been partially consummed.
145            // Else acts like if res has been already populated
146            // by the difference.
147            if j == a_min_size - 1 {
148                if a_prec > res_prec {
149                    ZNXARI::znx_mul_power_of_two_inplace(res_prec as i64 - a_prec as i64, a_norm);
150                    a_left -= a_prec - res_prec;
151                } else if res_prec > a_prec {
152                    res_left -= res_prec - a_prec;
153                }
154            }
155
156            // Flushes a into res
157            loop {
158                // Selects the maximum amount of a that can be flushed
159                let a_take: usize = a_base2k.min(a_left).min(res_left);
160
161                // Output limb
162                let res_slice: &mut [i64] = res.at_mut(res_col, res_idx);
163
164                // Scaling of the value to flush
165                let lsh: usize = res_base2k - res_left;
166
167                // Extract the bits to flush on the output and updates
168                // a_norm accordingly.
169                ZNXARI::znx_extract_digit_addmul(a_take, lsh, res_slice, a_norm);
170
171                // Updates the trackers
172                a_left -= a_take;
173                res_left -= a_take;
174
175                // If the current limb of res is full,
176                // then normalizes this limb and adds
177                // the carry on a_norm.
178                if res_left == 0 {
179                    // Updates tracker
180                    res_left += res_base2k;
181
182                    // Normalizes res and propagates the carry on a.
183                    ZNXARI::znx_normalize_digit(res_base2k, res_slice, a_norm);
184
185                    // If we reached the last limb of res breaks,
186                    // but we might rerun the above loop if the
187                    // base2k of a is much smaller than the base2k
188                    // of res.
189                    if res_idx == 0 {
190                        ZNXARI::znx_add_inplace(carry, a_norm);
191                        break;
192                    }
193
194                    // Else updates the limb index of res.
195                    res_idx -= 1
196                }
197
198                // If a_norm is exhausted, breaks the loop.
199                if a_left == 0 {
200                    ZNXARI::znx_add_inplace(carry, a_norm);
201                    break;
202                }
203            }
204        }
205    }
206}
207
208pub fn vec_znx_normalize_inplace<R: VecZnxToMut, ZNXARI>(base2k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
209where
210    ZNXARI: ZnxNormalizeFirstStepInplace + ZnxNormalizeMiddleStepInplace + ZnxNormalizeFinalStepInplace,
211{
212    let mut res: VecZnx<&mut [u8]> = res.to_mut();
213
214    #[cfg(debug_assertions)]
215    {
216        assert!(carry.len() >= res.n());
217    }
218
219    let res_size: usize = res.size();
220
221    for j in (0..res_size).rev() {
222        if j == res_size - 1 {
223            ZNXARI::znx_normalize_first_step_inplace(base2k, 0, res.at_mut(res_col, j), carry);
224        } else if j == 0 {
225            ZNXARI::znx_normalize_final_step_inplace(base2k, 0, res.at_mut(res_col, j), carry);
226        } else {
227            ZNXARI::znx_normalize_middle_step_inplace(base2k, 0, res.at_mut(res_col, j), carry);
228        }
229    }
230}
231
232pub fn bench_vec_znx_normalize<B: Backend>(c: &mut Criterion, label: &str)
233where
234    Module<B>: VecZnxNormalize<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
235    ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
236{
237    let group_name: String = format!("vec_znx_normalize::{label}");
238
239    let mut group = c.benchmark_group(group_name);
240
241    fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
242    where
243        Module<B>: VecZnxNormalize<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
244        ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
245    {
246        let n: usize = 1 << params[0];
247        let cols: usize = params[1];
248        let size: usize = params[2];
249
250        let module: Module<B> = Module::<B>::new(n as u64);
251
252        let base2k: usize = 50;
253
254        let mut source: Source = Source::new([0u8; 32]);
255
256        let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
257        let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
258
259        // Fill a with random i64
260        a.fill_uniform(50, &mut source);
261        res.fill_uniform(50, &mut source);
262
263        let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes());
264
265        move || {
266            for i in 0..cols {
267                module.vec_znx_normalize(base2k, &mut res, i, base2k, &a, i, scratch.borrow());
268            }
269            black_box(());
270        }
271    }
272
273    for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
274        let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
275        let mut runner = runner::<B>(params);
276        group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
277    }
278
279    group.finish();
280}
281
282pub fn bench_vec_znx_normalize_inplace<B: Backend>(c: &mut Criterion, label: &str)
283where
284    Module<B>: VecZnxNormalizeInplace<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
285    ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
286{
287    let group_name: String = format!("vec_znx_normalize_inplace::{label}");
288
289    let mut group = c.benchmark_group(group_name);
290
291    fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
292    where
293        Module<B>: VecZnxNormalizeInplace<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
294        ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
295    {
296        let n: usize = 1 << params[0];
297        let cols: usize = params[1];
298        let size: usize = params[2];
299
300        let module: Module<B> = Module::<B>::new(n as u64);
301
302        let base2k: usize = 50;
303
304        let mut source: Source = Source::new([0u8; 32]);
305
306        let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
307
308        // Fill a with random i64
309        a.fill_uniform(50, &mut source);
310
311        let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes());
312
313        move || {
314            for i in 0..cols {
315                module.vec_znx_normalize_inplace(base2k, &mut a, i, scratch.borrow());
316            }
317            black_box(());
318        }
319    }
320
321    for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
322        let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
323        let mut runner = runner::<B>(params);
324        group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
325    }
326
327    group.finish();
328}
329
330#[test]
331fn test_vec_znx_normalize_conv() {
332    let n: usize = 8;
333
334    let mut carry: Vec<i64> = vec![0i64; 2 * n];
335
336    use crate::reference::znx::ZnxRef;
337    use rug::ops::SubAssignRound;
338    use rug::{Float, float::Round};
339
340    let mut source: Source = Source::new([1u8; 32]);
341
342    let prec: usize = 128;
343
344    let mut data: Vec<i128> = vec![0i128; n];
345
346    data.iter_mut().for_each(|x| *x = source.next_i128());
347
348    for start_base2k in 1..50 {
349        for end_base2k in 1..50 {
350            let end_size: usize = prec.div_ceil(end_base2k);
351
352            let mut want: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, end_size);
353            want.encode_vec_i128(end_base2k, 0, prec, &data);
354            vec_znx_normalize_inplace::<_, ZnxRef>(end_base2k, &mut want, 0, &mut carry);
355
356            // Creates a temporary poly where encoding is in start_base2k
357            let mut tmp: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, prec.div_ceil(start_base2k));
358            tmp.encode_vec_i128(start_base2k, 0, prec, &data);
359
360            vec_znx_normalize_inplace::<_, ZnxRef>(start_base2k, &mut tmp, 0, &mut carry);
361
362            let mut data_tmp: Vec<Float> = (0..n).map(|_| Float::with_val(prec as u32, 0)).collect();
363            tmp.decode_vec_float(start_base2k, 0, &mut data_tmp);
364
365            let mut have: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, end_size);
366            vec_znx_normalize::<_, _, ZnxRef>(end_base2k, &mut have, 0, start_base2k, &tmp, 0, &mut carry);
367
368            let out_prec: u32 = (end_size * end_base2k) as u32;
369
370            let mut data_want: Vec<Float> = (0..n).map(|_| Float::with_val(out_prec, 0)).collect();
371            let mut data_res: Vec<Float> = (0..n).map(|_| Float::with_val(out_prec, 0)).collect();
372
373            have.decode_vec_float(end_base2k, 0, &mut data_want);
374            want.decode_vec_float(end_base2k, 0, &mut data_res);
375
376            for i in 0..n {
377                let mut err: Float = data_want[i].clone();
378                err.sub_assign_round(&data_res[i], Round::Nearest);
379                err = err.abs();
380
381                let err_log2: f64 = err
382                    .clone()
383                    .max(&Float::with_val(prec as u32, 1e-60))
384                    .log2()
385                    .to_f64();
386
387                assert!(
388                    err_log2 <= -(out_prec as f64) + 1.,
389                    "{} {}",
390                    err_log2,
391                    -(out_prec as f64) + 1.
392                )
393            }
394        }
395    }
396}