1pub(crate) use crate::native64::{mul_mod32, mul_mod64};
2use aligned_vec::avec;
3
4pub struct Plan32(
5 crate::prime32::Plan,
6 crate::prime32::Plan,
7 crate::prime32::Plan,
8 crate::prime32::Plan,
9 crate::prime32::Plan,
10);
11
12#[inline(always)]
13fn reconstruct_32bit_01234_v2(
14 mod_p0: u32,
15 mod_p1: u32,
16 mod_p2: u32,
17 mod_p3: u32,
18 mod_p4: u32,
19) -> u128 {
20 use crate::primes32::*;
21
22 let mod_p12 = {
23 let v1 = mod_p1;
24 let v2 = mul_mod32(P2, P1_INV_MOD_P2, 2 * P2 + mod_p2 - v1);
25 v1 as u64 + (v2 as u64 * P1 as u64)
26 };
27 let mod_p34 = {
28 let v3 = mod_p3;
29 let v4 = mul_mod32(P4, P3_INV_MOD_P4, 2 * P4 + mod_p4 - v3);
30 v3 as u64 + (v4 as u64 * P3 as u64)
31 };
32
33 let v0 = mod_p0 as u64;
34 let v12 = mul_mod64(
35 P12.wrapping_neg(),
36 2 * P12 + mod_p12 - v0,
37 P0_INV_MOD_P12,
38 P0_INV_MOD_P12_SHOUP,
39 );
40 let v34 = mul_mod64(
41 P34.wrapping_neg(),
42 2 * P34 + mod_p34 - (v0 + mul_mod64(P34.wrapping_neg(), v12, P0 as u64, P0_MOD_P34_SHOUP)),
43 P012_INV_MOD_P34,
44 P012_INV_MOD_P34_SHOUP,
45 );
46
47 let sign = v34 > (P34 / 2);
48
49 const _0: u128 = P0 as u128;
50 const _012: u128 = _0.wrapping_mul(P12 as u128);
51 const _01234: u128 = _012.wrapping_mul(P34 as u128);
52
53 let pos = (v0 as u128)
54 .wrapping_add((v12 as u128).wrapping_mul(_0))
55 .wrapping_add((v34 as u128).wrapping_mul(_012));
56 let neg = pos.wrapping_sub(_01234);
57
58 if sign {
59 neg
60 } else {
61 pos
62 }
63}
64
65impl Plan32 {
66 pub fn try_new(n: usize) -> Option<Self> {
69 use crate::{prime32::Plan, primes32::*};
70 Some(Self(
71 Plan::try_new(n, P0)?,
72 Plan::try_new(n, P1)?,
73 Plan::try_new(n, P2)?,
74 Plan::try_new(n, P3)?,
75 Plan::try_new(n, P4)?,
76 ))
77 }
78
79 #[inline]
81 pub fn ntt_size(&self) -> usize {
82 self.0.ntt_size()
83 }
84
85 pub fn fwd(
86 &self,
87 value: &[u128],
88 mod_p0: &mut [u32],
89 mod_p1: &mut [u32],
90 mod_p2: &mut [u32],
91 mod_p3: &mut [u32],
92 mod_p4: &mut [u32],
93 ) {
94 for (value, mod_p0, mod_p1, mod_p2, mod_p3, mod_p4) in crate::izip!(
95 value,
96 &mut *mod_p0,
97 &mut *mod_p1,
98 &mut *mod_p2,
99 &mut *mod_p3,
100 &mut *mod_p4,
101 ) {
102 *mod_p0 = (value % crate::primes32::P0 as u128) as u32;
103 *mod_p1 = (value % crate::primes32::P1 as u128) as u32;
104 *mod_p2 = (value % crate::primes32::P2 as u128) as u32;
105 *mod_p3 = (value % crate::primes32::P3 as u128) as u32;
106 *mod_p4 = (value % crate::primes32::P4 as u128) as u32;
107 }
108 self.0.fwd(mod_p0);
109 self.1.fwd(mod_p1);
110 self.2.fwd(mod_p2);
111 self.3.fwd(mod_p3);
112 self.4.fwd(mod_p4);
113 }
114
115 pub fn fwd_binary(
116 &self,
117 value: &[u128],
118 mod_p0: &mut [u32],
119 mod_p1: &mut [u32],
120 mod_p2: &mut [u32],
121 mod_p3: &mut [u32],
122 mod_p4: &mut [u32],
123 ) {
124 for (value, mod_p0, mod_p1, mod_p2, mod_p3, mod_p4) in crate::izip!(
125 value,
126 &mut *mod_p0,
127 &mut *mod_p1,
128 &mut *mod_p2,
129 &mut *mod_p3,
130 &mut *mod_p4,
131 ) {
132 *mod_p0 = *value as u32;
133 *mod_p1 = *value as u32;
134 *mod_p2 = *value as u32;
135 *mod_p3 = *value as u32;
136 *mod_p4 = *value as u32;
137 }
138 self.0.fwd(mod_p0);
139 self.1.fwd(mod_p1);
140 self.2.fwd(mod_p2);
141 self.3.fwd(mod_p3);
142 self.4.fwd(mod_p4);
143 }
144
145 pub fn inv(
146 &self,
147 value: &mut [u128],
148 mod_p0: &mut [u32],
149 mod_p1: &mut [u32],
150 mod_p2: &mut [u32],
151 mod_p3: &mut [u32],
152 mod_p4: &mut [u32],
153 ) {
154 self.0.inv(mod_p0);
155 self.1.inv(mod_p1);
156 self.2.inv(mod_p2);
157 self.3.inv(mod_p3);
158 self.4.inv(mod_p4);
159
160 for (value, &mod_p0, &mod_p1, &mod_p2, &mod_p3, &mod_p4) in
161 crate::izip!(value, &*mod_p0, &*mod_p1, &*mod_p2, &*mod_p3, &*mod_p4)
162 {
163 *value = reconstruct_32bit_01234_v2(mod_p0, mod_p1, mod_p2, mod_p3, mod_p4);
164 }
165 }
166
167 pub fn negacyclic_polymul(&self, prod: &mut [u128], lhs: &[u128], rhs: &[u128]) {
170 let n = prod.len();
171 assert_eq!(n, lhs.len());
172 assert_eq!(n, rhs.len());
173
174 let mut lhs0 = avec![0; n];
175 let mut lhs1 = avec![0; n];
176 let mut lhs2 = avec![0; n];
177 let mut lhs3 = avec![0; n];
178 let mut lhs4 = avec![0; n];
179
180 let mut rhs0 = avec![0; n];
181 let mut rhs1 = avec![0; n];
182 let mut rhs2 = avec![0; n];
183 let mut rhs3 = avec![0; n];
184 let mut rhs4 = avec![0; n];
185
186 self.fwd(lhs, &mut lhs0, &mut lhs1, &mut lhs2, &mut lhs3, &mut lhs4);
187 self.fwd_binary(rhs, &mut rhs0, &mut rhs1, &mut rhs2, &mut rhs3, &mut rhs4);
188
189 self.0.mul_assign_normalize(&mut lhs0, &rhs0);
190 self.1.mul_assign_normalize(&mut lhs1, &rhs1);
191 self.2.mul_assign_normalize(&mut lhs2, &rhs2);
192 self.3.mul_assign_normalize(&mut lhs3, &rhs3);
193 self.4.mul_assign_normalize(&mut lhs4, &rhs4);
194
195 self.inv(prod, &mut lhs0, &mut lhs1, &mut lhs2, &mut lhs3, &mut lhs4);
196 }
197}
198
199#[cfg(test)]
200mod tests {
201 use super::*;
202 use crate::native128::tests::negacyclic_convolution;
203 use alloc::{vec, vec::Vec};
204 use rand::random;
205
206 extern crate alloc;
207
208 #[test]
209 fn reconstruct_32bit() {
210 for n in [32, 64, 256, 1024, 2048] {
211 let plan = Plan32::try_new(n).unwrap();
212
213 let lhs = (0..n).map(|_| random::<u128>()).collect::<Vec<_>>();
214 let rhs = (0..n).map(|_| random::<u128>() % 2).collect::<Vec<_>>();
215 let negacyclic_convolution = negacyclic_convolution(n, &lhs, &rhs);
216
217 let mut prod = vec![0; n];
218 plan.negacyclic_polymul(&mut prod, &lhs, &rhs);
219 assert_eq!(prod, negacyclic_convolution);
220 }
221 }
222}