Skip to main content

chie_crypto/
tls13.rs

1//! TLS 1.3 Key Schedule Support
2//!
3//! This module implements the TLS 1.3 key schedule as defined in RFC 8446.
4//! Provides key derivation for handshake and application traffic secrets.
5//!
6//! # Examples
7//!
8//! ```
9//! use chie_crypto::tls13::Tls13KeySchedule;
10//!
11//! // Create key schedule with shared secret
12//! let shared_secret = [0u8; 32];
13//! let mut schedule = Tls13KeySchedule::new(&shared_secret);
14//!
15//! // Derive handshake traffic secrets
16//! let client_hello = b"client hello";
17//! let server_hello = b"server hello";
18//! let (client_hs_secret, server_hs_secret) = schedule.derive_handshake_secrets(
19//!     client_hello,
20//!     server_hello
21//! );
22//!
23//! // Derive application traffic secrets
24//! let (client_app_secret, server_app_secret) = schedule.derive_application_secrets().unwrap();
25//! ```
26
27use serde::{Deserialize, Serialize};
28use sha2::{Digest, Sha256};
29use thiserror::Error;
30
31/// TLS 1.3 key schedule errors
32#[derive(Debug, Error, Clone, PartialEq, Eq)]
33pub enum Tls13Error {
34    /// Invalid key length
35    #[error("Invalid key length: expected {expected}, got {actual}")]
36    InvalidLength { expected: usize, actual: usize },
37
38    /// Key schedule not initialized
39    #[error("Key schedule not initialized")]
40    NotInitialized,
41
42    /// Invalid state
43    #[error("Invalid state: {0}")]
44    InvalidState(String),
45}
46
47/// Result type for TLS 1.3 operations
48pub type Tls13Result<T> = Result<T, Tls13Error>;
49
50/// TLS 1.3 Key Schedule
51///
52/// Manages the key derivation process for TLS 1.3 connections.
53#[derive(Clone, Serialize, Deserialize)]
54pub struct Tls13KeySchedule {
55    /// Early secret (derived from PSK or zeros)
56    early_secret: [u8; 32],
57    /// Handshake secret
58    handshake_secret: Option<[u8; 32]>,
59    /// Master secret
60    master_secret: Option<[u8; 32]>,
61}
62
63impl Tls13KeySchedule {
64    /// Create a new TLS 1.3 key schedule
65    ///
66    /// # Arguments
67    /// * `shared_secret` - Shared secret from key exchange (e.g., ECDHE)
68    pub fn new(shared_secret: &[u8]) -> Self {
69        // Early secret = HKDF-Extract(salt=0, IKM=0)
70        let zero_salt = [0u8; 32];
71        let early_secret = hkdf_extract(&zero_salt, &zero_salt);
72
73        // Derive handshake secret
74        let handshake_secret = derive_secret(&early_secret, b"derived", &[]);
75        let handshake_secret = hkdf_extract(&handshake_secret, shared_secret);
76
77        Self {
78            early_secret,
79            handshake_secret: Some(handshake_secret),
80            master_secret: None,
81        }
82    }
83
84    /// Derive handshake traffic secrets
85    ///
86    /// # Arguments
87    /// * `client_hello` - Client hello message
88    /// * `server_hello` - Server hello message
89    ///
90    /// # Returns
91    /// Tuple of (client_handshake_traffic_secret, server_handshake_traffic_secret)
92    pub fn derive_handshake_secrets(
93        &mut self,
94        client_hello: &[u8],
95        server_hello: &[u8],
96    ) -> ([u8; 32], [u8; 32]) {
97        let handshake_secret = self
98            .handshake_secret
99            .expect("Handshake secret not initialized");
100
101        // Transcript hash = SHA-256(ClientHello || ServerHello)
102        let mut hasher = Sha256::new();
103        hasher.update(client_hello);
104        hasher.update(server_hello);
105        let transcript_hash = hasher.finalize();
106
107        // Client handshake traffic secret
108        let client_hs_traffic_secret =
109            derive_secret(&handshake_secret, b"c hs traffic", &transcript_hash);
110
111        // Server handshake traffic secret
112        let server_hs_traffic_secret =
113            derive_secret(&handshake_secret, b"s hs traffic", &transcript_hash);
114
115        // Derive master secret for application traffic
116        let derived = derive_secret(&handshake_secret, b"derived", &[]);
117        let master_secret = hkdf_extract(&derived, &[0u8; 32]);
118        self.master_secret = Some(master_secret);
119
120        (client_hs_traffic_secret, server_hs_traffic_secret)
121    }
122
123    /// Derive application traffic secrets
124    ///
125    /// # Returns
126    /// Tuple of (client_application_traffic_secret, server_application_traffic_secret)
127    pub fn derive_application_secrets(&self) -> Tls13Result<([u8; 32], [u8; 32])> {
128        let master_secret = self.master_secret.ok_or(Tls13Error::NotInitialized)?;
129
130        // Empty transcript hash for application traffic
131        let empty_hash = Sha256::digest([]);
132
133        // Client application traffic secret
134        let client_app_traffic_secret = derive_secret(&master_secret, b"c ap traffic", &empty_hash);
135
136        // Server application traffic secret
137        let server_app_traffic_secret = derive_secret(&master_secret, b"s ap traffic", &empty_hash);
138
139        Ok((client_app_traffic_secret, server_app_traffic_secret))
140    }
141
142    /// Derive exporter master secret
143    ///
144    /// Used for exporting keying material outside of TLS
145    pub fn derive_exporter_secret(&self) -> Tls13Result<[u8; 32]> {
146        let master_secret = self.master_secret.ok_or(Tls13Error::NotInitialized)?;
147
148        let empty_hash = Sha256::digest([]);
149        Ok(derive_secret(&master_secret, b"exp master", &empty_hash))
150    }
151
152    /// Derive resumption master secret
153    ///
154    /// Used for session resumption
155    pub fn derive_resumption_secret(&self, transcript_hash: &[u8]) -> Tls13Result<[u8; 32]> {
156        let master_secret = self.master_secret.ok_or(Tls13Error::NotInitialized)?;
157
158        Ok(derive_secret(
159            &master_secret,
160            b"res master",
161            transcript_hash,
162        ))
163    }
164
165    /// Update traffic keys (key update)
166    ///
167    /// Derives new traffic secret from current one
168    pub fn update_traffic_secret(current_secret: &[u8; 32]) -> [u8; 32] {
169        derive_secret(current_secret, b"traffic upd", &[])
170    }
171}
172
173/// HKDF-Extract operation
174fn hkdf_extract(salt: &[u8], ikm: &[u8]) -> [u8; 32] {
175    use hmac::digest::KeyInit;
176    use hmac::{Hmac, Mac};
177    type HmacSha256 = Hmac<Sha256>;
178
179    let mut mac =
180        <HmacSha256 as KeyInit>::new_from_slice(salt).expect("HMAC can take key of any size");
181    mac.update(ikm);
182    let result = mac.finalize();
183    let bytes = result.into_bytes();
184
185    let mut output = [0u8; 32];
186    output.copy_from_slice(&bytes);
187    output
188}
189
190/// HKDF-Expand-Label operation (TLS 1.3 specific)
191fn hkdf_expand_label(secret: &[u8], label: &[u8], context: &[u8], length: u16) -> Vec<u8> {
192    // HkdfLabel structure:
193    // struct {
194    //     uint16 length = Length;
195    //     opaque label<7..255> = "tls13 " + Label;
196    //     opaque context<0..255> = Context;
197    // } HkdfLabel;
198
199    let mut hkdf_label = Vec::new();
200
201    // Length (2 bytes)
202    hkdf_label.extend_from_slice(&length.to_be_bytes());
203
204    // Label = "tls13 " + label
205    let full_label = [b"tls13 ", label].concat();
206    hkdf_label.push(full_label.len() as u8);
207    hkdf_label.extend_from_slice(&full_label);
208
209    // Context
210    hkdf_label.push(context.len() as u8);
211    hkdf_label.extend_from_slice(context);
212
213    // HKDF-Expand
214    hkdf_expand(secret, &hkdf_label, length as usize)
215}
216
217/// HKDF-Expand operation
218fn hkdf_expand(prk: &[u8], info: &[u8], length: usize) -> Vec<u8> {
219    use hmac::digest::KeyInit;
220    use hmac::{Hmac, Mac};
221    type HmacSha256 = Hmac<Sha256>;
222
223    let mut output = Vec::with_capacity(length);
224    let mut t = Vec::new();
225    let mut counter = 1u8;
226
227    while output.len() < length {
228        let mut mac =
229            <HmacSha256 as KeyInit>::new_from_slice(prk).expect("HMAC can take key of any size");
230        mac.update(&t);
231        mac.update(info);
232        mac.update(&[counter]);
233
234        t = mac.finalize().into_bytes().to_vec();
235        output.extend_from_slice(&t);
236        counter += 1;
237    }
238
239    output.truncate(length);
240    output
241}
242
243/// Derive-Secret operation (TLS 1.3 specific)
244fn derive_secret(secret: &[u8], label: &[u8], messages: &[u8]) -> [u8; 32] {
245    // Transcript-Hash(Messages)
246    let transcript_hash = if messages.is_empty() {
247        Sha256::digest([]).to_vec()
248    } else {
249        messages.to_vec()
250    };
251
252    let expanded = hkdf_expand_label(secret, label, &transcript_hash, 32);
253    let mut output = [0u8; 32];
254    output.copy_from_slice(&expanded[..32]);
255    output
256}
257
258/// Derive traffic keys from traffic secret
259///
260/// # Returns
261/// Tuple of (key, iv) for AEAD encryption
262pub fn derive_traffic_keys(traffic_secret: &[u8; 32]) -> ([u8; 32], [u8; 12]) {
263    // Key = HKDF-Expand-Label(Secret, "key", "", key_length)
264    let key_bytes = hkdf_expand_label(traffic_secret, b"key", &[], 32);
265    let mut key = [0u8; 32];
266    key.copy_from_slice(&key_bytes[..32]);
267
268    // IV = HKDF-Expand-Label(Secret, "iv", "", iv_length)
269    let iv_bytes = hkdf_expand_label(traffic_secret, b"iv", &[], 12);
270    let mut iv = [0u8; 12];
271    iv.copy_from_slice(&iv_bytes[..12]);
272
273    (key, iv)
274}
275
276#[cfg(test)]
277mod tests {
278    use super::*;
279
280    #[test]
281    fn test_key_schedule_creation() {
282        let shared_secret = [0x42u8; 32];
283        let schedule = Tls13KeySchedule::new(&shared_secret);
284
285        assert!(schedule.handshake_secret.is_some());
286        assert!(schedule.master_secret.is_none());
287    }
288
289    #[test]
290    fn test_handshake_secrets_derivation() {
291        let shared_secret = [0x42u8; 32];
292        let mut schedule = Tls13KeySchedule::new(&shared_secret);
293
294        let client_hello = b"client hello message";
295        let server_hello = b"server hello message";
296
297        let (client_hs, server_hs) = schedule.derive_handshake_secrets(client_hello, server_hello);
298
299        // Secrets should be different
300        assert_ne!(client_hs, server_hs);
301
302        // Master secret should now be set
303        assert!(schedule.master_secret.is_some());
304    }
305
306    #[test]
307    fn test_application_secrets_derivation() {
308        let shared_secret = [0x42u8; 32];
309        let mut schedule = Tls13KeySchedule::new(&shared_secret);
310
311        // Must derive handshake secrets first
312        let client_hello = b"client hello";
313        let server_hello = b"server hello";
314        schedule.derive_handshake_secrets(client_hello, server_hello);
315
316        // Now derive application secrets
317        let result = schedule.derive_application_secrets();
318        assert!(result.is_ok());
319
320        let (client_app, server_app) = result.unwrap();
321        assert_ne!(client_app, server_app);
322    }
323
324    #[test]
325    fn test_application_secrets_before_handshake() {
326        let shared_secret = [0x42u8; 32];
327        let schedule = Tls13KeySchedule::new(&shared_secret);
328
329        // Should fail because handshake secrets not derived yet
330        let result = schedule.derive_application_secrets();
331        assert!(result.is_err());
332    }
333
334    #[test]
335    fn test_exporter_secret() {
336        let shared_secret = [0x42u8; 32];
337        let mut schedule = Tls13KeySchedule::new(&shared_secret);
338
339        schedule.derive_handshake_secrets(b"client hello", b"server hello");
340
341        let exporter_secret = schedule.derive_exporter_secret();
342        assert!(exporter_secret.is_ok());
343        assert_eq!(exporter_secret.unwrap().len(), 32);
344    }
345
346    #[test]
347    fn test_resumption_secret() {
348        let shared_secret = [0x42u8; 32];
349        let mut schedule = Tls13KeySchedule::new(&shared_secret);
350
351        schedule.derive_handshake_secrets(b"client hello", b"server hello");
352
353        let transcript = Sha256::digest(b"full handshake transcript");
354        let resumption_secret = schedule.derive_resumption_secret(&transcript);
355        assert!(resumption_secret.is_ok());
356        assert_eq!(resumption_secret.unwrap().len(), 32);
357    }
358
359    #[test]
360    fn test_traffic_key_update() {
361        let current_secret = [0x42u8; 32];
362        let new_secret = Tls13KeySchedule::update_traffic_secret(&current_secret);
363
364        // New secret should be different
365        assert_ne!(current_secret, new_secret);
366    }
367
368    #[test]
369    fn test_derive_traffic_keys() {
370        let traffic_secret = [0x42u8; 32];
371        let (key, iv) = derive_traffic_keys(&traffic_secret);
372
373        assert_eq!(key.len(), 32);
374        assert_eq!(iv.len(), 12);
375    }
376
377    #[test]
378    fn test_hkdf_extract() {
379        let salt = [0x01u8; 32];
380        let ikm = [0x02u8; 32];
381
382        let prk = hkdf_extract(&salt, &ikm);
383        assert_eq!(prk.len(), 32);
384
385        // Should be deterministic
386        let prk2 = hkdf_extract(&salt, &ikm);
387        assert_eq!(prk, prk2);
388    }
389
390    #[test]
391    fn test_hkdf_expand() {
392        let prk = [0x42u8; 32];
393        let info = b"test info";
394
395        let okm = hkdf_expand(&prk, info, 64);
396        assert_eq!(okm.len(), 64);
397
398        // Should be deterministic
399        let okm2 = hkdf_expand(&prk, info, 64);
400        assert_eq!(okm, okm2);
401    }
402
403    #[test]
404    fn test_hkdf_expand_label() {
405        let secret = [0x42u8; 32];
406        let label = b"test label";
407        let context = b"test context";
408
409        let output = hkdf_expand_label(&secret, label, context, 32);
410        assert_eq!(output.len(), 32);
411
412        // Should be deterministic
413        let output2 = hkdf_expand_label(&secret, label, context, 32);
414        assert_eq!(output, output2);
415    }
416
417    #[test]
418    fn test_derive_secret() {
419        let secret = [0x42u8; 32];
420        let label = b"test";
421        let messages = b"messages";
422
423        let derived = derive_secret(&secret, label, messages);
424        assert_eq!(derived.len(), 32);
425
426        // Should be deterministic
427        let derived2 = derive_secret(&secret, label, messages);
428        assert_eq!(derived, derived2);
429    }
430
431    #[test]
432    fn test_serialization() {
433        let shared_secret = [0x42u8; 32];
434        let schedule = Tls13KeySchedule::new(&shared_secret);
435
436        let serialized = crate::codec::encode(&schedule).unwrap();
437        let deserialized: Tls13KeySchedule = crate::codec::decode(&serialized).unwrap();
438
439        assert_eq!(deserialized.early_secret, schedule.early_secret);
440        assert_eq!(deserialized.handshake_secret, schedule.handshake_secret);
441    }
442}