1use aligned_vec::avec;
2
3#[allow(unused_imports)]
4use pulp::*;
5
6use crate::native32::mul_mod32;
7
8#[derive(Clone, Debug)]
11pub struct Plan32(crate::prime32::Plan, crate::prime32::Plan);
12
13#[cfg(all(feature = "nightly", any(target_arch = "x86", target_arch = "x86_64")))]
17#[cfg_attr(docsrs, doc(cfg(feature = "nightly")))]
18#[derive(Clone, Debug)]
19pub struct Plan52(crate::prime64::Plan, crate::V4IFma);
20
21#[inline(always)]
22pub(crate) fn reconstruct_32bit_01(mod_p0: u32, mod_p1: u32) -> u32 {
23 use crate::primes32::*;
24
25 let v0 = mod_p0;
26 let v1 = mul_mod32(P1, P0_INV_MOD_P1, 2 * P1 + mod_p1 - v0);
27
28 let sign = v1 > (P1 / 2);
29
30 const _0: u32 = P0;
31 const _01: u32 = _0.wrapping_mul(P1);
32
33 let pos = v0.wrapping_add(v1.wrapping_mul(_0));
34 let neg = pos.wrapping_sub(_01);
35
36 if sign {
37 neg
38 } else {
39 pos
40 }
41}
42
43#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
44#[inline(always)]
45pub(crate) fn reconstruct_32bit_01_avx2(simd: crate::V3, mod_p0: u32x8, mod_p1: u32x8) -> u32x8 {
46 use crate::{native32::mul_mod32_avx2, primes32::*};
47
48 let p0 = simd.splat_u32x8(P0);
49 let p1 = simd.splat_u32x8(P1);
50 let two_p1 = simd.splat_u32x8(2 * P1);
51 let half_p1 = simd.splat_u32x8(P1 / 2);
52
53 let p0_inv_mod_p1 = simd.splat_u32x8(P0_INV_MOD_P1);
54 let p0_inv_mod_p1_shoup = simd.splat_u32x8(P0_INV_MOD_P1_SHOUP);
55
56 let p01 = simd.splat_u32x8(P0.wrapping_mul(P1));
57
58 let v0 = mod_p0;
59 let v1 = mul_mod32_avx2(
60 simd,
61 p1,
62 simd.wrapping_sub_u32x8(simd.wrapping_add_u32x8(two_p1, mod_p1), v0),
63 p0_inv_mod_p1,
64 p0_inv_mod_p1_shoup,
65 );
66
67 let sign = simd.cmp_gt_u32x8(v1, half_p1);
68 let pos = simd.wrapping_add_u32x8(v0, simd.wrapping_mul_u32x8(v1, p0));
69
70 let neg = simd.wrapping_sub_u32x8(pos, p01);
71
72 simd.select_u32x8(sign, neg, pos)
73}
74
75#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
76#[cfg(feature = "nightly")]
77#[inline(always)]
78fn reconstruct_32bit_01_avx512(simd: crate::V4IFma, mod_p0: u32x16, mod_p1: u32x16) -> u32x16 {
79 use crate::{native32::mul_mod32_avx512, primes32::*};
80
81 let p0 = simd.splat_u32x16(P0);
82 let p1 = simd.splat_u32x16(P1);
83 let two_p1 = simd.splat_u32x16(2 * P1);
84 let half_p1 = simd.splat_u32x16(P1 / 2);
85
86 let p0_inv_mod_p1 = simd.splat_u32x16(P0_INV_MOD_P1);
87 let p0_inv_mod_p1_shoup = simd.splat_u32x16(P0_INV_MOD_P1_SHOUP);
88
89 let p01 = simd.splat_u32x16(P0.wrapping_mul(P1));
90
91 let v0 = mod_p0;
92 let v1 = mul_mod32_avx512(
93 simd,
94 p1,
95 simd.wrapping_sub_u32x16(simd.wrapping_add_u32x16(two_p1, mod_p1), v0),
96 p0_inv_mod_p1,
97 p0_inv_mod_p1_shoup,
98 );
99
100 let sign = simd.cmp_gt_u32x16(v1, half_p1);
101 let pos = simd.wrapping_add_u32x16(v0, simd.wrapping_mul_u32x16(v1, p0));
102
103 let neg = simd.wrapping_sub_u32x16(pos, p01);
104
105 simd.select_u32x16(sign, neg, pos)
106}
107
108#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
109#[cfg(feature = "nightly")]
110#[inline(always)]
111fn reconstruct_52bit_0_avx512(simd: crate::V4IFma, mod_p0: u64x8) -> u32x8 {
112 use crate::primes52::*;
113
114 let p0 = simd.splat_u64x8(P0);
115 let half_p0 = simd.splat_u64x8(P0 / 2);
116
117 let v0 = mod_p0;
118
119 let sign = simd.cmp_gt_u64x8(v0, half_p0);
120
121 let pos = v0;
122 let neg = simd.wrapping_sub_u64x8(pos, p0);
123
124 simd.convert_u64x8_to_u32x8(simd.select_u64x8(sign, neg, pos))
125}
126
127#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
128fn reconstruct_slice_32bit_01_avx2(
129 simd: crate::V3,
130 value: &mut [u32],
131 mod_p0: &[u32],
132 mod_p1: &[u32],
133) {
134 simd.vectorize(
135 #[inline(always)]
136 move || {
137 let value = pulp::as_arrays_mut::<8, _>(value).0;
138 let mod_p0 = pulp::as_arrays::<8, _>(mod_p0).0;
139 let mod_p1 = pulp::as_arrays::<8, _>(mod_p1).0;
140 for (value, &mod_p0, &mod_p1) in crate::izip!(value, mod_p0, mod_p1) {
141 *value = cast(reconstruct_32bit_01_avx2(simd, cast(mod_p0), cast(mod_p1)));
142 }
143 },
144 );
145}
146
147#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
148#[cfg(feature = "nightly")]
149fn reconstruct_slice_32bit_01_avx512(
150 simd: crate::V4IFma,
151 value: &mut [u32],
152 mod_p0: &[u32],
153 mod_p1: &[u32],
154) {
155 simd.vectorize(
156 #[inline(always)]
157 move || {
158 let value = pulp::as_arrays_mut::<16, _>(value).0;
159 let mod_p0 = pulp::as_arrays::<16, _>(mod_p0).0;
160 let mod_p1 = pulp::as_arrays::<16, _>(mod_p1).0;
161 for (value, &mod_p0, &mod_p1) in crate::izip!(value, mod_p0, mod_p1) {
162 *value = cast(reconstruct_32bit_01_avx512(
163 simd,
164 cast(mod_p0),
165 cast(mod_p1),
166 ));
167 }
168 },
169 );
170}
171
172#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
173#[cfg(feature = "nightly")]
174fn reconstruct_slice_52bit_0_avx512(simd: crate::V4IFma, value: &mut [u32], mod_p0: &[u64]) {
175 simd.vectorize(
176 #[inline(always)]
177 move || {
178 let value = pulp::as_arrays_mut::<8, _>(value).0;
179 let mod_p0 = pulp::as_arrays::<8, _>(mod_p0).0;
180 for (value, &mod_p0) in crate::izip!(value, mod_p0) {
181 *value = cast(reconstruct_52bit_0_avx512(simd, cast(mod_p0)));
182 }
183 },
184 );
185}
186
187impl Plan32 {
188 pub fn try_new(n: usize) -> Option<Self> {
191 use crate::{prime32::Plan, primes32::*};
192 Some(Self(Plan::try_new(n, P0)?, Plan::try_new(n, P1)?))
193 }
194
195 #[inline]
197 pub fn ntt_size(&self) -> usize {
198 self.0.ntt_size()
199 }
200
201 pub fn fwd(&self, value: &[u32], mod_p0: &mut [u32], mod_p1: &mut [u32]) {
202 for (value, mod_p0, mod_p1) in crate::izip!(value, &mut *mod_p0, &mut *mod_p1) {
203 *mod_p0 = value % crate::primes32::P0;
204 *mod_p1 = value % crate::primes32::P1;
205 }
206 self.0.fwd(mod_p0);
207 self.1.fwd(mod_p1);
208 }
209
210 pub fn fwd_binary(&self, value: &[u32], mod_p0: &mut [u32], mod_p1: &mut [u32]) {
211 for (value, mod_p0, mod_p1) in crate::izip!(value, &mut *mod_p0, &mut *mod_p1) {
212 *mod_p0 = *value;
213 *mod_p1 = *value;
214 }
215 self.0.fwd(mod_p0);
216 self.1.fwd(mod_p1);
217 }
218
219 pub fn inv(&self, value: &mut [u32], mod_p0: &mut [u32], mod_p1: &mut [u32]) {
220 self.0.inv(mod_p0);
221 self.1.inv(mod_p1);
222
223 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
224 {
225 #[cfg(feature = "nightly")]
226 if let Some(simd) = crate::V4IFma::try_new() {
227 reconstruct_slice_32bit_01_avx512(simd, value, mod_p0, mod_p1);
228 return;
229 }
230 if let Some(simd) = crate::V3::try_new() {
231 reconstruct_slice_32bit_01_avx2(simd, value, mod_p0, mod_p1);
232 return;
233 }
234 }
235
236 for (value, &mod_p0, &mod_p1) in crate::izip!(value, &*mod_p0, &*mod_p1) {
237 *value = reconstruct_32bit_01(mod_p0, mod_p1);
238 }
239 }
240
241 pub fn negacyclic_polymul(&self, prod: &mut [u32], lhs: &[u32], rhs_binary: &[u32]) {
244 let n = prod.len();
245 assert_eq!(n, lhs.len());
246 assert_eq!(n, rhs_binary.len());
247
248 let mut lhs0 = avec![0; n];
249 let mut lhs1 = avec![0; n];
250
251 let mut rhs0 = avec![0; n];
252 let mut rhs1 = avec![0; n];
253
254 self.fwd(lhs, &mut lhs0, &mut lhs1);
255 self.fwd_binary(rhs_binary, &mut rhs0, &mut rhs1);
256
257 self.0.mul_assign_normalize(&mut lhs0, &rhs0);
258 self.1.mul_assign_normalize(&mut lhs1, &rhs1);
259
260 self.inv(prod, &mut lhs0, &mut lhs1);
261 }
262}
263
264#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
265#[cfg(feature = "nightly")]
266impl Plan52 {
267 pub fn try_new(n: usize) -> Option<Self> {
271 use crate::{prime64::Plan, primes52::*};
272 let simd = crate::V4IFma::try_new()?;
273 Some(Self(Plan::try_new(n, P0)?, simd))
274 }
275
276 #[inline]
278 pub fn ntt_size(&self) -> usize {
279 self.0.ntt_size()
280 }
281
282 pub fn fwd(&self, value: &[u32], mod_p0: &mut [u64]) {
283 self.1.vectorize(
284 #[inline(always)]
285 || {
286 for (value, mod_p0) in crate::izip!(value, &mut *mod_p0) {
287 *mod_p0 = *value as u64;
288 }
289 },
290 );
291 self.0.fwd(mod_p0);
292 }
293
294 pub fn fwd_binary(&self, value: &[u32], mod_p0: &mut [u64]) {
295 self.fwd(value, mod_p0);
296 }
297
298 pub fn inv(&self, value: &mut [u32], mod_p0: &mut [u64]) {
299 self.0.inv(mod_p0);
300
301 let simd = self.1;
302 reconstruct_slice_52bit_0_avx512(simd, value, mod_p0);
303 }
304
305 pub fn negacyclic_polymul(&self, prod: &mut [u32], lhs: &[u32], rhs_binary: &[u32]) {
308 let n = prod.len();
309 assert_eq!(n, lhs.len());
310 assert_eq!(n, rhs_binary.len());
311
312 let mut lhs0 = avec![0; n];
313 let mut rhs0 = avec![0; n];
314
315 self.fwd(lhs, &mut lhs0);
316 self.fwd_binary(rhs_binary, &mut rhs0);
317
318 self.0.mul_assign_normalize(&mut lhs0, &rhs0);
319
320 self.inv(prod, &mut lhs0);
321 }
322}
323
324#[cfg(test)]
325mod tests {
326 use super::*;
327 use crate::prime32::tests::negacyclic_convolution;
328 use alloc::{vec, vec::Vec};
329 use rand::random;
330
331 extern crate alloc;
332
333 #[test]
334 fn reconstruct_32bit() {
335 for n in [32, 64, 256, 1024, 2048] {
336 let plan = Plan32::try_new(n).unwrap();
337
338 let lhs = (0..n).map(|_| random::<u32>()).collect::<Vec<_>>();
339 let rhs = (0..n).map(|_| random::<u32>() % 2).collect::<Vec<_>>();
340 let negacyclic_convolution = negacyclic_convolution(n, 0, &lhs, &rhs);
341
342 let mut prod = vec![0; n];
343 plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
344 assert_eq!(prod, negacyclic_convolution);
345 }
346 }
347
348 #[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
349 #[cfg(feature = "nightly")]
350 #[test]
351 fn reconstruct_52bit() {
352 for n in [32, 64, 256, 1024, 2048] {
353 if let Some(plan) = Plan52::try_new(n) {
354 let lhs = (0..n).map(|_| random::<u32>()).collect::<Vec<_>>();
355 let rhs = (0..n).map(|_| random::<u32>() % 2).collect::<Vec<_>>();
356 let negacyclic_convolution = negacyclic_convolution(n, 0, &lhs, &rhs);
357
358 let mut prod = vec![0; n];
359 plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
360 assert_eq!(prod, negacyclic_convolution);
361 }
362 }
363 }
364}