1use gbp_mls::MlsContext;
2use hkdf::Hkdf;
3use sha2::Sha256;
4
5use crate::error::SFrameError;
6
7#[derive(Clone, Copy, Debug, PartialEq, Eq)]
12pub enum CipherSuite {
13 Aes128Gcm,
15 Aes256Gcm,
17}
18
19impl CipherSuite {
20 pub(crate) fn key_len(self) -> usize {
22 match self {
23 Self::Aes128Gcm => 16,
24 Self::Aes256Gcm => 32,
25 }
26 }
27
28 pub fn from_u8(v: u8) -> Option<Self> {
30 match v {
31 0 => Some(Self::Aes128Gcm),
32 1 => Some(Self::Aes256Gcm),
33 _ => None,
34 }
35 }
36
37 pub fn as_u8(self) -> u8 {
39 match self {
40 Self::Aes128Gcm => 0,
41 Self::Aes256Gcm => 1,
42 }
43 }
44}
45
46pub(crate) struct ParticipantKeys {
48 pub key: Vec<u8>,
50 pub base_nonce: [u8; 12],
52}
53
54pub fn derive_base_key(mls: &MlsContext, label: &str, epoch: u64) -> Result<[u8; 32], SFrameError> {
61 let context = epoch.to_be_bytes();
62 let raw = mls
63 .export_raw(label, &context, 32)
64 .map_err(|e| SFrameError::MlsExport(e.to_string()))?;
65 let mut out = [0u8; 32];
66 out.copy_from_slice(&raw);
67 Ok(out)
68}
69
70const HKDF_LABEL_KEY: &[u8] = b"gbp sframe key ";
74const HKDF_LABEL_NONCE: &[u8] = b"gbp sframe salt ";
75
76pub(crate) fn derive_participant(
82 base_key: &[u8; 32],
83 leaf_index: u32,
84 suite: CipherSuite,
85) -> ParticipantKeys {
86 let hk =
88 Hkdf::<Sha256>::from_prk(base_key).expect("base_key is exactly SHA-256 HashLen (32 bytes)");
89
90 let leaf_be = leaf_index.to_be_bytes();
91
92 let mut label = HKDF_LABEL_KEY.to_vec();
93 label.extend_from_slice(&leaf_be);
94 let mut key = vec![0u8; suite.key_len()];
95 hk.expand(&label, &mut key)
96 .expect("key length is well within 255 * HashLen");
97
98 let mut label = HKDF_LABEL_NONCE.to_vec();
99 label.extend_from_slice(&leaf_be);
100 let mut base_nonce: [u8; 12] = Default::default();
101 hk.expand(&label, &mut base_nonce)
102 .expect("nonce length (12) is well within 255 * HashLen");
103
104 ParticipantKeys { key, base_nonce }
105}
106
107pub(crate) fn make_nonce(salt: &[u8; 12], ctr: u64) -> [u8; 12] {
110 let mut nonce = *salt;
111 let ctr_le = ctr.to_le_bytes(); for i in 0..8 {
113 nonce[i] ^= ctr_le[i];
114 }
115 nonce
116}