#ifdef HAVE_CONFIG_H
#include <config.h>
#endif
#include <wolfssl/wolfcrypt/settings.h>
#if defined(WOLFSSL_HAVE_PSA)
#if !defined(WOLFSSL_PSA_NO_AES)
#if !defined(NO_AES)
#include <wolfssl/wolfcrypt/error-crypt.h>
#include <wolfssl/wolfcrypt/port/psa/psa.h>
#include <wolfssl/wolfcrypt/logging.h>
#ifndef NO_INLINE
#define WOLFSSL_MISC_INCLUDED
#include <wolfcrypt/src/misc.c>
#else
#include <wolfssl/wolfcrypt/misc.h>
#endif
static int wc_psa_aes_import_key(Aes *aes, const uint8_t *key,
size_t key_length, psa_algorithm_t alg,
int dir)
{
psa_key_attributes_t key_attr;
psa_key_id_t id;
psa_status_t s;
XMEMSET(&key_attr, 0, sizeof(key_attr));
aes->key_id = PSA_KEY_ID_NULL;
aes->ctx_initialized = 0;
psa_set_key_type(&key_attr, PSA_KEY_TYPE_AES);
psa_set_key_bits(&key_attr, key_length * 8);
psa_set_key_usage_flags(&key_attr,
dir == AES_ENCRYPTION ? PSA_KEY_USAGE_ENCRYPT :
dir == AES_DECRYPTION ? PSA_KEY_USAGE_DECRYPT : 0);
psa_set_key_algorithm(&key_attr, alg);
PSA_LOCK();
s = psa_import_key(&key_attr, key, key_length, &id);
PSA_UNLOCK();
if (s != PSA_SUCCESS)
return WC_HW_E;
aes->key_id = id;
aes->key_need_importing = 0;
return 0;
}
int wc_psa_aes_init(Aes *aes)
{
aes->key_id = PSA_KEY_ID_NULL;
aes->ctx_initialized = 0;
aes->key_need_importing = 0;
XMEMSET(&aes->psa_ctx, 0, sizeof(aes->psa_ctx));
return 0;
}
int wc_psa_aes_get_key_size(Aes *aes, word32 *keySize)
{
psa_key_attributes_t attr;
psa_status_t s;
if (aes->key_need_importing == 1) {
*keySize = aes->keylen;
return 0;
}
if (aes->key_id == PSA_KEY_ID_NULL)
return BAD_FUNC_ARG;
PSA_LOCK();
s = psa_get_key_attributes(aes->key_id, &attr);
PSA_UNLOCK();
if (s != PSA_SUCCESS)
return WC_HW_E;
*keySize = (word32)(psa_get_key_bits(&attr) / 8);
PSA_LOCK();
psa_reset_key_attributes(&attr);
PSA_UNLOCK();
return 0;
}
int wc_psa_aes_set_key(Aes *aes, const uint8_t *key, size_t key_length,
uint8_t *iv, psa_algorithm_t alg, int dir)
{
psa_status_t s;
int ret;
if (aes->ctx_initialized) {
PSA_LOCK();
s = psa_cipher_abort(&aes->psa_ctx);
PSA_UNLOCK();
if (s != PSA_SUCCESS)
return WC_HW_E;
aes->ctx_initialized = 0;
}
if (aes->key_id != PSA_KEY_ID_NULL) {
PSA_LOCK();
psa_destroy_key(aes->key_id);
PSA_UNLOCK();
aes->key_id = PSA_KEY_ID_NULL;
}
if (alg == PSA_ALG_NONE) {
XMEMCPY(aes->key, key, key_length);
aes->key_need_importing = 1;
} else {
ret = wc_psa_aes_import_key(aes, key, key_length, alg, dir);
if (ret != 0)
return ret;
}
return wc_AesSetIV(aes, iv);
}
int wc_psa_aes_encrypt_decrypt(Aes *aes, const uint8_t *input,
uint8_t *output, size_t length,
psa_algorithm_t alg, int direction)
{
size_t output_length;
psa_status_t s;
int r;
if (aes == NULL || input == NULL || output == NULL)
return BAD_FUNC_ARG;
if (aes->ctx_initialized == 0) {
if (aes->key_need_importing == 1) {
r = wc_psa_aes_import_key(aes, (uint8_t*)aes->key, aes->keylen,
alg, direction);
if (r != 0)
return r;
ForceZero(aes->key, aes->keylen);
aes->key_need_importing = 0;
aes->keylen = 0;
}
if (direction == AES_ENCRYPTION) {
PSA_LOCK();
s = psa_cipher_encrypt_setup(&aes->psa_ctx, aes->key_id, alg);
PSA_UNLOCK();
} else {
PSA_LOCK();
s = psa_cipher_decrypt_setup(&aes->psa_ctx, aes->key_id, alg);
PSA_UNLOCK();
}
aes->ctx_initialized = 1;
if (s != PSA_SUCCESS)
goto err;
if (alg != PSA_ALG_ECB_NO_PADDING) {
PSA_LOCK();
s = psa_cipher_set_iv(&aes->psa_ctx,
(uint8_t*)aes->reg, AES_IV_SIZE);
PSA_UNLOCK();
if (s != PSA_SUCCESS)
goto err;
}
}
PSA_LOCK();
s = psa_cipher_update(&aes->psa_ctx, input,
length, output, length, &output_length);
PSA_UNLOCK();
if (s != PSA_SUCCESS)
goto err;
if (output_length != length)
goto err;
return 0;
err:
wc_psa_aes_free(aes);
return WC_HW_E;
}
int wc_psa_aes_free(Aes *aes)
{
if (aes->key_id != PSA_KEY_ID_NULL) {
PSA_LOCK();
psa_destroy_key(aes->key_id);
PSA_UNLOCK();
aes->key_id = PSA_KEY_ID_NULL;
}
if (aes->ctx_initialized == 1) {
PSA_LOCK();
psa_cipher_abort(&aes->psa_ctx);
PSA_UNLOCK();
aes->ctx_initialized = 0;
}
aes->key_need_importing = 0;
return 0;
}
int wc_AesEncrypt(Aes *aes, const byte *inBlock, byte *outBlock)
{
return wc_psa_aes_encrypt_decrypt(aes, inBlock, outBlock,
WC_AES_BLOCK_SIZE, PSA_ALG_ECB_NO_PADDING,
AES_ENCRYPTION);
}
#if defined(HAVE_AES_DECRYPT)
int wc_AesDecrypt(Aes *aes, const byte *inBlock, byte *outBlock)
{
return wc_psa_aes_encrypt_decrypt(aes, inBlock, outBlock,
WC_AES_BLOCK_SIZE, PSA_ALG_ECB_NO_PADDING,
AES_DECRYPTION);
}
#endif
#if defined(WOLFSSL_AES_COUNTER)
int wc_AesCtrEncrypt(Aes *aes, byte *out, const byte *in, word32 sz)
{
return wc_psa_aes_encrypt_decrypt(aes, in, out, sz, PSA_ALG_CTR,
AES_ENCRYPTION);
}
#endif
#if defined (HAVE_AES_CBC)
int wc_AesCbcEncrypt(Aes *aes, byte *out, const byte *in, word32 sz)
{
if (sz % WC_AES_BLOCK_SIZE != 0)
#if defined (WOLFSSL_AES_CBC_LENGTH_CHECKS)
return BAD_LENGTH_E;
#else
return BAD_FUNC_ARG;
#endif
return wc_psa_aes_encrypt_decrypt(aes, in, out, sz,
PSA_ALG_CBC_NO_PADDING,
AES_ENCRYPTION);
}
int wc_AesCbcDecrypt(Aes *aes, byte *out, const byte *in, word32 sz)
{
if (sz % WC_AES_BLOCK_SIZE != 0)
#if defined (WOLFSSL_AES_CBC_LENGTH_CHECKS)
return BAD_LENGTH_E;
#else
return BAD_FUNC_ARG;
#endif
return wc_psa_aes_encrypt_decrypt(aes, in, out, sz,
PSA_ALG_CBC_NO_PADDING,
AES_DECRYPTION);
}
#endif
#endif
#endif
#endif