#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <gmssl/sm3.h>
#include <gmssl/hmac.h>
#include <gmssl/error.h>
int hkdf_extract(const DIGEST *digest, const uint8_t *salt, size_t saltlen,
const uint8_t *ikm, size_t ikmlen,
uint8_t *prk, size_t *prklen)
{
HMAC_CTX hmac_ctx;
if (!salt || saltlen == 0) {
uint8_t zeros[DIGEST_MAX_SIZE] = {0};
if (hmac_init(&hmac_ctx, digest, zeros, digest->digest_size) != 1) {
error_print();
return -1;
}
} else {
if (hmac_init(&hmac_ctx, digest, salt, saltlen) != 1) {
error_print();
return -1;
}
}
if (hmac_update(&hmac_ctx, ikm, ikmlen) != 1
|| hmac_finish(&hmac_ctx, prk, prklen) != 1) {
error_print();
return -1;
}
return 1;
}
int hkdf_expand(const DIGEST *digest, const uint8_t *prk, size_t prklen,
const uint8_t *opt_info, size_t opt_infolen,
size_t L, uint8_t *okm)
{
HMAC_CTX hmac_ctx;
uint8_t T[HMAC_MAX_SIZE];
uint8_t counter = 0x01;
size_t len;
if (L > 0) {
if (hmac_init(&hmac_ctx, digest, prk, prklen) != 1
|| hmac_update(&hmac_ctx, opt_info, opt_infolen) < 0
|| hmac_update(&hmac_ctx, &counter, 1) != 1
|| hmac_finish(&hmac_ctx, T, &len) != 1) {
error_print();
return -1;
}
counter++;
if (len > L) {
len = L;
}
memcpy(okm, T, len);
okm += len;
L -= len;
}
while (L > 0) {
if (counter == 0) {
error_print();
return -1;
}
if (hmac_init(&hmac_ctx, digest, prk, prklen) != 1
|| hmac_update(&hmac_ctx, T, len) != 1
|| hmac_update(&hmac_ctx, opt_info, opt_infolen) < 0
|| hmac_update(&hmac_ctx, &counter, 1) != 1
|| hmac_finish(&hmac_ctx, T, &len) != 1) {
error_print();
return -1;
}
counter++;
if (len > L) {
len = L;
}
memcpy(okm, T, len);
okm += len;
L -= len;
}
return 1;
}