use core::{cell::OnceCell, fmt, marker::PhantomData};
use buggy::{Bug, BugExt as _};
use derive_where::derive_where;
use serde::{Deserialize, Serialize};
use spideroak_crypto::{
aead::Tag,
hex::Hex,
kdf::{self, Kdf},
keys::SecretKeyBytes,
};
use zerocopy::{ByteEq, Immutable, IntoBytes, KnownLayout, Unaligned};
use crate::{
Csprng, Random,
aranya::{Encap, EncryptionKey, EncryptionPublicKey},
ciphersuite::{CipherSuite, CipherSuiteExt as _},
engine::unwrapped,
error::Error,
generic_array::GenericArray,
hpke::{self, Mode},
id::{IdError, Identified, custom_id},
policy::{GroupId, PolicyId},
subtle::{Choice, ConstantTimeEq},
tls::{self, CipherSuiteId},
util,
zeroize::{Zeroize as _, ZeroizeOnDrop, Zeroizing},
};
type Prk<CS> = kdf::Prk<<<CS as CipherSuite>::Kdf as Kdf>::PrkSize>;
const SEED_DOMAIN: &[u8] = b"SeedForAranyaTls-v1";
const PSK_DOMAIN: &[u8] = b"PskForAranyaTls-v1";
custom_id! {
pub struct PskSeedId;
}
#[derive_where(Clone, Debug)]
pub struct PskSeed<CS: CipherSuite> {
#[derive_where(skip(Debug))]
prk: Prk<CS>,
id: OnceCell<Result<PskSeedId, Bug>>,
_marker: PhantomData<CS>,
}
impl<CS: CipherSuite> PskSeed<CS> {
pub fn new<R>(rng: R, group: &GroupId) -> Self
where
R: Csprng,
{
let ikm = Zeroizing::new(Random::random(rng));
Self::from_ikm(&ikm, group)
}
pub fn import_from_ikm(ikm: &[u8; 32], group: &GroupId) -> Self {
Self::from_ikm(ikm, group)
}
pub(crate) fn from_ikm(ikm: &[u8; 32], group: &GroupId) -> Self {
let prk = CS::labeled_extract(SEED_DOMAIN, &[], b"prk", [group.as_bytes(), ikm]);
Self::from_prk(prk)
}
fn from_prk(prk: Prk<CS>) -> Self {
Self {
prk,
id: OnceCell::new(),
_marker: PhantomData,
}
}
fn try_id(&self) -> Result<&PskSeedId, &Bug> {
self.id
.get_or_init(|| {
let id = CS::labeled_expand(SEED_DOMAIN, &self.prk, b"id", [])
.assume("should be able to generate PSK seed ID")?;
Ok(PskSeedId::from_bytes(id))
})
.as_ref()
}
pub fn generate_psks<I>(
self,
context: &'static [u8],
group: GroupId,
policy: PolicyId,
suites: I,
) -> impl Iterator<Item = Result<Psk<CS>, Error>>
where
I: Iterator<Item = CipherSuiteId>,
{
suites.into_iter().map(move |suite| {
let id = ImportedIdentity {
external_identity: *self.try_id().map_err(Bug::clone)?,
context: PskCtx { group, policy },
target_protocol: tls::Version::Tls13,
target_kdf: suite,
};
let secret =
CS::labeled_expand(PSK_DOMAIN, &self.prk, b"psk", [id.as_bytes(), context])?;
Ok(Psk {
id: PskId(id),
secret,
_marker: PhantomData,
})
})
}
}
impl<CS: CipherSuite> ZeroizeOnDrop for PskSeed<CS> {}
impl<CS: CipherSuite> Drop for PskSeed<CS> {
#[inline]
fn drop(&mut self) {
util::val_is_zeroize_on_drop(&self.prk);
}
}
unwrapped! {
name: PskSeed;
type: Prk;
into: |key: Self| { key.prk.clone() };
from: |prk| { Self::from_prk(prk) };
}
impl<CS: CipherSuite> Identified for PskSeed<CS> {
type Id = PskSeedId;
#[inline]
fn id(&self) -> Result<Self::Id, IdError> {
let id = self.try_id().map_err(Bug::clone)?;
Ok(*id)
}
}
impl<CS: CipherSuite> ConstantTimeEq for PskSeed<CS> {
#[inline]
fn ct_eq(&self, other: &Self) -> Choice {
self.prk.ct_eq(&other.prk)
}
}
#[repr(C)]
#[derive(Copy, Clone, Debug, Immutable, IntoBytes, KnownLayout, Serialize, Deserialize)]
struct ImportedIdentity {
external_identity: PskSeedId,
context: PskCtx,
target_protocol: tls::Version,
target_kdf: CipherSuiteId,
}
#[repr(C)]
#[derive(Copy, Clone, Debug, Immutable, IntoBytes, KnownLayout, Serialize, Deserialize)]
struct PskCtx {
group: GroupId,
policy: PolicyId,
}
impl<CS: CipherSuite> EncryptionKey<CS> {
pub fn seal_psk_seed<R: Csprng>(
&self,
rng: R,
seed: &PskSeed<CS>,
peer_pk: &EncryptionPublicKey<CS>,
group: &GroupId,
) -> Result<(Encap<CS>, EncryptedPskSeed<CS>), Error> {
if &self.public()? == peer_pk {
return Err(Error::InvalidArgument("same `EncryptionKey`"));
}
let info = Info {
domain: *b"PskSeed-v1",
group: *group,
};
let (enc, mut ctx) =
hpke::setup_send::<CS, _>(rng, Mode::Auth(&self.sk), &peer_pk.pk, [info.as_bytes()])?;
let mut ciphertext = seed.prk.clone().into_bytes().into_bytes();
let mut tag = Tag::<CS::Aead>::default();
ctx.seal_in_place(&mut ciphertext, &mut tag, info.as_bytes())
.inspect_err(|_| ciphertext.zeroize())?;
Ok((Encap(enc), EncryptedPskSeed { ciphertext, tag }))
}
pub fn open_psk_seed(
&self,
encap: &Encap<CS>,
ciphertext: EncryptedPskSeed<CS>,
peer_pk: &EncryptionPublicKey<CS>,
group: &GroupId,
) -> Result<PskSeed<CS>, Error> {
let EncryptedPskSeed {
mut ciphertext,
tag,
} = ciphertext;
let info = Info {
domain: *b"PskSeed-v1",
group: *group,
};
let mut ctx = hpke::setup_recv::<CS>(
Mode::Auth(&peer_pk.pk),
&encap.0,
&self.sk,
[info.as_bytes()],
)?;
ctx.open_in_place(&mut ciphertext, &tag, info.as_bytes())?;
let prk = Prk::<CS>::new(SecretKeyBytes::new(ciphertext));
Ok(PskSeed::from_prk(prk))
}
}
#[repr(C)]
#[derive(Copy, Clone, Debug, ByteEq, Immutable, IntoBytes, KnownLayout, Unaligned)]
struct Info {
domain: [u8; 10],
group: GroupId,
}
#[derive_where(Clone, Debug, Serialize, Deserialize)]
pub struct EncryptedPskSeed<CS: CipherSuite> {
pub(crate) ciphertext: GenericArray<u8, <<CS as CipherSuite>::Kdf as Kdf>::PrkSize>,
pub(crate) tag: Tag<CS::Aead>,
}
#[derive_where(Clone, Debug)]
pub struct Psk<CS> {
#[derive_where(skip(Debug))]
secret: [u8; 32],
id: PskId,
_marker: PhantomData<CS>,
}
impl<CS: CipherSuite> Psk<CS> {
pub fn identity(&self) -> &PskId {
&self.id
}
pub fn raw_secret_bytes(&self) -> &[u8] {
&self.secret
}
}
impl<CS> ZeroizeOnDrop for Psk<CS> {}
impl<CS> Drop for Psk<CS> {
#[inline]
fn drop(&mut self) {
self.secret.zeroize();
}
}
impl<CS> ConstantTimeEq for Psk<CS> {
#[inline]
fn ct_eq(&self, other: &Self) -> Choice {
self.secret.ct_eq(&other.secret)
}
}
#[derive(Copy, Clone, Debug, ByteEq, Immutable, IntoBytes, KnownLayout, Serialize, Deserialize)]
pub struct PskId(ImportedIdentity);
impl PskId {
pub const fn seed_id(&self) -> &PskSeedId {
&self.0.external_identity
}
pub const fn group_id(&self) -> &GroupId {
&self.0.context.group
}
pub const fn cipher_suite(&self) -> CipherSuiteId {
self.0.target_kdf
}
pub const fn as_bytes(&self) -> &[u8] {
let bytes: &[u8; 100] = zerocopy::transmute_ref!(self);
bytes
}
}
impl ConstantTimeEq for PskId {
#[inline]
fn ct_eq(&self, other: &Self) -> Choice {
self.as_bytes().ct_eq(other.as_bytes())
}
}
impl fmt::Display for PskId {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
Hex::new(self.as_bytes()).fmt(f)
}
}