1use aligned_vec::avec;
2
3#[allow(unused_imports)]
4use pulp::*;
5
6pub(crate) use crate::native32::mul_mod32;
7
8#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
9pub(crate) use crate::native32::mul_mod32_avx2;
10#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
11#[cfg(feature = "nightly")]
12pub(crate) use crate::native32::{mul_mod32_avx512, mul_mod52_avx512};
13
14#[derive(Clone, Debug)]
17pub struct Plan32(
18 crate::prime32::Plan,
19 crate::prime32::Plan,
20 crate::prime32::Plan,
21);
22
23#[cfg(all(feature = "nightly", any(target_arch = "x86", target_arch = "x86_64")))]
27#[cfg_attr(docsrs, doc(cfg(feature = "nightly")))]
28#[derive(Clone, Debug)]
29pub struct Plan52(crate::prime64::Plan, crate::prime64::Plan, crate::V4IFma);
30
31#[inline(always)]
32#[allow(dead_code)]
33fn reconstruct_32bit_012(mod_p0: u32, mod_p1: u32, mod_p2: u32) -> u64 {
34 use crate::primes32::*;
35
36 let v0 = mod_p0;
37 let v1 = mul_mod32(P1, P0_INV_MOD_P1, 2 * P1 + mod_p1 - v0);
38 let v2 = mul_mod32(
39 P2,
40 P01_INV_MOD_P2,
41 2 * P2 + mod_p2 - (v0 + mul_mod32(P2, P0, v1)),
42 );
43
44 let sign = v2 > (P2 / 2);
45
46 const _0: u64 = P0 as u64;
47 const _01: u64 = _0.wrapping_mul(P1 as u64);
48 const _012: u64 = _01.wrapping_mul(P2 as u64);
49
50 let pos = (v0 as u64)
51 .wrapping_add((v1 as u64).wrapping_mul(_0))
52 .wrapping_add((v2 as u64).wrapping_mul(_01));
53
54 let neg = pos.wrapping_sub(_012);
55
56 if sign {
57 neg
58 } else {
59 pos
60 }
61}
62
63#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
64#[allow(dead_code)]
65#[inline(always)]
66fn reconstruct_32bit_012_avx2(
67 simd: crate::V3,
68 mod_p0: u32x8,
69 mod_p1: u32x8,
70 mod_p2: u32x8,
71) -> [u64x4; 2] {
72 use crate::primes32::*;
73
74 let p0 = simd.splat_u32x8(P0);
75 let p1 = simd.splat_u32x8(P1);
76 let p2 = simd.splat_u32x8(P2);
77 let two_p1 = simd.splat_u32x8(2 * P1);
78 let two_p2 = simd.splat_u32x8(2 * P2);
79 let half_p2 = simd.splat_u32x8(P2 / 2);
80
81 let p0_inv_mod_p1 = simd.splat_u32x8(P0_INV_MOD_P1);
82 let p0_inv_mod_p1_shoup = simd.splat_u32x8(P0_INV_MOD_P1_SHOUP);
83 let p0_mod_p2_shoup = simd.splat_u32x8(P0_MOD_P2_SHOUP);
84
85 let p01_inv_mod_p2 = simd.splat_u32x8(P01_INV_MOD_P2);
86 let p01_inv_mod_p2_shoup = simd.splat_u32x8(P01_INV_MOD_P2_SHOUP);
87
88 let p01 = simd.splat_u64x4((P0 as u64).wrapping_mul(P1 as u64));
89 let p012 = simd.splat_u64x4((P0 as u64).wrapping_mul(P1 as u64).wrapping_mul(P2 as u64));
90
91 let v0 = mod_p0;
92 let v1 = mul_mod32_avx2(
93 simd,
94 p1,
95 simd.wrapping_sub_u32x8(simd.wrapping_add_u32x8(two_p1, mod_p1), v0),
96 p0_inv_mod_p1,
97 p0_inv_mod_p1_shoup,
98 );
99 let v2 = mul_mod32_avx2(
100 simd,
101 p2,
102 simd.wrapping_sub_u32x8(
103 simd.wrapping_add_u32x8(two_p2, mod_p2),
104 simd.wrapping_add_u32x8(v0, mul_mod32_avx2(simd, p2, v1, p0, p0_mod_p2_shoup)),
105 ),
106 p01_inv_mod_p2,
107 p01_inv_mod_p2_shoup,
108 );
109
110 let sign = simd.cmp_gt_u32x8(v2, half_p2);
111 let sign: [i32x4; 2] = pulp::cast(sign);
112 let sign0: m64x4 = unsafe { core::mem::transmute(simd.convert_i32x4_to_i64x4(sign[0])) };
114 let sign1: m64x4 = unsafe { core::mem::transmute(simd.convert_i32x4_to_i64x4(sign[1])) };
115
116 let v0: [u32x4; 2] = pulp::cast(v0);
117 let v1: [u32x4; 2] = pulp::cast(v1);
118 let v2: [u32x4; 2] = pulp::cast(v2);
119 let v00 = simd.convert_u32x4_to_u64x4(v0[0]);
120 let v01 = simd.convert_u32x4_to_u64x4(v0[1]);
121 let v10 = simd.convert_u32x4_to_u64x4(v1[0]);
122 let v11 = simd.convert_u32x4_to_u64x4(v1[1]);
123 let v20 = simd.convert_u32x4_to_u64x4(v2[0]);
124 let v21 = simd.convert_u32x4_to_u64x4(v2[1]);
125
126 let pos0 = v00;
127 let pos0 = simd.wrapping_add_u64x4(pos0, simd.mul_low_32_bits_u64x4(pulp::cast(p0), v10));
128 let pos0 = simd.wrapping_add_u64x4(
129 pos0,
130 simd.wrapping_mul_lhs_with_low_32_bits_of_rhs_u64x4(p01, v20),
131 );
132
133 let pos1 = v01;
134 let pos1 = simd.wrapping_add_u64x4(pos1, simd.mul_low_32_bits_u64x4(pulp::cast(p0), v11));
135 let pos1 = simd.wrapping_add_u64x4(
136 pos1,
137 simd.wrapping_mul_lhs_with_low_32_bits_of_rhs_u64x4(p01, v21),
138 );
139
140 let neg0 = simd.wrapping_sub_u64x4(pos0, p012);
141 let neg1 = simd.wrapping_sub_u64x4(pos1, p012);
142
143 [
144 simd.select_u64x4(sign0, neg0, pos0),
145 simd.select_u64x4(sign1, neg1, pos1),
146 ]
147}
148
149#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
150#[cfg(feature = "nightly")]
151#[allow(dead_code)]
152#[inline(always)]
153fn reconstruct_32bit_012_avx512(
154 simd: crate::V4IFma,
155 mod_p0: u32x16,
156 mod_p1: u32x16,
157 mod_p2: u32x16,
158) -> [u64x8; 2] {
159 use crate::primes32::*;
160
161 let p0 = simd.splat_u32x16(P0);
162 let p1 = simd.splat_u32x16(P1);
163 let p2 = simd.splat_u32x16(P2);
164 let two_p1 = simd.splat_u32x16(2 * P1);
165 let two_p2 = simd.splat_u32x16(2 * P2);
166 let half_p2 = simd.splat_u32x16(P2 / 2);
167
168 let p0_inv_mod_p1 = simd.splat_u32x16(P0_INV_MOD_P1);
169 let p0_inv_mod_p1_shoup = simd.splat_u32x16(P0_INV_MOD_P1_SHOUP);
170 let p0_mod_p2_shoup = simd.splat_u32x16(P0_MOD_P2_SHOUP);
171
172 let p01_inv_mod_p2 = simd.splat_u32x16(P01_INV_MOD_P2);
173 let p01_inv_mod_p2_shoup = simd.splat_u32x16(P01_INV_MOD_P2_SHOUP);
174
175 let p01 = simd.splat_u64x8((P0 as u64).wrapping_mul(P1 as u64));
176 let p012 = simd.splat_u64x8((P0 as u64).wrapping_mul(P1 as u64).wrapping_mul(P2 as u64));
177
178 let v0 = mod_p0;
179 let v1 = mul_mod32_avx512(
180 simd,
181 p1,
182 simd.wrapping_sub_u32x16(simd.wrapping_add_u32x16(two_p1, mod_p1), v0),
183 p0_inv_mod_p1,
184 p0_inv_mod_p1_shoup,
185 );
186 let v2 = mul_mod32_avx512(
187 simd,
188 p2,
189 simd.wrapping_sub_u32x16(
190 simd.wrapping_add_u32x16(two_p2, mod_p2),
191 simd.wrapping_add_u32x16(v0, mul_mod32_avx512(simd, p2, v1, p0, p0_mod_p2_shoup)),
192 ),
193 p01_inv_mod_p2,
194 p01_inv_mod_p2_shoup,
195 );
196
197 let sign = simd.cmp_gt_u32x16(v2, half_p2).0;
198 let sign0 = b8(sign as u8);
199 let sign1 = b8((sign >> 8) as u8);
200 let v0: [u32x8; 2] = pulp::cast(v0);
201 let v1: [u32x8; 2] = pulp::cast(v1);
202 let v2: [u32x8; 2] = pulp::cast(v2);
203 let v00 = simd.convert_u32x8_to_u64x8(v0[0]);
204 let v01 = simd.convert_u32x8_to_u64x8(v0[1]);
205 let v10 = simd.convert_u32x8_to_u64x8(v1[0]);
206 let v11 = simd.convert_u32x8_to_u64x8(v1[1]);
207 let v20 = simd.convert_u32x8_to_u64x8(v2[0]);
208 let v21 = simd.convert_u32x8_to_u64x8(v2[1]);
209
210 let pos0 = v00;
211 let pos0 = simd.wrapping_add_u64x8(pos0, simd.mul_low_32_bits_u64x8(pulp::cast(p0), v10));
212 let pos0 = simd.wrapping_add_u64x8(pos0, simd.wrapping_mul_u64x8(p01, v20));
213
214 let pos1 = v01;
215 let pos1 = simd.wrapping_add_u64x8(pos1, simd.mul_low_32_bits_u64x8(pulp::cast(p0), v11));
216 let pos1 = simd.wrapping_add_u64x8(pos1, simd.wrapping_mul_u64x8(p01, v21));
217
218 let neg0 = simd.wrapping_sub_u64x8(pos0, p012);
219 let neg1 = simd.wrapping_sub_u64x8(pos1, p012);
220
221 [
222 simd.select_u64x8(sign0, neg0, pos0),
223 simd.select_u64x8(sign1, neg1, pos1),
224 ]
225}
226
227#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
228#[cfg(feature = "nightly")]
229#[inline(always)]
230fn reconstruct_52bit_01_avx512(simd: crate::V4IFma, mod_p0: u64x8, mod_p1: u64x8) -> u64x8 {
231 use crate::primes52::*;
232
233 let p0 = simd.splat_u64x8(P0);
234 let p1 = simd.splat_u64x8(P1);
235 let neg_p1 = simd.splat_u64x8(P1.wrapping_neg());
236 let two_p1 = simd.splat_u64x8(2 * P1);
237 let half_p1 = simd.splat_u64x8(P1 / 2);
238
239 let p0_inv_mod_p1 = simd.splat_u64x8(P0_INV_MOD_P1);
240 let p0_inv_mod_p1_shoup = simd.splat_u64x8(P0_INV_MOD_P1_SHOUP);
241
242 let p01 = simd.splat_u64x8(P0.wrapping_mul(P1));
243
244 let v0 = mod_p0;
245 let v1 = mul_mod52_avx512(
246 simd,
247 p1,
248 neg_p1,
249 simd.wrapping_sub_u64x8(simd.wrapping_add_u64x8(two_p1, mod_p1), v0),
250 p0_inv_mod_p1,
251 p0_inv_mod_p1_shoup,
252 );
253
254 let sign = simd.cmp_gt_u64x8(v1, half_p1);
255
256 let pos = simd.wrapping_add_u64x8(v0, simd.wrapping_mul_u64x8(v1, p0));
257 let neg = simd.wrapping_sub_u64x8(pos, p01);
258
259 simd.select_u64x8(sign, neg, pos)
260}
261
262#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
263fn reconstruct_slice_32bit_012_avx2(
264 simd: crate::V3,
265 value: &mut [u64],
266 mod_p0: &[u32],
267 mod_p1: &[u32],
268 mod_p2: &[u32],
269) {
270 simd.vectorize(
271 #[inline(always)]
272 move || {
273 let value = pulp::as_arrays_mut::<8, _>(value).0;
274 let mod_p0 = pulp::as_arrays::<8, _>(mod_p0).0;
275 let mod_p1 = pulp::as_arrays::<8, _>(mod_p1).0;
276 let mod_p2 = pulp::as_arrays::<8, _>(mod_p2).0;
277 for (value, &mod_p0, &mod_p1, &mod_p2) in crate::izip!(value, mod_p0, mod_p1, mod_p2) {
278 *value = cast(reconstruct_32bit_012_avx2(
279 simd,
280 cast(mod_p0),
281 cast(mod_p1),
282 cast(mod_p2),
283 ));
284 }
285 },
286 );
287}
288
289#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
290#[cfg(feature = "nightly")]
291fn reconstruct_slice_32bit_012_avx512(
292 simd: crate::V4IFma,
293 value: &mut [u64],
294 mod_p0: &[u32],
295 mod_p1: &[u32],
296 mod_p2: &[u32],
297) {
298 simd.vectorize(
299 #[inline(always)]
300 move || {
301 let value = pulp::as_arrays_mut::<16, _>(value).0;
302 let mod_p0 = pulp::as_arrays::<16, _>(mod_p0).0;
303 let mod_p1 = pulp::as_arrays::<16, _>(mod_p1).0;
304 let mod_p2 = pulp::as_arrays::<16, _>(mod_p2).0;
305 for (value, &mod_p0, &mod_p1, &mod_p2) in crate::izip!(value, mod_p0, mod_p1, mod_p2) {
306 *value = cast(reconstruct_32bit_012_avx512(
307 simd,
308 cast(mod_p0),
309 cast(mod_p1),
310 cast(mod_p2),
311 ));
312 }
313 },
314 );
315}
316
317#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
318#[cfg(feature = "nightly")]
319fn reconstruct_slice_52bit_01_avx512(
320 simd: crate::V4IFma,
321 value: &mut [u64],
322 mod_p0: &[u64],
323 mod_p1: &[u64],
324) {
325 simd.vectorize(
326 #[inline(always)]
327 move || {
328 let value = pulp::as_arrays_mut::<8, _>(value).0;
329 let mod_p0 = pulp::as_arrays::<8, _>(mod_p0).0;
330 let mod_p1 = pulp::as_arrays::<8, _>(mod_p1).0;
331 for (value, &mod_p0, &mod_p1) in crate::izip!(value, mod_p0, mod_p1) {
332 *value = cast(reconstruct_52bit_01_avx512(
333 simd,
334 cast(mod_p0),
335 cast(mod_p1),
336 ));
337 }
338 },
339 );
340}
341
342impl Plan32 {
343 pub fn try_new(n: usize) -> Option<Self> {
346 use crate::{prime32::Plan, primes32::*};
347 Some(Self(
348 Plan::try_new(n, P0)?,
349 Plan::try_new(n, P1)?,
350 Plan::try_new(n, P2)?,
351 ))
352 }
353
354 #[inline]
356 pub fn ntt_size(&self) -> usize {
357 self.0.ntt_size()
358 }
359
360 pub fn fwd(&self, value: &[u64], mod_p0: &mut [u32], mod_p1: &mut [u32], mod_p2: &mut [u32]) {
361 for (value, mod_p0, mod_p1, mod_p2) in
362 crate::izip!(value, &mut *mod_p0, &mut *mod_p1, &mut *mod_p2)
363 {
364 *mod_p0 = (value % crate::primes32::P0 as u64) as u32;
365 *mod_p1 = (value % crate::primes32::P1 as u64) as u32;
366 *mod_p2 = (value % crate::primes32::P2 as u64) as u32;
367 }
368 self.0.fwd(mod_p0);
369 self.1.fwd(mod_p1);
370 self.2.fwd(mod_p2);
371 }
372 pub fn fwd_binary(
373 &self,
374 value: &[u64],
375 mod_p0: &mut [u32],
376 mod_p1: &mut [u32],
377 mod_p2: &mut [u32],
378 ) {
379 for (value, mod_p0, mod_p1, mod_p2) in
380 crate::izip!(value, &mut *mod_p0, &mut *mod_p1, &mut *mod_p2)
381 {
382 *mod_p0 = *value as u32;
383 *mod_p1 = *value as u32;
384 *mod_p2 = *value as u32;
385 }
386 self.0.fwd(mod_p0);
387 self.1.fwd(mod_p1);
388 self.2.fwd(mod_p2);
389 }
390
391 pub fn inv(
392 &self,
393 value: &mut [u64],
394 mod_p0: &mut [u32],
395 mod_p1: &mut [u32],
396 mod_p2: &mut [u32],
397 ) {
398 self.0.inv(mod_p0);
399 self.1.inv(mod_p1);
400 self.2.inv(mod_p2);
401
402 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
403 {
404 #[cfg(feature = "nightly")]
405 if let Some(simd) = crate::V4IFma::try_new() {
406 reconstruct_slice_32bit_012_avx512(simd, value, mod_p0, mod_p1, mod_p2);
407 return;
408 }
409 if let Some(simd) = crate::V3::try_new() {
410 reconstruct_slice_32bit_012_avx2(simd, value, mod_p0, mod_p1, mod_p2);
411 return;
412 }
413 }
414
415 for (value, &mod_p0, &mod_p1, &mod_p2) in crate::izip!(value, &*mod_p0, &*mod_p1, &*mod_p2)
416 {
417 *value = reconstruct_32bit_012(mod_p0, mod_p1, mod_p2);
418 }
419 }
420
421 pub fn negacyclic_polymul(&self, prod: &mut [u64], lhs: &[u64], rhs_binary: &[u64]) {
424 let n = prod.len();
425 assert_eq!(n, lhs.len());
426 assert_eq!(n, rhs_binary.len());
427
428 let mut lhs0 = avec![0; n];
429 let mut lhs1 = avec![0; n];
430 let mut lhs2 = avec![0; n];
431
432 let mut rhs0 = avec![0; n];
433 let mut rhs1 = avec![0; n];
434 let mut rhs2 = avec![0; n];
435
436 self.fwd(lhs, &mut lhs0, &mut lhs1, &mut lhs2);
437 self.fwd_binary(rhs_binary, &mut rhs0, &mut rhs1, &mut rhs2);
438
439 self.0.mul_assign_normalize(&mut lhs0, &rhs0);
440 self.1.mul_assign_normalize(&mut lhs1, &rhs1);
441 self.2.mul_assign_normalize(&mut lhs2, &rhs2);
442
443 self.inv(prod, &mut lhs0, &mut lhs1, &mut lhs2);
444 }
445}
446
447#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
448#[cfg(feature = "nightly")]
449impl Plan52 {
450 pub fn try_new(n: usize) -> Option<Self> {
454 use crate::{prime64::Plan, primes52::*};
455 let simd = crate::V4IFma::try_new()?;
456 Some(Self(Plan::try_new(n, P0)?, Plan::try_new(n, P1)?, simd))
457 }
458
459 #[inline]
461 pub fn ntt_size(&self) -> usize {
462 self.0.ntt_size()
463 }
464
465 pub fn fwd(&self, value: &[u64], mod_p0: &mut [u64], mod_p1: &mut [u64]) {
466 use crate::primes52::*;
467 self.2.vectorize(
468 #[inline(always)]
469 || {
470 for (&value, mod_p0, mod_p1) in crate::izip!(value, &mut *mod_p0, &mut *mod_p1) {
471 *mod_p0 = value % P0;
472 *mod_p1 = value % P1;
473 }
474 },
475 );
476 self.0.fwd(mod_p0);
477 self.1.fwd(mod_p1);
478 }
479 pub fn fwd_binary(&self, value: &[u64], mod_p0: &mut [u64], mod_p1: &mut [u64]) {
480 self.2.vectorize(
481 #[inline(always)]
482 || {
483 for (&value, mod_p0, mod_p1) in crate::izip!(value, &mut *mod_p0, &mut *mod_p1) {
484 *mod_p0 = value;
485 *mod_p1 = value;
486 }
487 },
488 );
489 self.0.fwd(mod_p0);
490 self.1.fwd(mod_p1);
491 }
492
493 pub fn inv(&self, value: &mut [u64], mod_p0: &mut [u64], mod_p1: &mut [u64]) {
494 self.0.inv(mod_p0);
495 self.1.inv(mod_p1);
496
497 reconstruct_slice_52bit_01_avx512(self.2, value, mod_p0, mod_p1);
498 }
499
500 pub fn negacyclic_polymul(&self, prod: &mut [u64], lhs: &[u64], rhs_binary: &[u64]) {
503 let n = prod.len();
504 assert_eq!(n, lhs.len());
505 assert_eq!(n, rhs_binary.len());
506
507 let mut lhs0 = avec![0; n];
508 let mut lhs1 = avec![0; n];
509
510 let mut rhs0 = avec![0; n];
511 let mut rhs1 = avec![0; n];
512
513 self.fwd(lhs, &mut lhs0, &mut lhs1);
514 self.fwd_binary(rhs_binary, &mut rhs0, &mut rhs1);
515
516 self.0.mul_assign_normalize(&mut lhs0, &rhs0);
517 self.1.mul_assign_normalize(&mut lhs1, &rhs1);
518
519 self.inv(prod, &mut lhs0, &mut lhs1);
520 }
521}
522
523#[cfg(test)]
524mod tests {
525 use super::*;
526 use crate::prime64::tests::negacyclic_convolution;
527 use alloc::{vec, vec::Vec};
528 use rand::random;
529
530 extern crate alloc;
531
532 #[test]
533 fn reconstruct_32bit() {
534 for n in [32, 64, 256, 1024, 2048] {
535 let plan = Plan32::try_new(n).unwrap();
536
537 let lhs = (0..n).map(|_| random::<u64>()).collect::<Vec<_>>();
538 let rhs = (0..n).map(|_| random::<u64>() % 2).collect::<Vec<_>>();
539 let negacyclic_convolution = negacyclic_convolution(n, 0, &lhs, &rhs);
540
541 let mut prod = vec![0; n];
542 plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
543 assert_eq!(prod, negacyclic_convolution);
544 }
545 }
546
547 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
548 #[cfg(feature = "nightly")]
549 #[test]
550 fn reconstruct_52bit() {
551 for n in [32, 64, 256, 1024, 2048] {
552 if let Some(plan) = Plan52::try_new(n) {
553 let lhs = (0..n).map(|_| random::<u64>()).collect::<Vec<_>>();
554 let rhs = (0..n).map(|_| random::<u64>() % 2).collect::<Vec<_>>();
555 let negacyclic_convolution = negacyclic_convolution(n, 0, &lhs, &rhs);
556
557 let mut prod = vec![0; n];
558 plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
559 assert_eq!(prod, negacyclic_convolution);
560 }
561 }
562 }
563}