Skip to main content

rns_core/link/
handshake.rs

1use alloc::vec::Vec;
2
3use rns_crypto::ed25519::{Ed25519PrivateKey, Ed25519PublicKey};
4use rns_crypto::hkdf::hkdf;
5use rns_crypto::x25519::{X25519PrivateKey, X25519PublicKey};
6
7use super::types::{LinkError, LinkId, LinkMode};
8use crate::constants::{LINK_ECPUBSIZE, LINK_MODE_BYTEMASK, LINK_MTU_BYTEMASK, LINK_MTU_SIZE};
9use crate::hash::truncated_hash;
10
11/// Compute link_id from a LINKREQUEST packet's hashable part.
12///
13/// The signalling bytes (if present) are stripped from the end before hashing.
14/// `extra_bytes_len` is `data.len() - ECPUBSIZE` (0 or MTU_SIZE).
15pub fn compute_link_id(hashable_part: &[u8], extra_bytes_len: usize) -> LinkId {
16    let end = if extra_bytes_len > 0 && hashable_part.len() > extra_bytes_len {
17        hashable_part.len() - extra_bytes_len
18    } else {
19        hashable_part.len()
20    };
21    truncated_hash(&hashable_part[..end])
22}
23
24/// Build signalling bytes: 3 big-endian bytes encoding MTU (21 bits) + mode (3 bits).
25///
26/// Format: `(mtu & 0x1FFFFF) + (((mode << 5) & 0xE0) << 16)`, packed as big-endian u32,
27/// then take the last 3 bytes.
28pub fn build_signalling_bytes(mtu: u32, mode: LinkMode) -> [u8; 3] {
29    let mode_bits = ((mode.mode_byte() << 5) & LINK_MODE_BYTEMASK) as u32;
30    let signalling_value = (mtu & LINK_MTU_BYTEMASK) + (mode_bits << 16);
31    let bytes = signalling_value.to_be_bytes();
32    [bytes[1], bytes[2], bytes[3]]
33}
34
35/// Parse signalling bytes → (mtu, mode).
36pub fn parse_signalling_bytes(bytes: &[u8; 3]) -> Result<(u32, LinkMode), LinkError> {
37    let mtu = ((bytes[0] as u32) << 16) | ((bytes[1] as u32) << 8) | (bytes[2] as u32);
38    let mode_byte = (bytes[0] & LINK_MODE_BYTEMASK) >> 5;
39    let mtu_val = mtu & LINK_MTU_BYTEMASK;
40    let mode = LinkMode::from_byte(mode_byte)?;
41    Ok((mtu_val, mode))
42}
43
44/// Build LINKREQUEST data: `[x25519_pub:32][ed25519_pub:32][signalling:0-3]`.
45pub fn build_linkrequest_data(
46    pub_bytes: &[u8; 32],
47    sig_pub_bytes: &[u8; 32],
48    mtu: Option<u32>,
49    mode: LinkMode,
50) -> Vec<u8> {
51    let mut data = Vec::with_capacity(LINK_ECPUBSIZE + LINK_MTU_SIZE);
52    data.extend_from_slice(pub_bytes);
53    data.extend_from_slice(sig_pub_bytes);
54    if let Some(mtu_val) = mtu {
55        let sig_bytes = build_signalling_bytes(mtu_val, mode);
56        data.extend_from_slice(&sig_bytes);
57    }
58    data
59}
60
61/// Parse LINKREQUEST data. Returns `(x25519_pub, ed25519_pub, mtu, mode)`.
62#[allow(clippy::type_complexity)]
63pub fn parse_linkrequest_data(
64    data: &[u8],
65) -> Result<([u8; 32], [u8; 32], Option<u32>, LinkMode), LinkError> {
66    if data.len() != LINK_ECPUBSIZE && data.len() != LINK_ECPUBSIZE + LINK_MTU_SIZE {
67        return Err(LinkError::InvalidData);
68    }
69
70    let mut x25519_pub = [0u8; 32];
71    let mut ed25519_pub = [0u8; 32];
72    x25519_pub.copy_from_slice(&data[..32]);
73    ed25519_pub.copy_from_slice(&data[32..64]);
74
75    if data.len() == LINK_ECPUBSIZE + LINK_MTU_SIZE {
76        let mut sig_bytes = [0u8; 3];
77        sig_bytes.copy_from_slice(&data[LINK_ECPUBSIZE..LINK_ECPUBSIZE + LINK_MTU_SIZE]);
78        let (mtu, mode) = parse_signalling_bytes(&sig_bytes)?;
79        Ok((x25519_pub, ed25519_pub, Some(mtu), mode))
80    } else {
81        Ok((x25519_pub, ed25519_pub, None, LinkMode::Aes256Cbc))
82    }
83}
84
85/// Build LRPROOF data: `[signature:64][x25519_pub:32][signalling:0-3]`.
86///
87/// Signs: `link_id + pub_bytes + sig_pub_bytes + signalling_bytes`.
88pub fn build_lrproof(
89    link_id: &LinkId,
90    pub_bytes: &[u8; 32],
91    sig_pub_bytes: &[u8; 32],
92    sig_prv: &Ed25519PrivateKey,
93    mtu: Option<u32>,
94    mode: LinkMode,
95) -> Vec<u8> {
96    let signalling_bytes = if let Some(mtu_val) = mtu {
97        build_signalling_bytes(mtu_val, mode).to_vec()
98    } else {
99        Vec::new()
100    };
101
102    let mut signed_data = Vec::with_capacity(16 + 32 + 32 + signalling_bytes.len());
103    signed_data.extend_from_slice(link_id);
104    signed_data.extend_from_slice(pub_bytes);
105    signed_data.extend_from_slice(sig_pub_bytes);
106    signed_data.extend_from_slice(&signalling_bytes);
107
108    let signature = sig_prv.sign(&signed_data);
109
110    let mut proof_data = Vec::with_capacity(64 + 32 + signalling_bytes.len());
111    proof_data.extend_from_slice(&signature);
112    proof_data.extend_from_slice(pub_bytes);
113    proof_data.extend_from_slice(&signalling_bytes);
114
115    proof_data
116}
117
118/// Validate LRPROOF. Returns peer X25519 public key bytes on success.
119///
120/// Expects `proof_data = [signature:64][x25519_pub:32][signalling:0-3]`.
121/// Validates against `signed_data = link_id + peer_x25519_pub + peer_sig_pub + signalling`.
122pub fn validate_lrproof(
123    proof_data: &[u8],
124    link_id: &LinkId,
125    peer_sig_pub: &Ed25519PublicKey,
126    peer_sig_pub_bytes: &[u8; 32],
127) -> Result<([u8; 32], Option<u32>, LinkMode), LinkError> {
128    let sig_len = 64;
129    let pub_len = 32;
130
131    if proof_data.len() != sig_len + pub_len
132        && proof_data.len() != sig_len + pub_len + LINK_MTU_SIZE
133    {
134        return Err(LinkError::InvalidData);
135    }
136
137    let mut signature = [0u8; 64];
138    signature.copy_from_slice(&proof_data[..sig_len]);
139
140    let mut peer_pub = [0u8; 32];
141    peer_pub.copy_from_slice(&proof_data[sig_len..sig_len + pub_len]);
142
143    let signalling_bytes = &proof_data[sig_len + pub_len..];
144
145    let (mtu, mode) = if signalling_bytes.len() == LINK_MTU_SIZE {
146        let mut sb = [0u8; 3];
147        sb.copy_from_slice(signalling_bytes);
148        let (m, md) = parse_signalling_bytes(&sb)?;
149        (Some(m), md)
150    } else {
151        (None, LinkMode::Aes256Cbc)
152    };
153
154    let mut signed_data = Vec::with_capacity(16 + 32 + 32 + signalling_bytes.len());
155    signed_data.extend_from_slice(link_id);
156    signed_data.extend_from_slice(&peer_pub);
157    signed_data.extend_from_slice(peer_sig_pub_bytes);
158    signed_data.extend_from_slice(signalling_bytes);
159
160    if peer_sig_pub.verify(&signature, &signed_data) {
161        Ok((peer_pub, mtu, mode))
162    } else {
163        Err(LinkError::InvalidSignature)
164    }
165}
166
167/// Derive session key using HKDF.
168///
169/// `shared_key` is the raw ECDH output (32 bytes).
170/// Salt = link_id, context = None.
171/// Output length depends on mode: 32 for AES-128, 64 for AES-256.
172pub fn derive_session_key(
173    shared_key: &[u8; 32],
174    link_id: &LinkId,
175    mode: LinkMode,
176) -> Result<Vec<u8>, LinkError> {
177    let length = mode.derived_key_length();
178    hkdf(length, shared_key, Some(link_id), None).map_err(|_| LinkError::CryptoError)
179}
180
181/// Perform ECDH key exchange and derive session key.
182pub fn perform_key_exchange(
183    prv: &X25519PrivateKey,
184    peer_pub_bytes: &[u8; 32],
185    link_id: &LinkId,
186    mode: LinkMode,
187) -> Result<Vec<u8>, LinkError> {
188    let peer_pub = X25519PublicKey::from_bytes(peer_pub_bytes);
189    let shared_key = prv.exchange(&peer_pub);
190    derive_session_key(&shared_key, link_id, mode)
191}
192
193/// Pack RTT as msgpack float64: `0xcb` + 8 bytes big-endian.
194pub fn pack_rtt(rtt: f64) -> Vec<u8> {
195    let mut data = Vec::with_capacity(9);
196    data.push(0xcb);
197    data.extend_from_slice(&rtt.to_be_bytes());
198    data
199}
200
201/// Unpack RTT from msgpack float64.
202pub fn unpack_rtt(data: &[u8]) -> Option<f64> {
203    if data.len() == 9 && data[0] == 0xcb {
204        let mut bytes = [0u8; 8];
205        bytes.copy_from_slice(&data[1..9]);
206        Some(f64::from_be_bytes(bytes))
207    } else {
208        None
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use rns_crypto::FixedRng;
216
217    #[test]
218    fn test_signalling_bytes_roundtrip() {
219        let mtu = 500u32;
220        let mode = LinkMode::Aes256Cbc;
221        let bytes = build_signalling_bytes(mtu, mode);
222        let (parsed_mtu, parsed_mode) = parse_signalling_bytes(&bytes).unwrap();
223        assert_eq!(parsed_mtu, mtu);
224        assert_eq!(parsed_mode, mode);
225    }
226
227    #[test]
228    fn test_signalling_bytes_aes128() {
229        let mtu = 1234u32;
230        let mode = LinkMode::Aes128Cbc;
231        let bytes = build_signalling_bytes(mtu, mode);
232        let (parsed_mtu, parsed_mode) = parse_signalling_bytes(&bytes).unwrap();
233        assert_eq!(parsed_mtu, mtu);
234        assert_eq!(parsed_mode, mode);
235    }
236
237    #[test]
238    fn test_signalling_bytes_max_mtu() {
239        let mtu = LINK_MTU_BYTEMASK; // maximum 21-bit value
240        let mode = LinkMode::Aes256Cbc;
241        let bytes = build_signalling_bytes(mtu, mode);
242        let (parsed_mtu, parsed_mode) = parse_signalling_bytes(&bytes).unwrap();
243        assert_eq!(parsed_mtu, mtu);
244        assert_eq!(parsed_mode, mode);
245    }
246
247    #[test]
248    fn test_linkrequest_data_roundtrip() {
249        let pub_bytes = [0xAAu8; 32];
250        let sig_pub_bytes = [0xBBu8; 32];
251        let data =
252            build_linkrequest_data(&pub_bytes, &sig_pub_bytes, Some(500), LinkMode::Aes256Cbc);
253        assert_eq!(data.len(), LINK_ECPUBSIZE + LINK_MTU_SIZE);
254
255        let (p, s, mtu, mode) = parse_linkrequest_data(&data).unwrap();
256        assert_eq!(p, pub_bytes);
257        assert_eq!(s, sig_pub_bytes);
258        assert_eq!(mtu, Some(500));
259        assert_eq!(mode, LinkMode::Aes256Cbc);
260    }
261
262    #[test]
263    fn test_linkrequest_data_no_mtu() {
264        let pub_bytes = [0xAAu8; 32];
265        let sig_pub_bytes = [0xBBu8; 32];
266        let data = build_linkrequest_data(&pub_bytes, &sig_pub_bytes, None, LinkMode::Aes256Cbc);
267        assert_eq!(data.len(), LINK_ECPUBSIZE);
268
269        let (p, s, mtu, mode) = parse_linkrequest_data(&data).unwrap();
270        assert_eq!(p, pub_bytes);
271        assert_eq!(s, sig_pub_bytes);
272        assert_eq!(mtu, None);
273        assert_eq!(mode, LinkMode::Aes256Cbc); // default when no signalling
274    }
275
276    #[test]
277    fn test_linkrequest_data_invalid_size() {
278        let data = [0u8; 10];
279        assert_eq!(parse_linkrequest_data(&data), Err(LinkError::InvalidData));
280    }
281
282    #[test]
283    fn test_compute_link_id_no_extra() {
284        let hashable = [0x42u8; 40];
285        let id = compute_link_id(&hashable, 0);
286        assert_eq!(id.len(), 16);
287    }
288
289    #[test]
290    fn test_compute_link_id_with_extra() {
291        let hashable = [0x42u8; 43]; // 40 base + 3 signalling
292        let id_with_extra = compute_link_id(&hashable, 3);
293        let id_base = compute_link_id(&hashable[..40], 0);
294        assert_eq!(id_with_extra, id_base);
295    }
296
297    #[test]
298    fn test_lrproof_sign_verify() {
299        let mut rng = FixedRng::new(&[0x11; 64]);
300        let sig_prv = Ed25519PrivateKey::generate(&mut rng);
301        let sig_pub = sig_prv.public_key();
302        let sig_pub_bytes = sig_pub.public_bytes();
303
304        let mut rng2 = FixedRng::new(&[0x22; 64]);
305        let x_prv = rns_crypto::x25519::X25519PrivateKey::generate(&mut rng2);
306        let pub_bytes = x_prv.public_key().public_bytes();
307
308        let link_id: LinkId = [0xAA; 16];
309        let mtu = Some(500u32);
310        let mode = LinkMode::Aes256Cbc;
311
312        let proof = build_lrproof(&link_id, &pub_bytes, &sig_pub_bytes, &sig_prv, mtu, mode);
313
314        let result = validate_lrproof(&proof, &link_id, &sig_pub, &sig_pub_bytes);
315        assert!(result.is_ok());
316        let (peer_pub, parsed_mtu, parsed_mode) = result.unwrap();
317        assert_eq!(peer_pub, pub_bytes);
318        assert_eq!(parsed_mtu, mtu);
319        assert_eq!(parsed_mode, mode);
320    }
321
322    #[test]
323    fn test_lrproof_wrong_link_id() {
324        let mut rng = FixedRng::new(&[0x11; 64]);
325        let sig_prv = Ed25519PrivateKey::generate(&mut rng);
326        let sig_pub = sig_prv.public_key();
327        let sig_pub_bytes = sig_pub.public_bytes();
328
329        let pub_bytes = [0x33u8; 32];
330        let link_id: LinkId = [0xAA; 16];
331        let wrong_id: LinkId = [0xBB; 16];
332
333        let proof = build_lrproof(
334            &link_id,
335            &pub_bytes,
336            &sig_pub_bytes,
337            &sig_prv,
338            None,
339            LinkMode::Aes256Cbc,
340        );
341        let result = validate_lrproof(&proof, &wrong_id, &sig_pub, &sig_pub_bytes);
342        assert_eq!(result, Err(LinkError::InvalidSignature));
343    }
344
345    #[test]
346    fn test_derive_session_key_aes128() {
347        let shared = [0x42u8; 32];
348        let link_id = [0xAA; 16];
349        let key = derive_session_key(&shared, &link_id, LinkMode::Aes128Cbc).unwrap();
350        assert_eq!(key.len(), 32);
351    }
352
353    #[test]
354    fn test_derive_session_key_aes256() {
355        let shared = [0x42u8; 32];
356        let link_id = [0xAA; 16];
357        let key = derive_session_key(&shared, &link_id, LinkMode::Aes256Cbc).unwrap();
358        assert_eq!(key.len(), 64);
359    }
360
361    #[test]
362    fn test_rtt_pack_unpack() {
363        let rtt = 0.123456789;
364        let packed = pack_rtt(rtt);
365        assert_eq!(packed.len(), 9);
366        assert_eq!(packed[0], 0xcb);
367        let unpacked = unpack_rtt(&packed).unwrap();
368        assert_eq!(unpacked, rtt);
369    }
370
371    #[test]
372    fn test_rtt_unpack_invalid() {
373        assert_eq!(unpack_rtt(&[0xcb, 0x00]), None);
374        assert_eq!(unpack_rtt(&[0xca, 0, 0, 0, 0, 0, 0, 0, 0]), None);
375    }
376
377    #[test]
378    fn test_rtt_pack_zero() {
379        let packed = pack_rtt(0.0);
380        let unpacked = unpack_rtt(&packed).unwrap();
381        assert_eq!(unpacked, 0.0);
382    }
383}