#include <wolfssl/wolfcrypt/libwolfssl_sources.h>
#ifdef WC_MLKEM_NO_ASM
#undef USE_INTEL_SPEEDUP
#undef WOLFSSL_ARMASM
#undef WOLFSSL_RISCV_ASM
#endif
#include <wolfssl/wolfcrypt/wc_mlkem.h>
#include <wolfssl/wolfcrypt/sha3.h>
#include <wolfssl/wolfcrypt/cpuid.h>
#ifdef WOLFSSL_WC_MLKEM
#ifdef NO_INLINE
#include <wolfssl/wolfcrypt/misc.h>
#else
#define WOLFSSL_MISC_INCLUDED
#include <wolfcrypt/src/misc.c>
#endif
#if defined(WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM) || \
defined(WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM)
static int mlkem_gen_matrix_i(MLKEM_PRF_T* prf, sword16* a, int k, byte* seed,
int i, int transposed);
static int mlkem_get_noise_i(MLKEM_PRF_T* prf, int k, sword16* vec2,
byte* seed, int i, int make);
static int mlkem_get_noise_eta2_c(MLKEM_PRF_T* prf, sword16* p,
const byte* seed);
#endif
extern sword16 wc_mlkem_opt_blocker(void);
#if defined(USE_INTEL_SPEEDUP) || (defined(__aarch64__) && \
defined(WOLFSSL_ARMASM))
static cpuid_flags_t cpuid_flags = WC_CPUID_INITIALIZER;
#endif
#define MLKEM_Q_1_HALF ((MLKEM_Q + 1) / 2)
#define MLKEM_Q_HALF (MLKEM_Q / 2)
#define MLKEM_QINV 62209
#define MLKEM_V (((1U << 26) + (MLKEM_Q / 2)) / MLKEM_Q)
#define MLKEM_F (((word64)1 << 32) % MLKEM_Q)
#define SHA3_128_BYTES (WC_SHA3_128_COUNT * 8)
#define SHA3_256_BYTES (WC_SHA3_256_COUNT * 8)
#define GEN_MATRIX_NBLOCKS \
((12 * MLKEM_N / 8 * (1 << 12) / MLKEM_Q + XOF_BLOCK_SIZE) / XOF_BLOCK_SIZE)
#define GEN_MATRIX_SIZE GEN_MATRIX_NBLOCKS * XOF_BLOCK_SIZE
#define ETA3_RAND_SIZE ((3 * MLKEM_N) / 4)
#define ETA2_RAND_SIZE ((2 * MLKEM_N) / 4)
#define MLKEM_MONT_RED(a) \
(sword16)(((a) - (sword32)(((sword16)((sword16)(a) * \
(sword16)MLKEM_QINV)) * \
(sword32)MLKEM_Q)) >> 16)
#define MLKEM_BARRETT_RED(a) \
(sword16)((sword16)(a) - (sword16)((sword16)( \
((sword32)((sword32)MLKEM_V * (sword16)(a))) >> 26) * (word16)MLKEM_Q))
const sword16 zetas[MLKEM_N / 2] = {
2285, 2571, 2970, 1812, 1493, 1422, 287, 202,
3158, 622, 1577, 182, 962, 2127, 1855, 1468,
573, 2004, 264, 383, 2500, 1458, 1727, 3199,
2648, 1017, 732, 608, 1787, 411, 3124, 1758,
1223, 652, 2777, 1015, 2036, 1491, 3047, 1785,
516, 3321, 3009, 2663, 1711, 2167, 126, 1469,
2476, 3239, 3058, 830, 107, 1908, 3082, 2378,
2931, 961, 1821, 2604, 448, 2264, 677, 2054,
2226, 430, 555, 843, 2078, 871, 1550, 105,
422, 587, 177, 3094, 3038, 2869, 1574, 1653,
3083, 778, 1159, 3182, 2552, 1483, 2727, 1119,
1739, 644, 2457, 349, 418, 329, 3173, 3254,
817, 1097, 603, 610, 1322, 2044, 1864, 384,
2114, 3193, 1218, 1994, 2455, 220, 2142, 1670,
2144, 1799, 2051, 794, 1819, 2475, 2459, 478,
3221, 3021, 996, 991, 958, 1869, 1522, 1628
};
#if !defined(WOLFSSL_ARMASM)
static void mlkem_ntt(sword16* r)
{
#ifdef WOLFSSL_MLKEM_SMALL
unsigned int len;
unsigned int k;
unsigned int j;
k = 1;
for (len = MLKEM_N / 2; len >= 2; len >>= 1) {
unsigned int start;
for (start = 0; start < MLKEM_N; start = j + len) {
sword16 zeta = zetas[k++];
for (j = start; j < start + len; ++j) {
sword32 p = (sword32)zeta * r[j + len];
sword16 t = MLKEM_MONT_RED(p);
sword16 rj = r[j];
r[j + len] = (sword16)(rj - t);
r[j] = (sword16)(rj + t);
}
}
}
for (j = 0; j < MLKEM_N; ++j) {
r[j] = MLKEM_BARRETT_RED(r[j]);
}
#elif defined(WOLFSSL_MLKEM_NO_LARGE_CODE)
unsigned int len;
unsigned int k = 1;
unsigned int j;
unsigned int start;
sword16 zeta = zetas[k++];
for (j = 0; j < MLKEM_N / 2; ++j) {
sword32 p = (sword32)zeta * r[j + MLKEM_N / 2];
sword16 t = MLKEM_MONT_RED(p);
sword16 rj = r[j];
r[j + MLKEM_N / 2] = (sword16)(rj - t);
r[j] = (sword16)(rj + t);
}
for (len = MLKEM_N / 4; len >= 2; len >>= 1) {
for (start = 0; start < MLKEM_N; start = j + len) {
zeta = zetas[k++];
for (j = start; j < start + len; ++j) {
sword32 p = (sword32)zeta * r[j + len];
sword16 t = MLKEM_MONT_RED(p);
sword16 rj = r[j];
r[j + len] = (sword16)(rj - t);
r[j] = (sword16)(rj + t);
}
}
}
for (j = 0; j < MLKEM_N; ++j) {
r[j] = MLKEM_BARRETT_RED(r[j]);
}
#elif defined(WOLFSSL_MLKEM_NTT_UNROLL)
unsigned int k = 1;
unsigned int j;
unsigned int start;
sword16 zeta = zetas[k++];
for (j = 0; j < MLKEM_N / 2; ++j) {
sword32 p = (sword32)zeta * r[j + MLKEM_N / 2];
sword16 t = MLKEM_MONT_RED(p);
sword16 rj = r[j];
r[j + MLKEM_N / 2] = rj - t;
r[j] = rj + t;
}
for (start = 0; start < MLKEM_N; start += 2 * 64) {
zeta = zetas[k++];
for (j = 0; j < 64; ++j) {
sword32 p = (sword32)zeta * r[start + j + 64];
sword16 t = MLKEM_MONT_RED(p);
sword16 rj = r[start + j];
r[start + j + 64] = rj - t;
r[start + j] = rj + t;
}
}
for (start = 0; start < MLKEM_N; start += 2 * 32) {
zeta = zetas[k++];
for (j = 0; j < 32; ++j) {
sword32 p = (sword32)zeta * r[start + j + 32];
sword16 t = MLKEM_MONT_RED(p);
sword16 rj = r[start + j];
r[start + j + 32] = rj - t;
r[start + j] = rj + t;
}
}
for (start = 0; start < MLKEM_N; start += 2 * 16) {
zeta = zetas[k++];
for (j = 0; j < 16; ++j) {
sword32 p = (sword32)zeta * r[start + j + 16];
sword16 t = MLKEM_MONT_RED(p);
sword16 rj = r[start + j];
r[start + j + 16] = rj - t;
r[start + j] = rj + t;
}
}
for (start = 0; start < MLKEM_N; start += 2 * 8) {
zeta = zetas[k++];
for (j = 0; j < 8; ++j) {
sword32 p = (sword32)zeta * r[start + j + 8];
sword16 t = MLKEM_MONT_RED(p);
sword16 rj = r[start + j];
r[start + j + 8] = rj - t;
r[start + j] = rj + t;
}
}
for (start = 0; start < MLKEM_N; start += 2 * 4) {
zeta = zetas[k++];
for (j = 0; j < 4; ++j) {
sword32 p = (sword32)zeta * r[start + j + 4];
sword16 t = MLKEM_MONT_RED(p);
sword16 rj = r[start + j];
r[start + j + 4] = rj - t;
r[start + j] = rj + t;
}
}
for (start = 0; start < MLKEM_N; start += 2 * 2) {
zeta = zetas[k++];
for (j = 0; j < 2; ++j) {
sword32 p = (sword32)zeta * r[start + j + 2];
sword16 t = MLKEM_MONT_RED(p);
sword16 rj = r[start + j];
r[start + j + 2] = rj - t;
r[start + j] = rj + t;
}
}
for (j = 0; j < MLKEM_N; ++j) {
r[j] = MLKEM_BARRETT_RED(r[j]);
}
#else
unsigned int j;
sword16 t0;
sword16 t1;
sword16 t2;
sword16 t3;
sword16 zeta128 = zetas[1];
sword16 zeta64_0 = zetas[2];
sword16 zeta64_1 = zetas[3];
for (j = 0; j < MLKEM_N / 8; j++) {
sword16 r0 = r[j + 0];
sword16 r1 = r[j + 32];
sword16 r2 = r[j + 64];
sword16 r3 = r[j + 96];
sword16 r4 = r[j + 128];
sword16 r5 = r[j + 160];
sword16 r6 = r[j + 192];
sword16 r7 = r[j + 224];
t0 = MLKEM_MONT_RED((sword32)zeta128 * r4);
t1 = MLKEM_MONT_RED((sword32)zeta128 * r5);
t2 = MLKEM_MONT_RED((sword32)zeta128 * r6);
t3 = MLKEM_MONT_RED((sword32)zeta128 * r7);
r4 = (sword16)(r0 - t0);
r5 = (sword16)(r1 - t1);
r6 = (sword16)(r2 - t2);
r7 = (sword16)(r3 - t3);
r0 = (sword16)(r0 + t0);
r1 = (sword16)(r1 + t1);
r2 = (sword16)(r2 + t2);
r3 = (sword16)(r3 + t3);
t0 = MLKEM_MONT_RED((sword32)zeta64_0 * r2);
t1 = MLKEM_MONT_RED((sword32)zeta64_0 * r3);
t2 = MLKEM_MONT_RED((sword32)zeta64_1 * r6);
t3 = MLKEM_MONT_RED((sword32)zeta64_1 * r7);
r2 = (sword16)(r0 - t0);
r3 = (sword16)(r1 - t1);
r6 = (sword16)(r4 - t2);
r7 = (sword16)(r5 - t3);
r0 = (sword16)(r0 + t0);
r1 = (sword16)(r1 + t1);
r4 = (sword16)(r4 + t2);
r5 = (sword16)(r5 + t3);
r[j + 0] = r0;
r[j + 32] = r1;
r[j + 64] = r2;
r[j + 96] = r3;
r[j + 128] = r4;
r[j + 160] = r5;
r[j + 192] = r6;
r[j + 224] = r7;
}
for (j = 0; j < MLKEM_N; j += 64) {
unsigned int i;
sword16 zeta32 = zetas[ 4 + j / 64 + 0];
sword16 zeta16_0 = zetas[ 8 + j / 32 + 0];
sword16 zeta16_1 = zetas[ 8 + j / 32 + 1];
sword16 zeta8_0 = zetas[16 + j / 16 + 0];
sword16 zeta8_1 = zetas[16 + j / 16 + 1];
sword16 zeta8_2 = zetas[16 + j / 16 + 2];
sword16 zeta8_3 = zetas[16 + j / 16 + 3];
for (i = 0; i < 8; i++) {
sword16 r0 = r[j + i + 0];
sword16 r1 = r[j + i + 8];
sword16 r2 = r[j + i + 16];
sword16 r3 = r[j + i + 24];
sword16 r4 = r[j + i + 32];
sword16 r5 = r[j + i + 40];
sword16 r6 = r[j + i + 48];
sword16 r7 = r[j + i + 56];
t0 = MLKEM_MONT_RED((sword32)zeta32 * r4);
t1 = MLKEM_MONT_RED((sword32)zeta32 * r5);
t2 = MLKEM_MONT_RED((sword32)zeta32 * r6);
t3 = MLKEM_MONT_RED((sword32)zeta32 * r7);
r4 = (sword16)(r0 - t0);
r5 = (sword16)(r1 - t1);
r6 = (sword16)(r2 - t2);
r7 = (sword16)(r3 - t3);
r0 = (sword16)(r0 + t0);
r1 = (sword16)(r1 + t1);
r2 = (sword16)(r2 + t2);
r3 = (sword16)(r3 + t3);
t0 = MLKEM_MONT_RED((sword32)zeta16_0 * r2);
t1 = MLKEM_MONT_RED((sword32)zeta16_0 * r3);
t2 = MLKEM_MONT_RED((sword32)zeta16_1 * r6);
t3 = MLKEM_MONT_RED((sword32)zeta16_1 * r7);
r2 = (sword16)(r0 - t0);
r3 = (sword16)(r1 - t1);
r6 = (sword16)(r4 - t2);
r7 = (sword16)(r5 - t3);
r0 = (sword16)(r0 + t0);
r1 = (sword16)(r1 + t1);
r4 = (sword16)(r4 + t2);
r5 = (sword16)(r5 + t3);
t0 = MLKEM_MONT_RED((sword32)zeta8_0 * r1);
t1 = MLKEM_MONT_RED((sword32)zeta8_1 * r3);
t2 = MLKEM_MONT_RED((sword32)zeta8_2 * r5);
t3 = MLKEM_MONT_RED((sword32)zeta8_3 * r7);
r1 = (sword16)(r0 - t0);
r3 = (sword16)(r2 - t1);
r5 = (sword16)(r4 - t2);
r7 = (sword16)(r6 - t3);
r0 = (sword16)(r0 + t0);
r2 = (sword16)(r2 + t1);
r4 = (sword16)(r4 + t2);
r6 = (sword16)(r6 + t3);
r[j + i + 0] = r0;
r[j + i + 8] = r1;
r[j + i + 16] = r2;
r[j + i + 24] = r3;
r[j + i + 32] = r4;
r[j + i + 40] = r5;
r[j + i + 48] = r6;
r[j + i + 56] = r7;
}
}
for (j = 0; j < MLKEM_N; j += 8) {
sword16 zeta4 = zetas[32 + j / 8 + 0];
sword16 zeta2_0 = zetas[64 + j / 4 + 0];
sword16 zeta2_1 = zetas[64 + j / 4 + 1];
sword16 r0 = r[j + 0];
sword16 r1 = r[j + 1];
sword16 r2 = r[j + 2];
sword16 r3 = r[j + 3];
sword16 r4 = r[j + 4];
sword16 r5 = r[j + 5];
sword16 r6 = r[j + 6];
sword16 r7 = r[j + 7];
t0 = MLKEM_MONT_RED((sword32)zeta4 * r4);
t1 = MLKEM_MONT_RED((sword32)zeta4 * r5);
t2 = MLKEM_MONT_RED((sword32)zeta4 * r6);
t3 = MLKEM_MONT_RED((sword32)zeta4 * r7);
r4 = (sword16)(r0 - t0);
r5 = (sword16)(r1 - t1);
r6 = (sword16)(r2 - t2);
r7 = (sword16)(r3 - t3);
r0 = (sword16)(r0 + t0);
r1 = (sword16)(r1 + t1);
r2 = (sword16)(r2 + t2);
r3 = (sword16)(r3 + t3);
t0 = MLKEM_MONT_RED((sword32)zeta2_0 * r2);
t1 = MLKEM_MONT_RED((sword32)zeta2_0 * r3);
t2 = MLKEM_MONT_RED((sword32)zeta2_1 * r6);
t3 = MLKEM_MONT_RED((sword32)zeta2_1 * r7);
r2 = (sword16)(r0 - t0);
r3 = (sword16)(r1 - t1);
r6 = (sword16)(r4 - t2);
r7 = (sword16)(r5 - t3);
r0 = (sword16)(r0 + t0);
r1 = (sword16)(r1 + t1);
r4 = (sword16)(r4 + t2);
r5 = (sword16)(r5 + t3);
r[j + 0] = MLKEM_BARRETT_RED(r0);
r[j + 1] = MLKEM_BARRETT_RED(r1);
r[j + 2] = MLKEM_BARRETT_RED(r2);
r[j + 3] = MLKEM_BARRETT_RED(r3);
r[j + 4] = MLKEM_BARRETT_RED(r4);
r[j + 5] = MLKEM_BARRETT_RED(r5);
r[j + 6] = MLKEM_BARRETT_RED(r6);
r[j + 7] = MLKEM_BARRETT_RED(r7);
}
#endif
}
#if !defined(WOLFSSL_MLKEM_NO_ENCAPSULATE) || \
!defined(WOLFSSL_MLKEM_NO_DECAPSULATE)
const sword16 zetas_inv[MLKEM_N / 2] = {
1701, 1807, 1460, 2371, 2338, 2333, 308, 108,
2851, 870, 854, 1510, 2535, 1278, 1530, 1185,
1659, 1187, 3109, 874, 1335, 2111, 136, 1215,
2945, 1465, 1285, 2007, 2719, 2726, 2232, 2512,
75, 156, 3000, 2911, 2980, 872, 2685, 1590,
2210, 602, 1846, 777, 147, 2170, 2551, 246,
1676, 1755, 460, 291, 235, 3152, 2742, 2907,
3224, 1779, 2458, 1251, 2486, 2774, 2899, 1103,
1275, 2652, 1065, 2881, 725, 1508, 2368, 398,
951, 247, 1421, 3222, 2499, 271, 90, 853,
1860, 3203, 1162, 1618, 666, 320, 8, 2813,
1544, 282, 1838, 1293, 2314, 552, 2677, 2106,
1571, 205, 2918, 1542, 2721, 2597, 2312, 681,
130, 1602, 1871, 829, 2946, 3065, 1325, 2756,
1861, 1474, 1202, 2367, 3147, 1752, 2707, 171,
3127, 3042, 1907, 1836, 1517, 359, 758, 1441
};
static void mlkem_invntt(sword16* r)
{
#ifdef WOLFSSL_MLKEM_SMALL
unsigned int len;
unsigned int k;
unsigned int j;
sword16 zeta;
k = 0;
for (len = 2; len <= MLKEM_N / 2; len <<= 1) {
unsigned int start;
for (start = 0; start < MLKEM_N; start = j + len) {
zeta = zetas_inv[k++];
for (j = start; j < start + len; ++j) {
sword32 p;
sword16 rj = r[j];
sword16 rjl = r[j + len];
sword16 t = (sword16)(rj + rjl);
r[j] = MLKEM_BARRETT_RED(t);
rjl = (sword16)(rj - rjl);
p = (sword32)zeta * rjl;
r[j + len] = MLKEM_MONT_RED(p);
}
}
}
zeta = zetas_inv[127];
for (j = 0; j < MLKEM_N; ++j) {
sword32 p = (sword32)zeta * r[j];
r[j] = MLKEM_MONT_RED(p);
}
#elif defined(WOLFSSL_MLKEM_NO_LARGE_CODE)
unsigned int len;
unsigned int k;
unsigned int j;
sword16 zeta;
sword16 zeta2;
k = 0;
for (len = 2; len <= MLKEM_N / 4; len <<= 1) {
unsigned int start;
for (start = 0; start < MLKEM_N; start = j + len) {
zeta = zetas_inv[k++];
for (j = start; j < start + len; ++j) {
sword32 p;
sword16 rj = r[j];
sword16 rjl = r[j + len];
sword16 t = (sword16)(rj + rjl);
r[j] = MLKEM_BARRETT_RED(t);
rjl = (sword16)(rj - rjl);
p = (sword32)zeta * rjl;
r[j + len] = MLKEM_MONT_RED(p);
}
}
}
zeta = zetas_inv[126];
zeta2 = zetas_inv[127];
for (j = 0; j < MLKEM_N / 2; ++j) {
sword32 p;
sword16 rj = r[j];
sword16 rjl = r[j + MLKEM_N / 2];
sword16 t = (sword16)(rj + rjl);
rjl = (sword16)(rj - rjl);
p = (sword32)zeta * rjl;
r[j] = (sword16)t;
r[j + MLKEM_N / 2] = MLKEM_MONT_RED(p);
p = (sword32)zeta2 * r[j];
r[j] = MLKEM_MONT_RED(p);
p = (sword32)zeta2 * r[j + MLKEM_N / 2];
r[j + MLKEM_N / 2] = MLKEM_MONT_RED(p);
}
#elif defined(WOLFSSL_MLKEM_INVNTT_UNROLL)
unsigned int k;
unsigned int j;
unsigned int start;
sword16 zeta;
sword16 zeta2;
k = 0;
for (start = 0; start < MLKEM_N; start += 2 * 2) {
zeta = zetas_inv[k++];
for (j = 0; j < 2; ++j) {
sword32 p;
sword16 rj = r[start + j];
sword16 rjl = r[start + j + 2];
sword16 t = rj + rjl;
r[start + j] = t;
rjl = rj - rjl;
p = (sword32)zeta * rjl;
r[start + j + 2] = MLKEM_MONT_RED(p);
}
}
for (start = 0; start < MLKEM_N; start += 2 * 4) {
zeta = zetas_inv[k++];
for (j = 0; j < 4; ++j) {
sword32 p;
sword16 rj = r[start + j];
sword16 rjl = r[start + j + 4];
sword16 t = rj + rjl;
r[start + j] = t;
rjl = rj - rjl;
p = (sword32)zeta * rjl;
r[start + j + 4] = MLKEM_MONT_RED(p);
}
}
for (start = 0; start < MLKEM_N; start += 2 * 8) {
zeta = zetas_inv[k++];
for (j = 0; j < 8; ++j) {
sword32 p;
sword16 rj = r[start + j];
sword16 rjl = r[start + j + 8];
sword16 t = rj + rjl;
r[start + j] = MLKEM_BARRETT_RED(t);
rjl = rj - rjl;
p = (sword32)zeta * rjl;
r[start + j + 8] = MLKEM_MONT_RED(p);
}
}
for (start = 0; start < MLKEM_N; start += 2 * 16) {
zeta = zetas_inv[k++];
for (j = 0; j < 16; ++j) {
sword32 p;
sword16 rj = r[start + j];
sword16 rjl = r[start + j + 16];
sword16 t = rj + rjl;
r[start + j] = t;
rjl = rj - rjl;
p = (sword32)zeta * rjl;
r[start + j + 16] = MLKEM_MONT_RED(p);
}
}
for (start = 0; start < MLKEM_N; start += 2 * 32) {
zeta = zetas_inv[k++];
for (j = 0; j < 32; ++j) {
sword32 p;
sword16 rj = r[start + j];
sword16 rjl = r[start + j + 32];
sword16 t = rj + rjl;
r[start + j] = t;
rjl = rj - rjl;
p = (sword32)zeta * rjl;
r[start + j + 32] = MLKEM_MONT_RED(p);
}
}
for (start = 0; start < MLKEM_N; start += 2 * 64) {
zeta = zetas_inv[k++];
for (j = 0; j < 64; ++j) {
sword32 p;
sword16 rj = r[start + j];
sword16 rjl = r[start + j + 64];
sword16 t = rj + rjl;
r[start + j] = MLKEM_BARRETT_RED(t);
rjl = rj - rjl;
p = (sword32)zeta * rjl;
r[start + j + 64] = MLKEM_MONT_RED(p);
}
}
zeta = zetas_inv[126];
zeta2 = zetas_inv[127];
for (j = 0; j < MLKEM_N / 2; ++j) {
sword32 p;
sword16 rj = r[j];
sword16 rjl = r[j + MLKEM_N / 2];
sword16 t = rj + rjl;
rjl = rj - rjl;
p = (sword32)zeta * rjl;
r[j] = t;
r[j + MLKEM_N / 2] = MLKEM_MONT_RED(p);
p = (sword32)zeta2 * r[j];
r[j] = MLKEM_MONT_RED(p);
p = (sword32)zeta2 * r[j + MLKEM_N / 2];
r[j + MLKEM_N / 2] = MLKEM_MONT_RED(p);
}
#else
unsigned int j;
sword16 t0;
sword16 t1;
sword16 t2;
sword16 t3;
sword16 zeta64_0;
sword16 zeta64_1;
sword16 zeta128;
sword16 zeta256;
sword32 p;
for (j = 0; j < MLKEM_N; j += 8) {
sword16 zeta2_0 = zetas_inv[ 0 + j / 4 + 0];
sword16 zeta2_1 = zetas_inv[ 0 + j / 4 + 1];
sword16 zeta4 = zetas_inv[64 + j / 8 + 0];
sword16 r0 = r[j + 0];
sword16 r1 = r[j + 1];
sword16 r2 = r[j + 2];
sword16 r3 = r[j + 3];
sword16 r4 = r[j + 4];
sword16 r5 = r[j + 5];
sword16 r6 = r[j + 6];
sword16 r7 = r[j + 7];
p = (sword32)zeta2_0 * (sword16)(r0 - r2);
t0 = MLKEM_MONT_RED(p);
p = (sword32)zeta2_0 * (sword16)(r1 - r3);
t1 = MLKEM_MONT_RED(p);
p = (sword32)zeta2_1 * (sword16)(r4 - r6);
t2 = MLKEM_MONT_RED(p);
p = (sword32)zeta2_1 * (sword16)(r5 - r7);
t3 = MLKEM_MONT_RED(p);
r0 = (sword16)(r0 + r2);
r1 = (sword16)(r1 + r3);
r4 = (sword16)(r4 + r6);
r5 = (sword16)(r5 + r7);
r2 = t0;
r3 = t1;
r6 = t2;
r7 = t3;
p = (sword32)zeta4 * (sword16)(r0 - r4);
t0 = MLKEM_MONT_RED(p);
p = (sword32)zeta4 * (sword16)(r1 - r5);
t1 = MLKEM_MONT_RED(p);
p = (sword32)zeta4 * (sword16)(r2 - r6);
t2 = MLKEM_MONT_RED(p);
p = (sword32)zeta4 * (sword16)(r3 - r7);
t3 = MLKEM_MONT_RED(p);
r0 = (sword16)(r0 + r4);
r1 = (sword16)(r1 + r5);
r2 = (sword16)(r2 + r6);
r3 = (sword16)(r3 + r7);
r4 = t0;
r5 = t1;
r6 = t2;
r7 = t3;
r[j + 0] = r0;
r[j + 1] = r1;
r[j + 2] = r2;
r[j + 3] = r3;
r[j + 4] = r4;
r[j + 5] = r5;
r[j + 6] = r6;
r[j + 7] = r7;
}
for (j = 0; j < MLKEM_N; j += 64) {
unsigned int i;
sword16 zeta8_0 = zetas_inv[ 96 + j / 16 + 0];
sword16 zeta8_1 = zetas_inv[ 96 + j / 16 + 1];
sword16 zeta8_2 = zetas_inv[ 96 + j / 16 + 2];
sword16 zeta8_3 = zetas_inv[ 96 + j / 16 + 3];
sword16 zeta16_0 = zetas_inv[112 + j / 32 + 0];
sword16 zeta16_1 = zetas_inv[112 + j / 32 + 1];
sword16 zeta32 = zetas_inv[120 + j / 64 + 0];
for (i = 0; i < 8; i++) {
sword16 r0 = r[j + i + 0];
sword16 r1 = r[j + i + 8];
sword16 r2 = r[j + i + 16];
sword16 r3 = r[j + i + 24];
sword16 r4 = r[j + i + 32];
sword16 r5 = r[j + i + 40];
sword16 r6 = r[j + i + 48];
sword16 r7 = r[j + i + 56];
p = (sword32)zeta8_0 * (sword16)(r0 - r1);
t0 = MLKEM_MONT_RED(p);
p = (sword32)zeta8_1 * (sword16)(r2 - r3);
t1 = MLKEM_MONT_RED(p);
p = (sword32)zeta8_2 * (sword16)(r4 - r5);
t2 = MLKEM_MONT_RED(p);
p = (sword32)zeta8_3 * (sword16)(r6 - r7);
t3 = MLKEM_MONT_RED(p);
r0 = MLKEM_BARRETT_RED(r0 + r1);
r2 = MLKEM_BARRETT_RED(r2 + r3);
r4 = MLKEM_BARRETT_RED(r4 + r5);
r6 = MLKEM_BARRETT_RED(r6 + r7);
r1 = t0;
r3 = t1;
r5 = t2;
r7 = t3;
p = (sword32)zeta16_0 * (sword16)(r0 - r2);
t0 = MLKEM_MONT_RED(p);
p = (sword32)zeta16_0 * (sword16)(r1 - r3);
t1 = MLKEM_MONT_RED(p);
p = (sword32)zeta16_1 * (sword16)(r4 - r6);
t2 = MLKEM_MONT_RED(p);
p = (sword32)zeta16_1 * (sword16)(r5 - r7);
t3 = MLKEM_MONT_RED(p);
r0 = (sword16)(r0 + r2);
r1 = (sword16)(r1 + r3);
r4 = (sword16)(r4 + r6);
r5 = (sword16)(r5 + r7);
r2 = t0;
r3 = t1;
r6 = t2;
r7 = t3;
p = (sword32)zeta32 * (sword16)(r0 - r4);
t0 = MLKEM_MONT_RED(p);
p = (sword32)zeta32 * (sword16)(r1 - r5);
t1 = MLKEM_MONT_RED(p);
p = (sword32)zeta32 * (sword16)(r2 - r6);
t2 = MLKEM_MONT_RED(p);
p = (sword32)zeta32 * (sword16)(r3 - r7);
t3 = MLKEM_MONT_RED(p);
r0 = (sword16)(r0 + r4);
r1 = (sword16)(r1 + r5);
r2 = (sword16)(r2 + r6);
r3 = (sword16)(r3 + r7);
r4 = t0;
r5 = t1;
r6 = t2;
r7 = t3;
r[j + i + 0] = r0;
r[j + i + 8] = r1;
r[j + i + 16] = r2;
r[j + i + 24] = r3;
r[j + i + 32] = r4;
r[j + i + 40] = r5;
r[j + i + 48] = r6;
r[j + i + 56] = r7;
}
}
zeta64_0 = zetas_inv[124];
zeta64_1 = zetas_inv[125];
zeta128 = zetas_inv[126];
zeta256 = zetas_inv[127];
for (j = 0; j < MLKEM_N / 8; j++) {
sword16 r0 = r[j + 0];
sword16 r1 = r[j + 32];
sword16 r2 = r[j + 64];
sword16 r3 = r[j + 96];
sword16 r4 = r[j + 128];
sword16 r5 = r[j + 160];
sword16 r6 = r[j + 192];
sword16 r7 = r[j + 224];
p = (sword32)zeta64_0 * (sword16)(r0 - r2);
t0 = MLKEM_MONT_RED(p);
p = (sword32)zeta64_0 * (sword16)(r1 - r3);
t1 = MLKEM_MONT_RED(p);
p = (sword32)zeta64_1 * (sword16)(r4 - r6);
t2 = MLKEM_MONT_RED(p);
p = (sword32)zeta64_1 * (sword16)(r5 - r7);
t3 = MLKEM_MONT_RED(p);
r0 = MLKEM_BARRETT_RED(r0 + r2);
r1 = MLKEM_BARRETT_RED(r1 + r3);
r4 = MLKEM_BARRETT_RED(r4 + r6);
r5 = MLKEM_BARRETT_RED(r5 + r7);
r2 = t0;
r3 = t1;
r6 = t2;
r7 = t3;
p = (sword32)zeta128 * (sword16)(r0 - r4);
t0 = MLKEM_MONT_RED(p);
p = (sword32)zeta128 * (sword16)(r1 - r5);
t1 = MLKEM_MONT_RED(p);
p = (sword32)zeta128 * (sword16)(r2 - r6);
t2 = MLKEM_MONT_RED(p);
p = (sword32)zeta128 * (sword16)(r3 - r7);
t3 = MLKEM_MONT_RED(p);
r0 = (sword16)(r0 + r4);
r1 = (sword16)(r1 + r5);
r2 = (sword16)(r2 + r6);
r3 = (sword16)(r3 + r7);
r4 = t0;
r5 = t1;
r6 = t2;
r7 = t3;
p = (sword32)zeta256 * r0;
r0 = MLKEM_MONT_RED(p);
p = (sword32)zeta256 * r1;
r1 = MLKEM_MONT_RED(p);
p = (sword32)zeta256 * r2;
r2 = MLKEM_MONT_RED(p);
p = (sword32)zeta256 * r3;
r3 = MLKEM_MONT_RED(p);
p = (sword32)zeta256 * r4;
r4 = MLKEM_MONT_RED(p);
p = (sword32)zeta256 * r5;
r5 = MLKEM_MONT_RED(p);
p = (sword32)zeta256 * r6;
r6 = MLKEM_MONT_RED(p);
p = (sword32)zeta256 * r7;
r7 = MLKEM_MONT_RED(p);
r[j + 0] = r0;
r[j + 32] = r1;
r[j + 64] = r2;
r[j + 96] = r3;
r[j + 128] = r4;
r[j + 160] = r5;
r[j + 192] = r6;
r[j + 224] = r7;
}
#endif
}
#endif
static void mlkem_basemul(sword16* r, const sword16* a, const sword16* b,
sword16 zeta)
{
sword16 r0;
sword16 a0 = a[0];
sword16 a1 = a[1];
sword16 b0 = b[0];
sword16 b1 = b[1];
sword32 p1;
sword32 p2;
p1 = (sword32)a0 * b0;
p2 = (sword32)a1 * b1;
r0 = MLKEM_MONT_RED(p2);
p2 = (sword32)zeta * r0;
p2 += p1;
r[0] = MLKEM_MONT_RED(p2);
p1 = (sword32)a0 * b1;
p2 = (sword32)a1 * b0;
p1 += p2;
r[1] = MLKEM_MONT_RED(p1);
}
static void mlkem_basemul_mont(sword16* r, const sword16* a, const sword16* b)
{
const sword16* zeta = zetas + 64;
#if defined(WOLFSSL_MLKEM_SMALL)
unsigned int i;
for (i = 0; i < MLKEM_N; i += 4, zeta++) {
mlkem_basemul(r + i + 0, a + i + 0, b + i + 0, zeta[0]);
mlkem_basemul(r + i + 2, a + i + 2, b + i + 2, (sword16)(-zeta[0]));
}
#elif defined(WOLFSSL_MLKEM_NO_LARGE_CODE)
unsigned int i;
for (i = 0; i < MLKEM_N; i += 8, zeta += 2) {
mlkem_basemul(r + i + 0, a + i + 0, b + i + 0, zeta[0]);
mlkem_basemul(r + i + 2, a + i + 2, b + i + 2, (sword16)(-zeta[0]));
mlkem_basemul(r + i + 4, a + i + 4, b + i + 4, zeta[1]);
mlkem_basemul(r + i + 6, a + i + 6, b + i + 6, (sword16)(-zeta[1]));
}
#else
unsigned int i;
for (i = 0; i < MLKEM_N; i += 16, zeta += 4) {
mlkem_basemul(r + i + 0, a + i + 0, b + i + 0, zeta[0]);
mlkem_basemul(r + i + 2, a + i + 2, b + i + 2, (sword16)(-zeta[0]));
mlkem_basemul(r + i + 4, a + i + 4, b + i + 4, zeta[1]);
mlkem_basemul(r + i + 6, a + i + 6, b + i + 6, (sword16)(-zeta[1]));
mlkem_basemul(r + i + 8, a + i + 8, b + i + 8, zeta[2]);
mlkem_basemul(r + i + 10, a + i + 10, b + i + 10, (sword16)(-zeta[2]));
mlkem_basemul(r + i + 12, a + i + 12, b + i + 12, zeta[3]);
mlkem_basemul(r + i + 14, a + i + 14, b + i + 14, (sword16)(-zeta[3]));
}
#endif
}
static void mlkem_basemul_mont_add(sword16* r, const sword16* a,
const sword16* b)
{
const sword16* zeta = zetas + 64;
#if defined(WOLFSSL_MLKEM_SMALL)
unsigned int i;
for (i = 0; i < MLKEM_N; i += 4, zeta++) {
sword16 t0[2];
sword16 t2[2];
mlkem_basemul(t0, a + i + 0, b + i + 0, zeta[0]);
mlkem_basemul(t2, a + i + 2, b + i + 2, (sword16)(-zeta[0]));
r[i + 0] = (sword16)(r[i + 0] + t0[0]);
r[i + 1] = (sword16)(r[i + 1] + t0[1]);
r[i + 2] = (sword16)(r[i + 2] + t2[0]);
r[i + 3] = (sword16)(r[i + 3] + t2[1]);
}
#elif defined(WOLFSSL_MLKEM_NO_LARGE_CODE)
unsigned int i;
for (i = 0; i < MLKEM_N; i += 8, zeta += 2) {
sword16 t0[2];
sword16 t2[2];
sword16 t4[2];
sword16 t6[2];
mlkem_basemul(t0, a + i + 0, b + i + 0, zeta[0]);
mlkem_basemul(t2, a + i + 2, b + i + 2, (sword16)(-zeta[0]));
mlkem_basemul(t4, a + i + 4, b + i + 4, zeta[1]);
mlkem_basemul(t6, a + i + 6, b + i + 6, (sword16)(-zeta[1]));
r[i + 0] = (sword16)(r[i + 0] + t0[0]);
r[i + 1] = (sword16)(r[i + 1] + t0[1]);
r[i + 2] = (sword16)(r[i + 2] + t2[0]);
r[i + 3] = (sword16)(r[i + 3] + t2[1]);
r[i + 4] = (sword16)(r[i + 4] + t4[0]);
r[i + 5] = (sword16)(r[i + 5] + t4[1]);
r[i + 6] = (sword16)(r[i + 6] + t6[0]);
r[i + 7] = (sword16)(r[i + 7] + t6[1]);
}
#else
unsigned int i;
for (i = 0; i < MLKEM_N; i += 16, zeta += 4) {
sword16 t0[2];
sword16 t2[2];
sword16 t4[2];
sword16 t6[2];
sword16 t8[2];
sword16 t10[2];
sword16 t12[2];
sword16 t14[2];
mlkem_basemul(t0, a + i + 0, b + i + 0, zeta[0]);
mlkem_basemul(t2, a + i + 2, b + i + 2, (sword16)(-zeta[0]));
mlkem_basemul(t4, a + i + 4, b + i + 4, zeta[1]);
mlkem_basemul(t6, a + i + 6, b + i + 6, (sword16)(-zeta[1]));
mlkem_basemul(t8, a + i + 8, b + i + 8, zeta[2]);
mlkem_basemul(t10, a + i + 10, b + i + 10, (sword16)(-zeta[2]));
mlkem_basemul(t12, a + i + 12, b + i + 12, zeta[3]);
mlkem_basemul(t14, a + i + 14, b + i + 14, (sword16)(-zeta[3]));
r[i + 0] = (sword16)(r[i + 0] + t0[0]);
r[i + 1] = (sword16)(r[i + 1] + t0[1]);
r[i + 2] = (sword16)(r[i + 2] + t2[0]);
r[i + 3] = (sword16)(r[i + 3] + t2[1]);
r[i + 4] = (sword16)(r[i + 4] + t4[0]);
r[i + 5] = (sword16)(r[i + 5] + t4[1]);
r[i + 6] = (sword16)(r[i + 6] + t6[0]);
r[i + 7] = (sword16)(r[i + 7] + t6[1]);
r[i + 8] = (sword16)(r[i + 8] + t8[0]);
r[i + 9] = (sword16)(r[i + 9] + t8[1]);
r[i + 10] = (sword16)(r[i + 10] + t10[0]);
r[i + 11] = (sword16)(r[i + 11] + t10[1]);
r[i + 12] = (sword16)(r[i + 12] + t12[0]);
r[i + 13] = (sword16)(r[i + 13] + t12[1]);
r[i + 14] = (sword16)(r[i + 14] + t14[0]);
r[i + 15] = (sword16)(r[i + 15] + t14[1]);
}
#endif
}
#endif
static void mlkem_pointwise_acc_mont(sword16* r, const sword16* a,
const sword16* b, unsigned int k)
{
unsigned int i;
mlkem_basemul_mont(r, a, b);
#ifdef WOLFSSL_MLKEM_SMALL
for (i = 1; i < k; ++i) {
mlkem_basemul_mont_add(r, a + i * MLKEM_N, b + i * MLKEM_N);
}
#else
for (i = 1; i < k - 1; ++i) {
mlkem_basemul_mont_add(r, a + i * MLKEM_N, b + i * MLKEM_N);
}
mlkem_basemul_mont_add(r, a + (k - 1) * MLKEM_N, b + (k - 1) * MLKEM_N);
#endif
}
void mlkem_init(void)
{
#if defined(USE_INTEL_SPEEDUP) || (defined(__aarch64__) && \
defined(WOLFSSL_ARMASM))
cpuid_get_flags_ex(&cpuid_flags);
#endif
}
#if defined(__aarch64__) && defined(WOLFSSL_ARMASM)
#ifndef WOLFSSL_MLKEM_NO_MAKE_KEY
void mlkem_keygen(sword16* s, sword16* t, sword16* e, const sword16* a, int k)
{
int i;
#ifndef WOLFSSL_AARCH64_NO_SQRDMLSH
if (IS_AARCH64_RDM(cpuid_flags)) {
for (i = 0; i < k; ++i) {
mlkem_ntt_sqrdmlsh(s + i * MLKEM_N);
}
for (i = 0; i < k; ++i) {
mlkem_pointwise_acc_mont(t + i * MLKEM_N, a + i * k * MLKEM_N, s,
k);
mlkem_to_mont_sqrdmlsh(t + i * MLKEM_N);
mlkem_ntt_sqrdmlsh(e + i * MLKEM_N);
mlkem_add_reduce(t + i * MLKEM_N, e + i * MLKEM_N);
}
}
else
#endif
{
for (i = 0; i < k; ++i) {
mlkem_ntt(s + i * MLKEM_N);
}
for (i = 0; i < k; ++i) {
mlkem_pointwise_acc_mont(t + i * MLKEM_N, a + i * k * MLKEM_N, s,
k);
mlkem_to_mont(t + i * MLKEM_N);
mlkem_ntt(e + i * MLKEM_N);
mlkem_add_reduce(t + i * MLKEM_N, e + i * MLKEM_N);
}
}
}
#endif
#if !defined(WOLFSSL_MLKEM_NO_ENCAPSULATE) || \
!defined(WOLFSSL_MLKEM_NO_DECAPSULATE)
void mlkem_encapsulate(const sword16* t, sword16* u , sword16* v,
const sword16* a, sword16* y, const sword16* e1, const sword16* e2,
const sword16* m, int k)
{
int i;
#ifndef WOLFSSL_AARCH64_NO_SQRDMLSH
if (IS_AARCH64_RDM(cpuid_flags)) {
for (i = 0; i < k; ++i) {
mlkem_ntt_sqrdmlsh(y + i * MLKEM_N);
}
for (i = 0; i < k; ++i) {
mlkem_pointwise_acc_mont(u + i * MLKEM_N, a + i * k * MLKEM_N, y,
k);
mlkem_invntt_sqrdmlsh(u + i * MLKEM_N);
mlkem_add_reduce(u + i * MLKEM_N, e1 + i * MLKEM_N);
}
mlkem_pointwise_acc_mont(v, t, y, k);
mlkem_invntt_sqrdmlsh(v);
}
else
#endif
{
for (i = 0; i < k; ++i) {
mlkem_ntt(y + i * MLKEM_N);
}
for (i = 0; i < k; ++i) {
mlkem_pointwise_acc_mont(u + i * MLKEM_N, a + i * k * MLKEM_N, y,
k);
mlkem_invntt(u + i * MLKEM_N);
mlkem_add_reduce(u + i * MLKEM_N, e1 + i * MLKEM_N);
}
mlkem_pointwise_acc_mont(v, t, y, k);
mlkem_invntt(v);
}
mlkem_add3_reduce(v, e2, m);
}
#endif
#ifndef WOLFSSL_MLKEM_NO_DECAPSULATE
void mlkem_decapsulate(const sword16* s, sword16* w, sword16* u,
const sword16* v, int k)
{
int i;
#ifndef WOLFSSL_AARCH64_NO_SQRDMLSH
if (IS_AARCH64_RDM(cpuid_flags)) {
for (i = 0; i < k; ++i) {
mlkem_ntt_sqrdmlsh(u + i * MLKEM_N);
}
mlkem_pointwise_acc_mont(w, s, u, k);
mlkem_invntt_sqrdmlsh(w);
}
else
#endif
{
for (i = 0; i < k; ++i) {
mlkem_ntt(u + i * MLKEM_N);
}
mlkem_pointwise_acc_mont(w, s, u, k);
mlkem_invntt(w);
}
mlkem_rsub_reduce(w, v);
}
#endif
#else
#ifndef WOLFSSL_MLKEM_NO_MAKE_KEY
#if !defined(WOLFSSL_MLKEM_SMALL) && !defined(WOLFSSL_MLKEM_NO_LARGE_CODE)
static void mlkem_ntt_add_to(sword16* r, sword16* a)
{
#if defined(WOLFSSL_MLKEM_NTT_UNROLL)
unsigned int k = 1;
unsigned int j;
unsigned int start;
sword16 zeta = zetas[k++];
for (j = 0; j < MLKEM_N / 2; ++j) {
sword32 p = (sword32)zeta * r[j + MLKEM_N / 2];
sword16 t = MLKEM_MONT_RED(p);
sword16 rj = r[j];
r[j + MLKEM_N / 2] = rj - t;
r[j] = rj + t;
}
for (start = 0; start < MLKEM_N; start += 2 * 64) {
zeta = zetas[k++];
for (j = 0; j < 64; ++j) {
sword32 p = (sword32)zeta * r[start + j + 64];
sword16 t = MLKEM_MONT_RED(p);
sword16 rj = r[start + j];
r[start + j + 64] = rj - t;
r[start + j] = rj + t;
}
}
for (start = 0; start < MLKEM_N; start += 2 * 32) {
zeta = zetas[k++];
for (j = 0; j < 32; ++j) {
sword32 p = (sword32)zeta * r[start + j + 32];
sword16 t = MLKEM_MONT_RED(p);
sword16 rj = r[start + j];
r[start + j + 32] = rj - t;
r[start + j] = rj + t;
}
}
for (start = 0; start < MLKEM_N; start += 2 * 16) {
zeta = zetas[k++];
for (j = 0; j < 16; ++j) {
sword32 p = (sword32)zeta * r[start + j + 16];
sword16 t = MLKEM_MONT_RED(p);
sword16 rj = r[start + j];
r[start + j + 16] = rj - t;
r[start + j] = rj + t;
}
}
for (start = 0; start < MLKEM_N; start += 2 * 8) {
zeta = zetas[k++];
for (j = 0; j < 8; ++j) {
sword32 p = (sword32)zeta * r[start + j + 8];
sword16 t = MLKEM_MONT_RED(p);
sword16 rj = r[start + j];
r[start + j + 8] = rj - t;
r[start + j] = rj + t;
}
}
for (start = 0; start < MLKEM_N; start += 2 * 4) {
zeta = zetas[k++];
for (j = 0; j < 4; ++j) {
sword32 p = (sword32)zeta * r[start + j + 4];
sword16 t = MLKEM_MONT_RED(p);
sword16 rj = r[start + j];
r[start + j + 4] = rj - t;
r[start + j] = rj + t;
}
}
for (start = 0; start < MLKEM_N; start += 2 * 2) {
zeta = zetas[k++];
for (j = 0; j < 2; ++j) {
sword32 p = (sword32)zeta * r[start + j + 2];
sword16 t = MLKEM_MONT_RED(p);
sword16 rj = r[start + j];
r[start + j + 2] = rj - t;
r[start + j] = rj + t;
}
}
for (j = 0; j < MLKEM_N; ++j) {
sword16 t = a[j] + r[j];
a[j] = MLKEM_BARRETT_RED(t);
}
#else
unsigned int j;
sword16 t0;
sword16 t1;
sword16 t2;
sword16 t3;
sword16 zeta128 = zetas[1];
sword16 zeta64_0 = zetas[2];
sword16 zeta64_1 = zetas[3];
for (j = 0; j < MLKEM_N / 8; j++) {
sword16 r0 = r[j + 0];
sword16 r1 = r[j + 32];
sword16 r2 = r[j + 64];
sword16 r3 = r[j + 96];
sword16 r4 = r[j + 128];
sword16 r5 = r[j + 160];
sword16 r6 = r[j + 192];
sword16 r7 = r[j + 224];
t0 = MLKEM_MONT_RED((sword32)zeta128 * r4);
t1 = MLKEM_MONT_RED((sword32)zeta128 * r5);
t2 = MLKEM_MONT_RED((sword32)zeta128 * r6);
t3 = MLKEM_MONT_RED((sword32)zeta128 * r7);
r4 = (sword16)(r0 - t0);
r5 = (sword16)(r1 - t1);
r6 = (sword16)(r2 - t2);
r7 = (sword16)(r3 - t3);
r0 = (sword16)(r0 + t0);
r1 = (sword16)(r1 + t1);
r2 = (sword16)(r2 + t2);
r3 = (sword16)(r3 + t3);
t0 = MLKEM_MONT_RED((sword32)zeta64_0 * r2);
t1 = MLKEM_MONT_RED((sword32)zeta64_0 * r3);
t2 = MLKEM_MONT_RED((sword32)zeta64_1 * r6);
t3 = MLKEM_MONT_RED((sword32)zeta64_1 * r7);
r2 = (sword16)(r0 - t0);
r3 = (sword16)(r1 - t1);
r6 = (sword16)(r4 - t2);
r7 = (sword16)(r5 - t3);
r0 = (sword16)(r0 + t0);
r1 = (sword16)(r1 + t1);
r4 = (sword16)(r4 + t2);
r5 = (sword16)(r5 + t3);
r[j + 0] = r0;
r[j + 32] = r1;
r[j + 64] = r2;
r[j + 96] = r3;
r[j + 128] = r4;
r[j + 160] = r5;
r[j + 192] = r6;
r[j + 224] = r7;
}
for (j = 0; j < MLKEM_N; j += 64) {
unsigned int i;
sword16 zeta32 = zetas[ 4 + j / 64 + 0];
sword16 zeta16_0 = zetas[ 8 + j / 32 + 0];
sword16 zeta16_1 = zetas[ 8 + j / 32 + 1];
sword16 zeta8_0 = zetas[16 + j / 16 + 0];
sword16 zeta8_1 = zetas[16 + j / 16 + 1];
sword16 zeta8_2 = zetas[16 + j / 16 + 2];
sword16 zeta8_3 = zetas[16 + j / 16 + 3];
for (i = 0; i < 8; i++) {
sword16 r0 = r[j + i + 0];
sword16 r1 = r[j + i + 8];
sword16 r2 = r[j + i + 16];
sword16 r3 = r[j + i + 24];
sword16 r4 = r[j + i + 32];
sword16 r5 = r[j + i + 40];
sword16 r6 = r[j + i + 48];
sword16 r7 = r[j + i + 56];
t0 = MLKEM_MONT_RED((sword32)zeta32 * r4);
t1 = MLKEM_MONT_RED((sword32)zeta32 * r5);
t2 = MLKEM_MONT_RED((sword32)zeta32 * r6);
t3 = MLKEM_MONT_RED((sword32)zeta32 * r7);
r4 = (sword16)(r0 - t0);
r5 = (sword16)(r1 - t1);
r6 = (sword16)(r2 - t2);
r7 = (sword16)(r3 - t3);
r0 = (sword16)(r0 + t0);
r1 = (sword16)(r1 + t1);
r2 = (sword16)(r2 + t2);
r3 = (sword16)(r3 + t3);
t0 = MLKEM_MONT_RED((sword32)zeta16_0 * r2);
t1 = MLKEM_MONT_RED((sword32)zeta16_0 * r3);
t2 = MLKEM_MONT_RED((sword32)zeta16_1 * r6);
t3 = MLKEM_MONT_RED((sword32)zeta16_1 * r7);
r2 = (sword16)(r0 - t0);
r3 = (sword16)(r1 - t1);
r6 = (sword16)(r4 - t2);
r7 = (sword16)(r5 - t3);
r0 = (sword16)(r0 + t0);
r1 = (sword16)(r1 + t1);
r4 = (sword16)(r4 + t2);
r5 = (sword16)(r5 + t3);
t0 = MLKEM_MONT_RED((sword32)zeta8_0 * r1);
t1 = MLKEM_MONT_RED((sword32)zeta8_1 * r3);
t2 = MLKEM_MONT_RED((sword32)zeta8_2 * r5);
t3 = MLKEM_MONT_RED((sword32)zeta8_3 * r7);
r1 = (sword16)(r0 - t0);
r3 = (sword16)(r2 - t1);
r5 = (sword16)(r4 - t2);
r7 = (sword16)(r6 - t3);
r0 = (sword16)(r0 + t0);
r2 = (sword16)(r2 + t1);
r4 = (sword16)(r4 + t2);
r6 = (sword16)(r6 + t3);
r[j + i + 0] = r0;
r[j + i + 8] = r1;
r[j + i + 16] = r2;
r[j + i + 24] = r3;
r[j + i + 32] = r4;
r[j + i + 40] = r5;
r[j + i + 48] = r6;
r[j + i + 56] = r7;
}
}
for (j = 0; j < MLKEM_N; j += 8) {
sword16 zeta4 = zetas[32 + j / 8 + 0];
sword16 zeta2_0 = zetas[64 + j / 4 + 0];
sword16 zeta2_1 = zetas[64 + j / 4 + 1];
sword16 r0 = r[j + 0];
sword16 r1 = r[j + 1];
sword16 r2 = r[j + 2];
sword16 r3 = r[j + 3];
sword16 r4 = r[j + 4];
sword16 r5 = r[j + 5];
sword16 r6 = r[j + 6];
sword16 r7 = r[j + 7];
t0 = MLKEM_MONT_RED((sword32)zeta4 * r4);
t1 = MLKEM_MONT_RED((sword32)zeta4 * r5);
t2 = MLKEM_MONT_RED((sword32)zeta4 * r6);
t3 = MLKEM_MONT_RED((sword32)zeta4 * r7);
r4 = (sword16)(r0 - t0);
r5 = (sword16)(r1 - t1);
r6 = (sword16)(r2 - t2);
r7 = (sword16)(r3 - t3);
r0 = (sword16)(r0 + t0);
r1 = (sword16)(r1 + t1);
r2 = (sword16)(r2 + t2);
r3 = (sword16)(r3 + t3);
t0 = MLKEM_MONT_RED((sword32)zeta2_0 * r2);
t1 = MLKEM_MONT_RED((sword32)zeta2_0 * r3);
t2 = MLKEM_MONT_RED((sword32)zeta2_1 * r6);
t3 = MLKEM_MONT_RED((sword32)zeta2_1 * r7);
r2 = (sword16)(r0 - t0);
r3 = (sword16)(r1 - t1);
r6 = (sword16)(r4 - t2);
r7 = (sword16)(r5 - t3);
r0 = (sword16)(r0 + t0);
r1 = (sword16)(r1 + t1);
r4 = (sword16)(r4 + t2);
r5 = (sword16)(r5 + t3);
r0 = (sword16)(r0 + a[j + 0]);
r1 = (sword16)(r1 + a[j + 1]);
r2 = (sword16)(r2 + a[j + 2]);
r3 = (sword16)(r3 + a[j + 3]);
r4 = (sword16)(r4 + a[j + 4]);
r5 = (sword16)(r5 + a[j + 5]);
r6 = (sword16)(r6 + a[j + 6]);
r7 = (sword16)(r7 + a[j + 7]);
a[j + 0] = MLKEM_BARRETT_RED(r0);
a[j + 1] = MLKEM_BARRETT_RED(r1);
a[j + 2] = MLKEM_BARRETT_RED(r2);
a[j + 3] = MLKEM_BARRETT_RED(r3);
a[j + 4] = MLKEM_BARRETT_RED(r4);
a[j + 5] = MLKEM_BARRETT_RED(r5);
a[j + 6] = MLKEM_BARRETT_RED(r6);
a[j + 7] = MLKEM_BARRETT_RED(r7);
}
#endif
}
#endif
#ifndef WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM
static void mlkem_keygen_c(sword16* s, sword16* t, sword16* e, const sword16* a,
int k)
{
int i;
for (i = 0; i < k; ++i) {
mlkem_ntt(s + i * MLKEM_N);
}
for (i = 0; i < k; ++i) {
int j;
mlkem_pointwise_acc_mont(t + i * MLKEM_N, a + i * k * MLKEM_N, s,
(unsigned int)k);
for (j = 0; j < MLKEM_N; ++j) {
sword32 n = t[i * MLKEM_N + j] * (sword32)MLKEM_F;
t[i * MLKEM_N + j] = MLKEM_MONT_RED(n);
}
#if defined(WOLFSSL_MLKEM_SMALL) || defined(WOLFSSL_MLKEM_NO_LARGE_CODE)
mlkem_ntt(e + i * MLKEM_N);
for (j = 0; j < MLKEM_N; ++j) {
sword16 n = (sword16)(t[i * MLKEM_N + j] + e[i * MLKEM_N + j]);
t[i * MLKEM_N + j] = MLKEM_BARRETT_RED(n);
}
#else
mlkem_ntt_add_to(e + i * MLKEM_N, t + i * MLKEM_N);
#endif
}
}
void mlkem_keygen(sword16* s, sword16* t, sword16* e, const sword16* a, int k)
{
#ifdef USE_INTEL_SPEEDUP
if ((IS_INTEL_AVX2(cpuid_flags)) && (SAVE_VECTOR_REGISTERS2() == 0)) {
mlkem_keygen_avx2(s, t, e, a, k);
RESTORE_VECTOR_REGISTERS();
}
else
#endif
{
mlkem_keygen_c(s, t, e, a, k);
}
}
#else
int mlkem_keygen_seeds(sword16* s, sword16* t, MLKEM_PRF_T* prf,
sword16* tv, int k, byte* rho, byte* sigma)
{
int i;
int ret = 0;
sword16* ai = tv;
sword16* e = tv;
for (i = 0; i < k; ++i) {
mlkem_ntt(s + i * MLKEM_N);
}
for (i = 0; i < k; ++i) {
int j;
ret = mlkem_gen_matrix_i(prf, ai, k, rho, i, 0);
if (ret != 0) {
break;
}
mlkem_pointwise_acc_mont(t + i * MLKEM_N, ai, s, (unsigned int)k);
for (j = 0; j < MLKEM_N; ++j) {
sword32 n = t[i * MLKEM_N + j] * (sword32)MLKEM_F;
t[i * MLKEM_N + j] = MLKEM_MONT_RED(n);
}
ret = mlkem_get_noise_i(prf, k, e, sigma, i, 1);
if (ret != 0) {
break;
}
#if defined(WOLFSSL_MLKEM_SMALL) || defined(WOLFSSL_MLKEM_NO_LARGE_CODE)
mlkem_ntt(e);
for (j = 0; j < MLKEM_N; ++j) {
sword16 n = (sword16)(t[i * MLKEM_N + j] + e[j]);
t[i * MLKEM_N + j] = MLKEM_BARRETT_RED(n);
}
#else
mlkem_ntt_add_to(e, t + i * MLKEM_N);
#endif
}
return ret;
}
#endif
#endif
#if !defined(WOLFSSL_MLKEM_NO_ENCAPSULATE) || \
!defined(WOLFSSL_MLKEM_NO_DECAPSULATE)
#ifndef WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM
static void mlkem_encapsulate_c(const sword16* pub, sword16* u, sword16* v,
const sword16* a, sword16* y, const sword16* e1, const sword16* e2,
const sword16* m, int k)
{
int i;
for (i = 0; i < k; ++i) {
mlkem_ntt(y + i * MLKEM_N);
}
for (i = 0; i < k; ++i) {
int j;
mlkem_pointwise_acc_mont(u + i * MLKEM_N, a + i * k * MLKEM_N, y,
(unsigned int)k);
mlkem_invntt(u + i * MLKEM_N);
#if defined(WOLFSSL_MLKEM_SMALL) || defined(WOLFSSL_MLKEM_NO_LARGE_CODE)
for (j = 0; j < MLKEM_N; ++j) {
sword16 t = (sword16)(u[i * MLKEM_N + j] + e1[i * MLKEM_N + j]);
u[i * MLKEM_N + j] = MLKEM_BARRETT_RED(t);
}
#else
for (j = 0; j < MLKEM_N; j += 8) {
sword16 t0 = (sword16)(u[i * MLKEM_N + j + 0] +
e1[i * MLKEM_N + j + 0]);
sword16 t1 = (sword16)(u[i * MLKEM_N + j + 1] +
e1[i * MLKEM_N + j + 1]);
sword16 t2 = (sword16)(u[i * MLKEM_N + j + 2] +
e1[i * MLKEM_N + j + 2]);
sword16 t3 = (sword16)(u[i * MLKEM_N + j + 3] +
e1[i * MLKEM_N + j + 3]);
sword16 t4 = (sword16)(u[i * MLKEM_N + j + 4] +
e1[i * MLKEM_N + j + 4]);
sword16 t5 = (sword16)(u[i * MLKEM_N + j + 5] +
e1[i * MLKEM_N + j + 5]);
sword16 t6 = (sword16)(u[i * MLKEM_N + j + 6] +
e1[i * MLKEM_N + j + 6]);
sword16 t7 = (sword16)(u[i * MLKEM_N + j + 7] +
e1[i * MLKEM_N + j + 7]);
u[i * MLKEM_N + j + 0] = MLKEM_BARRETT_RED(t0);
u[i * MLKEM_N + j + 1] = MLKEM_BARRETT_RED(t1);
u[i * MLKEM_N + j + 2] = MLKEM_BARRETT_RED(t2);
u[i * MLKEM_N + j + 3] = MLKEM_BARRETT_RED(t3);
u[i * MLKEM_N + j + 4] = MLKEM_BARRETT_RED(t4);
u[i * MLKEM_N + j + 5] = MLKEM_BARRETT_RED(t5);
u[i * MLKEM_N + j + 6] = MLKEM_BARRETT_RED(t6);
u[i * MLKEM_N + j + 7] = MLKEM_BARRETT_RED(t7);
}
#endif
}
mlkem_pointwise_acc_mont(v, pub, y, (unsigned int)k);
mlkem_invntt(v);
for (i = 0; i < MLKEM_N; ++i) {
sword16 t = (sword16)(v[i] + e2[i] + m[i]);
v[i] = MLKEM_BARRETT_RED(t);
}
}
void mlkem_encapsulate(const sword16* pub, sword16* u, sword16* v,
const sword16* a, sword16* y, const sword16* e1, const sword16* e2,
const sword16* m, int k)
{
#ifdef USE_INTEL_SPEEDUP
if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) {
mlkem_encapsulate_avx2(pub, u, v, a, y, e1, e2, m, k);
RESTORE_VECTOR_REGISTERS();
}
else
#endif
{
mlkem_encapsulate_c(pub, u, v, a, y, e1, e2, m, k);
}
}
#else
int mlkem_encapsulate_seeds(const sword16* pub, MLKEM_PRF_T* prf, sword16* u,
sword16* tp, sword16* y, int k, const byte* msg, byte* seed, byte* coins)
{
int ret = 0;
int i;
sword16* a = tp;
sword16* e1 = tp;
sword16* v = tp;
sword16* e2 = tp + MLKEM_N;
sword16* m = y;
for (i = 0; i < k; ++i) {
mlkem_ntt(y + i * MLKEM_N);
}
for (i = 0; i < k; ++i) {
int j;
ret = mlkem_gen_matrix_i(prf, a, k, seed, i, 1);
if (ret != 0) {
break;
}
mlkem_pointwise_acc_mont(u + i * MLKEM_N, a, y, (unsigned int)k);
mlkem_invntt(u + i * MLKEM_N);
ret = mlkem_get_noise_i(prf, k, e1, coins, i, 0);
if (ret != 0) {
break;
}
#if defined(WOLFSSL_MLKEM_SMALL) || defined(WOLFSSL_MLKEM_NO_LARGE_CODE)
for (j = 0; j < MLKEM_N; ++j) {
sword16 t = (sword16)(u[i * MLKEM_N + j] + e1[j]);
u[i * MLKEM_N + j] = MLKEM_BARRETT_RED(t);
}
#else
for (j = 0; j < MLKEM_N; j += 8) {
sword16 t0 = (sword16)(u[i * MLKEM_N + j + 0] + e1[j + 0]);
sword16 t1 = (sword16)(u[i * MLKEM_N + j + 1] + e1[j + 1]);
sword16 t2 = (sword16)(u[i * MLKEM_N + j + 2] + e1[j + 2]);
sword16 t3 = (sword16)(u[i * MLKEM_N + j + 3] + e1[j + 3]);
sword16 t4 = (sword16)(u[i * MLKEM_N + j + 4] + e1[j + 4]);
sword16 t5 = (sword16)(u[i * MLKEM_N + j + 5] + e1[j + 5]);
sword16 t6 = (sword16)(u[i * MLKEM_N + j + 6] + e1[j + 6]);
sword16 t7 = (sword16)(u[i * MLKEM_N + j + 7] + e1[j + 7]);
u[i * MLKEM_N + j + 0] = MLKEM_BARRETT_RED(t0);
u[i * MLKEM_N + j + 1] = MLKEM_BARRETT_RED(t1);
u[i * MLKEM_N + j + 2] = MLKEM_BARRETT_RED(t2);
u[i * MLKEM_N + j + 3] = MLKEM_BARRETT_RED(t3);
u[i * MLKEM_N + j + 4] = MLKEM_BARRETT_RED(t4);
u[i * MLKEM_N + j + 5] = MLKEM_BARRETT_RED(t5);
u[i * MLKEM_N + j + 6] = MLKEM_BARRETT_RED(t6);
u[i * MLKEM_N + j + 7] = MLKEM_BARRETT_RED(t7);
}
#endif
}
mlkem_pointwise_acc_mont(v, pub, y, (unsigned int)k);
mlkem_invntt(v);
mlkem_from_msg(m, msg);
coins[WC_ML_KEM_SYM_SZ] = (byte)(2 * k);
ret = mlkem_get_noise_eta2_c(prf, e2, coins);
if (ret == 0) {
#if defined(WOLFSSL_MLKEM_SMALL) || defined(WOLFSSL_MLKEM_NO_LARGE_CODE)
for (i = 0; i < MLKEM_N; ++i) {
sword16 t = (sword16)(v[i] + e2[i] + m[i]);
v[i] = MLKEM_BARRETT_RED(t);
}
#else
for (i = 0; i < MLKEM_N; i += 8) {
sword16 t0 = (sword16)(v[i + 0] + e2[i + 0] + m[i + 0]);
sword16 t1 = (sword16)(v[i + 1] + e2[i + 1] + m[i + 1]);
sword16 t2 = (sword16)(v[i + 2] + e2[i + 2] + m[i + 2]);
sword16 t3 = (sword16)(v[i + 3] + e2[i + 3] + m[i + 3]);
sword16 t4 = (sword16)(v[i + 4] + e2[i + 4] + m[i + 4]);
sword16 t5 = (sword16)(v[i + 5] + e2[i + 5] + m[i + 5]);
sword16 t6 = (sword16)(v[i + 6] + e2[i + 6] + m[i + 6]);
sword16 t7 = (sword16)(v[i + 7] + e2[i + 7] + m[i + 7]);
v[i + 0] = MLKEM_BARRETT_RED(t0);
v[i + 1] = MLKEM_BARRETT_RED(t1);
v[i + 2] = MLKEM_BARRETT_RED(t2);
v[i + 3] = MLKEM_BARRETT_RED(t3);
v[i + 4] = MLKEM_BARRETT_RED(t4);
v[i + 5] = MLKEM_BARRETT_RED(t5);
v[i + 6] = MLKEM_BARRETT_RED(t6);
v[i + 7] = MLKEM_BARRETT_RED(t7);
}
#endif
}
return ret;
}
#endif
#endif
#ifndef WOLFSSL_MLKEM_NO_DECAPSULATE
static void mlkem_decapsulate_c(const sword16* s, sword16* w, sword16* u,
const sword16* v, int k)
{
int i;
for (i = 0; i < k; ++i) {
mlkem_ntt(u + i * MLKEM_N);
}
mlkem_pointwise_acc_mont(w, s, u, (unsigned int)k);
mlkem_invntt(w);
for (i = 0; i < MLKEM_N; ++i) {
sword16 t = (sword16)(v[i] - w[i]);
w[i] = MLKEM_BARRETT_RED(t);
}
}
void mlkem_decapsulate(const sword16* s, sword16* w, sword16* u,
const sword16* v, int k)
{
#ifdef USE_INTEL_SPEEDUP
if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) {
mlkem_decapsulate_avx2(s, w, u, v, k);
RESTORE_VECTOR_REGISTERS();
}
else
#endif
{
mlkem_decapsulate_c(s, w, u, v, k);
}
}
#endif
#endif
#if defined(USE_INTEL_SPEEDUP) && !defined(WC_SHA3_NO_ASM)
#if defined(WOLFSSL_KYBER512) || defined(WOLFSSL_WC_ML_KEM_512)
static int mlkem_gen_matrix_k2_avx2(sword16* a, byte* seed, int transposed)
{
int i;
#ifdef WOLFSSL_SMALL_STACK
byte *rand = NULL;
word64 *state = NULL;
#else
byte rand[4 * GEN_MATRIX_SIZE + 2];
word64 state[25 * 4];
#endif
unsigned int ctr0;
unsigned int ctr1;
unsigned int ctr2;
unsigned int ctr3;
byte* p;
#ifdef WOLFSSL_SMALL_STACK
rand = (byte*)XMALLOC(4 * GEN_MATRIX_SIZE + 2, NULL,
DYNAMIC_TYPE_TMP_BUFFER);
state = (word64*)XMALLOC(sizeof(word64) * 25 * 4, NULL,
DYNAMIC_TYPE_TMP_BUFFER);
if ((rand == NULL) || (state == NULL)) {
XFREE(rand, NULL, DYNAMIC_TYPE_TMP_BUFFER);
XFREE(state, NULL, DYNAMIC_TYPE_TMP_BUFFER);
return MEMORY_E;
}
#endif
rand[4 * GEN_MATRIX_SIZE + 0] = 0xff;
rand[4 * GEN_MATRIX_SIZE + 1] = 0xff;
if (!transposed) {
state[4*4 + 0] = 0x1f0000 + 0x000;
state[4*4 + 1] = 0x1f0000 + 0x001;
state[4*4 + 2] = 0x1f0000 + 0x100;
state[4*4 + 3] = 0x1f0000 + 0x101;
}
else {
state[4*4 + 0] = 0x1f0000 + 0x000;
state[4*4 + 1] = 0x1f0000 + 0x100;
state[4*4 + 2] = 0x1f0000 + 0x001;
state[4*4 + 3] = 0x1f0000 + 0x101;
}
sha3_128_blocksx4_seed_avx2(state, seed);
mlkem_redistribute_21_rand_avx2(state, rand + 0 * GEN_MATRIX_SIZE,
rand + 1 * GEN_MATRIX_SIZE, rand + 2 * GEN_MATRIX_SIZE,
rand + 3 * GEN_MATRIX_SIZE);
for (i = SHA3_128_BYTES; i < GEN_MATRIX_SIZE; i += SHA3_128_BYTES) {
sha3_blocksx4_avx2(state);
mlkem_redistribute_21_rand_avx2(state, rand + i + 0 * GEN_MATRIX_SIZE,
rand + i + 1 * GEN_MATRIX_SIZE, rand + i + 2 * GEN_MATRIX_SIZE,
rand + i + 3 * GEN_MATRIX_SIZE);
}
p = rand;
ctr0 = mlkem_rej_uniform_n_avx2(a + 0 * MLKEM_N, MLKEM_N, p,
GEN_MATRIX_SIZE);
p += GEN_MATRIX_SIZE;
ctr1 = mlkem_rej_uniform_n_avx2(a + 1 * MLKEM_N, MLKEM_N, p,
GEN_MATRIX_SIZE);
p += GEN_MATRIX_SIZE;
ctr2 = mlkem_rej_uniform_n_avx2(a + 2 * MLKEM_N, MLKEM_N, p,
GEN_MATRIX_SIZE);
p += GEN_MATRIX_SIZE;
ctr3 = mlkem_rej_uniform_n_avx2(a + 3 * MLKEM_N, MLKEM_N, p,
GEN_MATRIX_SIZE);
while ((ctr0 < MLKEM_N) || (ctr1 < MLKEM_N) || (ctr2 < MLKEM_N) ||
(ctr3 < MLKEM_N)) {
sha3_blocksx4_avx2(state);
mlkem_redistribute_21_rand_avx2(state, rand + 0 * GEN_MATRIX_SIZE,
rand + 1 * GEN_MATRIX_SIZE, rand + 2 * GEN_MATRIX_SIZE,
rand + 3 * GEN_MATRIX_SIZE);
p = rand;
ctr0 += mlkem_rej_uniform_avx2(a + 0 * MLKEM_N + ctr0, MLKEM_N - ctr0,
p, XOF_BLOCK_SIZE);
p += GEN_MATRIX_SIZE;
ctr1 += mlkem_rej_uniform_avx2(a + 1 * MLKEM_N + ctr1, MLKEM_N - ctr1,
p, XOF_BLOCK_SIZE);
p += GEN_MATRIX_SIZE;
ctr2 += mlkem_rej_uniform_avx2(a + 2 * MLKEM_N + ctr2, MLKEM_N - ctr2,
p, XOF_BLOCK_SIZE);
p += GEN_MATRIX_SIZE;
ctr3 += mlkem_rej_uniform_avx2(a + 3 * MLKEM_N + ctr3, MLKEM_N - ctr3,
p, XOF_BLOCK_SIZE);
}
WC_FREE_VAR_EX(rand, NULL, DYNAMIC_TYPE_TMP_BUFFER);
WC_FREE_VAR_EX(state, NULL, DYNAMIC_TYPE_TMP_BUFFER);
return 0;
}
#endif
#if defined(WOLFSSL_KYBER768) || defined(WOLFSSL_WC_ML_KEM_768)
static int mlkem_gen_matrix_k3_avx2(sword16* a, byte* seed, int transposed)
{
int i;
int k;
#ifdef WOLFSSL_SMALL_STACK
byte *rand = NULL;
word64 *state = NULL;
#else
byte rand[4 * GEN_MATRIX_SIZE + 2];
word64 state[25 * 4];
#endif
unsigned int ctr0;
unsigned int ctr1;
unsigned int ctr2;
unsigned int ctr3;
byte* p;
#ifdef WOLFSSL_SMALL_STACK
rand = (byte*)XMALLOC(4 * GEN_MATRIX_SIZE + 2, NULL,
DYNAMIC_TYPE_TMP_BUFFER);
state = (word64*)XMALLOC(sizeof(word64) * 25 * 4, NULL,
DYNAMIC_TYPE_TMP_BUFFER);
if ((rand == NULL) || (state == NULL)) {
XFREE(rand, NULL, DYNAMIC_TYPE_TMP_BUFFER);
XFREE(state, NULL, DYNAMIC_TYPE_TMP_BUFFER);
return MEMORY_E;
}
#endif
rand[4 * GEN_MATRIX_SIZE + 0] = 0xff;
rand[4 * GEN_MATRIX_SIZE + 1] = 0xff;
for (k = 0; k < 2; k++) {
for (i = 0; i < 4; i++) {
if (!transposed) {
state[4*4 + i] = (word32)(0x1f0000 + (((k*4+i)/3) << 8) +
((k*4+i)%3));
}
else {
state[4*4 + i] = (word32)(0x1f0000 + (((k*4+i)%3) << 8) +
((k*4+i)/3));
}
}
sha3_128_blocksx4_seed_avx2(state, seed);
mlkem_redistribute_21_rand_avx2(state,
rand + 0 * GEN_MATRIX_SIZE, rand + 1 * GEN_MATRIX_SIZE,
rand + 2 * GEN_MATRIX_SIZE, rand + 3 * GEN_MATRIX_SIZE);
for (i = SHA3_128_BYTES; i < GEN_MATRIX_SIZE; i += SHA3_128_BYTES) {
sha3_blocksx4_avx2(state);
mlkem_redistribute_21_rand_avx2(state,
rand + i + 0 * GEN_MATRIX_SIZE, rand + i + 1 * GEN_MATRIX_SIZE,
rand + i + 2 * GEN_MATRIX_SIZE, rand + i + 3 * GEN_MATRIX_SIZE);
}
p = rand;
ctr0 = mlkem_rej_uniform_n_avx2(a + 0 * MLKEM_N, MLKEM_N, p,
GEN_MATRIX_SIZE);
p += GEN_MATRIX_SIZE;
ctr1 = mlkem_rej_uniform_n_avx2(a + 1 * MLKEM_N, MLKEM_N, p,
GEN_MATRIX_SIZE);
p += GEN_MATRIX_SIZE;
ctr2 = mlkem_rej_uniform_n_avx2(a + 2 * MLKEM_N, MLKEM_N, p,
GEN_MATRIX_SIZE);
p += GEN_MATRIX_SIZE;
ctr3 = mlkem_rej_uniform_n_avx2(a + 3 * MLKEM_N, MLKEM_N, p,
GEN_MATRIX_SIZE);
while ((ctr0 < MLKEM_N) || (ctr1 < MLKEM_N) || (ctr2 < MLKEM_N) ||
(ctr3 < MLKEM_N)) {
sha3_blocksx4_avx2(state);
mlkem_redistribute_21_rand_avx2(state, rand + 0 * GEN_MATRIX_SIZE,
rand + 1 * GEN_MATRIX_SIZE, rand + 2 * GEN_MATRIX_SIZE,
rand + 3 * GEN_MATRIX_SIZE);
p = rand;
ctr0 += mlkem_rej_uniform_avx2(a + 0 * MLKEM_N + ctr0,
MLKEM_N - ctr0, p, XOF_BLOCK_SIZE);
p += GEN_MATRIX_SIZE;
ctr1 += mlkem_rej_uniform_avx2(a + 1 * MLKEM_N + ctr1,
MLKEM_N - ctr1, p, XOF_BLOCK_SIZE);
p += GEN_MATRIX_SIZE;
ctr2 += mlkem_rej_uniform_avx2(a + 2 * MLKEM_N + ctr2,
MLKEM_N - ctr2, p, XOF_BLOCK_SIZE);
p += GEN_MATRIX_SIZE;
ctr3 += mlkem_rej_uniform_avx2(a + 3 * MLKEM_N + ctr3,
MLKEM_N - ctr3, p, XOF_BLOCK_SIZE);
}
a += 4 * MLKEM_N;
}
readUnalignedWords64(state, seed, 4);
state[4] = 0x1f0000 + (2 << 8) + 2;
XMEMSET(state + 5, 0, sizeof(*state) * (25 - 5));
state[20] = W64LIT(0x8000000000000000);
for (i = 0; i < GEN_MATRIX_SIZE; i += SHA3_128_BYTES) {
#ifndef WC_SHA3_NO_ASM
if (IS_INTEL_BMI2(cpuid_flags)) {
sha3_block_bmi2(state);
}
else if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0))
{
sha3_block_avx2(state);
RESTORE_VECTOR_REGISTERS();
}
else
#endif
{
BlockSha3(state);
}
XMEMCPY(rand + i, state, SHA3_128_BYTES);
}
ctr0 = mlkem_rej_uniform_n_avx2(a, MLKEM_N, rand, GEN_MATRIX_SIZE);
while (ctr0 < MLKEM_N) {
#ifndef WC_SHA3_NO_ASM
if (IS_INTEL_BMI2(cpuid_flags)) {
sha3_block_bmi2(state);
}
else if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0))
{
sha3_block_avx2(state);
RESTORE_VECTOR_REGISTERS();
}
else
#endif
{
BlockSha3(state);
}
XMEMCPY(rand, state, SHA3_128_BYTES);
ctr0 += mlkem_rej_uniform_avx2(a + ctr0, MLKEM_N - ctr0, rand,
XOF_BLOCK_SIZE);
}
WC_FREE_VAR_EX(rand, NULL, DYNAMIC_TYPE_TMP_BUFFER);
WC_FREE_VAR_EX(state, NULL, DYNAMIC_TYPE_TMP_BUFFER);
return 0;
}
#endif
#if defined(WOLFSSL_KYBER1024) || defined(WOLFSSL_WC_ML_KEM_1024)
static int mlkem_gen_matrix_k4_avx2(sword16* a, byte* seed, int transposed)
{
int i;
int k;
#ifdef WOLFSSL_SMALL_STACK
byte *rand = NULL;
word64 *state = NULL;
#else
byte rand[4 * GEN_MATRIX_SIZE + 2];
word64 state[25 * 4];
#endif
unsigned int ctr0;
unsigned int ctr1;
unsigned int ctr2;
unsigned int ctr3;
byte* p;
#ifdef WOLFSSL_SMALL_STACK
rand = (byte*)XMALLOC(4 * GEN_MATRIX_SIZE + 2, NULL,
DYNAMIC_TYPE_TMP_BUFFER);
state = (word64*)XMALLOC(sizeof(word64) * 25 * 4, NULL,
DYNAMIC_TYPE_TMP_BUFFER);
if ((rand == NULL) || (state == NULL)) {
XFREE(rand, NULL, DYNAMIC_TYPE_TMP_BUFFER);
XFREE(state, NULL, DYNAMIC_TYPE_TMP_BUFFER);
return MEMORY_E;
}
#endif
rand[4 * GEN_MATRIX_SIZE + 0] = 0xff;
rand[4 * GEN_MATRIX_SIZE + 1] = 0xff;
for (k = 0; k < 4; k++) {
for (i = 0; i < 4; i++) {
if (!transposed) {
state[4*4 + i] = (word32)(0x1f0000 + (k << 8) + i);
}
else {
state[4*4 + i] = (word32)(0x1f0000 + (i << 8) + k);
}
}
sha3_128_blocksx4_seed_avx2(state, seed);
mlkem_redistribute_21_rand_avx2(state,
rand + 0 * GEN_MATRIX_SIZE, rand + 1 * GEN_MATRIX_SIZE,
rand + 2 * GEN_MATRIX_SIZE, rand + 3 * GEN_MATRIX_SIZE);
for (i = SHA3_128_BYTES; i < GEN_MATRIX_SIZE; i += SHA3_128_BYTES) {
sha3_blocksx4_avx2(state);
mlkem_redistribute_21_rand_avx2(state,
rand + i + 0 * GEN_MATRIX_SIZE, rand + i + 1 * GEN_MATRIX_SIZE,
rand + i + 2 * GEN_MATRIX_SIZE, rand + i + 3 * GEN_MATRIX_SIZE);
}
p = rand;
ctr0 = mlkem_rej_uniform_n_avx2(a + 0 * MLKEM_N, MLKEM_N, p,
GEN_MATRIX_SIZE);
p += GEN_MATRIX_SIZE;
ctr1 = mlkem_rej_uniform_n_avx2(a + 1 * MLKEM_N, MLKEM_N, p,
GEN_MATRIX_SIZE);
p += GEN_MATRIX_SIZE;
ctr2 = mlkem_rej_uniform_n_avx2(a + 2 * MLKEM_N, MLKEM_N, p,
GEN_MATRIX_SIZE);
p += GEN_MATRIX_SIZE;
ctr3 = mlkem_rej_uniform_n_avx2(a + 3 * MLKEM_N, MLKEM_N, p,
GEN_MATRIX_SIZE);
while ((ctr0 < MLKEM_N) || (ctr1 < MLKEM_N) || (ctr2 < MLKEM_N) ||
(ctr3 < MLKEM_N)) {
sha3_blocksx4_avx2(state);
mlkem_redistribute_21_rand_avx2(state, rand + 0 * GEN_MATRIX_SIZE,
rand + 1 * GEN_MATRIX_SIZE, rand + 2 * GEN_MATRIX_SIZE,
rand + 3 * GEN_MATRIX_SIZE);
p = rand;
ctr0 += mlkem_rej_uniform_avx2(a + 0 * MLKEM_N + ctr0,
MLKEM_N - ctr0, p, XOF_BLOCK_SIZE);
p += GEN_MATRIX_SIZE;
ctr1 += mlkem_rej_uniform_avx2(a + 1 * MLKEM_N + ctr1,
MLKEM_N - ctr1, p, XOF_BLOCK_SIZE);
p += GEN_MATRIX_SIZE;
ctr2 += mlkem_rej_uniform_avx2(a + 2 * MLKEM_N + ctr2,
MLKEM_N - ctr2, p, XOF_BLOCK_SIZE);
p += GEN_MATRIX_SIZE;
ctr3 += mlkem_rej_uniform_avx2(a + 3 * MLKEM_N + ctr3,
MLKEM_N - ctr3, p, XOF_BLOCK_SIZE);
}
a += 4 * MLKEM_N;
}
WC_FREE_VAR_EX(rand, NULL, DYNAMIC_TYPE_TMP_BUFFER);
WC_FREE_VAR_EX(state, NULL, DYNAMIC_TYPE_TMP_BUFFER);
return 0;
}
#endif
#elif defined(WOLFSSL_ARMASM) && defined(__aarch64__)
#if defined(WOLFSSL_KYBER512) || defined(WOLFSSL_WC_ML_KEM_512)
static int mlkem_gen_matrix_k2_aarch64(sword16* a, byte* seed, int transposed)
{
word64 state[3 * 25];
word64* st = (word64*)state;
unsigned int ctr0;
unsigned int ctr1;
unsigned int ctr2;
byte* p;
if (!transposed) {
state[0*25 + 4] = 0x1f0000 + (0 << 8) + 0;
state[1*25 + 4] = 0x1f0000 + (0 << 8) + 1;
state[2*25 + 4] = 0x1f0000 + (1 << 8) + 0;
}
else {
state[0*25 + 4] = 0x1f0000 + (0 << 8) + 0;
state[1*25 + 4] = 0x1f0000 + (1 << 8) + 0;
state[2*25 + 4] = 0x1f0000 + (0 << 8) + 1;
}
mlkem_shake128_blocksx3_seed_neon(state, seed);
p = (byte*)st;
ctr0 = mlkem_rej_uniform_neon(a + 0 * MLKEM_N, MLKEM_N, p, XOF_BLOCK_SIZE);
p += 25 * 8;
ctr1 = mlkem_rej_uniform_neon(a + 1 * MLKEM_N, MLKEM_N, p, XOF_BLOCK_SIZE);
p += 25 * 8;
ctr2 = mlkem_rej_uniform_neon(a + 2 * MLKEM_N, MLKEM_N, p, XOF_BLOCK_SIZE);
while ((ctr0 < MLKEM_N) || (ctr1 < MLKEM_N) || (ctr2 < MLKEM_N)) {
mlkem_sha3_blocksx3_neon(st);
p = (byte*)st;
ctr0 += mlkem_rej_uniform_neon(a + 0 * MLKEM_N + ctr0, MLKEM_N - ctr0,
p, XOF_BLOCK_SIZE);
p += 25 * 8;
ctr1 += mlkem_rej_uniform_neon(a + 1 * MLKEM_N + ctr1, MLKEM_N - ctr1,
p, XOF_BLOCK_SIZE);
p += 25 * 8;
ctr2 += mlkem_rej_uniform_neon(a + 2 * MLKEM_N + ctr2, MLKEM_N - ctr2,
p, XOF_BLOCK_SIZE);
}
a += 3 * MLKEM_N;
readUnalignedWords64(state, seed, 4);
state[4] = 0x1f0000 + (1 << 8) + 1;
XMEMSET(state + 5, 0, sizeof(*state) * (25 - 5));
state[20] = W64LIT(0x8000000000000000);
BlockSha3(state);
p = (byte*)state;
ctr0 = mlkem_rej_uniform_neon(a, MLKEM_N, p, XOF_BLOCK_SIZE);
while (ctr0 < MLKEM_N) {
BlockSha3(state);
ctr0 += mlkem_rej_uniform_neon(a + ctr0, MLKEM_N - ctr0, p,
XOF_BLOCK_SIZE);
}
return 0;
}
#endif
#if defined(WOLFSSL_KYBER768) || defined(WOLFSSL_WC_ML_KEM_768)
static int mlkem_gen_matrix_k3_aarch64(sword16* a, byte* seed, int transposed)
{
int i;
int k;
word64 state[3 * 25];
word64* st = (word64*)state;
unsigned int ctr0;
unsigned int ctr1;
unsigned int ctr2;
byte* p;
for (k = 0; k < 3; k++) {
for (i = 0; i < 3; i++) {
if (!transposed) {
state[i*25 + 4] = 0x1f0000 + ((k << 8) + i);
}
else {
state[i*25 + 4] = 0x1f0000 + ((i << 8) + k);
}
}
mlkem_shake128_blocksx3_seed_neon(state, seed);
p = (byte*)st;
ctr0 = mlkem_rej_uniform_neon(a + 0 * MLKEM_N, MLKEM_N, p,
XOF_BLOCK_SIZE);
p += 25 * 8;
ctr1 = mlkem_rej_uniform_neon(a + 1 * MLKEM_N, MLKEM_N, p,
XOF_BLOCK_SIZE);
p += 25 * 8;
ctr2 = mlkem_rej_uniform_neon(a + 2 * MLKEM_N, MLKEM_N, p,
XOF_BLOCK_SIZE);
while ((ctr0 < MLKEM_N) || (ctr1 < MLKEM_N) || (ctr2 < MLKEM_N)) {
mlkem_sha3_blocksx3_neon(st);
p = (byte*)st;
ctr0 += mlkem_rej_uniform_neon(a + 0 * MLKEM_N + ctr0,
MLKEM_N - ctr0, p, XOF_BLOCK_SIZE);
p += 25 * 8;
ctr1 += mlkem_rej_uniform_neon(a + 1 * MLKEM_N + ctr1,
MLKEM_N - ctr1, p, XOF_BLOCK_SIZE);
p += 25 * 8;
ctr2 += mlkem_rej_uniform_neon(a + 2 * MLKEM_N + ctr2,
MLKEM_N - ctr2, p, XOF_BLOCK_SIZE);
}
a += 3 * MLKEM_N;
}
return 0;
}
#endif
#if defined(WOLFSSL_KYBER1024) || defined(WOLFSSL_WC_ML_KEM_1024)
static int mlkem_gen_matrix_k4_aarch64(sword16* a, byte* seed, int transposed)
{
int i;
int k;
word64 state[3 * 25];
word64* st = (word64*)state;
unsigned int ctr0;
unsigned int ctr1;
unsigned int ctr2;
byte* p;
for (k = 0; k < 5; k++) {
for (i = 0; i < 3; i++) {
byte bi = ((k * 3) + i) / 4;
byte bj = ((k * 3) + i) % 4;
if (!transposed) {
state[i*25 + 4] = 0x1f0000 + (bi << 8) + bj;
}
else {
state[i*25 + 4] = 0x1f0000 + (bj << 8) + bi;
}
}
mlkem_shake128_blocksx3_seed_neon(state, seed);
p = (byte*)st;
ctr0 = mlkem_rej_uniform_neon(a + 0 * MLKEM_N, MLKEM_N, p,
XOF_BLOCK_SIZE);
p += 25 * 8;
ctr1 = mlkem_rej_uniform_neon(a + 1 * MLKEM_N, MLKEM_N, p,
XOF_BLOCK_SIZE);
p += 25 * 8;
ctr2 = mlkem_rej_uniform_neon(a + 2 * MLKEM_N, MLKEM_N, p,
XOF_BLOCK_SIZE);
while ((ctr0 < MLKEM_N) || (ctr1 < MLKEM_N) || (ctr2 < MLKEM_N)) {
mlkem_sha3_blocksx3_neon(st);
p = (byte*)st;
ctr0 += mlkem_rej_uniform_neon(a + 0 * MLKEM_N + ctr0,
MLKEM_N - ctr0, p, XOF_BLOCK_SIZE);
p += 25 * 8;
ctr1 += mlkem_rej_uniform_neon(a + 1 * MLKEM_N + ctr1,
MLKEM_N - ctr1, p, XOF_BLOCK_SIZE);
p += 25 * 8;
ctr2 += mlkem_rej_uniform_neon(a + 2 * MLKEM_N + ctr2,
MLKEM_N - ctr2, p, XOF_BLOCK_SIZE);
}
a += 3 * MLKEM_N;
}
readUnalignedWords64(state, seed, 4);
state[4] = 0x1f0000 + (3 << 8) + 3;
XMEMSET(state + 5, 0, sizeof(*state) * (25 - 5));
state[20] = W64LIT(0x8000000000000000);
BlockSha3(state);
p = (byte*)state;
ctr0 = mlkem_rej_uniform_neon(a, MLKEM_N, p, XOF_BLOCK_SIZE);
while (ctr0 < MLKEM_N) {
BlockSha3(state);
ctr0 += mlkem_rej_uniform_neon(a + ctr0, MLKEM_N - ctr0, p,
XOF_BLOCK_SIZE);
}
return 0;
}
#endif
#endif
#if !(defined(WOLFSSL_ARMASM) && defined(__aarch64__))
static int mlkem_xof_absorb(wc_Shake* shake128, byte* seed, int len)
{
int ret;
ret = wc_InitShake128(shake128, NULL, INVALID_DEVID);
if (ret == 0) {
ret = wc_Shake128_Absorb(shake128, seed, (word32)len);
}
return ret;
}
static int mlkem_xof_squeezeblocks(wc_Shake* shake128, byte* out, int blocks)
{
return wc_Shake128_SqueezeBlocks(shake128, out, (word32)blocks);
}
#endif
int mlkem_hash_new(wc_Sha3* hash, void* heap, int devId)
{
return wc_InitSha3_256(hash, heap, devId);
}
void mlkem_hash_free(wc_Sha3* hash)
{
wc_Sha3_256_Free(hash);
}
int mlkem_hash256(wc_Sha3* hash, const byte* data, word32 dataLen, byte* out)
{
int ret;
ret = wc_Sha3_256_Update(hash, data, dataLen);
if (ret == 0) {
ret = wc_Sha3_256_Final(hash, out);
}
return ret;
}
int mlkem_hash512(wc_Sha3* hash, const byte* data1, word32 data1Len,
const byte* data2, word32 data2Len, byte* out)
{
int ret;
ret = wc_Sha3_512_Update(hash, data1, data1Len);
if ((ret == 0) && (data2Len > 0)) {
ret = wc_Sha3_512_Update(hash, data2, data2Len);
}
if (ret == 0) {
ret = wc_Sha3_512_Final(hash, out);
}
return ret;
}
void mlkem_prf_init(wc_Shake* prf)
{
wc_InitShake256(prf, NULL, 0);
}
int mlkem_prf_new(wc_Shake* prf, void* heap, int devId)
{
return wc_InitShake256(prf, heap, devId);
}
void mlkem_prf_free(wc_Shake* prf)
{
wc_Shake256_Free(prf);
}
#if !(defined(WOLFSSL_ARMASM) && defined(__aarch64__))
static int mlkem_prf(wc_Shake* shake256, byte* out, unsigned int outLen,
const byte* key)
{
#ifdef USE_INTEL_SPEEDUP
word64 state[25];
(void)shake256;
readUnalignedWords64(state, key, WC_ML_KEM_SYM_SZ / sizeof(word64));
state[WC_ML_KEM_SYM_SZ / 8] = 0x1f00 | key[WC_ML_KEM_SYM_SZ];
XMEMSET(state + WC_ML_KEM_SYM_SZ / 8 + 1, 0,
(25 - WC_ML_KEM_SYM_SZ / 8 - 1) * sizeof(word64));
state[WC_SHA3_256_COUNT - 1] = W64LIT(0x8000000000000000);
while (outLen > 0) {
unsigned int len = min(outLen, WC_SHA3_256_BLOCK_SIZE);
#ifndef WC_SHA3_NO_ASM
if (IS_INTEL_BMI2(cpuid_flags)) {
sha3_block_bmi2(state);
}
else if (IS_INTEL_AVX2(cpuid_flags) &&
(SAVE_VECTOR_REGISTERS2() == 0)) {
sha3_block_avx2(state);
RESTORE_VECTOR_REGISTERS();
}
else
#endif
{
BlockSha3(state);
}
XMEMCPY(out, state, len);
out += len;
outLen -= len;
}
return 0;
#else
int ret;
ret = wc_Shake256_Update(shake256, key, WC_ML_KEM_SYM_SZ + 1);
if (ret == 0) {
ret = wc_Shake256_Final(shake256, out, outLen);
}
return ret;
#endif
}
#endif
#ifdef WOLFSSL_MLKEM_KYBER
#ifdef USE_INTEL_SPEEDUP
int mlkem_kdf(byte* seed, int seedLen, byte* out, int outLen)
{
word64 state[25];
word32 len64 = seedLen / 8;
readUnalignedWords64(state, seed, len64);
state[len64] = 0x1f;
XMEMSET(state + len64 + 1, 0, (25 - len64 - 1) * sizeof(word64));
state[WC_SHA3_256_COUNT - 1] = W64LIT(0x8000000000000000);
#ifndef WC_SHA3_NO_ASM
if (IS_INTEL_BMI2(cpuid_flags)) {
sha3_block_bmi2(state);
}
else if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) {
sha3_block_avx2(state);
RESTORE_VECTOR_REGISTERS();
}
else
#endif
{
BlockSha3(state);
}
XMEMCPY(out, state, outLen);
return 0;
}
#endif
#if defined(WOLFSSL_ARMASM) && defined(__aarch64__)
int mlkem_kdf(byte* seed, int seedLen, byte* out, int outLen)
{
word64 state[25];
word32 len64 = seedLen / 8;
readUnalignedWords64(state, seed, len64);
state[len64] = 0x1f;
XMEMSET(state + len64 + 1, 0, (25 - len64 - 1) * sizeof(word64));
state[WC_SHA3_256_COUNT - 1] = W64LIT(0x8000000000000000);
BlockSha3(state);
XMEMCPY(out, state, outLen);
return 0;
}
#endif
#endif
#ifndef WOLFSSL_NO_ML_KEM
int mlkem_derive_secret(wc_Shake* shake256, const byte* z, const byte* ct,
word32 ctSz, byte* ss)
{
int ret;
#ifdef USE_INTEL_SPEEDUP
XMEMCPY(shake256->t, z, WC_ML_KEM_SYM_SZ);
XMEMCPY(shake256->t + WC_ML_KEM_SYM_SZ, ct,
WC_SHA3_256_COUNT * 8 - WC_ML_KEM_SYM_SZ);
shake256->i = WC_ML_KEM_SYM_SZ + WC_SHA3_256_COUNT * 8 - WC_ML_KEM_SYM_SZ;
ct += WC_SHA3_256_COUNT * 8 - WC_ML_KEM_SYM_SZ;
ctSz -= WC_SHA3_256_COUNT * 8 - WC_ML_KEM_SYM_SZ;
ret = wc_Shake256_Update(shake256, ct, ctSz);
if (ret == 0) {
ret = wc_Shake256_Final(shake256, ss, WC_ML_KEM_SS_SZ);
}
#else
ret = wc_InitShake256(shake256, NULL, INVALID_DEVID);
if (ret == 0) {
ret = wc_Shake256_Update(shake256, z, WC_ML_KEM_SYM_SZ);
}
if (ret == 0) {
ret = wc_Shake256_Update(shake256, ct, ctSz);
}
if (ret == 0) {
ret = wc_Shake256_Final(shake256, ss, WC_ML_KEM_SS_SZ);
}
#endif
return ret;
}
#endif
#if !defined(WOLFSSL_ARMASM)
static unsigned int mlkem_rej_uniform_c(sword16* p, unsigned int len,
const byte* r, unsigned int rLen)
{
unsigned int i;
unsigned int j;
#if defined(WOLFSSL_MLKEM_SMALL) || !defined(WC_64BIT_CPU) || \
defined(BIG_ENDIAN_ORDER)
for (i = 0, j = 0; (i < len) && (j <= rLen - 3); j += 3) {
sword16 v0 = ((r[0] >> 0) | ((word16)r[1] << 8)) & 0xFFF;
sword16 v1 = ((r[1] >> 4) | ((word16)r[2] << 4)) & 0xFFF;
if (v0 < MLKEM_Q) {
p[i++] = v0;
}
if ((i < len) && (v1 < MLKEM_Q)) {
p[i++] = v1;
}
r += 3;
}
#else
unsigned int minJ;
minJ = len / 4 * 6;
if (minJ > rLen)
minJ = rLen;
i = 0;
for (j = 0; j < minJ; j += 6) {
word64 r_word = readUnalignedWord64(r);
sword16 v0 = r_word & 0xfff;
sword16 v1 = (r_word >> 12) & 0xfff;
sword16 v2 = (r_word >> 24) & 0xfff;
sword16 v3 = (r_word >> 36) & 0xfff;
p[i] = v0;
i += (v0 < MLKEM_Q);
p[i] = v1;
i += (v1 < MLKEM_Q);
p[i] = v2;
i += (v2 < MLKEM_Q);
p[i] = v3;
i += (v3 < MLKEM_Q);
r += 6;
}
if (j < rLen) {
for (; (i + 4 < len) && (j < rLen); j += 6) {
word64 r_word = readUnalignedWord64(r);
sword16 v0 = r_word & 0xfff;
sword16 v1 = (r_word >> 12) & 0xfff;
sword16 v2 = (r_word >> 24) & 0xfff;
sword16 v3 = (r_word >> 36) & 0xfff;
p[i] = v0;
i += (v0 < MLKEM_Q);
p[i] = v1;
i += (v1 < MLKEM_Q);
p[i] = v2;
i += (v2 < MLKEM_Q);
p[i] = v3;
i += (v3 < MLKEM_Q);
r += 6;
}
for (; (i < len) && (j < rLen); j += 6) {
word64 r_word = readUnalignedWord64(r);
sword16 v0 = r_word & 0xfff;
sword16 v1 = (r_word >> 12) & 0xfff;
sword16 v2 = (r_word >> 24) & 0xfff;
sword16 v3 = (r_word >> 36) & 0xfff;
if (v0 < MLKEM_Q) {
p[i++] = v0;
}
if ((i < len) && (v1 < MLKEM_Q)) {
p[i++] = v1;
}
if ((i < len) && (v2 < MLKEM_Q)) {
p[i++] = v2;
}
if ((i < len) && (v3 < MLKEM_Q)) {
p[i++] = v3;
}
r += 6;
}
}
#endif
return i;
}
#endif
#if !defined(WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM) || \
!defined(WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM)
#if !(defined(WOLFSSL_ARMASM) && defined(__aarch64__))
static int mlkem_gen_matrix_c(MLKEM_PRF_T* prf, sword16* a, int k, byte* seed,
int transposed)
{
#if defined(WOLFSSL_SMALL_STACK) && !defined(WOLFSSL_NO_MALLOC)
byte* rand;
#else
byte rand[GEN_MATRIX_SIZE + 2];
#endif
byte extSeed[WC_ML_KEM_SYM_SZ + 2];
int ret = 0;
int i;
XMEMCPY(extSeed, seed, WC_ML_KEM_SYM_SZ);
#if defined(WOLFSSL_SMALL_STACK) && !defined(WOLFSSL_NO_MALLOC)
rand = (byte*)XMALLOC(GEN_MATRIX_SIZE + 2, NULL, DYNAMIC_TYPE_TMP_BUFFER);
if (rand == NULL) {
ret = MEMORY_E;
}
#endif
#if !defined(WOLFSSL_MLKEM_SMALL) && defined(WC_64BIT_CPU)
if (ret == 0) {
rand[GEN_MATRIX_SIZE+0] = 0xff;
rand[GEN_MATRIX_SIZE+1] = 0xff;
}
#endif
for (i = 0; (ret == 0) && (i < k); i++, a += k * MLKEM_N) {
int j;
for (j = 0; (ret == 0) && (j < k); j++) {
if (transposed) {
extSeed[WC_ML_KEM_SYM_SZ + 0] = (byte)i;
extSeed[WC_ML_KEM_SYM_SZ + 1] = (byte)j;
}
else {
extSeed[WC_ML_KEM_SYM_SZ + 0] = (byte)j;
extSeed[WC_ML_KEM_SYM_SZ + 1] = (byte)i;
}
ret = mlkem_xof_absorb(prf, extSeed, sizeof(extSeed));
if (ret == 0) {
ret = mlkem_xof_squeezeblocks(prf, rand, GEN_MATRIX_NBLOCKS);
}
if (ret == 0) {
unsigned int ctr;
ctr = mlkem_rej_uniform_c(a + j * MLKEM_N, MLKEM_N, rand,
GEN_MATRIX_SIZE);
while (ctr < MLKEM_N) {
mlkem_xof_squeezeblocks(prf, rand, 1);
ctr += mlkem_rej_uniform_c(a + j * MLKEM_N + ctr,
MLKEM_N - ctr, rand, XOF_BLOCK_SIZE);
}
}
}
}
#if defined(WOLFSSL_SMALL_STACK) && !defined(WOLFSSL_NO_MALLOC)
XFREE(rand, NULL, DYNAMIC_TYPE_TMP_BUFFER);
#endif
return ret;
}
#endif
int mlkem_gen_matrix(MLKEM_PRF_T* prf, sword16* a, int k, byte* seed,
int transposed)
{
int ret;
#if defined(WOLFSSL_KYBER512) || defined(WOLFSSL_WC_ML_KEM_512)
if (k == WC_ML_KEM_512_K) {
#if defined(WOLFSSL_ARMASM) && defined(__aarch64__)
ret = mlkem_gen_matrix_k2_aarch64(a, seed, transposed);
#else
#if defined(USE_INTEL_SPEEDUP) && !defined(WC_SHA3_NO_ASM)
if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) {
ret = mlkem_gen_matrix_k2_avx2(a, seed, transposed);
RESTORE_VECTOR_REGISTERS();
}
else
#endif
{
ret = mlkem_gen_matrix_c(prf, a, WC_ML_KEM_512_K, seed, transposed);
}
#endif
}
else
#endif
#if defined(WOLFSSL_KYBER768) || defined(WOLFSSL_WC_ML_KEM_768)
if (k == WC_ML_KEM_768_K) {
#if defined(WOLFSSL_ARMASM) && defined(__aarch64__)
ret = mlkem_gen_matrix_k3_aarch64(a, seed, transposed);
#else
#if defined(USE_INTEL_SPEEDUP) && !defined(WC_SHA3_NO_ASM)
if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) {
ret = mlkem_gen_matrix_k3_avx2(a, seed, transposed);
RESTORE_VECTOR_REGISTERS();
}
else
#endif
{
ret = mlkem_gen_matrix_c(prf, a, WC_ML_KEM_768_K, seed, transposed);
}
#endif
}
else
#endif
#if defined(WOLFSSL_KYBER1024) || defined(WOLFSSL_WC_ML_KEM_1024)
if (k == WC_ML_KEM_1024_K) {
#if defined(WOLFSSL_ARMASM) && defined(__aarch64__)
ret = mlkem_gen_matrix_k4_aarch64(a, seed, transposed);
#else
#if defined(USE_INTEL_SPEEDUP) && !defined(WC_SHA3_NO_ASM)
if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) {
ret = mlkem_gen_matrix_k4_avx2(a, seed, transposed);
RESTORE_VECTOR_REGISTERS();
}
else
#endif
{
ret = mlkem_gen_matrix_c(prf, a, WC_ML_KEM_1024_K, seed,
transposed);
}
#endif
}
else
#endif
{
ret = BAD_STATE_E;
}
(void)prf;
return ret;
}
#endif
#if defined(WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM) || \
defined(WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM)
static int mlkem_gen_matrix_i(MLKEM_PRF_T* prf, sword16* a, int k, byte* seed,
int i, int transposed)
{
#if defined(WOLFSSL_SMALL_STACK) && !defined(WOLFSSL_NO_MALLOC)
byte* rand;
#else
byte rand[GEN_MATRIX_SIZE + 2];
#endif
byte extSeed[WC_ML_KEM_SYM_SZ + 2];
int ret = 0;
int j;
XMEMCPY(extSeed, seed, WC_ML_KEM_SYM_SZ);
#if defined(WOLFSSL_SMALL_STACK) && !defined(WOLFSSL_NO_MALLOC)
rand = (byte*)XMALLOC(GEN_MATRIX_SIZE + 2, NULL, DYNAMIC_TYPE_TMP_BUFFER);
if (rand == NULL) {
ret = MEMORY_E;
}
#endif
#if !defined(WOLFSSL_MLKEM_SMALL) && defined(WC_64BIT_CPU)
if (ret == 0) {
rand[GEN_MATRIX_SIZE+0] = 0xff;
rand[GEN_MATRIX_SIZE+1] = 0xff;
}
#endif
for (j = 0; (ret == 0) && (j < k); j++) {
if (transposed) {
extSeed[WC_ML_KEM_SYM_SZ + 0] = (byte)i;
extSeed[WC_ML_KEM_SYM_SZ + 1] = (byte)j;
}
else {
extSeed[WC_ML_KEM_SYM_SZ + 0] = (byte)j;
extSeed[WC_ML_KEM_SYM_SZ + 1] = (byte)i;
}
ret = mlkem_xof_absorb(prf, extSeed, sizeof(extSeed));
if (ret == 0) {
ret = mlkem_xof_squeezeblocks(prf, rand, GEN_MATRIX_NBLOCKS);
}
if (ret == 0) {
unsigned int ctr;
ctr = mlkem_rej_uniform_c(a + j * MLKEM_N, MLKEM_N, rand,
GEN_MATRIX_SIZE);
while (ctr < MLKEM_N) {
mlkem_xof_squeezeblocks(prf, rand, 1);
ctr += mlkem_rej_uniform_c(a + j * MLKEM_N + ctr,
MLKEM_N - ctr, rand, XOF_BLOCK_SIZE);
}
}
}
#if defined(WOLFSSL_SMALL_STACK) && !defined(WOLFSSL_NO_MALLOC)
XFREE(rand, NULL, DYNAMIC_TYPE_TMP_BUFFER);
#endif
return ret;
}
#endif
#define ETA2_SUB(d, i) \
(sword16)(((sword16)(((d) >> ((i) * 4 + 0)) & 0x3)) - \
((sword16)(((d) >> ((i) * 4 + 2)) & 0x3)))
static void mlkem_cbd_eta2(sword16* p, const byte* r)
{
unsigned int i;
#ifndef WORD64_AVAILABLE
for (i = 0; i < MLKEM_N; i += 8) {
#ifdef WOLFSSL_MLKEM_SMALL
unsigned int j;
#endif
#ifdef BIG_ENDIAN_ORDER
word32 t = ByteReverseWord32(*(word32*)r);
#else
word32 t = *(word32*)r;
#endif
word32 d;
d = (t >> 0) & 0x55555555;
d += (t >> 1) & 0x55555555;
#ifdef WOLFSSL_MLKEM_SMALL
for (j = 0; j < 8; j++) {
p[i + j] = ETA2_SUB(d, j);
}
#else
p[i + 0] = ETA2_SUB(d, 0);
p[i + 1] = ETA2_SUB(d, 1);
p[i + 2] = ETA2_SUB(d, 2);
p[i + 3] = ETA2_SUB(d, 3);
p[i + 4] = ETA2_SUB(d, 4);
p[i + 5] = ETA2_SUB(d, 5);
p[i + 6] = ETA2_SUB(d, 6);
p[i + 7] = ETA2_SUB(d, 7);
#endif
r += 4;
}
#else
for (i = 0; i < MLKEM_N; i += 16) {
#ifdef WOLFSSL_MLKEM_SMALL
unsigned int j;
#endif
#ifdef BIG_ENDIAN_ORDER
word64 t = ByteReverseWord64(readUnalignedWord64(r));
#else
word64 t = readUnalignedWord64(r);
#endif
word64 d;
d = (t >> 0) & 0x5555555555555555L;
d += (t >> 1) & 0x5555555555555555L;
#ifdef WOLFSSL_MLKEM_SMALL
for (j = 0; j < 16; j++) {
p[i + j] = ETA2_SUB(d, j);
}
#else
p[i + 0] = ETA2_SUB(d, 0);
p[i + 1] = ETA2_SUB(d, 1);
p[i + 2] = ETA2_SUB(d, 2);
p[i + 3] = ETA2_SUB(d, 3);
p[i + 4] = ETA2_SUB(d, 4);
p[i + 5] = ETA2_SUB(d, 5);
p[i + 6] = ETA2_SUB(d, 6);
p[i + 7] = ETA2_SUB(d, 7);
p[i + 8] = ETA2_SUB(d, 8);
p[i + 9] = ETA2_SUB(d, 9);
p[i + 10] = ETA2_SUB(d, 10);
p[i + 11] = ETA2_SUB(d, 11);
p[i + 12] = ETA2_SUB(d, 12);
p[i + 13] = ETA2_SUB(d, 13);
p[i + 14] = ETA2_SUB(d, 14);
p[i + 15] = ETA2_SUB(d, 15);
#endif
r += 8;
}
#endif
}
#if defined(WOLFSSL_KYBER512) || defined(WOLFSSL_WC_ML_KEM_512)
#define ETA3_SUB(d, i) \
(sword16)(((sword16)(((d) >> ((i) * 6 + 0)) & 0x7)) - \
((sword16)(((d) >> ((i) * 6 + 3)) & 0x7)))
static void mlkem_cbd_eta3(sword16* p, const byte* r)
{
unsigned int i;
#if defined(WOLFSSL_SMALL_STACK) || defined(WOLFSSL_MLKEM_NO_LARGE_CODE) || \
defined(BIG_ENDIAN_ORDER)
#ifndef WORD64_AVAILABLE
for (i = 0; i < MLKEM_N; i += 4) {
#ifdef WOLFSSL_MLKEM_SMALL
unsigned int j;
#endif
word32 t = (((word32)(r[0])) << 0) |
(((word32)(r[1])) << 8) |
(((word32)(r[2])) << 16);
word32 d;
d = (t >> 0) & 0x00249249;
d += (t >> 1) & 0x00249249;
d += (t >> 2) & 0x00249249;
#ifdef WOLFSSL_MLKEM_SMALL
for (j = 0; j < 4; j++) {
p[i + j] = ETA3_SUB(d, j);
}
#else
p[i + 0] = ETA3_SUB(d, 0);
p[i + 1] = ETA3_SUB(d, 1);
p[i + 2] = ETA3_SUB(d, 2);
p[i + 3] = ETA3_SUB(d, 3);
#endif
r += 3;
}
#else
for (i = 0; i < MLKEM_N; i += 8) {
#ifdef WOLFSSL_MLKEM_SMALL
unsigned int j;
#endif
word64 t = (((word64)(r[0])) << 0) |
(((word64)(r[1])) << 8) |
(((word64)(r[2])) << 16) |
(((word64)(r[3])) << 24) |
(((word64)(r[4])) << 32) |
(((word64)(r[5])) << 40);
word64 d;
d = (t >> 0) & 0x0000249249249249L;
d += (t >> 1) & 0x0000249249249249L;
d += (t >> 2) & 0x0000249249249249L;
#ifdef WOLFSSL_MLKEM_SMALL
for (j = 0; j < 8; j++) {
p[i + j] = ETA3_SUB(d, j);
}
#else
p[i + 0] = ETA3_SUB(d, 0);
p[i + 1] = ETA3_SUB(d, 1);
p[i + 2] = ETA3_SUB(d, 2);
p[i + 3] = ETA3_SUB(d, 3);
p[i + 4] = ETA3_SUB(d, 4);
p[i + 5] = ETA3_SUB(d, 5);
p[i + 6] = ETA3_SUB(d, 6);
p[i + 7] = ETA3_SUB(d, 7);
#endif
r += 6;
}
#endif
#else
for (i = 0; i < MLKEM_N; i += 16) {
const word32* r32 = (const word32*)r;
word32 t0 = r32[0] & 0xffffff;
word32 t1 = ((r32[0] >> 24) | (r32[1] << 8)) & 0xffffff;
word32 t2 = ((r32[1] >> 16) | (r32[2] << 16)) & 0xffffff;
word32 t3 = r32[2] >> 8 ;
word32 d0;
word32 d1;
word32 d2;
word32 d3;
d0 = (t0 >> 0) & 0x00249249;
d0 += (t0 >> 1) & 0x00249249;
d0 += (t0 >> 2) & 0x00249249;
d1 = (t1 >> 0) & 0x00249249;
d1 += (t1 >> 1) & 0x00249249;
d1 += (t1 >> 2) & 0x00249249;
d2 = (t2 >> 0) & 0x00249249;
d2 += (t2 >> 1) & 0x00249249;
d2 += (t2 >> 2) & 0x00249249;
d3 = (t3 >> 0) & 0x00249249;
d3 += (t3 >> 1) & 0x00249249;
d3 += (t3 >> 2) & 0x00249249;
p[i + 0] = ETA3_SUB(d0, 0);
p[i + 1] = ETA3_SUB(d0, 1);
p[i + 2] = ETA3_SUB(d0, 2);
p[i + 3] = ETA3_SUB(d0, 3);
p[i + 4] = ETA3_SUB(d1, 0);
p[i + 5] = ETA3_SUB(d1, 1);
p[i + 6] = ETA3_SUB(d1, 2);
p[i + 7] = ETA3_SUB(d1, 3);
p[i + 8] = ETA3_SUB(d2, 0);
p[i + 9] = ETA3_SUB(d2, 1);
p[i + 10] = ETA3_SUB(d2, 2);
p[i + 11] = ETA3_SUB(d2, 3);
p[i + 12] = ETA3_SUB(d3, 0);
p[i + 13] = ETA3_SUB(d3, 1);
p[i + 14] = ETA3_SUB(d3, 2);
p[i + 15] = ETA3_SUB(d3, 3);
r += 12;
}
#endif
}
#endif
#if !(defined(__aarch64__) && defined(WOLFSSL_ARMASM))
static int mlkem_get_noise_eta1_c(MLKEM_PRF_T* prf, sword16* p,
const byte* seed, byte eta1)
{
int ret;
(void)eta1;
#if defined(WOLFSSL_KYBER512) || defined(WOLFSSL_WC_ML_KEM_512)
if (eta1 == MLKEM_CBD_ETA3) {
byte rand[ETA3_RAND_SIZE];
ret = mlkem_prf(prf, rand, sizeof(rand), seed);
if (ret == 0) {
mlkem_cbd_eta3(p, rand);
}
}
else
#endif
{
byte rand[ETA2_RAND_SIZE];
ret = mlkem_prf(prf, rand, sizeof(rand), seed);
if (ret == 0) {
mlkem_cbd_eta2(p, rand);
}
}
return ret;
}
static int mlkem_get_noise_eta2_c(MLKEM_PRF_T* prf, sword16* p,
const byte* seed)
{
int ret;
byte rand[ETA2_RAND_SIZE];
ret = mlkem_prf(prf, rand, sizeof(rand), seed);
if (ret == 0) {
mlkem_cbd_eta2(p, rand);
}
return ret;
}
#endif
#if defined(USE_INTEL_SPEEDUP) && !defined(WC_SHA3_NO_ASM)
#define PRF_RAND_SZ (2 * SHA3_256_BYTES)
#if defined(WOLFSSL_KYBER768) || defined(WOLFSSL_WC_ML_KEM_768) || \
defined(WOLFSSL_KYBER1024) || defined(WOLFSSL_WC_ML_KEM_1024)
static void mlkem_get_noise_x4_eta2_avx2(byte* rand, byte* seed, byte o)
{
int i;
word64 state[25 * 4];
for (i = 0; i < 4; i++) {
state[4*4 + i] = (word32)(0x1f00 + i + o);
}
sha3_256_blocksx4_seed_avx2(state, seed);
mlkem_redistribute_16_rand_avx2(state, rand + 0 * ETA2_RAND_SIZE,
rand + 1 * ETA2_RAND_SIZE, rand + 2 * ETA2_RAND_SIZE,
rand + 3 * ETA2_RAND_SIZE);
}
#endif
#if defined(WOLFSSL_KYBER512) || defined(WOLFSSL_WC_ML_KEM_512) || \
defined(WOLFSSL_KYBER1024) || defined(WOLFSSL_WC_ML_KEM_1024)
static int mlkem_get_noise_eta2_avx2(MLKEM_PRF_T* prf, sword16* p,
const byte* seed)
{
word64 state[25];
(void)prf;
readUnalignedWords64(state, seed, WC_ML_KEM_SYM_SZ / sizeof(word64));
state[WC_ML_KEM_SYM_SZ / 8] = 0x1f00 | seed[WC_ML_KEM_SYM_SZ];
XMEMSET(state + WC_ML_KEM_SYM_SZ / 8 + 1, 0,
(25 - WC_ML_KEM_SYM_SZ / 8 - 1) * sizeof(word64));
state[WC_SHA3_256_COUNT - 1] = W64LIT(0x8000000000000000);
#ifndef WC_SHA3_NO_ASM
if (IS_INTEL_BMI2(cpuid_flags)) {
sha3_block_bmi2(state);
}
else if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) {
sha3_block_avx2(state);
RESTORE_VECTOR_REGISTERS();
}
else
#endif
{
BlockSha3(state);
}
mlkem_cbd_eta2_avx2(p, (byte*)state);
return 0;
}
#endif
#if defined(WOLFSSL_KYBER512) || defined(WOLFSSL_WC_ML_KEM_512)
static void mlkem_get_noise_x4_eta3_avx2(byte* rand, byte* seed)
{
word64 state[25 * 4];
int i;
state[4*4 + 0] = 0x1f00 + 0;
state[4*4 + 1] = 0x1f00 + 1;
state[4*4 + 2] = 0x1f00 + 2;
state[4*4 + 3] = 0x1f00 + 3;
sha3_256_blocksx4_seed_avx2(state, seed);
mlkem_redistribute_17_rand_avx2(state, rand + 0 * PRF_RAND_SZ,
rand + 1 * PRF_RAND_SZ, rand + 2 * PRF_RAND_SZ,
rand + 3 * PRF_RAND_SZ);
i = SHA3_256_BYTES;
sha3_blocksx4_avx2(state);
mlkem_redistribute_8_rand_avx2(state, rand + i + 0 * PRF_RAND_SZ,
rand + i + 1 * PRF_RAND_SZ, rand + i + 2 * PRF_RAND_SZ,
rand + i + 3 * PRF_RAND_SZ);
}
static int mlkem_get_noise_k2_avx2(MLKEM_PRF_T* prf, sword16* vec1,
sword16* vec2, sword16* poly, byte* seed)
{
int ret = 0;
WC_DECLARE_VAR(rand, byte, 4 * PRF_RAND_SZ, 0);
WC_ALLOC_VAR_EX(rand, byte, 4 * PRF_RAND_SZ, NULL, DYNAMIC_TYPE_TMP_BUFFER,
return MEMORY_E);
mlkem_get_noise_x4_eta3_avx2(rand, seed);
mlkem_cbd_eta3_avx2(vec1 , rand + 0 * PRF_RAND_SZ);
mlkem_cbd_eta3_avx2(vec1 + MLKEM_N, rand + 1 * PRF_RAND_SZ);
if (poly == NULL) {
mlkem_cbd_eta3_avx2(vec2 , rand + 2 * PRF_RAND_SZ);
mlkem_cbd_eta3_avx2(vec2 + MLKEM_N, rand + 3 * PRF_RAND_SZ);
}
else {
mlkem_cbd_eta2_avx2(vec2 , rand + 2 * PRF_RAND_SZ);
mlkem_cbd_eta2_avx2(vec2 + MLKEM_N, rand + 3 * PRF_RAND_SZ);
seed[WC_ML_KEM_SYM_SZ] = 4;
ret = mlkem_get_noise_eta2_avx2(prf, poly, seed);
}
WC_FREE_VAR_EX(rand, NULL, DYNAMIC_TYPE_TMP_BUFFER);
return ret;
}
#endif
#if defined(WOLFSSL_KYBER768) || defined(WOLFSSL_WC_ML_KEM_768)
static int mlkem_get_noise_k3_avx2(sword16* vec1, sword16* vec2, sword16* poly,
byte* seed)
{
byte rand[4 * ETA2_RAND_SIZE];
mlkem_get_noise_x4_eta2_avx2(rand, seed, 0);
mlkem_cbd_eta2_avx2(vec1 , rand + 0 * ETA2_RAND_SIZE);
mlkem_cbd_eta2_avx2(vec1 + 1 * MLKEM_N, rand + 1 * ETA2_RAND_SIZE);
mlkem_cbd_eta2_avx2(vec1 + 2 * MLKEM_N, rand + 2 * ETA2_RAND_SIZE);
mlkem_cbd_eta2_avx2(vec2 , rand + 3 * ETA2_RAND_SIZE);
mlkem_get_noise_x4_eta2_avx2(rand, seed, 4);
mlkem_cbd_eta2_avx2(vec2 + 1 * MLKEM_N, rand + 0 * ETA2_RAND_SIZE);
mlkem_cbd_eta2_avx2(vec2 + 2 * MLKEM_N, rand + 1 * ETA2_RAND_SIZE);
if (poly != NULL) {
mlkem_cbd_eta2_avx2(poly, rand + 2 * ETA2_RAND_SIZE);
}
return 0;
}
#endif
#if defined(WOLFSSL_KYBER1024) || defined(WOLFSSL_WC_ML_KEM_1024)
static int mlkem_get_noise_k4_avx2(MLKEM_PRF_T* prf, sword16* vec1,
sword16* vec2, sword16* poly, byte* seed)
{
int ret = 0;
byte rand[4 * ETA2_RAND_SIZE];
(void)prf;
mlkem_get_noise_x4_eta2_avx2(rand, seed, 0);
mlkem_cbd_eta2_avx2(vec1 , rand + 0 * ETA2_RAND_SIZE);
mlkem_cbd_eta2_avx2(vec1 + 1 * MLKEM_N, rand + 1 * ETA2_RAND_SIZE);
mlkem_cbd_eta2_avx2(vec1 + 2 * MLKEM_N, rand + 2 * ETA2_RAND_SIZE);
mlkem_cbd_eta2_avx2(vec1 + 3 * MLKEM_N, rand + 3 * ETA2_RAND_SIZE);
mlkem_get_noise_x4_eta2_avx2(rand, seed, 4);
mlkem_cbd_eta2_avx2(vec2 , rand + 0 * ETA2_RAND_SIZE);
mlkem_cbd_eta2_avx2(vec2 + 1 * MLKEM_N, rand + 1 * ETA2_RAND_SIZE);
mlkem_cbd_eta2_avx2(vec2 + 2 * MLKEM_N, rand + 2 * ETA2_RAND_SIZE);
mlkem_cbd_eta2_avx2(vec2 + 3 * MLKEM_N, rand + 3 * ETA2_RAND_SIZE);
if (poly != NULL) {
seed[WC_ML_KEM_SYM_SZ] = 8;
ret = mlkem_get_noise_eta2_avx2(prf, poly, seed);
}
return ret;
}
#endif
#endif
#if defined(__aarch64__) && defined(WOLFSSL_ARMASM)
#define PRF_RAND_SZ (2 * SHA3_256_BYTES)
static void mlkem_get_noise_x3_eta2_aarch64(byte* rand, byte* seed, byte o)
{
word64* state = (word64*)rand;
state[0*25 + 4] = 0x1f00 + 0 + o;
state[1*25 + 4] = 0x1f00 + 1 + o;
state[2*25 + 4] = 0x1f00 + 2 + o;
mlkem_shake256_blocksx3_seed_neon(state, seed);
}
#if defined(WOLFSSL_KYBER512) || defined(WOLFSSL_WC_ML_KEM_512)
static void mlkem_get_noise_x3_eta3_aarch64(byte* rand, byte* seed, byte o)
{
word64 state[3 * 25];
state[0*25 + 4] = 0x1f00 + 0 + o;
state[1*25 + 4] = 0x1f00 + 1 + o;
state[2*25 + 4] = 0x1f00 + 2 + o;
mlkem_shake256_blocksx3_seed_neon(state, seed);
XMEMCPY(rand + 0 * ETA3_RAND_SIZE, state + 0*25, SHA3_256_BYTES);
XMEMCPY(rand + 1 * ETA3_RAND_SIZE, state + 1*25, SHA3_256_BYTES);
XMEMCPY(rand + 2 * ETA3_RAND_SIZE, state + 2*25, SHA3_256_BYTES);
mlkem_sha3_blocksx3_neon(state);
rand += SHA3_256_BYTES;
XMEMCPY(rand + 0 * ETA3_RAND_SIZE, state + 0*25,
ETA3_RAND_SIZE - SHA3_256_BYTES);
XMEMCPY(rand + 1 * ETA3_RAND_SIZE, state + 1*25,
ETA3_RAND_SIZE - SHA3_256_BYTES);
XMEMCPY(rand + 2 * ETA3_RAND_SIZE, state + 2*25,
ETA3_RAND_SIZE - SHA3_256_BYTES);
}
static void mlkem_get_noise_eta3_aarch64(byte* rand, byte* seed, byte o)
{
word64 state[25];
state[0] = ((word64*)seed)[0];
state[1] = ((word64*)seed)[1];
state[2] = ((word64*)seed)[2];
state[3] = ((word64*)seed)[3];
state[4] = 0x1f00 + o;
XMEMSET(state + 5, 0, sizeof(*state) * (25 - 5));
state[16] = W64LIT(0x8000000000000000);
BlockSha3(state);
XMEMCPY(rand , state, SHA3_256_BYTES);
BlockSha3(state);
XMEMCPY(rand + SHA3_256_BYTES, state, ETA3_RAND_SIZE - SHA3_256_BYTES);
}
static int mlkem_get_noise_k2_aarch64(sword16* vec1, sword16* vec2,
sword16* poly, byte* seed)
{
int ret = 0;
byte rand[3 * 25 * 8];
mlkem_get_noise_x3_eta3_aarch64(rand, seed, 0);
mlkem_cbd_eta3(vec1 , rand + 0 * ETA3_RAND_SIZE);
mlkem_cbd_eta3(vec1 + MLKEM_N, rand + 1 * ETA3_RAND_SIZE);
if (poly == NULL) {
mlkem_cbd_eta3(vec2 , rand + 2 * ETA3_RAND_SIZE);
mlkem_get_noise_eta3_aarch64(rand, seed, 3);
mlkem_cbd_eta3(vec2 + MLKEM_N, rand );
}
else {
mlkem_get_noise_x3_eta2_aarch64(rand, seed, 2);
mlkem_cbd_eta2(vec2 , rand + 0 * 25 * 8);
mlkem_cbd_eta2(vec2 + MLKEM_N, rand + 1 * 25 * 8);
mlkem_cbd_eta2(poly , rand + 2 * 25 * 8);
}
return ret;
}
#endif
#if defined(WOLFSSL_KYBER768) || defined(WOLFSSL_WC_ML_KEM_768)
static void mlkem_get_noise_eta2_aarch64(byte* rand, byte* seed, byte o)
{
word64* state = (word64*)rand;
state[0] = ((word64*)seed)[0];
state[1] = ((word64*)seed)[1];
state[2] = ((word64*)seed)[2];
state[3] = ((word64*)seed)[3];
state[4] = 0x1f00 + o;
XMEMSET(state + 5, 0, sizeof(*state) * (25 - 5));
state[16] = W64LIT(0x8000000000000000);
BlockSha3(state);
}
static int mlkem_get_noise_k3_aarch64(sword16* vec1, sword16* vec2,
sword16* poly, byte* seed)
{
byte rand[3 * 25 * 8];
mlkem_get_noise_x3_eta2_aarch64(rand, seed, 0);
mlkem_cbd_eta2(vec1 , rand + 0 * 25 * 8);
mlkem_cbd_eta2(vec1 + 1 * MLKEM_N, rand + 1 * 25 * 8);
mlkem_cbd_eta2(vec1 + 2 * MLKEM_N, rand + 2 * 25 * 8);
mlkem_get_noise_x3_eta2_aarch64(rand, seed, 3);
mlkem_cbd_eta2(vec2 , rand + 0 * 25 * 8);
mlkem_cbd_eta2(vec2 + 1 * MLKEM_N, rand + 1 * 25 * 8);
mlkem_cbd_eta2(vec2 + 2 * MLKEM_N, rand + 2 * 25 * 8);
if (poly != NULL) {
mlkem_get_noise_eta2_aarch64(rand, seed, 6);
mlkem_cbd_eta2(poly , rand + 0 * 25 * 8);
}
return 0;
}
#endif
#if defined(WOLFSSL_KYBER1024) || defined(WOLFSSL_WC_ML_KEM_1024)
static int mlkem_get_noise_k4_aarch64(sword16* vec1, sword16* vec2,
sword16* poly, byte* seed)
{
int ret = 0;
byte rand[3 * 25 * 8];
mlkem_get_noise_x3_eta2_aarch64(rand, seed, 0);
mlkem_cbd_eta2(vec1 , rand + 0 * 25 * 8);
mlkem_cbd_eta2(vec1 + 1 * MLKEM_N, rand + 1 * 25 * 8);
mlkem_cbd_eta2(vec1 + 2 * MLKEM_N, rand + 2 * 25 * 8);
mlkem_get_noise_x3_eta2_aarch64(rand, seed, 3);
mlkem_cbd_eta2(vec1 + 3 * MLKEM_N, rand + 0 * 25 * 8);
mlkem_cbd_eta2(vec2 , rand + 1 * 25 * 8);
mlkem_cbd_eta2(vec2 + 1 * MLKEM_N, rand + 2 * 25 * 8);
mlkem_get_noise_x3_eta2_aarch64(rand, seed, 6);
mlkem_cbd_eta2(vec2 + 2 * MLKEM_N, rand + 0 * 25 * 8);
mlkem_cbd_eta2(vec2 + 3 * MLKEM_N, rand + 1 * 25 * 8);
if (poly != NULL) {
mlkem_cbd_eta2(poly, rand + 2 * 25 * 8);
}
return ret;
}
#endif
#endif
#if !(defined(__aarch64__) && defined(WOLFSSL_ARMASM))
static int mlkem_get_noise_c(MLKEM_PRF_T* prf, int k, sword16* vec1, int eta1,
sword16* vec2, int eta2, sword16* poly, byte* seed)
{
int ret = 0;
int i;
seed[WC_ML_KEM_SYM_SZ] = 0;
for (i = 0; (ret == 0) && (i < k); i++) {
ret = mlkem_get_noise_eta1_c(prf, vec1 + i * MLKEM_N, seed, (byte)eta1);
seed[WC_ML_KEM_SYM_SZ]++;
}
if ((ret == 0) && (vec2 != NULL)) {
for (i = 0; (ret == 0) && (i < k); i++) {
ret = mlkem_get_noise_eta1_c(prf, vec2 + i * MLKEM_N, seed,
(byte)eta2);
seed[WC_ML_KEM_SYM_SZ]++;
}
}
else {
seed[WC_ML_KEM_SYM_SZ] = (byte)(2 * k);
}
if ((ret == 0) && (poly != NULL)) {
ret = mlkem_get_noise_eta2_c(prf, poly, seed);
}
return ret;
}
#endif
int mlkem_get_noise(MLKEM_PRF_T* prf, int k, sword16* vec1, sword16* vec2,
sword16* poly, byte* seed)
{
int ret;
#if defined(WOLFSSL_KYBER512) || defined(WOLFSSL_WC_ML_KEM_512)
if (k == WC_ML_KEM_512_K) {
#if defined(WOLFSSL_ARMASM) && defined(__aarch64__)
ret = mlkem_get_noise_k2_aarch64(vec1, vec2, poly, seed);
#else
#if defined(USE_INTEL_SPEEDUP) && !defined(WC_SHA3_NO_ASM)
if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) {
ret = mlkem_get_noise_k2_avx2(prf, vec1, vec2, poly, seed);
RESTORE_VECTOR_REGISTERS();
}
else
#endif
if (poly == NULL) {
ret = mlkem_get_noise_c(prf, k, vec1, MLKEM_CBD_ETA3, vec2,
MLKEM_CBD_ETA3, NULL, seed);
}
else {
ret = mlkem_get_noise_c(prf, k, vec1, MLKEM_CBD_ETA3, vec2,
MLKEM_CBD_ETA2, poly, seed);
}
#endif
}
else
#endif
#if defined(WOLFSSL_KYBER768) || defined(WOLFSSL_WC_ML_KEM_768)
if (k == WC_ML_KEM_768_K) {
#if defined(WOLFSSL_ARMASM) && defined(__aarch64__)
ret = mlkem_get_noise_k3_aarch64(vec1, vec2, poly, seed);
#else
#if defined(USE_INTEL_SPEEDUP) && !defined(WC_SHA3_NO_ASM)
if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) {
ret = mlkem_get_noise_k3_avx2(vec1, vec2, poly, seed);
RESTORE_VECTOR_REGISTERS();
}
else
#endif
{
ret = mlkem_get_noise_c(prf, k, vec1, MLKEM_CBD_ETA2, vec2,
MLKEM_CBD_ETA2, poly, seed);
}
#endif
}
else
#endif
#if defined(WOLFSSL_KYBER1024) || defined(WOLFSSL_WC_ML_KEM_1024)
if (k == WC_ML_KEM_1024_K) {
#if defined(WOLFSSL_ARMASM) && defined(__aarch64__)
ret = mlkem_get_noise_k4_aarch64(vec1, vec2, poly, seed);
#else
#if defined(USE_INTEL_SPEEDUP) && !defined(WC_SHA3_NO_ASM)
if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) {
ret = mlkem_get_noise_k4_avx2(prf, vec1, vec2, poly, seed);
RESTORE_VECTOR_REGISTERS();
}
else
#endif
{
ret = mlkem_get_noise_c(prf, k, vec1, MLKEM_CBD_ETA2, vec2,
MLKEM_CBD_ETA2, poly, seed);
}
#endif
}
else
#endif
{
ret = BAD_STATE_E;
}
(void)prf;
return ret;
}
#if defined(WOLFSSL_MLKEM_MAKEKEY_SMALL_MEM) || \
defined(WOLFSSL_MLKEM_ENCAPSULATE_SMALL_MEM)
static int mlkem_get_noise_i(MLKEM_PRF_T* prf, int k, sword16* vec2,
byte* seed, int i, int make)
{
int ret;
mlkem_prf_init(prf);
seed[WC_ML_KEM_SYM_SZ] = (byte)(k + i);
#if defined(WOLFSSL_KYBER512) || defined(WOLFSSL_WC_ML_KEM_512)
if ((k == WC_ML_KEM_512_K) && make) {
ret = mlkem_get_noise_eta1_c(prf, vec2, seed, MLKEM_CBD_ETA3);
}
else
#endif
{
ret = mlkem_get_noise_eta1_c(prf, vec2, seed, MLKEM_CBD_ETA2);
}
(void)make;
return ret;
}
#endif
#if !(defined(__aarch64__) && defined(WOLFSSL_ARMASM))
static int mlkem_cmp_c(const byte* a, const byte* b, int sz)
{
int i;
byte r = 0;
for (i = 0; i < sz; i++) {
r |= a[i] ^ b[i];
}
return (int)(0 - ((-(word32)r) >> 31));
}
#endif
int mlkem_cmp(const byte* a, const byte* b, int sz)
{
#if defined(__aarch64__) && defined(WOLFSSL_ARMASM)
return mlkem_cmp_neon(a, b, sz);
#else
int fail;
#ifdef USE_INTEL_SPEEDUP
if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) {
fail = mlkem_cmp_avx2(a, b, sz);
RESTORE_VECTOR_REGISTERS();
}
else
#endif
{
fail = mlkem_cmp_c(a, b, sz);
}
return fail;
#endif
}
#if !defined(WOLFSSL_ARMASM)
static MLKEM_NOINLINE void mlkem_csubq_c(sword16* p)
{
unsigned int i;
for (i = 0; i < MLKEM_N; ++i) {
sword16 t = (sword16)(p[i] - MLKEM_Q);
p[i] = (sword16)(((word16)(-((word16)t >> 15)) & MLKEM_Q) +
(word16)t);
}
}
#elif defined(__aarch64__)
#define mlkem_csubq_c mlkem_csubq_neon
#elif defined(WOLFSSL_ARMASM_THUMB2)
#define mlkem_csubq_c mlkem_thumb2_csubq
#else
#define mlkem_csubq_c mlkem_arm32_csubq
#endif
#if defined(CONV_WITH_DIV) || !defined(WORD64_AVAILABLE)
#define TO_COMP_WORD_VEC(v, i, j, k, s, m) \
((((word32)v[i * MLKEM_N + j + k] << s) + MLKEM_Q_HALF) / MLKEM_Q) & m
#define TO_COMP_WORD_10(v, i, j, k) \
TO_COMP_WORD_VEC(v, i, j, k, 10, 0x3ff)
#define TO_COMP_WORD_11(v, i, j, k) \
TO_COMP_WORD_VEC(v, i, j, k, 11, 0x7ff)
#else
#define MLKEM_V53 0x275f6ed0176UL
#define MLKEM_V53_HALF 0x10013afb768076UL
#define MLKEM_V54 0x4ebedda02ecUL
#define MLKEM_V54_HALF 0x200275f6ed00ecUL
#define TO_COMP_WORD_10(v, i, j, k) \
(sword16)((((MLKEM_V54 << 10) * (word64)(v)[(i) * MLKEM_N + (j) + (k)]) + \
MLKEM_V54_HALF) >> 54)
#define TO_COMP_WORD_11(v, i, j, k) \
(sword16)((((MLKEM_V53 << 11) * (word64)(v)[(i) * MLKEM_N + (j) + (k)]) + \
MLKEM_V53_HALF) >> 53)
#endif
#if !defined(WOLFSSL_MLKEM_NO_ENCAPSULATE) || \
!defined(WOLFSSL_MLKEM_NO_DECAPSULATE)
#if defined(WOLFSSL_KYBER512) || defined(WOLFSSL_WC_ML_KEM_512) || \
defined(WOLFSSL_KYBER768) || defined(WOLFSSL_WC_ML_KEM_768)
static void mlkem_vec_compress_10_c(byte* r, sword16* v, unsigned int k)
{
unsigned int i;
unsigned int j;
for (i = 0; i < k; i++) {
mlkem_csubq_c(v + i * MLKEM_N);
}
for (i = 0; i < k; i++) {
#if defined(WOLFSSL_SMALL_STACK) || defined(WOLFSSL_MLKEM_NO_LARGE_CODE) || \
defined(BIG_ENDIAN_ORDER)
for (j = 0; j < MLKEM_N; j += 4) {
#ifdef WOLFSSL_MLKEM_SMALL
unsigned int l;
sword16 t[4];
for (l = 0; l < 4; l++) {
t[l] = TO_COMP_WORD_10(v, i, j, l);
}
r[ 0] = (t[0] >> 0);
r[ 1] = (t[0] >> 8) | (t[1] << 2);
r[ 2] = (t[1] >> 6) | (t[2] << 4);
r[ 3] = (t[2] >> 4) | (t[3] << 6);
r[ 4] = (t[3] >> 2);
#else
sword16 t0 = TO_COMP_WORD_10(v, i, j, 0);
sword16 t1 = TO_COMP_WORD_10(v, i, j, 1);
sword16 t2 = TO_COMP_WORD_10(v, i, j, 2);
sword16 t3 = TO_COMP_WORD_10(v, i, j, 3);
r[ 0] = (byte)( t0 >> 0);
r[ 1] = (byte)((t0 >> 8) | (t1 << 2));
r[ 2] = (byte)((t1 >> 6) | (t2 << 4));
r[ 3] = (byte)((t2 >> 4) | (t3 << 6));
r[ 4] = (byte)( t3 >> 2);
#endif
r += 5;
}
#else
for (j = 0; j < MLKEM_N; j += 16) {
sword16 t0 = TO_COMP_WORD_10(v, i, j, 0);
sword16 t1 = TO_COMP_WORD_10(v, i, j, 1);
sword16 t2 = TO_COMP_WORD_10(v, i, j, 2);
sword16 t3 = TO_COMP_WORD_10(v, i, j, 3);
sword16 t4 = TO_COMP_WORD_10(v, i, j, 4);
sword16 t5 = TO_COMP_WORD_10(v, i, j, 5);
sword16 t6 = TO_COMP_WORD_10(v, i, j, 6);
sword16 t7 = TO_COMP_WORD_10(v, i, j, 7);
sword16 t8 = TO_COMP_WORD_10(v, i, j, 8);
sword16 t9 = TO_COMP_WORD_10(v, i, j, 9);
sword16 t10 = TO_COMP_WORD_10(v, i, j, 10);
sword16 t11 = TO_COMP_WORD_10(v, i, j, 11);
sword16 t12 = TO_COMP_WORD_10(v, i, j, 12);
sword16 t13 = TO_COMP_WORD_10(v, i, j, 13);
sword16 t14 = TO_COMP_WORD_10(v, i, j, 14);
sword16 t15 = TO_COMP_WORD_10(v, i, j, 15);
word32* r32 = (word32*)r;
r32[0] = (word32)t0 | ((word32)t1 << 10) |
((word32)t2 << 20) | ((word32)t3 << 30);
r32[1] = ((word32)t3 >> 2) | ((word32)t4 << 8) |
((word32)t5 << 18) | ((word32)t6 << 28);
r32[2] = ((word32)t6 >> 4) | ((word32)t7 << 6) |
((word32)t8 << 16) | ((word32)t9 << 26);
r32[3] = ((word32)t9 >> 6) | ((word32)t10 << 4) |
((word32)t11 << 14) | ((word32)t12 << 24);
r32[4] = ((word32)t12 >> 8) | ((word32)t13 << 2) |
((word32)t14 << 12) | ((word32)t15 << 22);
r += 20;
}
#endif
}
}
void mlkem_vec_compress_10(byte* r, sword16* v, unsigned int k)
{
#ifdef USE_INTEL_SPEEDUP
if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) {
mlkem_compress_10_avx2(r, v, (int)k);
RESTORE_VECTOR_REGISTERS();
}
else
#endif
{
mlkem_vec_compress_10_c(r, v, k);
}
}
#endif
#if defined(WOLFSSL_KYBER1024) || defined(WOLFSSL_WC_ML_KEM_1024)
static void mlkem_vec_compress_11_c(byte* r, sword16* v)
{
unsigned int i;
unsigned int j;
#ifdef WOLFSSL_MLKEM_SMALL
unsigned int k;
#endif
for (i = 0; i < 4; i++) {
mlkem_csubq_c(v + i * MLKEM_N);
}
for (i = 0; i < 4; i++) {
for (j = 0; j < MLKEM_N; j += 8) {
#ifdef WOLFSSL_MLKEM_SMALL
sword16 t[8];
for (k = 0; k < 8; k++) {
t[k] = TO_COMP_WORD_11(v, i, j, k);
}
r[ 0] = (byte)( t[0] >> 0);
r[ 1] = (byte)((t[0] >> 8) | (t[1] << 3));
r[ 2] = (byte)((t[1] >> 5) | (t[2] << 6));
r[ 3] = (byte)( t[2] >> 2);
r[ 4] = (byte)((t[2] >> 10) | (t[3] << 1));
r[ 5] = (byte)((t[3] >> 7) | (t[4] << 4));
r[ 6] = (byte)((t[4] >> 4) | (t[5] << 7));
r[ 7] = (byte)( t[5] >> 1);
r[ 8] = (byte)((t[5] >> 9) | (t[6] << 2));
r[ 9] = (byte)((t[6] >> 6) | (t[7] << 5));
r[10] = (byte)( t[7] >> 3);
#else
sword16 t0 = TO_COMP_WORD_11(v, i, j, 0);
sword16 t1 = TO_COMP_WORD_11(v, i, j, 1);
sword16 t2 = TO_COMP_WORD_11(v, i, j, 2);
sword16 t3 = TO_COMP_WORD_11(v, i, j, 3);
sword16 t4 = TO_COMP_WORD_11(v, i, j, 4);
sword16 t5 = TO_COMP_WORD_11(v, i, j, 5);
sword16 t6 = TO_COMP_WORD_11(v, i, j, 6);
sword16 t7 = TO_COMP_WORD_11(v, i, j, 7);
r[ 0] = (byte)( t0 >> 0);
r[ 1] = (byte)((t0 >> 8) | (t1 << 3));
r[ 2] = (byte)((t1 >> 5) | (t2 << 6));
r[ 3] = (byte)( t2 >> 2);
r[ 4] = (byte)((t2 >> 10) | (t3 << 1));
r[ 5] = (byte)((t3 >> 7) | (t4 << 4));
r[ 6] = (byte)((t4 >> 4) | (t5 << 7));
r[ 7] = (byte)( t5 >> 1);
r[ 8] = (byte)((t5 >> 9) | (t6 << 2));
r[ 9] = (byte)((t6 >> 6) | (t7 << 5));
r[10] = (byte)( t7 >> 3);
#endif
r += 11;
}
}
}
void mlkem_vec_compress_11(byte* r, sword16* v)
{
#ifdef USE_INTEL_SPEEDUP
if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) {
mlkem_compress_11_avx2(r, v, 4);
RESTORE_VECTOR_REGISTERS();
}
else
#endif
{
mlkem_vec_compress_11_c(r, v);
}
}
#endif
#endif
#ifndef WOLFSSL_MLKEM_NO_DECAPSULATE
#define DECOMP_10(v, i, j, k, t) \
v[(i) * MLKEM_N + 4 * (j) + (k)] = \
(sword16)((((word32)((t) & 0x3ff) * MLKEM_Q) + 512) >> 10)
#define DECOMP_11(v, i, j, k, t) \
v[(i) * MLKEM_N + 8 * (j) + (k)] = \
(sword16)((((word32)((t) & 0x7ff) * MLKEM_Q) + 1024) >> 11)
#if defined(WOLFSSL_KYBER512) || defined(WOLFSSL_WC_ML_KEM_512) || \
defined(WOLFSSL_KYBER768) || defined(WOLFSSL_WC_ML_KEM_768)
static void mlkem_vec_decompress_10_c(sword16* v, const byte* b, unsigned int k)
{
unsigned int i;
unsigned int j;
#ifdef WOLFSSL_MLKEM_SMALL
unsigned int l;
#endif
for (i = 0; i < k; i++) {
for (j = 0; j < MLKEM_N / 4; j++) {
#ifdef WOLFSSL_MLKEM_SMALL
word16 t[4];
t[0] = (word16)((b[0] >> 0) | ((word16)b[ 1] << 8));
t[1] = (word16)((b[1] >> 2) | ((word16)b[ 2] << 6));
t[2] = (word16)((b[2] >> 4) | ((word16)b[ 3] << 4));
t[3] = (word16)((b[3] >> 6) | ((word16)b[ 4] << 2));
b += 5;
for (l = 0; l < 4; l++) {
DECOMP_10(v, i, j, l, t[l]);
}
#else
word16 t0 = (word16)((b[0] >> 0) | ((word16)b[ 1] << 8));
word16 t1 = (word16)((b[1] >> 2) | ((word16)b[ 2] << 6));
word16 t2 = (word16)((b[2] >> 4) | ((word16)b[ 3] << 4));
word16 t3 = (word16)((b[3] >> 6) | ((word16)b[ 4] << 2));
b += 5;
DECOMP_10(v, i, j, 0, t0);
DECOMP_10(v, i, j, 1, t1);
DECOMP_10(v, i, j, 2, t2);
DECOMP_10(v, i, j, 3, t3);
#endif
}
}
}
void mlkem_vec_decompress_10(sword16* v, const byte* b, unsigned int k)
{
#ifdef USE_INTEL_SPEEDUP
if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) {
mlkem_decompress_10_avx2(v, b, (int)k);
RESTORE_VECTOR_REGISTERS();
}
else
#endif
{
mlkem_vec_decompress_10_c(v, b, k);
}
}
#endif
#if defined(WOLFSSL_KYBER1024) || defined(WOLFSSL_WC_ML_KEM_1024)
static void mlkem_vec_decompress_11_c(sword16* v, const byte* b)
{
unsigned int i;
unsigned int j;
#ifdef WOLFSSL_MLKEM_SMALL
unsigned int l;
#endif
for (i = 0; i < 4; i++) {
for (j = 0; j < MLKEM_N / 8; j++) {
#ifdef WOLFSSL_MLKEM_SMALL
word16 t[8];
t[0] = (word16)((b[0] >> 0) | ((word16)b[ 1] << 8));
t[1] = (word16)((b[1] >> 3) | ((word16)b[ 2] << 5));
t[2] = (word16)((b[2] >> 6) | ((word16)b[ 3] << 2) |
((word16)b[4] << 10));
t[3] = (word16)((b[4] >> 1) | ((word16)b[ 5] << 7));
t[4] = (word16)((b[5] >> 4) | ((word16)b[ 6] << 4));
t[5] = (word16)((b[6] >> 7) | ((word16)b[ 7] << 1) |
((word16)b[8] << 9));
t[6] = (word16)((b[8] >> 2) | ((word16)b[ 9] << 6));
t[7] = (word16)((b[9] >> 5) | ((word16)b[10] << 3));
b += 11;
for (l = 0; l < 8; l++) {
DECOMP_11(v, i, j, l, t[l]);
}
#else
word16 t0 = (word16)((b[0] >> 0) | ((word16)b[ 1] << 8));
word16 t1 = (word16)((b[1] >> 3) | ((word16)b[ 2] << 5));
word16 t2 = (word16)((b[2] >> 6) | ((word16)b[ 3] << 2) |
((word16)b[4] << 10));
word16 t3 = (word16)((b[4] >> 1) | ((word16)b[ 5] << 7));
word16 t4 = (word16)((b[5] >> 4) | ((word16)b[ 6] << 4));
word16 t5 = (word16)((b[6] >> 7) | ((word16)b[ 7] << 1) |
((word16)b[8] << 9));
word16 t6 = (word16)((b[8] >> 2) | ((word16)b[ 9] << 6));
word16 t7 = (word16)((b[9] >> 5) | ((word16)b[10] << 3));
b += 11;
DECOMP_11(v, i, j, 0, t0);
DECOMP_11(v, i, j, 1, t1);
DECOMP_11(v, i, j, 2, t2);
DECOMP_11(v, i, j, 3, t3);
DECOMP_11(v, i, j, 4, t4);
DECOMP_11(v, i, j, 5, t5);
DECOMP_11(v, i, j, 6, t6);
DECOMP_11(v, i, j, 7, t7);
#endif
}
}
}
void mlkem_vec_decompress_11(sword16* v, const byte* b)
{
#ifdef USE_INTEL_SPEEDUP
if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) {
mlkem_decompress_11_avx2(v, b, 4);
RESTORE_VECTOR_REGISTERS();
}
else
#endif
{
mlkem_vec_decompress_11_c(v, b);
}
}
#endif
#endif
#ifdef CONV_WITH_DIV
#define TO_COMP_WORD(v, i, j, s, m) \
((((word32)v[i + j] << s) + MLKEM_Q_HALF) / MLKEM_Q) & m
#define TO_COMP_WORD_4(p, i, j) \
TO_COMP_WORD(p, i, j, 4, 0xf)
#define TO_COMP_WORD_5(p, i, j) \
TO_COMP_WORD(p, i, j, 5, 0x1f)
#else
#define MLKEM_V28 ((word32)(((1U << 28) + MLKEM_Q_HALF)) / MLKEM_Q)
#define MLKEM_V28_HALF ((word32)(MLKEM_V28 * (MLKEM_Q_HALF + 1)))
#define MLKEM_V27 ((word32)(((1U << 27) + MLKEM_Q_HALF)) / MLKEM_Q)
#define MLKEM_V27_HALF ((word32)(MLKEM_V27 * MLKEM_Q_HALF))
#define TO_COMP_WORD_4(p, i, j) \
(byte)((((MLKEM_V28 << 4) * (word32)(p)[(i) + (j)]) + MLKEM_V28_HALF) >> 28)
#define TO_COMP_WORD_5(p, i, j) \
(byte)((((MLKEM_V27 << 5) * (word32)(p)[(i) + (j)]) + MLKEM_V27_HALF) >> 27)
#endif
#if !defined(WOLFSSL_MLKEM_NO_ENCAPSULATE) || \
!defined(WOLFSSL_MLKEM_NO_DECAPSULATE)
#if defined(WOLFSSL_KYBER512) || defined(WOLFSSL_WC_ML_KEM_512) || \
defined(WOLFSSL_KYBER768) || defined(WOLFSSL_WC_ML_KEM_768)
static void mlkem_compress_4_c(byte* b, sword16* p)
{
unsigned int i;
#ifdef WOLFSSL_MLKEM_SMALL
unsigned int j;
byte t[8];
#endif
mlkem_csubq_c(p);
for (i = 0; i < MLKEM_N; i += 8) {
#ifdef WOLFSSL_MLKEM_SMALL
for (j = 0; j < 8; j++) {
t[j] = TO_COMP_WORD_4(p, i, j);
}
b[0] = (byte)(t[0] | (t[1] << 4));
b[1] = (byte)(t[2] | (t[3] << 4));
b[2] = (byte)(t[4] | (t[5] << 4));
b[3] = (byte)(t[6] | (t[7] << 4));
#else
byte t0 = TO_COMP_WORD_4(p, i, 0);
byte t1 = TO_COMP_WORD_4(p, i, 1);
byte t2 = TO_COMP_WORD_4(p, i, 2);
byte t3 = TO_COMP_WORD_4(p, i, 3);
byte t4 = TO_COMP_WORD_4(p, i, 4);
byte t5 = TO_COMP_WORD_4(p, i, 5);
byte t6 = TO_COMP_WORD_4(p, i, 6);
byte t7 = TO_COMP_WORD_4(p, i, 7);
b[0] = (byte)(t0 | (t1 << 4));
b[1] = (byte)(t2 | (t3 << 4));
b[2] = (byte)(t4 | (t5 << 4));
b[3] = (byte)(t6 | (t7 << 4));
#endif
b += 4;
}
}
void mlkem_compress_4(byte* b, sword16* p)
{
#ifdef USE_INTEL_SPEEDUP
if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) {
mlkem_compress_4_avx2(b, p);
RESTORE_VECTOR_REGISTERS();
}
else
#endif
{
mlkem_compress_4_c(b, p);
}
}
#endif
#if defined(WOLFSSL_KYBER1024) || defined(WOLFSSL_WC_ML_KEM_1024)
static void mlkem_compress_5_c(byte* b, sword16* p)
{
unsigned int i;
#ifdef WOLFSSL_MLKEM_SMALL
unsigned int j;
byte t[8];
#endif
mlkem_csubq_c(p);
for (i = 0; i < MLKEM_N; i += 8) {
#ifdef WOLFSSL_MLKEM_SMALL
for (j = 0; j < 8; j++) {
t[j] = TO_COMP_WORD_5(p, i, j);
}
b[0] = (byte)((t[0] >> 0) | (t[1] << 5));
b[1] = (byte)((t[1] >> 3) | (t[2] << 2) | (t[3] << 7));
b[2] = (byte)((t[3] >> 1) | (t[4] << 4));
b[3] = (byte)((t[4] >> 4) | (t[5] << 1) | (t[6] << 6));
b[4] = (byte)((t[6] >> 2) | (t[7] << 3));
#else
byte t0 = TO_COMP_WORD_5(p, i, 0);
byte t1 = TO_COMP_WORD_5(p, i, 1);
byte t2 = TO_COMP_WORD_5(p, i, 2);
byte t3 = TO_COMP_WORD_5(p, i, 3);
byte t4 = TO_COMP_WORD_5(p, i, 4);
byte t5 = TO_COMP_WORD_5(p, i, 5);
byte t6 = TO_COMP_WORD_5(p, i, 6);
byte t7 = TO_COMP_WORD_5(p, i, 7);
b[0] = (byte)((t0 >> 0) | (t1 << 5));
b[1] = (byte)((t1 >> 3) | (t2 << 2) | (t3 << 7));
b[2] = (byte)((t3 >> 1) | (t4 << 4));
b[3] = (byte)((t4 >> 4) | (t5 << 1) | (t6 << 6));
b[4] = (byte)((t6 >> 2) | (t7 << 3));
#endif
b += 5;
}
}
void mlkem_compress_5(byte* b, sword16* p)
{
#ifdef USE_INTEL_SPEEDUP
if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) {
mlkem_compress_5_avx2(b, p);
RESTORE_VECTOR_REGISTERS();
}
else
#endif
{
mlkem_compress_5_c(b, p);
}
}
#endif
#endif
#ifndef WOLFSSL_MLKEM_NO_DECAPSULATE
#define DECOMP_4(p, i, j, t) \
p[(i) + (j)] = (sword16)(((word16)((t) * MLKEM_Q) + 8) >> 4)
#define DECOMP_5(p, i, j, t) \
p[(i) + (j)] = (sword16)((((word32)((t) & 0x1f) * MLKEM_Q) + 16) >> 5)
#if defined(WOLFSSL_KYBER512) || defined(WOLFSSL_WC_ML_KEM_512) || \
defined(WOLFSSL_KYBER768) || defined(WOLFSSL_WC_ML_KEM_768)
static void mlkem_decompress_4_c(sword16* p, const byte* b)
{
unsigned int i;
for (i = 0; i < MLKEM_N; i += 2) {
DECOMP_4(p, i, 0, b[0] & 0xf);
DECOMP_4(p, i, 1, b[0] >> 4);
b += 1;
}
}
void mlkem_decompress_4(sword16* p, const byte* b)
{
#ifdef USE_INTEL_SPEEDUP
if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) {
mlkem_decompress_4_avx2(p, b);
RESTORE_VECTOR_REGISTERS();
}
else
#endif
{
mlkem_decompress_4_c(p, b);
}
}
#endif
#if defined(WOLFSSL_KYBER1024) || defined(WOLFSSL_WC_ML_KEM_1024)
static void mlkem_decompress_5_c(sword16* p, const byte* b)
{
unsigned int i;
for (i = 0; i < MLKEM_N; i += 8) {
#ifdef WOLFSSL_MLKEM_SMALL
unsigned int j;
byte t[8];
t[0] = (b[0] >> 0);
t[1] = (byte)((b[0] >> 5) | (b[1] << 3));
t[2] = (b[1] >> 2);
t[3] = (byte)((b[1] >> 7) | (b[2] << 1));
t[4] = (byte)((b[2] >> 4) | (b[3] << 4));
t[5] = (b[3] >> 1);
t[6] = (byte)((b[3] >> 6) | (b[4] << 2));
t[7] = (b[4] >> 3);
b += 5;
for (j = 0; j < 8; j++) {
DECOMP_5(p, i, j, t[j]);
}
#else
byte t0 = (b[0] >> 0);
byte t1 = (byte)((b[0] >> 5) | (b[1] << 3));
byte t2 = (b[1] >> 2);
byte t3 = (byte)((b[1] >> 7) | (b[2] << 1));
byte t4 = (byte)((b[2] >> 4) | (b[3] << 4));
byte t5 = (b[3] >> 1);
byte t6 = (byte)((b[3] >> 6) | (b[4] << 2));
byte t7 = (b[4] >> 3);
b += 5;
DECOMP_5(p, i, 0, t0);
DECOMP_5(p, i, 1, t1);
DECOMP_5(p, i, 2, t2);
DECOMP_5(p, i, 3, t3);
DECOMP_5(p, i, 4, t4);
DECOMP_5(p, i, 5, t5);
DECOMP_5(p, i, 6, t6);
DECOMP_5(p, i, 7, t7);
#endif
}
}
void mlkem_decompress_5(sword16* p, const byte* b)
{
#ifdef USE_INTEL_SPEEDUP
if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) {
mlkem_decompress_5_avx2(p, b);
RESTORE_VECTOR_REGISTERS();
}
else
#endif
{
mlkem_decompress_5_c(p, b);
}
}
#endif
#endif
#if !(defined(__aarch64__) && defined(WOLFSSL_ARMASM))
#if !defined(WOLFSSL_MLKEM_NO_ENCAPSULATE) || \
!defined(WOLFSSL_MLKEM_NO_DECAPSULATE)
#define FROM_MSG_BIT(p, msg, i, j) \
((p)[8 * (i) + (j)] = (((sword16)0 - (sword16)(((msg)[i] >> (j)) & 1)) ^ \
wc_mlkem_opt_blocker()) & MLKEM_Q_1_HALF)
static void mlkem_from_msg_c(sword16* p, const byte* msg)
{
unsigned int i;
for (i = 0; i < MLKEM_N / 8; i++) {
#ifdef WOLFSSL_MLKEM_SMALL
unsigned int j;
for (j = 0; j < 8; j++) {
FROM_MSG_BIT(p, msg, i, j);
}
#else
FROM_MSG_BIT(p, msg, i, 0);
FROM_MSG_BIT(p, msg, i, 1);
FROM_MSG_BIT(p, msg, i, 2);
FROM_MSG_BIT(p, msg, i, 3);
FROM_MSG_BIT(p, msg, i, 4);
FROM_MSG_BIT(p, msg, i, 5);
FROM_MSG_BIT(p, msg, i, 6);
FROM_MSG_BIT(p, msg, i, 7);
#endif
}
}
void mlkem_from_msg(sword16* p, const byte* msg)
{
#ifdef USE_INTEL_SPEEDUP
if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) {
mlkem_from_msg_avx2(p, msg);
RESTORE_VECTOR_REGISTERS();
}
else
#endif
{
mlkem_from_msg_c(p, msg);
}
}
#endif
#ifndef WOLFSSL_MLKEM_NO_DECAPSULATE
#ifdef CONV_WITH_DIV
#define TO_MSG_BIT(m, p, i, j) \
m[i] |= (((((sword16)p[8 * i + j] << 1) + MLKEM_Q_HALF) / MLKEM_Q) & 1) << j
#else
#define MLKEM_V31 (((1U << 31) + (MLKEM_Q / 2)) / MLKEM_Q)
#define MLKEM_V31_2 ((word32)(MLKEM_V31 * 2))
#define MLKEM_V31_HALF ((word32)(MLKEM_V31 * MLKEM_Q_HALF))
#define TO_MSG_BIT(m, p, i, j) \
(m)[i] |= (byte)((((MLKEM_V31_2 * (word16)(p)[8 * (i) + (j)]) + \
MLKEM_V31_HALF) >> 31) << (j))
#endif
static void mlkem_to_msg_c(byte* msg, sword16* p)
{
unsigned int i;
mlkem_csubq_c(p);
for (i = 0; i < MLKEM_N / 8; i++) {
#ifdef WOLFSSL_MLKEM_SMALL
unsigned int j;
msg[i] = 0;
for (j = 0; j < 8; j++) {
TO_MSG_BIT(msg, p, i, j);
}
#else
msg[i] = 0;
TO_MSG_BIT(msg, p, i, 0);
TO_MSG_BIT(msg, p, i, 1);
TO_MSG_BIT(msg, p, i, 2);
TO_MSG_BIT(msg, p, i, 3);
TO_MSG_BIT(msg, p, i, 4);
TO_MSG_BIT(msg, p, i, 5);
TO_MSG_BIT(msg, p, i, 6);
TO_MSG_BIT(msg, p, i, 7);
#endif
}
}
void mlkem_to_msg(byte* msg, sword16* p)
{
#ifdef USE_INTEL_SPEEDUP
if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) {
mlkem_to_msg_avx2(msg, p);
RESTORE_VECTOR_REGISTERS();
}
else
#endif
{
mlkem_to_msg_c(msg, p);
}
}
#endif
#else
#if !defined(WOLFSSL_MLKEM_NO_ENCAPSULATE) || \
!defined(WOLFSSL_MLKEM_NO_DECAPSULATE)
void mlkem_from_msg(sword16* p, const byte* msg)
{
mlkem_from_msg_neon(p, msg);
}
#endif
#ifndef WOLFSSL_MLKEM_NO_DECAPSULATE
void mlkem_to_msg(byte* msg, sword16* p)
{
mlkem_to_msg_neon(msg, p);
}
#endif
#endif
static void mlkem_from_bytes_c(sword16* p, const byte* b, int k)
{
int i;
int j;
for (j = 0; j < k; j++) {
for (i = 0; i < MLKEM_N / 2; i++) {
p[2 * i + 0] = ((b[3 * i + 0] >> 0) |
((word16)b[3 * i + 1] << 8)) & 0xfff;
p[2 * i + 1] = ((b[3 * i + 1] >> 4) |
((word16)b[3 * i + 2] << 4)) & 0xfff;
}
p += MLKEM_N;
b += WC_ML_KEM_POLY_SIZE;
}
}
void mlkem_from_bytes(sword16* p, const byte* b, int k)
{
#ifdef USE_INTEL_SPEEDUP
if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) {
int i;
for (i = 0; i < k; i++) {
mlkem_from_bytes_avx2(p, b);
p += MLKEM_N;
b += WC_ML_KEM_POLY_SIZE;
}
RESTORE_VECTOR_REGISTERS();
}
else
#endif
{
mlkem_from_bytes_c(p, b, k);
}
}
static void mlkem_to_bytes_c(byte* b, sword16* p, int k)
{
int i;
int j;
for (j = 0; j < k; j++) {
mlkem_csubq_c(p);
for (i = 0; i < MLKEM_N / 2; i++) {
word16 t0 = (word16)p[2 * i];
word16 t1 = (word16)p[2 * i + 1];
b[3 * i + 0] = (byte)(t0 >> 0);
b[3 * i + 1] = (byte)((t0 >> 8) | (t1 << 4));
b[3 * i + 2] = (byte)(t1 >> 4);
}
p += MLKEM_N;
b += WC_ML_KEM_POLY_SIZE;
}
}
void mlkem_to_bytes(byte* b, sword16* p, int k)
{
#ifdef USE_INTEL_SPEEDUP
if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) {
int i;
for (i = 0; i < k; i++) {
mlkem_to_bytes_avx2(b, p);
p += MLKEM_N;
b += WC_ML_KEM_POLY_SIZE;
}
RESTORE_VECTOR_REGISTERS();
}
else
#endif
{
mlkem_to_bytes_c(b, p, k);
}
}
int mlkem_check_public(sword16* pub, int k)
{
int ret = 0;
int i;
for (i = 0; i < k * MLKEM_N; i++) {
if (pub[i] >= MLKEM_Q) {
ret = PUBLIC_KEY_E;
break;
}
}
return ret;
}
#endif