poulpy_hal/reference/vec_znx/
rotate.rs

1use std::hint::black_box;
2
3use criterion::{BenchmarkId, Criterion};
4
5use crate::{
6    api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxRotate, VecZnxRotateInplace, VecZnxRotateInplaceTmpBytes},
7    layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
8    reference::znx::{ZnxCopy, ZnxRotate, ZnxZero},
9    source::Source,
10};
11
12pub fn vec_znx_rotate_inplace_tmp_bytes(n: usize) -> usize {
13    n * size_of::<i64>()
14}
15
16pub fn vec_znx_rotate<R, A, ZNXARI>(p: i64, res: &mut R, res_col: usize, a: &A, a_col: usize)
17where
18    R: VecZnxToMut,
19    A: VecZnxToRef,
20    ZNXARI: ZnxRotate + ZnxZero,
21{
22    let mut res: VecZnx<&mut [u8]> = res.to_mut();
23    let a: VecZnx<&[u8]> = a.to_ref();
24
25    #[cfg(debug_assertions)]
26    {
27        assert_eq!(res.n(), a.n())
28    }
29
30    let res_size: usize = res.size();
31    let a_size: usize = a.size();
32
33    let min_size: usize = res_size.min(a_size);
34
35    for j in 0..min_size {
36        ZNXARI::znx_rotate(p, res.at_mut(res_col, j), a.at(a_col, j))
37    }
38
39    for j in min_size..res_size {
40        ZNXARI::znx_zero(res.at_mut(res_col, j));
41    }
42}
43
44pub fn vec_znx_rotate_inplace<R, ZNXARI>(p: i64, res: &mut R, res_col: usize, tmp: &mut [i64])
45where
46    R: VecZnxToMut,
47    ZNXARI: ZnxRotate + ZnxCopy,
48{
49    let mut res: VecZnx<&mut [u8]> = res.to_mut();
50    #[cfg(debug_assertions)]
51    {
52        assert_eq!(res.n(), tmp.len());
53    }
54    for j in 0..res.size() {
55        ZNXARI::znx_rotate(p, tmp, res.at(res_col, j));
56        ZNXARI::znx_copy(res.at_mut(res_col, j), tmp);
57    }
58}
59
60pub fn bench_vec_znx_rotate<B: Backend>(c: &mut Criterion, label: &str)
61where
62    Module<B>: VecZnxRotate + ModuleNew<B>,
63{
64    let group_name: String = format!("vec_znx_rotate::{label}");
65
66    let mut group = c.benchmark_group(group_name);
67
68    fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
69    where
70        Module<B>: VecZnxRotate + ModuleNew<B>,
71    {
72        let n: usize = 1 << params[0];
73        let cols: usize = params[1];
74        let size: usize = params[2];
75
76        let module: Module<B> = Module::<B>::new(n as u64);
77
78        let mut source: Source = Source::new([0u8; 32]);
79
80        let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
81        let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
82
83        // Fill a with random i64
84        a.fill_uniform(50, &mut source);
85        res.fill_uniform(50, &mut source);
86
87        move || {
88            for i in 0..cols {
89                module.vec_znx_rotate(-7, &mut res, i, &a, i);
90            }
91            black_box(());
92        }
93    }
94
95    for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
96        let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
97        let mut runner = runner::<B>(params);
98        group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
99    }
100
101    group.finish();
102}
103
104pub fn bench_vec_znx_rotate_inplace<B: Backend>(c: &mut Criterion, label: &str)
105where
106    Module<B>: VecZnxRotateInplace<B> + VecZnxRotateInplaceTmpBytes + ModuleNew<B>,
107    ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
108{
109    let group_name: String = format!("vec_znx_rotate_inplace::{label}");
110
111    let mut group = c.benchmark_group(group_name);
112
113    fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
114    where
115        Module<B>: VecZnxRotateInplace<B> + ModuleNew<B> + VecZnxRotateInplaceTmpBytes,
116        ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
117    {
118        let n: usize = 1 << params[0];
119        let cols: usize = params[1];
120        let size: usize = params[2];
121
122        let module: Module<B> = Module::<B>::new(n as u64);
123
124        let mut source: Source = Source::new([0u8; 32]);
125
126        let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
127
128        let mut scratch = ScratchOwned::alloc(module.vec_znx_rotate_inplace_tmp_bytes());
129
130        // Fill a with random i64
131        res.fill_uniform(50, &mut source);
132
133        move || {
134            for i in 0..cols {
135                module.vec_znx_rotate_inplace(-7, &mut res, i, scratch.borrow());
136            }
137            black_box(());
138        }
139    }
140
141    for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
142        let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
143        let mut runner = runner::<B>(params);
144        group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
145    }
146
147    group.finish();
148}