1use core::{cell::OnceCell, fmt, marker::PhantomData};
2
3use buggy::{Bug, BugExt as _};
4use derive_where::derive_where;
5use serde::{Deserialize, Serialize};
6use spideroak_crypto::{
7 aead::Tag,
8 hex::Hex,
9 kdf::{self, Kdf},
10 keys::SecretKeyBytes,
11};
12use zerocopy::{ByteEq, Immutable, IntoBytes, KnownLayout, Unaligned};
13
14use crate::{
15 Csprng, Random,
16 aranya::{Encap, EncryptionKey, EncryptionPublicKey},
17 ciphersuite::{CipherSuite, CipherSuiteExt as _},
18 engine::unwrapped,
19 error::Error,
20 generic_array::GenericArray,
21 hpke::{self, Mode},
22 id::{IdError, Identified, custom_id},
23 policy::{GroupId, PolicyId},
24 subtle::{Choice, ConstantTimeEq},
25 tls::{self, CipherSuiteId},
26 util,
27 zeroize::{Zeroize as _, ZeroizeOnDrop, Zeroizing},
28};
29
30type Prk<CS> = kdf::Prk<<<CS as CipherSuite>::Kdf as Kdf>::PrkSize>;
31
32const SEED_DOMAIN: &[u8] = b"SeedForAranyaTls-v1";
34
35const PSK_DOMAIN: &[u8] = b"PskForAranyaTls-v1";
37
38custom_id! {
39 pub struct PskSeedId;
41}
42
43#[derive_where(Clone, Debug)]
45pub struct PskSeed<CS: CipherSuite> {
46 #[derive_where(skip(Debug))]
47 prk: Prk<CS>,
48 id: OnceCell<Result<PskSeedId, Bug>>,
55 _marker: PhantomData<CS>,
56}
57
58impl<CS: CipherSuite> PskSeed<CS> {
59 pub fn new<R>(rng: &mut R, group: &GroupId) -> Self
61 where
62 R: Csprng,
63 {
64 let ikm = Zeroizing::new(Random::random(rng));
65 Self::from_ikm(&ikm, group)
66 }
67
68 pub fn import_from_ikm(ikm: &[u8; 32], group: &GroupId) -> Self {
76 Self::from_ikm(ikm, group)
77 }
78
79 pub(crate) fn from_ikm(ikm: &[u8; 32], group: &GroupId) -> Self {
83 let prk = CS::labeled_extract(SEED_DOMAIN, &[], b"prk", [group.as_bytes(), ikm]);
84 Self::from_prk(prk)
85 }
86
87 fn from_prk(prk: Prk<CS>) -> Self {
89 Self {
90 prk,
91 id: OnceCell::new(),
92 _marker: PhantomData,
93 }
94 }
95
96 fn try_id(&self) -> Result<&PskSeedId, &Bug> {
98 self.id
99 .get_or_init(|| {
100 let id = CS::labeled_expand(SEED_DOMAIN, &self.prk, b"id", [])
119 .assume("should be able to generate PSK seed ID")?;
120 Ok(PskSeedId::from_bytes(id))
121 })
122 .as_ref()
123 }
124
125 pub fn generate_psks<I>(
136 self,
137 context: &'static [u8],
138 group: GroupId,
139 policy: PolicyId,
140 suites: I,
141 ) -> impl Iterator<Item = Result<Psk<CS>, Error>>
142 where
143 I: Iterator<Item = CipherSuiteId>,
144 {
145 suites.into_iter().map(move |suite| {
146 let id = ImportedIdentity {
147 external_identity: *self.try_id().map_err(Bug::clone)?,
148 context: PskCtx { group, policy },
149 target_protocol: tls::Version::Tls13,
150 target_kdf: suite,
151 };
152 let secret =
153 CS::labeled_expand(PSK_DOMAIN, &self.prk, b"psk", [id.as_bytes(), context])?;
154 Ok(Psk {
155 id: PskId(id),
156 secret,
157 _marker: PhantomData,
158 })
159 })
160 }
161}
162
163impl<CS: CipherSuite> ZeroizeOnDrop for PskSeed<CS> {}
164impl<CS: CipherSuite> Drop for PskSeed<CS> {
165 #[inline]
166 fn drop(&mut self) {
167 util::val_is_zeroize_on_drop(&self.prk);
168 }
169}
170
171unwrapped! {
172 name: PskSeed;
173 type: Prk;
174 into: |key: Self| { key.prk.clone() };
175 from: |prk| { Self::from_prk(prk) };
176}
177
178impl<CS: CipherSuite> Identified for PskSeed<CS> {
179 type Id = PskSeedId;
180
181 #[inline]
182 fn id(&self) -> Result<Self::Id, IdError> {
183 let id = self.try_id().map_err(Bug::clone)?;
184 Ok(*id)
185 }
186}
187
188impl<CS: CipherSuite> ConstantTimeEq for PskSeed<CS> {
189 #[inline]
190 fn ct_eq(&self, other: &Self) -> Choice {
191 self.prk.ct_eq(&other.prk)
193 }
194}
195
196#[repr(C)]
209#[derive(Copy, Clone, Debug, Immutable, IntoBytes, KnownLayout, Serialize, Deserialize)]
210struct ImportedIdentity {
211 external_identity: PskSeedId,
215 context: PskCtx,
216 target_protocol: tls::Version,
217 target_kdf: CipherSuiteId,
220}
221
222#[repr(C)]
223#[derive(Copy, Clone, Debug, Immutable, IntoBytes, KnownLayout, Serialize, Deserialize)]
224struct PskCtx {
225 group: GroupId,
226 policy: PolicyId,
227}
228
229impl<CS: CipherSuite> EncryptionKey<CS> {
230 pub fn seal_psk_seed<R: Csprng>(
236 &self,
237 rng: &mut R,
238 seed: &PskSeed<CS>,
239 peer_pk: &EncryptionPublicKey<CS>,
240 group: &GroupId,
241 ) -> Result<(Encap<CS>, EncryptedPskSeed<CS>), Error> {
242 if &self.public()? == peer_pk {
243 return Err(Error::InvalidArgument("same `EncryptionKey`"));
244 }
245 let info = Info {
250 domain: *b"PskSeed-v1",
251 group: *group,
252 };
253 let (enc, mut ctx) =
254 hpke::setup_send::<CS, _>(rng, Mode::Auth(&self.sk), &peer_pk.pk, [info.as_bytes()])?;
255 let mut ciphertext = seed.prk.clone().into_bytes().into_bytes();
256 let mut tag = Tag::<CS::Aead>::default();
257 ctx.seal_in_place(&mut ciphertext, &mut tag, info.as_bytes())
258 .inspect_err(|_| ciphertext.zeroize())?;
259 Ok((Encap(enc), EncryptedPskSeed { ciphertext, tag }))
260 }
261
262 pub fn open_psk_seed(
265 &self,
266 encap: &Encap<CS>,
267 ciphertext: EncryptedPskSeed<CS>,
268 peer_pk: &EncryptionPublicKey<CS>,
269 group: &GroupId,
270 ) -> Result<PskSeed<CS>, Error> {
271 let EncryptedPskSeed {
272 mut ciphertext,
273 tag,
274 } = ciphertext;
275
276 let info = Info {
281 domain: *b"PskSeed-v1",
282 group: *group,
283 };
284 let mut ctx = hpke::setup_recv::<CS>(
285 Mode::Auth(&peer_pk.pk),
286 &encap.0,
287 &self.sk,
288 [info.as_bytes()],
289 )?;
290 ctx.open_in_place(&mut ciphertext, &tag, info.as_bytes())?;
291
292 let prk = Prk::<CS>::new(SecretKeyBytes::new(ciphertext));
293 Ok(PskSeed::from_prk(prk))
294 }
295}
296
297#[repr(C)]
299#[derive(Copy, Clone, Debug, ByteEq, Immutable, IntoBytes, KnownLayout, Unaligned)]
300struct Info {
301 domain: [u8; 10],
303 group: GroupId,
304}
305
306#[derive_where(Clone, Debug, Serialize, Deserialize)]
308pub struct EncryptedPskSeed<CS: CipherSuite> {
309 pub(crate) ciphertext: GenericArray<u8, <<CS as CipherSuite>::Kdf as Kdf>::PrkSize>,
311 pub(crate) tag: Tag<CS::Aead>,
312}
313
314#[derive_where(Clone, Debug)]
321pub struct Psk<CS> {
322 #[derive_where(skip(Debug))]
323 secret: [u8; 32],
324 id: PskId,
325 _marker: PhantomData<CS>,
326}
327
328impl<CS: CipherSuite> Psk<CS> {
329 pub fn identity(&self) -> &PskId {
336 &self.id
337 }
338
339 pub fn raw_secret_bytes(&self) -> &[u8] {
346 &self.secret
347 }
348}
349
350impl<CS> ZeroizeOnDrop for Psk<CS> {}
351impl<CS> Drop for Psk<CS> {
352 #[inline]
353 fn drop(&mut self) {
354 self.secret.zeroize();
355 }
356}
357
358impl<CS> ConstantTimeEq for Psk<CS> {
359 #[inline]
360 fn ct_eq(&self, other: &Self) -> Choice {
361 self.secret.ct_eq(&other.secret)
367 }
368}
369
370#[derive(Copy, Clone, Debug, ByteEq, Immutable, IntoBytes, KnownLayout, Serialize, Deserialize)]
379pub struct PskId(ImportedIdentity);
380
381impl PskId {
382 pub const fn seed_id(&self) -> &PskSeedId {
384 &self.0.external_identity
385 }
386
387 pub const fn group_id(&self) -> &GroupId {
389 &self.0.context.group
390 }
391
392 pub const fn cipher_suite(&self) -> CipherSuiteId {
394 self.0.target_kdf
395 }
396
397 pub const fn as_bytes(&self) -> &[u8] {
399 let bytes: &[u8; 100] = zerocopy::transmute_ref!(self);
400 bytes
401 }
402}
403
404impl ConstantTimeEq for PskId {
405 #[inline]
406 fn ct_eq(&self, other: &Self) -> Choice {
407 self.as_bytes().ct_eq(other.as_bytes())
408 }
409}
410
411impl fmt::Display for PskId {
412 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
413 Hex::new(self.as_bytes()).fmt(f)
414 }
415}