#include <wolfssl/wolfcrypt/libwolfssl_sources.h>
#include <wolfssl/wolfcrypt/wc_slhdsa.h>
#ifdef WOLFSSL_HAVE_SLHDSA
#include <wolfssl/wolfcrypt/cpuid.h>
#include <wolfssl/wolfcrypt/error-crypt.h>
#ifdef NO_INLINE
#include <wolfssl/wolfcrypt/misc.h>
#else
#define WOLFSSL_MISC_INCLUDED
#include <wolfcrypt/src/misc.c>
#endif
#include <wolfssl/wolfcrypt/hash.h>
#include <wolfssl/wolfcrypt/sha3.h>
#if defined(USE_INTEL_SPEEDUP)
static cpuid_flags_t cpuid_flags = WC_CPUID_INITIALIZER;
#endif
#define SLHDSA_W 16
#define SLHDSA_WM1 (SLHDSA_W - 1)
#ifndef WOLFSSL_SLHDSA_PARAM_NO_256
#define SLHDSA_MAX_N 32
#ifndef WOLFSSL_SLHDSA_PARAM_NO_FAST
#define SLHDSA_MAX_INDICES_SZ 35
#else
#define SLHDSA_MAX_INDICES_SZ 22
#endif
#elif !defined(WOLFSSL_SLHDSA_PARAM_NO_192)
#define SLHDSA_MAX_N 24
#ifndef WOLFSSL_SLHDSA_PARAM_NO_FAST
#define SLHDSA_MAX_INDICES_SZ 33
#else
#define SLHDSA_MAX_INDICES_SZ 17
#endif
#else
#define SLHDSA_MAX_N 16
#ifndef WOLFSSL_SLHDSA_PARAM_NO_FAST
#define SLHDSA_MAX_INDICES_SZ 33
#else
#define SLHDSA_MAX_INDICES_SZ 14
#endif
#endif
#ifndef WOLFSSL_SLHDSA_PARAM_NO_SMALL
#if !defined(WOLFSSL_SLHDSA_PARAM_NO_256)
#define SLHDSA_MAX_A 14
#elif !defined(WOLFSSL_SLHDSA_PARAM_NO_192)
#define SLHDSA_MAX_A 14
#else
#define SLHDSA_MAX_A 12
#endif
#else
#if !defined(WOLFSSL_SLHDSA_PARAM_NO_256)
#define SLHDSA_MAX_A 9
#elif !defined(WOLFSSL_SLHDSA_PARAM_NO_192)
#define SLHDSA_MAX_A 8
#else
#define SLHDSA_MAX_A 6
#endif
#endif
#ifndef WOLFSSL_SLHDSA_PARAM_NO_SMALL
#define SLHDSA_MAX_H_M 9
#else
#define SLHDSA_MAX_H_M 3
#endif
#define SLHDSA_MAX_MSG_SZ ((2 * SLHDSA_MAX_N) + 3)
#ifndef WOLFSSL_SLHDSA_PARAM_NO_256F
#define SLHDSA_MAX_MD 49
#elif !defined(WOLFSSL_SLHDSA_PARAM_NO_256S)
#define SLHDSA_MAX_MD 47
#elif !defined(WOLFSSL_SLHDSA_PARAM_NO_192F)
#define SLHDSA_MAX_MD 42
#elif !defined(WOLFSSL_SLHDSA_PARAM_NO_192S)
#define SLHDSA_MAX_MD 39
#elif !defined(WOLFSSL_SLHDSA_PARAM_NO_128F)
#define SLHDSA_MAX_MD 34
#else
#define SLHDSA_MAX_MD 30
#endif
#define HA_WOTS_HASH 0
#define HA_WOTS_PK 1
#define HA_TREE 2
#define HA_FORS_TREE 3
#define HA_FORS_ROOTS 4
#define HA_WOTS_PRF 5
#define HA_FORS_PRF 6
#define SLHDSA_HA_SZ 32
#define HA_Init(a) XMEMSET(a, 0, sizeof(HashAddress))
#define HA_Copy(a, b) XMEMCPY(a, b, sizeof(HashAddress))
#define HA_SetLayerAddress(a, l) (a)[0] = (l)
#define HA_SetTreeAddress(a, t) \
do { (a)[1] = (t)[0]; (a)[2] = (t)[1]; (a)[3] = (t)[2]; } while (0)
#define HA_SetTypeAndClear(a, y) \
do { (a)[4] = y; (a)[5] = 0; (a)[6] = 0; (a)[7] = 0; } while (0)
#define HA_SetTypeAndClearNotKPA(a, y) \
do { (a)[4] = y; (a)[6] = 0; (a)[7] = 0; } while (0)
#define HA_SetKeyPairAddress(a, i) (a)[5] = (i)
#define HA_SetChainAddress(a, i) (a)[6] = (i)
#define HA_SetTreeHeight(a, i) (a)[6] = (i)
#define HA_SetTreeHeightBE(a, i) c32toa(i, (a) + (6 * 4))
#define HA_SetHashAddress(a, i) (a)[7] = (i)
#define HA_SetTreeIndex(a, i) (a)[7] = (i)
#define HA_CopyKeyPairAddress(a, b) (a)[5] = (b)[5]
typedef word32 HashAddress[8];
static void HA_Encode(const word32* adrs, byte* address)
{
#ifndef WOLFSSL_WC_SLHDSA_SMALL
c32toa(adrs[0], address + (0 * 4));
c32toa(adrs[1], address + (1 * 4));
c32toa(adrs[2], address + (2 * 4));
c32toa(adrs[3], address + (3 * 4));
c32toa(adrs[4], address + (4 * 4));
c32toa(adrs[5], address + (5 * 4));
c32toa(adrs[6], address + (6 * 4));
c32toa(adrs[7], address + (7 * 4));
#else
int i;
for (i = 0; i < 8; i++) {
c32toa(adrs[i], address + (i * 4));
}
#endif
}
#define INDEX_TREE_MASK(t, mask) ((t)[2] & (mask))
#define INDEX_TREE_SHIFT_DOWN(t, b) \
(t)[2] = ((t)[1] << (32 - (b))) | ((t)[2] >> (b)); \
(t)[1] = (t)[1] >> (b);
#define SLHDSA_PARAMETERS(p, n, h, d, h_m, a, k) \
{ (p), (n), (h), (d), (h_m), (a), (k), \
2 * (n) + 3, \
(((k) * (a)) + 7) / 8, \
(((h) - ((h) / (d))) + 7) / 8, \
((h) + ((8 * (d)) - 1)) / (8 * (d)), \
(1 + (k) * (1 + (a)) + (d) * ((h_m) + 2*(n) + 3)) * (n) }
static const SlhDsaParameters SlhDsaParams[] =
{
#ifndef WOLFSSL_SLHDSA_PARAM_NO_128S
SLHDSA_PARAMETERS(SLHDSA_SHAKE128S, 16, 63, 7, 9, 12, 14),
#endif
#ifndef WOLFSSL_SLHDSA_PARAM_NO_128F
SLHDSA_PARAMETERS(SLHDSA_SHAKE128F, 16, 66, 22, 3, 6, 33),
#endif
#ifndef WOLFSSL_SLHDSA_PARAM_NO_192S
SLHDSA_PARAMETERS(SLHDSA_SHAKE192S, 24, 63, 7, 9, 14, 17),
#endif
#ifndef WOLFSSL_SLHDSA_PARAM_NO_192F
SLHDSA_PARAMETERS(SLHDSA_SHAKE192F, 24, 66, 22, 3, 8, 33),
#endif
#ifndef WOLFSSL_SLHDSA_PARAM_NO_256S
SLHDSA_PARAMETERS(SLHDSA_SHAKE256S, 32, 64, 8, 8, 14, 22),
#endif
#ifndef WOLFSSL_SLHDSA_PARAM_NO_256F
SLHDSA_PARAMETERS(SLHDSA_SHAKE256F, 32, 68, 17, 4, 9, 35),
#endif
};
#define SLHDSA_PARAM_LEN \
((int)(sizeof(SlhDsaParams) / sizeof(SlhDsaParameters)))
#ifndef WOLFSSL_WC_SLHDSA_SMALL
static int slhdsakey_hash_shake_3(wc_Shake* shake, const byte* data1,
byte data1_len, const word32* adrs, const byte* data2, byte data2_len,
byte* hash, byte hash_len)
{
#ifdef WOLFSSL_SLHDSA_FULL_HASH
int ret;
byte address[SLHDSA_HA_SZ];
HA_Encode(adrs, address);
ret = wc_Shake256_Update(shake, data1, data1_len);
if (ret == 0) {
ret = wc_Shake256_Update(shake, address, SLHDSA_HA_SZ);
}
if (ret == 0) {
ret = wc_Shake256_Update(shake, data2, data2_len);
}
if (ret == 0) {
ret = wc_Shake256_Final(shake, hash, hash_len);
}
return ret;
#elif defined(USE_INTEL_SPEEDUP)
word64* state = shake->s;
word8* state8 = (word8*)shake->s;
word32 o = 0;
XMEMCPY(state8 + o, data1, data1_len);
o += data1_len;
HA_Encode(adrs, state8 + o);
o += SLHDSA_HA_SZ;
XMEMCPY(state8 + o, data2, data2_len);
o += data2_len;
state8[o] = 0x1f;
o += 1;
XMEMSET(state8 + o, 0, sizeof(shake->s) - o);
state8[WC_SHA3_256_COUNT * 8 - 1] ^= 0x80;
#ifndef WC_SHA3_NO_ASM
if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) {
sha3_block_avx2(state);
RESTORE_VECTOR_REGISTERS();
}
else if (IS_INTEL_BMI2(cpuid_flags)) {
sha3_block_bmi2(state);
}
else
#endif
{
BlockSha3(state);
}
XMEMCPY(hash, shake->s, hash_len);
return 0;
#else
XMEMCPY(shake->t, data1, data1_len);
HA_Encode(adrs, shake->t + data1_len);
XMEMCPY(shake->t + data1_len + SLHDSA_HA_SZ, data2, data2_len);
shake->i = data1_len + SLHDSA_HA_SZ + data2_len;
return wc_Shake256_Final(shake, hash, hash_len);
#endif
}
#endif
static int slhdsakey_hash_shake_4(wc_Shake* shake, const byte* data1,
byte data1_len, const word32* adrs, const byte* data2, byte data2_len,
const byte* data3, byte data3_len, byte* hash, byte hash_len)
{
#ifdef WOLFSSL_SLHDSA_FULL_HASH
int ret;
byte address[SLHDSA_HA_SZ];
HA_Encode(adrs, address);
ret = wc_Shake256_Update(shake, data1, data1_len);
if (ret == 0) {
ret = wc_Shake256_Update(shake, address, SLHDSA_HA_SZ);
}
if (ret == 0) {
ret = wc_Shake256_Update(shake, data2, data2_len);
}
if (ret == 0) {
ret = wc_Shake256_Update(shake, data3, data3_len);
}
if (ret == 0) {
ret = wc_Shake256_Final(shake, hash, hash_len);
}
return ret;
#elif defined(USE_INTEL_SPEEDUP)
word64* state = shake->s;
word8* state8 = (word8*)shake->s;
word32 o = 0;
XMEMCPY(state8 + o, data1, data1_len);
o += data1_len;
HA_Encode(adrs, state8 + o);
o += SLHDSA_HA_SZ;
XMEMCPY(state8 + o, data2, data2_len);
o += data2_len;
XMEMCPY(state8 + o, data3, data3_len);
o += data3_len;
state8[o] = 0x1f;
o += 1;
XMEMSET(state8 + o, 0, sizeof(shake->s) - o);
state8[WC_SHA3_256_COUNT * 8 - 1] ^= 0x80;
#ifndef WC_SHA3_NO_ASM
if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) {
sha3_block_avx2(state);
RESTORE_VECTOR_REGISTERS();
}
else if (IS_INTEL_BMI2(cpuid_flags)) {
sha3_block_bmi2(state);
}
else
#endif
{
BlockSha3(state);
}
XMEMCPY(hash, shake->s, hash_len);
return 0;
#else
XMEMCPY(shake->t, data1, data1_len);
HA_Encode(adrs, shake->t + data1_len);
XMEMCPY(shake->t + data1_len + SLHDSA_HA_SZ, data2, data2_len);
XMEMCPY(shake->t + data1_len + SLHDSA_HA_SZ + data2_len, data3, data3_len);
shake->i = data1_len + SLHDSA_HA_SZ + data2_len + data3_len;
return wc_Shake256_Final(shake, hash, hash_len);
#endif
}
#ifndef WOLFSSL_WC_SLHDSA_SMALL
#define HASH_PRF(shake, pk_seed, sk_seed, adrs, n, hash) \
slhdsakey_hash_shake_3(shake, pk_seed, n, adrs, sk_seed, n, hash, n)
#define HASH_F(shake, pk_seed, adrs, m, n, hash) \
slhdsakey_hash_shake_3(shake, pk_seed, n, adrs, m, n, hash, n)
#define HASH_H(shake, pk_seed, adrs, node, n, hash) \
slhdsakey_hash_shake_3(shake, pk_seed, n, adrs, node, 2 * (n), hash, (n))
#else
#define HASH_PRF(shake, pk_seed, sk_seed, adrs, n, hash) \
slhdsakey_hash_shake_4(shake, pk_seed, n, adrs, sk_seed, n, NULL, 0, \
hash, n)
#define HASH_F(shake, pk_seed, adrs, m, n, hash) \
slhdsakey_hash_shake_4(shake, pk_seed, n, adrs, m, n, NULL, 0, hash, n)
#define HASH_H(shake, pk_seed, adrs, node, n, hash) \
slhdsakey_hash_shake_4(shake, pk_seed, n, adrs, node, 2 * n, NULL, 0, \
hash, n)
#endif
#define HASH_H_2(shake, pk_seed, adrs, m1, m2, n, hash) \
slhdsakey_hash_shake_4(shake, pk_seed, n, adrs, m1, n, m2, n, hash, n)
static int slhdsakey_hash_start(wc_Shake* shake, const byte* data, byte len)
{
#if defined(USE_INTEL_SPEEDUP)
XMEMSET(shake->s, 0, sizeof(shake->s));
#endif
#ifdef WOLFSSL_SLHDSA_FULL_HASH
return wc_Shake256_Update(shake, data, len);
#else
XMEMCPY(shake->t, data, len);
shake->i = len;
return 0;
#endif
}
static int slhdsakey_hash_start_addr(wc_Shake* shake, const byte* pk_seed,
const word32* adrs, byte n)
{
#ifdef WOLFSSL_SLHDSA_FULL_HASH
int ret;
byte address[SLHDSA_HA_SZ];
HA_Encode(adrs, address);
#if defined(USE_INTEL_SPEEDUP)
XMEMSET(shake->s, 0, sizeof(shake->s));
#endif
ret = wc_Shake256_Update(shake, pk_seed, n);
if (ret == 0) {
ret = wc_Shake256_Update(shake, address, SLHDSA_HA_SZ);
}
return ret;
#else
#if defined(USE_INTEL_SPEEDUP)
XMEMSET(shake->s, 0, sizeof(shake->s));
#endif
XMEMCPY(shake->t, pk_seed, n);
HA_Encode(adrs, shake->t + n);
shake->i = n + SLHDSA_HA_SZ;
return 0;
#endif
}
static int slhdsakey_hash_update(wc_Shake* shake, const byte* data, word32 len)
{
return wc_Shake256_Update(shake, data, len);
}
static int slhdsakey_hash_final(wc_Shake* shake, byte* hash, word32 len)
{
return wc_Shake256_Final(shake, hash, len);
}
static void slhdsakey_base_2b(const byte* x, byte b, byte outLen, word16* baseb)
{
int j;
int i = 0;
int bits = 0;
int total = 0;
word16 mask = (1 << b) - 1;
for (j = 0; j < outLen; j++) {
while (bits < b) {
total = (total << 8) + x[i++];
bits += 8;
}
bits -= b;
baseb[j] = (total >> bits) & mask;
}
}
static int slhdsakey_chain(SlhDsaKey* key, const byte* x, byte i, byte s,
const byte* pk_seed, word32* adrs, byte* node)
{
int ret = 0;
int j;
byte n = key->params->n;
if (s == 0) {
if (x != node) {
XMEMCPY(node, x, n);
}
}
else {
HA_SetHashAddress(adrs, i);
ret = HASH_F(&key->shake, pk_seed, adrs, x, n, node);
if (ret == 0) {
for (j = i + 1; j < i + s; j++) {
HA_SetHashAddress(adrs, j);
ret = HASH_F(&key->shake, pk_seed, adrs, node, n, node);
if (ret != 0) {
break;
}
}
}
}
return ret;
}
#if defined(USE_INTEL_SPEEDUP) && !defined(WOLFSSL_WC_SLHDSA_SMALL)
#ifndef WOLFSSL_SLHDSA_PARAM_NO_128
#define SHAKE256_SET_SEED_HA_X4_16(state, seed, addr) \
do { \
\
(state)[0] = (state)[1] = (state)[2] = (state)[3] = \
readUnalignedWord64((seed) + (0 * sizeof(word64))); \
(state)[4] = (state)[5] = (state)[6] = (state)[7] = \
readUnalignedWord64((seed) + (1 * sizeof(word64))); \
\
(state)[ 8] = (state)[ 9] = (state)[10] = (state)[11] = \
readUnalignedWord64((addr) + (0 * sizeof(word64))); \
(state)[12] = (state)[13] = (state)[14] = (state)[15] = \
readUnalignedWord64((addr) + (1 * sizeof(word64))); \
(state)[16] = (state)[17] = (state)[18] = (state)[19] = \
readUnalignedWord64((addr) + (2 * sizeof(word64))); \
(state)[20] = (state)[21] = (state)[22] = (state)[23] = \
readUnalignedWord64((addr) + (3 * sizeof(word64))); \
} while (0)
#define SHAKE256_SET_HASH_X4_16(state, hash) \
do { \
(state)[24] = ((word64*)((hash) + 0 * 16))[0]; \
(state)[25] = ((word64*)((hash) + 1 * 16))[0]; \
(state)[26] = ((word64*)((hash) + 2 * 16))[0]; \
(state)[27] = ((word64*)((hash) + 3 * 16))[0]; \
(state)[28] = ((word64*)((hash) + 0 * 16))[1]; \
(state)[29] = ((word64*)((hash) + 1 * 16))[1]; \
(state)[30] = ((word64*)((hash) + 2 * 16))[1]; \
(state)[31] = ((word64*)((hash) + 3 * 16))[1]; \
} while (0)
#define SHAKE256_GET_HASH_X4_16(state, hash) \
do { \
((word64*)((hash) + 0 * 16))[0] = (state)[0]; \
((word64*)((hash) + 1 * 16))[0] = (state)[1]; \
((word64*)((hash) + 2 * 16))[0] = (state)[2]; \
((word64*)((hash) + 3 * 16))[0] = (state)[3]; \
((word64*)((hash) + 0 * 16))[1] = (state)[4]; \
((word64*)((hash) + 1 * 16))[1] = (state)[5]; \
((word64*)((hash) + 2 * 16))[1] = (state)[6]; \
((word64*)((hash) + 3 * 16))[1] = (state)[7]; \
} while (0)
#endif
#ifndef WOLFSSL_SLHDSA_PARAM_NO_192
#define SHAKE256_SET_SEED_HA_X4_24(state, seed, addr) \
do { \
(state)[0] = (state)[1] = (state)[ 2] = (state)[ 3] = \
readUnalignedWord64((seed) + (0 * sizeof(word64))); \
(state)[4] = (state)[5] = (state)[ 6] = (state)[ 7] = \
readUnalignedWord64((seed) + (1 * sizeof(word64))); \
(state)[8] = (state)[9] = (state)[10] = (state)[11] = \
readUnalignedWord64((seed) + (2 * sizeof(word64))); \
\
(state)[12] = (state)[13] = (state)[14] = (state)[15] = \
readUnalignedWord64((addr) + (0 * sizeof(word64))); \
(state)[16] = (state)[17] = (state)[18] = (state)[19] = \
readUnalignedWord64((addr) + (1 * sizeof(word64))); \
(state)[20] = (state)[21] = (state)[22] = (state)[23] = \
readUnalignedWord64((addr) + (2 * sizeof(word64))); \
(state)[24] = (state)[25] = (state)[26] = (state)[27] = \
readUnalignedWord64((addr) + (3 * sizeof(word64))); \
} while (0)
#define SHAKE256_SET_HASH_X4_24(state, hash) \
do { \
(state)[28] = ((word64*)((hash) + 0 * 24))[0]; \
(state)[29] = ((word64*)((hash) + 1 * 24))[0]; \
(state)[30] = ((word64*)((hash) + 2 * 24))[0]; \
(state)[31] = ((word64*)((hash) + 3 * 24))[0]; \
(state)[32] = ((word64*)((hash) + 0 * 24))[1]; \
(state)[33] = ((word64*)((hash) + 1 * 24))[1]; \
(state)[34] = ((word64*)((hash) + 2 * 24))[1]; \
(state)[35] = ((word64*)((hash) + 3 * 24))[1]; \
(state)[36] = ((word64*)((hash) + 0 * 24))[2]; \
(state)[37] = ((word64*)((hash) + 1 * 24))[2]; \
(state)[38] = ((word64*)((hash) + 2 * 24))[2]; \
(state)[39] = ((word64*)((hash) + 3 * 24))[2]; \
} while (0)
#define SHAKE256_GET_HASH_X4_24(state, hash) \
do { \
((word64*)((hash) + 0 * 24))[0] = (state)[ 0]; \
((word64*)((hash) + 1 * 24))[0] = (state)[ 1]; \
((word64*)((hash) + 2 * 24))[0] = (state)[ 2]; \
((word64*)((hash) + 3 * 24))[0] = (state)[ 3]; \
((word64*)((hash) + 0 * 24))[1] = (state)[ 4]; \
((word64*)((hash) + 1 * 24))[1] = (state)[ 5]; \
((word64*)((hash) + 2 * 24))[1] = (state)[ 6]; \
((word64*)((hash) + 3 * 24))[1] = (state)[ 7]; \
((word64*)((hash) + 0 * 24))[2] = (state)[ 8]; \
((word64*)((hash) + 1 * 24))[2] = (state)[ 9]; \
((word64*)((hash) + 2 * 24))[2] = (state)[10]; \
((word64*)((hash) + 3 * 24))[2] = (state)[11]; \
} while (0)
#endif
#ifndef WOLFSSL_SLHDSA_PARAM_NO_256
#define SHAKE256_SET_SEED_HA_X4_32(state, seed, addr) \
do { \
(state)[ 0] = (state)[ 1] = (state)[ 2] = (state)[ 3] = \
readUnalignedWord64((seed) + (0 * sizeof(word64))); \
(state)[ 4] = (state)[ 5] = (state)[ 6] = (state)[ 7] = \
readUnalignedWord64((seed) + (1 * sizeof(word64))); \
(state)[ 8] = (state)[ 9] = (state)[10] = (state)[11] = \
readUnalignedWord64((seed) + (2 * sizeof(word64))); \
(state)[12] = (state)[13] = (state)[14] = (state)[15] = \
readUnalignedWord64((seed) + (3 * sizeof(word64))); \
\
(state)[16] = (state)[17] = (state)[18] = (state)[19] = \
readUnalignedWord64((addr) + (0 * sizeof(word64))); \
(state)[20] = (state)[21] = (state)[22] = (state)[23] = \
readUnalignedWord64((addr) + (1 * sizeof(word64))); \
(state)[24] = (state)[25] = (state)[26] = (state)[27] = \
readUnalignedWord64((addr) + (2 * sizeof(word64))); \
(state)[28] = (state)[29] = (state)[30] = (state)[31] = \
readUnalignedWord64((addr) + (3 * sizeof(word64))); \
} while (0)
#define SHAKE256_SET_HASH_X4_32(state, hash) \
do { \
(state)[32] = ((word64*)((hash) + 0 * 32))[0]; \
(state)[33] = ((word64*)((hash) + 1 * 32))[0]; \
(state)[34] = ((word64*)((hash) + 2 * 32))[0]; \
(state)[35] = ((word64*)((hash) + 3 * 32))[0]; \
(state)[36] = ((word64*)((hash) + 0 * 32))[1]; \
(state)[37] = ((word64*)((hash) + 1 * 32))[1]; \
(state)[38] = ((word64*)((hash) + 2 * 32))[1]; \
(state)[39] = ((word64*)((hash) + 3 * 32))[1]; \
(state)[40] = ((word64*)((hash) + 0 * 32))[2]; \
(state)[41] = ((word64*)((hash) + 1 * 32))[2]; \
(state)[42] = ((word64*)((hash) + 2 * 32))[2]; \
(state)[43] = ((word64*)((hash) + 3 * 32))[2]; \
(state)[44] = ((word64*)((hash) + 0 * 32))[3]; \
(state)[45] = ((word64*)((hash) + 1 * 32))[3]; \
(state)[46] = ((word64*)((hash) + 2 * 32))[3]; \
(state)[47] = ((word64*)((hash) + 3 * 32))[3]; \
} while (0)
#define SHAKE256_GET_HASH_X4_32(state, hash) \
do { \
((word64*)((hash) + 0 * 32))[0] = (state)[ 0]; \
((word64*)((hash) + 1 * 32))[0] = (state)[ 1]; \
((word64*)((hash) + 2 * 32))[0] = (state)[ 2]; \
((word64*)((hash) + 3 * 32))[0] = (state)[ 3]; \
((word64*)((hash) + 0 * 32))[1] = (state)[ 4]; \
((word64*)((hash) + 1 * 32))[1] = (state)[ 5]; \
((word64*)((hash) + 2 * 32))[1] = (state)[ 6]; \
((word64*)((hash) + 3 * 32))[1] = (state)[ 7]; \
((word64*)((hash) + 0 * 32))[2] = (state)[ 8]; \
((word64*)((hash) + 1 * 32))[2] = (state)[ 9]; \
((word64*)((hash) + 2 * 32))[2] = (state)[10]; \
((word64*)((hash) + 3 * 32))[2] = (state)[11]; \
((word64*)((hash) + 0 * 32))[3] = (state)[12]; \
((word64*)((hash) + 1 * 32))[3] = (state)[13]; \
((word64*)((hash) + 2 * 32))[3] = (state)[14]; \
((word64*)((hash) + 3 * 32))[3] = (state)[15]; \
} while (0)
#endif
#define SHAKE256_SET_END_X4(state, o) \
do { \
\
(state)[(o) + 0] = (word64)0x1f; \
(state)[(o) + 1] = (word64)0x1f; \
(state)[(o) + 2] = (word64)0x1f; \
(state)[(o) + 3] = (word64)0x1f; \
XMEMSET((state) + (o) + 4, 0, (25 * 4 - ((o) + 4)) * sizeof(word64)); \
\
((word8*)((state) + 4 * WC_SHA3_256_COUNT - 4))[7] ^= 0x80; \
((word8*)((state) + 4 * WC_SHA3_256_COUNT - 3))[7] ^= 0x80; \
((word8*)((state) + 4 * WC_SHA3_256_COUNT - 2))[7] ^= 0x80; \
((word8*)((state) + 4 * WC_SHA3_256_COUNT - 1))[7] ^= 0x80; \
} while (0)
static int slhdsakey_shake256_set_seed_ha_x4(word64* state, const byte* seed,
const byte* addr, int n)
{
int i;
int o = 0;
for (i = 0; i < n; i += 8) {
state[o + 0] = state[o + 1] = state[o + 2] = state[o + 3] =
readUnalignedWord64(seed + i);
o += 4;
}
for (i = 0; i < SLHDSA_HA_SZ; i += 8) {
state[o + 0] = state[o + 1] = state[o + 2] = state[o + 3] =
readUnalignedWord64(addr + i);
o += 4;
}
return o;
}
#ifndef WOLFSSL_SLHDSA_VERIFY_ONLY
static int slhdsakey_shake256_set_seed_ha_hash_x4(word64* state,
const byte* seed, const byte* addr, const byte* hash, int n)
{
int i;
int o = 0;
int ret;
ret = o = slhdsakey_shake256_set_seed_ha_x4(state, seed, addr, n);
for (i = 0; i < n; i += 8) {
state[o + 0] = state[o + 1] = state[o + 2] = state[o + 3] =
readUnalignedWord64(hash + i);
o += 4;
}
SHAKE256_SET_END_X4(state, o);
return ret;
}
#endif
static void slhdsakey_shake256_get_hash_x4(const word64* state, byte* hash,
int n)
{
int i;
for (i = 0; i < (n / 8); i++) {
((word64*)(hash + 0 * n))[i] = state[4 * i + 0];
((word64*)(hash + 1 * n))[i] = state[4 * i + 1];
((word64*)(hash + 2 * n))[i] = state[4 * i + 2];
((word64*)(hash + 3 * n))[i] = state[4 * i + 3];
}
}
#ifndef WOLFSSL_SLHDSA_VERIFY_ONLY
#define SHAKE256_SET_CHAIN_ADDRESS(state, o, a) \
do { \
((word8*)((state) + (o) - 4))[3] = (a) + 0; \
((word8*)((state) + (o) - 3))[3] = (a) + 1; \
((word8*)((state) + (o) - 2))[3] = (a) + 2; \
((word8*)((state) + (o) - 1))[3] = (a) + 3; \
} while (0)
#endif
#define SHAKE256_SET_CHAIN_ADDRESS_IDX(state, o, idx) \
do { \
((word8*)((state) + (o) - 4))[3] = (idx)[0]; \
((word8*)((state) + (o) - 3))[3] = (idx)[1]; \
((word8*)((state) + (o) - 2))[3] = (idx)[2]; \
((word8*)((state) + (o) - 1))[3] = (idx)[3]; \
} while (0)
#define SHAKE256_SET_HASH_ADDRESS(state, o, a) \
do { \
((word8*)((state) + (o) - 4))[7] = (a); \
((word8*)((state) + (o) - 3))[7] = (a); \
((word8*)((state) + (o) - 2))[7] = (a); \
((word8*)((state) + (o) - 1))[7] = (a); \
} while (0)
#ifndef WOLFSSL_SLHDSA_VERIFY_ONLY
#define SHAKE256_SET_TREE_INDEX(state, o, ti) \
do { \
c32toa((ti) + 0, (byte*)&((word32*)((state) + (o) - 4))[1]); \
c32toa((ti) + 1, (byte*)&((word32*)((state) + (o) - 3))[1]); \
c32toa((ti) + 2, (byte*)&((word32*)((state) + (o) - 2))[1]); \
c32toa((ti) + 3, (byte*)&((word32*)((state) + (o) - 1))[1]); \
} while (0)
#endif
#define SHAKE256_SET_TREE_INDEX_IDX(state, o, ti) \
do { \
c32toa((ti)[0], (byte*)&((word32*)((state) + (o) - 4))[1]); \
c32toa((ti)[1], (byte*)&((word32*)((state) + (o) - 3))[1]); \
c32toa((ti)[2], (byte*)&((word32*)((state) + (o) - 2))[1]); \
c32toa((ti)[3], (byte*)&((word32*)((state) + (o) - 1))[1]); \
} while (0)
#define SHAKE256_SET_TREE_HEIGHT(state, o, th) \
do { \
c32toa((th), (byte*)&((word32*)((state) + (o) - 4))[0]); \
c32toa((th), (byte*)&((word32*)((state) + (o) - 3))[0]); \
c32toa((th), (byte*)&((word32*)((state) + (o) - 2))[0]); \
c32toa((th), (byte*)&((word32*)((state) + (o) - 1))[0]); \
} while (0)
#ifndef WOLFSSL_SLHDSA_PARAM_NO_128
static int slhdsakey_chain_idx_x4_16(byte* sk, byte i, byte s,
const byte* pk_seed, byte* addr, byte* idx, void* heap)
{
int ret = 0;
int j;
WC_DECLARE_VAR(fixed, word64, 6 * 4, heap);
WC_DECLARE_VAR(state, word64, 25 * 4, heap);
(void)heap;
WC_ALLOC_VAR_EX(fixed, word64, 6 * 4, heap, DYNAMIC_TYPE_SLHDSA,
ret = MEMORY_E);
if (ret == 0) {
WC_ALLOC_VAR_EX(state, word64, 25 * 4, heap, DYNAMIC_TYPE_SLHDSA,
ret = MEMORY_E);
}
if (ret == 0) {
SHAKE256_SET_SEED_HA_X4_16(fixed, pk_seed, addr);
SHAKE256_SET_CHAIN_ADDRESS_IDX(fixed, 24, idx);
SHAKE256_SET_HASH_X4_16(state, sk);
for (j = i; j < i + s; j++) {
if (j != i) {
XMEMCPY(state + 24, state, 16 * 4);
}
XMEMCPY(state, fixed, (6 * 4) * sizeof(word64));
SHAKE256_SET_HASH_ADDRESS(state, 24, j);
SHAKE256_SET_END_X4(state, 32);
ret = SAVE_VECTOR_REGISTERS2();
if (ret != 0)
return ret;
sha3_blocksx4_avx2(state);
RESTORE_VECTOR_REGISTERS();
}
SHAKE256_GET_HASH_X4_16(state, sk);
}
WC_FREE_VAR_EX(state, heap, DYNAMIC_TYPE_SLHDSA);
WC_FREE_VAR_EX(fixed, heap, DYNAMIC_TYPE_SLHDSA);
return ret;
}
#endif
#ifndef WOLFSSL_SLHDSA_PARAM_NO_192
static int slhdsakey_chain_idx_x4_24(byte* sk, byte i, byte s,
const byte* pk_seed, byte* addr, byte* idx, void* heap)
{
int ret = 0;
int j;
WC_DECLARE_VAR(fixed, word64, 7 * 4, heap);
WC_DECLARE_VAR(state, word64, 25 * 4, heap);
(void)heap;
WC_ALLOC_VAR_EX(fixed, word64, 7 * 4, heap, DYNAMIC_TYPE_SLHDSA,
ret = MEMORY_E);
if (ret == 0) {
WC_ALLOC_VAR_EX(state, word64, 25 * 4, heap, DYNAMIC_TYPE_SLHDSA,
ret = MEMORY_E);
}
if (ret == 0) {
SHAKE256_SET_SEED_HA_X4_24(fixed, pk_seed, addr);
SHAKE256_SET_CHAIN_ADDRESS_IDX(fixed, 28, idx);
SHAKE256_SET_HASH_X4_24(state, sk);
for (j = i; j < i + s; j++) {
if (j != i) {
XMEMCPY(state + 28, state, 24 * 4);
}
XMEMCPY(state, fixed, 28 * sizeof(word64));
SHAKE256_SET_HASH_ADDRESS(state, 28, j);
SHAKE256_SET_END_X4(state, 40);
ret = SAVE_VECTOR_REGISTERS2();
if (ret != 0)
return ret;
sha3_blocksx4_avx2(state);
RESTORE_VECTOR_REGISTERS();
}
SHAKE256_GET_HASH_X4_24(state, sk);
}
WC_FREE_VAR_EX(state, heap, DYNAMIC_TYPE_SLHDSA);
WC_FREE_VAR_EX(fixed, heap, DYNAMIC_TYPE_SLHDSA);
return ret;
}
#endif
#ifndef WOLFSSL_SLHDSA_PARAM_NO_256
static int slhdsakey_chain_idx_x4_32(byte* sk, byte i, byte s,
const byte* pk_seed, byte* addr, byte* idx, void* heap)
{
int ret = 0;
int j;
WC_DECLARE_VAR(fixed, word64, 8 * 4, heap);
WC_DECLARE_VAR(state, word64, 25 * 4, heap);
(void)heap;
WC_ALLOC_VAR_EX(fixed, word64, 8 * 4, heap, DYNAMIC_TYPE_SLHDSA,
ret = MEMORY_E);
if (ret == 0) {
WC_ALLOC_VAR_EX(state, word64, 25 * 4, heap, DYNAMIC_TYPE_SLHDSA,
ret = MEMORY_E);
}
if (ret == 0) {
SHAKE256_SET_SEED_HA_X4_32(fixed, pk_seed, addr);
SHAKE256_SET_CHAIN_ADDRESS_IDX(fixed, 32, idx);
SHAKE256_SET_HASH_X4_32(state, sk);
for (j = i; j < i + s; j++) {
if (j != i) {
XMEMCPY(state + 32, state, 32 * 4);
}
XMEMCPY(state, fixed, 32 * sizeof(word64));
SHAKE256_SET_HASH_ADDRESS(state, 32, j);
SHAKE256_SET_END_X4(state, 48);
ret = SAVE_VECTOR_REGISTERS2();
if (ret != 0)
return ret;
sha3_blocksx4_avx2(state);
RESTORE_VECTOR_REGISTERS();
}
SHAKE256_GET_HASH_X4_32(state, sk);
}
WC_FREE_VAR_EX(state, heap, DYNAMIC_TYPE_SLHDSA);
WC_FREE_VAR_EX(fixed, heap, DYNAMIC_TYPE_SLHDSA);
return ret;
}
#endif
#endif
#ifndef WOLFSSL_SLHDSA_VERIFY_ONLY
#if defined(USE_INTEL_SPEEDUP) && !defined(WOLFSSL_WC_SLHDSA_SMALL)
static int slhdsakey_hash_prf_x4(const byte* pk_seed, const byte* sk_seed,
byte* addr, byte n, byte ca, byte* sk, void* heap)
{
int ret = 0;
word32 o = 0;
WC_DECLARE_VAR(state, word64, 25 * 4, heap);
(void)heap;
WC_ALLOC_VAR_EX(state, word64, 25 * 4, heap, DYNAMIC_TYPE_SLHDSA,
ret = MEMORY_E);
if (ret == 0) {
o = slhdsakey_shake256_set_seed_ha_hash_x4(state, pk_seed, addr,
sk_seed, n);
SHAKE256_SET_CHAIN_ADDRESS(state, o, ca);
ret = SAVE_VECTOR_REGISTERS2();
if (ret == 0) {
sha3_blocksx4_avx2(state);
slhdsakey_shake256_get_hash_x4(state, sk, n);
RESTORE_VECTOR_REGISTERS();
}
WC_FREE_VAR_EX(state, heap, DYNAMIC_TYPE_SLHDSA);
}
return ret;
}
#if !defined(WOLFSSL_SLHDSA_PARAM_NO_128)
static int slhdsakey_chain_x4_16(byte* sk, const byte* pk_seed, byte* addr,
byte ca, void* heap)
{
int ret = 0;
int j;
WC_DECLARE_VAR(fixed, word64, 8 * 4, heap);
WC_DECLARE_VAR(state, word64, 25 * 4, heap);
(void)heap;
WC_ALLOC_VAR_EX(fixed, word64, 8 * 4, heap, DYNAMIC_TYPE_SLHDSA,
ret = MEMORY_E);
if (ret == 0) {
WC_ALLOC_VAR_EX(state, word64, 25 * 4, heap, DYNAMIC_TYPE_SLHDSA,
ret = MEMORY_E);
}
if (ret == 0) {
SHAKE256_SET_SEED_HA_X4_16(fixed, pk_seed, addr);
SHAKE256_SET_CHAIN_ADDRESS(fixed, 24, ca);
SHAKE256_SET_HASH_X4_16(state, sk);
for (j = 0; j < 15; j++) {
if (j != 0) {
XMEMCPY(state + 24, state, 16 * 4);
}
XMEMCPY(state, fixed, 24 * sizeof(word64));
SHAKE256_SET_HASH_ADDRESS(state, 24, j);
SHAKE256_SET_END_X4(state, 32);
ret = SAVE_VECTOR_REGISTERS2();
if (ret != 0)
break;
sha3_blocksx4_avx2(state);
RESTORE_VECTOR_REGISTERS();
}
if (ret == 0)
SHAKE256_GET_HASH_X4_16(state, sk);
}
WC_FREE_VAR_EX(state, heap, DYNAMIC_TYPE_SLHDSA);
WC_FREE_VAR_EX(fixed, heap, DYNAMIC_TYPE_SLHDSA);
return ret;
}
#endif
#if !defined(WOLFSSL_SLHDSA_PARAM_NO_192)
static int slhdsakey_chain_x4_24(byte* sk, const byte* pk_seed, byte* addr,
byte ca, void* heap)
{
int ret = 0;
int j;
WC_DECLARE_VAR(fixed, word64, 8 * 4, heap);
WC_DECLARE_VAR(state, word64, 25 * 4, heap);
(void)heap;
WC_ALLOC_VAR_EX(fixed, word64, 8 * 4, heap, DYNAMIC_TYPE_SLHDSA,
ret = MEMORY_E);
if (ret == 0) {
WC_ALLOC_VAR_EX(state, word64, 25 * 4, heap, DYNAMIC_TYPE_SLHDSA,
ret = MEMORY_E);
}
if (ret == 0) {
SHAKE256_SET_SEED_HA_X4_24(fixed, pk_seed, addr);
SHAKE256_SET_CHAIN_ADDRESS(fixed, 28, ca);
SHAKE256_SET_HASH_X4_24(state, sk);
for (j = 0; j < 15; j++) {
if (j != 0) {
XMEMCPY(state + 28, state, 24 * 4);
}
XMEMCPY(state, fixed, 28 * sizeof(word64));
SHAKE256_SET_HASH_ADDRESS(state, 28, j);
SHAKE256_SET_END_X4(state, 40);
ret = SAVE_VECTOR_REGISTERS2();
if (ret != 0)
break;
sha3_blocksx4_avx2(state);
RESTORE_VECTOR_REGISTERS();
}
if (ret == 0)
SHAKE256_GET_HASH_X4_24(state, sk);
}
WC_FREE_VAR_EX(state, heap, DYNAMIC_TYPE_SLHDSA);
WC_FREE_VAR_EX(fixed, heap, DYNAMIC_TYPE_SLHDSA);
return ret;
}
#endif
#if !defined(WOLFSSL_SLHDSA_PARAM_NO_256)
static int slhdsakey_chain_x4_32(byte* sk, const byte* pk_seed, byte* addr,
byte ca, void* heap)
{
int ret = 0;
int j;
WC_DECLARE_VAR(fixed, word64, 8 * 4, heap);
WC_DECLARE_VAR(state, word64, 25 * 4, heap);
(void)heap;
WC_ALLOC_VAR_EX(fixed, word64, 8 * 4, heap, DYNAMIC_TYPE_SLHDSA,
ret = MEMORY_E);
if (ret == 0) {
WC_ALLOC_VAR_EX(state, word64, 25 * 4, heap, DYNAMIC_TYPE_SLHDSA,
ret = MEMORY_E);
}
if (ret == 0) {
SHAKE256_SET_SEED_HA_X4_32(fixed, pk_seed, addr);
SHAKE256_SET_CHAIN_ADDRESS(fixed, 32, ca);
SHAKE256_SET_HASH_X4_32(state, sk);
for (j = 0; j < 15; j++) {
if (j != 0) {
XMEMCPY(state + 32, state, 32 * 4);
}
XMEMCPY(state, fixed, 32 * sizeof(word64));
SHAKE256_SET_HASH_ADDRESS(state, 32, j);
SHAKE256_SET_END_X4(state, 48);
ret = SAVE_VECTOR_REGISTERS2();
if (ret != 0)
break;
sha3_blocksx4_avx2(state);
RESTORE_VECTOR_REGISTERS();
}
if (ret == 0)
SHAKE256_GET_HASH_X4_32(state, sk);
}
WC_FREE_VAR_EX(state, heap, DYNAMIC_TYPE_SLHDSA);
WC_FREE_VAR_EX(fixed, heap, DYNAMIC_TYPE_SLHDSA);
return ret;
}
#endif
static int slhdsakey_hash_prf_idx_x4(const byte* pk_seed, const byte* sk_seed,
byte* addr, byte n, byte* idx, byte* sk, void* heap)
{
int ret = 0;
word32 o = 0;
WC_DECLARE_VAR(state, word64, 25 * 4, heap);
(void)heap;
WC_ALLOC_VAR_EX(state, word64, 25 * 4, heap, DYNAMIC_TYPE_SLHDSA,
ret = MEMORY_E);
if (ret == 0) {
o = slhdsakey_shake256_set_seed_ha_hash_x4(state, pk_seed, addr,
sk_seed, n);
SHAKE256_SET_CHAIN_ADDRESS_IDX(state, o, idx);
ret = SAVE_VECTOR_REGISTERS2();
if (ret == 0) {
sha3_blocksx4_avx2(state);
RESTORE_VECTOR_REGISTERS();
slhdsakey_shake256_get_hash_x4(state, sk, n);
}
WC_FREE_VAR_EX(state, heap, DYNAMIC_TYPE_SLHDSA);
}
return ret;
}
#if !defined(WOLFSSL_SLHDSA_PARAM_NO_128)
static int slhdsakey_chain_idx_16(SlhDsaKey* key, byte* sk,
const byte* pk_seed, word32* adrs, byte* addr, const byte* msg, byte* idx,
int j, int cnt, byte* sig)
{
int ret = 0;
if (j != 0) {
ret = slhdsakey_chain_idx_x4_16(sk, 0, j, pk_seed, addr, idx,
key->heap);
}
if (ret == 0) {
if (cnt > 3) {
XMEMCPY(sig + idx[3] * 16, sk + 3 * 16, 16);
}
if (msg[idx[2]] != j) {
ret = slhdsakey_chain_idx_x4_16(sk, j, msg[idx[2]] - j, pk_seed,
addr, idx, key->heap);
j = msg[idx[2]];
}
}
if (ret == 0) {
XMEMCPY(sig + idx[2] * 16, sk + 2 * 16, 16);
if (msg[idx[1]] != j) {
ret = slhdsakey_chain_idx_x4_16(sk, j, msg[idx[1]] - j, pk_seed,
addr, idx, key->heap);
j = msg[idx[1]];
}
}
if (ret == 0) {
XMEMCPY(sig + idx[1] * 16, sk + 1 * 16, 16);
if (msg[idx[0]] != j) {
HA_SetChainAddress(adrs, idx[0]);
ret = slhdsakey_chain(key, sk, j, msg[idx[0]] - j, pk_seed, adrs,
sk);
}
}
if (ret == 0) {
XMEMCPY(sig + idx[0] * 16, sk + 0 * 16, 16);
}
return ret;
}
#endif
#if !defined(WOLFSSL_SLHDSA_PARAM_NO_192)
static int slhdsakey_chain_idx_24(SlhDsaKey* key, byte* sk,
const byte* pk_seed, word32* adrs, byte* addr, const byte* msg, byte* idx,
int j, int cnt, byte* sig)
{
int ret = 0;
if (j != 0) {
ret = slhdsakey_chain_idx_x4_24(sk, 0, j, pk_seed, addr, idx,
key->heap);
}
if (ret == 0) {
if (cnt > 3) {
XMEMCPY(sig + idx[3] * 24, sk + 3 * 24, 24);
}
if (msg[idx[2]] != j) {
ret = slhdsakey_chain_idx_x4_24(sk, j, msg[idx[2]] - j, pk_seed,
addr, idx, key->heap);
j = msg[idx[2]];
}
}
if (ret == 0) {
XMEMCPY(sig + idx[2] * 24, sk + 2 * 24, 24);
if (msg[idx[1]] != j) {
ret = slhdsakey_chain_idx_x4_24(sk, j, msg[idx[1]] - j, pk_seed,
addr, idx, key->heap);
j = msg[idx[1]];
}
}
if (ret == 0) {
XMEMCPY(sig + idx[1] * 24, sk + 1 * 24, 24);
if (msg[idx[0]] != j) {
HA_SetChainAddress(adrs, idx[0]);
ret = slhdsakey_chain(key, sk, j, msg[idx[0]] - j, pk_seed, adrs,
sk);
}
}
if (ret == 0) {
XMEMCPY(sig + idx[0] * 24, sk + 0 * 24, 24);
}
return ret;
}
#endif
#if !defined(WOLFSSL_SLHDSA_PARAM_NO_256)
static int slhdsakey_chain_idx_32(SlhDsaKey* key, byte* sk,
const byte* pk_seed, word32* adrs, byte* addr, const byte* msg, byte* idx,
int j, int cnt, byte* sig)
{
int ret = 0;
if (j != 0) {
ret = slhdsakey_chain_idx_x4_32(sk, 0, j, pk_seed, addr, idx,
key->heap);
}
if (ret == 0) {
if (cnt > 3) {
XMEMCPY(sig + idx[3] * 32, sk + 3 * 32, 32);
}
if (msg[idx[2]] != j) {
ret = slhdsakey_chain_idx_x4_32(sk, j, msg[idx[2]] - j, pk_seed,
addr, idx, key->heap);
j = msg[idx[2]];
}
}
if (ret == 0) {
XMEMCPY(sig + idx[2] * 32, sk + 2 * 32, 32);
if (msg[idx[1]] != j) {
ret = slhdsakey_chain_idx_x4_32(sk, j, msg[idx[1]] - j, pk_seed,
addr, idx, key->heap);
j = msg[idx[1]];
}
}
if (ret == 0) {
XMEMCPY(sig + idx[1] * 32, sk + 1 * 32, 32);
if (msg[idx[0]] != j) {
HA_SetChainAddress(adrs, idx[0]);
ret = slhdsakey_chain(key, sk, j, msg[idx[0]] - j, pk_seed, adrs,
sk);
}
}
if (ret == 0) {
XMEMCPY(sig + idx[0] * 32, sk + 0 * 32, 32);
}
return ret;
}
#endif
#endif
#if defined(USE_INTEL_SPEEDUP) && !defined(WOLFSSL_WC_SLHDSA_SMALL)
#if !defined(WOLFSSL_SLHDSA_PARAM_NO_128)
static int slhdsakey_wots_pkgen_chain_x4_16(SlhDsaKey* key, const byte* sk_seed,
const byte* pk_seed, byte* addr, byte* sk_addr)
{
int ret = 0;
int i;
byte len = key->params->len;
WC_DECLARE_VAR(sk, byte, (SLHDSA_MAX_MSG_SZ + 3) * 16, key->heap);
WC_ALLOC_VAR_EX(sk, byte, (SLHDSA_MAX_MSG_SZ + 3) * 16, key->heap,
DYNAMIC_TYPE_SLHDSA, ret = MEMORY_E);
if (ret == 0) {
for (i = 0; i < len - 3; i += 4) {
ret = slhdsakey_hash_prf_x4(pk_seed, sk_seed, sk_addr, 16, i,
sk + i * 16, key->heap);
if (ret != 0) {
break;
}
ret = slhdsakey_chain_x4_16(sk + i * 16, pk_seed, addr, i,
key->heap);
if (ret != 0) {
break;
}
}
}
if (ret == 0) {
ret = slhdsakey_hash_prf_x4(pk_seed, sk_seed, sk_addr, 16, i,
sk + i * 16, key->heap);
if (ret == 0) {
ret = slhdsakey_chain_x4_16(sk + i * 16, pk_seed, addr, i,
key->heap);
}
}
if (ret == 0) {
ret = slhdsakey_hash_update(&key->shake2, sk, len * 16);
}
WC_FREE_VAR_EX(sk, key->heap, DYNAMIC_TYPE_SLHDSA);
return ret;
}
#endif
#if !defined(WOLFSSL_SLHDSA_PARAM_NO_192)
static int slhdsakey_wots_pkgen_chain_x4_24(SlhDsaKey* key, const byte* sk_seed,
const byte* pk_seed, byte* addr, byte* sk_addr)
{
int ret = 0;
int i;
byte len = key->params->len;
WC_DECLARE_VAR(sk, byte, (SLHDSA_MAX_MSG_SZ + 3) * 24, key->heap);
WC_ALLOC_VAR_EX(sk, byte, (SLHDSA_MAX_MSG_SZ + 3) * 24, key->heap,
DYNAMIC_TYPE_SLHDSA, ret = MEMORY_E);
if (ret == 0) {
for (i = 0; i < len - 3; i += 4) {
ret = slhdsakey_hash_prf_x4(pk_seed, sk_seed, sk_addr, 24, i,
sk + i * 24, key->heap);
if (ret != 0) {
break;
}
ret = slhdsakey_chain_x4_24(sk + i * 24, pk_seed, addr, i,
key->heap);
if (ret != 0) {
break;
}
}
}
if (ret == 0) {
ret = slhdsakey_hash_prf_x4(pk_seed, sk_seed, sk_addr, 24, i,
sk + i * 24, key->heap);
if (ret == 0) {
ret = slhdsakey_chain_x4_24(sk + i * 24, pk_seed, addr, i,
key->heap);
}
}
if (ret == 0) {
ret = slhdsakey_hash_update(&key->shake2, sk, len * 24);
}
WC_FREE_VAR_EX(sk, key->heap, DYNAMIC_TYPE_SLHDSA);
return ret;
}
#endif
#if !defined(WOLFSSL_SLHDSA_PARAM_NO_256)
static int slhdsakey_wots_pkgen_chain_x4_32(SlhDsaKey* key, const byte* sk_seed,
const byte* pk_seed, byte* addr, byte* sk_addr)
{
int ret = 0;
int i;
byte len = key->params->len;
WC_DECLARE_VAR(sk, byte, (SLHDSA_MAX_MSG_SZ + 3) * 32, key->heap);
WC_ALLOC_VAR_EX(sk, byte, (SLHDSA_MAX_MSG_SZ + 3) * 32, key->heap,
DYNAMIC_TYPE_SLHDSA, ret = MEMORY_E);
if (ret == 0) {
for (i = 0; i < len - 3; i += 4) {
ret = slhdsakey_hash_prf_x4(pk_seed, sk_seed, sk_addr, 32, i,
sk + i * 32, key->heap);
if (ret != 0) {
break;
}
ret = slhdsakey_chain_x4_32(sk + i * 32, pk_seed, addr, i,
key->heap);
if (ret != 0) {
break;
}
}
}
if (ret == 0) {
ret = slhdsakey_hash_prf_x4(pk_seed, sk_seed, sk_addr, 32, i,
sk + i * 32, key->heap);
if (ret == 0) {
ret = slhdsakey_chain_x4_32(sk + i * 32, pk_seed, addr, i,
key->heap);
}
}
if (ret == 0) {
ret = slhdsakey_hash_update(&key->shake2, sk, len * 32);
}
WC_FREE_VAR_EX(sk, key->heap, DYNAMIC_TYPE_SLHDSA);
return ret;
}
#endif
static int slhdsakey_wots_pkgen_chain_x4(SlhDsaKey* key, const byte* sk_seed,
const byte* pk_seed, word32* adrs, word32* sk_adrs)
{
int ret = 0;
byte sk_addr[SLHDSA_HA_SZ];
byte addr[SLHDSA_HA_SZ];
byte n = key->params->n;
HA_SetHashAddress(sk_adrs, 0);
HA_Encode(sk_adrs, sk_addr);
HA_Encode(adrs, addr);
#if !defined(WOLFSSL_SLHDSA_PARAM_NO_128)
if (n == 16) {
ret = slhdsakey_wots_pkgen_chain_x4_16(key, sk_seed, pk_seed, addr,
sk_addr);
}
else
#endif
#if !defined(WOLFSSL_SLHDSA_PARAM_NO_192)
if (n == 24) {
ret = slhdsakey_wots_pkgen_chain_x4_24(key, sk_seed, pk_seed, addr,
sk_addr);
}
else
#endif
#if !defined(WOLFSSL_SLHDSA_PARAM_NO_256)
if (n == 32) {
ret = slhdsakey_wots_pkgen_chain_x4_32(key, sk_seed, pk_seed, addr,
sk_addr);
}
else
#endif
if (ret == 0) {
ret = NOT_COMPILED_IN;
}
return ret;
}
#endif
static int slhdsakey_wots_pkgen_chain_c(SlhDsaKey* key, const byte* sk_seed,
const byte* pk_seed, word32* adrs, word32* sk_adrs)
{
int ret = 0;
int i;
byte n = key->params->n;
byte len = key->params->len;
#if !defined(WOLFSSL_WC_SLHDSA_SMALL_MEM)
WC_DECLARE_VAR(sk, byte, (SLHDSA_MAX_MSG_SZ + 3) * SLHDSA_MAX_N, key->heap);
WC_ALLOC_VAR_EX(sk, byte, (SLHDSA_MAX_MSG_SZ + 3) * SLHDSA_MAX_N,
key->heap, DYNAMIC_TYPE_SLHDSA, ret = MEMORY_E);
if (ret == 0) {
for (i = 0; i < len; i++) {
HA_SetChainAddress(sk_adrs, i);
ret = HASH_PRF(&key->shake, pk_seed, sk_seed, sk_adrs, n,
sk + i * n);
if (ret != 0) {
break;
}
HA_SetChainAddress(adrs, i);
ret = slhdsakey_chain(key, sk + i * n, 0, SLHDSA_WM1, pk_seed, adrs,
sk + i * n);
if (ret != 0) {
break;
}
}
}
if (ret == 0) {
ret = slhdsakey_hash_update(&key->shake2, sk, len * n);
}
WC_FREE_VAR_EX(sk, key->heap, DYNAMIC_TYPE_SLHDSA);
#else
for (i = 0; i < len; i++) {
byte sk[SLHDSA_MAX_N];
HA_SetChainAddress(sk_adrs, i);
ret = HASH_PRF(&key->shake, pk_seed, sk_seed, sk_adrs, n, sk);
if (ret != 0) {
break;
}
HA_SetChainAddress(adrs, i);
ret = slhdsakey_chain(key, sk, 0, SLHDSA_WM1, pk_seed, adrs, sk);
if (ret != 0) {
break;
}
ret = slhdsakey_hash_update(&key->shake2, sk, n);
if (ret != 0) {
break;
}
}
#endif
return ret;
}
static int slhdsakey_wots_pkgen(SlhDsaKey* key, const byte* sk_seed,
const byte* pk_seed, word32* adrs, byte* node)
{
int ret;
byte n = key->params->n;
{
HashAddress wotspk_adrs;
HA_Copy(wotspk_adrs, adrs);
HA_SetTypeAndClearNotKPA(wotspk_adrs, HA_WOTS_PK);
ret = slhdsakey_hash_start_addr(&key->shake2, pk_seed, wotspk_adrs, n);
}
if (ret == 0) {
HashAddress sk_adrs;
HA_Copy(sk_adrs, adrs);
HA_SetTypeAndClearNotKPA(sk_adrs, HA_WOTS_PRF);
#if defined(USE_INTEL_SPEEDUP) && !defined(WOLFSSL_WC_SLHDSA_SMALL)
if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) {
ret = slhdsakey_wots_pkgen_chain_x4(key, sk_seed, pk_seed, adrs,
sk_adrs);
RESTORE_VECTOR_REGISTERS();
}
else
#endif
{
ret = slhdsakey_wots_pkgen_chain_c(key, sk_seed, pk_seed, adrs,
sk_adrs);
}
}
if (ret == 0) {
ret = slhdsakey_hash_final(&key->shake2, node, n);
}
return ret;
}
#if defined(USE_INTEL_SPEEDUP) && !defined(WOLFSSL_WC_SLHDSA_SMALL)
#if !defined(WOLFSSL_SLHDSA_PARAM_NO_128)
static int slhdsakey_wots_sign_chain_x4_16(SlhDsaKey* key, const byte* msg,
const byte* sk_seed, const byte* pk_seed, word32* adrs, byte* addr,
byte* sk_addr, byte* sig)
{
int ret = 0;
int i;
sword8 j;
byte ii;
byte idx[4] = {0};
byte n = key->params->n;
byte len = key->params->len;
WC_DECLARE_VAR(sk, byte, 4 * 16, key->heap);
WC_ALLOC_VAR_EX(sk, byte, 4 * 16, key->heap, DYNAMIC_TYPE_SLHDSA,
ret = MEMORY_E);
if (ret == 0) {
ii = 0;
for (j = SLHDSA_WM1; j >= 0; j--) {
for (i = 0; i < len; i++) {
if ((sword8)msg[i] == j) {
idx[ii++] = i;
if (ii == 4) {
ret = slhdsakey_hash_prf_idx_x4(pk_seed, sk_seed,
sk_addr, n, idx, sk, key->heap);
if (ret != 0) {
break;
}
ret = slhdsakey_chain_idx_16(key, sk, pk_seed, adrs,
addr, msg, idx, j, 4, sig);
if (ret != 0) {
break;
}
ii = 0;
}
}
}
}
}
if (ret == 0) {
ret = slhdsakey_hash_prf_idx_x4(pk_seed, sk_seed, sk_addr, n, idx, sk,
key->heap);
}
if (ret == 0) {
j = min(min(msg[idx[0]], msg[idx[1]]), msg[idx[2]]);
ret = slhdsakey_chain_idx_16(key, sk, pk_seed, adrs, addr, msg, idx, j,
3, sig);
}
WC_FREE_VAR_EX(sk, key->heap, DYNAMIC_TYPE_SLHDSA);
return ret;
}
#endif
#if !defined(WOLFSSL_SLHDSA_PARAM_NO_192)
static int slhdsakey_wots_sign_chain_x4_24(SlhDsaKey* key, const byte* msg,
const byte* sk_seed, const byte* pk_seed, word32* adrs, byte* addr,
byte* sk_addr, byte* sig)
{
int ret = 0;
int i;
sword8 j;
byte ii;
byte idx[4] = {0};
byte n = key->params->n;
byte len = key->params->len;
WC_DECLARE_VAR(sk, byte, 4 * 24, key->heap);
WC_ALLOC_VAR_EX(sk, byte, 4 * 24, key->heap, DYNAMIC_TYPE_SLHDSA,
ret = MEMORY_E);
if (ret == 0) {
ii = 0;
for (j = SLHDSA_WM1; j >= 0; j--) {
for (i = 0; i < len; i++) {
if ((sword8)msg[i] == j) {
idx[ii++] = i;
if (ii == 4) {
ret = slhdsakey_hash_prf_idx_x4(pk_seed, sk_seed,
sk_addr, n, idx, sk, key->heap);
if (ret != 0) {
break;
}
ret = slhdsakey_chain_idx_24(key, sk, pk_seed, adrs,
addr, msg, idx, j, 4, sig);
if (ret != 0) {
break;
}
ii = 0;
}
}
}
}
}
if (ret == 0) {
ret = slhdsakey_hash_prf_idx_x4(pk_seed, sk_seed, sk_addr, n, idx, sk,
key->heap);
}
if (ret == 0) {
j = min(min(msg[idx[0]], msg[idx[1]]), msg[idx[2]]);
ret = slhdsakey_chain_idx_24(key, sk, pk_seed, adrs, addr,
msg, idx, j, 3, sig);
}
WC_FREE_VAR_EX(sk, key->heap, DYNAMIC_TYPE_SLHDSA);
return ret;
}
#endif
#if !defined(WOLFSSL_SLHDSA_PARAM_NO_256)
static int slhdsakey_wots_sign_chain_x4_32(SlhDsaKey* key, const byte* msg,
const byte* sk_seed, const byte* pk_seed, word32* adrs, byte* addr,
byte* sk_addr, byte* sig)
{
int ret = 0;
int i;
sword8 j;
byte ii;
byte idx[4] = {0};
byte n = key->params->n;
byte len = key->params->len;
WC_DECLARE_VAR(sk, byte, 4 * 32, key->heap);
WC_ALLOC_VAR_EX(sk, byte, 4 * 32, key->heap, DYNAMIC_TYPE_SLHDSA,
ret = MEMORY_E);
if (ret == 0) {
ii = 0;
for (j = SLHDSA_WM1; j >= 0; j--) {
for (i = 0; i < len; i++) {
if ((sword8)msg[i] == j) {
idx[ii++] = i;
if (ii == 4) {
ret = slhdsakey_hash_prf_idx_x4(pk_seed, sk_seed,
sk_addr, n, idx, sk, key->heap);
if (ret != 0) {
break;
}
ret = slhdsakey_chain_idx_32(key, sk, pk_seed, adrs,
addr, msg, idx, j, 4, sig);
if (ret != 0) {
break;
}
ii = 0;
}
}
}
}
}
if (ret == 0) {
ret = slhdsakey_hash_prf_idx_x4(pk_seed, sk_seed, sk_addr, n, idx, sk,
key->heap);
}
if (ret == 0) {
j = min(min(msg[idx[0]], msg[idx[1]]), msg[idx[2]]);
ret = slhdsakey_chain_idx_32(key, sk, pk_seed, adrs, addr, msg, idx, j,
3, sig);
}
if (ret == 0) {
sig += len * n;
}
WC_FREE_VAR_EX(sk, key->heap, DYNAMIC_TYPE_SLHDSA);
return ret;
}
#endif
static int slhdsakey_wots_sign_chain_x4(SlhDsaKey* key, const byte* msg,
const byte* sk_seed, const byte* pk_seed, word32* adrs, word32* sk_adrs,
byte* sig)
{
int ret = 0;
byte sk_addr[SLHDSA_HA_SZ];
byte addr[SLHDSA_HA_SZ];
byte n = key->params->n;
HA_SetHashAddress(sk_adrs, 0);
HA_Encode(sk_adrs, sk_addr);
HA_Encode(adrs, addr);
#if !defined(WOLFSSL_SLHDSA_PARAM_NO_128)
if (n == 16) {
ret = slhdsakey_wots_sign_chain_x4_16(key, msg, sk_seed, pk_seed, adrs,
addr, sk_addr, sig);
}
else
#endif
#if !defined(WOLFSSL_SLHDSA_PARAM_NO_192)
if (n == 24) {
ret = slhdsakey_wots_sign_chain_x4_24(key, msg, sk_seed, pk_seed, adrs,
addr, sk_addr, sig);
}
else
#endif
#if !defined(WOLFSSL_SLHDSA_PARAM_NO_256)
if (n == 32) {
ret = slhdsakey_wots_sign_chain_x4_32(key, msg, sk_seed, pk_seed, adrs,
addr, sk_addr, sig);
}
else
#endif
if (ret == 0) {
ret = NOT_COMPILED_IN;
}
return ret;
}
#endif
static int slhdsakey_wots_sign(SlhDsaKey* key, const byte* m,
const byte* sk_seed, const byte* pk_seed, word32* adrs, byte* sig)
{
int ret;
word16 csum;
HashAddress sk_adrs;
byte n = key->params->n;
byte len = key->params->len;
int i;
byte msg[SLHDSA_MAX_MSG_SZ];
csum = 0;
for (i = 0; i < n * 2; i += 2) {
msg[i+0] = (m[i / 2] >> 4) & 0xf;
csum += SLHDSA_WM1 - msg[i + 0];
msg[i+1] = m[i / 2] & 0xf;
csum += SLHDSA_WM1 - msg[i + 1];
}
msg[i + 0] = (csum >> 8) & 0xf;
msg[i + 1] = (csum >> 4) & 0xf;
msg[i + 2] = csum & 0xf;
HA_Copy(sk_adrs, adrs);
HA_SetTypeAndClearNotKPA(sk_adrs, HA_WOTS_PRF);
#if defined(USE_INTEL_SPEEDUP) && !defined(WOLFSSL_WC_SLHDSA_SMALL)
if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) {
ret = slhdsakey_wots_sign_chain_x4(key, msg, sk_seed, pk_seed, adrs,
sk_adrs, sig);
RESTORE_VECTOR_REGISTERS();
}
else
#endif
{
for (i = 0; i < len; i++) {
HA_SetChainAddress(sk_adrs, i);
ret = HASH_PRF(&key->shake, pk_seed, sk_seed, sk_adrs, n, sig);
if (ret != 0) {
break;
}
HA_SetChainAddress(adrs, i);
ret = slhdsakey_chain(key, sig, 0, msg[i], pk_seed, adrs, sig);
if (ret != 0) {
break;
}
sig += n;
}
}
return ret;
}
#endif
#if defined(USE_INTEL_SPEEDUP) && !defined(WOLFSSL_WC_SLHDSA_SMALL)
#if !defined(WOLFSSL_SLHDSA_PARAM_NO_128)
static int slhdsakey_chain_idx_to_max_16(SlhDsaKey* key, const byte* sig,
const byte* pk_seed, word32* adrs, const byte* msg, byte* idx, int j,
int cnt, byte* nodes)
{
int ret = 0;
byte node[4 * 16];
byte addr[SLHDSA_HA_SZ];
HA_SetChainAddress(adrs, idx[0]);
HA_Encode(adrs, addr);
XMEMCPY(node + 0 * 16, sig + idx[0] * 16, 16);
if ((msg[idx[0]] != j) && (msg[idx[0]] != msg[idx[1]])) {
ret = slhdsakey_chain(key, node, msg[idx[0]],
msg[idx[1]] - msg[idx[0]], pk_seed, adrs, node);
}
if (ret == 0) {
XMEMCPY(node + 1 * 16, sig + idx[1] * 16, 16);
XMEMSET(node + 2 * 16, 0, sizeof(node) - 2 * 16);
if ((msg[idx[1]] != j) && (msg[idx[1]] != msg[idx[2]])) {
ret = slhdsakey_chain_idx_x4_16(node, msg[idx[1]],
msg[idx[2]] - msg[idx[1]], pk_seed, addr, idx, key->heap);
}
}
if (ret == 0) {
XMEMCPY(node + 2 * 16, sig + idx[2] * 16, 16);
if ((cnt > 3) && (msg[idx[2]] != j)) {
ret = slhdsakey_chain_idx_x4_16(node, msg[idx[2]],
j - msg[idx[2]], pk_seed, addr, idx, key->heap);
}
}
if (ret == 0) {
if (cnt > 3) {
XMEMCPY(node + 3 * 16, sig + idx[3] * 16, 16);
}
if (j != SLHDSA_WM1) {
ret = slhdsakey_chain_idx_x4_16(node, j, SLHDSA_WM1 - j, pk_seed,
addr, idx, key->heap);
}
}
if (ret == 0) {
XMEMCPY(nodes + idx[0] * 16, node + 0 * 16, 16);
XMEMCPY(nodes + idx[1] * 16, node + 1 * 16, 16);
XMEMCPY(nodes + idx[2] * 16, node + 2 * 16, 16);
if (cnt > 3) {
XMEMCPY(nodes + idx[3] * 16, node + 3 * 16, 16);
}
}
return ret;
}
#endif
#if !defined(WOLFSSL_SLHDSA_PARAM_NO_192)
static int slhdsakey_chain_idx_to_max_24(SlhDsaKey* key, const byte* sig,
const byte* pk_seed, word32* adrs, const byte* msg, byte* idx, int j,
int cnt, byte* nodes)
{
int ret = 0;
byte node[4 * 24];
byte addr[SLHDSA_HA_SZ];
HA_SetChainAddress(adrs, idx[0]);
HA_Encode(adrs, addr);
XMEMCPY(node + 0 * 24, sig + idx[0] * 24, 24);
if ((msg[idx[0]] != j) && (msg[idx[0]] != msg[idx[1]])) {
ret = slhdsakey_chain(key, node, msg[idx[0]],
msg[idx[1]] - msg[idx[0]], pk_seed, adrs, node);
}
if (ret == 0) {
XMEMCPY(node + 1 * 24, sig + idx[1] * 24, 24);
XMEMSET(node + 2 * 24, 0, sizeof(node) - 2 * 24);
if ((msg[idx[1]] != j) && (msg[idx[1]] != msg[idx[2]])) {
ret = slhdsakey_chain_idx_x4_24(node, msg[idx[1]],
msg[idx[2]] - msg[idx[1]], pk_seed, addr, idx, key->heap);
}
}
if (ret == 0) {
XMEMCPY(node + 2 * 24, sig + idx[2] * 24, 24);
if ((cnt > 3) && (msg[idx[2]] != j)) {
ret = slhdsakey_chain_idx_x4_24(node, msg[idx[2]],
j - msg[idx[2]], pk_seed, addr, idx, key->heap);
}
}
if (ret == 0) {
if (cnt > 3) {
XMEMCPY(node + 3 * 24, sig + idx[3] * 24, 24);
}
if (j != SLHDSA_WM1) {
ret = slhdsakey_chain_idx_x4_24(node, j, SLHDSA_WM1 - j, pk_seed,
addr, idx, key->heap);
}
}
if (ret == 0) {
XMEMCPY(nodes + idx[0] * 24, node + 0 * 24, 24);
XMEMCPY(nodes + idx[1] * 24, node + 1 * 24, 24);
XMEMCPY(nodes + idx[2] * 24, node + 2 * 24, 24);
if (cnt > 3) {
XMEMCPY(nodes + idx[3] * 24, node + 3 * 24, 24);
}
}
return ret;
}
#endif
#if !defined(WOLFSSL_SLHDSA_PARAM_NO_256)
static int slhdsakey_chain_idx_to_max_32(SlhDsaKey* key, const byte* sig,
const byte* pk_seed, word32* adrs, const byte* msg, byte* idx, int j,
int cnt, byte* nodes)
{
int ret = 0;
byte node[4 * 32];
byte addr[SLHDSA_HA_SZ];
HA_SetChainAddress(adrs, idx[0]);
HA_Encode(adrs, addr);
XMEMCPY(node + 0 * 32, sig + idx[0] * 32, 32);
if ((msg[idx[0]] != j) && (msg[idx[0]] != msg[idx[1]])) {
ret = slhdsakey_chain(key, node, msg[idx[0]],
msg[idx[1]] - msg[idx[0]], pk_seed, adrs, node);
}
if (ret == 0) {
XMEMCPY(node + 1 * 32, sig + idx[1] * 32, 32);
XMEMSET(node + 2 * 32, 0, sizeof(node) - 2 * 32);
if ((msg[idx[1]] != j) && (msg[idx[1]] != msg[idx[2]])) {
ret = slhdsakey_chain_idx_x4_32(node, msg[idx[1]],
msg[idx[2]] - msg[idx[1]], pk_seed, addr, idx, key->heap);
}
}
if (ret == 0) {
XMEMCPY(node + 2 * 32, sig + idx[2] * 32, 32);
if ((cnt > 3) && (msg[idx[2]] != j)) {
ret = slhdsakey_chain_idx_x4_32(node, msg[idx[2]],
j - msg[idx[2]], pk_seed, addr, idx, key->heap);
}
}
if (ret == 0) {
if (cnt > 3) {
XMEMCPY(node + 3 * 32, sig + idx[3] * 32, 32);
}
if (j != SLHDSA_WM1) {
ret = slhdsakey_chain_idx_x4_32(node, j, SLHDSA_WM1 - j, pk_seed,
addr, idx, key->heap);
}
}
if (ret == 0) {
XMEMCPY(nodes + idx[0] * 32, node + 0 * 32, 32);
XMEMCPY(nodes + idx[1] * 32, node + 1 * 32, 32);
XMEMCPY(nodes + idx[2] * 32, node + 2 * 32, 32);
if (cnt > 3) {
XMEMCPY(nodes + idx[3] * 32, node + 3 * 32, 32);
}
}
return ret;
}
#endif
#endif
#if defined(USE_INTEL_SPEEDUP) && !defined(WOLFSSL_WC_SLHDSA_SMALL)
static int slhdsakey_wots_pk_from_sig_x4(SlhDsaKey* key, const byte* sig,
const byte* msg, const byte* pk_seed, word32* adrs, byte* pk_sig)
{
int ret = 0;
byte idx[4] = {0};
int i;
byte ii;
sword8 j;
HashAddress wotspk_adrs;
byte n = key->params->n;
byte len = key->params->len;
WC_DECLARE_VAR(nodes, byte, SLHDSA_MAX_MSG_SZ * SLHDSA_MAX_N, key->heap);
WC_ALLOC_VAR_EX(nodes, byte, SLHDSA_MAX_MSG_SZ * SLHDSA_MAX_N, key->heap,
DYNAMIC_TYPE_SLHDSA, ret = MEMORY_E);
#if !defined(WOLFSSL_SLHDSA_PARAM_NO_128)
if ((ret == 0) && (n == 16)) {
ii = 0;
for (j = 0; j <= SLHDSA_WM1; j++) {
for (i = 0; i < len; i++) {
if ((sword8)msg[i] == j) {
idx[ii++] = i;
if (ii == 4) {
ret = slhdsakey_chain_idx_to_max_16(key, sig,
pk_seed, adrs, msg, idx, j, 4, nodes);
if (ret != 0) {
break;
}
ii = 0;
}
}
}
}
if (ret == 0) {
j = max(max(msg[idx[0]], msg[idx[1]]), msg[idx[2]]);
ret = slhdsakey_chain_idx_to_max_16(key, sig, pk_seed, adrs, msg,
idx, j, 3, nodes);
}
}
else
#endif
#if !defined(WOLFSSL_SLHDSA_PARAM_NO_192)
if ((ret == 0) && (n == 24)) {
ii = 0;
for (j = 0; j <= SLHDSA_WM1; j++) {
for (i = 0; i < len; i++) {
if ((sword8)msg[i] == j) {
idx[ii++] = i;
if (ii == 4) {
ret = slhdsakey_chain_idx_to_max_24(key, sig,
pk_seed, adrs, msg, idx, j, 4, nodes);
if (ret != 0) {
break;
}
ii = 0;
}
}
}
}
if (ret == 0) {
j = max(max(msg[idx[0]], msg[idx[1]]), msg[idx[2]]);
ret = slhdsakey_chain_idx_to_max_24(key, sig, pk_seed, adrs, msg,
idx, j, 3, nodes);
}
}
else
#endif
#if !defined(WOLFSSL_SLHDSA_PARAM_NO_256)
if ((ret == 0) && (n == 32)) {
ii = 0;
for (j = 0; j <= SLHDSA_WM1; j++) {
for (i = 0; i < len; i++) {
if ((sword8)msg[i] == j) {
idx[ii++] = i;
if (ii == 4) {
ret = slhdsakey_chain_idx_to_max_32(key, sig,
pk_seed, adrs, msg, idx, j, 4, nodes);
if (ret != 0) {
break;
}
ii = 0;
}
}
}
}
if (ret == 0) {
j = max(max(msg[idx[0]], msg[idx[1]]), msg[idx[2]]);
ret = slhdsakey_chain_idx_to_max_32(key, sig, pk_seed, adrs, msg,
idx, j, 3, nodes);
}
}
else
#endif
if (ret == 0) {
ret = NOT_COMPILED_IN;
}
if (ret == 0) {
HA_Copy(wotspk_adrs, adrs);
HA_SetTypeAndClearNotKPA(wotspk_adrs, HA_WOTS_PK);
ret = slhdsakey_hash_start_addr(&key->shake2, pk_seed, wotspk_adrs, n);
}
if (ret == 0) {
ret = slhdsakey_hash_update(&key->shake2, nodes, len * n);
sig += len * n;
}
if (ret == 0) {
ret = slhdsakey_hash_final(&key->shake2, pk_sig, n);
}
WC_FREE_VAR_EX(nodes, key->heap, DYNAMIC_TYPE_SLHDSA);
return ret;
}
#endif
#if !defined(WOLFSSL_WC_SLHDSA_SMALL_MEM)
static int slhdsakey_wots_pk_from_sig_c(SlhDsaKey* key, const byte* sig,
const byte* msg, const byte* pk_seed, word32* adrs, byte* pk_sig)
{
int ret = 0;
int i;
byte n = key->params->n;
byte len = key->params->len;
HashAddress wotspk_adrs;
WC_DECLARE_VAR(nodes, byte, SLHDSA_MAX_MSG_SZ * SLHDSA_MAX_N, key->heap);
WC_ALLOC_VAR_EX(nodes, byte, SLHDSA_MAX_MSG_SZ * SLHDSA_MAX_N, key->heap,
DYNAMIC_TYPE_SLHDSA, ret = MEMORY_E);
if (ret == 0) {
for (i = 0; i < len; i++) {
HA_SetChainAddress(adrs, i);
ret = slhdsakey_chain(key, sig, msg[i], SLHDSA_WM1 - msg[i],
pk_seed, adrs, nodes + i * n);
if (ret != 0) {
break;
}
sig += n;
}
}
if (ret == 0) {
HA_Copy(wotspk_adrs, adrs);
HA_SetTypeAndClearNotKPA(wotspk_adrs, HA_WOTS_PK);
ret = slhdsakey_hash_start_addr(&key->shake2, pk_seed, wotspk_adrs, n);
}
if (ret == 0) {
ret = slhdsakey_hash_update(&key->shake2, nodes, len * n);
}
if (ret == 0) {
ret = slhdsakey_hash_final(&key->shake2, pk_sig, n);
}
WC_FREE_VAR_EX(nodes, key->heap, DYNAMIC_TYPE_SLHDSA);
return ret;
}
#else
static int slhdsakey_wots_pk_from_sig_c(SlhDsaKey* key, const byte* sig,
const byte* msg, const byte* pk_seed, word32* adrs, byte* pk_sig)
{
int ret;
int i;
byte n = key->params->n;
byte len = key->params->len;
HashAddress wotspk_adrs;
byte* node = pk_sig;
HA_Copy(wotspk_adrs, adrs);
HA_SetTypeAndClearNotKPA(wotspk_adrs, HA_WOTS_PK);
ret = slhdsakey_hash_start_addr(&key->shake2, pk_seed, wotspk_adrs, n);
if (ret == 0) {
for (i = 0; i < len; i++) {
HA_SetChainAddress(adrs, i);
ret = slhdsakey_chain(key, sig, msg[i], SLHDSA_WM1 - msg[i],
pk_seed, adrs, node);
if (ret != 0) {
break;
}
ret = slhdsakey_hash_update(&key->shake2, node, n);
if (ret != 0) {
break;
}
sig += n;
}
}
if (ret == 0) {
ret = slhdsakey_hash_final(&key->shake2, pk_sig, n);
}
return ret;
}
#endif
static int slhdsakey_wots_pk_from_sig(SlhDsaKey* key, const byte* sig,
const byte* m, const byte* pk_seed, word32* adrs, byte* pk_sig)
{
int ret;
word16 csum;
byte n = key->params->n;
int i;
byte msg[SLHDSA_MAX_MSG_SZ];
csum = 0;
for (i = 0; i < n * 2; i += 2) {
msg[i+0] = (m[i / 2] >> 4) & 0xf;
csum += SLHDSA_WM1 - msg[i + 0];
msg[i+1] = m[i / 2] & 0xf;
csum += SLHDSA_WM1 - msg[i + 1];
}
msg[i + 0] = (csum >> 8) & 0xf;
msg[i + 1] = (csum >> 4) & 0xf;
msg[i + 2] = csum & 0xf;
#if defined(USE_INTEL_SPEEDUP) && !defined(WOLFSSL_WC_SLHDSA_SMALL)
if (IS_INTEL_AVX2(cpuid_flags) && (SAVE_VECTOR_REGISTERS2() == 0)) {
ret = slhdsakey_wots_pk_from_sig_x4(key, sig, msg, pk_seed, adrs,
pk_sig);
RESTORE_VECTOR_REGISTERS();
}
else
#endif
{
ret = slhdsakey_wots_pk_from_sig_c(key, sig, msg, pk_seed, adrs,
pk_sig);
}
return ret;
}
#ifndef WOLFSSL_SLHDSA_VERIFY_ONLY
#ifndef WOLFSSL_WC_SLHDSA_RECURSIVE
static int slhdsakey_xmss_node(SlhDsaKey* key, const byte* sk_seed, int i,
int z, const byte* pk_seed, word32* adrs, byte* node)
{
int ret = 0;
if (z == 0) {
HA_SetTypeAndClearNotKPA(adrs, HA_WOTS_HASH);
HA_SetKeyPairAddress(adrs, i);
ret = slhdsakey_wots_pkgen(key, sk_seed, pk_seed, adrs, node);
}
else {
WC_DECLARE_VAR(nodes, byte, (SLHDSA_MAX_H_M + 2) * SLHDSA_MAX_N,
key->heap);
word32 j;
word32 k;
word32 m = (word32)1 << z;
byte n = key->params->n;
WC_ALLOC_VAR_EX(nodes, byte, (SLHDSA_MAX_H_M + 2) * SLHDSA_MAX_N,
key->heap, DYNAMIC_TYPE_SLHDSA, ret = MEMORY_E);
if (ret == 0) {
for (j = 0; j < m; j++) {
HA_SetTypeAndClearNotKPA(adrs, HA_WOTS_HASH);
HA_SetKeyPairAddress(adrs, m * i + j);
ret = slhdsakey_wots_pkgen(key, sk_seed, pk_seed, adrs,
nodes + (z - 1 + (j & 1)) * n);
if (ret != 0) {
break;
}
for (k = z-1; k > 0; k--) {
if (((j >> (z-1-k)) & 1) == 1) {
HA_SetTypeAndClear(adrs, HA_TREE);
HA_SetTreeHeight(adrs, z - k);
HA_SetTreeIndex(adrs, (m * i + j) >> (z - k));
ret = HASH_H(&key->shake, pk_seed, adrs, nodes + k * n,
n, nodes + (k - 1 + ((j >> (z-k)) & 1)) * n);
if (ret != 0) {
break;
}
}
else {
break;
}
}
if (ret != 0) {
break;
}
}
if (ret == 0) {
HA_SetTypeAndClear(adrs, HA_TREE);
HA_SetTreeHeight(adrs, z);
HA_SetTreeIndex(adrs, i);
ret = HASH_H(&key->shake, pk_seed, adrs, nodes, n, node);
}
}
WC_FREE_VAR_EX(nodes, key->heap, DYNAMIC_TYPE_SLHDSA);
}
return ret;
}
#else
static int slhdsakey_xmss_node(SlhDsaKey* key, const byte* sk_seed, int i,
int z, const byte* pk_seed, word32* adrs, byte* node)
{
int ret;
byte nodes[2 * SLHDSA_MAX_N];
if (z == 0) {
HA_SetTypeAndClearNotKPA(adrs, HA_WOTS_HASH);
HA_SetKeyPairAddress(adrs, i);
ret = slhdsakey_wots_pkgen(key, sk_seed, pk_seed, adrs, node);
}
else {
byte n = key->params->n;
ret = slhdsakey_xmss_node(key, sk_seed, 2 * i, z - 1, pk_seed, adrs,
nodes);
if (ret == 0) {
ret = slhdsakey_xmss_node(key, sk_seed, 2 * i + 1, z - 1, pk_seed,
adrs, nodes + n);
}
if (ret == 0) {
HA_SetTypeAndClear(adrs, HA_TREE);
HA_SetTreeHeight(adrs, z);
HA_SetTreeIndex(adrs, i);
ret = HASH_H(&key->shake, pk_seed, adrs, nodes, n, node);
}
}
return ret;
}
#endif
static int slhdsakey_xmss_sign(SlhDsaKey* key, const byte* m,
const byte* sk_seed, word32 idx, const byte* pk_seed, word32* adrs,
byte* sig_xmss)
{
int ret;
byte n = key->params->n;
byte len = key->params->len;
byte h_m = key->params->h_m;
byte* auth = sig_xmss + (len * n);
word32 i = idx;
int j;
for (j = 0; j < h_m; j++) {
word32 k = i ^ 1;
ret = slhdsakey_xmss_node(key, sk_seed, k, j, pk_seed, adrs, auth);
if (ret != 0) {
break;
}
auth += n;
i >>= 1;
}
if (ret == 0) {
HA_SetTypeAndClearNotKPA(adrs, HA_WOTS_HASH);
HA_SetKeyPairAddress(adrs, idx);
ret = slhdsakey_wots_sign(key, m, sk_seed, pk_seed, adrs, sig_xmss);
}
return ret;
}
#endif
static int slhdsakey_xmss_pk_from_sig(SlhDsaKey* key, word32 idx,
const byte* sig_xmss, const byte* m, const byte* pk_seed, word32* adrs,
byte* node)
{
int ret;
byte n = key->params->n;
byte h_m = key->params->h_m;
byte len = key->params->len;
const byte* sig = sig_xmss;
const byte* auth = sig_xmss + (len * n);
int k;
HA_SetTypeAndClear(adrs, HA_WOTS_HASH);
HA_SetKeyPairAddress(adrs, idx);
ret = slhdsakey_wots_pk_from_sig(key, sig, m, pk_seed, adrs, node);
if (ret == 0) {
HA_SetTypeAndClear(adrs, HA_TREE);
HA_SetTreeIndex(adrs, idx);
for (k = 0; k < h_m; k++) {
byte side = idx & 1;
idx >>= 1;
HA_SetTreeHeight(adrs, k + 1);
HA_SetTreeIndex(adrs, idx);
if (side == 0) {
ret = HASH_H_2(&key->shake, pk_seed, adrs, node, auth, n, node);
}
else {
ret = HASH_H_2(&key->shake, pk_seed, adrs, auth, node, n, node);
}
if (ret != 0) {
break;
}
auth += n;
}
}
return ret;
}
#ifndef WOLFSSL_SLHDSA_VERIFY_ONLY
static int slhdsakey_ht_sign(SlhDsaKey* key, const byte* pk_fors,
const byte* sk_seed, const byte* pk_seed, word32* idx_tree, word32 idx_leaf,
byte* sig_ht)
{
int ret;
HashAddress adrs;
byte root[SLHDSA_MAX_N];
byte n = key->params->n;
byte h_m = key->params->h_m;
byte len = key->params->len;
byte d = key->params->d;
int j;
word32 mask = ((word32)1 << h_m) - 1;
HA_Init(adrs);
HA_SetTreeAddress(adrs, idx_tree);
ret = slhdsakey_xmss_sign(key, pk_fors, sk_seed, idx_leaf, pk_seed, adrs,
sig_ht);
if (ret == 0) {
ret = slhdsakey_xmss_pk_from_sig(key, idx_leaf, sig_ht, pk_fors,
pk_seed, adrs, root);
sig_ht += (h_m + len) * n;
}
if (ret == 0) {
for (j = 1; j < d; j++) {
idx_leaf = INDEX_TREE_MASK(idx_tree, mask);
INDEX_TREE_SHIFT_DOWN(idx_tree, h_m);
HA_SetLayerAddress(adrs, j);
HA_SetTreeAddress(adrs, idx_tree);
ret = slhdsakey_xmss_sign(key, root, sk_seed, idx_leaf, pk_seed,
adrs, sig_ht);
if (ret != 0) {
break;
}
if (j < d) {
ret = slhdsakey_xmss_pk_from_sig(key, idx_leaf, sig_ht, root,
pk_seed, adrs, root);
if (ret != 0) {
break;
}
}
sig_ht += (h_m + len) * n;
}
}
return ret;
}
#endif
static int slhdsakey_ht_verify(SlhDsaKey* key, const byte* m,
const byte* sig_ht, const byte* pk_seed, word32* idx_tree, word32 idx_leaf,
const byte* pk_root)
{
int ret;
HashAddress adrs;
byte node[SLHDSA_MAX_N];
byte n = key->params->n;
byte h_m = key->params->h_m;
byte len = key->params->len;
byte d = key->params->d;
int j;
word32 mask = ((word32)1 << h_m) - 1;
HA_Init(adrs);
HA_SetTreeAddress(adrs, idx_tree);
ret = slhdsakey_xmss_pk_from_sig(key, idx_leaf, sig_ht, m, pk_seed, adrs,
node);
sig_ht += (h_m + len) * n;
if (ret == 0) {
for (j = 1; j < d; j++) {
idx_leaf = INDEX_TREE_MASK(idx_tree, mask);
INDEX_TREE_SHIFT_DOWN(idx_tree, h_m);
HA_SetLayerAddress(adrs, j);
HA_SetTreeAddress(adrs, idx_tree);
ret = slhdsakey_xmss_pk_from_sig(key, idx_leaf, sig_ht, node,
pk_seed, adrs, node);
if (ret != 0) {
break;
}
sig_ht += (h_m + len) * n;
}
}
if ((ret == 0) && (XMEMCMP(node, pk_root, n) != 0)) {
ret = SIG_VERIFY_E;
}
return ret;
}
#ifndef WOLFSSL_SLHDSA_VERIFY_ONLY
static int slhdsakey_fors_sk_gen(SlhDsaKey* key, const byte* sk_seed,
const byte* pk_seed, word32* adrs, word32 idx, byte* node)
{
HashAddress sk_adrs;
HA_Copy(sk_adrs, adrs);
HA_SetTypeAndClearNotKPA(sk_adrs, HA_FORS_PRF);
HA_SetTreeIndex(sk_adrs, idx);
return HASH_PRF(&key->shake, pk_seed, sk_seed, sk_adrs, key->params->n,
node);
}
#if defined(USE_INTEL_SPEEDUP) && !defined(WOLFSSL_WC_SLHDSA_SMALL)
static int slhdsakey_hash_prf_ti_x4(const byte* pk_seed, const byte* sk_seed,
byte* addr, byte n, int ti, byte* node, void* heap)
{
int ret = 0;
word32 o = 0;
WC_DECLARE_VAR(state, word64, 25 * 4, heap);
(void)heap;
WC_ALLOC_VAR_EX(state, word64, 25 * 4, heap, DYNAMIC_TYPE_SLHDSA,
ret = MEMORY_E);
if (ret == 0) {
o = slhdsakey_shake256_set_seed_ha_hash_x4(state, pk_seed, addr,
sk_seed, n);
SHAKE256_SET_TREE_INDEX(state, o, ti);
ret = SAVE_VECTOR_REGISTERS2();
if (ret == 0) {
sha3_blocksx4_avx2(state);
RESTORE_VECTOR_REGISTERS();
slhdsakey_shake256_get_hash_x4(state, node, n);
}
WC_FREE_VAR_EX(state, heap, DYNAMIC_TYPE_SLHDSA);
}
return ret;
}
static int slhdsakey_hash_f_ti_x4(const byte* pk_seed, byte* addr, byte* node,
byte n, word32 ti, void* heap)
{
int ret = 0;
int i;
word32 o = 0;
WC_DECLARE_VAR(state, word64, 25 * 4, heap);
(void)heap;
WC_ALLOC_VAR_EX(state, word64, 25 * 4, heap, DYNAMIC_TYPE_SLHDSA,
ret = MEMORY_E);
if (ret == 0) {
o = slhdsakey_shake256_set_seed_ha_x4(state, pk_seed, addr, n);
SHAKE256_SET_TREE_INDEX(state, o, ti);
for (i = 0; i < n / 8; i++) {
state[o + 0] = ((word64*)(node + 0 * n))[i];
state[o + 1] = ((word64*)(node + 1 * n))[i];
state[o + 2] = ((word64*)(node + 2 * n))[i];
state[o + 3] = ((word64*)(node + 3 * n))[i];
o += 4;
}
SHAKE256_SET_END_X4(state, o);
ret = SAVE_VECTOR_REGISTERS2();
if (ret == 0) {
sha3_blocksx4_avx2(state);
RESTORE_VECTOR_REGISTERS();
slhdsakey_shake256_get_hash_x4(state, node, n);
}
WC_FREE_VAR_EX(state, heap, DYNAMIC_TYPE_SLHDSA);
}
return ret;
}
static int slhdsakey_hash_h_ti_x4(const byte* pk_seed, byte* addr,
const byte* m, byte n, word32 ti, byte* hash, void* heap)
{
int ret = 0;
int i;
word32 o = 0;
WC_DECLARE_VAR(state, word64, 25 * 4, heap);
(void)heap;
WC_ALLOC_VAR_EX(state, word64, 25 * 4, heap, DYNAMIC_TYPE_SLHDSA,
ret = MEMORY_E);
if (ret == 0) {
o = slhdsakey_shake256_set_seed_ha_x4(state, pk_seed, addr, n);
SHAKE256_SET_TREE_INDEX(state, o, ti);
for (i = 0; i < 2 * n / 8; i++) {
state[o + 0] = ((word64*)(m + 0 * n))[i];
state[o + 1] = ((word64*)(m + 2 * n))[i];
state[o + 2] = ((word64*)(m + 4 * n))[i];
state[o + 3] = ((word64*)(m + 6 * n))[i];
o += 4;
}
SHAKE256_SET_END_X4(state, o);
ret = SAVE_VECTOR_REGISTERS2();
if (ret == 0) {
sha3_blocksx4_avx2(state);
RESTORE_VECTOR_REGISTERS();
slhdsakey_shake256_get_hash_x4(state, hash, n);
}
WC_FREE_VAR_EX(state, heap, DYNAMIC_TYPE_SLHDSA);
}
return ret;
}
#if SLHDSA_MAX_A < 9
#define SLHDSA_MAX_FORS_NODE_DEPTH (SLHDSA_MAX_A-1)
#else
#define SLHDSA_MAX_FORS_NODE_DEPTH 8
#endif
#define SLHDSA_MAX_FORS_NODE_TOP_DEPTH \
(SLHDSA_MAX_A - SLHDSA_MAX_FORS_NODE_DEPTH)
static int slhdsakey_fors_node_x4_z0(SlhDsaKey* key, const byte* sk_seed,
word32 i, const byte* pk_seed, word32* adrs, byte* node)
{
int ret;
byte n = key->params->n;
ret = slhdsakey_fors_sk_gen(key, sk_seed, pk_seed, adrs, i, node);
if (ret == 0) {
HA_SetTreeHeight(adrs, 0);
HA_SetTreeIndex(adrs, i);
ret = HASH_F(&key->shake, pk_seed, adrs, node, n, node);
}
return ret;
}
static int slhdsakey_fors_node_x4_z1(SlhDsaKey* key, const byte* sk_seed,
word32 i, const byte* pk_seed, word32* adrs, byte* node)
{
int ret;
byte n = key->params->n;
byte nodes[2 * SLHDSA_MAX_N];
ret = slhdsakey_fors_sk_gen(key, sk_seed, pk_seed, adrs, 2 * i + 0, nodes);
if (ret == 0) {
HA_SetTreeHeight(adrs, 0);
HA_SetTreeIndex(adrs, 2 * i + 0);
ret = HASH_F(&key->shake, pk_seed, adrs, nodes, n, nodes);
}
if (ret == 0) {
ret = slhdsakey_fors_sk_gen(key, sk_seed, pk_seed, adrs, 2 * i + 1,
nodes + n);
}
if (ret == 0) {
HA_SetTreeIndex(adrs, 2 * i + 1);
ret = HASH_F(&key->shake, pk_seed, adrs, nodes + n, n, nodes + n);
}
if (ret == 0) {
HA_SetTreeHeight(adrs, 1);
HA_SetTreeIndex(adrs, i);
ret = HASH_H(&key->shake, pk_seed, adrs, nodes, n, node);
}
return ret;
}
static int slhdsakey_fors_node_x4_low(SlhDsaKey* key, const byte* sk_seed,
word32 i, word32 z, const byte* pk_seed, word32* adrs, byte* node)
{
int ret = 0;
byte n = key->params->n;
HashAddress sk_adrs;
byte addr[SLHDSA_HA_SZ];
int j;
int m = 1 << z;
WC_DECLARE_VAR(nodes, byte, (1 << SLHDSA_MAX_FORS_NODE_DEPTH) *
SLHDSA_MAX_N, key->heap);
WC_ALLOC_VAR_EX(nodes, byte, (1 << SLHDSA_MAX_FORS_NODE_DEPTH) *
SLHDSA_MAX_N, key->heap, DYNAMIC_TYPE_SLHDSA, ret = MEMORY_E);
if (ret == 0) {
byte sk_addr[SLHDSA_HA_SZ];
HA_SetTreeHeight(adrs, 0);
HA_Copy(sk_adrs, adrs);
HA_SetTypeAndClearNotKPA(sk_adrs, HA_FORS_PRF);
HA_Encode(sk_adrs, sk_addr);
HA_Encode(adrs, addr);
for (j = 0; j < m; j += 4) {
ret = slhdsakey_hash_prf_ti_x4(pk_seed, sk_seed, sk_addr, n,
m * i + j, nodes + j * n, key->heap);
if (ret != 0) {
break;
}
}
}
if (ret == 0) {
HA_SetTreeHeight((word32*)addr, 0);
for (j = 0; j < m; j += 4) {
ret = slhdsakey_hash_f_ti_x4(pk_seed, addr, nodes + j * n, n,
m * i + j, key->heap);
if (ret != 0) {
break;
}
}
}
if (ret == 0) {
word32 k;
for (k = 1; k < z - 1; k++) {
m >>= 1;
HA_SetTreeHeightBE(addr, k);
for (j = 0; j < m; j += 4) {
ret = slhdsakey_hash_h_ti_x4(pk_seed, addr, nodes + 2 * j * n,
n, m * i + j, nodes + j * n, key->heap);
if (ret != 0) {
break;
}
}
if (ret != 0) {
break;
}
}
}
if (ret == 0) {
HA_SetTreeHeight(adrs, z - 1);
HA_SetTreeIndex(adrs, 2 * i + 0);
ret = HASH_H(&key->shake, pk_seed, adrs, nodes, n, nodes);
}
if (ret == 0) {
HA_SetTreeIndex(adrs, 2 * i + 1);
ret = HASH_H(&key->shake, pk_seed, adrs, nodes + 2 * n, n,
nodes + 1 * n);
}
if (ret == 0) {
HA_SetTreeHeight(adrs, z);
HA_SetTreeIndex(adrs, i);
ret = HASH_H(&key->shake, pk_seed, adrs, nodes, n, node);
}
WC_FREE_VAR_EX(nodes, key->heap, DYNAMIC_TYPE_SLHDSA);
return ret;
}
#if SLHDSA_MAX_FORS_NODE_DEPTH < SLHDSA_MAX_A-1
static int slhdsakey_fors_node_x4_high(SlhDsaKey* key, const byte* sk_seed,
word32 i, word32 z, const byte* pk_seed, word32* adrs, byte* node)
{
int ret = 0;
byte n = key->params->n;
int j;
int z2 = z % SLHDSA_MAX_FORS_NODE_DEPTH;
int m;
WC_DECLARE_VAR(nodes, byte, (1 << SLHDSA_MAX_FORS_NODE_TOP_DEPTH) *
SLHDSA_MAX_N, key->heap);
WC_ALLOC_VAR_EX(nodes, byte, (1 << SLHDSA_MAX_FORS_NODE_TOP_DEPTH) *
SLHDSA_MAX_N, key->heap, DYNAMIC_TYPE_SLHDSA, ret = MEMORY_E);
if (ret == 0) {
if (z2 == 0) {
z2 = SLHDSA_MAX_FORS_NODE_DEPTH;
}
m = 1 << z2;
for (j = 0; j < m; j++) {
ret = slhdsakey_fors_node_x4_low(key, sk_seed, m * i + j, z - z2,
pk_seed, adrs, nodes + j * n);
if (ret != 0) {
break;
}
}
}
if ((ret == 0) && (z2 > 2)) {
word32 k;
for (k = z - z2 + 1; k < z - 1; k++) {
byte addr[SLHDSA_HA_SZ];
m >>= 1;
HA_SetTreeHeight(adrs, k);
HA_Encode(adrs, addr);
for (j = 0; j < m; j += 4) {
ret = slhdsakey_hash_h_ti_x4(pk_seed, addr, nodes + 2 * j * n,
n, m * i + j, nodes + j * n, key->heap);
if (ret != 0) {
break;
}
}
if (ret != 0) {
break;
}
}
}
if ((ret == 0) && (z2 > 1)) {
HA_SetTreeHeight(adrs, z - 1);
HA_SetTreeIndex(adrs, 2 * i + 0);
ret = HASH_H(&key->shake, pk_seed, adrs, nodes, n, nodes);
}
if ((ret == 0) && (z2 > 1)) {
HA_SetTreeIndex(adrs, 2 * i + 1);
ret = HASH_H(&key->shake, pk_seed, adrs, nodes + 2 * n, n,
nodes + 1 * n);
}
if (ret == 0) {
HA_SetTreeHeight(adrs, z);
HA_SetTreeIndex(adrs, i);
ret = HASH_H(&key->shake, pk_seed, adrs, nodes, n, node);
}
WC_FREE_VAR_EX(nodes, key->heap, DYNAMIC_TYPE_SLHDSA);
return ret;
}
#endif
static int slhdsakey_fors_node_x4(SlhDsaKey* key, const byte* sk_seed, word32 i,
word32 z, const byte* pk_seed, word32* adrs, byte* node)
{
int ret = 0;
if (z == 0) {
ret = slhdsakey_fors_node_x4_z0(key, sk_seed, i, pk_seed, adrs, node);
}
else if (z == 1) {
ret = slhdsakey_fors_node_x4_z1(key, sk_seed, i, pk_seed, adrs, node);
}
else if ((z >= 2) && (z <= SLHDSA_MAX_FORS_NODE_DEPTH)) {
ret = slhdsakey_fors_node_x4_low(key, sk_seed, i, z, pk_seed, adrs,
node);
}
#if SLHDSA_MAX_FORS_NODE_DEPTH < SLHDSA_MAX_A-1
else {
ret = slhdsakey_fors_node_x4_high(key, sk_seed, i, z, pk_seed, adrs,
node);
}
#endif
return ret;
}
#endif
#if !defined(WOLFSSL_WC_SLHDSA_RECURSIVE)
static int slhdsakey_fors_node_c(SlhDsaKey* key, const byte* sk_seed, word32 i,
word32 z, const byte* pk_seed, word32* adrs, byte* node)
{
int ret = 0;
byte n = key->params->n;
if (z == 0) {
ret = slhdsakey_fors_sk_gen(key, sk_seed, pk_seed, adrs, i, node);
if (ret == 0) {
HA_SetTreeHeight(adrs, 0);
HA_SetTreeIndex(adrs, i);
ret = HASH_F(&key->shake, pk_seed, adrs, node, n, node);
}
}
else {
WC_DECLARE_VAR(nodes, byte, (SLHDSA_MAX_A + 1) * SLHDSA_MAX_N,
key->heap);
word32 j;
word32 k;
word32 m = (word32)1 << z;
WC_ALLOC_VAR_EX(nodes, byte, (SLHDSA_MAX_A + 1) * SLHDSA_MAX_N,
key->heap, DYNAMIC_TYPE_SLHDSA, ret = MEMORY_E);
if (ret == 0) {
for (j = 0; j < m; j++) {
int o = (z - 1 + (j & 1)) * n;
ret = slhdsakey_fors_sk_gen(key, sk_seed, pk_seed, adrs,
m * i + j, nodes + o);
if (ret != 0) {
break;
}
HA_SetTreeHeight(adrs, 0);
HA_SetTreeIndex(adrs, m * i + j);
ret = HASH_F(&key->shake, pk_seed, adrs, nodes + o, n,
nodes + o);
if (ret != 0) {
break;
}
for (k = z-1; k > 0; k--) {
if (((j >> (z-1-k)) & 1) == 1) {
HA_SetTreeHeight(adrs, z - k);
HA_SetTreeIndex(adrs, (m * i + j) >> (z - k));
ret = HASH_H(&key->shake, pk_seed, adrs, nodes + k * n,
n, nodes + (k - 1 + ((j >> (z-k)) & 1)) * n);
if (ret != 0) {
break;
}
}
else {
break;
}
}
}
if (ret == 0) {
HA_SetTreeHeight(adrs, z);
HA_SetTreeIndex(adrs, i);
ret = HASH_H(&key->shake, pk_seed, adrs, nodes, n, node);
}
}
WC_FREE_VAR_EX(nodes, key->heap, DYNAMIC_TYPE_SLHDSA);
}
return ret;
}
#else
static int slhdsakey_fors_node_c(SlhDsaKey* key, const byte* sk_seed, word32 i,
word32 z, const byte* pk_seed, word32* adrs, byte* node)
{
int ret;
byte n = key->params->n;
if (z == 0) {
ret = slhdsakey_fors_sk_gen(key, sk_seed, pk_seed, adrs, i, node);
if (ret == 0) {
HA_SetTreeHeight(adrs, 0);
HA_SetTreeIndex(adrs, i);
ret = HASH_F(&key->shake, pk_seed, adrs, node, n, node);
}
}
else {
byte nodes[2 * SLHDSA_MAX_N];
ret = slhdsakey_fors_node_c(key, sk_seed, 2 * i + 0, z - 1, pk_seed,
adrs, nodes);
if (ret == 0) {
ret = slhdsakey_fors_node_c(key, sk_seed, 2 * i + 1, z - 1, pk_seed,
adrs, nodes + n);
}
if (ret == 0) {
HA_SetTreeHeight(adrs, z);
HA_SetTreeIndex(adrs, i);
ret = HASH_H(&key->shake, pk_seed, adrs, nodes, n, node);
}
}
return ret;
}
#endif
static int slhdsakey_fors_sign(SlhDsaKey* key, const byte* md,
const byte* sk_seed, const byte* pk_seed, word32* adrs, byte* sig_fors)
{
int ret;
word16 indices[SLHDSA_MAX_INDICES_SZ];
int i;
int j;
byte n = key->params->n;
byte a = key->params->a;
byte k = key->params->k;
slhdsakey_base_2b(md, a, k, indices);
for (i = 0; i < k; i++) {
ret = slhdsakey_fors_sk_gen(key, sk_seed, pk_seed, adrs,
((word32)i << a) + indices[i], sig_fors);
if (ret != 0) {
break;
}
sig_fors += n;
#if defined(USE_INTEL_SPEEDUP) && !defined(WOLFSSL_WC_SLHDSA_SMALL)
if (IS_INTEL_AVX2(cpuid_flags) && CAN_SAVE_VECTOR_REGISTERS()) {
word16 idx = indices[i];
for (j = 0; j < a; j++) {
word32 s = idx ^ 1;
ret = slhdsakey_fors_node_x4(key, sk_seed, (i << (a - j)) + s,
j, pk_seed, adrs, sig_fors);
if (ret != 0) {
break;
}
sig_fors += n;
idx >>= 1;
}
}
else
#endif
{
word16 idx = indices[i];
for (j = 0; j < a; j++) {
word32 s = idx ^ 1;
ret = slhdsakey_fors_node_c(key, sk_seed, (i << (a - j)) + s, j,
pk_seed, adrs, sig_fors);
if (ret != 0) {
break;
}
sig_fors += n;
idx >>= 1;
}
}
if (ret != 0) {
break;
}
}
return ret;
}
#endif
#if defined(USE_INTEL_SPEEDUP) && !defined(WOLFSSL_WC_SLHDSA_SMALL)
static int slhdsakey_hash_f_ti4_x4(const byte* pk_seed, byte* addr,
const byte* sig_fors, int so, byte n, word32* ti, byte* node, void* heap)
{
int ret = 0;
int i;
word32 o = 0;
WC_DECLARE_VAR(state, word64, 25 * 4, heap);
(void)heap;
WC_ALLOC_VAR_EX(state, word64, 25 * 4, heap, DYNAMIC_TYPE_SLHDSA,
ret = MEMORY_E);
if (ret == 0) {
o = slhdsakey_shake256_set_seed_ha_x4(state, pk_seed, addr, n);
SHAKE256_SET_TREE_INDEX_IDX(state, o, ti);
for (i = 0; i < n / 8; i++) {
state[o + 0] = ((word64*)(sig_fors + 0 * so * n))[i];
state[o + 1] = ((word64*)(sig_fors + 1 * so * n))[i];
state[o + 2] = ((word64*)(sig_fors + 2 * so * n))[i];
state[o + 3] = ((word64*)(sig_fors + 3 * so * n))[i];
o += 4;
}
SHAKE256_SET_END_X4(state, o);
ret = SAVE_VECTOR_REGISTERS2();
if (ret == 0) {
sha3_blocksx4_avx2(state);
RESTORE_VECTOR_REGISTERS();
slhdsakey_shake256_get_hash_x4(state, node, n);
}
WC_FREE_VAR_EX(state, heap, DYNAMIC_TYPE_SLHDSA);
}
return ret;
}
static int slhdsakey_hash_h_2_x4(const byte* pk_seed, byte* addr, byte* node,
const byte* sig_fors, int so, word32* bit, byte n, word32 th, word32* ti,
void* heap)
{
int ret = 0;
int i;
int j;
word32 o = 0;
WC_DECLARE_VAR(state, word64, 25 * 4, heap);
(void)heap;
WC_ALLOC_VAR_EX(state, word64, 25 * 4, heap, DYNAMIC_TYPE_SLHDSA,
ret = MEMORY_E);
if (ret == 0) {
o = slhdsakey_shake256_set_seed_ha_x4(state, pk_seed, addr, n);
SHAKE256_SET_TREE_HEIGHT(state, o, th);
SHAKE256_SET_TREE_INDEX_IDX(state, o, ti);
for (i = 0; i < n / 8; i++) {
for (j = 0; j < 4; j++) {
if (bit[j] == 0) {
state[o + j] = ((word64*)(node + j * n))[i];
}
else {
state[o + j] = ((word64*)(sig_fors + j * so * n))[i];
}
}
o += 4;
}
for (i = 0; i < n / 8; i++) {
for (j = 0; j < 4; j++) {
if (bit[j] == 0) {
state[o + j] = ((word64*)(sig_fors + j * so * n))[i];
}
else {
state[o + j] = ((word64*)(node + j * n))[i];
}
}
o += 4;
}
SHAKE256_SET_END_X4(state, o);
ret = SAVE_VECTOR_REGISTERS2();
if (ret == 0) {
sha3_blocksx4_avx2(state);
RESTORE_VECTOR_REGISTERS();
slhdsakey_shake256_get_hash_x4(state, node, n);
}
WC_FREE_VAR_EX(state, heap, DYNAMIC_TYPE_SLHDSA);
}
return ret;
}
static int slhdsakey_fors_pk_from_sig_i_x4(SlhDsaKey* key, const byte* sig_fors,
const byte* pk_seed, byte* addr, const word16* indices, int i, byte* node)
{
int ret;
int j;
int k;
byte n = key->params->n;
byte a = key->params->a;
word32 ti[4];
word32 bit[4];
ti[0] = ((word32)(i + 0) << a) + indices[i + 0];
ti[1] = ((word32)(i + 1) << a) + indices[i + 1];
ti[2] = ((word32)(i + 2) << a) + indices[i + 2];
ti[3] = ((word32)(i + 3) << a) + indices[i + 3];
ret = slhdsakey_hash_f_ti4_x4(pk_seed, addr, sig_fors, 1 + a, n, ti, node,
key->heap);
if (ret == 0) {
sig_fors += n;
for (j = 0; j < a; j++) {
for (k = 0; k < 4; k++) {
bit[k] = ti[k] & 1;
ti[k] /= 2;
}
ret = slhdsakey_hash_h_2_x4(pk_seed, addr, node, sig_fors, 1 + a,
bit, n, j + 1, ti, key->heap);
if (ret != 0) {
break;
}
sig_fors += n;
}
}
return ret;
}
static int slhdsakey_fors_pk_from_sig_x4(SlhDsaKey* key, const byte* sig_fors,
const word16* indices, const byte* pk_seed, word32* adrs)
{
int ret = 0;
int i;
int j;
byte n = key->params->n;
byte a = key->params->a;
byte k = key->params->k;
byte addr[SLHDSA_HA_SZ];
WC_DECLARE_VAR(node, byte, SLHDSA_MAX_INDICES_SZ * SLHDSA_MAX_N, key->heap);
WC_ALLOC_VAR_EX(node, byte, SLHDSA_MAX_INDICES_SZ * SLHDSA_MAX_N, key->heap,
DYNAMIC_TYPE_SLHDSA, ret = MEMORY_E);
if (ret == 0) {
HA_SetTreeHeight(adrs, 0);
HA_Encode(adrs, addr);
for (i = 0; i < k-3; i += 4) {
ret = slhdsakey_fors_pk_from_sig_i_x4(key, sig_fors, pk_seed, addr,
indices, i, node + i * n);
if (ret != 0) {
break;
}
sig_fors += 4 * (1 + a) * n;
}
}
if (ret == 0) {
for (; i < k; i++) {
word32 idx = ((word32)i << a) + indices[i];
HA_SetTreeHeight(adrs, 0);
HA_SetTreeIndex(adrs, idx);
ret = HASH_F(&key->shake, pk_seed, adrs, sig_fors, n, node + i * n);
if (ret != 0) {
break;
}
sig_fors += n;
for (j = 0; j < a; j++) {
word32 side = idx & 1;
idx >>= 1;
HA_SetTreeHeight(adrs, j + 1);
HA_SetTreeIndex(adrs, idx);
if (side == 0) {
ret = HASH_H_2(&key->shake, pk_seed, adrs, node + i * n,
sig_fors, n, node + i * n);
}
else {
ret = HASH_H_2(&key->shake, pk_seed, adrs, sig_fors,
node + i * n, n, node + i * n);
}
if (ret != 0) {
break;
}
sig_fors += n;
}
if (ret != 0) {
break;
}
}
}
if (ret == 0) {
ret = slhdsakey_hash_update(&key->shake2, node, i * n);
}
WC_FREE_VAR_EX(node, key->heap, DYNAMIC_TYPE_SLHDSA);
return ret;
}
#endif
#if !defined(WOLFSSL_WC_SLHDSA_SMALL_MEM)
static int slhdsakey_fors_pk_from_sig_c(SlhDsaKey* key, const byte* sig_fors,
const word16* indices, const byte* pk_seed, word32* adrs, byte* pk_fors)
{
int ret = 0;
int i = 0;
int j;
byte n = key->params->n;
byte a = key->params->a;
byte k = key->params->k;
WC_DECLARE_VAR(node, byte, SLHDSA_MAX_INDICES_SZ * SLHDSA_MAX_N, key->heap);
(void)pk_fors;
WC_ALLOC_VAR_EX(node, byte, SLHDSA_MAX_INDICES_SZ * SLHDSA_MAX_N, key->heap,
DYNAMIC_TYPE_SLHDSA, ret = MEMORY_E);
if (ret == 0) {
for (i = 0; i < k; i++) {
word32 idx = ((word32)i << a) + indices[i];
HA_SetTreeHeight(adrs, 0);
HA_SetTreeIndex(adrs, idx);
ret = HASH_F(&key->shake, pk_seed, adrs, sig_fors, n, node + i * n);
if (ret != 0) {
break;
}
sig_fors += n;
for (j = 0; j < a; j++) {
word32 bit = idx & 1;
idx >>= 1;
HA_SetTreeHeight(adrs, j + 1);
HA_SetTreeIndex(adrs, idx);
if (bit == 0) {
ret = HASH_H_2(&key->shake, pk_seed, adrs, node + i * n,
sig_fors, n, node + i * n);
}
else {
ret = HASH_H_2(&key->shake, pk_seed, adrs, sig_fors,
node + i * n, n, node + i * n);
}
if (ret != 0) {
break;
}
sig_fors += n;
}
if (ret != 0) {
break;
}
}
}
if (ret == 0) {
ret = slhdsakey_hash_update(&key->shake2, node, i * n);
}
WC_FREE_VAR_EX(node, key->heap, DYNAMIC_TYPE_SLHDSA);
return ret;
}
#else
static int slhdsakey_fors_pk_from_sig_c(SlhDsaKey* key, const byte* sig_fors,
const word16* indices, const byte* pk_seed, word32* adrs, byte* node)
{
int ret;
int i;
int j;
byte n = key->params->n;
byte a = key->params->a;
byte k = key->params->k;
for (i = 0; i < k; i++) {
word32 idx = ((word32)i << a) + indices[i];
HA_SetTreeHeight(adrs, 0);
HA_SetTreeIndex(adrs, idx);
ret = HASH_F(&key->shake, pk_seed, adrs, sig_fors, n, node);
if (ret != 0) {
break;
}
sig_fors += n;
for (j = 0; j < a; j++) {
word32 bit = idx & 1;
idx >>= 1;
HA_SetTreeHeight(adrs, j + 1);
HA_SetTreeIndex(adrs, idx);
if (bit == 0) {
ret = HASH_H_2(&key->shake, pk_seed, adrs, node, sig_fors, n,
node);
}
else {
ret = HASH_H_2(&key->shake, pk_seed, adrs, sig_fors, node, n,
node);
}
if (ret != 0) {
break;
}
sig_fors += n;
}
if (ret == 0) {
ret = slhdsakey_hash_update(&key->shake2, node, n);
}
if (ret != 0) {
break;
}
}
return ret;
}
#endif
static int slhdsakey_fors_pk_from_sig(SlhDsaKey* key, const byte* sig_fors,
const byte* md, const byte* pk_seed, word32* adrs, byte* pk_fors)
{
int ret;
word16 indices[SLHDSA_MAX_INDICES_SZ];
HashAddress forspk_adrs;
byte n = key->params->n;
byte a = key->params->a;
byte k = key->params->k;
slhdsakey_base_2b(md, a, k, indices);
HA_Copy(forspk_adrs, adrs);
HA_SetTypeAndClearNotKPA(forspk_adrs, HA_FORS_ROOTS);
ret = slhdsakey_hash_start_addr(&key->shake2, pk_seed, forspk_adrs, n);
#if defined(USE_INTEL_SPEEDUP) && !defined(WOLFSSL_WC_SLHDSA_SMALL)
if ((ret == 0) && IS_INTEL_AVX2(cpuid_flags) &&
(SAVE_VECTOR_REGISTERS2() == 0)) {
ret = slhdsakey_fors_pk_from_sig_x4(key, sig_fors, indices, pk_seed,
adrs);
RESTORE_VECTOR_REGISTERS();
}
else
#endif
if (ret == 0) {
ret = slhdsakey_fors_pk_from_sig_c(key, sig_fors, indices, pk_seed,
adrs, pk_fors);
}
if (ret == 0) {
ret = slhdsakey_hash_final(&key->shake2, pk_fors, n);
}
return ret;
}
int wc_SlhDsaKey_Init(SlhDsaKey* key, enum SlhDsaParam param, void* heap,
int devId)
{
int ret = 0;
int idx = -1;
if (key == NULL) {
ret = BAD_FUNC_ARG;
}
if (ret == 0) {
int i;
for (i = 0; i < SLHDSA_PARAM_LEN; i++) {
if (param == SlhDsaParams[i].param) {
idx = i;
break;
}
}
if (idx == -1) {
ret = NOT_COMPILED_IN;
}
}
if (ret == 0) {
XMEMSET(key, 0, sizeof(SlhDsaKey));
ret = wc_InitShake256(&key->shake, key->heap, INVALID_DEVID);
}
if (ret == 0) {
ret = wc_InitShake256(&key->shake2, key->heap, INVALID_DEVID);
}
if (ret == 0) {
key->params = &SlhDsaParams[idx];
key->heap = heap;
#ifdef WOLF_CRYPTO_CB
key->devId = devId;
#endif
}
(void)devId;
#if defined(USE_INTEL_SPEEDUP)
cpuid_get_flags_ex(&cpuid_flags);
#endif
return ret;
}
void wc_SlhDsaKey_Free(SlhDsaKey* key)
{
if ((key != NULL) && (key->params != NULL)) {
ForceZero(key->sk, key->params->n * 2);
wc_Shake256_Free(&key->shake2);
wc_Shake256_Free(&key->shake);
}
}
static void slhdsakey_set_ha_from_md(SlhDsaKey* key, const byte* md,
HashAddress adrs, word32* t, word32* l)
{
const byte* p;
int bits;
HA_Init(adrs);
p = md + key->params->dl1 + (key->params->dl2 - 8);
t[0] = 0;
ato32(p + 0, &t[1]);
ato32(p + 4, &t[2]);
bits = key->params->h - (key->params->h / key->params->d);
if (bits < 64) {
t[1] &= ((word32)1 << (bits - 32)) - 1;
}
p = md + key->params->dl1 + key->params->dl2 + (key->params->dl3 - 4);
ato32(p, l);
bits = key->params->h / key->params->d;
*l &= ((word32)1 << bits) - 1;
HA_SetTreeAddress(adrs, t);
HA_SetTypeAndClearNotKPA(adrs, HA_FORS_TREE);
HA_SetKeyPairAddress(adrs, *l);
}
#ifndef WOLFSSL_SLHDSA_VERIFY_ONLY
int wc_SlhDsaKey_MakeKey(SlhDsaKey* key, WC_RNG* rng)
{
int ret = 0;
if ((key == NULL) || (key->params == NULL) || (rng == NULL)) {
ret = BAD_FUNC_ARG;
}
if (ret == 0) {
ret = wc_RNG_GenerateBlock(rng, key->sk, 3 * key->params->n);
}
if (ret == 0) {
byte n = key->params->n;
ret = wc_SlhDsaKey_MakeKeyWithRandom(key, key->sk, n, key->sk + n, n,
key->sk + 2 * n, n);
}
return ret;
}
int wc_SlhDsaKey_MakeKeyWithRandom(SlhDsaKey* key, const byte* sk_seed,
word32 sk_seed_len, const byte* sk_prf, word32 sk_prf_len,
const byte* pk_seed, word32 pk_seed_len)
{
int ret = 0;
if ((key == NULL) || (key->params == NULL)) {
ret = BAD_FUNC_ARG;
}
else if ((sk_seed == NULL) || (sk_seed_len != key->params->n)) {
ret = BAD_FUNC_ARG;
}
else if ((sk_prf == NULL) || (sk_prf_len != key->params->n)) {
ret = BAD_FUNC_ARG;
}
else if ((pk_seed == NULL) || (pk_seed_len != key->params->n)) {
ret = BAD_FUNC_ARG;
}
else {
byte n = key->params->n;
HashAddress adrs;
if (sk_seed != key->sk) {
XMEMCPY(key->sk , sk_seed, n);
XMEMCPY(key->sk + n, sk_prf , n);
XMEMCPY(key->sk + 2 * n, pk_seed, n);
}
HA_Init(adrs);
HA_SetLayerAddress(adrs, key->params->d - 1);
ret = slhdsakey_xmss_node(key, sk_seed, 0, key->params->h_m, pk_seed,
adrs, &key->sk[3 * n]);
if (ret == 0) {
key->flags = WC_SLHDSA_FLAG_BOTH_KEYS;
}
}
return ret;
}
static int slhdsakey_sign(SlhDsaKey* key, byte* md, byte* sig)
{
int ret;
HashAddress adrs;
word32 t[3];
word32 l;
byte pk_fors[SLHDSA_MAX_N];
byte n = key->params->n;
slhdsakey_set_ha_from_md(key, md, adrs, t, &l);
ret = slhdsakey_fors_sign(key, md, key->sk, key->sk + 2 * n, adrs, sig);
if (ret == 0) {
ret = slhdsakey_fors_pk_from_sig(key, sig, md, key->sk + 2 * n, adrs,
pk_fors);
sig += key->params->k * (1 + key->params->a) * n;
}
if (ret == 0) {
ret = slhdsakey_ht_sign(key, pk_fors, key->sk, key->sk + 2 * n, t, l,
sig);
}
return ret;
}
static int slhdsakey_sign_external(SlhDsaKey* key, const byte* ctx, byte ctxSz,
const byte* msg, word32 msgSz, byte* sig, word32* sigSz,
const byte* addRnd)
{
int ret = 0;
if ((key == NULL) || (key->params == NULL) ||
((ctx == NULL) && (ctxSz > 0)) || (msg == NULL) || (sig == NULL) ||
(sigSz == NULL)) {
ret = BAD_FUNC_ARG;
}
else if (*sigSz < key->params->sigLen) {
ret = BAD_LENGTH_E;
}
else if (addRnd == NULL) {
ret = BAD_FUNC_ARG;
}
else if ((key->flags & WC_SLHDSA_FLAG_PRIVATE) == 0) {
ret = MISSING_KEY;
}
if (ret == 0) {
byte md[SLHDSA_MAX_MD];
byte hdr[2];
byte n = key->params->n;
hdr[0] = 0;
hdr[1] = ctxSz;
ret = slhdsakey_hash_start(&key->shake, key->sk + n, n);
if (ret == 0) {
ret = slhdsakey_hash_update(&key->shake, addRnd, n);
}
if (ret == 0) {
ret = slhdsakey_hash_update(&key->shake, hdr, sizeof(hdr));
}
if ((ret == 0) && (ctxSz > 0)) {
ret = slhdsakey_hash_update(&key->shake, ctx, ctxSz);
}
if (ret == 0) {
ret = slhdsakey_hash_update(&key->shake, msg, msgSz);
}
if (ret == 0) {
ret = slhdsakey_hash_final(&key->shake, sig, n);
}
if (ret == 0) {
ret = slhdsakey_hash_start(&key->shake, sig, n);
sig += n;
}
if (ret == 0) {
ret = slhdsakey_hash_update(&key->shake, key->sk + 2 * n, 2 * n);
}
if (ret == 0) {
ret = slhdsakey_hash_update(&key->shake, hdr, sizeof(hdr));
}
if ((ret == 0) && (ctxSz > 0)) {
ret = slhdsakey_hash_update(&key->shake, ctx, ctxSz);
}
if (ret == 0) {
ret = slhdsakey_hash_update(&key->shake, msg, msgSz);
}
if (ret == 0) {
ret = slhdsakey_hash_final(&key->shake, md, key->params->dl1 +
key->params->dl2 + key->params->dl3);
}
if (ret == 0) {
ret = slhdsakey_sign(key, md, sig);
}
if (ret == 0) {
*sigSz = key->params->sigLen;
}
}
return ret;
}
int wc_SlhDsaKey_SignDeterministic(SlhDsaKey* key, const byte* ctx, byte ctxSz,
const byte* msg, word32 msgSz, byte* sig, word32* sigSz)
{
int ret;
if ((key == NULL) || (key->params == NULL)) {
ret = BAD_FUNC_ARG;
}
else {
ret = slhdsakey_sign_external(key, ctx, ctxSz, msg, msgSz, sig, sigSz,
key->sk + 2 * key->params->n);
}
return ret;
}
int wc_SlhDsaKey_SignWithRandom(SlhDsaKey* key, const byte* ctx, byte ctxSz,
const byte* msg, word32 msgSz, byte* sig, word32* sigSz, const byte* addRnd)
{
return slhdsakey_sign_external(key, ctx, ctxSz, msg, msgSz, sig, sigSz,
addRnd);
}
int wc_SlhDsaKey_Sign(SlhDsaKey* key, const byte* ctx, byte ctxSz,
const byte* msg, word32 msgSz, byte* sig, word32* sigSz, WC_RNG* rng)
{
int ret = 0;
byte addRnd[SLHDSA_MAX_N];
if ((key == NULL) || (key->params == NULL) ||
((ctx == NULL) && (ctxSz > 0)) || (msg == NULL) || (sig == NULL) ||
(sigSz == NULL) || (rng == NULL)) {
ret = BAD_FUNC_ARG;
}
else if (*sigSz < key->params->sigLen) {
ret = BAD_LENGTH_E;
}
else if ((key->flags & WC_SLHDSA_FLAG_PRIVATE) == 0) {
ret = MISSING_KEY;
}
if (ret == 0) {
ret = wc_RNG_GenerateBlock(rng, addRnd, key->params->n);
}
if (ret == 0) {
ret = wc_SlhDsaKey_SignWithRandom(key, ctx, ctxSz, msg, msgSz, sig,
sigSz, addRnd);
}
return ret;
}
#endif
static int slhdsakey_verify(SlhDsaKey* key, byte* md, const byte* sig)
{
int ret;
HashAddress adrs;
word32 t[3];
word32 l;
byte pk_fors[SLHDSA_MAX_N];
byte n = key->params->n;
slhdsakey_set_ha_from_md(key, md, adrs, t, &l);
sig += n;
ret = slhdsakey_fors_pk_from_sig(key, sig, md, key->sk + 2 * n, adrs,
pk_fors);
sig += key->params->k * (1 + key->params->a) * n;
if (ret == 0) {
ret = slhdsakey_ht_verify(key, pk_fors, sig, key->sk + 2 * n, t, l,
key->sk + 3 * n);
}
return ret;
}
int wc_SlhDsaKey_Verify(SlhDsaKey* key, const byte* ctx, byte ctxSz,
const byte* msg, word32 msgSz, const byte* sig, word32 sigSz)
{
int ret = 0;
if ((key == NULL) || (key->params == NULL) ||
((ctx == NULL) && (ctxSz > 0)) || (msg == NULL) ||
(sig == NULL)) {
ret = BAD_FUNC_ARG;
}
else if (sigSz != key->params->sigLen) {
ret = BAD_LENGTH_E;
}
else if ((key->flags & WC_SLHDSA_FLAG_PUBLIC) == 0) {
ret = MISSING_KEY;
}
if (ret == 0) {
byte md[SLHDSA_MAX_MD];
byte n = key->params->n;
ret = slhdsakey_hash_start(&key->shake, sig, n);
if (ret == 0) {
ret = slhdsakey_hash_update(&key->shake, key->sk + 2 * n, 2 * n);
}
if (ret == 0) {
byte hdr[2];
hdr[0] = 0;
hdr[1] = ctxSz;
ret = slhdsakey_hash_update(&key->shake, hdr, sizeof(hdr));
}
if ((ret == 0) && (ctxSz > 0)) {
ret = slhdsakey_hash_update(&key->shake, ctx, ctxSz);
}
if (ret == 0) {
ret = slhdsakey_hash_update(&key->shake, msg, msgSz);
}
if (ret == 0) {
ret = slhdsakey_hash_final(&key->shake, md, key->params->dl1 +
key->params->dl2 + key->params->dl3);
}
if (ret == 0) {
ret = slhdsakey_verify(key, md, sig);
}
}
return ret;
}
#ifdef WOLFSSL_SHA224
static const byte slhdsakey_oid_sha224[] = {
0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x04
};
#endif
#ifndef NO_SHA256
static const byte slhdsakey_oid_sha256[] = {
0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x01
};
#endif
#ifdef WOLFSSL_SHA384
static const byte slhdsakey_oid_sha384[] = {
0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x02
};
#endif
#ifdef WOLFSSL_SHA512
static const byte slhdsakey_oid_sha512[] = {
0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x03
};
#ifndef WOLFSSL_NOSHA512_224
static const byte slhdsakey_oid_sha512_224[] = {
0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x05
};
#endif
#ifndef WOLFSSL_NOSHA512_256
static const byte slhdsakey_oid_sha512_256[] = {
0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x06
};
#endif
#endif
#ifdef WOLFSSL_SHAKE128
static const byte slhdsakey_oid_shake128[] = {
0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x0b
};
#endif
#ifdef WOLFSSL_SHAKE256
static const byte slhdsakey_oid_shake256[] = {
0x06, 0x09, 0x60, 0x86, 0x48, 0x01, 0x65, 0x03, 0x04, 0x02, 0x0c
};
#endif
static int slhdsakey_prehash_msg(const byte* msg, word32 msgSz,
enum wc_HashType hashType, byte* ph, byte* phLen, const byte** oid,
byte* oidLen)
{
int ret;
switch ((int)hashType) {
#ifdef WOLFSSL_SHA224
case WC_HASH_TYPE_SHA224:
*oid = slhdsakey_oid_sha224;
*oidLen = (byte)sizeof(slhdsakey_oid_sha224);
*phLen = WC_SHA224_DIGEST_SIZE;
ret = wc_Sha224Hash(msg, msgSz, ph);
break;
#endif
#ifndef NO_SHA256
case WC_HASH_TYPE_SHA256:
*oid = slhdsakey_oid_sha256;
*oidLen = (byte)sizeof(slhdsakey_oid_sha256);
*phLen = WC_SHA256_DIGEST_SIZE;
ret = wc_Sha256Hash(msg, msgSz, ph);
break;
#endif
#ifdef WOLFSSL_SHA384
case WC_HASH_TYPE_SHA384:
*oid = slhdsakey_oid_sha384;
*oidLen = (byte)sizeof(slhdsakey_oid_sha384);
*phLen = WC_SHA384_DIGEST_SIZE;
ret = wc_Sha384Hash(msg, msgSz, ph);
break;
#endif
#ifdef WOLFSSL_SHA512
case WC_HASH_TYPE_SHA512:
*oid = slhdsakey_oid_sha512;
*oidLen = (byte)sizeof(slhdsakey_oid_sha512);
*phLen = WC_SHA512_DIGEST_SIZE;
ret = wc_Sha512Hash(msg, msgSz, ph);
break;
#ifndef WOLFSSL_NOSHA512_224
case WC_HASH_TYPE_SHA512_224:
*oid = slhdsakey_oid_sha512_224;
*oidLen = (byte)sizeof(slhdsakey_oid_sha512_224);
*phLen = WC_SHA512_224_DIGEST_SIZE;
ret = wc_Sha512_224Hash(msg, msgSz, ph);
break;
#endif
#ifndef WOLFSSL_NOSHA512_256
case WC_HASH_TYPE_SHA512_256:
*oid = slhdsakey_oid_sha512_256;
*oidLen = (byte)sizeof(slhdsakey_oid_sha512_256);
*phLen = WC_SHA512_256_DIGEST_SIZE;
ret = wc_Sha512_256Hash(msg, msgSz, ph);
break;
#endif
#endif
#ifdef WOLFSSL_SHAKE128
case WC_HASH_TYPE_SHAKE128:
*oid = slhdsakey_oid_shake128;
*oidLen = (byte)sizeof(slhdsakey_oid_shake128);
*phLen = WC_SHA3_256_DIGEST_SIZE;
ret = wc_Shake128Hash(msg, msgSz, ph, WC_SHA3_256_DIGEST_SIZE);
break;
#endif
#ifdef WOLFSSL_SHAKE256
case WC_HASH_TYPE_SHAKE256:
*oid = slhdsakey_oid_shake256;
*oidLen = (byte)sizeof(slhdsakey_oid_shake256);
*phLen = WC_SHA3_512_DIGEST_SIZE;
ret = wc_Shake256Hash(msg, msgSz, ph, WC_SHA3_512_DIGEST_SIZE);
break;
#endif
default:
ret = NOT_COMPILED_IN;
break;
}
return ret;
}
#ifndef WOLFSSL_SLHDSA_VERIFY_ONLY
static int slhdsakey_signhash_external(SlhDsaKey* key, const byte* ctx,
byte ctxSz, const byte* msg, word32 msgSz, enum wc_HashType hashType,
byte* sig, word32* sigSz, byte* addRnd)
{
int ret = 0;
byte ph[WC_MAX_DIGEST_SIZE];
byte phLen = 0;
const byte* oid = NULL;
byte oidLen = 0;
if ((key == NULL) || (key->params == NULL) ||
((ctx == NULL) && (ctxSz > 0)) || (msg == NULL) || (sig == NULL) ||
(sigSz == NULL)) {
ret = BAD_FUNC_ARG;
}
else if (*sigSz < key->params->sigLen) {
ret = BAD_LENGTH_E;
}
else if (addRnd == NULL) {
ret = BAD_FUNC_ARG;
}
if (ret == 0) {
ret = slhdsakey_prehash_msg(msg, msgSz, hashType, ph, &phLen, &oid,
&oidLen);
}
if (ret == 0) {
byte n = key->params->n;
byte md[SLHDSA_MAX_MD];
byte hdr[2];
hdr[0] = 1;
hdr[1] = ctxSz;
ret = slhdsakey_hash_start(&key->shake, key->sk + n, n);
if (ret == 0) {
ret = slhdsakey_hash_update(&key->shake, addRnd, n);
}
if (ret == 0) {
ret = slhdsakey_hash_update(&key->shake, hdr, sizeof(hdr));
}
if ((ret == 0) && (ctxSz > 0)) {
ret = slhdsakey_hash_update(&key->shake, ctx, ctxSz);
}
if (ret == 0) {
ret = slhdsakey_hash_update(&key->shake, oid, oidLen);
}
if (ret == 0) {
ret = slhdsakey_hash_update(&key->shake, ph, phLen);
}
if (ret == 0) {
ret = slhdsakey_hash_final(&key->shake, sig, n);
}
if (ret == 0) {
ret = slhdsakey_hash_start(&key->shake, sig, n);
sig += n;
}
if (ret == 0) {
ret = slhdsakey_hash_update(&key->shake, key->sk + 2 * n, 2 * n);
}
if (ret == 0) {
ret = slhdsakey_hash_update(&key->shake, hdr, sizeof(hdr));
}
if ((ret == 0) && (ctxSz > 0)) {
ret = slhdsakey_hash_update(&key->shake, ctx, ctxSz);
}
if (ret == 0) {
ret = slhdsakey_hash_update(&key->shake, oid, oidLen);
}
if (ret == 0) {
ret = slhdsakey_hash_update(&key->shake, ph, phLen);
}
if (ret == 0) {
ret = slhdsakey_hash_final(&key->shake, md, key->params->dl1 +
key->params->dl2 + key->params->dl3);
}
if (ret == 0) {
ret = slhdsakey_sign(key, md, sig);
}
if (ret == 0) {
*sigSz = key->params->sigLen;
}
}
return ret;
}
int wc_SlhDsaKey_SignHashDeterministic(SlhDsaKey* key, const byte* ctx,
byte ctxSz, const byte* msg, word32 msgSz, enum wc_HashType hashType,
byte* sig, word32* sigSz)
{
int ret;
if ((key == NULL) || (key->params == NULL)) {
ret = BAD_FUNC_ARG;
}
else if ((key->flags & WC_SLHDSA_FLAG_PRIVATE) == 0) {
ret = MISSING_KEY;
}
else {
ret = slhdsakey_signhash_external(key, ctx, ctxSz, msg, msgSz, hashType,
sig, sigSz, key->sk + 2 * key->params->n);
}
return ret;
}
int wc_SlhDsaKey_SignHashWithRandom(SlhDsaKey* key, const byte* ctx, byte ctxSz,
const byte* msg, word32 msgSz, enum wc_HashType hashType, byte* sig,
word32* sigSz, byte* addRnd)
{
return slhdsakey_signhash_external(key, ctx, ctxSz, msg, msgSz, hashType,
sig, sigSz, addRnd);
}
int wc_SlhDsaKey_SignHash(SlhDsaKey* key, const byte* ctx, byte ctxSz,
const byte* msg, word32 msgSz, enum wc_HashType hashType, byte* sig,
word32* sigSz, WC_RNG* rng)
{
int ret = 0;
byte addRnd[SLHDSA_MAX_N];
if ((key == NULL) || (key->params == NULL) ||
((ctx == NULL) && (ctxSz > 0)) || (msg == NULL) || (sig == NULL) ||
(sigSz == NULL) || (rng == NULL)) {
ret = BAD_FUNC_ARG;
}
else if (*sigSz < key->params->sigLen) {
ret = BAD_LENGTH_E;
}
else if ((key->flags & WC_SLHDSA_FLAG_PRIVATE) == 0) {
ret = MISSING_KEY;
}
if (ret == 0) {
ret = wc_RNG_GenerateBlock(rng, addRnd, key->params->n);
}
if (ret == 0) {
ret = wc_SlhDsaKey_SignHashWithRandom(key, ctx, ctxSz, msg, msgSz,
hashType, sig, sigSz, addRnd);
}
return ret;
}
#endif
int wc_SlhDsaKey_VerifyHash(SlhDsaKey* key, const byte* ctx, byte ctxSz,
const byte* msg, word32 msgSz, enum wc_HashType hashType, const byte* sig,
word32 sigSz)
{
int ret = 0;
byte ph[WC_MAX_DIGEST_SIZE];
byte phLen = 0;
const byte* oid = NULL;
byte oidLen = 0;
if ((key == NULL) || (key->params == NULL) ||
((ctx == NULL) && (ctxSz > 0)) || (msg == NULL) || (sig == NULL)) {
ret = BAD_FUNC_ARG;
}
else if (sigSz != key->params->sigLen) {
ret = BAD_LENGTH_E;
}
else if ((key->flags & WC_SLHDSA_FLAG_PUBLIC) == 0) {
ret = MISSING_KEY;
}
if (ret == 0) {
ret = slhdsakey_prehash_msg(msg, msgSz, hashType, ph, &phLen, &oid,
&oidLen);
}
if (ret == 0) {
byte n = key->params->n;
byte md[SLHDSA_MAX_MD];
ret = slhdsakey_hash_start(&key->shake, sig, n);
if (ret == 0) {
ret = slhdsakey_hash_update(&key->shake, key->sk + 2 * n, 2 * n);
}
if (ret == 0) {
byte hdr[2];
hdr[0] = 1;
hdr[1] = ctxSz;
ret = slhdsakey_hash_update(&key->shake, hdr, sizeof(hdr));
}
if ((ret == 0) && (ctxSz > 0)) {
ret = slhdsakey_hash_update(&key->shake, ctx, ctxSz);
}
if (ret == 0) {
ret = slhdsakey_hash_update(&key->shake, oid, oidLen);
}
if (ret == 0) {
ret = slhdsakey_hash_update(&key->shake, ph, phLen);
}
if (ret == 0) {
ret = slhdsakey_hash_final(&key->shake, md, key->params->dl1 +
key->params->dl2 + key->params->dl3);
}
if (ret == 0) {
ret = slhdsakey_verify(key, md, sig);
}
}
return ret;
}
#ifndef WOLFSSL_SLHDSA_VERIFY_ONLY
int wc_SlhDsaKey_ImportPrivate(SlhDsaKey* key, const byte* priv, word32 privLen)
{
int ret = 0;
if ((key == NULL) || (key->params == NULL) || (priv == NULL)) {
ret = BAD_FUNC_ARG;
}
else if ((privLen != 4 * key->params->n)) {
ret = BAD_LENGTH_E;
}
else {
XMEMCPY(key->sk, priv, 4 * key->params->n);
key->flags = WC_SLHDSA_FLAG_BOTH_KEYS;
}
return ret;
}
#endif
int wc_SlhDsaKey_ImportPublic(SlhDsaKey* key, const byte* pub, word32 pubLen)
{
int ret = 0;
if ((key == NULL) || (key->params == NULL) || (pub == NULL)) {
ret = BAD_FUNC_ARG;
}
else if ((pubLen != 2 * key->params->n)) {
ret = BAD_LENGTH_E;
}
else {
XMEMCPY(key->sk + 2 * key->params->n, pub, 2 * key->params->n);
key->flags = WC_SLHDSA_FLAG_PUBLIC;
}
return ret;
}
#ifndef WOLFSSL_SLHDSA_VERIFY_ONLY
int wc_SlhDsaKey_CheckKey(SlhDsaKey* key)
{
int ret = 0;
if ((key == NULL) || (key->params == NULL)) {
ret = BAD_FUNC_ARG;
}
else if ((key->flags & WC_SLHDSA_FLAG_PRIVATE) == 0) {
ret = MISSING_KEY;
}
if (ret == 0) {
byte root[SLHDSA_MAX_N];
byte n = key->params->n;
XMEMCPY(root, key->sk + 3 * n, n);
ret = wc_SlhDsaKey_MakeKeyWithRandom(key, key->sk, n, key->sk + n, n,
key->sk + 2 * n, n);
if ((ret == 0) && (XMEMCMP(root, key->sk + 3 * n, n) != 0)) {
ret = WC_KEY_MISMATCH_E;
}
}
return ret;
}
int wc_SlhDsaKey_ExportPrivate(SlhDsaKey* key, byte* priv, word32* privLen)
{
int ret = 0;
if ((key == NULL) || (key->params == NULL) || (priv == NULL) ||
(privLen == NULL)) {
ret = BAD_FUNC_ARG;
}
else if (*privLen < key->params->n * 4) {
ret = BAD_LENGTH_E;
}
else {
int n = key->params->n;
XMEMCPY(priv, key->sk, n * 4);
*privLen = n * 4;
}
return ret;
}
#endif
int wc_SlhDsaKey_ExportPublic(SlhDsaKey* key, byte* pub, word32* pubLen)
{
int ret = 0;
if ((key == NULL) || (key->params == NULL) || (pub == NULL) ||
(pubLen == NULL)) {
ret = BAD_FUNC_ARG;
}
else if (*pubLen < key->params->n * 2) {
ret = BAD_LENGTH_E;
}
else {
int n = key->params->n;
XMEMCPY(pub, key->sk + n * 2, n * 2);
*pubLen = n * 2;
}
return ret;
}
#ifndef WOLFSSL_SLHDSA_VERIFY_ONLY
int wc_SlhDsaKey_PrivateSize(SlhDsaKey* key)
{
int ret;
if ((key == NULL) || (key->params == NULL)) {
ret = BAD_FUNC_ARG;
}
else {
ret = key->params->n * 4;
}
return ret;
}
#endif
int wc_SlhDsaKey_PublicSize(SlhDsaKey* key)
{
int ret;
if ((key == NULL) || (key->params == NULL)) {
ret = BAD_FUNC_ARG;
}
else {
ret = key->params->n * 2;
}
return ret;
}
int wc_SlhDsaKey_SigSize(SlhDsaKey* key)
{
int ret;
if ((key == NULL) || (key->params == NULL)) {
ret = BAD_FUNC_ARG;
}
else {
ret = key->params->sigLen;
}
return ret;
}
#ifndef WOLFSSL_SLHDSA_VERIFY_ONLY
int wc_SlhDsaKey_PrivateSizeFromParam(enum SlhDsaParam param)
{
int ret;
switch (param) {
case SLHDSA_SHAKE128S:
ret = WC_SLHDSA_SHAKE128S_PRIV_LEN;
break;
case SLHDSA_SHAKE128F:
ret = WC_SLHDSA_SHAKE128F_PRIV_LEN;
break;
case SLHDSA_SHAKE192S:
ret = WC_SLHDSA_SHAKE192S_PRIV_LEN;
break;
case SLHDSA_SHAKE192F:
ret = WC_SLHDSA_SHAKE192F_PRIV_LEN;
break;
case SLHDSA_SHAKE256S:
ret = WC_SLHDSA_SHAKE256S_PRIV_LEN;
break;
case SLHDSA_SHAKE256F:
ret = WC_SLHDSA_SHAKE256F_PRIV_LEN;
break;
default:
ret = NOT_COMPILED_IN;
break;
}
return ret;
}
#endif
int wc_SlhDsaKey_PublicSizeFromParam(enum SlhDsaParam param)
{
int ret;
switch (param) {
case SLHDSA_SHAKE128S:
ret = WC_SLHDSA_SHAKE128S_PUB_LEN;
break;
case SLHDSA_SHAKE128F:
ret = WC_SLHDSA_SHAKE128F_PUB_LEN;
break;
case SLHDSA_SHAKE192S:
ret = WC_SLHDSA_SHAKE192S_PUB_LEN;
break;
case SLHDSA_SHAKE192F:
ret = WC_SLHDSA_SHAKE192F_PUB_LEN;
break;
case SLHDSA_SHAKE256S:
ret = WC_SLHDSA_SHAKE256S_PUB_LEN;
break;
case SLHDSA_SHAKE256F:
ret = WC_SLHDSA_SHAKE256F_PUB_LEN;
break;
default:
ret = NOT_COMPILED_IN;
break;
}
return ret;
}
int wc_SlhDsaKey_SigSizeFromParam(enum SlhDsaParam param)
{
int ret;
switch (param) {
case SLHDSA_SHAKE128S:
ret = WC_SLHDSA_SHAKE128S_SIG_LEN;
break;
case SLHDSA_SHAKE128F:
ret = WC_SLHDSA_SHAKE128F_SIG_LEN;
break;
case SLHDSA_SHAKE192S:
ret = WC_SLHDSA_SHAKE192S_SIG_LEN;
break;
case SLHDSA_SHAKE192F:
ret = WC_SLHDSA_SHAKE192F_SIG_LEN;
break;
case SLHDSA_SHAKE256S:
ret = WC_SLHDSA_SHAKE256S_SIG_LEN;
break;
case SLHDSA_SHAKE256F:
ret = WC_SLHDSA_SHAKE256F_SIG_LEN;
break;
default:
ret = NOT_COMPILED_IN;
break;
}
return ret;
}
#endif