1#![allow(clippy::too_many_arguments)]
5#![allow(clippy::uninit_vec)]
6
7use alloc::vec::Vec;
8
9use crate::{
10 common::is_short_half,
11 fft,
12 fpr::*,
13 rng::{prng_get_u64, prng_get_u8, prng_init, Prng},
14 shake::InnerShake256Context,
15};
16
17fn ffldl_treesize(logn: u32) -> usize {
24 ((logn + 1) as usize) << logn
25}
26
27fn ffldl_fft_inner(tree: &mut [Fpr], g0: &mut [Fpr], g1: &mut [Fpr], logn: u32, tmp: &mut [Fpr]) {
33 let n: usize = 1 << logn;
34 if n == 1 {
35 tree[0] = g0[0];
36 return;
37 }
38 let hn = n >> 1;
39
40 fft::poly_ldlmv_fft(tmp, tree, g0, g1, g0, logn);
42
43 {
46 let (g1_lo, g1_hi) = g1.split_at_mut(hn);
47 fft::poly_split_fft(g1_lo, g1_hi, &g0[..n], logn);
48 }
49 {
50 let (g0_lo, g0_hi) = g0.split_at_mut(hn);
51 fft::poly_split_fft(g0_lo, g0_hi, &tmp[..n], logn);
52 }
53
54 {
56 let (g1_lo, g1_rest) = g1.split_at_mut(hn);
57 ffldl_fft_inner(&mut tree[n..], g1_lo, g1_rest, logn - 1, tmp);
58 }
59
60 let off = n + ffldl_treesize(logn - 1);
62 {
63 let (g0_lo, g0_rest) = g0.split_at_mut(hn);
64 ffldl_fft_inner(&mut tree[off..], g0_lo, g0_rest, logn - 1, tmp);
65 }
66}
67
68fn ffldl_fft(tree: &mut [Fpr], g00: &[Fpr], g01: &[Fpr], g11: &[Fpr], logn: u32, tmp: &mut [Fpr]) {
71 let n: usize = 1 << logn;
72 if n == 1 {
73 tree[0] = g00[0];
74 return;
75 }
76 let hn = n >> 1;
77
78 let d00 = &mut tmp[..n];
79 d00.copy_from_slice(&g00[..n]);
80
81 let (d00_slice, rest) = tmp.split_at_mut(n);
83 let (d11_slice, scratch) = rest.split_at_mut(n);
84
85 fft::poly_ldlmv_fft(d11_slice, tree, g00, g01, g11, logn);
86
87 {
89 let (s0, s1) = scratch.split_at_mut(hn);
90 fft::poly_split_fft(s0, s1, d00_slice, logn);
91 }
92 {
94 let (d0, d1) = d00_slice.split_at_mut(hn);
95 fft::poly_split_fft(d0, d1, d11_slice, logn);
96 }
97 d11_slice[..n].copy_from_slice(&scratch[..n]);
99
100 {
102 let (d11_lo, d11_hi) = d11_slice.split_at_mut(hn);
103 ffldl_fft_inner(&mut tree[n..], d11_lo, d11_hi, logn - 1, scratch);
104 }
105
106 let off = n + ffldl_treesize(logn - 1);
108 {
109 let (d00_lo, d00_hi) = d00_slice.split_at_mut(hn);
110 ffldl_fft_inner(&mut tree[off..], d00_lo, d00_hi, logn - 1, scratch);
111 }
112}
113
114fn ffldl_binary_normalize(tree: &mut [Fpr], orig_logn: u32, logn: u32) {
117 let n: usize = 1 << logn;
118 if n == 1 {
119 tree[0] = fpr_mul(fpr_sqrt(tree[0]), FPR_INV_SIGMA[orig_logn as usize]);
120 } else {
121 ffldl_binary_normalize(&mut tree[n..], orig_logn, logn - 1);
122 let off = n + ffldl_treesize(logn - 1);
123 ffldl_binary_normalize(&mut tree[off..], orig_logn, logn - 1);
124 }
125}
126
127fn smallints_to_fpr(r: &mut [Fpr], t: &[i8], logn: u32) {
134 let n: usize = 1 << logn;
135 for u in 0..n {
136 r[u] = fpr_of(t[u] as i64);
137 }
138}
139
140#[inline(always)]
142fn skoff_b00(_logn: u32) -> usize {
143 0
144}
145#[inline(always)]
146fn skoff_b01(logn: u32) -> usize {
147 1 << logn
148}
149#[inline(always)]
150fn skoff_b10(logn: u32) -> usize {
151 2 << logn
152}
153#[inline(always)]
154fn skoff_b11(logn: u32) -> usize {
155 3 << logn
156}
157#[inline(always)]
158fn skoff_tree(logn: u32) -> usize {
159 4 << logn
160}
161
162pub fn expand_privkey(
165 expanded_key: &mut [Fpr],
166 f: &[i8],
167 g: &[i8],
168 big_f: &[i8],
169 big_g: &[i8],
170 logn: u32,
171 tmp: &mut [u8],
172) {
173 let n: usize = 1 << logn;
174
175 let b00_off = skoff_b00(logn);
176 let b01_off = skoff_b01(logn);
177 let b10_off = skoff_b10(logn);
178 let b11_off = skoff_b11(logn);
179 let tree_off = skoff_tree(logn);
180
181 smallints_to_fpr(&mut expanded_key[b01_off..], f, logn);
183 smallints_to_fpr(&mut expanded_key[b00_off..], g, logn);
184 smallints_to_fpr(&mut expanded_key[b11_off..], big_f, logn);
185 smallints_to_fpr(&mut expanded_key[b10_off..], big_g, logn);
186
187 fft::fft(&mut expanded_key[b01_off..b01_off + n], logn);
188 fft::fft(&mut expanded_key[b00_off..b00_off + n], logn);
189 fft::fft(&mut expanded_key[b11_off..b11_off + n], logn);
190 fft::fft(&mut expanded_key[b10_off..b10_off + n], logn);
191 fft::poly_neg(&mut expanded_key[b01_off..b01_off + n], logn);
192 fft::poly_neg(&mut expanded_key[b11_off..b11_off + n], logn);
193
194 let ftmp: &mut [Fpr] = unsafe {
196 core::slice::from_raw_parts_mut(
197 tmp.as_mut_ptr() as *mut Fpr,
198 tmp.len() / core::mem::size_of::<Fpr>(),
199 )
200 };
201
202 let (g00, rest) = ftmp.split_at_mut(n);
203 let (g01_g, rest) = rest.split_at_mut(n);
204 let (g11_g, gxx) = rest.split_at_mut(n);
205
206 g00.copy_from_slice(&expanded_key[b00_off..b00_off + n]);
207 fft::poly_mulselfadj_fft(g00, logn);
208 gxx[..n].copy_from_slice(&expanded_key[b01_off..b01_off + n]);
209 fft::poly_mulselfadj_fft(&mut gxx[..n], logn);
210 fft::poly_add(g00, &gxx[..n], logn);
211
212 g01_g.copy_from_slice(&expanded_key[b00_off..b00_off + n]);
213 fft::poly_muladj_fft(g01_g, &expanded_key[b10_off..b10_off + n], logn);
214 gxx[..n].copy_from_slice(&expanded_key[b01_off..b01_off + n]);
215 fft::poly_muladj_fft(&mut gxx[..n], &expanded_key[b11_off..b11_off + n], logn);
216 fft::poly_add(g01_g, &gxx[..n], logn);
217
218 g11_g.copy_from_slice(&expanded_key[b10_off..b10_off + n]);
219 fft::poly_mulselfadj_fft(g11_g, logn);
220 gxx[..n].copy_from_slice(&expanded_key[b11_off..b11_off + n]);
221 fft::poly_mulselfadj_fft(&mut gxx[..n], logn);
222 fft::poly_add(g11_g, &gxx[..n], logn);
223
224 ffldl_fft(&mut expanded_key[tree_off..], g00, g01_g, g11_g, logn, gxx);
225 ffldl_binary_normalize(&mut expanded_key[tree_off..], logn, logn);
226}
227
228pub struct SamplerContext {
234 pub p: Prng,
235 pub sigma_min: Fpr,
236}
237
238type SamplerZ = fn(&mut SamplerContext, Fpr, Fpr) -> i32;
240
241fn ff_sampling_fft_dyntree(
247 samp: SamplerZ,
248 samp_ctx: &mut SamplerContext,
249 t0: &mut [Fpr],
250 t1: &mut [Fpr],
251 g00: &mut [Fpr],
252 g01: &mut [Fpr],
253 g11: &mut [Fpr],
254 orig_logn: u32,
255 logn: u32,
256 tmp: &mut [Fpr],
257) {
258 if logn == 0 {
259 let leaf = fpr_mul(fpr_sqrt(g00[0]), FPR_INV_SIGMA[orig_logn as usize]);
260 t0[0] = fpr_of(samp(samp_ctx, t0[0], leaf) as i64);
261 t1[0] = fpr_of(samp(samp_ctx, t1[0], leaf) as i64);
262 return;
263 }
264
265 let n: usize = 1 << logn;
266 let hn = n >> 1;
267
268 fft::poly_ldl_fft(g00, g01, g11, logn);
270
271 {
273 let (t_lo, t_hi) = tmp.split_at_mut(hn);
274 fft::poly_split_fft(t_lo, t_hi, g00, logn);
275 }
276 g00[..n].copy_from_slice(&tmp[..n]);
277 {
278 let (t_lo, t_hi) = tmp.split_at_mut(hn);
279 fft::poly_split_fft(t_lo, t_hi, g11, logn);
280 }
281 g11[..n].copy_from_slice(&tmp[..n]);
282 tmp[..n].copy_from_slice(&g01[..n]);
283 g01[..hn].copy_from_slice(&g00[..hn]);
284 g01[hn..n].copy_from_slice(&g11[..hn]);
285
286 {
288 let z1 = &mut tmp[n..];
289 let (z1_lo, z1_hi_and_rest) = z1.split_at_mut(hn);
290 let (z1_hi, scratch) = z1_hi_and_rest.split_at_mut(hn);
291 fft::poly_split_fft(z1_lo, z1_hi, t1, logn);
292
293 let (g11_lo, g11_hi) = g11.split_at_mut(hn);
294 ff_sampling_fft_dyntree(
295 samp,
296 samp_ctx,
297 z1_lo,
298 z1_hi,
299 g11_lo,
300 g11_hi,
301 &mut g01[hn..],
302 orig_logn,
303 logn - 1,
304 scratch,
305 );
306
307 let mut z1_lo_copy = Vec::<Fpr>::with_capacity(hn);
309 let mut z1_hi_copy = Vec::<Fpr>::with_capacity(hn);
310 unsafe {
311 z1_lo_copy.set_len(hn);
312 z1_hi_copy.set_len(hn);
313 }
314 z1_lo_copy.copy_from_slice(z1_lo);
315 z1_hi_copy.copy_from_slice(z1_hi);
316
317 fft::poly_merge_fft(&mut scratch[..n], &z1_lo_copy, &z1_hi_copy, logn);
319
320 z1_lo.copy_from_slice(&t1[..hn]);
323 z1_hi.copy_from_slice(&t1[hn..n]);
324 let mut merged_copy = Vec::<Fpr>::with_capacity(n);
326 unsafe {
327 merged_copy.set_len(n);
328 }
329 merged_copy.copy_from_slice(&scratch[..n]);
330 t1[..n].copy_from_slice(&merged_copy);
331 }
332 {
334 let (l10, z1_full) = tmp.split_at_mut(n);
335 let _ = l10;
336 fft::poly_sub(&mut z1_full[..n], &t1[..n], logn);
337 }
338
339 {
342 let (l10, diff) = tmp.split_at_mut(n);
343 fft::poly_mul_fft(l10, &diff[..n], logn);
344 }
345 fft::poly_add(t0, &tmp[..n], logn);
346
347 {
349 let (z0, rest_tmp) = tmp.split_at_mut(n);
350 let (z0_lo, z0_hi) = z0.split_at_mut(hn);
351 fft::poly_split_fft(z0_lo, z0_hi, t0, logn);
352 let (g00_lo, g00_hi) = g00.split_at_mut(hn);
353 ff_sampling_fft_dyntree(
354 samp,
355 samp_ctx,
356 z0_lo,
357 z0_hi,
358 g00_lo,
359 g00_hi,
360 g01,
361 orig_logn,
362 logn - 1,
363 rest_tmp,
364 );
365 let mut z0_lo_copy = Vec::<Fpr>::with_capacity(hn);
366 let mut z0_hi_copy = Vec::<Fpr>::with_capacity(hn);
367 unsafe {
368 z0_lo_copy.set_len(hn);
369 z0_hi_copy.set_len(hn);
370 }
371 z0_lo_copy.copy_from_slice(z0_lo);
372 z0_hi_copy.copy_from_slice(z0_hi);
373 fft::poly_merge_fft(t0, &z0_lo_copy, &z0_hi_copy, logn);
374 }
375}
376
377fn ff_sampling_fft(
384 samp: SamplerZ,
385 samp_ctx: &mut SamplerContext,
386 z0: &mut [Fpr],
387 z1: &mut [Fpr],
388 tree: &[Fpr],
389 t0: &[Fpr],
390 t1: &[Fpr],
391 logn: u32,
392 tmp: &mut [Fpr],
393) {
394 if logn == 2 {
396 let tree0 = &tree[4..];
397 let tree1 = &tree[8..];
398
399 let a_re = t1[0];
401 let a_im = t1[2];
402 let b_re = t1[1];
403 let b_im = t1[3];
404 let c_re = fpr_add(a_re, b_re);
405 let c_im = fpr_add(a_im, b_im);
406 let mut w0 = fpr_half(c_re);
407 let mut w1 = fpr_half(c_im);
408 let c_re = fpr_sub(a_re, b_re);
409 let c_im = fpr_sub(a_im, b_im);
410 let mut w2 = fpr_mul(fpr_add(c_re, c_im), FPR_INVSQRT8);
411 let mut w3 = fpr_mul(fpr_sub(c_im, c_re), FPR_INVSQRT8);
412
413 let x0 = w2;
414 let x1 = w3;
415 let sigma = tree1[3];
416 w2 = fpr_of(samp(samp_ctx, x0, sigma) as i64);
417 w3 = fpr_of(samp(samp_ctx, x1, sigma) as i64);
418 let a_re = fpr_sub(x0, w2);
419 let a_im = fpr_sub(x1, w3);
420 let b_re = tree1[0];
421 let b_im = tree1[1];
422 let c_re = fpr_sub(fpr_mul(a_re, b_re), fpr_mul(a_im, b_im));
423 let c_im = fpr_add(fpr_mul(a_re, b_im), fpr_mul(a_im, b_re));
424 let x0 = fpr_add(c_re, w0);
425 let x1 = fpr_add(c_im, w1);
426 let sigma = tree1[2];
427 w0 = fpr_of(samp(samp_ctx, x0, sigma) as i64);
428 w1 = fpr_of(samp(samp_ctx, x1, sigma) as i64);
429
430 let a_re = w0;
431 let a_im = w1;
432 let c_re = fpr_mul(fpr_sub(w2, w3), FPR_INVSQRT2);
433 let c_im = fpr_mul(fpr_add(w2, w3), FPR_INVSQRT2);
434 z1[0] = fpr_add(a_re, c_re);
435 z1[2] = fpr_add(a_im, c_im);
436 z1[1] = fpr_sub(a_re, c_re);
437 z1[3] = fpr_sub(a_im, c_im);
438
439 w0 = fpr_sub(t1[0], z1[0]);
441 w1 = fpr_sub(t1[1], z1[1]);
442 w2 = fpr_sub(t1[2], z1[2]);
443 w3 = fpr_sub(t1[3], z1[3]);
444
445 {
446 let (a_re, a_im) = (w0, w2);
447 let (b_re, b_im) = (tree[0], tree[2]);
448 w0 = fpr_sub(fpr_mul(a_re, b_re), fpr_mul(a_im, b_im));
449 w2 = fpr_add(fpr_mul(a_re, b_im), fpr_mul(a_im, b_re));
450 }
451 {
452 let (a_re, a_im) = (w1, w3);
453 let (b_re, b_im) = (tree[1], tree[3]);
454 w1 = fpr_sub(fpr_mul(a_re, b_re), fpr_mul(a_im, b_im));
455 w3 = fpr_add(fpr_mul(a_re, b_im), fpr_mul(a_im, b_re));
456 }
457
458 w0 = fpr_add(w0, t0[0]);
459 w1 = fpr_add(w1, t0[1]);
460 w2 = fpr_add(w2, t0[2]);
461 w3 = fpr_add(w3, t0[3]);
462
463 let a_re = w0;
465 let a_im = w2;
466 let b_re = w1;
467 let b_im = w3;
468 let c_re = fpr_add(a_re, b_re);
469 let c_im = fpr_add(a_im, b_im);
470 w0 = fpr_half(c_re);
471 w1 = fpr_half(c_im);
472 let c_re = fpr_sub(a_re, b_re);
473 let c_im = fpr_sub(a_im, b_im);
474 w2 = fpr_mul(fpr_add(c_re, c_im), FPR_INVSQRT8);
475 w3 = fpr_mul(fpr_sub(c_im, c_re), FPR_INVSQRT8);
476
477 let x0 = w2;
478 let x1 = w3;
479 let sigma = tree0[3];
480 let y0 = fpr_of(samp(samp_ctx, x0, sigma) as i64);
481 let y1 = fpr_of(samp(samp_ctx, x1, sigma) as i64);
482 w2 = y0;
483 w3 = y1;
484 let a_re = fpr_sub(x0, y0);
485 let a_im = fpr_sub(x1, y1);
486 let b_re = tree0[0];
487 let b_im = tree0[1];
488 let c_re = fpr_sub(fpr_mul(a_re, b_re), fpr_mul(a_im, b_im));
489 let c_im = fpr_add(fpr_mul(a_re, b_im), fpr_mul(a_im, b_re));
490 let x0 = fpr_add(c_re, w0);
491 let x1 = fpr_add(c_im, w1);
492 let sigma = tree0[2];
493 w0 = fpr_of(samp(samp_ctx, x0, sigma) as i64);
494 w1 = fpr_of(samp(samp_ctx, x1, sigma) as i64);
495
496 let a_re = w0;
497 let a_im = w1;
498 let c_re = fpr_mul(fpr_sub(w2, w3), FPR_INVSQRT2);
499 let c_im = fpr_mul(fpr_add(w2, w3), FPR_INVSQRT2);
500 z0[0] = fpr_add(a_re, c_re);
501 z0[2] = fpr_add(a_im, c_im);
502 z0[1] = fpr_sub(a_re, c_re);
503 z0[3] = fpr_sub(a_im, c_im);
504
505 return;
506 }
507
508 if logn == 1 {
510 let x0 = t1[0];
511 let x1 = t1[1];
512 let sigma = tree[3];
513 let y0 = fpr_of(samp(samp_ctx, x0, sigma) as i64);
514 let y1 = fpr_of(samp(samp_ctx, x1, sigma) as i64);
515 z1[0] = y0;
516 z1[1] = y1;
517 let a_re = fpr_sub(x0, y0);
518 let a_im = fpr_sub(x1, y1);
519 let b_re = tree[0];
520 let b_im = tree[1];
521 let c_re = fpr_sub(fpr_mul(a_re, b_re), fpr_mul(a_im, b_im));
522 let c_im = fpr_add(fpr_mul(a_re, b_im), fpr_mul(a_im, b_re));
523 let x0 = fpr_add(c_re, t0[0]);
524 let x1 = fpr_add(c_im, t0[1]);
525 let sigma = tree[2];
526 z0[0] = fpr_of(samp(samp_ctx, x0, sigma) as i64);
527 z0[1] = fpr_of(samp(samp_ctx, x1, sigma) as i64);
528 return;
529 }
530
531 let n: usize = 1 << logn;
533 let hn = n >> 1;
534 let tree0 = &tree[n..];
535 let tree1 = &tree[n + ffldl_treesize(logn - 1)..];
536
537 {
539 let (z1_lo, z1_hi) = z1.split_at_mut(hn);
540 fft::poly_split_fft(z1_lo, z1_hi, t1, logn);
541 }
542
543 {
545 let (tmp_lo, tmp_rest) = tmp.split_at_mut(hn);
546 let (tmp_hi, scratch) = tmp_rest.split_at_mut(hn);
547 let (z1_lo, z1_hi) = z1.split_at_mut(hn);
548 ff_sampling_fft(
549 samp,
550 samp_ctx,
551 tmp_lo,
552 tmp_hi,
553 tree1,
554 z1_lo,
555 z1_hi,
556 logn - 1,
557 scratch,
558 );
559 }
560 {
561 let mut tmp_lo_copy = Vec::<Fpr>::with_capacity(hn);
562 let mut tmp_hi_copy = Vec::<Fpr>::with_capacity(hn);
563 unsafe {
564 tmp_lo_copy.set_len(hn);
565 tmp_hi_copy.set_len(hn);
566 }
567 tmp_lo_copy.copy_from_slice(&tmp[..hn]);
568 tmp_hi_copy.copy_from_slice(&tmp[hn..n]);
569 fft::poly_merge_fft(z1, &tmp_lo_copy, &tmp_hi_copy, logn);
570 }
571
572 tmp[..n].copy_from_slice(&t1[..n]);
574 fft::poly_sub(&mut tmp[..n], &z1[..n], logn);
575 fft::poly_mul_fft(&mut tmp[..n], tree, logn);
576 fft::poly_add(&mut tmp[..n], t0, logn);
577
578 {
580 let (z0_lo, z0_hi) = z0.split_at_mut(hn);
581 fft::poly_split_fft(z0_lo, z0_hi, &tmp[..n], logn);
582 }
583 {
584 let (tmp_lo, tmp_rest) = tmp.split_at_mut(hn);
585 let (tmp_hi, scratch) = tmp_rest.split_at_mut(hn);
586 let (z0_lo, z0_hi) = z0.split_at_mut(hn);
587 ff_sampling_fft(
588 samp,
589 samp_ctx,
590 tmp_lo,
591 tmp_hi,
592 tree0,
593 z0_lo,
594 z0_hi,
595 logn - 1,
596 scratch,
597 );
598 }
599 {
600 let mut tmp_lo_copy = Vec::<Fpr>::with_capacity(hn);
601 let mut tmp_hi_copy = Vec::<Fpr>::with_capacity(hn);
602 unsafe {
603 tmp_lo_copy.set_len(hn);
604 tmp_hi_copy.set_len(hn);
605 }
606 tmp_lo_copy.copy_from_slice(&tmp[..hn]);
607 tmp_hi_copy.copy_from_slice(&tmp[hn..n]);
608 fft::poly_merge_fft(z0, &tmp_lo_copy, &tmp_hi_copy, logn);
609 }
610}
611
612fn do_sign_tree(
619 samp: SamplerZ,
620 samp_ctx: &mut SamplerContext,
621 s2: &mut [i16],
622 expanded_key: &[Fpr],
623 hm: &[u16],
624 logn: u32,
625 tmp: &mut [Fpr],
626) -> bool {
627 let n: usize = 1 << logn;
628
629 let b00 = &expanded_key[skoff_b00(logn)..];
630 let b01 = &expanded_key[skoff_b01(logn)..];
631 let b10 = &expanded_key[skoff_b10(logn)..];
632 let b11 = &expanded_key[skoff_b11(logn)..];
633 let tree = &expanded_key[skoff_tree(logn)..];
634
635 for u in 0..n {
637 tmp[u] = fpr_of(hm[u] as i64);
638 }
639
640 fft::fft(&mut tmp[0..n], logn);
642 let ni = FPR_INVERSE_OF_Q;
643 unsafe {
644 let p = tmp.as_mut_ptr();
645 core::ptr::copy_nonoverlapping(p, p.add(n), n);
646 }
647 fft::poly_mul_fft(&mut tmp[n..2 * n], &b01[..n], logn);
648 fft::poly_mulconst(&mut tmp[n..2 * n], fpr_neg(ni), logn);
649 fft::poly_mul_fft(&mut tmp[0..n], &b11[..n], logn);
650 fft::poly_mulconst(&mut tmp[0..n], ni, logn);
651
652 {
655 let ptr = tmp.as_mut_ptr();
658 let t0 = unsafe { core::slice::from_raw_parts(ptr, n) };
659 let t1 = unsafe { core::slice::from_raw_parts(ptr.add(n), n) };
660 let tx = unsafe { core::slice::from_raw_parts_mut(ptr.add(2 * n), n) };
661 let ty = unsafe { core::slice::from_raw_parts_mut(ptr.add(3 * n), n) };
662 let scratch = unsafe { core::slice::from_raw_parts_mut(ptr.add(4 * n), tmp.len() - 4 * n) };
663 ff_sampling_fft(samp, samp_ctx, tx, ty, tree, t0, t1, logn, scratch);
664 }
665
666 {
669 let ptr = tmp.as_mut_ptr();
670 unsafe {
671 core::ptr::copy_nonoverlapping(ptr.add(2 * n), ptr, n);
673 core::ptr::copy_nonoverlapping(ptr.add(3 * n), ptr.add(n), n);
675 }
676 }
677 fft::poly_mul_fft(&mut tmp[2 * n..3 * n], &b00[..n], logn);
679 fft::poly_mul_fft(&mut tmp[3 * n..4 * n], &b10[..n], logn);
681 {
683 let (front, back) = tmp.split_at_mut(3 * n);
684 fft::poly_add(&mut front[2 * n..], &back[..n], logn);
685 }
686 {
688 let ptr = tmp.as_mut_ptr();
689 unsafe {
690 core::ptr::copy_nonoverlapping(ptr, ptr.add(3 * n), n);
691 }
692 }
693 fft::poly_mul_fft(&mut tmp[3 * n..4 * n], &b01[..n], logn);
694
695 {
697 let ptr = tmp.as_mut_ptr();
698 unsafe {
699 core::ptr::copy_nonoverlapping(ptr.add(2 * n), ptr, n);
700 }
701 }
702 fft::poly_mul_fft(&mut tmp[n..2 * n], &b11[..n], logn);
704 {
706 let (front, back) = tmp.split_at_mut(3 * n);
707 fft::poly_add(&mut front[n..2 * n], &back[..n], logn);
708 }
709
710 fft::ifft(&mut tmp[0..n], logn);
711 fft::ifft(&mut tmp[n..2 * n], logn);
712
713 let s1tmp: &mut [i16] =
715 unsafe { core::slice::from_raw_parts_mut(tmp[2 * n..].as_mut_ptr() as *mut i16, n) };
716 let mut sqn: u32 = 0;
717 let mut ng: u32 = 0;
718 for u in 0..n {
719 let z = (hm[u] as i32) - (fpr_rint(tmp[u]) as i32);
720 sqn = sqn.wrapping_add((z * z) as u32);
721 ng |= sqn;
722 s1tmp[u] = z as i16;
723 }
724 sqn |= (ng >> 31).wrapping_neg();
725
726 let mut s2_vals: Vec<i16> = Vec::with_capacity(n);
728 for u in 0..n {
729 s2_vals.push(-(fpr_rint(tmp[n + u]) as i16));
730 }
731
732 if is_short_half(sqn, &s2_vals, logn) {
733 s2[..n].copy_from_slice(&s2_vals);
734 let s1_out: &mut [i16] =
736 unsafe { core::slice::from_raw_parts_mut(tmp.as_mut_ptr() as *mut i16, n) };
737 s1_out[..n].copy_from_slice(&s1tmp[..n]);
738 return true;
739 }
740 false
741}
742
743fn do_sign_dyn(
753 samp: SamplerZ,
754 samp_ctx: &mut SamplerContext,
755 s2: &mut [i16],
756 f: &[i8],
757 g: &[i8],
758 big_f: &[i8],
759 big_g: &[i8],
760 hm: &[u16],
761 logn: u32,
762 tmp: &mut [Fpr],
763) -> bool {
764 let n: usize = 1 << logn;
765 let ptr = tmp.as_mut_ptr();
766
767 {
770 let b00 = unsafe { core::slice::from_raw_parts_mut(ptr, n) };
771 let b01 = unsafe { core::slice::from_raw_parts_mut(ptr.add(n), n) };
772 let b10 = unsafe { core::slice::from_raw_parts_mut(ptr.add(2 * n), n) };
773 let b11 = unsafe { core::slice::from_raw_parts_mut(ptr.add(3 * n), n) };
774
775 smallints_to_fpr(b01, f, logn);
776 smallints_to_fpr(b00, g, logn);
777 smallints_to_fpr(b11, big_f, logn);
778 smallints_to_fpr(b10, big_g, logn);
779 fft::fft(b01, logn);
780 fft::fft(b00, logn);
781 fft::fft(b11, logn);
782 fft::fft(b10, logn);
783 fft::poly_neg(b01, logn);
784 fft::poly_neg(b11, logn);
785 }
786
787 {
790 let b00 = unsafe { core::slice::from_raw_parts_mut(ptr, n) };
791 let b01 = unsafe { core::slice::from_raw_parts_mut(ptr.add(n), n) };
792 let b10 = unsafe { core::slice::from_raw_parts_mut(ptr.add(2 * n), n) };
793 let b11 = unsafe { core::slice::from_raw_parts_mut(ptr.add(3 * n), n) };
794 let t0 = unsafe { core::slice::from_raw_parts_mut(ptr.add(4 * n), n) };
795 let t1 = unsafe { core::slice::from_raw_parts_mut(ptr.add(5 * n), n) };
796
797 t0.copy_from_slice(b01);
799 fft::poly_mulselfadj_fft(t0, logn);
800
801 t1.copy_from_slice(b00);
803 fft::poly_muladj_fft(t1, b10, logn);
804
805 fft::poly_mulselfadj_fft(b00, logn);
807 fft::poly_add(b00, t0, logn);
808
809 t0.copy_from_slice(b01);
811 fft::poly_muladj_fft(b01, b11, logn);
812 fft::poly_add(b01, t1, logn);
813
814 fft::poly_mulselfadj_fft(b10, logn);
816 t1.copy_from_slice(b11);
817 fft::poly_mulselfadj_fft(t1, logn);
818 fft::poly_add(b10, t1, logn);
819 }
820
821 {
827 let t0 = unsafe { core::slice::from_raw_parts_mut(ptr.add(5 * n), n) };
828 for u in 0..n {
829 t0[u] = fpr_of(hm[u] as i64);
830 }
831 }
832
833 {
835 let t0 = unsafe { core::slice::from_raw_parts_mut(ptr.add(5 * n), n) };
836 let t1 = unsafe { core::slice::from_raw_parts_mut(ptr.add(6 * n), n) };
837 let b01_saved = unsafe { core::slice::from_raw_parts(ptr.add(4 * n), n) };
838 let b11 = unsafe { core::slice::from_raw_parts(ptr.add(3 * n), n) };
839
840 fft::fft(t0, logn);
841 let ni = FPR_INVERSE_OF_Q;
842 t1.copy_from_slice(t0);
843 fft::poly_mul_fft(t1, b01_saved, logn);
844 fft::poly_mulconst(t1, fpr_neg(ni), logn);
845 fft::poly_mul_fft(t0, b11, logn);
846 fft::poly_mulconst(t0, ni, logn);
847 }
848
849 unsafe {
851 core::ptr::copy(ptr.add(5 * n), ptr.add(3 * n), 2 * n);
852 }
853
854 {
859 let g00 = unsafe { core::slice::from_raw_parts_mut(ptr, n) };
860 let g01 = unsafe { core::slice::from_raw_parts_mut(ptr.add(n), n) };
861 let g11 = unsafe { core::slice::from_raw_parts_mut(ptr.add(2 * n), n) };
862 let t0 = unsafe { core::slice::from_raw_parts_mut(ptr.add(3 * n), n) };
863 let t1 = unsafe { core::slice::from_raw_parts_mut(ptr.add(4 * n), n) };
864 let scratch = unsafe { core::slice::from_raw_parts_mut(ptr.add(5 * n), tmp.len() - 5 * n) };
865 ff_sampling_fft_dyntree(samp, samp_ctx, t0, t1, g00, g01, g11, logn, logn, scratch);
866 }
867
868 unsafe {
871 core::ptr::copy(ptr.add(3 * n), ptr.add(5 * n), 2 * n);
872 }
873
874 {
876 let b00 = unsafe { core::slice::from_raw_parts_mut(ptr, n) };
877 let b01 = unsafe { core::slice::from_raw_parts_mut(ptr.add(n), n) };
878 let b10 = unsafe { core::slice::from_raw_parts_mut(ptr.add(2 * n), n) };
879 let b11 = unsafe { core::slice::from_raw_parts_mut(ptr.add(3 * n), n) };
880
881 smallints_to_fpr(b01, f, logn);
882 smallints_to_fpr(b00, g, logn);
883 smallints_to_fpr(b11, big_f, logn);
884 smallints_to_fpr(b10, big_g, logn);
885 fft::fft(b01, logn);
886 fft::fft(b00, logn);
887 fft::fft(b11, logn);
888 fft::fft(b10, logn);
889 fft::poly_neg(b01, logn);
890 fft::poly_neg(b11, logn);
891 }
892
893 unsafe {
896 core::ptr::copy_nonoverlapping(ptr.add(5 * n), ptr.add(7 * n), n);
897 core::ptr::copy_nonoverlapping(ptr.add(6 * n), ptr.add(8 * n), n);
898 }
899
900 {
902 let b00 = unsafe { core::slice::from_raw_parts(ptr, n) };
903 let b01 = unsafe { core::slice::from_raw_parts(ptr.add(n), n) };
904 let b10 = unsafe { core::slice::from_raw_parts(ptr.add(2 * n), n) };
905 let _b11 = unsafe { core::slice::from_raw_parts(ptr.add(3 * n), n) };
906 let tx = unsafe { core::slice::from_raw_parts_mut(ptr.add(7 * n), n) };
907 let ty = unsafe { core::slice::from_raw_parts_mut(ptr.add(8 * n), n) };
908
909 fft::poly_mul_fft(tx, b00, logn);
910 fft::poly_mul_fft(ty, b10, logn);
911 fft::poly_add(tx, ty, logn);
912
913 let t0_slice = unsafe { core::slice::from_raw_parts(ptr.add(5 * n), n) };
915 ty.copy_from_slice(t0_slice);
916 fft::poly_mul_fft(ty, b01, logn);
917 }
918
919 unsafe {
921 core::ptr::copy_nonoverlapping(ptr.add(7 * n), ptr.add(5 * n), n);
922 }
923
924 {
926 let t1 = unsafe { core::slice::from_raw_parts_mut(ptr.add(6 * n), n) };
927 let b11 = unsafe { core::slice::from_raw_parts(ptr.add(3 * n), n) };
928 fft::poly_mul_fft(t1, b11, logn);
929 }
930
931 {
933 let t1 = unsafe { core::slice::from_raw_parts_mut(ptr.add(6 * n), n) };
934 let ty = unsafe { core::slice::from_raw_parts(ptr.add(8 * n), n) };
935 fft::poly_add(t1, ty, logn);
936 }
937
938 {
940 let t0 = unsafe { core::slice::from_raw_parts_mut(ptr.add(5 * n), n) };
941 fft::ifft(t0, logn);
942 }
943 {
944 let t1 = unsafe { core::slice::from_raw_parts_mut(ptr.add(6 * n), n) };
945 fft::ifft(t1, logn);
946 }
947
948 let s1tmp: &mut [i16] =
950 unsafe { core::slice::from_raw_parts_mut(ptr.add(7 * n) as *mut i16, n) };
951 let mut sqn: u32 = 0;
952 let mut ng: u32 = 0;
953 for u in 0..n {
954 let t0_u = unsafe { *ptr.add(5 * n + u) };
955 let z = (hm[u] as i32) - (fpr_rint(t0_u) as i32);
956 sqn = sqn.wrapping_add((z * z) as u32);
957 ng |= sqn;
958 s1tmp[u] = z as i16;
959 }
960 sqn |= (ng >> 31).wrapping_neg();
961
962 let mut s2_vals: Vec<i16> = Vec::with_capacity(n);
963 for u in 0..n {
964 let t1_u = unsafe { *ptr.add(6 * n + u) };
965 s2_vals.push(-(fpr_rint(t1_u) as i16));
966 }
967
968 if is_short_half(sqn, &s2_vals, logn) {
969 s2[..n].copy_from_slice(&s2_vals);
970 let s1_out: &mut [i16] = unsafe { core::slice::from_raw_parts_mut(ptr as *mut i16, n) };
971 s1_out[..n].copy_from_slice(&s1tmp[..n]);
972 return true;
973 }
974 false
975}
976
977static GAUSS0_DIST: [u32; 54] = [
983 10745844, 3068844, 3741698, 5559083, 1580863, 8248194, 2260429, 13669192, 2736639, 708981,
984 4421575, 10046180, 169348, 7122675, 4136815, 30538, 13063405, 7650655, 4132, 14505003, 7826148,
985 417, 16768101, 11363290, 31, 8444042, 8086568, 1, 12844466, 265321, 0, 1232676, 13644283, 0,
986 38047, 9111839, 0, 870, 6138264, 0, 14, 12545723, 0, 0, 3104126, 0, 0, 28824, 0, 0, 198, 0, 0,
987 1,
988];
989
990pub fn gaussian0_sampler(p: &mut Prng) -> i32 {
993 let lo = prng_get_u64(p);
994 let hi = prng_get_u8(p);
995 let v0 = (lo as u32) & 0xFFFFFF;
996 let v1 = ((lo >> 24) as u32) & 0xFFFFFF;
997 let v2 = ((lo >> 48) as u32) | (hi << 16);
998
999 let mut z: i32 = 0;
1000 let mut u = 0;
1001 while u < GAUSS0_DIST.len() {
1002 unsafe {
1003 let w0 = *GAUSS0_DIST.get_unchecked(u + 2);
1004 let w1 = *GAUSS0_DIST.get_unchecked(u + 1);
1005 let w2 = *GAUSS0_DIST.get_unchecked(u);
1006 let cc = v0.wrapping_sub(w0) >> 31;
1007 let cc = v1.wrapping_sub(w1).wrapping_sub(cc) >> 31;
1008 let cc = v2.wrapping_sub(w2).wrapping_sub(cc) >> 31;
1009 z += cc as i32;
1010 }
1011 u += 3;
1012 }
1013 z
1014}
1015
1016fn ber_exp(p: &mut Prng, x: Fpr, ccs: Fpr) -> bool {
1018 let s = fpr_trunc(fpr_mul(x, FPR_INV_LOG2)) as i32;
1019 let r = fpr_sub(x, fpr_mul(fpr_of(s as i64), FPR_LOG2));
1020
1021 let mut sw = s as u32;
1023 sw ^= (sw ^ 63) & (63u32.wrapping_sub(sw) >> 31).wrapping_neg();
1024 let s = sw as i32;
1025
1026 let z = ((fpr_expm_p63(r, ccs) << 1).wrapping_sub(1)) >> (s as u32);
1027
1028 let mut i: i32 = 64;
1029 let mut w: u32;
1030 loop {
1031 i -= 8;
1032 w = prng_get_u8(p).wrapping_sub(((z >> (i as u32)) & 0xFF) as u32);
1033 if w != 0 || i <= 0 {
1034 break;
1035 }
1036 }
1037 (w >> 31) != 0
1038}
1039
1040pub fn sampler(ctx: &mut SamplerContext, mu: Fpr, isigma: Fpr) -> i32 {
1043 let s = fpr_floor(mu) as i32;
1044 let r = fpr_sub(mu, fpr_of(s as i64));
1045 let dss = fpr_half(fpr_sqr(isigma));
1046 let ccs = fpr_mul(isigma, ctx.sigma_min);
1047
1048 for _ in 0..1000 {
1051 let z0 = gaussian0_sampler(&mut ctx.p);
1052 let b = (prng_get_u8(&mut ctx.p) & 1) as i32;
1053 let z = b + ((b << 1) - 1) * z0;
1054
1055 let x = fpr_mul(fpr_sqr(fpr_sub(fpr_of(z as i64), r)), dss);
1056 let x = fpr_sub(x, fpr_mul(fpr_of((z0 * z0) as i64), FPR_INV_2SQRSIGMA0));
1057 if ber_exp(&mut ctx.p, x, ccs) {
1058 return s + z;
1059 }
1060 }
1061 s
1063}
1064
1065pub fn sign_tree(
1071 sig: &mut [i16],
1072 rng: &mut InnerShake256Context,
1073 expanded_key: &[Fpr],
1074 hm: &[u16],
1075 logn: u32,
1076 tmp: &mut [u8],
1077) {
1078 let ftmp: &mut [Fpr] = unsafe {
1079 core::slice::from_raw_parts_mut(
1080 tmp.as_mut_ptr() as *mut Fpr,
1081 tmp.len() / core::mem::size_of::<Fpr>(),
1082 )
1083 };
1084 loop {
1085 let mut spc = SamplerContext {
1086 p: Prng::new(),
1087 sigma_min: FPR_SIGMA_MIN[logn as usize],
1088 };
1089 prng_init(&mut spc.p, rng);
1090
1091 if do_sign_tree(sampler, &mut spc, sig, expanded_key, hm, logn, ftmp) {
1092 break;
1093 }
1094 }
1095}
1096
1097pub fn sign_dyn(
1099 sig: &mut [i16],
1100 rng: &mut InnerShake256Context,
1101 f: &[i8],
1102 g: &[i8],
1103 big_f: &[i8],
1104 big_g: &[i8],
1105 hm: &[u16],
1106 logn: u32,
1107 tmp: &mut [u8],
1108) {
1109 let ftmp: &mut [Fpr] = unsafe {
1110 core::slice::from_raw_parts_mut(
1111 tmp.as_mut_ptr() as *mut Fpr,
1112 tmp.len() / core::mem::size_of::<Fpr>(),
1113 )
1114 };
1115 loop {
1116 let mut spc = SamplerContext {
1117 p: Prng::new(),
1118 sigma_min: FPR_SIGMA_MIN[logn as usize],
1119 };
1120 prng_init(&mut spc.p, rng);
1121
1122 if do_sign_dyn(sampler, &mut spc, sig, f, g, big_f, big_g, hm, logn, ftmp) {
1123 break;
1124 }
1125 }
1126}