1use aligned_vec::avec;
2
3#[allow(unused_imports)]
4use pulp::*;
5
6pub(crate) use crate::native32::mul_mod32;
7
8#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
9pub(crate) use crate::native32::mul_mod32_avx2;
10#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
11#[cfg(feature = "nightly")]
12pub(crate) use crate::native32::{mul_mod32_avx512, mul_mod52_avx512};
13
14#[derive(Clone, Debug)]
16pub struct Plan32(
17 crate::prime32::Plan,
18 crate::prime32::Plan,
19 crate::prime32::Plan,
20 crate::prime32::Plan,
21 crate::prime32::Plan,
22);
23
24#[cfg(all(feature = "nightly", any(target_arch = "x86", target_arch = "x86_64")))]
27#[cfg_attr(docsrs, doc(cfg(feature = "nightly")))]
28#[derive(Clone, Debug)]
29pub struct Plan52(
30 crate::prime64::Plan,
31 crate::prime64::Plan,
32 crate::prime64::Plan,
33 crate::V4IFma,
34);
35
36#[inline(always)]
37pub(crate) fn mul_mod64(p_neg: u64, a: u64, b: u64, b_shoup: u64) -> u64 {
38 let q = ((a as u128 * b_shoup as u128) >> 64) as u64;
39 let r = a.wrapping_mul(b).wrapping_add(p_neg.wrapping_mul(q));
40 r.min(r.wrapping_add(p_neg))
41}
42
43#[inline(always)]
44#[allow(dead_code)]
45fn reconstruct_32bit_01234(mod_p0: u32, mod_p1: u32, mod_p2: u32, mod_p3: u32, mod_p4: u32) -> u64 {
46 use crate::primes32::*;
47
48 let v0 = mod_p0;
49 let v1 = mul_mod32(P1, P0_INV_MOD_P1, 2 * P1 + mod_p1 - v0);
50 let v2 = mul_mod32(
51 P2,
52 P01_INV_MOD_P2,
53 2 * P2 + mod_p2 - (v0 + mul_mod32(P2, P0, v1)),
54 );
55 let v3 = mul_mod32(
56 P3,
57 P012_INV_MOD_P3,
58 2 * P3 + mod_p3 - (v0 + mul_mod32(P3, P0, v1 + mul_mod32(P3, P1, v2))),
59 );
60 let v4 = mul_mod32(
61 P4,
62 P0123_INV_MOD_P4,
63 2 * P4 + mod_p4
64 - (v0 + mul_mod32(P4, P0, v1 + mul_mod32(P4, P1, v2 + mul_mod32(P4, P2, v3)))),
65 );
66
67 let sign = v4 > (P4 / 2);
68
69 const _0: u64 = P0 as u64;
70 const _01: u64 = _0.wrapping_mul(P1 as u64);
71 const _012: u64 = _01.wrapping_mul(P2 as u64);
72 const _0123: u64 = _012.wrapping_mul(P3 as u64);
73 const _01234: u64 = _0123.wrapping_mul(P4 as u64);
74
75 let pos = (v0 as u64)
76 .wrapping_add((v1 as u64).wrapping_mul(_0))
77 .wrapping_add((v2 as u64).wrapping_mul(_01))
78 .wrapping_add((v3 as u64).wrapping_mul(_012))
79 .wrapping_add((v4 as u64).wrapping_mul(_0123));
80
81 let neg = pos.wrapping_sub(_01234);
82
83 if sign {
84 neg
85 } else {
86 pos
87 }
88}
89
90#[inline(always)]
91fn reconstruct_32bit_01234_v2(
92 mod_p0: u32,
93 mod_p1: u32,
94 mod_p2: u32,
95 mod_p3: u32,
96 mod_p4: u32,
97) -> u64 {
98 use crate::primes32::*;
99
100 let mod_p12 = {
101 let v1 = mod_p1;
102 let v2 = mul_mod32(P2, P1_INV_MOD_P2, 2 * P2 + mod_p2 - v1);
103 v1 as u64 + (v2 as u64 * P1 as u64)
104 };
105 let mod_p34 = {
106 let v3 = mod_p3;
107 let v4 = mul_mod32(P4, P3_INV_MOD_P4, 2 * P4 + mod_p4 - v3);
108 v3 as u64 + (v4 as u64 * P3 as u64)
109 };
110
111 let v0 = mod_p0 as u64;
112 let v12 = mul_mod64(
113 P12.wrapping_neg(),
114 2 * P12 + mod_p12 - v0,
115 P0_INV_MOD_P12,
116 P0_INV_MOD_P12_SHOUP,
117 );
118 let v34 = mul_mod64(
119 P34.wrapping_neg(),
120 2 * P34 + mod_p34 - (v0 + mul_mod64(P34.wrapping_neg(), v12, P0 as u64, P0_MOD_P34_SHOUP)),
121 P012_INV_MOD_P34,
122 P012_INV_MOD_P34_SHOUP,
123 );
124
125 let sign = v34 > (P34 / 2);
126
127 const _0: u64 = P0 as u64;
128 const _012: u64 = _0.wrapping_mul(P12);
129 const _01234: u64 = _012.wrapping_mul(P34);
130
131 let pos = v0
132 .wrapping_add(v12.wrapping_mul(_0))
133 .wrapping_add(v34.wrapping_mul(_012));
134 let neg = pos.wrapping_sub(_01234);
135
136 if sign {
137 neg
138 } else {
139 pos
140 }
141}
142
143#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
144#[inline(always)]
145pub(crate) fn mul_mod32_v2_avx2(
146 simd: crate::V3,
147 p: u64x4,
148 a: u64x4,
149 b: u64x4,
150 b_shoup: u64x4,
151) -> u64x4 {
152 let shoup_q = simd.shr_const_u64x4::<32>(simd.mul_low_32_bits_u64x4(a, b_shoup));
153 let t = simd.and_u64x4(
154 simd.splat_u64x4((1u64 << 32) - 1),
155 simd.wrapping_sub_u64x4(
156 simd.mul_low_32_bits_u64x4(a, b),
157 simd.mul_low_32_bits_u64x4(shoup_q, p),
158 ),
159 );
160 simd.small_mod_u64x4(p, t)
161}
162
163#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
164#[cfg(feature = "nightly")]
165#[inline(always)]
166pub(crate) fn mul_mod32_v2_avx512(
167 simd: crate::V4IFma,
168 p: u64x8,
169 a: u64x8,
170 b: u64x8,
171 b_shoup: u64x8,
172) -> u64x8 {
173 let shoup_q = simd.shr_const_u64x8::<32>(simd.mul_low_32_bits_u64x8(a, b_shoup));
174 let t = simd.and_u64x8(
175 simd.splat_u64x8((1u64 << 32) - 1),
176 simd.wrapping_sub_u64x8(
177 simd.mul_low_32_bits_u64x8(a, b),
178 simd.mul_low_32_bits_u64x8(shoup_q, p),
179 ),
180 );
181 simd.small_mod_u64x8(p, t)
182}
183
184#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
185#[inline(always)]
186pub(crate) fn mul_mod64_avx2(
187 simd: crate::V3,
188 p: u64x4,
189 a: u64x4,
190 b: u64x4,
191 b_shoup: u64x4,
192) -> u64x4 {
193 let q = simd.widening_mul_u64x4(a, b_shoup).1;
194 let r = simd.wrapping_sub_u64x4(
195 simd.widening_mul_u64x4(a, b).0,
196 simd.widening_mul_u64x4(p, q).0,
197 );
198 simd.small_mod_u64x4(p, r)
199}
200
201#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
202#[cfg(feature = "nightly")]
203#[inline(always)]
204pub(crate) fn mul_mod64_avx512(
205 simd: crate::V4IFma,
206 p: u64x8,
207 a: u64x8,
208 b: u64x8,
209 b_shoup: u64x8,
210) -> u64x8 {
211 let q = simd.widening_mul_u64x8(a, b_shoup).1;
212 let r = simd.wrapping_sub_u64x8(simd.wrapping_mul_u64x8(a, b), simd.wrapping_mul_u64x8(p, q));
213 simd.small_mod_u64x8(p, r)
214}
215
216#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
217#[inline(always)]
218fn reconstruct_32bit_01234_v2_avx2(
219 simd: crate::V3,
220 mod_p0: u32x4,
221 mod_p1: u32x4,
222 mod_p2: u32x4,
223 mod_p3: u32x4,
224 mod_p4: u32x4,
225) -> u64x4 {
226 use crate::primes32::*;
227
228 let p0 = simd.splat_u64x4(P0 as u64);
229 let p1 = simd.splat_u64x4(P1 as u64);
230 let p2 = simd.splat_u64x4(P2 as u64);
231 let p3 = simd.splat_u64x4(P3 as u64);
232 let p4 = simd.splat_u64x4(P4 as u64);
233 let p12 = simd.splat_u64x4(P12);
234 let p34 = simd.splat_u64x4(P34);
235 let p012 = simd.splat_u64x4((P0 as u64).wrapping_mul(P12));
236 let p01234 = simd.splat_u64x4((P0 as u64).wrapping_mul(P12).wrapping_mul(P34));
237
238 let two_p2 = simd.splat_u64x4(2 * P2 as u64);
239 let two_p4 = simd.splat_u64x4(2 * P4 as u64);
240 let two_p12 = simd.splat_u64x4(2 * P12);
241 let two_p34 = simd.splat_u64x4(2 * P34);
242 let half_p34 = simd.splat_u64x4(P34 / 2);
243
244 let p0_inv_mod_p12 = simd.splat_u64x4(P0_INV_MOD_P12);
245 let p0_inv_mod_p12_shoup = simd.splat_u64x4(P0_INV_MOD_P12_SHOUP);
246 let p1_inv_mod_p2 = simd.splat_u64x4(P1_INV_MOD_P2 as u64);
247 let p1_inv_mod_p2_shoup = simd.splat_u64x4(P1_INV_MOD_P2_SHOUP as u64);
248 let p3_inv_mod_p4 = simd.splat_u64x4(P3_INV_MOD_P4 as u64);
249 let p3_inv_mod_p4_shoup = simd.splat_u64x4(P3_INV_MOD_P4_SHOUP as u64);
250
251 let p012_inv_mod_p34 = simd.splat_u64x4(P012_INV_MOD_P34);
252 let p012_inv_mod_p34_shoup = simd.splat_u64x4(P012_INV_MOD_P34_SHOUP);
253 let p0_mod_p34_shoup = simd.splat_u64x4(P0_MOD_P34_SHOUP);
254
255 let mod_p0 = simd.convert_u32x4_to_u64x4(mod_p0);
256 let mod_p1 = simd.convert_u32x4_to_u64x4(mod_p1);
257 let mod_p2 = simd.convert_u32x4_to_u64x4(mod_p2);
258 let mod_p3 = simd.convert_u32x4_to_u64x4(mod_p3);
259 let mod_p4 = simd.convert_u32x4_to_u64x4(mod_p4);
260
261 let mod_p12 = {
262 let v1 = mod_p1;
263 let v2 = mul_mod32_v2_avx2(
264 simd,
265 p2,
266 simd.wrapping_sub_u64x4(simd.wrapping_add_u64x4(two_p2, mod_p2), v1),
267 p1_inv_mod_p2,
268 p1_inv_mod_p2_shoup,
269 );
270 simd.wrapping_add_u64x4(v1, simd.mul_low_32_bits_u64x4(v2, p1))
271 };
272 let mod_p34 = {
273 let v3 = mod_p3;
274 let v4 = mul_mod32_v2_avx2(
275 simd,
276 p4,
277 simd.wrapping_sub_u64x4(simd.wrapping_add_u64x4(two_p4, mod_p4), v3),
278 p3_inv_mod_p4,
279 p3_inv_mod_p4_shoup,
280 );
281 simd.wrapping_add_u64x4(v3, simd.mul_low_32_bits_u64x4(v4, p3))
282 };
283
284 let v0 = mod_p0;
285 let v12 = mul_mod64_avx2(
286 simd,
287 p12,
288 simd.wrapping_sub_u64x4(simd.wrapping_add_u64x4(two_p12, mod_p12), v0),
289 p0_inv_mod_p12,
290 p0_inv_mod_p12_shoup,
291 );
292 let v34 = mul_mod64_avx2(
293 simd,
294 p34,
295 simd.wrapping_sub_u64x4(
296 simd.wrapping_add_u64x4(two_p34, mod_p34),
297 simd.wrapping_add_u64x4(v0, mul_mod64_avx2(simd, p34, v12, p0, p0_mod_p34_shoup)),
298 ),
299 p012_inv_mod_p34,
300 p012_inv_mod_p34_shoup,
301 );
302
303 let sign = simd.cmp_gt_u64x4(v34, half_p34);
304 let pos = v0;
305 let pos = simd.wrapping_add_u64x4(
306 pos,
307 simd.wrapping_mul_lhs_with_low_32_bits_of_rhs_u64x4(v12, p0),
308 );
309 let pos = simd.wrapping_add_u64x4(pos, simd.widening_mul_u64x4(v34, p012).0);
310 let neg = simd.wrapping_sub_u64x4(pos, p01234);
311 simd.select_u64x4(sign, neg, pos)
312}
313
314#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
315#[allow(dead_code)]
316#[inline(always)]
317fn reconstruct_32bit_01234_avx2(
318 simd: crate::V3,
319 mod_p0: u32x8,
320 mod_p1: u32x8,
321 mod_p2: u32x8,
322 mod_p3: u32x8,
323 mod_p4: u32x8,
324) -> [u64x4; 2] {
325 use crate::primes32::*;
326
327 let p0 = simd.splat_u32x8(P0);
328 let p1 = simd.splat_u32x8(P1);
329 let p2 = simd.splat_u32x8(P2);
330 let p3 = simd.splat_u32x8(P3);
331 let p4 = simd.splat_u32x8(P4);
332 let two_p1 = simd.splat_u32x8(2 * P1);
333 let two_p2 = simd.splat_u32x8(2 * P2);
334 let two_p3 = simd.splat_u32x8(2 * P3);
335 let two_p4 = simd.splat_u32x8(2 * P4);
336 let half_p4 = simd.splat_u32x8(P4 / 2);
337
338 let p0_inv_mod_p1 = simd.splat_u32x8(P0_INV_MOD_P1);
339 let p0_inv_mod_p1_shoup = simd.splat_u32x8(P0_INV_MOD_P1_SHOUP);
340 let p0_mod_p2_shoup = simd.splat_u32x8(P0_MOD_P2_SHOUP);
341 let p0_mod_p3_shoup = simd.splat_u32x8(P0_MOD_P3_SHOUP);
342 let p1_mod_p3_shoup = simd.splat_u32x8(P1_MOD_P3_SHOUP);
343 let p0_mod_p4_shoup = simd.splat_u32x8(P0_MOD_P4_SHOUP);
344 let p1_mod_p4_shoup = simd.splat_u32x8(P1_MOD_P4_SHOUP);
345 let p2_mod_p4_shoup = simd.splat_u32x8(P2_MOD_P4_SHOUP);
346
347 let p01_inv_mod_p2 = simd.splat_u32x8(P01_INV_MOD_P2);
348 let p01_inv_mod_p2_shoup = simd.splat_u32x8(P01_INV_MOD_P2_SHOUP);
349 let p012_inv_mod_p3 = simd.splat_u32x8(P012_INV_MOD_P3);
350 let p012_inv_mod_p3_shoup = simd.splat_u32x8(P012_INV_MOD_P3_SHOUP);
351 let p0123_inv_mod_p4 = simd.splat_u32x8(P0123_INV_MOD_P4);
352 let p0123_inv_mod_p4_shoup = simd.splat_u32x8(P0123_INV_MOD_P4_SHOUP);
353
354 let p01 = simd.splat_u64x4((P0 as u64).wrapping_mul(P1 as u64));
355 let p012 = simd.splat_u64x4((P0 as u64).wrapping_mul(P1 as u64).wrapping_mul(P2 as u64));
356 let p0123 = simd.splat_u64x4(
357 (P0 as u64)
358 .wrapping_mul(P1 as u64)
359 .wrapping_mul(P2 as u64)
360 .wrapping_mul(P3 as u64),
361 );
362 let p01234 = simd.splat_u64x4(
363 (P0 as u64)
364 .wrapping_mul(P1 as u64)
365 .wrapping_mul(P2 as u64)
366 .wrapping_mul(P3 as u64)
367 .wrapping_mul(P4 as u64),
368 );
369
370 let v0 = mod_p0;
371 let v1 = mul_mod32_avx2(
372 simd,
373 p1,
374 simd.wrapping_sub_u32x8(simd.wrapping_add_u32x8(two_p1, mod_p1), v0),
375 p0_inv_mod_p1,
376 p0_inv_mod_p1_shoup,
377 );
378 let v2 = mul_mod32_avx2(
379 simd,
380 p2,
381 simd.wrapping_sub_u32x8(
382 simd.wrapping_add_u32x8(two_p2, mod_p2),
383 simd.wrapping_add_u32x8(v0, mul_mod32_avx2(simd, p2, v1, p0, p0_mod_p2_shoup)),
384 ),
385 p01_inv_mod_p2,
386 p01_inv_mod_p2_shoup,
387 );
388 let v3 = mul_mod32_avx2(
389 simd,
390 p3,
391 simd.wrapping_sub_u32x8(
392 simd.wrapping_add_u32x8(two_p3, mod_p3),
393 simd.wrapping_add_u32x8(
394 v0,
395 mul_mod32_avx2(
396 simd,
397 p3,
398 simd.wrapping_add_u32x8(v1, mul_mod32_avx2(simd, p3, v2, p1, p1_mod_p3_shoup)),
399 p0,
400 p0_mod_p3_shoup,
401 ),
402 ),
403 ),
404 p012_inv_mod_p3,
405 p012_inv_mod_p3_shoup,
406 );
407 let v4 = mul_mod32_avx2(
408 simd,
409 p4,
410 simd.wrapping_sub_u32x8(
411 simd.wrapping_add_u32x8(two_p4, mod_p4),
412 simd.wrapping_add_u32x8(
413 v0,
414 mul_mod32_avx2(
415 simd,
416 p4,
417 simd.wrapping_add_u32x8(
418 v1,
419 mul_mod32_avx2(
420 simd,
421 p4,
422 simd.wrapping_add_u32x8(
423 v2,
424 mul_mod32_avx2(simd, p4, v3, p2, p2_mod_p4_shoup),
425 ),
426 p1,
427 p1_mod_p4_shoup,
428 ),
429 ),
430 p0,
431 p0_mod_p4_shoup,
432 ),
433 ),
434 ),
435 p0123_inv_mod_p4,
436 p0123_inv_mod_p4_shoup,
437 );
438
439 let sign = simd.cmp_gt_u32x8(v4, half_p4);
440 let sign: [i32x4; 2] = pulp::cast(sign);
441 let sign0: m64x4 = unsafe { core::mem::transmute(simd.convert_i32x4_to_i64x4(sign[0])) };
443 let sign1: m64x4 = unsafe { core::mem::transmute(simd.convert_i32x4_to_i64x4(sign[1])) };
444
445 let v0: [u32x4; 2] = pulp::cast(v0);
446 let v1: [u32x4; 2] = pulp::cast(v1);
447 let v2: [u32x4; 2] = pulp::cast(v2);
448 let v3: [u32x4; 2] = pulp::cast(v3);
449 let v4: [u32x4; 2] = pulp::cast(v4);
450 let v00 = simd.convert_u32x4_to_u64x4(v0[0]);
451 let v01 = simd.convert_u32x4_to_u64x4(v0[1]);
452 let v10 = simd.convert_u32x4_to_u64x4(v1[0]);
453 let v11 = simd.convert_u32x4_to_u64x4(v1[1]);
454 let v20 = simd.convert_u32x4_to_u64x4(v2[0]);
455 let v21 = simd.convert_u32x4_to_u64x4(v2[1]);
456 let v30 = simd.convert_u32x4_to_u64x4(v3[0]);
457 let v31 = simd.convert_u32x4_to_u64x4(v3[1]);
458 let v40 = simd.convert_u32x4_to_u64x4(v4[0]);
459 let v41 = simd.convert_u32x4_to_u64x4(v4[1]);
460
461 let pos0 = v00;
462 let pos0 = simd.wrapping_add_u64x4(pos0, simd.mul_low_32_bits_u64x4(pulp::cast(p0), v10));
463 let pos0 = simd.wrapping_add_u64x4(
464 pos0,
465 simd.wrapping_mul_lhs_with_low_32_bits_of_rhs_u64x4(p01, v20),
466 );
467 let pos0 = simd.wrapping_add_u64x4(
468 pos0,
469 simd.wrapping_mul_lhs_with_low_32_bits_of_rhs_u64x4(p012, v30),
470 );
471 let pos0 = simd.wrapping_add_u64x4(
472 pos0,
473 simd.wrapping_mul_lhs_with_low_32_bits_of_rhs_u64x4(p0123, v40),
474 );
475
476 let pos1 = v01;
477 let pos1 = simd.wrapping_add_u64x4(pos1, simd.mul_low_32_bits_u64x4(pulp::cast(p0), v11));
478 let pos1 = simd.wrapping_add_u64x4(
479 pos1,
480 simd.wrapping_mul_lhs_with_low_32_bits_of_rhs_u64x4(p01, v21),
481 );
482 let pos1 = simd.wrapping_add_u64x4(
483 pos1,
484 simd.wrapping_mul_lhs_with_low_32_bits_of_rhs_u64x4(p012, v31),
485 );
486 let pos1 = simd.wrapping_add_u64x4(
487 pos1,
488 simd.wrapping_mul_lhs_with_low_32_bits_of_rhs_u64x4(p0123, v41),
489 );
490
491 let neg0 = simd.wrapping_sub_u64x4(pos0, p01234);
492 let neg1 = simd.wrapping_sub_u64x4(pos1, p01234);
493
494 [
495 simd.select_u64x4(sign0, neg0, pos0),
496 simd.select_u64x4(sign1, neg1, pos1),
497 ]
498}
499
500#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
501#[cfg(feature = "nightly")]
502#[allow(dead_code)]
503#[inline(always)]
504fn reconstruct_32bit_01234_avx512(
505 simd: crate::V4IFma,
506 mod_p0: u32x16,
507 mod_p1: u32x16,
508 mod_p2: u32x16,
509 mod_p3: u32x16,
510 mod_p4: u32x16,
511) -> [u64x8; 2] {
512 use crate::primes32::*;
513
514 let p0 = simd.splat_u32x16(P0);
515 let p1 = simd.splat_u32x16(P1);
516 let p2 = simd.splat_u32x16(P2);
517 let p3 = simd.splat_u32x16(P3);
518 let p4 = simd.splat_u32x16(P4);
519 let two_p1 = simd.splat_u32x16(2 * P1);
520 let two_p2 = simd.splat_u32x16(2 * P2);
521 let two_p3 = simd.splat_u32x16(2 * P3);
522 let two_p4 = simd.splat_u32x16(2 * P4);
523 let half_p4 = simd.splat_u32x16(P4 / 2);
524
525 let p0_inv_mod_p1 = simd.splat_u32x16(P0_INV_MOD_P1);
526 let p0_inv_mod_p1_shoup = simd.splat_u32x16(P0_INV_MOD_P1_SHOUP);
527 let p0_mod_p2_shoup = simd.splat_u32x16(P0_MOD_P2_SHOUP);
528 let p0_mod_p3_shoup = simd.splat_u32x16(P0_MOD_P3_SHOUP);
529 let p1_mod_p3_shoup = simd.splat_u32x16(P1_MOD_P3_SHOUP);
530 let p0_mod_p4_shoup = simd.splat_u32x16(P0_MOD_P4_SHOUP);
531 let p1_mod_p4_shoup = simd.splat_u32x16(P1_MOD_P4_SHOUP);
532 let p2_mod_p4_shoup = simd.splat_u32x16(P2_MOD_P4_SHOUP);
533
534 let p01_inv_mod_p2 = simd.splat_u32x16(P01_INV_MOD_P2);
535 let p01_inv_mod_p2_shoup = simd.splat_u32x16(P01_INV_MOD_P2_SHOUP);
536 let p012_inv_mod_p3 = simd.splat_u32x16(P012_INV_MOD_P3);
537 let p012_inv_mod_p3_shoup = simd.splat_u32x16(P012_INV_MOD_P3_SHOUP);
538 let p0123_inv_mod_p4 = simd.splat_u32x16(P0123_INV_MOD_P4);
539 let p0123_inv_mod_p4_shoup = simd.splat_u32x16(P0123_INV_MOD_P4_SHOUP);
540
541 let p01 = simd.splat_u64x8((P0 as u64).wrapping_mul(P1 as u64));
542 let p012 = simd.splat_u64x8((P0 as u64).wrapping_mul(P1 as u64).wrapping_mul(P2 as u64));
543 let p0123 = simd.splat_u64x8(
544 (P0 as u64)
545 .wrapping_mul(P1 as u64)
546 .wrapping_mul(P2 as u64)
547 .wrapping_mul(P3 as u64),
548 );
549 let p01234 = simd.splat_u64x8(
550 (P0 as u64)
551 .wrapping_mul(P1 as u64)
552 .wrapping_mul(P2 as u64)
553 .wrapping_mul(P3 as u64)
554 .wrapping_mul(P4 as u64),
555 );
556
557 let v0 = mod_p0;
558 let v1 = mul_mod32_avx512(
559 simd,
560 p1,
561 simd.wrapping_sub_u32x16(simd.wrapping_add_u32x16(two_p1, mod_p1), v0),
562 p0_inv_mod_p1,
563 p0_inv_mod_p1_shoup,
564 );
565 let v2 = mul_mod32_avx512(
566 simd,
567 p2,
568 simd.wrapping_sub_u32x16(
569 simd.wrapping_add_u32x16(two_p2, mod_p2),
570 simd.wrapping_add_u32x16(v0, mul_mod32_avx512(simd, p2, v1, p0, p0_mod_p2_shoup)),
571 ),
572 p01_inv_mod_p2,
573 p01_inv_mod_p2_shoup,
574 );
575 let v3 = mul_mod32_avx512(
576 simd,
577 p3,
578 simd.wrapping_sub_u32x16(
579 simd.wrapping_add_u32x16(two_p3, mod_p3),
580 simd.wrapping_add_u32x16(
581 v0,
582 mul_mod32_avx512(
583 simd,
584 p3,
585 simd.wrapping_add_u32x16(
586 v1,
587 mul_mod32_avx512(simd, p3, v2, p1, p1_mod_p3_shoup),
588 ),
589 p0,
590 p0_mod_p3_shoup,
591 ),
592 ),
593 ),
594 p012_inv_mod_p3,
595 p012_inv_mod_p3_shoup,
596 );
597 let v4 = mul_mod32_avx512(
598 simd,
599 p4,
600 simd.wrapping_sub_u32x16(
601 simd.wrapping_add_u32x16(two_p4, mod_p4),
602 simd.wrapping_add_u32x16(
603 v0,
604 mul_mod32_avx512(
605 simd,
606 p4,
607 simd.wrapping_add_u32x16(
608 v1,
609 mul_mod32_avx512(
610 simd,
611 p4,
612 simd.wrapping_add_u32x16(
613 v2,
614 mul_mod32_avx512(simd, p4, v3, p2, p2_mod_p4_shoup),
615 ),
616 p1,
617 p1_mod_p4_shoup,
618 ),
619 ),
620 p0,
621 p0_mod_p4_shoup,
622 ),
623 ),
624 ),
625 p0123_inv_mod_p4,
626 p0123_inv_mod_p4_shoup,
627 );
628
629 let sign = simd.cmp_gt_u32x16(v4, half_p4).0;
630 let sign0 = b8(sign as u8);
631 let sign1 = b8((sign >> 8) as u8);
632 let v0: [u32x8; 2] = pulp::cast(v0);
633 let v1: [u32x8; 2] = pulp::cast(v1);
634 let v2: [u32x8; 2] = pulp::cast(v2);
635 let v3: [u32x8; 2] = pulp::cast(v3);
636 let v4: [u32x8; 2] = pulp::cast(v4);
637 let v00 = simd.convert_u32x8_to_u64x8(v0[0]);
638 let v01 = simd.convert_u32x8_to_u64x8(v0[1]);
639 let v10 = simd.convert_u32x8_to_u64x8(v1[0]);
640 let v11 = simd.convert_u32x8_to_u64x8(v1[1]);
641 let v20 = simd.convert_u32x8_to_u64x8(v2[0]);
642 let v21 = simd.convert_u32x8_to_u64x8(v2[1]);
643 let v30 = simd.convert_u32x8_to_u64x8(v3[0]);
644 let v31 = simd.convert_u32x8_to_u64x8(v3[1]);
645 let v40 = simd.convert_u32x8_to_u64x8(v4[0]);
646 let v41 = simd.convert_u32x8_to_u64x8(v4[1]);
647
648 let pos0 = v00;
649 let pos0 = simd.wrapping_add_u64x8(pos0, simd.mul_low_32_bits_u64x8(pulp::cast(p0), v10));
650 let pos0 = simd.wrapping_add_u64x8(pos0, simd.wrapping_mul_u64x8(p01, v20));
651 let pos0 = simd.wrapping_add_u64x8(pos0, simd.wrapping_mul_u64x8(p012, v30));
652 let pos0 = simd.wrapping_add_u64x8(pos0, simd.wrapping_mul_u64x8(p0123, v40));
653
654 let pos1 = v01;
655 let pos1 = simd.wrapping_add_u64x8(pos1, simd.mul_low_32_bits_u64x8(pulp::cast(p0), v11));
656 let pos1 = simd.wrapping_add_u64x8(pos1, simd.wrapping_mul_u64x8(p01, v21));
657 let pos1 = simd.wrapping_add_u64x8(pos1, simd.wrapping_mul_u64x8(p012, v31));
658 let pos1 = simd.wrapping_add_u64x8(pos1, simd.wrapping_mul_u64x8(p0123, v41));
659
660 let neg0 = simd.wrapping_sub_u64x8(pos0, p01234);
661 let neg1 = simd.wrapping_sub_u64x8(pos1, p01234);
662
663 [
664 simd.select_u64x8(sign0, neg0, pos0),
665 simd.select_u64x8(sign1, neg1, pos1),
666 ]
667}
668
669#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
670#[cfg(feature = "nightly")]
671#[inline(always)]
672fn reconstruct_32bit_01234_v2_avx512(
673 simd: crate::V4IFma,
674 mod_p0: u32x8,
675 mod_p1: u32x8,
676 mod_p2: u32x8,
677 mod_p3: u32x8,
678 mod_p4: u32x8,
679) -> u64x8 {
680 use crate::primes32::*;
681
682 let p0 = simd.splat_u64x8(P0 as u64);
683 let p1 = simd.splat_u64x8(P1 as u64);
684 let p2 = simd.splat_u64x8(P2 as u64);
685 let p3 = simd.splat_u64x8(P3 as u64);
686 let p4 = simd.splat_u64x8(P4 as u64);
687 let p12 = simd.splat_u64x8(P12);
688 let p34 = simd.splat_u64x8(P34);
689 let p012 = simd.splat_u64x8((P0 as u64).wrapping_mul(P12));
690 let p01234 = simd.splat_u64x8((P0 as u64).wrapping_mul(P12).wrapping_mul(P34));
691
692 let two_p2 = simd.splat_u64x8(2 * P2 as u64);
693 let two_p4 = simd.splat_u64x8(2 * P4 as u64);
694 let two_p12 = simd.splat_u64x8(2 * P12);
695 let two_p34 = simd.splat_u64x8(2 * P34);
696 let half_p34 = simd.splat_u64x8(P34 / 2);
697
698 let p0_inv_mod_p12 = simd.splat_u64x8(P0_INV_MOD_P12);
699 let p0_inv_mod_p12_shoup = simd.splat_u64x8(P0_INV_MOD_P12_SHOUP);
700 let p1_inv_mod_p2 = simd.splat_u64x8(P1_INV_MOD_P2 as u64);
701 let p1_inv_mod_p2_shoup = simd.splat_u64x8(P1_INV_MOD_P2_SHOUP as u64);
702 let p3_inv_mod_p4 = simd.splat_u64x8(P3_INV_MOD_P4 as u64);
703 let p3_inv_mod_p4_shoup = simd.splat_u64x8(P3_INV_MOD_P4_SHOUP as u64);
704
705 let p012_inv_mod_p34 = simd.splat_u64x8(P012_INV_MOD_P34);
706 let p012_inv_mod_p34_shoup = simd.splat_u64x8(P012_INV_MOD_P34_SHOUP);
707 let p0_mod_p34_shoup = simd.splat_u64x8(P0_MOD_P34_SHOUP);
708
709 let mod_p0 = simd.convert_u32x8_to_u64x8(mod_p0);
710 let mod_p1 = simd.convert_u32x8_to_u64x8(mod_p1);
711 let mod_p2 = simd.convert_u32x8_to_u64x8(mod_p2);
712 let mod_p3 = simd.convert_u32x8_to_u64x8(mod_p3);
713 let mod_p4 = simd.convert_u32x8_to_u64x8(mod_p4);
714
715 let mod_p12 = {
716 let v1 = mod_p1;
717 let v2 = mul_mod32_v2_avx512(
718 simd,
719 p2,
720 simd.wrapping_sub_u64x8(simd.wrapping_add_u64x8(two_p2, mod_p2), v1),
721 p1_inv_mod_p2,
722 p1_inv_mod_p2_shoup,
723 );
724 simd.wrapping_add_u64x8(v1, simd.wrapping_mul_u64x8(v2, p1))
725 };
726 let mod_p34 = {
727 let v3 = mod_p3;
728 let v4 = mul_mod32_v2_avx512(
729 simd,
730 p4,
731 simd.wrapping_sub_u64x8(simd.wrapping_add_u64x8(two_p4, mod_p4), v3),
732 p3_inv_mod_p4,
733 p3_inv_mod_p4_shoup,
734 );
735 simd.wrapping_add_u64x8(v3, simd.wrapping_mul_u64x8(v4, p3))
736 };
737
738 let v0 = mod_p0;
739 let v12 = mul_mod64_avx512(
740 simd,
741 p12,
742 simd.wrapping_sub_u64x8(simd.wrapping_add_u64x8(two_p12, mod_p12), v0),
743 p0_inv_mod_p12,
744 p0_inv_mod_p12_shoup,
745 );
746 let v34 = mul_mod64_avx512(
747 simd,
748 p34,
749 simd.wrapping_sub_u64x8(
750 simd.wrapping_add_u64x8(two_p34, mod_p34),
751 simd.wrapping_add_u64x8(v0, mul_mod64_avx512(simd, p34, v12, p0, p0_mod_p34_shoup)),
752 ),
753 p012_inv_mod_p34,
754 p012_inv_mod_p34_shoup,
755 );
756
757 let sign = simd.cmp_gt_u64x8(v34, half_p34);
758 let pos = v0;
759 let pos = simd.wrapping_add_u64x8(pos, simd.wrapping_mul_u64x8(v12, p0));
760 let pos = simd.wrapping_add_u64x8(pos, simd.wrapping_mul_u64x8(v34, p012));
761
762 let neg = simd.wrapping_sub_u64x8(pos, p01234);
763
764 simd.select_u64x8(sign, neg, pos)
765}
766
767#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
768#[cfg(feature = "nightly")]
769#[inline(always)]
770fn reconstruct_52bit_012_avx512(
771 simd: crate::V4IFma,
772 mod_p0: u64x8,
773 mod_p1: u64x8,
774 mod_p2: u64x8,
775) -> u64x8 {
776 use crate::primes52::*;
777
778 let p0 = simd.splat_u64x8(P0);
779 let p1 = simd.splat_u64x8(P1);
780 let p2 = simd.splat_u64x8(P2);
781 let neg_p1 = simd.splat_u64x8(P1.wrapping_neg());
782 let neg_p2 = simd.splat_u64x8(P2.wrapping_neg());
783 let two_p1 = simd.splat_u64x8(2 * P1);
784 let two_p2 = simd.splat_u64x8(2 * P2);
785 let half_p2 = simd.splat_u64x8(P2 / 2);
786
787 let p0_inv_mod_p1 = simd.splat_u64x8(P0_INV_MOD_P1);
788 let p0_inv_mod_p1_shoup = simd.splat_u64x8(P0_INV_MOD_P1_SHOUP);
789 let p0_mod_p2_shoup = simd.splat_u64x8(P0_MOD_P2_SHOUP);
790 let p01_inv_mod_p2 = simd.splat_u64x8(P01_INV_MOD_P2);
791 let p01_inv_mod_p2_shoup = simd.splat_u64x8(P01_INV_MOD_P2_SHOUP);
792
793 let p01 = simd.splat_u64x8(P0.wrapping_mul(P1));
794 let p012 = simd.splat_u64x8(P0.wrapping_mul(P1).wrapping_mul(P2));
795
796 let v0 = mod_p0;
797 let v1 = mul_mod52_avx512(
798 simd,
799 p1,
800 neg_p1,
801 simd.wrapping_sub_u64x8(simd.wrapping_add_u64x8(two_p1, mod_p1), v0),
802 p0_inv_mod_p1,
803 p0_inv_mod_p1_shoup,
804 );
805 let v2 = mul_mod52_avx512(
806 simd,
807 p2,
808 neg_p2,
809 simd.wrapping_sub_u64x8(
810 simd.wrapping_add_u64x8(two_p2, mod_p2),
811 simd.wrapping_add_u64x8(
812 v0,
813 mul_mod52_avx512(simd, p2, neg_p2, v1, p0, p0_mod_p2_shoup),
814 ),
815 ),
816 p01_inv_mod_p2,
817 p01_inv_mod_p2_shoup,
818 );
819
820 let sign = simd.cmp_gt_u64x8(v2, half_p2);
821
822 let pos = simd.wrapping_add_u64x8(
823 simd.wrapping_add_u64x8(v0, simd.wrapping_mul_u64x8(v1, p0)),
824 simd.wrapping_mul_u64x8(v2, p01),
825 );
826 let neg = simd.wrapping_sub_u64x8(pos, p012);
827
828 simd.select_u64x8(sign, neg, pos)
829}
830
831#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
832fn reconstruct_slice_32bit_01234_avx2(
833 simd: crate::V3,
834 value: &mut [u64],
835 mod_p0: &[u32],
836 mod_p1: &[u32],
837 mod_p2: &[u32],
838 mod_p3: &[u32],
839 mod_p4: &[u32],
840) {
841 simd.vectorize(
842 #[inline(always)]
843 move || {
844 let value = pulp::as_arrays_mut::<4, _>(value).0;
845 let mod_p0 = pulp::as_arrays::<4, _>(mod_p0).0;
846 let mod_p1 = pulp::as_arrays::<4, _>(mod_p1).0;
847 let mod_p2 = pulp::as_arrays::<4, _>(mod_p2).0;
848 let mod_p3 = pulp::as_arrays::<4, _>(mod_p3).0;
849 let mod_p4 = pulp::as_arrays::<4, _>(mod_p4).0;
850 for (value, &mod_p0, &mod_p1, &mod_p2, &mod_p3, &mod_p4) in
851 crate::izip!(value, mod_p0, mod_p1, mod_p2, mod_p3, mod_p4)
852 {
853 *value = cast(reconstruct_32bit_01234_v2_avx2(
854 simd,
855 cast(mod_p0),
856 cast(mod_p1),
857 cast(mod_p2),
858 cast(mod_p3),
859 cast(mod_p4),
860 ));
861 }
862 },
863 );
864}
865
866#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
867#[cfg(feature = "nightly")]
868fn reconstruct_slice_32bit_01234_avx512(
869 simd: crate::V4IFma,
870 value: &mut [u64],
871 mod_p0: &[u32],
872 mod_p1: &[u32],
873 mod_p2: &[u32],
874 mod_p3: &[u32],
875 mod_p4: &[u32],
876) {
877 simd.vectorize(
878 #[inline(always)]
879 move || {
880 let value = pulp::as_arrays_mut::<8, _>(value).0;
881 let mod_p0 = pulp::as_arrays::<8, _>(mod_p0).0;
882 let mod_p1 = pulp::as_arrays::<8, _>(mod_p1).0;
883 let mod_p2 = pulp::as_arrays::<8, _>(mod_p2).0;
884 let mod_p3 = pulp::as_arrays::<8, _>(mod_p3).0;
885 let mod_p4 = pulp::as_arrays::<8, _>(mod_p4).0;
886 for (value, &mod_p0, &mod_p1, &mod_p2, &mod_p3, &mod_p4) in
887 crate::izip!(value, mod_p0, mod_p1, mod_p2, mod_p3, mod_p4)
888 {
889 *value = cast(reconstruct_32bit_01234_v2_avx512(
890 simd,
891 cast(mod_p0),
892 cast(mod_p1),
893 cast(mod_p2),
894 cast(mod_p3),
895 cast(mod_p4),
896 ));
897 }
898 },
899 );
900}
901
902#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
903#[cfg(feature = "nightly")]
904fn reconstruct_slice_52bit_012_avx512(
905 simd: crate::V4IFma,
906 value: &mut [u64],
907 mod_p0: &[u64],
908 mod_p1: &[u64],
909 mod_p2: &[u64],
910) {
911 simd.vectorize(
912 #[inline(always)]
913 move || {
914 let value = pulp::as_arrays_mut::<8, _>(value).0;
915 let mod_p0 = pulp::as_arrays::<8, _>(mod_p0).0;
916 let mod_p1 = pulp::as_arrays::<8, _>(mod_p1).0;
917 let mod_p2 = pulp::as_arrays::<8, _>(mod_p2).0;
918 for (value, &mod_p0, &mod_p1, &mod_p2) in crate::izip!(value, mod_p0, mod_p1, mod_p2) {
919 *value = cast(reconstruct_52bit_012_avx512(
920 simd,
921 cast(mod_p0),
922 cast(mod_p1),
923 cast(mod_p2),
924 ));
925 }
926 },
927 );
928}
929
930impl Plan32 {
931 pub fn try_new(n: usize) -> Option<Self> {
934 use crate::{prime32::Plan, primes32::*};
935 Some(Self(
936 Plan::try_new(n, P0)?,
937 Plan::try_new(n, P1)?,
938 Plan::try_new(n, P2)?,
939 Plan::try_new(n, P3)?,
940 Plan::try_new(n, P4)?,
941 ))
942 }
943
944 #[inline]
946 pub fn ntt_size(&self) -> usize {
947 self.0.ntt_size()
948 }
949
950 #[inline]
951 pub fn ntt_0(&self) -> &crate::prime32::Plan {
952 &self.0
953 }
954 #[inline]
955 pub fn ntt_1(&self) -> &crate::prime32::Plan {
956 &self.1
957 }
958 #[inline]
959 pub fn ntt_2(&self) -> &crate::prime32::Plan {
960 &self.2
961 }
962 #[inline]
963 pub fn ntt_3(&self) -> &crate::prime32::Plan {
964 &self.3
965 }
966 #[inline]
967 pub fn ntt_4(&self) -> &crate::prime32::Plan {
968 &self.4
969 }
970
971 pub fn fwd(
972 &self,
973 value: &[u64],
974 mod_p0: &mut [u32],
975 mod_p1: &mut [u32],
976 mod_p2: &mut [u32],
977 mod_p3: &mut [u32],
978 mod_p4: &mut [u32],
979 ) {
980 for (value, mod_p0, mod_p1, mod_p2, mod_p3, mod_p4) in crate::izip!(
981 value,
982 &mut *mod_p0,
983 &mut *mod_p1,
984 &mut *mod_p2,
985 &mut *mod_p3,
986 &mut *mod_p4
987 ) {
988 *mod_p0 = (value % crate::primes32::P0 as u64) as u32;
989 *mod_p1 = (value % crate::primes32::P1 as u64) as u32;
990 *mod_p2 = (value % crate::primes32::P2 as u64) as u32;
991 *mod_p3 = (value % crate::primes32::P3 as u64) as u32;
992 *mod_p4 = (value % crate::primes32::P4 as u64) as u32;
993 }
994 self.0.fwd(mod_p0);
995 self.1.fwd(mod_p1);
996 self.2.fwd(mod_p2);
997 self.3.fwd(mod_p3);
998 self.4.fwd(mod_p4);
999 }
1000
1001 pub fn inv(
1002 &self,
1003 value: &mut [u64],
1004 mod_p0: &mut [u32],
1005 mod_p1: &mut [u32],
1006 mod_p2: &mut [u32],
1007 mod_p3: &mut [u32],
1008 mod_p4: &mut [u32],
1009 ) {
1010 self.0.inv(mod_p0);
1011 self.1.inv(mod_p1);
1012 self.2.inv(mod_p2);
1013 self.3.inv(mod_p3);
1014 self.4.inv(mod_p4);
1015
1016 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1017 {
1018 #[cfg(feature = "nightly")]
1019 if let Some(simd) = crate::V4IFma::try_new() {
1020 reconstruct_slice_32bit_01234_avx512(
1021 simd, value, mod_p0, mod_p1, mod_p2, mod_p3, mod_p4,
1022 );
1023 return;
1024 }
1025 if let Some(simd) = crate::V3::try_new() {
1026 reconstruct_slice_32bit_01234_avx2(
1027 simd, value, mod_p0, mod_p1, mod_p2, mod_p3, mod_p4,
1028 );
1029 return;
1030 }
1031 }
1032
1033 for (value, &mod_p0, &mod_p1, &mod_p2, &mod_p3, &mod_p4) in
1034 crate::izip!(value, &*mod_p0, &*mod_p1, &*mod_p2, &*mod_p3, &*mod_p4)
1035 {
1036 *value = reconstruct_32bit_01234_v2(mod_p0, mod_p1, mod_p2, mod_p3, mod_p4);
1037 }
1038 }
1039
1040 pub fn negacyclic_polymul(&self, prod: &mut [u64], lhs: &[u64], rhs: &[u64]) {
1043 let n = prod.len();
1044 assert_eq!(n, lhs.len());
1045 assert_eq!(n, rhs.len());
1046
1047 let mut lhs0 = avec![0; n];
1048 let mut lhs1 = avec![0; n];
1049 let mut lhs2 = avec![0; n];
1050 let mut lhs3 = avec![0; n];
1051 let mut lhs4 = avec![0; n];
1052
1053 let mut rhs0 = avec![0; n];
1054 let mut rhs1 = avec![0; n];
1055 let mut rhs2 = avec![0; n];
1056 let mut rhs3 = avec![0; n];
1057 let mut rhs4 = avec![0; n];
1058
1059 self.fwd(lhs, &mut lhs0, &mut lhs1, &mut lhs2, &mut lhs3, &mut lhs4);
1060 self.fwd(rhs, &mut rhs0, &mut rhs1, &mut rhs2, &mut rhs3, &mut rhs4);
1061
1062 self.0.mul_assign_normalize(&mut lhs0, &rhs0);
1063 self.1.mul_assign_normalize(&mut lhs1, &rhs1);
1064 self.2.mul_assign_normalize(&mut lhs2, &rhs2);
1065 self.3.mul_assign_normalize(&mut lhs3, &rhs3);
1066 self.4.mul_assign_normalize(&mut lhs4, &rhs4);
1067
1068 self.inv(prod, &mut lhs0, &mut lhs1, &mut lhs2, &mut lhs3, &mut lhs4);
1069 }
1070}
1071
1072#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1073#[cfg(feature = "nightly")]
1074impl Plan52 {
1075 pub fn try_new(n: usize) -> Option<Self> {
1079 use crate::{prime64::Plan, primes52::*};
1080 let simd = crate::V4IFma::try_new()?;
1081 Some(Self(
1082 Plan::try_new(n, P0)?,
1083 Plan::try_new(n, P1)?,
1084 Plan::try_new(n, P2)?,
1085 simd,
1086 ))
1087 }
1088
1089 #[inline]
1091 pub fn ntt_size(&self) -> usize {
1092 self.0.ntt_size()
1093 }
1094
1095 #[inline]
1096 pub fn ntt_0(&self) -> &crate::prime64::Plan {
1097 &self.0
1098 }
1099 #[inline]
1100 pub fn ntt_1(&self) -> &crate::prime64::Plan {
1101 &self.1
1102 }
1103 #[inline]
1104 pub fn ntt_2(&self) -> &crate::prime64::Plan {
1105 &self.2
1106 }
1107
1108 pub fn fwd(&self, value: &[u64], mod_p0: &mut [u64], mod_p1: &mut [u64], mod_p2: &mut [u64]) {
1109 use crate::primes52::*;
1110 self.3.vectorize(
1111 #[inline(always)]
1112 || {
1113 for (&value, mod_p0, mod_p1, mod_p2) in
1114 crate::izip!(value, &mut *mod_p0, &mut *mod_p1, &mut *mod_p2)
1115 {
1116 *mod_p0 = value % P0;
1117 *mod_p1 = value % P1;
1118 *mod_p2 = value % P2;
1119 }
1120 },
1121 );
1122 self.0.fwd(mod_p0);
1123 self.1.fwd(mod_p1);
1124 self.2.fwd(mod_p2);
1125 }
1126
1127 pub fn inv(
1128 &self,
1129 value: &mut [u64],
1130 mod_p0: &mut [u64],
1131 mod_p1: &mut [u64],
1132 mod_p2: &mut [u64],
1133 ) {
1134 self.0.inv(mod_p0);
1135 self.1.inv(mod_p1);
1136 self.2.inv(mod_p2);
1137
1138 reconstruct_slice_52bit_012_avx512(self.3, value, mod_p0, mod_p1, mod_p2);
1139 }
1140
1141 pub fn negacyclic_polymul(&self, prod: &mut [u64], lhs: &[u64], rhs: &[u64]) {
1144 let n = prod.len();
1145 assert_eq!(n, lhs.len());
1146 assert_eq!(n, rhs.len());
1147
1148 let mut lhs0 = avec![0; n];
1149 let mut lhs1 = avec![0; n];
1150 let mut lhs2 = avec![0; n];
1151
1152 let mut rhs0 = avec![0; n];
1153 let mut rhs1 = avec![0; n];
1154 let mut rhs2 = avec![0; n];
1155
1156 self.fwd(lhs, &mut lhs0, &mut lhs1, &mut lhs2);
1157 self.fwd(rhs, &mut rhs0, &mut rhs1, &mut rhs2);
1158
1159 self.0.mul_assign_normalize(&mut lhs0, &rhs0);
1160 self.1.mul_assign_normalize(&mut lhs1, &rhs1);
1161 self.2.mul_assign_normalize(&mut lhs2, &rhs2);
1162
1163 self.inv(prod, &mut lhs0, &mut lhs1, &mut lhs2);
1164 }
1165}
1166
1167#[cfg(test)]
1168mod tests {
1169 use super::*;
1170 use crate::prime64::tests::random_lhs_rhs_with_negacyclic_convolution;
1171 use alloc::{vec, vec::Vec};
1172 use rand::random;
1173
1174 extern crate alloc;
1175
1176 #[test]
1177 fn reconstruct_32bit() {
1178 for n in [32, 64, 256, 1024, 2048] {
1179 let value = (0..n).map(|_| random::<u64>()).collect::<Vec<_>>();
1180 let mut value_roundtrip = vec![0; n];
1181 let mut mod_p0 = vec![0; n];
1182 let mut mod_p1 = vec![0; n];
1183 let mut mod_p2 = vec![0; n];
1184 let mut mod_p3 = vec![0; n];
1185 let mut mod_p4 = vec![0; n];
1186
1187 let plan = Plan32::try_new(n).unwrap();
1188 plan.fwd(
1189 &value,
1190 &mut mod_p0,
1191 &mut mod_p1,
1192 &mut mod_p2,
1193 &mut mod_p3,
1194 &mut mod_p4,
1195 );
1196 plan.inv(
1197 &mut value_roundtrip,
1198 &mut mod_p0,
1199 &mut mod_p1,
1200 &mut mod_p2,
1201 &mut mod_p3,
1202 &mut mod_p4,
1203 );
1204 for (&value, &value_roundtrip) in crate::izip!(&value, &value_roundtrip) {
1205 assert_eq!(value_roundtrip, value.wrapping_mul(n as u64));
1206 }
1207
1208 let (lhs, rhs, negacyclic_convolution) =
1209 random_lhs_rhs_with_negacyclic_convolution(n, 0);
1210
1211 let mut prod = vec![0; n];
1212 plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
1213 assert_eq!(prod, negacyclic_convolution);
1214 }
1215 }
1216
1217 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1218 #[cfg(feature = "nightly")]
1219 #[test]
1220 fn reconstruct_52bit() {
1221 for n in [32, 64, 256, 1024, 2048] {
1222 if let Some(plan) = Plan52::try_new(n) {
1223 let value = (0..n).map(|_| random::<u64>()).collect::<Vec<_>>();
1224 let mut value_roundtrip = vec![0; n];
1225 let mut mod_p0 = vec![0; n];
1226 let mut mod_p1 = vec![0; n];
1227 let mut mod_p2 = vec![0; n];
1228
1229 plan.fwd(&value, &mut mod_p0, &mut mod_p1, &mut mod_p2);
1230 plan.inv(&mut value_roundtrip, &mut mod_p0, &mut mod_p1, &mut mod_p2);
1231 for (&value, &value_roundtrip) in crate::izip!(&value, &value_roundtrip) {
1232 assert_eq!(value_roundtrip, value.wrapping_mul(n as u64));
1233 }
1234
1235 let (lhs, rhs, negacyclic_convolution) =
1236 random_lhs_rhs_with_negacyclic_convolution(n, 0);
1237
1238 let mut prod = vec![0; n];
1239 plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
1240 assert_eq!(prod, negacyclic_convolution);
1241 }
1242 }
1243 }
1244
1245 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
1246 #[test]
1247 fn reconstruct_32bit_avx() {
1248 for n in [16, 32, 64, 256, 1024, 2048] {
1249 use crate::primes32::*;
1250
1251 let mut value = vec![0; n];
1252 let mut value_avx2 = vec![0; n];
1253 #[cfg(feature = "nightly")]
1254 let mut value_avx512 = vec![0; n];
1255 let mod_p0 = (0..n).map(|_| random::<u32>() % P0).collect::<Vec<_>>();
1256 let mod_p1 = (0..n).map(|_| random::<u32>() % P1).collect::<Vec<_>>();
1257 let mod_p2 = (0..n).map(|_| random::<u32>() % P2).collect::<Vec<_>>();
1258 let mod_p3 = (0..n).map(|_| random::<u32>() % P3).collect::<Vec<_>>();
1259 let mod_p4 = (0..n).map(|_| random::<u32>() % P4).collect::<Vec<_>>();
1260
1261 for (value, &mod_p0, &mod_p1, &mod_p2, &mod_p3, &mod_p4) in
1262 crate::izip!(&mut value, &mod_p0, &mod_p1, &mod_p2, &mod_p3, &mod_p4)
1263 {
1264 *value = reconstruct_32bit_01234_v2(mod_p0, mod_p1, mod_p2, mod_p3, mod_p4);
1265 }
1266
1267 if let Some(simd) = crate::V3::try_new() {
1268 reconstruct_slice_32bit_01234_avx2(
1269 simd,
1270 &mut value_avx2,
1271 &mod_p0,
1272 &mod_p1,
1273 &mod_p2,
1274 &mod_p3,
1275 &mod_p4,
1276 );
1277 assert_eq!(value, value_avx2);
1278 }
1279 #[cfg(feature = "nightly")]
1280 if let Some(simd) = crate::V4IFma::try_new() {
1281 reconstruct_slice_32bit_01234_avx512(
1282 simd,
1283 &mut value_avx512,
1284 &mod_p0,
1285 &mod_p1,
1286 &mod_p2,
1287 &mod_p3,
1288 &mod_p4,
1289 );
1290 assert_eq!(value, value_avx512);
1291 }
1292 }
1293 }
1294}