1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
/*
Copyright (C) 2012 William Hart
Copyright (C) 2025 Fredrik Johansson
This file is part of FLINT.
FLINT is free software: you can redistribute it and/or modify it under
the terms of the GNU Lesser General Public License (LGPL) as published
by the Free Software Foundation; either version 3 of the License, or
(at your option) any later version. See <https://www.gnu.org/licenses/>.
*/
#include "mpn_extras.h"
/*
Hack: flint_mpn_mod_preinvn is currently slow for divisions of
size (n+2,n) or (n+1,n), so we use the following code adapted from
flint_mpn_mod_preinv1. It would be better to improve flint_mpn_mod_preinvn,
but this is not straightforward; flint_mpn_mod_preinvn is only
allowed to overwrite n limbs while the following overwrites m limbs
(which is fine for the use local to this file).
Note: flint_mpn_divrem21_preinv is documented as requiring the
precomputed inverse generated by flint_mpn_preinv1, but it turns out
that it works to take the top limb of an inverse computed by
flint_mpn_preinvn.
*/
static void flint_mpn_mod_preinv1(mp_ptr a, mp_size_t m,
mp_srcptr b, mp_size_t n, mp_limb_t dinv)
{
mp_size_t i;
mp_limb_t q;
if (mpn_cmp(a + m - n, b, n) >= 0)
mpn_sub_n(a + m - n, a + m - n, b, n);
for (i = m - 1; i >= n; i--)
{
flint_mpn_divrem21_preinv(q, a[i], a[i - 1], dinv);
a[i] -= mpn_submul_1(a + i - n, b, n, q);
if (mpn_cmp(a + i - n, b, n) >= 0 || a[i] != 0)
a[i] -= mpn_sub_n(a + i - n, a + i - n, b, n);
}
}
mp_size_t flint_mpn_mulmod_precond_matrix_alloc(mp_size_t n)
{
/* We only need n^2 limbs for the result, but allocate one extra limb
which flint_mpn_mulmod_precond_precompute can use as scratch space
to save a copy. */
return n * n + 1;
}
void
flint_mpn_mulmod_precond_matrix_precompute(mp_ptr apre, mp_srcptr a, mp_size_t n, mp_srcptr d, mp_srcptr dinv, ulong norm)
{
slong i;
FLINT_ASSERT(n >= 2);
if (norm == 0)
flint_mpn_copyi(apre, a, n);
else
mpn_lshift(apre, a, n, norm);
for (i = 1; i < n; i++)
{
apre[i * n] = 0;
flint_mpn_copyi(apre + i * n + 1, apre + (i - 1) * n, n);
#if 0
flint_mpn_mod_preinvn(apre + i * n, apre + i * n, n + 1, d, n, dinv);
#else
flint_mpn_mod_preinv1(apre + i * n, n + 1, d, n, dinv[n - 1]);
#endif
}
}
/* p-mulmod_precond */
int
flint_mpn_mulmod_want_precond(mp_size_t n, slong num, ulong norm)
{
if (num < 4 || (n == 2 && norm == 0))
return MPN_MULMOD_PRECOND_NONE;
if (n <= 10 || (n <= 12 && num <= 12))
return MPN_MULMOD_PRECOND_SHOUP;
if (n <= 64 || (n <= 128 && num >= 6) || (n <= 192 && num >= 20))
return MPN_MULMOD_PRECOND_MATRIX;
if ((n <= 320 && num >= 9) || (n <= 768 && num >= 20))
return MPN_MULMOD_PRECOND_SHOUP;
return MPN_MULMOD_PRECOND_NONE;
}
void
flint_mpn_mulmod_precond_matrix(mp_ptr rp, mp_srcptr apre, mp_srcptr b, mp_size_t n, mp_srcptr d, mp_srcptr dinv, ulong norm)
{
/*
Note: it is possible to add a special case for n = 2.
For example, we can do something like
FLINT_MPN_MUL_2X1(t[2], t[1], t[0], apre[1], apre[0], b[0]);
FLINT_MPN_MUL_2X1(u[2], u[1], u[0], apre[3], apre[2], b[1]);
add_ssssaaaaaaaa(t[3], t[2], t[1], t[0], 0, t[2], t[1], t[0], 0, u[2], u[1], u[0]);
and then reduce mod d using the same operation sequence as in
flint_mpn_mulmod_preinvn_2.
We omit this special case as the resulting code does not run
appreciable faster than flint_mpn_mulmod_preinvn_2.
*/
if (n == 2)
{
mp_limb_t cy, r0, r1;
mp_limb_t t[10];
mp_limb_t u[3];
/* mpn_mul_n(t, a, b, n) */
FLINT_MPN_MUL_2X1(t[2], t[1], t[0], apre[1], apre[0], b[0]);
FLINT_MPN_MUL_2X1(u[2], u[1], u[0], apre[3], apre[2], b[1]);
add_ssssaaaaaaaa(t[3], t[2], t[1], t[0], 0, t[2], t[1], t[0], 0, u[2], u[1], u[0]);
/* mpn_mul_n(t + 3*n, t + n, dinv, n) */
FLINT_MPN_MUL_2X2(t[9], t[8], t[7], t[6], t[3], t[2], dinv[1], dinv[0]);
/* mpn_add_n(t + 4*n, t + 4*n, t + n, n) */
add_ssaaaa(t[9], t[8], t[9], t[8], t[3], t[2]);
/* mpn_mul_n(t + 2*n, t + 4*n, d, n) */
FLINT_MPN_MUL_3P2X2(t[6], t[5], t[4], t[9], t[8], d[1], d[0]);
/* cy = t[n] - t[3*n] - mpn_sub_n(r, t, t + 2*n, n) */
sub_dddmmmsss(cy, r1, r0, t[2], t[1], t[0], t[6], t[5], t[4]);
while (cy > 0)
{
/* cy -= mpn_sub_n(r, r, d, n) */
sub_dddmmmsss(cy, r1, r0, cy, r1, r0, 0, d[1], d[0]);
}
if ((r1 > d[1]) || (r1 == d[1] && r0 >= d[0]))
{
/* mpn_sub_n(r, r, d, n) */
sub_ddmmss(r1, r0, r1, r0, d[1], d[0]);
}
if (norm)
{
rp[0] = (r0 >> norm) | (r1 << (FLINT_BITS - norm));
rp[1] = (r1 >> norm);
}
else
{
rp[0] = r0;
rp[1] = r1;
}
return;
}
mp_ptr tmp;
mp_limb_t cy, cy1, cy2;
slong i, rn;
TMP_INIT;
TMP_START;
tmp = TMP_ALLOC((n + 2) * sizeof(mp_limb_t));
cy1 = mpn_mul_1(tmp, apre, n, b[0]);
cy2 = 0;
for (i = 1; i < n; i++)
{
cy = mpn_addmul_1(tmp, apre + i * n, n, b[i]);
add_ssaaaa(cy2, cy1, cy2, cy1, 0, cy);
}
tmp[n] = cy1;
tmp[n + 1] = cy2;
rn = (n + 2) - (tmp[n + 1] == 0);
#if 0
flint_mpn_mod_preinvn(tmp, tmp, rn, d, n, dinv);
#else
flint_mpn_mod_preinv1(tmp, rn, d, n, dinv[n - 1]);
#endif
if (norm == 0)
flint_mpn_copyi(rp, tmp, n);
else
mpn_rshift(rp, tmp, n, norm);
TMP_END;
}
void
flint_mpn_fmmamod_precond_matrix(mp_ptr rp, mp_srcptr apre1, mp_srcptr b1, mp_srcptr apre2, mp_srcptr b2, mp_size_t n, mp_srcptr d, mp_srcptr dinv, ulong norm)
{
mp_ptr tmp;
mp_limb_t cy, cy1, cy2;
slong i, rn;
TMP_INIT;
/* Something like this if we want a special case for n = 2 */
/*
if (n == 2)
{
ulong tmp[4];
ulong ump[4];
FLINT_MPN_MUL_2X1(tmp[2], tmp[1], tmp[0], apre1[1], apre1[0], b1[0]);
FLINT_MPN_MUL_2X1(ump[2], ump[1], ump[0], apre1[3], apre1[2], b1[1]);
add_ssssaaaaaaaa(tmp[3], tmp[2], tmp[1], tmp[0], tmp[3], tmp[2], tmp[1], tmp[0], 0, ump[2], ump[1], ump[0]);
FLINT_MPN_MUL_2X1(ump[2], ump[1], ump[0], apre2[3], apre2[2], b2[0]);
add_ssssaaaaaaaa(tmp[3], tmp[2], tmp[1], tmp[0], tmp[3], tmp[2], tmp[1], tmp[0], 0, ump[2], ump[1], ump[0]);
FLINT_MPN_MUL_2X1(ump[2], ump[1], ump[0], apre2[3], apre2[2], b2[1]);
add_ssssaaaaaaaa(tmp[3], tmp[2], tmp[1], tmp[0], tmp[3], tmp[2], tmp[1], tmp[0], 0, ump[2], ump[1], ump[0]);
rn = (n + 2) - (tmp[n + 1] == 0);
flint_mpn_mod_preinv1(tmp, rn, d, n, dinv[n - 1]);
if (norm)
{
rp[0] = (tmp[0] >> norm) | (tmp[1] << (FLINT_BITS - norm));
rp[1] = (tmp[1] >> norm);
}
else
{
rp[0] = tmp[0];
rp[1] = tmp[1];
}
return;
}
*/
TMP_START;
tmp = TMP_ALLOC((n + 2) * sizeof(mp_limb_t));
cy1 = mpn_mul_1(tmp, apre1, n, b1[0]);
cy2 = 0;
for (i = 1; i < n; i++)
{
cy = mpn_addmul_1(tmp, apre1 + i * n, n, b1[i]);
add_ssaaaa(cy2, cy1, cy2, cy1, 0, cy);
}
for (i = 0; i < n; i++)
{
cy = mpn_addmul_1(tmp, apre2 + i * n, n, b2[i]);
add_ssaaaa(cy2, cy1, cy2, cy1, 0, cy);
}
tmp[n] = cy1;
tmp[n + 1] = cy2;
rn = (n + 2) - (tmp[n + 1] == 0);
#if 0
flint_mpn_mod_preinvn(tmp, tmp, rn, d, n, dinv);
#else
flint_mpn_mod_preinv1(tmp, rn, d, n, dinv[n - 1]);
#endif
if (norm == 0)
flint_mpn_copyi(rp, tmp, n);
else
mpn_rshift(rp, tmp, n, norm);
TMP_END;
}