poulpy_hal/reference/vec_znx/
shift.rs

1use std::hint::black_box;
2
3use criterion::{BenchmarkId, Criterion};
4
5use crate::{
6    api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxLsh, VecZnxLshInplace, VecZnxRsh, VecZnxRshInplace},
7    layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
8    reference::{
9        vec_znx::vec_znx_copy,
10        znx::{
11            ZnxCopy, ZnxNormalizeFinalStep, ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly,
12            ZnxNormalizeFirstStepInplace, ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace,
13            ZnxZero,
14        },
15    },
16    source::Source,
17};
18
19pub fn vec_znx_lsh_tmp_bytes(n: usize) -> usize {
20    n * size_of::<i64>()
21}
22
23pub fn vec_znx_lsh_inplace<R, ZNXARI>(base2k: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
24where
25    R: VecZnxToMut,
26    ZNXARI: ZnxZero
27        + ZnxCopy
28        + ZnxNormalizeFirstStepInplace
29        + ZnxNormalizeMiddleStepInplace
30        + ZnxNormalizeFirstStepInplace
31        + ZnxNormalizeFinalStepInplace,
32{
33    let mut res: VecZnx<&mut [u8]> = res.to_mut();
34
35    let n: usize = res.n();
36    let cols: usize = res.cols();
37    let size: usize = res.size();
38    let steps: usize = k / base2k;
39    let k_rem: usize = k % base2k;
40
41    if steps >= size {
42        for j in 0..size {
43            ZNXARI::znx_zero(res.at_mut(res_col, j));
44        }
45        return;
46    }
47
48    // Inplace shift of limbs by a k/base2k
49    if steps > 0 {
50        let start: usize = n * res_col;
51        let end: usize = start + n;
52        let slice_size: usize = n * cols;
53        let res_raw: &mut [i64] = res.raw_mut();
54
55        (0..size - steps).for_each(|j| {
56            let (lhs, rhs) = res_raw.split_at_mut(slice_size * (j + steps));
57            ZNXARI::znx_copy(
58                &mut lhs[start + j * slice_size..end + j * slice_size],
59                &rhs[start..end],
60            );
61        });
62
63        for j in size - steps..size {
64            ZNXARI::znx_zero(res.at_mut(res_col, j));
65        }
66    }
67
68    // Inplace normalization with left shift of k % base2k
69    if !k.is_multiple_of(base2k) {
70        for j in (0..size - steps).rev() {
71            if j == size - steps - 1 {
72                ZNXARI::znx_normalize_first_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry);
73            } else if j == 0 {
74                ZNXARI::znx_normalize_final_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry);
75            } else {
76                ZNXARI::znx_normalize_middle_step_inplace(base2k, k_rem, res.at_mut(res_col, j), carry);
77            }
78        }
79    }
80}
81
82pub fn vec_znx_lsh<R, A, ZNXARI>(base2k: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
83where
84    R: VecZnxToMut,
85    A: VecZnxToRef,
86    ZNXARI: ZnxZero + ZnxNormalizeFirstStep + ZnxNormalizeMiddleStep + ZnxNormalizeFirstStep + ZnxCopy + ZnxNormalizeFinalStep,
87{
88    let mut res: VecZnx<&mut [u8]> = res.to_mut();
89    let a: VecZnx<&[u8]> = a.to_ref();
90
91    let res_size: usize = res.size();
92    let a_size = a.size();
93    let steps: usize = k / base2k;
94    let k_rem: usize = k % base2k;
95
96    if steps >= res_size.min(a_size) {
97        for j in 0..res_size {
98            ZNXARI::znx_zero(res.at_mut(res_col, j));
99        }
100        return;
101    }
102
103    let min_size: usize = a_size.min(res_size) - steps;
104
105    // Simply a left shifted normalization of limbs
106    // by k/base2k and intra-limb by base2k - k%base2k
107    if !k.is_multiple_of(base2k) {
108        for j in (0..min_size).rev() {
109            if j == min_size - 1 {
110                ZNXARI::znx_normalize_first_step(
111                    base2k,
112                    k_rem,
113                    res.at_mut(res_col, j),
114                    a.at(a_col, j + steps),
115                    carry,
116                );
117            } else if j == 0 {
118                ZNXARI::znx_normalize_final_step(
119                    base2k,
120                    k_rem,
121                    res.at_mut(res_col, j),
122                    a.at(a_col, j + steps),
123                    carry,
124                );
125            } else {
126                ZNXARI::znx_normalize_middle_step(
127                    base2k,
128                    k_rem,
129                    res.at_mut(res_col, j),
130                    a.at(a_col, j + steps),
131                    carry,
132                );
133            }
134        }
135    } else {
136        // If k % base2k = 0, then this is simply a copy.
137        for j in (0..min_size).rev() {
138            ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j + steps));
139        }
140    }
141
142    // Zeroes bottom
143    for j in min_size..res_size {
144        ZNXARI::znx_zero(res.at_mut(res_col, j));
145    }
146}
147
148pub fn vec_znx_rsh_tmp_bytes(n: usize) -> usize {
149    n * size_of::<i64>()
150}
151
152pub fn vec_znx_rsh_inplace<R, ZNXARI>(base2k: usize, k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
153where
154    R: VecZnxToMut,
155    ZNXARI: ZnxZero
156        + ZnxCopy
157        + ZnxNormalizeFirstStepCarryOnly
158        + ZnxNormalizeMiddleStepCarryOnly
159        + ZnxNormalizeMiddleStep
160        + ZnxNormalizeMiddleStepInplace
161        + ZnxNormalizeFirstStepInplace
162        + ZnxNormalizeFinalStepInplace,
163{
164    let mut res: VecZnx<&mut [u8]> = res.to_mut();
165    let n: usize = res.n();
166    let cols: usize = res.cols();
167    let size: usize = res.size();
168
169    let mut steps: usize = k / base2k;
170    let k_rem: usize = k % base2k;
171
172    if k == 0 {
173        return;
174    }
175
176    if steps >= size {
177        for j in 0..size {
178            ZNXARI::znx_zero(res.at_mut(res_col, j));
179        }
180        return;
181    }
182
183    let start: usize = n * res_col;
184    let end: usize = start + n;
185    let slice_size: usize = n * cols;
186
187    if !k.is_multiple_of(base2k) {
188        // We rsh by an additional base2k and then lsh by base2k-k
189        // Allows to re-use efficient normalization code, avoids
190        // avoids overflows & produce output that is normalized
191        steps += 1;
192
193        // All limbs of a that would fall outside of the limbs of res are discarded,
194        // but the carry still need to be computed.
195        (size - steps..size).rev().for_each(|j| {
196            if j == size - 1 {
197                ZNXARI::znx_normalize_first_step_carry_only(base2k, base2k - k_rem, res.at(res_col, j), carry);
198            } else {
199                ZNXARI::znx_normalize_middle_step_carry_only(base2k, base2k - k_rem, res.at(res_col, j), carry);
200            }
201        });
202
203        // Continues with shifted normalization
204        let res_raw: &mut [i64] = res.raw_mut();
205        (steps..size).rev().for_each(|j| {
206            let (lhs, rhs) = res_raw.split_at_mut(slice_size * j);
207            let rhs_slice: &mut [i64] = &mut rhs[start..end];
208            let lhs_slice: &[i64] = &lhs[(j - steps) * slice_size + start..(j - steps) * slice_size + end];
209            ZNXARI::znx_normalize_middle_step(base2k, base2k - k_rem, rhs_slice, lhs_slice, carry);
210        });
211
212        // Propagates carry on the rest of the limbs of res
213        for j in (0..steps).rev() {
214            ZNXARI::znx_zero(res.at_mut(res_col, j));
215            if j == 0 {
216                ZNXARI::znx_normalize_final_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
217            } else {
218                ZNXARI::znx_normalize_middle_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
219            }
220        }
221    } else {
222        // Shift by multiples of base2k
223        let res_raw: &mut [i64] = res.raw_mut();
224        (steps..size).rev().for_each(|j| {
225            let (lhs, rhs) = res_raw.split_at_mut(slice_size * j);
226            ZNXARI::znx_copy(
227                &mut rhs[start..end],
228                &lhs[(j - steps) * slice_size + start..(j - steps) * slice_size + end],
229            );
230        });
231
232        // Zeroes the top
233        (0..steps).for_each(|j| {
234            ZNXARI::znx_zero(res.at_mut(res_col, j));
235        });
236    }
237}
238
239pub fn vec_znx_rsh<R, A, ZNXARI>(base2k: usize, k: usize, res: &mut R, res_col: usize, a: &A, a_col: usize, carry: &mut [i64])
240where
241    R: VecZnxToMut,
242    A: VecZnxToRef,
243    ZNXARI: ZnxZero
244        + ZnxCopy
245        + ZnxNormalizeFirstStepCarryOnly
246        + ZnxNormalizeMiddleStepCarryOnly
247        + ZnxNormalizeFirstStep
248        + ZnxNormalizeMiddleStep
249        + ZnxNormalizeMiddleStepInplace
250        + ZnxNormalizeFirstStepInplace
251        + ZnxNormalizeFinalStepInplace,
252{
253    let mut res: VecZnx<&mut [u8]> = res.to_mut();
254    let a: VecZnx<&[u8]> = a.to_ref();
255
256    let res_size: usize = res.size();
257    let a_size: usize = a.size();
258
259    let mut steps: usize = k / base2k;
260    let k_rem: usize = k % base2k;
261
262    if k == 0 {
263        vec_znx_copy::<_, _, ZNXARI>(&mut res, res_col, &a, a_col);
264        return;
265    }
266
267    if steps >= res_size {
268        for j in 0..res_size {
269            ZNXARI::znx_zero(res.at_mut(res_col, j));
270        }
271        return;
272    }
273
274    if !k.is_multiple_of(base2k) {
275        // We rsh by an additional base2k and then lsh by base2k-k
276        // Allows to re-use efficient normalization code, avoids
277        // avoids overflows & produce output that is normalized
278        steps += 1;
279
280        // All limbs of a that are moved outside of the limbs of res are discarded,
281        // but the carry still need to be computed.
282        for j in (res_size..a_size + steps).rev() {
283            if j == a_size + steps - 1 {
284                ZNXARI::znx_normalize_first_step_carry_only(base2k, base2k - k_rem, a.at(a_col, j - steps), carry);
285            } else {
286                ZNXARI::znx_normalize_middle_step_carry_only(base2k, base2k - k_rem, a.at(a_col, j - steps), carry);
287            }
288        }
289
290        // Avoids over flow of limbs of res
291        let min_size: usize = res_size.min(a_size + steps);
292
293        // Zeroes lower limbs of res if a_size + steps < res_size
294        (min_size..res_size).for_each(|j| {
295            ZNXARI::znx_zero(res.at_mut(res_col, j));
296        });
297
298        // Continues with shifted normalization
299        for j in (steps..min_size).rev() {
300            // Case if no limb of a was previously discarded
301            if res_size.saturating_sub(steps) >= a_size && j == min_size - 1 {
302                ZNXARI::znx_normalize_first_step(
303                    base2k,
304                    base2k - k_rem,
305                    res.at_mut(res_col, j),
306                    a.at(a_col, j - steps),
307                    carry,
308                );
309            } else {
310                ZNXARI::znx_normalize_middle_step(
311                    base2k,
312                    base2k - k_rem,
313                    res.at_mut(res_col, j),
314                    a.at(a_col, j - steps),
315                    carry,
316                );
317            }
318        }
319
320        // Propagates carry on the rest of the limbs of res
321        for j in (0..steps).rev() {
322            ZNXARI::znx_zero(res.at_mut(res_col, j));
323            if j == 0 {
324                ZNXARI::znx_normalize_final_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
325            } else {
326                ZNXARI::znx_normalize_middle_step_inplace(base2k, base2k - k_rem, res.at_mut(res_col, j), carry);
327            }
328        }
329    } else {
330        let min_size: usize = res_size.min(a_size + steps);
331
332        // Zeroes the top
333        (0..steps).for_each(|j| {
334            ZNXARI::znx_zero(res.at_mut(res_col, j));
335        });
336
337        // Shift a into res, up to the maximum
338        for j in (steps..min_size).rev() {
339            ZNXARI::znx_copy(res.at_mut(res_col, j), a.at(a_col, j - steps));
340        }
341
342        // Zeroes bottom if a_size + steps < res_size
343        (min_size..res_size).for_each(|j| {
344            ZNXARI::znx_zero(res.at_mut(res_col, j));
345        });
346    }
347}
348
349pub fn bench_vec_znx_lsh_inplace<B: Backend>(c: &mut Criterion, label: &str)
350where
351    Module<B>: ModuleNew<B> + VecZnxLshInplace<B>,
352    ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
353{
354    let group_name: String = format!("vec_znx_lsh_inplace::{label}");
355
356    let mut group = c.benchmark_group(group_name);
357
358    fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
359    where
360        Module<B>: VecZnxLshInplace<B> + ModuleNew<B>,
361        ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
362    {
363        let n: usize = 1 << params[0];
364        let cols: usize = params[1];
365        let size: usize = params[2];
366
367        let module: Module<B> = Module::<B>::new(n as u64);
368
369        let base2k: usize = 50;
370
371        let mut source: Source = Source::new([0u8; 32]);
372
373        let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
374        let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
375
376        let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(n * size_of::<i64>());
377
378        // Fill a with random i64
379        a.fill_uniform(50, &mut source);
380        b.fill_uniform(50, &mut source);
381
382        move || {
383            for i in 0..cols {
384                module.vec_znx_lsh_inplace(base2k, base2k - 1, &mut b, i, scratch.borrow());
385            }
386            black_box(());
387        }
388    }
389
390    for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
391        let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
392        let mut runner = runner::<B>(params);
393        group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
394    }
395
396    group.finish();
397}
398
399pub fn bench_vec_znx_lsh<B: Backend>(c: &mut Criterion, label: &str)
400where
401    Module<B>: VecZnxLsh<B> + ModuleNew<B>,
402    ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
403{
404    let group_name: String = format!("vec_znx_lsh::{label}");
405
406    let mut group = c.benchmark_group(group_name);
407
408    fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
409    where
410        Module<B>: VecZnxLsh<B> + ModuleNew<B>,
411        ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
412    {
413        let n: usize = 1 << params[0];
414        let cols: usize = params[1];
415        let size: usize = params[2];
416
417        let module: Module<B> = Module::<B>::new(n as u64);
418
419        let base2k: usize = 50;
420
421        let mut source: Source = Source::new([0u8; 32]);
422
423        let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
424        let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
425
426        let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(n * size_of::<i64>());
427
428        // Fill a with random i64
429        a.fill_uniform(50, &mut source);
430        res.fill_uniform(50, &mut source);
431
432        move || {
433            for i in 0..cols {
434                module.vec_znx_lsh(base2k, base2k - 1, &mut res, i, &a, i, scratch.borrow());
435            }
436            black_box(());
437        }
438    }
439
440    for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
441        let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
442        let mut runner = runner::<B>(params);
443        group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
444    }
445
446    group.finish();
447}
448
449pub fn bench_vec_znx_rsh_inplace<B: Backend>(c: &mut Criterion, label: &str)
450where
451    Module<B>: VecZnxRshInplace<B> + ModuleNew<B>,
452    ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
453{
454    let group_name: String = format!("vec_znx_rsh_inplace::{label}");
455
456    let mut group = c.benchmark_group(group_name);
457
458    fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
459    where
460        Module<B>: VecZnxRshInplace<B> + ModuleNew<B>,
461        ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
462    {
463        let n: usize = 1 << params[0];
464        let cols: usize = params[1];
465        let size: usize = params[2];
466
467        let module: Module<B> = Module::<B>::new(n as u64);
468
469        let base2k: usize = 50;
470
471        let mut source: Source = Source::new([0u8; 32]);
472
473        let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
474        let mut b: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
475
476        let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(n * size_of::<i64>());
477
478        // Fill a with random i64
479        a.fill_uniform(50, &mut source);
480        b.fill_uniform(50, &mut source);
481
482        move || {
483            for i in 0..cols {
484                module.vec_znx_rsh_inplace(base2k, base2k - 1, &mut b, i, scratch.borrow());
485            }
486            black_box(());
487        }
488    }
489
490    for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
491        let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
492        let mut runner = runner::<B>(params);
493        group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
494    }
495
496    group.finish();
497}
498
499pub fn bench_vec_znx_rsh<B: Backend>(c: &mut Criterion, label: &str)
500where
501    Module<B>: VecZnxRsh<B> + ModuleNew<B>,
502    ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
503{
504    let group_name: String = format!("vec_znx_rsh::{label}");
505
506    let mut group = c.benchmark_group(group_name);
507
508    fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
509    where
510        Module<B>: VecZnxRsh<B> + ModuleNew<B>,
511        ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
512    {
513        let n: usize = 1 << params[0];
514        let cols: usize = params[1];
515        let size: usize = params[2];
516
517        let module: Module<B> = Module::<B>::new(n as u64);
518
519        let base2k: usize = 50;
520
521        let mut source: Source = Source::new([0u8; 32]);
522
523        let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
524        let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
525
526        let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(n * size_of::<i64>());
527
528        // Fill a with random i64
529        a.fill_uniform(50, &mut source);
530        res.fill_uniform(50, &mut source);
531
532        move || {
533            for i in 0..cols {
534                module.vec_znx_rsh(base2k, base2k - 1, &mut res, i, &a, i, scratch.borrow());
535            }
536            black_box(());
537        }
538    }
539
540    for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
541        let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2]));
542        let mut runner = runner::<B>(params);
543        group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
544    }
545
546    group.finish();
547}
548
549#[cfg(test)]
550mod tests {
551    use crate::{
552        layouts::{FillUniform, VecZnx, ZnxView},
553        reference::{
554            vec_znx::{
555                vec_znx_copy, vec_znx_lsh, vec_znx_lsh_inplace, vec_znx_normalize_inplace, vec_znx_rsh, vec_znx_rsh_inplace,
556                vec_znx_sub_inplace,
557            },
558            znx::ZnxRef,
559        },
560        source::Source,
561    };
562
563    #[test]
564    fn test_vec_znx_lsh() {
565        let n: usize = 8;
566        let cols: usize = 2;
567        let size: usize = 7;
568
569        let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
570        let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
571        let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
572
573        let mut source: Source = Source::new([0u8; 32]);
574
575        let mut carry: Vec<i64> = vec![0i64; n];
576
577        let base2k: usize = 50;
578
579        for k in 0..256 {
580            a.fill_uniform(50, &mut source);
581
582            for i in 0..cols {
583                vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut a, i, &mut carry);
584                vec_znx_copy::<_, _, ZnxRef>(&mut res_ref, i, &a, i);
585            }
586
587            for i in 0..cols {
588                vec_znx_lsh_inplace::<_, ZnxRef>(base2k, k, &mut res_ref, i, &mut carry);
589                vec_znx_lsh::<_, _, ZnxRef>(base2k, k, &mut res_test, i, &a, i, &mut carry);
590                vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut res_test, i, &mut carry);
591            }
592
593            assert_eq!(res_ref, res_test);
594        }
595    }
596
597    #[test]
598    fn test_vec_znx_rsh() {
599        let n: usize = 8;
600        let cols: usize = 2;
601
602        let res_size: usize = 7;
603
604        let mut res_ref: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
605        let mut res_test: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, res_size);
606
607        let mut carry: Vec<i64> = vec![0i64; n];
608
609        let base2k: usize = 50;
610
611        let mut source: Source = Source::new([0u8; 32]);
612
613        let zero: Vec<i64> = vec![0i64; n];
614
615        for a_size in [res_size - 1, res_size, res_size + 1] {
616            let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, a_size);
617
618            for k in 0..res_size * base2k {
619                a.fill_uniform(50, &mut source);
620
621                for i in 0..cols {
622                    vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut a, i, &mut carry);
623                    vec_znx_copy::<_, _, ZnxRef>(&mut res_ref, i, &a, i);
624                }
625
626                res_test.fill_uniform(50, &mut source);
627
628                for j in 0..cols {
629                    vec_znx_rsh_inplace::<_, ZnxRef>(base2k, k, &mut res_ref, j, &mut carry);
630                    vec_znx_rsh::<_, _, ZnxRef>(base2k, k, &mut res_test, j, &a, j, &mut carry);
631                }
632
633                for j in 0..cols {
634                    vec_znx_lsh_inplace::<_, ZnxRef>(base2k, k, &mut res_ref, j, &mut carry);
635                    vec_znx_lsh_inplace::<_, ZnxRef>(base2k, k, &mut res_test, j, &mut carry);
636                }
637
638                // Case where res has enough to fully store a right shifted without any loss
639                // In this case we can check exact equality.
640                if a_size + k.div_ceil(base2k) <= res_size {
641                    assert_eq!(res_ref, res_test);
642
643                    for i in 0..cols {
644                        for j in 0..a_size {
645                            assert_eq!(res_ref.at(i, j), a.at(i, j), "r0 {} {}", i, j);
646                            assert_eq!(res_test.at(i, j), a.at(i, j), "r1 {} {}", i, j);
647                        }
648
649                        for j in a_size..res_size {
650                            assert_eq!(res_ref.at(i, j), zero, "r0 {} {}", i, j);
651                            assert_eq!(res_test.at(i, j), zero, "r1 {} {}", i, j);
652                        }
653                    }
654                // Some loss occures, either because a initially has more precision than res
655                // or because the storage of the right shift of a requires more precision than
656                // res.
657                } else {
658                    for j in 0..cols {
659                        vec_znx_sub_inplace::<_, _, ZnxRef>(&mut res_ref, j, &a, j);
660                        vec_znx_sub_inplace::<_, _, ZnxRef>(&mut res_test, j, &a, j);
661
662                        vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut res_ref, j, &mut carry);
663                        vec_znx_normalize_inplace::<_, ZnxRef>(base2k, &mut res_test, j, &mut carry);
664
665                        assert!(res_ref.stats(base2k, j).std().log2() - (k as f64) <= (k * base2k) as f64);
666                        assert!(res_test.stats(base2k, j).std().log2() - (k as f64) <= (k * base2k) as f64);
667                    }
668                }
669            }
670        }
671    }
672}