1use std::hint::black_box;
2
3use criterion::{BenchmarkId, Criterion};
4
5use crate::{
6 api::{ModuleNew, ScratchOwnedAlloc, ScratchOwnedBorrow, VecZnxNormalize, VecZnxNormalizeInplace, VecZnxNormalizeTmpBytes},
7 layouts::{Backend, FillUniform, Module, ScratchOwned, VecZnx, VecZnxToMut, VecZnxToRef, ZnxInfos, ZnxView, ZnxViewMut},
8 reference::znx::{
9 ZnxAddInplace, ZnxCopy, ZnxExtractDigitAddMul, ZnxMulPowerOfTwoInplace, ZnxNormalizeDigit, ZnxNormalizeFinalStep,
10 ZnxNormalizeFinalStepInplace, ZnxNormalizeFirstStep, ZnxNormalizeFirstStepCarryOnly, ZnxNormalizeFirstStepInplace,
11 ZnxNormalizeMiddleStep, ZnxNormalizeMiddleStepCarryOnly, ZnxNormalizeMiddleStepInplace, ZnxZero,
12 },
13 source::Source,
14};
15
16pub fn vec_znx_normalize_tmp_bytes(n: usize) -> usize {
17 2 * n * size_of::<i64>()
18}
19
20pub fn vec_znx_normalize<R, A, ZNXARI>(
21 res_base2k: usize,
22 res: &mut R,
23 res_col: usize,
24 a_base2k: usize,
25 a: &A,
26 a_col: usize,
27 carry: &mut [i64],
28) where
29 R: VecZnxToMut,
30 A: VecZnxToRef,
31 ZNXARI: ZnxZero
32 + ZnxCopy
33 + ZnxAddInplace
34 + ZnxMulPowerOfTwoInplace
35 + ZnxNormalizeFirstStepCarryOnly
36 + ZnxNormalizeMiddleStepCarryOnly
37 + ZnxNormalizeMiddleStep
38 + ZnxNormalizeFinalStep
39 + ZnxNormalizeFirstStep
40 + ZnxExtractDigitAddMul
41 + ZnxNormalizeDigit,
42{
43 let mut res: VecZnx<&mut [u8]> = res.to_mut();
44 let a: VecZnx<&[u8]> = a.to_ref();
45
46 #[cfg(debug_assertions)]
47 {
48 assert!(carry.len() >= 2 * res.n());
49 assert_eq!(res.n(), a.n());
50 }
51
52 let n: usize = res.n();
53 let res_size: usize = res.size();
54 let a_size: usize = a.size();
55
56 let carry: &mut [i64] = &mut carry[..2 * n];
57
58 if res_base2k == a_base2k {
59 if a_size > res_size {
60 for j in (res_size..a_size).rev() {
61 if j == a_size - 1 {
62 ZNXARI::znx_normalize_first_step_carry_only(res_base2k, 0, a.at(a_col, j), carry);
63 } else {
64 ZNXARI::znx_normalize_middle_step_carry_only(res_base2k, 0, a.at(a_col, j), carry);
65 }
66 }
67
68 for j in (1..res_size).rev() {
69 ZNXARI::znx_normalize_middle_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
70 }
71
72 ZNXARI::znx_normalize_final_step(res_base2k, 0, res.at_mut(res_col, 0), a.at(a_col, 0), carry);
73 } else {
74 for j in (0..a_size).rev() {
75 if j == a_size - 1 {
76 ZNXARI::znx_normalize_first_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
77 } else if j == 0 {
78 ZNXARI::znx_normalize_final_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
79 } else {
80 ZNXARI::znx_normalize_middle_step(res_base2k, 0, res.at_mut(res_col, j), a.at(a_col, j), carry);
81 }
82 }
83
84 for j in a_size..res_size {
85 ZNXARI::znx_zero(res.at_mut(res_col, j));
86 }
87 }
88 } else {
89 let (a_norm, carry) = carry.split_at_mut(n);
90
91 let res_min_size: usize = (a_size * a_base2k).div_ceil(res_base2k).min(res_size);
93
94 let a_min_size: usize = (res_size * res_base2k).div_ceil(a_base2k).min(a_size);
96
97 for j in (a_min_size..a_size).rev() {
99 if j == a_size - 1 {
100 ZNXARI::znx_normalize_first_step_carry_only(a_base2k, 0, a.at(a_col, j), carry);
101 } else {
102 ZNXARI::znx_normalize_middle_step_carry_only(a_base2k, 0, a.at(a_col, j), carry);
103 }
104 }
105
106 if a_min_size == a_size {
107 ZNXARI::znx_zero(carry);
108 }
109
110 let a_prec: usize = a_min_size * a_base2k;
112
113 let res_prec: usize = res_min_size * res_base2k;
115
116 let mut res_idx: usize = res_min_size - 1;
118
119 let mut res_left: usize = res_base2k;
122
123 for j in 0..res_size {
124 ZNXARI::znx_zero(res.at_mut(res_col, j));
125 }
126
127 for j in (0..a_min_size).rev() {
128 let mut a_left: usize = a_base2k;
131
132 if j != 0 {
136 ZNXARI::znx_normalize_middle_step(a_base2k, 0, a_norm, a.at(a_col, j), carry);
137 } else {
138 ZNXARI::znx_normalize_final_step(a_base2k, 0, a_norm, a.at(a_col, j), carry);
139 }
140
141 if j == a_min_size - 1 {
148 if a_prec > res_prec {
149 ZNXARI::znx_mul_power_of_two_inplace(res_prec as i64 - a_prec as i64, a_norm);
150 a_left -= a_prec - res_prec;
151 } else if res_prec > a_prec {
152 res_left -= res_prec - a_prec;
153 }
154 }
155
156 loop {
158 let a_take: usize = a_base2k.min(a_left).min(res_left);
160
161 let res_slice: &mut [i64] = res.at_mut(res_col, res_idx);
163
164 let lsh: usize = res_base2k - res_left;
166
167 ZNXARI::znx_extract_digit_addmul(a_take, lsh, res_slice, a_norm);
170
171 a_left -= a_take;
173 res_left -= a_take;
174
175 if res_left == 0 {
179 res_left += res_base2k;
181
182 ZNXARI::znx_normalize_digit(res_base2k, res_slice, a_norm);
184
185 if res_idx == 0 {
190 ZNXARI::znx_add_inplace(carry, a_norm);
191 break;
192 }
193
194 res_idx -= 1
196 }
197
198 if a_left == 0 {
200 ZNXARI::znx_add_inplace(carry, a_norm);
201 break;
202 }
203 }
204 }
205 }
206}
207
208pub fn vec_znx_normalize_inplace<R: VecZnxToMut, ZNXARI>(base2k: usize, res: &mut R, res_col: usize, carry: &mut [i64])
209where
210 ZNXARI: ZnxNormalizeFirstStepInplace + ZnxNormalizeMiddleStepInplace + ZnxNormalizeFinalStepInplace,
211{
212 let mut res: VecZnx<&mut [u8]> = res.to_mut();
213
214 #[cfg(debug_assertions)]
215 {
216 assert!(carry.len() >= res.n());
217 }
218
219 let res_size: usize = res.size();
220
221 for j in (0..res_size).rev() {
222 if j == res_size - 1 {
223 ZNXARI::znx_normalize_first_step_inplace(base2k, 0, res.at_mut(res_col, j), carry);
224 } else if j == 0 {
225 ZNXARI::znx_normalize_final_step_inplace(base2k, 0, res.at_mut(res_col, j), carry);
226 } else {
227 ZNXARI::znx_normalize_middle_step_inplace(base2k, 0, res.at_mut(res_col, j), carry);
228 }
229 }
230}
231
232pub fn bench_vec_znx_normalize<B: Backend>(c: &mut Criterion, label: &str)
233where
234 Module<B>: VecZnxNormalize<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
235 ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
236{
237 let group_name: String = format!("vec_znx_normalize::{label}");
238
239 let mut group = c.benchmark_group(group_name);
240
241 fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
242 where
243 Module<B>: VecZnxNormalize<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
244 ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
245 {
246 let n: usize = 1 << params[0];
247 let cols: usize = params[1];
248 let size: usize = params[2];
249
250 let module: Module<B> = Module::<B>::new(n as u64);
251
252 let base2k: usize = 50;
253
254 let mut source: Source = Source::new([0u8; 32]);
255
256 let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
257 let mut res: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
258
259 a.fill_uniform(50, &mut source);
261 res.fill_uniform(50, &mut source);
262
263 let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes());
264
265 move || {
266 for i in 0..cols {
267 module.vec_znx_normalize(base2k, &mut res, i, base2k, &a, i, scratch.borrow());
268 }
269 black_box(());
270 }
271 }
272
273 for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
274 let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
275 let mut runner = runner::<B>(params);
276 group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
277 }
278
279 group.finish();
280}
281
282pub fn bench_vec_znx_normalize_inplace<B: Backend>(c: &mut Criterion, label: &str)
283where
284 Module<B>: VecZnxNormalizeInplace<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
285 ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
286{
287 let group_name: String = format!("vec_znx_normalize_inplace::{label}");
288
289 let mut group = c.benchmark_group(group_name);
290
291 fn runner<B: Backend>(params: [usize; 3]) -> impl FnMut()
292 where
293 Module<B>: VecZnxNormalizeInplace<B> + ModuleNew<B> + VecZnxNormalizeTmpBytes,
294 ScratchOwned<B>: ScratchOwnedAlloc<B> + ScratchOwnedBorrow<B>,
295 {
296 let n: usize = 1 << params[0];
297 let cols: usize = params[1];
298 let size: usize = params[2];
299
300 let module: Module<B> = Module::<B>::new(n as u64);
301
302 let base2k: usize = 50;
303
304 let mut source: Source = Source::new([0u8; 32]);
305
306 let mut a: VecZnx<Vec<u8>> = VecZnx::alloc(n, cols, size);
307
308 a.fill_uniform(50, &mut source);
310
311 let mut scratch: ScratchOwned<B> = ScratchOwned::alloc(module.vec_znx_normalize_tmp_bytes());
312
313 move || {
314 for i in 0..cols {
315 module.vec_znx_normalize_inplace(base2k, &mut a, i, scratch.borrow());
316 }
317 black_box(());
318 }
319 }
320
321 for params in [[10, 2, 2], [11, 2, 4], [12, 2, 8], [13, 2, 16], [14, 2, 32]] {
322 let id: BenchmarkId = BenchmarkId::from_parameter(format!("{}x({}x{})", 1 << params[0], params[1], params[2],));
323 let mut runner = runner::<B>(params);
324 group.bench_with_input(id, &(), |b, _| b.iter(&mut runner));
325 }
326
327 group.finish();
328}
329
330#[test]
331fn test_vec_znx_normalize_conv() {
332 let n: usize = 8;
333
334 let mut carry: Vec<i64> = vec![0i64; 2 * n];
335
336 use crate::reference::znx::ZnxRef;
337 use rug::ops::SubAssignRound;
338 use rug::{Float, float::Round};
339
340 let mut source: Source = Source::new([1u8; 32]);
341
342 let prec: usize = 128;
343
344 let mut data: Vec<i128> = vec![0i128; n];
345
346 data.iter_mut().for_each(|x| *x = source.next_i128());
347
348 for start_base2k in 1..50 {
349 for end_base2k in 1..50 {
350 let end_size: usize = prec.div_ceil(end_base2k);
351
352 let mut want: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, end_size);
353 want.encode_vec_i128(end_base2k, 0, prec, &data);
354 vec_znx_normalize_inplace::<_, ZnxRef>(end_base2k, &mut want, 0, &mut carry);
355
356 let mut tmp: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, prec.div_ceil(start_base2k));
358 tmp.encode_vec_i128(start_base2k, 0, prec, &data);
359
360 vec_znx_normalize_inplace::<_, ZnxRef>(start_base2k, &mut tmp, 0, &mut carry);
361
362 let mut data_tmp: Vec<Float> = (0..n).map(|_| Float::with_val(prec as u32, 0)).collect();
363 tmp.decode_vec_float(start_base2k, 0, &mut data_tmp);
364
365 let mut have: VecZnx<Vec<u8>> = VecZnx::alloc(n, 1, end_size);
366 vec_znx_normalize::<_, _, ZnxRef>(end_base2k, &mut have, 0, start_base2k, &tmp, 0, &mut carry);
367
368 let out_prec: u32 = (end_size * end_base2k) as u32;
369
370 let mut data_want: Vec<Float> = (0..n).map(|_| Float::with_val(out_prec, 0)).collect();
371 let mut data_res: Vec<Float> = (0..n).map(|_| Float::with_val(out_prec, 0)).collect();
372
373 have.decode_vec_float(end_base2k, 0, &mut data_want);
374 want.decode_vec_float(end_base2k, 0, &mut data_res);
375
376 for i in 0..n {
377 let mut err: Float = data_want[i].clone();
378 err.sub_assign_round(&data_res[i], Round::Nearest);
379 err = err.abs();
380
381 let err_log2: f64 = err
382 .clone()
383 .max(&Float::with_val(prec as u32, 1e-60))
384 .log2()
385 .to_f64();
386
387 assert!(
388 err_log2 <= -(out_prec as f64) + 1.,
389 "{} {}",
390 err_log2,
391 -(out_prec as f64) + 1.
392 )
393 }
394 }
395 }
396}