#include <wolfssl/wolfcrypt/libwolfssl_sources.h>
#include <wolfssl/wolfcrypt/sha256.h>
#if defined(WOLFSSL_HAVE_XMSS) && defined(HAVE_LIBXMSS)
#include <wolfssl/wolfcrypt/ext_xmss.h>
#ifdef NO_INLINE
#include <wolfssl/wolfcrypt/misc.h>
#else
#define WOLFSSL_MISC_INCLUDED
#include <wolfcrypt/src/misc.c>
#endif
#include <xmss_callbacks.h>
#include <utils.h>
#ifndef WOLFSSL_XMSS_VERIFY_ONLY
static THREAD_LS_T WC_RNG * xmssRng = NULL;
static int rng_cb(void * output, size_t length)
{
int ret = 0;
if (output == NULL || xmssRng == NULL) {
return -1;
}
if (length == 0) {
return 0;
}
ret = wc_RNG_GenerateBlock(xmssRng, (byte *)output, (word32)length);
if (ret) {
WOLFSSL_MSG("error: XMSS rng_cb failed");
return -1;
}
return 0;
}
#endif
static int sha256_cb(const unsigned char *in, unsigned long long inlen,
unsigned char *out)
{
wc_Sha256 sha;
if (wc_InitSha256_ex(&sha, NULL, INVALID_DEVID) != 0) {
WOLFSSL_MSG("SHA256 Init failed");
return -1;
}
if (wc_Sha256Update(&sha, in, (word32) inlen) != 0) {
WOLFSSL_MSG("SHA256 Update failed");
return -1;
}
if (wc_Sha256Final(&sha, out) != 0) {
WOLFSSL_MSG("SHA256 Final failed");
wc_Sha256Free(&sha);
return -1;
}
wc_Sha256Free(&sha);
return 0;
}
int wc_XmssKey_Init(XmssKey * key, void * heap, int devId)
{
if (key == NULL) {
return BAD_FUNC_ARG;
}
(void) heap;
(void) devId;
ForceZero(key, sizeof(XmssKey));
#ifndef WOLFSSL_XMSS_VERIFY_ONLY
key->sk = NULL;
key->sk_len = 0;
key->write_private_key = NULL;
key->read_private_key = NULL;
key->context = NULL;
#endif
key->state = WC_XMSS_STATE_INITED;
return 0;
}
static int wc_XmssKey_SetOid(XmssKey * key, uint32_t oid, int is_xmssmt)
{
int ret = 0;
if (key == NULL || oid == 0) {
return BAD_FUNC_ARG;
}
if (is_xmssmt) {
ret = xmssmt_parse_oid(&key->params, oid);
}
else {
ret = xmss_parse_oid(&key->params, oid);
}
if (ret != 0) {
WOLFSSL_MSG("error: XMSS parse oid failed");
return -1;
}
if (key->params.func != XMSS_SHA2 ||
key->params.n != XMSS_SHA256_N ||
key->params.padding_len != XMSS_SHA256_PADDING_LEN ||
key->params.wots_w != 16 ||
key->params.wots_len != XMSS_SHA256_WOTS_LEN) {
WOLFSSL_MSG("error: unsupported XMSS/XMSS^MT parameter set");
return -1;
}
if ((key->params.full_height < WOLFSSL_XMSS_MIN_HEIGHT) ||
(key->params.full_height > WOLFSSL_XMSS_MAX_HEIGHT)) {
WOLFSSL_MSG("error: unsupported XMSS/XMSS^MT parameter set - height");
return -1;
}
ret = xmss_set_sha_cb(sha256_cb);
if (ret != 0) {
WOLFSSL_MSG("error: xmss_set_sha_cb failed");
return -1;
}
#ifndef WOLFSSL_XMSS_VERIFY_ONLY
ret = xmss_set_rng_cb(rng_cb);
if (ret != 0) {
WOLFSSL_MSG("error: xmss_set_rng_cb failed");
return -1;
}
#endif
key->oid = oid;
key->is_xmssmt = is_xmssmt;
key->state = WC_XMSS_STATE_PARMSET;
return 0;
}
int wc_XmssKey_SetParamStr(XmssKey * key, const char * str)
{
int ret = 0;
uint32_t oid = 0;
int is_xmssmt = 0;
if (key == NULL || str == NULL) {
return BAD_FUNC_ARG;
}
if (key->state != WC_XMSS_STATE_INITED) {
WOLFSSL_MSG("error: XMSS key needs init");
return BAD_FUNC_ARG;
}
switch(XSTRLEN(str)) {
case XMSS_NAME_LEN:
is_xmssmt = 0;
break;
case XMSSMT_NAME_MIN_LEN:
case XMSSMT_NAME_MAX_LEN:
is_xmssmt = 1;
break;
default:
WOLFSSL_MSG("error: XMSS param str invalid length");
return BAD_FUNC_ARG;
}
if (is_xmssmt) {
ret = xmssmt_str_to_oid(&oid, str);
}
else {
ret = xmss_str_to_oid(&oid, str);
}
if (ret != 0) {
WOLFSSL_MSG("error: xmssmt_str_to_oid failed");
return -1;
}
return wc_XmssKey_SetOid(key, oid, is_xmssmt);
}
void wc_XmssKey_Free(XmssKey* key)
{
if (key == NULL) {
return;
}
#ifndef WOLFSSL_XMSS_VERIFY_ONLY
if (key->sk != NULL) {
ForceZero(key->sk, key->sk_len);
XFREE(key->sk, NULL, DYNAMIC_TYPE_TMP_BUFFER);
key->sk = NULL;
key->sk_len = 0;
}
#endif
ForceZero(key, sizeof(XmssKey));
key->state = WC_XMSS_STATE_FREED;
return;
}
#ifndef WOLFSSL_XMSS_VERIFY_ONLY
int wc_XmssKey_SetWriteCb(XmssKey * key, wc_xmss_write_private_key_cb write_cb)
{
if (key == NULL || write_cb == NULL) {
return BAD_FUNC_ARG;
}
if (key->state == WC_XMSS_STATE_OK) {
WOLFSSL_MSG("error: wc_XmssKey_SetWriteCb: key in use");
return -1;
}
key->write_private_key = write_cb;
return 0;
}
int wc_XmssKey_SetReadCb(XmssKey * key, wc_xmss_read_private_key_cb read_cb)
{
if (key == NULL || read_cb == NULL) {
return BAD_FUNC_ARG;
}
if (key->state == WC_XMSS_STATE_OK) {
WOLFSSL_MSG("error: wc_XmssKey_SetReadCb: key in use");
return -1;
}
key->read_private_key = read_cb;
return 0;
}
int wc_XmssKey_SetContext(XmssKey * key, void * context)
{
if (key == NULL || context == NULL) {
return BAD_FUNC_ARG;
}
if (key->state == WC_XMSS_STATE_OK) {
WOLFSSL_MSG("error: wc_XmssKey_SetContext: key in use");
return -1;
}
key->context = context;
return 0;
}
static int wc_XmssKey_AllocSk(XmssKey* key)
{
int ret = 0;
if (key == NULL) {
return BAD_FUNC_ARG;
}
if (key->sk != NULL) {
WOLFSSL_MSG("error: XMSS secret key already exists");
return -1;
}
ret = wc_XmssKey_GetPrivLen(key, &key->sk_len);
if (ret != 0 || key->sk_len <= 0) {
WOLFSSL_MSG("error: wc_XmssKey_GetPrivLen failed");
return -1;
}
key->sk = (unsigned char *)XMALLOC(key->sk_len, NULL,
DYNAMIC_TYPE_TMP_BUFFER);
if (key->sk == NULL) {
WOLFSSL_MSG("error: malloc XMSS key->sk failed");
return -1;
}
ForceZero(key->sk, key->sk_len);
return 0;
}
int wc_XmssKey_MakeKey(XmssKey* key, WC_RNG * rng)
{
int ret = 0;
enum wc_XmssRc cb_rc = WC_XMSS_RC_NONE;
if (key == NULL || rng == NULL) {
return BAD_FUNC_ARG;
}
if (key->state != WC_XMSS_STATE_PARMSET) {
WOLFSSL_MSG("error: XmssKey not ready for generation");
return -1;
}
if (key->write_private_key == NULL || key->read_private_key == NULL) {
WOLFSSL_MSG("error: XmssKey write/read callbacks are not set");
return -1;
}
if (key->context == NULL) {
WOLFSSL_MSG("error: XmssKey context is not set");
return -1;
}
ret = wc_XmssKey_AllocSk(key);
if (ret != 0) {
return ret;
}
xmssRng = rng;
if (key->is_xmssmt) {
ret = xmssmt_keypair(key->pk, key->sk, key->oid);
}
else {
ret = xmss_keypair(key->pk, key->sk, key->oid);
}
if (ret == 0) {
cb_rc = key->write_private_key(key->sk, key->sk_len, key->context);
}
ForceZero(key->sk, key->sk_len);
if (ret != 0) {
WOLFSSL_MSG("error: XMSS keypair failed");
key->state = WC_XMSS_STATE_BAD;
return -1;
}
if (cb_rc != WC_XMSS_RC_SAVED_TO_NV_MEMORY) {
WOLFSSL_MSG("error: XMSS write to NV storage failed");
key->state = WC_XMSS_STATE_BAD;
return -1;
}
key->state = WC_XMSS_STATE_OK;
return 0;
}
int wc_XmssKey_Reload(XmssKey * key)
{
int ret = 0;
enum wc_XmssRc cb_rc = WC_XMSS_RC_NONE;
if (key == NULL) {
return BAD_FUNC_ARG;
}
if (key->state != WC_XMSS_STATE_PARMSET) {
WOLFSSL_MSG("error: XmssKey not ready for reload");
return -1;
}
if (key->write_private_key == NULL || key->read_private_key == NULL) {
WOLFSSL_MSG("error: XmssKey write/read callbacks are not set");
return -1;
}
if (key->context == NULL) {
WOLFSSL_MSG("error: XmssKey context is not set");
return -1;
}
ret = wc_XmssKey_AllocSk(key);
if (ret != 0) {
return ret;
}
cb_rc = key->read_private_key(key->sk, key->sk_len, key->context);
ForceZero(key->sk, key->sk_len);
if (cb_rc != WC_XMSS_RC_READ_TO_MEMORY) {
WOLFSSL_MSG("error: XMSS read from NV storage failed");
key->state = WC_XMSS_STATE_BAD;
return -1;
}
key->state = WC_XMSS_STATE_OK;
return 0;
}
int wc_XmssKey_GetPrivLen(const XmssKey * key, word32 * len)
{
if (key == NULL || len == NULL) {
return BAD_FUNC_ARG;
}
if (key->state != WC_XMSS_STATE_OK && key->state != WC_XMSS_STATE_PARMSET) {
return -1;
}
*len = XMSS_OID_LEN + (word32) key->params.sk_bytes;
return 0;
}
static void wc_XmssKey_SignUpdate(XmssKey* key, byte * sig, word32 * sigLen,
const byte * msg, int msgLen)
{
int ret = -1;
unsigned long long len = *sigLen;
enum wc_XmssRc cb_rc = WC_XMSS_RC_NONE;
key->state = WC_XMSS_STATE_BAD;
*sigLen = 0;
cb_rc = key->read_private_key(key->sk, key->sk_len, key->context);
if (cb_rc == WC_XMSS_RC_READ_TO_MEMORY) {
if (key->is_xmssmt) {
ret = xmssmt_sign(key->sk, sig, &len, msg, msgLen);
}
else {
ret = xmss_sign(key->sk, sig, &len, msg, msgLen);
}
if (ret == 0 && len == key->params.sig_bytes) {
cb_rc = key->write_private_key(key->sk, key->sk_len, key->context);
if (cb_rc == WC_XMSS_RC_SAVED_TO_NV_MEMORY) {
key->state = WC_XMSS_STATE_OK;
*sigLen = (word32) len;
}
else {
ForceZero(sig, key->params.sig_bytes);
WOLFSSL_MSG("error: XMSS write_private_key failed");
}
}
else if (ret == -2) {
key->state = WC_XMSS_STATE_NOSIGS;
WOLFSSL_MSG("error: no XMSS signatures remaining");
}
else {
ForceZero(sig, key->params.sig_bytes);
WOLFSSL_MSG("error: XMSS sign failed");
}
}
else {
WOLFSSL_MSG("error: XMSS read_private_key failed");
}
ForceZero(key->sk, key->sk_len);
return;
}
int wc_XmssKey_Sign(XmssKey* key, byte * sig, word32 * sigLen, const byte * msg,
int msgLen)
{
if (key == NULL || sig == NULL || sigLen == NULL || msg == NULL) {
return BAD_FUNC_ARG;
}
if (msgLen <= 0) {
return BAD_FUNC_ARG;
}
if (*sigLen < key->params.sig_bytes) {
WOLFSSL_MSG("error: XMSS sig buffer too small");
return BUFFER_E;
}
if (key->state == WC_XMSS_STATE_NOSIGS) {
WOLFSSL_MSG("error: XMSS signatures exhausted");
return -1;
}
else if (key->state != WC_XMSS_STATE_OK) {
WOLFSSL_MSG("error: can't sign, XMSS key not in good state");
return -1;
}
if (key->write_private_key == NULL || key->read_private_key == NULL) {
WOLFSSL_MSG("error: XmssKey write/read callbacks are not set");
return -1;
}
if (key->context == NULL) {
WOLFSSL_MSG("error: XmssKey context is not set");
return -1;
}
wc_XmssKey_SignUpdate(key, sig, sigLen, msg, msgLen);
return (key->state == WC_XMSS_STATE_OK) ? 0 : -1;
}
int wc_XmssKey_SigsLeft(XmssKey* key)
{
int ret = 0;
if (key == NULL) {
ret = 0;
}
else if (key->state == WC_XMSS_STATE_NOSIGS) {
WOLFSSL_MSG("error: XMSS signatures exhausted");
ret = 0;
}
else if (key->state != WC_XMSS_STATE_OK) {
WOLFSSL_MSG("error: can't sign, XMSS key not in good state");
ret = 0;
}
else if (key->read_private_key(key->sk, key->sk_len, key->context) !=
WC_XMSS_RC_READ_TO_MEMORY) {
WOLFSSL_MSG("error: XMSS read_private_key failed");
ret = 0;
}
else {
const unsigned char* sk = (key->sk + XMSS_OID_LEN);
const xmss_params* params = &key->params;
unsigned long long idx = 0;
if (key->is_xmssmt) {
for (uint64_t i = 0; i < params->index_bytes; i++) {
idx |= ((unsigned long long)sk[i])
<< 8 * (params->index_bytes - 1 - i);
}
}
else {
idx = ((unsigned long)sk[0] << 24) |
((unsigned long)sk[1] << 16) |
((unsigned long)sk[2] << 8) | sk[3];
}
ret = idx < ((1ULL << params->full_height) - 1);
ForceZero(key->sk, key->sk_len);
}
return ret;
}
#endif
int wc_XmssKey_GetPubLen(const XmssKey * key, word32 * len)
{
if (key == NULL || len == NULL) {
return BAD_FUNC_ARG;
}
*len = XMSS_SHA256_PUBLEN;
return 0;
}
int wc_XmssKey_ExportPub(XmssKey * keyDst, const XmssKey * keySrc)
{
if (keyDst == NULL || keySrc == NULL) {
return BAD_FUNC_ARG;
}
ForceZero(keyDst, sizeof(XmssKey));
XMEMCPY(keyDst->pk, keySrc->pk, sizeof(keySrc->pk));
keyDst->oid = keySrc->oid;
keyDst->is_xmssmt = keySrc->is_xmssmt;
keyDst->state = WC_XMSS_STATE_VERIFYONLY;
return 0;
}
int wc_XmssKey_ExportPubRaw(const XmssKey * key, byte * out, word32 * outLen)
{
int ret = 0;
word32 pubLen = 0;
if (key == NULL || out == NULL || outLen == NULL) {
return BAD_FUNC_ARG;
}
ret = wc_XmssKey_GetPubLen(key, &pubLen);
if (ret != 0) {
WOLFSSL_MSG("error: wc_XmssKey_GetPubLen failed");
return -1;
}
if (*outLen < pubLen) {
return BUFFER_E;
}
XMEMCPY(out, key->pk, pubLen);
*outLen = pubLen;
return 0;
}
int wc_XmssKey_ImportPubRaw(XmssKey * key, const byte * in, word32 inLen)
{
int ret = 0;
word32 pubLen = 0;
if (key == NULL || in == NULL) {
return BAD_FUNC_ARG;
}
if (key->state != WC_XMSS_STATE_PARMSET) {
WOLFSSL_MSG("error: XMSS key not ready for import");
return -1;
}
ret = wc_XmssKey_GetPubLen(key, &pubLen);
if (ret != 0) {
WOLFSSL_MSG("error: wc_XmssKey_GetPubLen failed");
return -1;
}
if (inLen != pubLen) {
return BUFFER_E;
}
XMEMCPY(key->pk, in, pubLen);
key->state = WC_XMSS_STATE_VERIFYONLY;
return 0;
}
int wc_XmssKey_GetSigLen(const XmssKey * key, word32 * len)
{
if (key == NULL || len == NULL) {
return BAD_FUNC_ARG;
}
if (key->state != WC_XMSS_STATE_OK && key->state != WC_XMSS_STATE_PARMSET) {
return -1;
}
*len = key->params.sig_bytes;
return 0;
}
int wc_XmssKey_Verify(XmssKey * key, const byte * sig, word32 sigLen,
const byte * msg, int msgLen)
{
int ret = 0;
unsigned long long msg_len = 0;
msg_len = msgLen;
if (key == NULL || sig == NULL || msg == NULL) {
return BAD_FUNC_ARG;
}
if (sigLen < key->params.sig_bytes) {
return BUFFER_E;
}
if (key->state != WC_XMSS_STATE_OK &&
key->state != WC_XMSS_STATE_VERIFYONLY) {
WOLFSSL_MSG("error: XMSS key not ready for verification");
return -1;
}
if (key->is_xmssmt) {
ret = xmssmt_sign_open(msg, &msg_len, sig, sigLen, key->pk);
}
else {
ret = xmss_sign_open(msg, &msg_len, sig, sigLen, key->pk);
}
if (ret != 0 || (int) msg_len != msgLen) {
WOLFSSL_MSG("error: XMSS verify failed");
return -1;
}
return ret;
}
#endif