use libcrux_hkdf::{expand as hkdf_expand, Algorithm as HKDF_Algorithm};
use libcrux_hmac::{hmac, Algorithm as HMAC_Algorithm};
use libcrux_traits::kem::KEM;
use rand::CryptoRng;
use tls_codec::{Deserialize, Serialize, Size, TlsDeserialize, TlsSerialize, TlsSize};
use super::Error;
pub const PSQ_COMPONENT_LENGTH: usize = 32;
const K0_LENGTH: usize = 32;
const KM_LENGTH: usize = 32;
const CONFIRMATION_CONTEXT: &[u8] = b"Confirmation";
const PSK_CONTEXT: &[u8] = b"PSK";
const MAC_INPUT: &[u8] = b"MAC-Input";
pub(crate) const MAC_LENGTH: usize = 32;
type Mac = [u8; MAC_LENGTH];
pub type PSQComponent = [u8; PSQ_COMPONENT_LENGTH];
pub(crate) mod private {
pub trait Seal {}
}
pub trait PSQ: private::Seal {
type InnerKEM: KEM<
Ciphertext: Deserialize + Serialize + Size,
SharedSecret: Serialize + Size,
EncapsulationKey: Serialize + Size,
>;
fn encapsulate_psq(
pk: &<Self::InnerKEM as KEM>::EncapsulationKey,
sctx: &[u8],
rng: &mut impl CryptoRng,
) -> Result<(PSQComponent, Ciphertext<Self::InnerKEM>), Error> {
let (ikm, enc) =
Self::InnerKEM::encapsulate(pk, rng).map_err(|_| Error::PSQGenerationError)?;
let mut pk_serialized = vec![0u8; pk.tls_serialized_len()];
let _ = pk
.tls_serialize(&mut &mut pk_serialized[..])
.map_err(|_| Error::Serialization)?;
let mut ikm_serialized = vec![0u8; ikm.tls_serialized_len()];
let _ = ikm
.tls_serialize(&mut &mut ikm_serialized[..])
.map_err(|_| Error::Serialization)?;
let mut enc_serialized = vec![0u8; enc.tls_serialized_len()];
let _ = enc
.tls_serialize(&mut &mut enc_serialized[..])
.map_err(|_| Error::Serialization)?;
let k0 = compute_k0(&pk_serialized, &ikm_serialized, &enc_serialized, sctx)?;
let mac = compute_mac(&k0)?;
let psk = compute_psk(&k0)?;
Ok((
psk,
Ciphertext {
inner_ctxt: enc,
mac,
},
))
}
fn decapsulate_psq(
sk: &<Self::InnerKEM as KEM>::DecapsulationKey,
pk: &<Self::InnerKEM as KEM>::EncapsulationKey,
enc: &Ciphertext<Self::InnerKEM>,
sctx: &[u8],
) -> Result<PSQComponent, Error> {
let Ciphertext { inner_ctxt, mac } = enc;
let ikm =
Self::InnerKEM::decapsulate(sk, inner_ctxt).map_err(|_| Error::PSQDerivationError)?;
let mut pk_serialized = vec![0u8; pk.tls_serialized_len()];
let _ = pk
.tls_serialize(&mut &mut pk_serialized[..])
.map_err(|_| Error::Serialization)?;
let mut ikm_serialized = vec![0u8; ikm.tls_serialized_len()];
let _ = ikm
.tls_serialize(&mut &mut ikm_serialized[..])
.map_err(|_| Error::Serialization)?;
let mut inner_ctxt_serialized = vec![0u8; inner_ctxt.tls_serialized_len()];
let _ = inner_ctxt
.tls_serialize(&mut &mut inner_ctxt_serialized[..])
.map_err(|_| Error::Serialization)?;
let k0 = compute_k0(
&pk_serialized,
&ikm_serialized,
&inner_ctxt_serialized,
sctx,
)?;
let recomputed_mac = compute_mac(&k0)?;
if compare(&recomputed_mac, mac) == 0 {
compute_psk(&k0)
} else {
Err(Error::PSQDerivationError)
}
}
}
#[derive(TlsSerialize, TlsDeserialize, TlsSize)]
pub struct Ciphertext<
T: KEM<
Ciphertext: Deserialize + Serialize + Size,
SharedSecret: Serialize + Size,
EncapsulationKey: Serialize + Size,
>,
> {
pub(crate) inner_ctxt: T::Ciphertext,
pub(crate) mac: Mac,
}
fn inz(value: u8) -> u8 {
let value = value as u16;
let result = ((value | (!value).wrapping_add(1)) >> 8) & 1;
result as u8
}
#[inline(never)] fn is_non_zero(value: u8) -> u8 {
core::hint::black_box(inz(value))
}
fn compare(lhs: &[u8], rhs: &[u8]) -> u8 {
let mut r: u8 = 0;
for i in 0..lhs.len() {
r |= lhs[i] ^ rhs[i];
}
is_non_zero(r)
}
fn compute_psk(k0: &[u8]) -> Result<PSQComponent, Error> {
let mut psk = [0u8; PSQ_COMPONENT_LENGTH];
hkdf_expand(HKDF_Algorithm::Sha256, &mut psk, k0, PSK_CONTEXT)
.map_err(|_| Error::PSQGenerationError)?;
Ok(psk)
}
fn compute_k0(pqpk: &[u8], ikm: &[u8], enc: &[u8], sctx: &[u8]) -> Result<Vec<u8>, Error> {
let mut info = Vec::from(pqpk);
info.extend_from_slice(enc);
info.extend_from_slice(sctx);
let mut k0 = [0u8; K0_LENGTH];
hkdf_expand(HKDF_Algorithm::Sha256, &mut k0, ikm, &info)
.map_err(|_| Error::PSQDerivationError)?;
Ok(k0.to_vec())
}
fn compute_mac(k0: &[u8]) -> Result<Mac, Error> {
let mut km = [0u8; KM_LENGTH];
hkdf_expand(HKDF_Algorithm::Sha256, &mut km, k0, CONFIRMATION_CONTEXT)
.map_err(|_| Error::PSQGenerationError)?;
let mac: Mac = hmac(HMAC_Algorithm::Sha256, &km, MAC_INPUT, Some(MAC_LENGTH))
.try_into()
.expect("should receive the correct number of bytes from HMAC");
Ok(mac)
}