#include <wolfssl/wolfcrypt/libwolfssl_sources.h>
#include <wolfssl/wolfcrypt/port/liboqs/liboqs.h>
#if defined(HAVE_LIBOQS)
static WC_RNG liboqsDefaultRNG;
static WC_RNG* liboqsCurrentRNG;
static wolfSSL_Mutex liboqsRNGMutex;
static int liboqs_init = 0;
static void wolfSSL_liboqsGetRandomData(uint8_t* buffer, size_t numOfBytes)
{
int ret;
word32 numOfBytes_word32;
while (numOfBytes > 0) {
numOfBytes_word32 = (word32)numOfBytes;
numOfBytes -= numOfBytes_word32;
ret = wc_RNG_GenerateBlock(liboqsCurrentRNG, buffer,
numOfBytes_word32);
if (ret != 0) {
WOLFSSL_MSG_EX(
"wc_RNG_GenerateBlock(..., %u) failed with ret %d "
"in wolfSSL_liboqsGetRandomData().", numOfBytes_word32, ret
);
abort();
}
}
}
int wolfSSL_liboqsInit(void)
{
int ret = 0;
if (liboqs_init == 0) {
ret = wc_InitMutex(&liboqsRNGMutex);
if (ret != 0) {
return ret;
}
ret = wc_LockMutex(&liboqsRNGMutex);
if (ret != 0) {
return ret;
}
ret = wc_InitRng(&liboqsDefaultRNG);
if (ret == 0) {
OQS_init();
liboqs_init = 1;
}
liboqsCurrentRNG = &liboqsDefaultRNG;
wc_UnLockMutex(&liboqsRNGMutex);
OQS_randombytes_custom_algorithm(wolfSSL_liboqsGetRandomData);
}
return ret;
}
void wolfSSL_liboqsClose(void)
{
wc_FreeRng(&liboqsDefaultRNG);
}
int wolfSSL_liboqsRngMutexLock(WC_RNG* rng)
{
int ret = wolfSSL_liboqsInit();
if (ret == 0) {
ret = wc_LockMutex(&liboqsRNGMutex);
}
if (ret == 0 && rng != NULL) {
liboqsCurrentRNG = rng;
}
return ret;
}
int wolfSSL_liboqsRngMutexUnlock(void)
{
liboqsCurrentRNG = &liboqsDefaultRNG;
if (liboqs_init) {
return wc_UnLockMutex(&liboqsRNGMutex);
}
else {
return BAD_MUTEX_E;
}
}
#endif