#include <wolfssl/wolfcrypt/libwolfssl_sources.h>
#if defined(WOLFSSL_HAVE_MLKEM) && !defined(WOLFSSL_WC_MLKEM)
#include <wolfssl/wolfcrypt/ext_mlkem.h>
#ifdef NO_INLINE
#include <wolfssl/wolfcrypt/misc.h>
#else
#define WOLFSSL_MISC_INCLUDED
#include <wolfcrypt/src/misc.c>
#endif
#if defined (HAVE_LIBOQS)
#include <wolfssl/wolfcrypt/port/liboqs/liboqs.h>
static const char* OQS_ID2name(int id) {
switch (id) {
#ifndef WOLFSSL_NO_ML_KEM
case WC_ML_KEM_512: return OQS_KEM_alg_ml_kem_512;
case WC_ML_KEM_768: return OQS_KEM_alg_ml_kem_768;
case WC_ML_KEM_1024: return OQS_KEM_alg_ml_kem_1024;
#endif
#ifdef WOLFSSL_MLKEM_KYBER
case KYBER_LEVEL1: return OQS_KEM_alg_kyber_512;
case KYBER_LEVEL3: return OQS_KEM_alg_kyber_768;
case KYBER_LEVEL5: return OQS_KEM_alg_kyber_1024;
#endif
default: break;
}
return NULL;
}
int ext_mlkem_enabled(int id)
{
const char * name = OQS_ID2name(id);
return OQS_KEM_alg_is_enabled(name);
}
#endif
int wc_MlKemKey_Init(MlKemKey* key, int type, void* heap, int devId)
{
int ret = 0;
if (key == NULL) {
ret = BAD_FUNC_ARG;
}
if (ret == 0) {
switch (type) {
#ifndef WOLFSSL_NO_ML_KEM
case WC_ML_KEM_512:
#ifdef HAVE_LIBOQS
case WC_ML_KEM_768:
case WC_ML_KEM_1024:
#endif
#endif
#ifdef WOLFSSL_MLKEM_KYBER
case KYBER_LEVEL1:
#ifdef HAVE_LIBOQS
case KYBER_LEVEL3:
case KYBER_LEVEL5:
#endif
#endif
break;
default:
ret = BAD_FUNC_ARG;
break;
}
}
if (ret == 0) {
XMEMSET(key, 0, sizeof(*key));
key->type = type;
#ifdef WOLF_CRYPTO_CB
key->devCtx = NULL;
key->devId = devId;
#endif
}
(void)heap;
(void)devId;
return ret;
}
int wc_MlKemKey_Free(MlKemKey* key)
{
if (key != NULL) {
ForceZero(key, sizeof(*key));
}
return 0;
}
int wc_MlKemKey_PrivateKeySize(MlKemKey* key, word32* len)
{
int ret = 0;
if ((key == NULL) || (len == NULL)) {
ret = BAD_FUNC_ARG;
}
#ifdef HAVE_LIBOQS
if (ret == 0) {
switch (key->type) {
#ifndef WOLFSSL_NO_ML_KEM
case WC_ML_KEM_512:
*len = OQS_KEM_ml_kem_512_length_secret_key;
break;
case WC_ML_KEM_768:
*len = OQS_KEM_ml_kem_768_length_secret_key;
break;
case WC_ML_KEM_1024:
*len = OQS_KEM_ml_kem_1024_length_secret_key;
break;
#endif
#ifdef WOLFSSL_MLKEM_KYBER
case KYBER_LEVEL1:
*len = OQS_KEM_kyber_512_length_secret_key;
break;
case KYBER_LEVEL3:
*len = OQS_KEM_kyber_768_length_secret_key;
break;
case KYBER_LEVEL5:
*len = OQS_KEM_kyber_1024_length_secret_key;
break;
#endif
default:
ret = BAD_FUNC_ARG;
break;
}
}
#endif
return ret;
}
int wc_MlKemKey_PublicKeySize(MlKemKey* key, word32* len)
{
int ret = 0;
if ((key == NULL) || (len == NULL)) {
ret = BAD_FUNC_ARG;
}
#ifdef HAVE_LIBOQS
if (ret == 0) {
switch (key->type) {
#ifndef WOLFSSL_NO_ML_KEM
case WC_ML_KEM_512:
*len = OQS_KEM_ml_kem_512_length_public_key;
break;
case WC_ML_KEM_768:
*len = OQS_KEM_ml_kem_768_length_public_key;
break;
case WC_ML_KEM_1024:
*len = OQS_KEM_ml_kem_1024_length_public_key;
break;
#endif
#ifdef WOLFSSL_MLKEM_KYBER
case KYBER_LEVEL1:
*len = OQS_KEM_kyber_512_length_public_key;
break;
case KYBER_LEVEL3:
*len = OQS_KEM_kyber_768_length_public_key;
break;
case KYBER_LEVEL5:
*len = OQS_KEM_kyber_1024_length_public_key;
break;
#endif
default:
ret = BAD_FUNC_ARG;
break;
}
}
#endif
return ret;
}
int wc_MlKemKey_CipherTextSize(MlKemKey* key, word32* len)
{
int ret = 0;
if ((key == NULL) || (len == NULL)) {
ret = BAD_FUNC_ARG;
}
#ifdef HAVE_LIBOQS
if (ret == 0) {
switch (key->type) {
#ifndef WOLFSSL_NO_ML_KEM
case WC_ML_KEM_512:
*len = OQS_KEM_ml_kem_512_length_ciphertext;
break;
case WC_ML_KEM_768:
*len = OQS_KEM_ml_kem_768_length_ciphertext;
break;
case WC_ML_KEM_1024:
*len = OQS_KEM_ml_kem_1024_length_ciphertext;
break;
#endif
#ifdef WOLFSSL_MLKEM_KYBER
case KYBER_LEVEL1:
*len = OQS_KEM_kyber_512_length_ciphertext;
break;
case KYBER_LEVEL3:
*len = OQS_KEM_kyber_768_length_ciphertext;
break;
case KYBER_LEVEL5:
*len = OQS_KEM_kyber_1024_length_ciphertext;
break;
#endif
default:
ret = BAD_FUNC_ARG;
break;
}
}
#endif
return ret;
}
int wc_MlKemKey_SharedSecretSize(MlKemKey* key, word32* len)
{
(void)key;
if (len == NULL) {
return BAD_FUNC_ARG;
}
*len = KYBER_SS_SZ;
return 0;
}
int wc_MlKemKey_MakeKey(MlKemKey* key, WC_RNG* rng)
{
int ret = 0;
#ifdef HAVE_LIBOQS
const char* algName = NULL;
OQS_KEM *kem = NULL;
#endif
if (key == NULL) {
return BAD_FUNC_ARG;
}
#ifdef WOLF_CRYPTO_CB
#ifndef WOLF_CRYPTO_CB_FIND
if (key->devId != INVALID_DEVID)
#endif
{
ret = wc_CryptoCb_MakePqcKemKey(rng, WC_PQC_KEM_TYPE_KYBER,
key->type, key);
if (ret != WC_NO_ERR_TRACE(CRYPTOCB_UNAVAILABLE))
return ret;
ret = 0;
}
#endif
#ifdef HAVE_LIBOQS
if (ret == 0) {
algName = OQS_ID2name(key->type);
if (algName == NULL) {
ret = BAD_FUNC_ARG;
}
}
if (ret == 0) {
kem = OQS_KEM_new(algName);
if (kem == NULL) {
ret = BAD_FUNC_ARG;
}
}
if (ret == 0) {
ret = wolfSSL_liboqsRngMutexLock(rng);
}
if (ret == 0) {
if (OQS_KEM_keypair(kem, key->pub, key->priv) !=
OQS_SUCCESS) {
ret = BAD_FUNC_ARG;
}
}
wolfSSL_liboqsRngMutexUnlock();
OQS_KEM_free(kem);
#endif
if (ret != 0) {
ForceZero(key, sizeof(*key));
}
return ret;
}
int wc_MlKemKey_MakeKeyWithRandom(MlKemKey* key, const unsigned char* rand,
int len)
{
(void)rand;
(void)len;
return wc_MlKemKey_MakeKey(key, NULL);
}
int wc_MlKemKey_Encapsulate(MlKemKey* key, unsigned char* ct, unsigned char* ss,
WC_RNG* rng)
{
int ret = 0;
#ifdef WOLF_CRYPTO_CB
word32 ctlen = 0;
#endif
#ifdef HAVE_LIBOQS
const char * algName = NULL;
OQS_KEM *kem = NULL;
#endif
(void)rng;
if ((key == NULL) || (ct == NULL) || (ss == NULL)) {
ret = BAD_FUNC_ARG;
}
#ifdef WOLF_CRYPTO_CB
if (ret == 0) {
ret = wc_MlKemKey_CipherTextSize(key, &ctlen);
}
if ((ret == 0)
#ifndef WOLF_CRYPTO_CB_FIND
&& (key->devId != INVALID_DEVID)
#endif
) {
ret = wc_CryptoCb_PqcEncapsulate(ct, ctlen, ss, KYBER_SS_SZ, rng,
WC_PQC_KEM_TYPE_KYBER, key);
if (ret != WC_NO_ERR_TRACE(CRYPTOCB_UNAVAILABLE))
return ret;
ret = 0;
}
#endif
#ifdef HAVE_LIBOQS
if (ret == 0) {
algName = OQS_ID2name(key->type);
if (algName == NULL) {
ret = BAD_FUNC_ARG;
}
}
if (ret == 0) {
kem = OQS_KEM_new(algName);
if (kem == NULL) {
ret = BAD_FUNC_ARG;
}
}
if (ret == 0) {
ret = wolfSSL_liboqsRngMutexLock(rng);
}
if (ret == 0) {
if (OQS_KEM_encaps(kem, ct, ss, key->pub) != OQS_SUCCESS) {
ret = BAD_FUNC_ARG;
}
}
wolfSSL_liboqsRngMutexUnlock();
OQS_KEM_free(kem);
#endif
return ret;
}
int wc_MlKemKey_EncapsulateWithRandom(MlKemKey* key, unsigned char* ct,
unsigned char* ss, const unsigned char* rand, int len)
{
(void)rand;
(void)len;
return wc_MlKemKey_Encapsulate(key, ct, ss, NULL);
}
int wc_MlKemKey_Decapsulate(MlKemKey* key, unsigned char* ss,
const unsigned char* ct, word32 len)
{
int ret = 0;
word32 ctlen = 0;
#ifdef HAVE_LIBOQS
const char * algName = NULL;
OQS_KEM *kem = NULL;
#endif
if ((key == NULL) || (ss == NULL) || (ct == NULL)) {
ret = BAD_FUNC_ARG;
}
if (ret == 0) {
ret = wc_MlKemKey_CipherTextSize(key, &ctlen);
}
if ((ret == 0) && (len != ctlen)) {
ret = BUFFER_E;
}
#ifdef WOLF_CRYPTO_CB
if ((ret == 0)
#ifndef WOLF_CRYPTO_CB_FIND
&& (key->devId != INVALID_DEVID)
#endif
) {
ret = wc_CryptoCb_PqcDecapsulate(ct, ctlen, ss, KYBER_SS_SZ,
WC_PQC_KEM_TYPE_KYBER, key);
if (ret != WC_NO_ERR_TRACE(CRYPTOCB_UNAVAILABLE))
return ret;
ret = 0;
}
#endif
#ifdef HAVE_LIBOQS
if (ret == 0) {
algName = OQS_ID2name(key->type);
if (algName == NULL) {
ret = BAD_FUNC_ARG;
}
}
if (ret == 0) {
kem = OQS_KEM_new(algName);
if (kem == NULL) {
ret = BAD_FUNC_ARG;
}
}
if (ret == 0) {
if (OQS_KEM_decaps(kem, ss, ct, key->priv) != OQS_SUCCESS) {
ret = BAD_FUNC_ARG;
}
}
OQS_KEM_free(kem);
#endif
return ret;
}
int wc_MlKemKey_DecodePrivateKey(MlKemKey* key, const unsigned char* in,
word32 len)
{
int ret = 0;
word32 privLen = 0;
if ((key == NULL) || (in == NULL)) {
ret = BAD_FUNC_ARG;
}
if (ret == 0) {
ret = wc_MlKemKey_PrivateKeySize(key, &privLen);
}
if ((ret == 0) && (len != privLen)) {
ret = BUFFER_E;
}
if (ret == 0) {
XMEMCPY(key->priv, in, privLen);
}
return ret;
}
int wc_MlKemKey_DecodePublicKey(MlKemKey* key, const unsigned char* in,
word32 len)
{
int ret = 0;
word32 pubLen = 0;
if ((key == NULL) || (in == NULL)) {
ret = BAD_FUNC_ARG;
}
if (ret == 0) {
ret = wc_MlKemKey_PublicKeySize(key, &pubLen);
}
if ((ret == 0) && (len != pubLen)) {
ret = BUFFER_E;
}
if (ret == 0) {
XMEMCPY(key->pub, in, pubLen);
}
return ret;
}
int wc_MlKemKey_EncodePrivateKey(MlKemKey* key, unsigned char* out, word32 len)
{
int ret = 0;
unsigned int privLen = 0;
if ((key == NULL) || (out == NULL)) {
ret = BAD_FUNC_ARG;
}
if (ret == 0) {
ret = wc_MlKemKey_PrivateKeySize(key, &privLen);
}
if ((ret == 0) && (len != privLen)) {
ret = BUFFER_E;
}
if (ret == 0) {
XMEMCPY(out, key->priv, privLen);
}
return ret;
}
int wc_MlKemKey_EncodePublicKey(MlKemKey* key, unsigned char* out, word32 len)
{
int ret = 0;
unsigned int pubLen = 0;
if ((key == NULL) || (out == NULL)) {
ret = BAD_FUNC_ARG;
}
if (ret == 0) {
ret = wc_MlKemKey_PublicKeySize(key, &pubLen);
}
if ((ret == 0) && (len != pubLen)) {
ret = BUFFER_E;
}
if (ret == 0) {
XMEMCPY(out, key->pub, pubLen);
}
return ret;
}
#endif