1use chacha20poly1305::{
2 KeyInit, XChaCha20Poly1305, XNonce,
3 aead::{Aead, Payload},
4};
5use hkdf::Hkdf;
6use rand_core::{OsRng, RngCore};
7use sha2::Sha256;
8use x25519_dalek::{PublicKey, StaticSecret};
9use zeroize::{Zeroize, Zeroizing};
10
11use crate::{
12 CoreError,
13 frame::{Frame, FrameHeader, PROFILE_X25519_HKDF_XCHACHA20POLY1305},
14};
15
16#[derive(Clone, Copy, Debug, Eq, PartialEq)]
18pub enum Direction {
19 C2S,
21 S2C,
23}
24
25#[derive(Clone, Debug, Eq, PartialEq)]
27pub struct TrafficKeys {
28 pub key_id: u8,
30 pub c2s: [u8; 32],
32 pub s2c: [u8; 32],
34}
35
36impl TrafficKeys {
37 pub fn key_for(&self, direction: Direction) -> [u8; 32] {
39 match direction {
40 Direction::C2S => self.c2s,
41 Direction::S2C => self.s2c,
42 }
43 }
44}
45
46impl Drop for TrafficKeys {
47 fn drop(&mut self) {
48 self.c2s.zeroize();
49 self.s2c.zeroize();
50 }
51}
52
53pub fn make_nonce(key_id: u8, stream_id: u32, seq: u64) -> [u8; 24] {
55 let mut nonce = [0u8; 24];
56 nonce[0] = key_id;
57 nonce[1..5].copy_from_slice(&stream_id.to_be_bytes());
58 nonce[5..13].copy_from_slice(&seq.to_be_bytes());
59 nonce
60}
61
62pub fn derive_traffic_keys(
64 shared_secret: &[u8],
65 session_salt: &[u8; 32],
66 key_id: u8,
67) -> Result<TrafficKeys, CoreError> {
68 let hk = Hkdf::<Sha256>::new(Some(session_salt), shared_secret);
69 let mut c2s = [0u8; 32];
70 let mut s2c = [0u8; 32];
71 hk.expand(b"foctet c2s", &mut c2s)
72 .map_err(|_| CoreError::Hkdf)?;
73 hk.expand(b"foctet s2c", &mut s2c)
74 .map_err(|_| CoreError::Hkdf)?;
75 Ok(TrafficKeys { key_id, c2s, s2c })
76}
77
78pub fn derive_rekey_traffic_keys(
80 shared_secret: &[u8; 32],
81 session_salt: &[u8; 32],
82 rekey_salt: &[u8; 32],
83 key_id: u8,
84) -> Result<TrafficKeys, CoreError> {
85 let mut salt = Zeroizing::new([0u8; 64]);
86 salt[..32].copy_from_slice(session_salt);
87 salt[32..].copy_from_slice(rekey_salt);
88 let hk = Hkdf::<Sha256>::new(Some(&salt[..]), shared_secret);
89
90 let mut c2s = [0u8; 32];
91 let mut s2c = [0u8; 32];
92
93 let mut info_c2s = [0u8; 17];
94 info_c2s[..16].copy_from_slice(b"foctet rekey c2s");
95 info_c2s[16] = key_id;
96 let mut info_s2c = [0u8; 17];
97 info_s2c[..16].copy_from_slice(b"foctet rekey s2c");
98 info_s2c[16] = key_id;
99
100 hk.expand(&info_c2s, &mut c2s)
101 .map_err(|_| CoreError::Hkdf)?;
102 hk.expand(&info_s2c, &mut s2c)
103 .map_err(|_| CoreError::Hkdf)?;
104
105 Ok(TrafficKeys { key_id, c2s, s2c })
106}
107
108pub fn random_session_salt() -> [u8; 32] {
110 let mut out = [0u8; 32];
111 OsRng.fill_bytes(&mut out);
112 out
113}
114
115#[derive(Clone, Debug)]
117pub struct EphemeralKeyPair {
118 private: Zeroizing<[u8; 32]>,
119 pub public: [u8; 32],
121}
122
123impl EphemeralKeyPair {
124 pub fn generate() -> Self {
126 let private = StaticSecret::random_from_rng(OsRng);
127 let public = PublicKey::from(&private);
128 Self {
129 private: Zeroizing::new(private.to_bytes()),
130 public: public.to_bytes(),
131 }
132 }
133
134 pub fn shared_secret(&self, peer_public: [u8; 32]) -> Result<[u8; 32], CoreError> {
136 let private = StaticSecret::from(*self.private);
137 let peer = PublicKey::from(peer_public);
138 let shared = private.diffie_hellman(&peer).to_bytes();
139 if shared.iter().all(|byte| *byte == 0) {
140 return Err(CoreError::InvalidSharedSecret);
141 }
142 Ok(shared)
143 }
144}
145
146pub fn encrypt_frame(
148 keys: &TrafficKeys,
149 direction: Direction,
150 flags: u8,
151 stream_id: u32,
152 seq: u64,
153 plaintext: &[u8],
154) -> Result<Frame, CoreError> {
155 let key = Zeroizing::new(keys.key_for(direction));
156 let cipher =
157 XChaCha20Poly1305::new_from_slice(&key[..]).map_err(|_| CoreError::InvalidKeyLength)?;
158
159 let mut header = FrameHeader::new(
160 flags,
161 PROFILE_X25519_HKDF_XCHACHA20POLY1305,
162 keys.key_id,
163 stream_id,
164 seq,
165 0,
166 );
167
168 let nonce_raw = make_nonce(keys.key_id, stream_id, seq);
169 let nonce = XNonce::from_slice(&nonce_raw);
170
171 let mut aad_header = header.clone();
172 aad_header.ct_len = (plaintext.len() + 16) as u32;
173 let aad = aad_header.encode();
174
175 let ciphertext = cipher
176 .encrypt(
177 nonce,
178 Payload {
179 msg: plaintext,
180 aad: &aad,
181 },
182 )
183 .map_err(|_| CoreError::Aead)?;
184
185 header.ct_len = ciphertext.len() as u32;
186 Ok(Frame { header, ciphertext })
187}
188
189pub fn decrypt_frame(
191 keys: &TrafficKeys,
192 direction: Direction,
193 frame: &Frame,
194) -> Result<Vec<u8>, CoreError> {
195 frame.header.validate_v0()?;
196 if frame.header.key_id != keys.key_id {
197 return Err(CoreError::UnexpectedKeyId {
198 expected: keys.key_id,
199 actual: frame.header.key_id,
200 });
201 }
202 decrypt_frame_with_key(keys, direction, frame)
203}
204
205pub fn decrypt_frame_with_key(
207 keys: &TrafficKeys,
208 direction: Direction,
209 frame: &Frame,
210) -> Result<Vec<u8>, CoreError> {
211 frame.header.validate_v0()?;
212 if frame.ciphertext.len() != frame.header.ct_len as usize {
213 return Err(CoreError::CiphertextLengthMismatch {
214 expected: frame.header.ct_len as usize,
215 actual: frame.ciphertext.len(),
216 });
217 }
218
219 let key = Zeroizing::new(keys.key_for(direction));
220 let cipher =
221 XChaCha20Poly1305::new_from_slice(&key[..]).map_err(|_| CoreError::InvalidKeyLength)?;
222 let nonce_raw = make_nonce(
223 frame.header.key_id,
224 frame.header.stream_id,
225 frame.header.seq,
226 );
227 let nonce = XNonce::from_slice(&nonce_raw);
228 let aad = frame.header.encode();
229 cipher
230 .decrypt(
231 nonce,
232 Payload {
233 msg: &frame.ciphertext,
234 aad: &aad,
235 },
236 )
237 .map_err(|_| CoreError::Aead)
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243
244 #[test]
245 fn frame_roundtrip_encrypt_decrypt() {
246 let eph_a = EphemeralKeyPair::generate();
247 let eph_b = EphemeralKeyPair::generate();
248 let ss_a = eph_a.shared_secret(eph_b.public).expect("shared secret a");
249 let ss_b = eph_b.shared_secret(eph_a.public).expect("shared secret b");
250 assert_eq!(ss_a, ss_b);
251
252 let salt = random_session_salt();
253 let keys = derive_traffic_keys(&ss_a, &salt, 7).expect("derive traffic keys");
254
255 let plaintext = b"foctet core frame roundtrip";
256 let frame =
257 encrypt_frame(&keys, Direction::C2S, 0b10, 10, 42, plaintext).expect("encrypt frame");
258 let bytes = frame.to_bytes();
259
260 let parsed = Frame::from_bytes(&bytes).expect("parse frame");
261 let out = decrypt_frame(&keys, Direction::C2S, &parsed).expect("decrypt frame");
262 assert_eq!(out, plaintext);
263 }
264
265 #[test]
266 fn nonce_layout_matches_spec() {
267 let nonce = make_nonce(0xAB, 0x0102_0304, 0x0102_0304_0506_0708);
268 assert_eq!(nonce[0], 0xAB);
269 assert_eq!(&nonce[1..5], &[0x01, 0x02, 0x03, 0x04]);
270 assert_eq!(
271 &nonce[5..13],
272 &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]
273 );
274 assert_eq!(&nonce[13..], &[0u8; 11]);
275 }
276}