1use aligned_vec::avec;
2
3#[allow(unused_imports)]
4use pulp::*;
5
6#[derive(Clone, Debug)]
8pub struct Plan32(
9 crate::prime32::Plan,
10 crate::prime32::Plan,
11 crate::prime32::Plan,
12);
13
14#[cfg(all(feature = "nightly", any(target_arch = "x86", target_arch = "x86_64")))]
17#[cfg_attr(docsrs, doc(cfg(feature = "nightly")))]
18#[derive(Clone, Debug)]
19pub struct Plan52(crate::prime64::Plan, crate::prime64::Plan, crate::V4IFma);
20
21#[inline(always)]
22pub(crate) fn mul_mod32(p: u32, a: u32, b: u32) -> u32 {
23 let wide = a as u64 * b as u64;
24 (wide % p as u64) as u32
25}
26
27#[inline(always)]
28pub(crate) fn reconstruct_32bit_012(mod_p0: u32, mod_p1: u32, mod_p2: u32) -> u32 {
29 use crate::primes32::*;
30
31 let v0 = mod_p0;
32 let v1 = mul_mod32(P1, P0_INV_MOD_P1, 2 * P1 + mod_p1 - v0);
33 let v2 = mul_mod32(
34 P2,
35 P01_INV_MOD_P2,
36 2 * P2 + mod_p2 - (v0 + mul_mod32(P2, P0, v1)),
37 );
38
39 let sign = v2 > (P2 / 2);
40
41 const _0: u32 = P0;
42 const _01: u32 = _0.wrapping_mul(P1);
43 const _012: u32 = _01.wrapping_mul(P2);
44
45 let pos = v0
46 .wrapping_add(v1.wrapping_mul(_0))
47 .wrapping_add(v2.wrapping_mul(_01));
48
49 let neg = pos.wrapping_sub(_012);
50
51 if sign {
52 neg
53 } else {
54 pos
55 }
56}
57
58#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
59#[inline(always)]
60pub(crate) fn mul_mod32_avx2(
61 simd: crate::V3,
62 p: u32x8,
63 a: u32x8,
64 b: u32x8,
65 b_shoup: u32x8,
66) -> u32x8 {
67 let shoup_q = simd.widening_mul_u32x8(a, b_shoup).1;
68 let t = simd.wrapping_sub_u32x8(
69 simd.wrapping_mul_u32x8(a, b),
70 simd.wrapping_mul_u32x8(shoup_q, p),
71 );
72 simd.small_mod_u32x8(p, t)
73}
74
75#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
76#[cfg(feature = "nightly")]
77#[inline(always)]
78pub(crate) fn mul_mod32_avx512(
79 simd: crate::V4IFma,
80 p: u32x16,
81 a: u32x16,
82 b: u32x16,
83 b_shoup: u32x16,
84) -> u32x16 {
85 let shoup_q = simd.widening_mul_u32x16(a, b_shoup).1;
86 let t = simd.wrapping_sub_u32x16(
87 simd.wrapping_mul_u32x16(a, b),
88 simd.wrapping_mul_u32x16(shoup_q, p),
89 );
90 simd.small_mod_u32x16(p, t)
91}
92
93#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
94#[cfg(feature = "nightly")]
95#[inline(always)]
96pub(crate) fn mul_mod52_avx512(
97 simd: crate::V4IFma,
98 p: u64x8,
99 neg_p: u64x8,
100 a: u64x8,
101 b: u64x8,
102 b_shoup: u64x8,
103) -> u64x8 {
104 let shoup_q = simd.widening_mul_u52x8(a, b_shoup).1;
105 let t = simd.wrapping_mul_add_u52x8(shoup_q, neg_p, simd.widening_mul_u52x8(a, b).0);
106 simd.small_mod_u64x8(p, t)
107}
108
109#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
110#[inline(always)]
111pub(crate) fn reconstruct_32bit_012_avx2(
112 simd: crate::V3,
113 mod_p0: u32x8,
114 mod_p1: u32x8,
115 mod_p2: u32x8,
116) -> u32x8 {
117 use crate::primes32::*;
118
119 let p0 = simd.splat_u32x8(P0);
120 let p1 = simd.splat_u32x8(P1);
121 let p2 = simd.splat_u32x8(P2);
122 let two_p1 = simd.splat_u32x8(2 * P1);
123 let two_p2 = simd.splat_u32x8(2 * P2);
124 let half_p2 = simd.splat_u32x8(P2 / 2);
125
126 let p0_inv_mod_p1 = simd.splat_u32x8(P0_INV_MOD_P1);
127 let p0_inv_mod_p1_shoup = simd.splat_u32x8(P0_INV_MOD_P1_SHOUP);
128 let p01_inv_mod_p2 = simd.splat_u32x8(P01_INV_MOD_P2);
129 let p01_inv_mod_p2_shoup = simd.splat_u32x8(P01_INV_MOD_P2_SHOUP);
130 let p0_mod_p2_shoup = simd.splat_u32x8(P0_MOD_P2_SHOUP);
131
132 let p01 = simd.splat_u32x8(P0.wrapping_mul(P1));
133 let p012 = simd.splat_u32x8(P0.wrapping_mul(P1).wrapping_mul(P2));
134
135 let v0 = mod_p0;
136 let v1 = mul_mod32_avx2(
137 simd,
138 p1,
139 simd.wrapping_sub_u32x8(simd.wrapping_add_u32x8(two_p1, mod_p1), v0),
140 p0_inv_mod_p1,
141 p0_inv_mod_p1_shoup,
142 );
143 let v2 = mul_mod32_avx2(
144 simd,
145 p2,
146 simd.wrapping_sub_u32x8(
147 simd.wrapping_add_u32x8(two_p2, mod_p2),
148 simd.wrapping_add_u32x8(v0, mul_mod32_avx2(simd, p2, v1, p0, p0_mod_p2_shoup)),
149 ),
150 p01_inv_mod_p2,
151 p01_inv_mod_p2_shoup,
152 );
153
154 let sign = simd.cmp_gt_u32x8(v2, half_p2);
155 let pos = simd.wrapping_add_u32x8(
156 simd.wrapping_add_u32x8(v0, simd.wrapping_mul_u32x8(v1, p0)),
157 simd.wrapping_mul_u32x8(v2, p01),
158 );
159 let neg = simd.wrapping_sub_u32x8(pos, p012);
160
161 simd.select_u32x8(sign, neg, pos)
162}
163
164#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
165#[cfg(feature = "nightly")]
166#[inline(always)]
167fn reconstruct_32bit_012_avx512(
168 simd: crate::V4IFma,
169 mod_p0: u32x16,
170 mod_p1: u32x16,
171 mod_p2: u32x16,
172) -> u32x16 {
173 use crate::primes32::*;
174
175 let p0 = simd.splat_u32x16(P0);
176 let p1 = simd.splat_u32x16(P1);
177 let p2 = simd.splat_u32x16(P2);
178 let two_p1 = simd.splat_u32x16(2 * P1);
179 let two_p2 = simd.splat_u32x16(2 * P2);
180 let half_p2 = simd.splat_u32x16(P2 / 2);
181
182 let p0_inv_mod_p1 = simd.splat_u32x16(P0_INV_MOD_P1);
183 let p0_inv_mod_p1_shoup = simd.splat_u32x16(P0_INV_MOD_P1_SHOUP);
184 let p01_inv_mod_p2 = simd.splat_u32x16(P01_INV_MOD_P2);
185 let p01_inv_mod_p2_shoup = simd.splat_u32x16(P01_INV_MOD_P2_SHOUP);
186 let p0_mod_p2_shoup = simd.splat_u32x16(P0_MOD_P2_SHOUP);
187
188 let p01 = simd.splat_u32x16(P0.wrapping_mul(P1));
189 let p012 = simd.splat_u32x16(P0.wrapping_mul(P1).wrapping_mul(P2));
190
191 let v0 = mod_p0;
192 let v1 = mul_mod32_avx512(
193 simd,
194 p1,
195 simd.wrapping_sub_u32x16(simd.wrapping_add_u32x16(two_p1, mod_p1), v0),
196 p0_inv_mod_p1,
197 p0_inv_mod_p1_shoup,
198 );
199 let v2 = mul_mod32_avx512(
200 simd,
201 p2,
202 simd.wrapping_sub_u32x16(
203 simd.wrapping_add_u32x16(two_p2, mod_p2),
204 simd.wrapping_add_u32x16(v0, mul_mod32_avx512(simd, p2, v1, p0, p0_mod_p2_shoup)),
205 ),
206 p01_inv_mod_p2,
207 p01_inv_mod_p2_shoup,
208 );
209
210 let sign = simd.cmp_gt_u32x16(v2, half_p2);
211 let pos = simd.wrapping_add_u32x16(
212 simd.wrapping_add_u32x16(v0, simd.wrapping_mul_u32x16(v1, p0)),
213 simd.wrapping_mul_u32x16(v2, p01),
214 );
215 let neg = simd.wrapping_sub_u32x16(pos, p012);
216
217 simd.select_u32x16(sign, neg, pos)
218}
219
220#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
221#[cfg(feature = "nightly")]
222#[inline(always)]
223fn reconstruct_52bit_01_avx512(simd: crate::V4IFma, mod_p0: u64x8, mod_p1: u64x8) -> u32x8 {
224 use crate::primes52::*;
225
226 let p0 = simd.splat_u64x8(P0);
227 let p1 = simd.splat_u64x8(P1);
228 let neg_p1 = simd.splat_u64x8(P1.wrapping_neg());
229 let two_p1 = simd.splat_u64x8(2 * P1);
230 let half_p1 = simd.splat_u64x8(P1 / 2);
231
232 let p0_inv_mod_p1 = simd.splat_u64x8(P0_INV_MOD_P1);
233 let p0_inv_mod_p1_shoup = simd.splat_u64x8(P0_INV_MOD_P1_SHOUP);
234
235 let p01 = simd.splat_u64x8(P0.wrapping_mul(P1));
236
237 let v0 = mod_p0;
238 let v1 = mul_mod52_avx512(
239 simd,
240 p1,
241 neg_p1,
242 simd.wrapping_sub_u64x8(simd.wrapping_add_u64x8(two_p1, mod_p1), v0),
243 p0_inv_mod_p1,
244 p0_inv_mod_p1_shoup,
245 );
246
247 let sign = simd.cmp_gt_u64x8(v1, half_p1);
248
249 let pos = simd.wrapping_add_u64x8(v0, simd.wrapping_mul_u64x8(v1, p0));
250 let neg = simd.wrapping_sub_u64x8(pos, p01);
251
252 simd.convert_u64x8_to_u32x8(simd.select_u64x8(sign, neg, pos))
253}
254
255#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
256fn reconstruct_slice_32bit_012_avx2(
257 simd: crate::V3,
258 value: &mut [u32],
259 mod_p0: &[u32],
260 mod_p1: &[u32],
261 mod_p2: &[u32],
262) {
263 simd.vectorize(
264 #[inline(always)]
265 move || {
266 let value = pulp::as_arrays_mut::<8, _>(value).0;
267 let mod_p0 = pulp::as_arrays::<8, _>(mod_p0).0;
268 let mod_p1 = pulp::as_arrays::<8, _>(mod_p1).0;
269 let mod_p2 = pulp::as_arrays::<8, _>(mod_p2).0;
270 for (value, &mod_p0, &mod_p1, &mod_p2) in crate::izip!(value, mod_p0, mod_p1, mod_p2) {
271 *value = cast(reconstruct_32bit_012_avx2(
272 simd,
273 cast(mod_p0),
274 cast(mod_p1),
275 cast(mod_p2),
276 ));
277 }
278 },
279 );
280}
281
282#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
283#[cfg(feature = "nightly")]
284fn reconstruct_slice_32bit_012_avx512(
285 simd: crate::V4IFma,
286 value: &mut [u32],
287 mod_p0: &[u32],
288 mod_p1: &[u32],
289 mod_p2: &[u32],
290) {
291 simd.vectorize(
292 #[inline(always)]
293 move || {
294 let value = pulp::as_arrays_mut::<16, _>(value).0;
295 let mod_p0 = pulp::as_arrays::<16, _>(mod_p0).0;
296 let mod_p1 = pulp::as_arrays::<16, _>(mod_p1).0;
297 let mod_p2 = pulp::as_arrays::<16, _>(mod_p2).0;
298 for (value, &mod_p0, &mod_p1, &mod_p2) in crate::izip!(value, mod_p0, mod_p1, mod_p2) {
299 *value = cast(reconstruct_32bit_012_avx512(
300 simd,
301 cast(mod_p0),
302 cast(mod_p1),
303 cast(mod_p2),
304 ));
305 }
306 },
307 );
308}
309
310#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
311#[cfg(feature = "nightly")]
312fn reconstruct_slice_52bit_01_avx512(
313 simd: crate::V4IFma,
314 value: &mut [u32],
315 mod_p0: &[u64],
316 mod_p1: &[u64],
317) {
318 simd.vectorize(
319 #[inline(always)]
320 move || {
321 let value = pulp::as_arrays_mut::<8, _>(value).0;
322 let mod_p0 = pulp::as_arrays::<8, _>(mod_p0).0;
323 let mod_p1 = pulp::as_arrays::<8, _>(mod_p1).0;
324 for (value, &mod_p0, &mod_p1) in crate::izip!(value, mod_p0, mod_p1) {
325 *value = cast(reconstruct_52bit_01_avx512(
326 simd,
327 cast(mod_p0),
328 cast(mod_p1),
329 ));
330 }
331 },
332 );
333}
334
335impl Plan32 {
336 pub fn try_new(n: usize) -> Option<Self> {
339 use crate::{prime32::Plan, primes32::*};
340 Some(Self(
341 Plan::try_new(n, P0)?,
342 Plan::try_new(n, P1)?,
343 Plan::try_new(n, P2)?,
344 ))
345 }
346
347 #[inline]
349 pub fn ntt_size(&self) -> usize {
350 self.0.ntt_size()
351 }
352
353 #[inline]
354 pub fn ntt_0(&self) -> &crate::prime32::Plan {
355 &self.0
356 }
357 #[inline]
358 pub fn ntt_1(&self) -> &crate::prime32::Plan {
359 &self.1
360 }
361 #[inline]
362 pub fn ntt_2(&self) -> &crate::prime32::Plan {
363 &self.2
364 }
365
366 pub fn fwd(&self, value: &[u32], mod_p0: &mut [u32], mod_p1: &mut [u32], mod_p2: &mut [u32]) {
367 for (value, mod_p0, mod_p1, mod_p2) in
368 crate::izip!(value, &mut *mod_p0, &mut *mod_p1, &mut *mod_p2)
369 {
370 *mod_p0 = value % crate::primes32::P0;
371 *mod_p1 = value % crate::primes32::P1;
372 *mod_p2 = value % crate::primes32::P2;
373 }
374 self.0.fwd(mod_p0);
375 self.1.fwd(mod_p1);
376 self.2.fwd(mod_p2);
377 }
378
379 pub fn inv(
380 &self,
381 value: &mut [u32],
382 mod_p0: &mut [u32],
383 mod_p1: &mut [u32],
384 mod_p2: &mut [u32],
385 ) {
386 self.0.inv(mod_p0);
387 self.1.inv(mod_p1);
388 self.2.inv(mod_p2);
389
390 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
391 {
392 #[cfg(feature = "nightly")]
393 if let Some(simd) = crate::V4IFma::try_new() {
394 reconstruct_slice_32bit_012_avx512(simd, value, mod_p0, mod_p1, mod_p2);
395 return;
396 }
397 if let Some(simd) = crate::V3::try_new() {
398 reconstruct_slice_32bit_012_avx2(simd, value, mod_p0, mod_p1, mod_p2);
399 return;
400 }
401 }
402
403 for (value, &mod_p0, &mod_p1, &mod_p2) in crate::izip!(value, &*mod_p0, &*mod_p1, &*mod_p2)
404 {
405 *value = reconstruct_32bit_012(mod_p0, mod_p1, mod_p2);
406 }
407 }
408
409 pub fn negacyclic_polymul(&self, prod: &mut [u32], lhs: &[u32], rhs: &[u32]) {
412 let n = prod.len();
413 assert_eq!(n, lhs.len());
414 assert_eq!(n, rhs.len());
415
416 let mut lhs0 = avec![0; n];
417 let mut lhs1 = avec![0; n];
418 let mut lhs2 = avec![0; n];
419
420 let mut rhs0 = avec![0; n];
421 let mut rhs1 = avec![0; n];
422 let mut rhs2 = avec![0; n];
423
424 self.fwd(lhs, &mut lhs0, &mut lhs1, &mut lhs2);
425 self.fwd(rhs, &mut rhs0, &mut rhs1, &mut rhs2);
426
427 self.0.mul_assign_normalize(&mut lhs0, &rhs0);
428 self.1.mul_assign_normalize(&mut lhs1, &rhs1);
429 self.2.mul_assign_normalize(&mut lhs2, &rhs2);
430
431 self.inv(prod, &mut lhs0, &mut lhs1, &mut lhs2);
432 }
433}
434
435#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
436#[cfg(feature = "nightly")]
437impl Plan52 {
438 pub fn try_new(n: usize) -> Option<Self> {
442 use crate::{prime64::Plan, primes52::*};
443 let simd = crate::V4IFma::try_new()?;
444 Some(Self(Plan::try_new(n, P0)?, Plan::try_new(n, P1)?, simd))
445 }
446
447 #[inline]
449 pub fn ntt_size(&self) -> usize {
450 self.0.ntt_size()
451 }
452
453 pub fn fwd(&self, value: &[u32], mod_p0: &mut [u64], mod_p1: &mut [u64]) {
454 self.2.vectorize(
455 #[inline(always)]
456 || {
457 for (value, mod_p0, mod_p1) in crate::izip!(value, &mut *mod_p0, &mut *mod_p1) {
458 *mod_p0 = *value as u64;
459 *mod_p1 = *value as u64;
460 }
461 },
462 );
463 self.0.fwd(mod_p0);
464 self.1.fwd(mod_p1);
465 }
466
467 pub fn inv(&self, value: &mut [u32], mod_p0: &mut [u64], mod_p1: &mut [u64]) {
468 self.0.inv(mod_p0);
469 self.1.inv(mod_p1);
470
471 let simd = self.2;
472 reconstruct_slice_52bit_01_avx512(simd, value, mod_p0, mod_p1);
473 }
474
475 pub fn negacyclic_polymul(&self, prod: &mut [u32], lhs: &[u32], rhs: &[u32]) {
478 let n = prod.len();
479 assert_eq!(n, lhs.len());
480 assert_eq!(n, rhs.len());
481
482 let mut lhs0 = avec![0; n];
483 let mut lhs1 = avec![0; n];
484
485 let mut rhs0 = avec![0; n];
486 let mut rhs1 = avec![0; n];
487
488 self.fwd(lhs, &mut lhs0, &mut lhs1);
489 self.fwd(rhs, &mut rhs0, &mut rhs1);
490
491 self.0.mul_assign_normalize(&mut lhs0, &rhs0);
492 self.1.mul_assign_normalize(&mut lhs1, &rhs1);
493
494 self.inv(prod, &mut lhs0, &mut lhs1);
495 }
496}
497
498#[cfg(test)]
499mod tests {
500 use super::*;
501 use crate::prime32::tests::random_lhs_rhs_with_negacyclic_convolution;
502 use rand::random;
503
504 extern crate alloc;
505 use alloc::{vec, vec::Vec};
506
507 #[test]
508 fn reconstruct_32bit() {
509 for n in [32, 64, 256, 1024, 2048] {
510 let plan = Plan32::try_new(n).unwrap();
511
512 let value = (0..n).map(|_| random::<u32>()).collect::<Vec<_>>();
513 let mut value_roundtrip = vec![0u32; n];
514 let mut mod_p0 = vec![0u32; n];
515 let mut mod_p1 = vec![0u32; n];
516 let mut mod_p2 = vec![0u32; n];
517
518 plan.fwd(&value, &mut mod_p0, &mut mod_p1, &mut mod_p2);
519 plan.inv(&mut value_roundtrip, &mut mod_p0, &mut mod_p1, &mut mod_p2);
520 for (&value, &value_roundtrip) in crate::izip!(&value, &value_roundtrip) {
521 assert_eq!(value_roundtrip, value.wrapping_mul(n as u32));
522 }
523
524 let (lhs, rhs, negacyclic_convolution) =
525 random_lhs_rhs_with_negacyclic_convolution(n, 0);
526
527 let mut prod = vec![0; n];
528 plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
529 assert_eq!(prod, negacyclic_convolution);
530 }
531 }
532
533 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
534 #[cfg(feature = "nightly")]
535 #[test]
536 fn reconstruct_52bit() {
537 for n in [32, 64, 256, 1024, 2048] {
538 if let Some(plan) = Plan52::try_new(n) {
539 let value = (0..n).map(|_| random::<u32>()).collect::<Vec<_>>();
540 let mut value_roundtrip = vec![0u32; n];
541 let mut mod_p0 = vec![0u64; n];
542 let mut mod_p1 = vec![0u64; n];
543
544 plan.fwd(&value, &mut mod_p0, &mut mod_p1);
545 plan.inv(&mut value_roundtrip, &mut mod_p0, &mut mod_p1);
546 for (&value, &value_roundtrip) in crate::izip!(&value, &value_roundtrip) {
547 assert_eq!(value_roundtrip, value.wrapping_mul(n as u32));
548 }
549
550 let (lhs, rhs, negacyclic_convolution) =
551 random_lhs_rhs_with_negacyclic_convolution(n, 0);
552
553 let mut prod = vec![0; n];
554 plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
555 assert_eq!(prod, negacyclic_convolution);
556 }
557 }
558 }
559
560 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
561 #[test]
562 fn reconstruct_32bit_avx() {
563 for n in [16, 32, 64, 256, 1024, 2048] {
564 use crate::primes32::*;
565
566 let mut value = vec![0u32; n];
567 let mut value_avx2 = vec![0u32; n];
568 #[cfg(feature = "nightly")]
569 let mut value_avx512 = vec![0u32; n];
570 let mod_p0 = (0..n).map(|_| random::<u32>() % P0).collect::<Vec<_>>();
571 let mod_p1 = (0..n).map(|_| random::<u32>() % P1).collect::<Vec<_>>();
572 let mod_p2 = (0..n).map(|_| random::<u32>() % P2).collect::<Vec<_>>();
573
574 for (value, &mod_p0, &mod_p1, &mod_p2) in
575 crate::izip!(&mut value, &mod_p0, &mod_p1, &mod_p2)
576 {
577 *value = reconstruct_32bit_012(mod_p0, mod_p1, mod_p2);
578 }
579
580 if let Some(simd) = crate::V3::try_new() {
581 reconstruct_slice_32bit_012_avx2(simd, &mut value_avx2, &mod_p0, &mod_p1, &mod_p2);
582 assert_eq!(value, value_avx2);
583 }
584 #[cfg(feature = "nightly")]
585 if let Some(simd) = crate::V4IFma::try_new() {
586 reconstruct_slice_32bit_012_avx512(
587 simd,
588 &mut value_avx512,
589 &mod_p0,
590 &mod_p1,
591 &mod_p2,
592 );
593 assert_eq!(value, value_avx512);
594 }
595 }
596 }
597}