#include <openssl/byteorder.h>
#include "ml_dsa_local.h"
#include "ml_dsa_vector.h"
#include "ml_dsa_matrix.h"
#include "ml_dsa_hash.h"
#include "internal/sha3.h"
#include "internal/packet.h"
#define SHAKE128_BLOCKSIZE SHA3_BLOCKSIZE(128)
#define SHAKE256_BLOCKSIZE SHA3_BLOCKSIZE(256)
#define MOD5(n) ((n) - 5 * (0x3335 * (n) >> 16))
#if SHAKE128_BLOCKSIZE % 3 != 0
#error "rej_ntt_poly() requires SHAKE128_BLOCKSIZE to be a multiple of 3"
#endif
typedef int(COEFF_FROM_NIBBLE_FUNC)(uint32_t nibble, uint32_t *out);
static COEFF_FROM_NIBBLE_FUNC coeff_from_nibble_4;
static COEFF_FROM_NIBBLE_FUNC coeff_from_nibble_2;
static ossl_inline int coeff_from_three_bytes(const uint8_t *s, uint32_t *out)
{
*out = (uint32_t)s[0] | ((uint32_t)s[1] << 8) | (((uint32_t)s[2] & 0x7f) << 16);
return *out < ML_DSA_Q;
}
static ossl_inline int coeff_from_nibble_4(uint32_t nibble, uint32_t *out)
{
if (value_barrier_32(nibble < 9)) {
*out = mod_sub(4, nibble);
return 1;
}
return 0;
}
static ossl_inline int coeff_from_nibble_2(uint32_t nibble, uint32_t *out)
{
if (value_barrier_32(nibble < 15)) {
*out = mod_sub(2, MOD5(nibble));
return 1;
}
return 0;
}
static int rej_ntt_poly(EVP_MD_CTX *g_ctx, const EVP_MD *md,
const uint8_t *seed, size_t seed_len, POLY *out)
{
int j = 0;
uint8_t blocks[SHAKE128_BLOCKSIZE], *b, *end = blocks + sizeof(blocks);
if (!shake_xof(g_ctx, md, seed, seed_len, blocks, sizeof(blocks)))
return 0;
while (1) {
for (b = blocks; b < end; b += 3) {
if (coeff_from_three_bytes(b, &(out->coeff[j]))) {
if (++j >= ML_DSA_NUM_POLY_COEFFICIENTS)
return 1;
}
}
if (!EVP_DigestSqueeze(g_ctx, blocks, sizeof(blocks)))
return 0;
}
}
static int rej_bounded_poly(EVP_MD_CTX *h_ctx, const EVP_MD *md,
COEFF_FROM_NIBBLE_FUNC *coef_from_nibble,
const uint8_t *seed, size_t seed_len, POLY *out)
{
int j = 0;
uint32_t z0, z1;
uint8_t blocks[SHAKE256_BLOCKSIZE], *b, *end = blocks + sizeof(blocks);
if (!shake_xof(h_ctx, md, seed, seed_len, blocks, sizeof(blocks)))
return 0;
while (1) {
for (b = blocks; b < end; b++) {
z0 = *b & 0x0F;
z1 = *b >> 4;
if (coef_from_nibble(z0, &out->coeff[j])
&& ++j >= ML_DSA_NUM_POLY_COEFFICIENTS)
return 1;
if (coef_from_nibble(z1, &out->coeff[j])
&& ++j >= ML_DSA_NUM_POLY_COEFFICIENTS)
return 1;
}
if (!EVP_DigestSqueeze(h_ctx, blocks, sizeof(blocks)))
return 0;
}
}
int ossl_ml_dsa_matrix_expand_A(EVP_MD_CTX *g_ctx, const EVP_MD *md,
const uint8_t *rho, MATRIX *out)
{
int ret = 0;
size_t i, j;
uint8_t derived_seed[ML_DSA_RHO_BYTES + 2];
POLY *poly = out->m_poly;
memcpy(derived_seed, rho, ML_DSA_RHO_BYTES);
for (i = 0; i < out->k; i++) {
for (j = 0; j < out->l; j++) {
derived_seed[ML_DSA_RHO_BYTES + 1] = (uint8_t)i;
derived_seed[ML_DSA_RHO_BYTES] = (uint8_t)j;
if (!rej_ntt_poly(g_ctx, md, derived_seed, sizeof(derived_seed), poly++))
goto err;
}
}
ret = 1;
err:
return ret;
}
int ossl_ml_dsa_vector_expand_S(EVP_MD_CTX *h_ctx, const EVP_MD *md, int eta,
const uint8_t *seed, VECTOR *s1, VECTOR *s2)
{
int ret = 0;
size_t i;
size_t l = s1->num_poly;
size_t k = s2->num_poly;
uint8_t derived_seed[ML_DSA_PRIV_SEED_BYTES + 2];
COEFF_FROM_NIBBLE_FUNC *coef_from_nibble_fn;
coef_from_nibble_fn = (eta == ML_DSA_ETA_4) ? coeff_from_nibble_4 : coeff_from_nibble_2;
memcpy(derived_seed, seed, ML_DSA_PRIV_SEED_BYTES);
derived_seed[ML_DSA_PRIV_SEED_BYTES] = 0;
derived_seed[ML_DSA_PRIV_SEED_BYTES + 1] = 0;
for (i = 0; i < l; i++) {
if (!rej_bounded_poly(h_ctx, md, coef_from_nibble_fn,
derived_seed, sizeof(derived_seed), &s1->poly[i]))
goto err;
++derived_seed[ML_DSA_PRIV_SEED_BYTES];
}
for (i = 0; i < k; i++) {
if (!rej_bounded_poly(h_ctx, md, coef_from_nibble_fn,
derived_seed, sizeof(derived_seed), &s2->poly[i]))
goto err;
++derived_seed[ML_DSA_PRIV_SEED_BYTES];
}
ret = 1;
err:
return ret;
}
int ossl_ml_dsa_poly_expand_mask(POLY *out, const uint8_t *seed, size_t seed_len,
uint32_t gamma1,
EVP_MD_CTX *h_ctx, const EVP_MD *md)
{
uint8_t buf[32 * 20];
size_t buf_len = 32 * (gamma1 == ML_DSA_GAMMA1_TWO_POWER_19 ? 20 : 18);
return shake_xof(h_ctx, md, seed, seed_len, buf, buf_len)
&& ossl_ml_dsa_poly_decode_expand_mask(out, buf, buf_len, gamma1);
}
int ossl_ml_dsa_poly_sample_in_ball(POLY *out_c, const uint8_t *seed, int seed_len,
EVP_MD_CTX *h_ctx, const EVP_MD *md,
uint32_t tau)
{
uint8_t block[SHAKE256_BLOCKSIZE];
uint64_t signs;
int offset = 8;
size_t end;
if (!shake_xof(h_ctx, md, seed, seed_len, block, sizeof(block)))
return 0;
OPENSSL_load_u64_le(&signs, block);
poly_zero(out_c);
for (end = 256 - tau; end < 256; end++) {
size_t index;
for (;;) {
if (offset == sizeof(block)) {
if (!EVP_DigestSqueeze(h_ctx, block, sizeof(block)))
return 0;
offset = 0;
}
index = block[offset++];
if (index <= end)
break;
}
out_c->coeff[end] = out_c->coeff[index];
out_c->coeff[index] = mod_sub(1, 2 * (signs & 1));
signs >>= 1;
}
return 1;
}