Skip to main content

foctet_core/
crypto.rs

1use chacha20poly1305::{
2    KeyInit, XChaCha20Poly1305, XNonce,
3    aead::{Aead, Payload},
4};
5use hkdf::Hkdf;
6use rand_core::{OsRng, RngCore};
7use sha2::Sha256;
8use x25519_dalek::{PublicKey, StaticSecret};
9use zeroize::{Zeroize, Zeroizing};
10
11use crate::{
12    CoreError,
13    frame::{Frame, FrameHeader, PROFILE_X25519_HKDF_XCHACHA20POLY1305},
14};
15
16/// Direction of protected traffic keys.
17#[derive(Clone, Copy, Debug, Eq, PartialEq)]
18pub enum Direction {
19    /// Client-to-server direction.
20    C2S,
21    /// Server-to-client direction.
22    S2C,
23}
24
25/// Bidirectional traffic keys bound to a single `key_id`.
26#[derive(Clone, Debug, Eq, PartialEq)]
27pub struct TrafficKeys {
28    /// Active key identifier carried in frame headers.
29    pub key_id: u8,
30    /// Client-to-server key bytes.
31    pub c2s: [u8; 32],
32    /// Server-to-client key bytes.
33    pub s2c: [u8; 32],
34}
35
36impl TrafficKeys {
37    /// Returns key bytes for the specified direction.
38    pub fn key_for(&self, direction: Direction) -> [u8; 32] {
39        match direction {
40            Direction::C2S => self.c2s,
41            Direction::S2C => self.s2c,
42        }
43    }
44}
45
46impl Drop for TrafficKeys {
47    fn drop(&mut self) {
48        self.c2s.zeroize();
49        self.s2c.zeroize();
50    }
51}
52
53/// Builds a Draft v0 XChaCha nonce from frame metadata.
54pub fn make_nonce(key_id: u8, stream_id: u32, seq: u64) -> [u8; 24] {
55    let mut nonce = [0u8; 24];
56    nonce[0] = key_id;
57    nonce[1..5].copy_from_slice(&stream_id.to_be_bytes());
58    nonce[5..13].copy_from_slice(&seq.to_be_bytes());
59    nonce
60}
61
62/// Derives initial traffic keys from a shared secret and session salt.
63pub fn derive_traffic_keys(
64    shared_secret: &[u8],
65    session_salt: &[u8; 32],
66    key_id: u8,
67) -> Result<TrafficKeys, CoreError> {
68    let hk = Hkdf::<Sha256>::new(Some(session_salt), shared_secret);
69    let mut c2s = [0u8; 32];
70    let mut s2c = [0u8; 32];
71    hk.expand(b"foctet c2s", &mut c2s)
72        .map_err(|_| CoreError::Hkdf)?;
73    hk.expand(b"foctet s2c", &mut s2c)
74        .map_err(|_| CoreError::Hkdf)?;
75    Ok(TrafficKeys { key_id, c2s, s2c })
76}
77
78/// Derives rekeyed traffic keys from shared/session/rekey salt inputs.
79pub fn derive_rekey_traffic_keys(
80    shared_secret: &[u8; 32],
81    session_salt: &[u8; 32],
82    rekey_salt: &[u8; 32],
83    key_id: u8,
84) -> Result<TrafficKeys, CoreError> {
85    let mut salt = Zeroizing::new([0u8; 64]);
86    salt[..32].copy_from_slice(session_salt);
87    salt[32..].copy_from_slice(rekey_salt);
88    let hk = Hkdf::<Sha256>::new(Some(&salt[..]), shared_secret);
89
90    let mut c2s = [0u8; 32];
91    let mut s2c = [0u8; 32];
92
93    let mut info_c2s = [0u8; 17];
94    info_c2s[..16].copy_from_slice(b"foctet rekey c2s");
95    info_c2s[16] = key_id;
96    let mut info_s2c = [0u8; 17];
97    info_s2c[..16].copy_from_slice(b"foctet rekey s2c");
98    info_s2c[16] = key_id;
99
100    hk.expand(&info_c2s, &mut c2s)
101        .map_err(|_| CoreError::Hkdf)?;
102    hk.expand(&info_s2c, &mut s2c)
103        .map_err(|_| CoreError::Hkdf)?;
104
105    Ok(TrafficKeys { key_id, c2s, s2c })
106}
107
108/// Generates a random session salt for key derivation.
109pub fn random_session_salt() -> [u8; 32] {
110    let mut out = [0u8; 32];
111    OsRng.fill_bytes(&mut out);
112    out
113}
114
115/// Ephemeral X25519 key pair used during native handshake.
116#[derive(Clone, Debug)]
117pub struct EphemeralKeyPair {
118    private: Zeroizing<[u8; 32]>,
119    /// Public key bytes.
120    pub public: [u8; 32],
121}
122
123impl EphemeralKeyPair {
124    /// Generates a fresh ephemeral X25519 key pair.
125    pub fn generate() -> Self {
126        let private = StaticSecret::random_from_rng(OsRng);
127        let public = PublicKey::from(&private);
128        Self {
129            private: Zeroizing::new(private.to_bytes()),
130            public: public.to_bytes(),
131        }
132    }
133
134    /// Computes shared secret with peer ephemeral public key.
135    pub fn shared_secret(&self, peer_public: [u8; 32]) -> Result<[u8; 32], CoreError> {
136        let private = StaticSecret::from(*self.private);
137        let peer = PublicKey::from(peer_public);
138        let shared = private.diffie_hellman(&peer).to_bytes();
139        if shared.iter().all(|byte| *byte == 0) {
140            return Err(CoreError::InvalidSharedSecret);
141        }
142        Ok(shared)
143    }
144}
145
146/// Encrypts plaintext into a Foctet frame using AEAD profile `0x01`.
147pub fn encrypt_frame(
148    keys: &TrafficKeys,
149    direction: Direction,
150    flags: u8,
151    stream_id: u32,
152    seq: u64,
153    plaintext: &[u8],
154) -> Result<Frame, CoreError> {
155    let key = Zeroizing::new(keys.key_for(direction));
156    let cipher =
157        XChaCha20Poly1305::new_from_slice(&key[..]).map_err(|_| CoreError::InvalidKeyLength)?;
158
159    let mut header = FrameHeader::new(
160        flags,
161        PROFILE_X25519_HKDF_XCHACHA20POLY1305,
162        keys.key_id,
163        stream_id,
164        seq,
165        0,
166    );
167
168    let nonce_raw = make_nonce(keys.key_id, stream_id, seq);
169    let nonce = XNonce::from_slice(&nonce_raw);
170
171    let mut aad_header = header.clone();
172    aad_header.ct_len = (plaintext.len() + 16) as u32;
173    let aad = aad_header.encode();
174
175    let ciphertext = cipher
176        .encrypt(
177            nonce,
178            Payload {
179                msg: plaintext,
180                aad: &aad,
181            },
182        )
183        .map_err(|_| CoreError::Aead)?;
184
185    header.ct_len = ciphertext.len() as u32;
186    Ok(Frame { header, ciphertext })
187}
188
189/// Decrypts a frame and enforces `key_id` equality with `keys`.
190pub fn decrypt_frame(
191    keys: &TrafficKeys,
192    direction: Direction,
193    frame: &Frame,
194) -> Result<Vec<u8>, CoreError> {
195    frame.header.validate_v0()?;
196    if frame.header.key_id != keys.key_id {
197        return Err(CoreError::UnexpectedKeyId {
198            expected: keys.key_id,
199            actual: frame.header.key_id,
200        });
201    }
202    decrypt_frame_with_key(keys, direction, frame)
203}
204
205/// Decrypts a frame with a specific key record, without key-id equality check.
206pub fn decrypt_frame_with_key(
207    keys: &TrafficKeys,
208    direction: Direction,
209    frame: &Frame,
210) -> Result<Vec<u8>, CoreError> {
211    frame.header.validate_v0()?;
212    if frame.ciphertext.len() != frame.header.ct_len as usize {
213        return Err(CoreError::CiphertextLengthMismatch {
214            expected: frame.header.ct_len as usize,
215            actual: frame.ciphertext.len(),
216        });
217    }
218
219    let key = Zeroizing::new(keys.key_for(direction));
220    let cipher =
221        XChaCha20Poly1305::new_from_slice(&key[..]).map_err(|_| CoreError::InvalidKeyLength)?;
222    let nonce_raw = make_nonce(
223        frame.header.key_id,
224        frame.header.stream_id,
225        frame.header.seq,
226    );
227    let nonce = XNonce::from_slice(&nonce_raw);
228    let aad = frame.header.encode();
229    cipher
230        .decrypt(
231            nonce,
232            Payload {
233                msg: &frame.ciphertext,
234                aad: &aad,
235            },
236        )
237        .map_err(|_| CoreError::Aead)
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    #[test]
245    fn frame_roundtrip_encrypt_decrypt() {
246        let eph_a = EphemeralKeyPair::generate();
247        let eph_b = EphemeralKeyPair::generate();
248        let ss_a = eph_a.shared_secret(eph_b.public).expect("shared secret a");
249        let ss_b = eph_b.shared_secret(eph_a.public).expect("shared secret b");
250        assert_eq!(ss_a, ss_b);
251
252        let salt = random_session_salt();
253        let keys = derive_traffic_keys(&ss_a, &salt, 7).expect("derive traffic keys");
254
255        let plaintext = b"foctet core frame roundtrip";
256        let frame =
257            encrypt_frame(&keys, Direction::C2S, 0b10, 10, 42, plaintext).expect("encrypt frame");
258        let bytes = frame.to_bytes();
259
260        let parsed = Frame::from_bytes(&bytes).expect("parse frame");
261        let out = decrypt_frame(&keys, Direction::C2S, &parsed).expect("decrypt frame");
262        assert_eq!(out, plaintext);
263    }
264
265    #[test]
266    fn nonce_layout_matches_spec() {
267        let nonce = make_nonce(0xAB, 0x0102_0304, 0x0102_0304_0506_0708);
268        assert_eq!(nonce[0], 0xAB);
269        assert_eq!(&nonce[1..5], &[0x01, 0x02, 0x03, 0x04]);
270        assert_eq!(
271            &nonce[5..13],
272            &[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]
273        );
274        assert_eq!(&nonce[13..], &[0u8; 11]);
275    }
276}